docs for zanj v0.4.0
View Source on GitHub

zanj.torchutil


  1from __future__ import annotations
  2
  3import abc
  4import typing
  5import warnings
  6from typing import Any, Type, TypeVar
  7
  8import torch
  9from muutils.json_serialize import SerializableDataclass
 10from muutils.json_serialize.json_serialize import ObjectPath
 11from muutils.json_serialize.util import safe_getsource, string_as_lines, _FORMAT_KEY
 12
 13from zanj import ZANJ, register_loader_handler
 14from zanj.loading import LoaderHandler, load_item_recursive
 15
 16# pylint: disable=protected-access
 17
 18KWArgs = Any
 19
 20
 21def num_params(m: torch.nn.Module, only_trainable: bool = True):
 22    """return total number of parameters in a model
 23
 24    - only counting shared parameters once
 25    - if `only_trainable` is False, will include parameters with `requires_grad = False`
 26
 27    https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
 28    """
 29    parameters: list[torch.nn.Parameter] = list(m.parameters())
 30    if only_trainable:
 31        parameters = [p for p in parameters if p.requires_grad]
 32
 33    unique: list[torch.nn.Parameter] = list(
 34        {p.data_ptr(): p for p in parameters}.values()
 35    )
 36
 37    return sum(p.numel() for p in unique)
 38
 39
 40def get_module_device(
 41    m: torch.nn.Module,
 42) -> tuple[bool, torch.device | dict[str, torch.device]]:
 43    """get the current devices"""
 44
 45    devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()}
 46
 47    if len(devs) == 0:
 48        return False, devs
 49
 50    # check if all devices are the same by getting one device
 51    dev_uni: torch.device = next(iter(devs.values()))
 52
 53    if all(dev == dev_uni for dev in devs.values()):
 54        return True, dev_uni
 55    else:
 56        return False, devs
 57
 58
 59T_config = TypeVar("T_config", bound=SerializableDataclass)
 60
 61
 62class ConfiguredModel(
 63    torch.nn.Module,
 64    typing.Generic[T_config],
 65    metaclass=abc.ABCMeta,
 66):
 67    """a model that has a configuration, for saving with ZANJ
 68
 69    ```python
 70    @set_config_class(YourConfig)
 71    class YourModule(ConfiguredModel[YourConfig]):
 72        def __init__(self, cfg: YourConfig):
 73            super().__init__(cfg)
 74    ```
 75
 76    `__init__()` must initialize the model from a config object only, and call
 77    `super().__init__(zanj_model_config)`
 78
 79    If you are inheriting from another class + ConfiguredModel,
 80    ConfiguredModel must be the first class in the inheritance list
 81    """
 82
 83    # dont set this directly, use `set_config_class()` decorator
 84    _config_class: type | None = None
 85    zanj_config_class = property(lambda self: type(self)._config_class)
 86
 87    def __init__(self, zanj_model_config: T_config, **kwargs):
 88        super().__init__(**kwargs)
 89        if self.zanj_config_class is None:
 90            raise NotImplementedError("you need to set `config_class` for your model")
 91        if not isinstance(zanj_model_config, self.zanj_config_class):  # type: ignore
 92            raise TypeError(
 93                f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }"
 94            )
 95
 96        self.zanj_model_config: T_config = zanj_model_config
 97        self.training_records: dict | None = None
 98
 99    def serialize(
100        self, path: ObjectPath = tuple(), zanj: ZANJ | None = None
101    ) -> dict[str, Any]:
102        if zanj is None:
103            zanj = ZANJ()
104        obj = dict(
105            zanj_model_config=self.zanj_model_config.serialize(),
106            meta=dict(
107                class_name=self.__class__.__name__,
108                class_doc=string_as_lines(self.__class__.__doc__),
109                class_source=safe_getsource(self.__class__),
110                module_name=self.__class__.__module__,
111                module_mro=[str(x) for x in self.__class__.__mro__],
112                num_params=num_params(self),
113                as_str=string_as_lines(str(self)),
114            ),
115            training_records=self.training_records,
116            state_dict=self.state_dict(),
117            __muutils_format__=self.__class__.__name__,
118        )
119        return obj
120
121    def save(self, file_path: str, zanj: ZANJ | None = None):
122        if zanj is None:
123            zanj = ZANJ()
124        zanj.save(self.serialize(), file_path)
125
126    def _load_state_dict_wrapper(
127        self,
128        state_dict: dict[str, torch.Tensor],
129        **kwargs,
130    ):
131        """wrapper for `load_state_dict()` in case you need to override it"""
132        assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}"
133        return self.load_state_dict(state_dict)
134
135    @classmethod
136    def load(
137        cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None
138    ) -> "ConfiguredModel":
139        """load a model from a serialized object"""
140
141        if zanj is None:
142            zanj = ZANJ()
143
144        # get the config
145        zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"])  # type: ignore
146
147        # get the training records
148        training_records: typing.Any = load_item_recursive(
149            obj.get("training_records", None),
150            tuple(path) + ("training_records",),
151            zanj,
152        )
153
154        # initialize the model
155        model: "ConfiguredModel" = cls(zanj_model_config)
156
157        # load the state dict
158        tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive(
159            obj["state_dict"],
160            tuple(path) + ("state_dict",),
161            zanj,
162        )
163
164        model._load_state_dict_wrapper(
165            tensored_state_dict,
166            **zanj.custom_settings.get("_load_state_dict_wrapper", dict()),
167        )
168
169        # set the training records
170        model.training_records = training_records
171
172        return model
173
174    @classmethod
175    def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
176        """read a model from a file"""
177        if zanj is None:
178            zanj = ZANJ()
179
180        mdl: ConfiguredModel = zanj.read(file_path)
181        assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}"
182        return mdl
183
184    @classmethod
185    def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
186        """read a model from a file"""
187        warnings.warn(
188            "load_file() is deprecated, use read() instead", DeprecationWarning
189        )
190        return cls.read(file_path, zanj)
191
192    @classmethod
193    def get_handler(cls) -> LoaderHandler:
194        cls_name: str = str(cls.__name__)
195        return LoaderHandler(
196            check=lambda json_item, path=None, z=None: (  # type: ignore
197                isinstance(json_item, dict)
198                and _FORMAT_KEY in json_item
199                and json_item[_FORMAT_KEY].startswith(cls_name)
200            ),
201            load=lambda json_item, path=None, z=None: cls.load(json_item, path, z),  # type: ignore
202            uid=cls_name,
203            source_pckg=cls.__module__,
204            desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel",
205        )
206
207    def num_params(self) -> int:
208        return num_params(self)
209
210
211def set_config_class(
212    config_class: Type[SerializableDataclass],
213) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]:
214    if not issubclass(config_class, SerializableDataclass):
215        raise TypeError(f"{config_class} must be a subclass of SerializableDataclass")
216
217    def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]:
218        # set the config class
219        cls._config_class = config_class
220
221        # register the handlers
222        register_loader_handler(cls.get_handler())
223
224        # return the new class
225        return cls
226
227    return wrapper
228
229
230class ConfigMismatchException(ValueError):
231    def __init__(self, msg: str, diff):
232        super().__init__(msg)
233        self.diff = diff
234
235    def __str__(self):
236        return f"{super().__str__()}: {self.diff}"
237
238
239def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
240    """check both models are correct instances and have the same config
241
242    Raises:
243        ConfigMismatchException: if the configs don't match, e.diff will contain the diff
244    """
245    assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel"
246    assert isinstance(model_a.zanj_model_config, SerializableDataclass), (
247        "model_a must have a zanj_model_config"
248    )
249    assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel"
250    assert isinstance(model_b.zanj_model_config, SerializableDataclass), (
251        "model_b must have a zanj_model_config"
252    )
253
254    cls_type: type = type(model_a.zanj_model_config)
255
256    if not (model_a.zanj_model_config == model_b.zanj_model_config):
257        raise ConfigMismatchException(
258            f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match",
259            diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config),  # type: ignore[attr-defined]
260        )
261
262
263def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
264    """check the models are exactly equal, including state dict contents"""
265    assert_model_cfg_equality(model_a, model_b)
266
267    model_a_sd_keys: set[str] = set(model_a.state_dict().keys())
268    model_b_sd_keys: set[str] = set(model_b.state_dict().keys())
269    assert model_a_sd_keys == model_b_sd_keys, (
270        f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}"
271    )
272    keys_failed: list[str] = list()
273    for k, v_a in model_a.state_dict().items():
274        v_b = model_b.state_dict()[k]
275        if not (v_a == v_b).all():
276            # if not torch.allclose(v, v_load):
277            keys_failed.append(k)
278            print(f"failed {k}")
279        else:
280            print(f"passed {k}")
281    assert len(keys_failed) == 0, (
282        f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}"
283    )

