import os
import numpy as np
import pandas as pd
import torch
from muutils.json_serialize import (
SerializableDataclass,
serializable_dataclass,
serializable_field,
)
from zanj import ZANJ
@serializable_dataclass
class BasicZanj(SerializableDataclass):
a: str
q: int = 42
c: list[int] = serializable_field(default_factory=list)
# initialize a zanj reader/writer
zj = ZANJ()
# create an instance
instance: BasicZanj = BasicZanj("hello", 42, [1, 2, 3])
path: str = "tests/junk_data/path_to_save_instance.zanj"
zj.save(instance, path)
recovered: BasicZanj = zj.read(path)
print(f"{type(recovered) = }") # BasicZanj
print(f"{os.path.getsize(path) = }")
type(recovered) = <class '__main__.BasicZanj'> os.path.getsize(path) = 2509
ZANJ will intelligently handle nested serializable dataclasses, numpy arrays, pytorch tensors, and pandas dataframes:
@serializable_dataclass
class Complicated(SerializableDataclass):
name: str
arr1: np.ndarray
arr2: np.ndarray
iris_data: pd.DataFrame
brain_data: pd.DataFrame
container: list[BasicZanj]
torch_tensor: torch.Tensor
For custom classes, you can specify a serialization_fn
and loading_fn
to handle the logic of converting to and from a json-serializable format:
@serializable_dataclass
class Complicated2(SerializableDataclass):
name: str
device: torch.device = serializable_field(
serialization_fn=lambda self: str(self.device),
loading_fn=lambda data: torch.device(data["device"]),
)
Note that loading_fn
takes the dictionary of the whole class -- this is in case you've stored data in multiple fields of the dict which are needed to reconstruct the object.
Saving Models¶
First, define a configuration class for your model. This class will hold the parameters for your model and any associated objects (like losses and optimizers). The configuration class should be a subclass of SerializableDataclass
and use the serializable_field
function to define fields that need special serialization.
Here's an example that defines a configuration for a simple neural network:
from zanj.torchutil import ConfiguredModel, set_config_class
@serializable_dataclass
class MyNNConfig(SerializableDataclass):
input_dim: int
hidden_dim: int
output_dim: int
# store the activation function by name, reconstruct it by looking it up in torch.nn
act_fn: torch.nn.Module = serializable_field(
serialization_fn=lambda x: x.__name__,
loading_fn=lambda x: getattr(torch.nn, x["act_fn"]),
)
# same for the loss function
loss_kwargs: dict = serializable_field(default_factory=dict)
loss_factory: torch.nn.modules.loss._Loss = serializable_field(
default_factory=lambda: torch.nn.CrossEntropyLoss,
serialization_fn=lambda x: x.__name__,
loading_fn=lambda x: getattr(torch.nn, x["loss_factory"]),
)
loss = property(lambda self: self.loss_factory(**self.loss_kwargs))
Then, define your model class. It should be a subclass of ConfiguredModel
, and use the set_config_class
decorator to associate it with your configuration class. The __init__
method should take a single argument, which is an instance of your configuration class. You must also call the superclass __init__
method with the configuration instance.
@set_config_class(MyNNConfig)
class MyNN(ConfiguredModel[MyNNConfig]):
def __init__(self, config: MyNNConfig):
# call the superclass init!
# this will store the model in the zanj_model_config field
super().__init__(config)
# whatever you want here
self.net = torch.nn.Sequential(
torch.nn.Linear(config.input_dim, config.hidden_dim),
config.act_fn(),
torch.nn.Linear(config.hidden_dim, config.output_dim),
)
def forward(self, x):
return self.net(x)
You can now create instances of your model, save them to disk, and load them back into memory:
config = MyNNConfig(
input_dim=10,
hidden_dim=20,
output_dim=2,
act_fn=torch.nn.ReLU,
loss_kwargs=dict(reduction="mean"),
)
# create your model from the config, and save
model = MyNN(config)
fname = "tests/junk_data/path_to_save_model.zanj"
ZANJ().save(model, fname)
# load by calling the class method `read()`
loaded_model = MyNN.read(fname)
# zanj will actually infer the type of the object in the file
# -- and will warn you if you don't have the correct package installed
loaded_another_way = ZANJ().read(fname)
print(f"{type(loaded_model) = }")
x = torch.randn(config.input_dim)
print(f"{x.shape = }")
out_1 = model(x)
out_2 = loaded_model(x)
out_3 = loaded_another_way(x)
print(f"{out_1 = }, {out_2 = }, {out_3 = }")
assert torch.allclose(out_1, out_2)
assert torch.allclose(out_1, out_3)
type(loaded_model) = <class '__main__.MyNN'> x.shape = torch.Size([10]) out_1 = tensor([ 0.0378, -0.4873], grad_fn=<AddBackward0>), out_2 = tensor([ 0.0378, -0.4873], grad_fn=<AddBackward0>), out_3 = tensor([ 0.0378, -0.4873], grad_fn=<AddBackward0>)