Coverage for zanj\torchutil.py: 88%
113 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
1from __future__ import annotations
3import abc
4import typing
5import warnings
6from typing import Any, Type, TypeVar
8import torch
9from muutils.json_serialize import SerializableDataclass
10from muutils.json_serialize.json_serialize import ObjectPath
11from muutils.json_serialize.util import safe_getsource, string_as_lines, _FORMAT_KEY
13from zanj import ZANJ, register_loader_handler
14from zanj.loading import LoaderHandler, load_item_recursive
16# pylint: disable=protected-access
18KWArgs = Any
21def num_params(m: torch.nn.Module, only_trainable: bool = True):
22 """return total number of parameters in a model
24 - only counting shared parameters once
25 - if `only_trainable` is False, will include parameters with `requires_grad = False`
27 https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
28 """
29 parameters: list[torch.nn.Parameter] = list(m.parameters())
30 if only_trainable:
31 parameters = [p for p in parameters if p.requires_grad]
33 unique: list[torch.nn.Parameter] = list(
34 {p.data_ptr(): p for p in parameters}.values()
35 )
37 return sum(p.numel() for p in unique)
40def get_module_device(
41 m: torch.nn.Module,
42) -> tuple[bool, torch.device | dict[str, torch.device]]:
43 """get the current devices"""
45 devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()}
47 if len(devs) == 0:
48 return False, devs
50 # check if all devices are the same by getting one device
51 dev_uni: torch.device = next(iter(devs.values()))
53 if all(dev == dev_uni for dev in devs.values()):
54 return True, dev_uni
55 else:
56 return False, devs
59T_config = TypeVar("T_config", bound=SerializableDataclass)
62class ConfiguredModel(
63 torch.nn.Module,
64 typing.Generic[T_config],
65 metaclass=abc.ABCMeta,
66):
67 """a model that has a configuration, for saving with ZANJ
69 ```python
70 @set_config_class(YourConfig)
71 class YourModule(ConfiguredModel[YourConfig]):
72 def __init__(self, cfg: YourConfig):
73 super().__init__(cfg)
74 ```
76 `__init__()` must initialize the model from a config object only, and call
77 `super().__init__(zanj_model_config)`
79 If you are inheriting from another class + ConfiguredModel,
80 ConfiguredModel must be the first class in the inheritance list
81 """
83 # dont set this directly, use `set_config_class()` decorator
84 _config_class: type | None = None
85 zanj_config_class = property(lambda self: type(self)._config_class)
87 def __init__(self, zanj_model_config: T_config, **kwargs):
88 super().__init__(**kwargs)
89 if self.zanj_config_class is None:
90 raise NotImplementedError("you need to set `config_class` for your model")
91 if not isinstance(zanj_model_config, self.zanj_config_class): # type: ignore
92 raise TypeError(
93 f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }"
94 )
96 self.zanj_model_config: T_config = zanj_model_config
97 self.training_records: dict | None = None
99 def serialize(
100 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None
101 ) -> dict[str, Any]:
102 if zanj is None:
103 zanj = ZANJ()
104 obj = dict(
105 zanj_model_config=self.zanj_model_config.serialize(),
106 meta=dict(
107 class_name=self.__class__.__name__,
108 class_doc=string_as_lines(self.__class__.__doc__),
109 class_source=safe_getsource(self.__class__),
110 module_name=self.__class__.__module__,
111 module_mro=[str(x) for x in self.__class__.__mro__],
112 num_params=num_params(self),
113 as_str=string_as_lines(str(self)),
114 ),
115 training_records=self.training_records,
116 state_dict=self.state_dict(),
117 __muutils_format__=self.__class__.__name__,
118 )
119 return obj
121 def save(self, file_path: str, zanj: ZANJ | None = None):
122 if zanj is None:
123 zanj = ZANJ()
124 zanj.save(self.serialize(), file_path)
126 def _load_state_dict_wrapper(
127 self,
128 state_dict: dict[str, torch.Tensor],
129 **kwargs,
130 ):
131 """wrapper for `load_state_dict()` in case you need to override it"""
132 assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}"
133 return self.load_state_dict(state_dict)
135 @classmethod
136 def load(
137 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None
138 ) -> "ConfiguredModel":
139 """load a model from a serialized object"""
141 if zanj is None:
142 zanj = ZANJ()
144 # get the config
145 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore
147 # get the training records
148 training_records: typing.Any = load_item_recursive(
149 obj.get("training_records", None),
150 tuple(path) + ("training_records",),
151 zanj,
152 )
154 # initialize the model
155 model: "ConfiguredModel" = cls(zanj_model_config)
157 # load the state dict
158 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive(
159 obj["state_dict"],
160 tuple(path) + ("state_dict",),
161 zanj,
162 )
164 model._load_state_dict_wrapper(
165 tensored_state_dict,
166 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()),
167 )
169 # set the training records
170 model.training_records = training_records
172 return model
174 @classmethod
175 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
176 """read a model from a file"""
177 if zanj is None:
178 zanj = ZANJ()
180 mdl: ConfiguredModel = zanj.read(file_path)
181 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}"
182 return mdl
184 @classmethod
185 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
186 """read a model from a file"""
187 warnings.warn(
188 "load_file() is deprecated, use read() instead", DeprecationWarning
189 )
190 return cls.read(file_path, zanj)
192 @classmethod
193 def get_handler(cls) -> LoaderHandler:
194 cls_name: str = str(cls.__name__)
195 return LoaderHandler(
196 check=lambda json_item, path=None, z=None: ( # type: ignore
197 isinstance(json_item, dict)
198 and _FORMAT_KEY in json_item
199 and json_item[_FORMAT_KEY].startswith(cls_name)
200 ),
201 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore
202 uid=cls_name,
203 source_pckg=cls.__module__,
204 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel",
205 )
207 def num_params(self) -> int:
208 return num_params(self)
211def set_config_class(
212 config_class: Type[SerializableDataclass],
213) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]:
214 if not issubclass(config_class, SerializableDataclass):
215 raise TypeError(f"{config_class} must be a subclass of SerializableDataclass")
217 def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]:
218 # set the config class
219 cls._config_class = config_class
221 # register the handlers
222 register_loader_handler(cls.get_handler())
224 # return the new class
225 return cls
227 return wrapper
230class ConfigMismatchException(ValueError):
231 def __init__(self, msg: str, diff):
232 super().__init__(msg)
233 self.diff = diff
235 def __str__(self):
236 return f"{super().__str__()}: {self.diff}"
239def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
240 """check both models are correct instances and have the same config
242 Raises:
243 ConfigMismatchException: if the configs don't match, e.diff will contain the diff
244 """
245 assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel"
246 assert isinstance(model_a.zanj_model_config, SerializableDataclass), (
247 "model_a must have a zanj_model_config"
248 )
249 assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel"
250 assert isinstance(model_b.zanj_model_config, SerializableDataclass), (
251 "model_b must have a zanj_model_config"
252 )
254 cls_type: type = type(model_a.zanj_model_config)
256 if not (model_a.zanj_model_config == model_b.zanj_model_config):
257 raise ConfigMismatchException(
258 f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match",
259 diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config), # type: ignore[attr-defined]
260 )
263def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
264 """check the models are exactly equal, including state dict contents"""
265 assert_model_cfg_equality(model_a, model_b)
267 model_a_sd_keys: set[str] = set(model_a.state_dict().keys())
268 model_b_sd_keys: set[str] = set(model_b.state_dict().keys())
269 assert model_a_sd_keys == model_b_sd_keys, (
270 f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}"
271 )
272 keys_failed: list[str] = list()
273 for k, v_a in model_a.state_dict().items():
274 v_b = model_b.state_dict()[k]
275 if not (v_a == v_b).all():
276 # if not torch.allclose(v, v_load):
277 keys_failed.append(k)
278 print(f"failed {k}")
279 else:
280 print(f"passed {k}")
281 assert len(keys_failed) == 0, (
282 f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}"
283 )