Coverage for tests\test_zanj_sdc_modelcfg.py: 100%
90 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
1from __future__ import annotations
3import json
4import sys
5import typing
6from pathlib import Path
8import numpy as np
9import torch
10from muutils.json_serialize import (
11 SerializableDataclass,
12 serializable_dataclass,
13 serializable_field,
14)
16from zanj import ZANJ
18np.random.seed(0)
20# pylint: disable=missing-function-docstring,missing-class-docstring
22TEST_DATA_PATH: Path = Path("tests/junk_data")
25SUPPORTS_KW_ONLY: bool = bool(sys.version_info >= (3, 10))
28@serializable_dataclass
29class MyModelCfg(SerializableDataclass):
30 name: str
31 num_layers: int
32 hidden_size: int
33 dropout: float
36@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
37class TrainCfg(SerializableDataclass):
38 name: str
39 weight_decay: float
40 optimizer: typing.Type[torch.optim.Optimizer] = serializable_field(
41 default_factory=lambda: torch.optim.Adam,
42 serialization_fn=lambda x: x.__name__,
43 loading_fn=lambda data: getattr(torch.optim, data["optimizer"]),
44 )
45 optimizer_kwargs: typing.Dict[str, typing.Any] = serializable_field( # type: ignore
46 default_factory=lambda: dict(lr=0.000001)
47 )
50class CustomCfg:
51 def __init__(self, x: int, y: str):
52 self.x = x
53 self.y = y
55 def __eq__(self, other):
56 return self.x == other.x and self.y == other.y
58 def serialize(self):
59 return {"x": self.x, "y": self.y}
61 @classmethod
62 def load(cls, data):
63 return cls(
64 **{
65 "x": data["x"],
66 "y": data["y"],
67 }
68 )
71@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
72class BasicCfgHolder(SerializableDataclass):
73 model: MyModelCfg
74 optimizer: TrainCfg
75 custom: typing.Optional[CustomCfg] = serializable_field(
76 default=None,
77 serialization_fn=lambda x: x.serialize(),
78 loading_fn=lambda data: CustomCfg.load(data["custom"]),
79 )
82instance_basic: BasicCfgHolder = BasicCfgHolder( # type: ignore
83 model=MyModelCfg("lstm", 3, 128, 0.1), # type: ignore
84 optimizer=TrainCfg( # type: ignore
85 name="adamw",
86 weight_decay=0.2,
87 optimizer=torch.optim.AdamW,
88 optimizer_kwargs=dict(lr=0.0001),
89 ),
90 custom=CustomCfg(42, "forty-two"),
91)
94def test_config_holder():
95 instance_stored = instance_basic.serialize()
96 with open(TEST_DATA_PATH / "test_config_holder.json", "w") as f:
97 json.dump(instance_stored, f, indent="\t")
98 with open(TEST_DATA_PATH / "test_config_holder.json", "r") as f:
99 instance_stored_read = json.load(f)
100 recovered = BasicCfgHolder.load(instance_stored_read)
101 assert isinstance(recovered.model, MyModelCfg)
102 assert isinstance(recovered.optimizer, TrainCfg)
103 assert isinstance(recovered.custom, CustomCfg)
104 assert recovered.custom.x == 42
105 assert instance_basic == recovered
108def test_config_holder_zanj():
109 z = ZANJ()
110 path = TEST_DATA_PATH / "test_config_holder.zanj"
111 z.save(instance_basic, path)
112 recovered = z.read(path)
113 assert isinstance(recovered.model, MyModelCfg)
114 assert isinstance(recovered.optimizer, TrainCfg)
115 assert isinstance(recovered.custom, CustomCfg)
116 assert recovered.custom.x == 42
117 assert instance_basic == recovered
120@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
121class BaseGPTConfig(SerializableDataclass):
122 name: str
123 act_fn: str
124 d_model: int
125 d_head: int
126 n_layers: int
129@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
130class AdvCfgHolder(SerializableDataclass):
131 model_cfg: BaseGPTConfig
132 name: str = serializable_field(default="default")
133 tokenizer: typing.Optional[CustomCfg] = serializable_field(
134 default=None,
135 serialization_fn=lambda x: repr(x) if x is not None else None,
136 loading_fn=lambda data: (
137 None if data["tokenizer"] is None else NotImplementedError
138 ),
139 )
142instance_adv: AdvCfgHolder = AdvCfgHolder( # type: ignore
143 model_cfg=BaseGPTConfig( # type: ignore
144 name="gpt2",
145 act_fn="gelu",
146 d_model=128,
147 d_head=64,
148 n_layers=3,
149 ),
150 tokenizer=None,
151)
154def test_adv_config_holder():
155 instance_stored = instance_adv.serialize()
156 with open(TEST_DATA_PATH / "test_adv_config_holder.json", "w") as f:
157 json.dump(instance_stored, f, indent="\t")
158 recovered = AdvCfgHolder.load(instance_stored)
159 assert isinstance(recovered.model_cfg, BaseGPTConfig)
160 assert instance_adv == recovered
163def test_adv_config_holder_zanj():
164 z = ZANJ()
165 path = TEST_DATA_PATH / "test_adv_config_holder.zanj"
166 z.save(instance_adv, path)
167 recovered = z.read(path)
168 assert isinstance(recovered.model_cfg, BaseGPTConfig)
169 assert instance_adv == recovered