KWArgs = typing.Any
def num_params(m: torch.nn.modules.module.Module, only_trainable: bool = True):
22def num_params(m: torch.nn.Module, only_trainable: bool = True):
23    """return total number of parameters in a model
24
25    - only counting shared parameters once
26    - if `only_trainable` is False, will include parameters with `requires_grad = False`
27
28    https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
29    """
30    parameters: list[torch.nn.Parameter] = list(m.parameters())
31    if only_trainable:
32        parameters = [p for p in parameters if p.requires_grad]
33
34    unique: list[torch.nn.Parameter] = list(
35        {p.data_ptr(): p for p in parameters}.values()
36    )
37
38    return sum(p.numel() for p in unique)

return total number of parameters in a model

  • only counting shared parameters once
  • if only_trainable is False, will include parameters with requires_grad = False

https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model

def get_module_device( m: torch.nn.modules.module.Module) -> tuple[bool, torch.device | dict[str, torch.device]]:
41def get_module_device(
42    m: torch.nn.Module,
43) -> tuple[bool, torch.device | dict[str, torch.device]]:
44    """get the current devices"""
45
46    devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()}
47
48    if len(devs) == 0:
49        return False, devs
50
51    # check if all devices are the same by getting one device
52    dev_uni: torch.device = next(iter(devs.values()))
53
54    if all(dev == dev_uni for dev in devs.values()):
55        return True, dev_uni
56    else:
57        return False, devs

