Coverage for tests\test_zanj_torch_cfgmismatch.py: 95%
59 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
1from __future__ import annotations
3from typing import Any
5import torch
6from muutils.json_serialize import (
7 SerializableDataclass,
8 serializable_dataclass,
9 serializable_field,
10)
12from zanj.torchutil import (
13 ConfigMismatchException,
14 ConfiguredModel,
15 assert_model_cfg_equality,
16 set_config_class,
17)
19# Assuming required imports and classes are present (including ConfiguredModel, MyGPTConfig, and MyGPT)
22@serializable_dataclass
23class MyGPTConfig(SerializableDataclass):
24 """basic test GPT config"""
26 n_layers: int
27 n_heads: int
28 embedding_size: int
29 n_positions: int
30 n_vocab: int
31 junk_data: Any = serializable_field(default_factory=dict)
34@set_config_class(MyGPTConfig)
35class MyGPT(ConfiguredModel[MyGPTConfig]):
36 """basic "GPT" model"""
38 def __init__(self, config: MyGPTConfig):
39 super().__init__(config)
40 self.transformer = torch.nn.Linear(config.embedding_size, config.n_vocab)
42 def forward(self, x):
43 return self.transformer(x)
46def test_config_mismatch_exception_direct():
47 msg = "Configs don't match"
48 diff = {"model_cfg": {"are_weights_processed": {"self": False, "other": True}}}
50 exc = ConfigMismatchException(msg, diff)
51 assert exc.diff == diff
52 assert (
53 str(exc)
54 == r"Configs don't match: {'model_cfg': {'are_weights_processed': {'self': False, 'other': True}}}"
55 )
58def test_equal_configs():
59 config = MyGPTConfig(
60 n_layers=2,
61 n_heads=2,
62 embedding_size=16,
63 n_positions=16,
64 n_vocab=128,
65 junk_data={"a": 1, "b": 2},
66 )
68 model_a = MyGPT(config)
69 model_b = MyGPT(config)
71 assert_model_cfg_equality(model_a, model_b)
74def test_unequal_configs():
75 config_a = MyGPTConfig(
76 n_layers=2,
77 n_heads=2,
78 embedding_size=16,
79 n_positions=16,
80 n_vocab=128,
81 junk_data={"a": 1, "b": 2},
82 )
83 # a different config
84 config_b = MyGPTConfig(
85 n_layers=3,
86 n_heads=2,
87 embedding_size=16,
88 n_positions=16,
89 n_vocab=128,
90 junk_data={"a": 7, "something": "or other"},
91 )
93 model_a = MyGPT(config_a)
94 model_b = MyGPT(config_b)
96 try:
97 assert_model_cfg_equality(model_a, model_b)
98 except ConfigMismatchException as exc:
99 assert exc.diff == {
100 "n_layers": {"self": 2, "other": 3},
101 "junk_data": {
102 "self": {"a": 1, "b": 2},
103 "other": {"a": 7, "something": "or other"},
104 },
105 }
106 else:
107 raise AssertionError("Expected a ConfigMismatchException!")
110def test_unequal_configs_2():
111 config_a = MyGPTConfig(
112 n_layers=2,
113 n_heads=2,
114 embedding_size=16,
115 n_positions=16,
116 n_vocab=128,
117 junk_data={"a": 1, "b": 2},
118 )
119 # a different config
120 config_b = MyGPTConfig(
121 n_layers=3,
122 n_heads=2,
123 embedding_size=16,
124 n_positions=16,
125 n_vocab=128,
126 junk_data="this isnt even a dict lol",
127 )
129 model_a = MyGPT(config_a)
130 model_b = MyGPT(config_b)
132 try:
133 assert_model_cfg_equality(model_a, model_b)
134 except ConfigMismatchException as exc:
135 assert exc.diff == {
136 "n_layers": {"self": 2, "other": 3},
137 "junk_data": {
138 "self": {"a": 1, "b": 2},
139 "other": "this isnt even a dict lol",
140 },
141 }
142 else:
143 raise AssertionError("Expected a ConfigMismatchException!")
146def test_incorrect_instance():
147 config = MyGPTConfig(
148 n_layers=2,
149 n_heads=2,
150 embedding_size=16,
151 n_positions=16,
152 n_vocab=128,
153 )
155 model_a = MyGPT(config)
156 model_b = "Not a ConfiguredModel instance"
158 try:
159 assert_model_cfg_equality(model_a, model_b) # type: ignore
160 except AssertionError as exc:
161 assert str(exc) == "model_b must be a ConfiguredModel"