Coverage for tests\test_zanj_torch.py: 95%
88 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
1from __future__ import annotations
3from pathlib import Path
5import numpy as np
6import torch
7from muutils.json_serialize import (
8 SerializableDataclass,
9 serializable_dataclass,
10 serializable_field,
11)
12from muutils.tensor_utils import compare_state_dicts
14from zanj import ZANJ
15from zanj.torchutil import (
16 ConfiguredModel,
17 assert_model_exact_equality,
18 set_config_class,
19)
21np.random.seed(0)
23TEST_DATA_PATH: Path = Path("tests/junk_data")
26def test_torch_configmodel_minimal():
27 @serializable_dataclass
28 class MyNNConfig(SerializableDataclass):
29 n_layers: int
31 @set_config_class(MyNNConfig)
32 class MyNN(ConfiguredModel[MyNNConfig]):
33 def __init__(self, config: MyNNConfig):
34 super().__init__(config)
36 self.layer = torch.nn.Linear(config.n_layers, 1)
38 def forward(self, x):
39 return self.layer(x)
41 config: MyNNConfig = MyNNConfig(
42 n_layers=2,
43 )
45 model: MyNN = MyNN(config)
47 fname: Path = TEST_DATA_PATH / "test_torch_configmodel.zanj"
48 ZANJ().save(model, fname)
50 print(f"saved model to {fname}")
51 print(f"{model.zanj_model_config = }")
53 # try to load the model
54 model2: MyNN = MyNN.read(fname)
55 print(f"loaded model from {fname}")
56 print(f"{model2.zanj_model_config = }")
58 assert model.zanj_model_config == model2.zanj_model_config
59 assert model.training_records == model2.training_records
61 compare_state_dicts(model.state_dict(), model2.state_dict())
62 assert_model_exact_equality(model, model2)
64 model3: MyNN = ZANJ().read(fname)
65 print(f"loaded model from {fname}")
66 print(f"{model3.zanj_model_config = }")
68 assert model.zanj_model_config == model3.zanj_model_config
69 assert model.training_records == model3.training_records
71 compare_state_dicts(model.state_dict(), model3.state_dict())
72 assert_model_exact_equality(model, model3)
75def test_torch_configmodel():
76 import torch
78 from zanj.torchutil import ConfiguredModel, set_config_class
80 @serializable_dataclass
81 class MyGPTConfig(SerializableDataclass):
82 """basic test GPT config"""
84 n_layers: int
85 n_heads: int
86 embedding_size: int
87 n_positions: int
88 n_vocab: int
90 loss_factory: torch.nn.modules.loss._Loss = serializable_field(
91 default_factory=lambda: torch.nn.CrossEntropyLoss,
92 serialization_fn=lambda x: x.__name__,
93 loading_fn=lambda x: getattr(torch.nn, x["loss_factory"]),
94 )
96 loss_kwargs: dict = serializable_field(default_factory=dict)
98 @property
99 def loss(self):
100 return self.loss_factory(**self.loss_kwargs)
102 optim_factory: torch.optim.Optimizer = serializable_field(
103 default_factory=lambda: torch.optim.Adam,
104 serialization_fn=lambda x: x.__name__,
105 loading_fn=lambda x: getattr(torch.optim, x["optim_factory"]),
106 )
108 optim_kwargs: dict = serializable_field(default_factory=dict)
110 def optim(self, model):
111 return self.optim_factory(model.parameters(), **self.optim_kwargs) # type: ignore
113 @set_config_class(MyGPTConfig)
114 class MyGPT(ConfiguredModel[MyGPTConfig]):
115 """basic GPT model"""
117 def __init__(self, config: MyGPTConfig):
118 super().__init__(config)
120 # implementation of a GPT style model with decoders only
122 self.transformer = torch.nn.Transformer(
123 d_model=config.embedding_size,
124 nhead=config.n_heads,
125 num_encoder_layers=0,
126 num_decoder_layers=config.n_layers,
127 )
129 def forward(self, x):
130 return self.transformer(x)
132 config: MyGPTConfig = MyGPTConfig(
133 n_layers=2,
134 n_heads=2,
135 embedding_size=16,
136 n_positions=16,
137 n_vocab=128,
138 loss_factory=torch.nn.CrossEntropyLoss,
139 )
141 model: MyGPT = MyGPT(config)
142 model.training_records = dict(loss=[3, 2, 1], accuracy=[0.1, 0.2, 0.3])
144 fname: Path = TEST_DATA_PATH / "test_torch_configmodel.zanj"
145 ZANJ().save(model, fname)
147 print(f"saved model to {fname}")
148 print(f"{model.zanj_model_config = }")
150 # try to load the model
151 model2: MyGPT = MyGPT.read(fname)
152 print(f"loaded model from {fname}")
153 print(f"{model2.zanj_model_config = }")
155 assert model.zanj_model_config == model2.zanj_model_config
156 assert model.training_records == model2.training_records
158 compare_state_dicts(model.state_dict(), model2.state_dict())
159 assert_model_exact_equality(model, model2)
161 model3: MyGPT = ZANJ().read(fname)
162 print(f"loaded model from {fname}")
163 print(f"{model3.zanj_model_config = }")
165 assert model.zanj_model_config == model3.zanj_model_config
166 assert model.training_records == model3.training_records
168 compare_state_dicts(model.state_dict(), model3.state_dict())
169 assert_model_exact_equality(model, model3)