get the current devices

class ConfiguredModel(torch.nn.modules.module.Module, typing.Generic[~T_config]):
 63class ConfiguredModel(
 64    torch.nn.Module,
 65    typing.Generic[T_config],
 66    metaclass=abc.ABCMeta,
 67):
 68    """a model that has a configuration, for saving with ZANJ
 69
 70    ```python
 71    @set_config_class(YourConfig)
 72    class YourModule(ConfiguredModel[YourConfig]):
 73        def __init__(self, cfg: YourConfig):
 74            super().__init__(cfg)
 75    ```
 76
 77    `__init__()` must initialize the model from a config object only, and call
 78    `super().__init__(zanj_model_config)`
 79
 80    If you are inheriting from another class + ConfiguredModel,
 81    ConfiguredModel must be the first class in the inheritance list
 82    """
 83
 84    # dont set this directly, use `set_config_class()` decorator
 85    _config_class: type | None = None
 86    zanj_config_class = property(lambda self: type(self)._config_class)
 87
 88    def __init__(self, zanj_model_config: T_config, **kwargs):
 89        super().__init__(**kwargs)
 90        if self.zanj_config_class is None:
 91            raise NotImplementedError("you need to set `config_class` for your model")
 92        if not isinstance(zanj_model_config, self.zanj_config_class):  # type: ignore
 93            raise TypeError(
 94                f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }"
 95            )
 96
 97        self.zanj_model_config: T_config = zanj_model_config
 98        self.training_records: dict | None = None
 99
100    def serialize(
101        self, path: ObjectPath = tuple(), zanj: ZANJ | None = None
102    ) -> dict[str, Any]:
103        if zanj is None:
104            zanj = ZANJ()
105        obj = dict(
106            zanj_model_config=self.zanj_model_config.serialize(),
107            meta=dict(
108                class_name=self.__class__.__name__,
109                class_doc=string_as_lines(self.__class__.__doc__),
110                class_source=safe_getsource(self.__class__),
111                module_name=self.__class__.__module__,
112                module_mro=[str(x) for x in self.__class__.__mro__],
113                num_params=num_params(self),
114                as_str=string_as_lines(str(self)),
115            ),
116            training_records=self.training_records,
117            state_dict=self.state_dict(),
118            __muutils_format__=self.__class__.__name__,
119        )
120        return obj
121
122    def save(self, file_path: str, zanj: ZANJ | None = None):
123        if zanj is None:
124            zanj = ZANJ()
125        zanj.save(self.serialize(), file_path)
126
127    def _load_state_dict_wrapper(
128        self,
129        state_dict: dict[str, torch.Tensor],
130        **kwargs,
131    ):
132        """wrapper for `load_state_dict()` in case you need to override it"""
133        assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}"
134        return self.load_state_dict(state_dict)
135
136    @classmethod
137    def load(
138        cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None
139    ) -> "ConfiguredModel":
140        """load a model from a serialized object"""
141
142        if zanj is None:
143            zanj = ZANJ()
144
145        # get the config
146        zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"])  # type: ignore
147
148        # get the training records
149        training_records: typing.Any = load_item_recursive(
150            obj.get("training_records", None),
151            tuple(path) + ("training_records",),
152            zanj,
153        )
154
155        # initialize the model
156        model: "ConfiguredModel" = cls(zanj_model_config)
157
158        # load the state dict
159        tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive(
160            obj["state_dict"],
161            tuple(path) + ("state_dict",),
162            zanj,
163        )
164
165        model._load_state_dict_wrapper(
166            tensored_state_dict,
167            **zanj.custom_settings.get("_load_state_dict_wrapper", dict()),
168        )
169
170        # set the training records
171        model.training_records = training_records
172
173        return model
174
175    @classmethod
176    def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
177        """read a model from a file"""
178        if zanj is None:
179            zanj = ZANJ()
180
181        mdl: ConfiguredModel = zanj.read(file_path)
182        assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}"
183        return mdl
184
185    @classmethod
186    def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
187        """read a model from a file"""
188        warnings.warn(
189            "load_file() is deprecated, use read() instead", DeprecationWarning
190        )
191        return cls.read(file_path, zanj)
192
193    @classmethod
194    def get_handler(cls) -> LoaderHandler:
195        cls_name: str = str(cls.__name__)
196        return LoaderHandler(
197            check=lambda json_item, path=None, z=None: (  # type: ignore
198                isinstance(json_item, dict)
199                and _FORMAT_KEY in json_item
200                and json_item[_FORMAT_KEY].startswith(cls_name)
201            ),
202            load=lambda json_item, path=None, z=None: cls.load(json_item, path, z),  # type: ignore
203            uid=cls_name,
204            source_pckg=cls.__module__,
205            desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel",
206        )
207
208    def num_params(self) -> int:
209        return num_params(self)

