Coverage for tests\test_isolate_zanj_handler_store.py: 98%
47 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
1from __future__ import annotations
3import json
4import typing
5import zipfile
6from pathlib import Path
8import numpy as np
9from muutils.json_serialize import (
10 SerializableDataclass,
11 serializable_dataclass,
12 serializable_field,
13)
15from zanj import ZANJ
16from zanj.loading import LOADER_MAP
18np.random.seed(0)
20# pylint: disable=missing-function-docstring,missing-class-docstring
22TEST_DATA_PATH: Path = Path("tests/junk_data")
25@serializable_dataclass
26class Basic(SerializableDataclass):
27 a: str
28 q: int = 42
29 c: typing.List[int] = serializable_field(default_factory=list)
32def test_Basic():
33 instance = Basic("hello", 42, [1, 2, 3])
35 z = ZANJ()
36 path = TEST_DATA_PATH / "test_Basic.zanj"
37 z.save(instance, path)
38 recovered = z.read(path)
39 assert instance == recovered
42print(list(LOADER_MAP.keys()))
45@serializable_dataclass
46class ModelCfg(SerializableDataclass):
47 name: str
48 num_layers: int
49 hidden_size: int
50 dropout: float
53print(list(LOADER_MAP.keys()))
56def test_isolate_handlers():
57 instance = ModelCfg("lstm", 3, 128, 0.1)
59 print(list(LOADER_MAP.keys()))
61 z = ZANJ()
62 path = TEST_DATA_PATH / "00-test_isolate_handlers.zanj"
63 z.save(instance, path)
64 recovered = z.read(path)
65 assert instance == recovered
67 assert "Basic(SerializableDataclass)" in LOADER_MAP
68 assert "ModelCfg(SerializableDataclass)" in LOADER_MAP
70 # check they are in the zanj file
71 with zipfile.ZipFile(path, "r") as zfile:
72 zmeta = json.load(zfile.open("__zanj_meta__.json", "r"))
73 assert "Basic(SerializableDataclass)" in zmeta["zanj_cfg"]["load_handlers"]
74 assert "ModelCfg(SerializableDataclass)" in zmeta["zanj_cfg"]["load_handlers"]
77if __name__ == "__main__":
78 test_isolate_handlers()