Installation¶

Available on PyPI as zanj

pip install zanj
In [1]:
import os

import numpy as np
import pandas as pd
import torch

from muutils.json_serialize import (
    SerializableDataclass,
    serializable_dataclass,
    serializable_field,
)
from zanj import ZANJ

Usage¶

Saving a basic object¶

Any SerializableDataclass of basic types can be saved as zanj:

In [2]:
@serializable_dataclass
class BasicZanj(SerializableDataclass):
    a: str
    q: int = 42
    c: list[int] = serializable_field(default_factory=list)


# initialize a zanj reader/writer
zj = ZANJ()

# create an instance
instance: BasicZanj = BasicZanj("hello", 42, [1, 2, 3])
path: str = "tests/junk_data/path_to_save_instance.zanj"
zj.save(instance, path)
recovered: BasicZanj = zj.read(path)
In [3]:
print(f"{type(recovered) = }")  # BasicZanj
print(f"{os.path.getsize(path) = }")
type(recovered) = <class '__main__.BasicZanj'>
os.path.getsize(path) = 2509

ZANJ will intelligently handle nested serializable dataclasses, numpy arrays, pytorch tensors, and pandas dataframes:

In [4]:
@serializable_dataclass
class Complicated(SerializableDataclass):
    name: str
    arr1: np.ndarray
    arr2: np.ndarray
    iris_data: pd.DataFrame
    brain_data: pd.DataFrame
    container: list[BasicZanj]
    torch_tensor: torch.Tensor

For custom classes, you can specify a serialization_fn and loading_fn to handle the logic of converting to and from a json-serializable format:

In [5]:
@serializable_dataclass
class Complicated2(SerializableDataclass):
    name: str
    device: torch.device = serializable_field(
        serialization_fn=lambda self: str(self.device),
        loading_fn=lambda data: torch.device(data["device"]),
    )

Note that loading_fn takes the dictionary of the whole class -- this is in case you've stored data in multiple fields of the dict which are needed to reconstruct the object.

Saving Models¶

First, define a configuration class for your model. This class will hold the parameters for your model and any associated objects (like losses and optimizers). The configuration class should be a subclass of SerializableDataclass and use the serializable_field function to define fields that need special serialization.

Here's an example that defines a configuration for a simple neural network:

In [6]:
from zanj.torchutil import ConfiguredModel, set_config_class


@serializable_dataclass
class MyNNConfig(SerializableDataclass):
    input_dim: int
    hidden_dim: int
    output_dim: int

    # store the activation function by name, reconstruct it by looking it up in torch.nn
    act_fn: torch.nn.Module = serializable_field(
        serialization_fn=lambda x: x.__name__,
        loading_fn=lambda x: getattr(torch.nn, x["act_fn"]),
    )

    # same for the loss function
    loss_kwargs: dict = serializable_field(default_factory=dict)
    loss_factory: torch.nn.modules.loss._Loss = serializable_field(
        default_factory=lambda: torch.nn.CrossEntropyLoss,
        serialization_fn=lambda x: x.__name__,
        loading_fn=lambda x: getattr(torch.nn, x["loss_factory"]),
    )
    loss = property(lambda self: self.loss_factory(**self.loss_kwargs))

Then, define your model class. It should be a subclass of ConfiguredModel, and use the set_config_class decorator to associate it with your configuration class. The __init__ method should take a single argument, which is an instance of your configuration class. You must also call the superclass __init__ method with the configuration instance.

In [7]:
@set_config_class(MyNNConfig)
class MyNN(ConfiguredModel[MyNNConfig]):
    def __init__(self, config: MyNNConfig):
        # call the superclass init!
        # this will store the model in the zanj_model_config field
        super().__init__(config)

        # whatever you want here
        self.net = torch.nn.Sequential(
            torch.nn.Linear(config.input_dim, config.hidden_dim),
            config.act_fn(),
            torch.nn.Linear(config.hidden_dim, config.output_dim),
        )

    def forward(self, x):
        return self.net(x)

You can now create instances of your model, save them to disk, and load them back into memory:

In [8]:
config = MyNNConfig(
    input_dim=10,
    hidden_dim=20,
    output_dim=2,
    act_fn=torch.nn.ReLU,
    loss_kwargs=dict(reduction="mean"),
)

# create your model from the config, and save
model = MyNN(config)
fname = "tests/junk_data/path_to_save_model.zanj"
ZANJ().save(model, fname)
# load by calling the class method `read()`
loaded_model = MyNN.read(fname)
# zanj will actually infer the type of the object in the file
# -- and will warn you if you don't have the correct package installed
loaded_another_way = ZANJ().read(fname)
In [9]:
print(f"{type(loaded_model) = }")
x = torch.randn(config.input_dim)
print(f"{x.shape = }")
out_1 = model(x)
out_2 = loaded_model(x)
out_3 = loaded_another_way(x)

print(f"{out_1 = }, {out_2 = }, {out_3 = }")
assert torch.allclose(out_1, out_2)
assert torch.allclose(out_1, out_3)
type(loaded_model) = <class '__main__.MyNN'>
x.shape = torch.Size([10])
out_1 = tensor([ 0.0378, -0.4873], grad_fn=<AddBackward0>), out_2 = tensor([ 0.0378, -0.4873], grad_fn=<AddBackward0>), out_3 = tensor([ 0.0378, -0.4873], grad_fn=<AddBackward0>)