a model that has a configuration, for saving with ZANJ

@set_config_class(YourConfig)
class YourModule(ConfiguredModel[YourConfig]):
    def __init__(self, cfg: YourConfig):
        super().__init__(cfg)

__init__() must initialize the model from a config object only, and call super().__init__(zanj_model_config)

If you are inheriting from another class + ConfiguredModel, ConfiguredModel must be the first class in the inheritance list

zanj_config_class
86    zanj_config_class = property(lambda self: type(self)._config_class)
zanj_model_config: ~T_config
training_records: dict | None
def serialize( self, path: tuple[typing.Union[str, int], ...] = (), zanj: zanj.ZANJ | None = None) -> dict[str, typing.Any]:
100    def serialize(
101        self, path: ObjectPath = tuple(), zanj: ZANJ | None = None
102    ) -> dict[str, Any]:
103        if zanj is None:
104            zanj = ZANJ()
105        obj = dict(
106            zanj_model_config=self.zanj_model_config.serialize(),
107            meta=dict(
108                class_name=self.__class__.__name__,
109                class_doc=string_as_lines(self.__class__.__doc__),
110                class_source=safe_getsource(self.__class__),
111                module_name=self.__class__.__module__,
112                module_mro=[str(x) for x in self.__class__.__mro__],
113                num_params=num_params(self),
114                as_str=string_as_lines(str(self)),
115            ),
116            training_records=self.training_records,
117            state_dict=self.state_dict(),
118            __muutils_format__=self.__class__.__name__,
119        )
120        return obj
def save(self, file_path: str, zanj: zanj.ZANJ | None = None):
122    def save(self, file_path: str, zanj: ZANJ | None = None):
123        if zanj is None:
124            zanj = ZANJ()
125        zanj.save(self.serialize(), file_path)
@classmethod
def load( cls, obj: dict[str, typing.Any], path: tuple[typing.Union[str, int], ...], zanj: zanj.ZANJ | None = None) -> ConfiguredModel:
136    @classmethod
137    def load(
138        cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None
139    ) -> "ConfiguredModel":
140        """load a model from a serialized object"""
141
142        if zanj is None:
143            zanj = ZANJ()
144
145        # get the config
146        zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"])  # type: ignore
147
148        # get the training records
149        training_records: typing.Any = load_item_recursive(
150            obj.get("training_records", None),
151            tuple(path) + ("training_records",),
152            zanj,
153        )
154
155        # initialize the model
156        model: "ConfiguredModel" = cls(zanj_model_config)
157
158        # load the state dict
159        tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive(
160            obj["state_dict"],
161            tuple(path) + ("state_dict",),
162            zanj,
163        )
164
165        model._load_state_dict_wrapper(
166            tensored_state_dict,
167            **zanj.custom_settings.get("_load_state_dict_wrapper", dict()),
168        )
169
170        # set the training records
171        model.training_records = training_records
172
173        return model

load a model from a serialized object

@classmethod
def read( cls, file_path: str, zanj: zanj.ZANJ | None = None) -> ConfiguredModel:
175    @classmethod
176    def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
177        """read a model from a file"""
178        if zanj is None:
179            zanj = ZANJ()
180
181        mdl: ConfiguredModel = zanj.read(file_path)
182        assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}"
183        return mdl

read a model from a file

@classmethod
def load_file( cls, file_path: str, zanj: zanj.ZANJ | None = None) -> ConfiguredModel:
185    @classmethod
186    def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel":
187        """read a model from a file"""
188        warnings.warn(
189            "load_file() is deprecated, use read() instead", DeprecationWarning
190        )
191        return cls.read(file_path, zanj)

read a model from a file

@classmethod
def get_handler(cls) -> zanj.loading.LoaderHandler:
193    @classmethod
194    def get_handler(cls) -> LoaderHandler:
195        cls_name: str = str(cls.__name__)
196        return LoaderHandler(
197            check=lambda json_item, path=None, z=None: (  # type: ignore
198                isinstance(json_item, dict)
199                and _FORMAT_KEY in json_item
200                and json_item[_FORMAT_KEY].startswith(cls_name)
201            ),
202            load=lambda json_item, path=None, z=None: cls.load(json_item, path, z),  # type: ignore
203            uid=cls_name,
204            source_pckg=cls.__module__,
205            desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel",
206        )
def num_params(self) -> int:
208    def num_params(self) -> int:
209        return num_params(self)
Inherited Members
torch.nn.modules.module.Module
Module
dump_patches
training
call_super_init
forward
register_buffer
register_parameter
add_module
register_module
get_submodule
set_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
mtia
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_post_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_pre_hook
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def set_config_class( config_class: Type[muutils.json_serialize.serializable_dataclass.SerializableDataclass]) -> Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]:
212def set_config_class(
213    config_class: Type[SerializableDataclass],
214) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]:
215    if not issubclass(config_class, SerializableDataclass):
216        raise TypeError(f"{config_class} must be a subclass of SerializableDataclass")
217
218    def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]:
219        # set the config class
220        cls._config_class = config_class
221
222        # register the handlers
223        register_loader_handler(cls.get_handler())
224
225        # return the new class
226        return cls
227
228    return wrapper
class ConfigMismatchException(builtins.ValueError):
231class ConfigMismatchException(ValueError):
232    def __init__(self, msg: str, diff):
233        super().__init__(msg)
234        self.diff = diff
235
236    def __str__(self):
237        return f"{super().__str__()}: {self.diff}"

Inappropriate argument value (of correct type).

ConfigMismatchException(msg: str, diff)
232    def __init__(self, msg: str, diff):
233        super().__init__(msg)
234        self.diff = diff
diff
Inherited Members
builtins.BaseException
with_traceback
add_note
args
def assert_model_cfg_equality( model_a: ConfiguredModel, model_b: ConfiguredModel):
240def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
241    """check both models are correct instances and have the same config
242
243    Raises:
244        ConfigMismatchException: if the configs don't match, e.diff will contain the diff
245    """
246    assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel"
247    assert isinstance(model_a.zanj_model_config, SerializableDataclass), (
248        "model_a must have a zanj_model_config"
249    )
250    assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel"
251    assert isinstance(model_b.zanj_model_config, SerializableDataclass), (
252        "model_b must have a zanj_model_config"
253    )
254
255    cls_type: type = type(model_a.zanj_model_config)
256
257    if not (model_a.zanj_model_config == model_b.zanj_model_config):
258        raise ConfigMismatchException(
259            f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match",
260            diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config),  # type: ignore[attr-defined]
261        )

check both models are correct instances and have the same config

Raises: ConfigMismatchException: if the configs don't match, e.diff will contain the diff

def assert_model_exact_equality( model_a: ConfiguredModel, model_b: ConfiguredModel):
264def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel):
265    """check the models are exactly equal, including state dict contents"""
266    assert_model_cfg_equality(model_a, model_b)
267
268    model_a_sd_keys: set[str] = set(model_a.state_dict().keys())
269    model_b_sd_keys: set[str] = set(model_b.state_dict().keys())
270    assert model_a_sd_keys == model_b_sd_keys, (
271        f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}"
272    )
273    keys_failed: list[str] = list()
274    for k, v_a in model_a.state_dict().items():
275        v_b = model_b.state_dict()[k]
276        if not (v_a == v_b).all():
277            # if not torch.allclose(v, v_load):
278            keys_failed.append(k)
279            print(f"failed {k}")
280        else:
281            print(f"passed {k}")
282    assert len(keys_failed) == 0, (
283        f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}"
284    )

check the models are exactly equal, including state dict contents