zanj.torchutil
1from __future__ import annotations 2 3import abc 4import typing 5import warnings 6from typing import Any, Type, TypeVar 7 8import torch 9from muutils.json_serialize import SerializableDataclass 10from muutils.json_serialize.json_serialize import ObjectPath 11from muutils.json_serialize.util import safe_getsource, string_as_lines, _FORMAT_KEY 12 13from zanj import ZANJ, register_loader_handler 14from zanj.loading import LoaderHandler, load_item_recursive 15 16# pylint: disable=protected-access 17 18KWArgs = Any 19 20 21def num_params(m: torch.nn.Module, only_trainable: bool = True): 22 """return total number of parameters in a model 23 24 - only counting shared parameters once 25 - if `only_trainable` is False, will include parameters with `requires_grad = False` 26 27 https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model 28 """ 29 parameters: list[torch.nn.Parameter] = list(m.parameters()) 30 if only_trainable: 31 parameters = [p for p in parameters if p.requires_grad] 32 33 unique: list[torch.nn.Parameter] = list( 34 {p.data_ptr(): p for p in parameters}.values() 35 ) 36 37 return sum(p.numel() for p in unique) 38 39 40def get_module_device( 41 m: torch.nn.Module, 42) -> tuple[bool, torch.device | dict[str, torch.device]]: 43 """get the current devices""" 44 45 devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()} 46 47 if len(devs) == 0: 48 return False, devs 49 50 # check if all devices are the same by getting one device 51 dev_uni: torch.device = next(iter(devs.values())) 52 53 if all(dev == dev_uni for dev in devs.values()): 54 return True, dev_uni 55 else: 56 return False, devs 57 58 59T_config = TypeVar("T_config", bound=SerializableDataclass) 60 61 62class ConfiguredModel( 63 torch.nn.Module, 64 typing.Generic[T_config], 65 metaclass=abc.ABCMeta, 66): 67 """a model that has a configuration, for saving with ZANJ 68 69 ```python 70 @set_config_class(YourConfig) 71 class YourModule(ConfiguredModel[YourConfig]): 72 def __init__(self, cfg: YourConfig): 73 super().__init__(cfg) 74 ``` 75 76 `__init__()` must initialize the model from a config object only, and call 77 `super().__init__(zanj_model_config)` 78 79 If you are inheriting from another class + ConfiguredModel, 80 ConfiguredModel must be the first class in the inheritance list 81 """ 82 83 # dont set this directly, use `set_config_class()` decorator 84 _config_class: type | None = None 85 zanj_config_class = property(lambda self: type(self)._config_class) 86 87 def __init__(self, zanj_model_config: T_config, **kwargs): 88 super().__init__(**kwargs) 89 if self.zanj_config_class is None: 90 raise NotImplementedError("you need to set `config_class` for your model") 91 if not isinstance(zanj_model_config, self.zanj_config_class): # type: ignore 92 raise TypeError( 93 f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }" 94 ) 95 96 self.zanj_model_config: T_config = zanj_model_config 97 self.training_records: dict | None = None 98 99 def serialize( 100 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None 101 ) -> dict[str, Any]: 102 if zanj is None: 103 zanj = ZANJ() 104 obj = dict( 105 zanj_model_config=self.zanj_model_config.serialize(), 106 meta=dict( 107 class_name=self.__class__.__name__, 108 class_doc=string_as_lines(self.__class__.__doc__), 109 class_source=safe_getsource(self.__class__), 110 module_name=self.__class__.__module__, 111 module_mro=[str(x) for x in self.__class__.__mro__], 112 num_params=num_params(self), 113 as_str=string_as_lines(str(self)), 114 ), 115 training_records=self.training_records, 116 state_dict=self.state_dict(), 117 __muutils_format__=self.__class__.__name__, 118 ) 119 return obj 120 121 def save(self, file_path: str, zanj: ZANJ | None = None): 122 if zanj is None: 123 zanj = ZANJ() 124 zanj.save(self.serialize(), file_path) 125 126 def _load_state_dict_wrapper( 127 self, 128 state_dict: dict[str, torch.Tensor], 129 **kwargs, 130 ): 131 """wrapper for `load_state_dict()` in case you need to override it""" 132 assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}" 133 return self.load_state_dict(state_dict) 134 135 @classmethod 136 def load( 137 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None 138 ) -> "ConfiguredModel": 139 """load a model from a serialized object""" 140 141 if zanj is None: 142 zanj = ZANJ() 143 144 # get the config 145 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore 146 147 # get the training records 148 training_records: typing.Any = load_item_recursive( 149 obj.get("training_records", None), 150 tuple(path) + ("training_records",), 151 zanj, 152 ) 153 154 # initialize the model 155 model: "ConfiguredModel" = cls(zanj_model_config) 156 157 # load the state dict 158 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive( 159 obj["state_dict"], 160 tuple(path) + ("state_dict",), 161 zanj, 162 ) 163 164 model._load_state_dict_wrapper( 165 tensored_state_dict, 166 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()), 167 ) 168 169 # set the training records 170 model.training_records = training_records 171 172 return model 173 174 @classmethod 175 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 176 """read a model from a file""" 177 if zanj is None: 178 zanj = ZANJ() 179 180 mdl: ConfiguredModel = zanj.read(file_path) 181 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}" 182 return mdl 183 184 @classmethod 185 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 186 """read a model from a file""" 187 warnings.warn( 188 "load_file() is deprecated, use read() instead", DeprecationWarning 189 ) 190 return cls.read(file_path, zanj) 191 192 @classmethod 193 def get_handler(cls) -> LoaderHandler: 194 cls_name: str = str(cls.__name__) 195 return LoaderHandler( 196 check=lambda json_item, path=None, z=None: ( # type: ignore 197 isinstance(json_item, dict) 198 and _FORMAT_KEY in json_item 199 and json_item[_FORMAT_KEY].startswith(cls_name) 200 ), 201 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore 202 uid=cls_name, 203 source_pckg=cls.__module__, 204 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel", 205 ) 206 207 def num_params(self) -> int: 208 return num_params(self) 209 210 211def set_config_class( 212 config_class: Type[SerializableDataclass], 213) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]: 214 if not issubclass(config_class, SerializableDataclass): 215 raise TypeError(f"{config_class} must be a subclass of SerializableDataclass") 216 217 def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]: 218 # set the config class 219 cls._config_class = config_class 220 221 # register the handlers 222 register_loader_handler(cls.get_handler()) 223 224 # return the new class 225 return cls 226 227 return wrapper 228 229 230class ConfigMismatchException(ValueError): 231 def __init__(self, msg: str, diff): 232 super().__init__(msg) 233 self.diff = diff 234 235 def __str__(self): 236 return f"{super().__str__()}: {self.diff}" 237 238 239def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 240 """check both models are correct instances and have the same config 241 242 Raises: 243 ConfigMismatchException: if the configs don't match, e.diff will contain the diff 244 """ 245 assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel" 246 assert isinstance(model_a.zanj_model_config, SerializableDataclass), ( 247 "model_a must have a zanj_model_config" 248 ) 249 assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel" 250 assert isinstance(model_b.zanj_model_config, SerializableDataclass), ( 251 "model_b must have a zanj_model_config" 252 ) 253 254 cls_type: type = type(model_a.zanj_model_config) 255 256 if not (model_a.zanj_model_config == model_b.zanj_model_config): 257 raise ConfigMismatchException( 258 f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match", 259 diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config), # type: ignore[attr-defined] 260 ) 261 262 263def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 264 """check the models are exactly equal, including state dict contents""" 265 assert_model_cfg_equality(model_a, model_b) 266 267 model_a_sd_keys: set[str] = set(model_a.state_dict().keys()) 268 model_b_sd_keys: set[str] = set(model_b.state_dict().keys()) 269 assert model_a_sd_keys == model_b_sd_keys, ( 270 f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}" 271 ) 272 keys_failed: list[str] = list() 273 for k, v_a in model_a.state_dict().items(): 274 v_b = model_b.state_dict()[k] 275 if not (v_a == v_b).all(): 276 # if not torch.allclose(v, v_load): 277 keys_failed.append(k) 278 print(f"failed {k}") 279 else: 280 print(f"passed {k}") 281 assert len(keys_failed) == 0, ( 282 f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}" 283 )
22def num_params(m: torch.nn.Module, only_trainable: bool = True): 23 """return total number of parameters in a model 24 25 - only counting shared parameters once 26 - if `only_trainable` is False, will include parameters with `requires_grad = False` 27 28 https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model 29 """ 30 parameters: list[torch.nn.Parameter] = list(m.parameters()) 31 if only_trainable: 32 parameters = [p for p in parameters if p.requires_grad] 33 34 unique: list[torch.nn.Parameter] = list( 35 {p.data_ptr(): p for p in parameters}.values() 36 ) 37 38 return sum(p.numel() for p in unique)
return total number of parameters in a model
- only counting shared parameters once
- if
only_trainable
is False, will include parameters withrequires_grad = False
https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
41def get_module_device( 42 m: torch.nn.Module, 43) -> tuple[bool, torch.device | dict[str, torch.device]]: 44 """get the current devices""" 45 46 devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()} 47 48 if len(devs) == 0: 49 return False, devs 50 51 # check if all devices are the same by getting one device 52 dev_uni: torch.device = next(iter(devs.values())) 53 54 if all(dev == dev_uni for dev in devs.values()): 55 return True, dev_uni 56 else: 57 return False, devs
get the current devices
63class ConfiguredModel( 64 torch.nn.Module, 65 typing.Generic[T_config], 66 metaclass=abc.ABCMeta, 67): 68 """a model that has a configuration, for saving with ZANJ 69 70 ```python 71 @set_config_class(YourConfig) 72 class YourModule(ConfiguredModel[YourConfig]): 73 def __init__(self, cfg: YourConfig): 74 super().__init__(cfg) 75 ``` 76 77 `__init__()` must initialize the model from a config object only, and call 78 `super().__init__(zanj_model_config)` 79 80 If you are inheriting from another class + ConfiguredModel, 81 ConfiguredModel must be the first class in the inheritance list 82 """ 83 84 # dont set this directly, use `set_config_class()` decorator 85 _config_class: type | None = None 86 zanj_config_class = property(lambda self: type(self)._config_class) 87 88 def __init__(self, zanj_model_config: T_config, **kwargs): 89 super().__init__(**kwargs) 90 if self.zanj_config_class is None: 91 raise NotImplementedError("you need to set `config_class` for your model") 92 if not isinstance(zanj_model_config, self.zanj_config_class): # type: ignore 93 raise TypeError( 94 f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }" 95 ) 96 97 self.zanj_model_config: T_config = zanj_model_config 98 self.training_records: dict | None = None 99 100 def serialize( 101 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None 102 ) -> dict[str, Any]: 103 if zanj is None: 104 zanj = ZANJ() 105 obj = dict( 106 zanj_model_config=self.zanj_model_config.serialize(), 107 meta=dict( 108 class_name=self.__class__.__name__, 109 class_doc=string_as_lines(self.__class__.__doc__), 110 class_source=safe_getsource(self.__class__), 111 module_name=self.__class__.__module__, 112 module_mro=[str(x) for x in self.__class__.__mro__], 113 num_params=num_params(self), 114 as_str=string_as_lines(str(self)), 115 ), 116 training_records=self.training_records, 117 state_dict=self.state_dict(), 118 __muutils_format__=self.__class__.__name__, 119 ) 120 return obj 121 122 def save(self, file_path: str, zanj: ZANJ | None = None): 123 if zanj is None: 124 zanj = ZANJ() 125 zanj.save(self.serialize(), file_path) 126 127 def _load_state_dict_wrapper( 128 self, 129 state_dict: dict[str, torch.Tensor], 130 **kwargs, 131 ): 132 """wrapper for `load_state_dict()` in case you need to override it""" 133 assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}" 134 return self.load_state_dict(state_dict) 135 136 @classmethod 137 def load( 138 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None 139 ) -> "ConfiguredModel": 140 """load a model from a serialized object""" 141 142 if zanj is None: 143 zanj = ZANJ() 144 145 # get the config 146 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore 147 148 # get the training records 149 training_records: typing.Any = load_item_recursive( 150 obj.get("training_records", None), 151 tuple(path) + ("training_records",), 152 zanj, 153 ) 154 155 # initialize the model 156 model: "ConfiguredModel" = cls(zanj_model_config) 157 158 # load the state dict 159 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive( 160 obj["state_dict"], 161 tuple(path) + ("state_dict",), 162 zanj, 163 ) 164 165 model._load_state_dict_wrapper( 166 tensored_state_dict, 167 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()), 168 ) 169 170 # set the training records 171 model.training_records = training_records 172 173 return model 174 175 @classmethod 176 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 177 """read a model from a file""" 178 if zanj is None: 179 zanj = ZANJ() 180 181 mdl: ConfiguredModel = zanj.read(file_path) 182 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}" 183 return mdl 184 185 @classmethod 186 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 187 """read a model from a file""" 188 warnings.warn( 189 "load_file() is deprecated, use read() instead", DeprecationWarning 190 ) 191 return cls.read(file_path, zanj) 192 193 @classmethod 194 def get_handler(cls) -> LoaderHandler: 195 cls_name: str = str(cls.__name__) 196 return LoaderHandler( 197 check=lambda json_item, path=None, z=None: ( # type: ignore 198 isinstance(json_item, dict) 199 and _FORMAT_KEY in json_item 200 and json_item[_FORMAT_KEY].startswith(cls_name) 201 ), 202 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore 203 uid=cls_name, 204 source_pckg=cls.__module__, 205 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel", 206 ) 207 208 def num_params(self) -> int: 209 return num_params(self)
a model that has a configuration, for saving with ZANJ
@set_config_class(YourConfig)
class YourModule(ConfiguredModel[YourConfig]):
def __init__(self, cfg: YourConfig):
super().__init__(cfg)
__init__()
must initialize the model from a config object only, and call
super().__init__(zanj_model_config)
If you are inheriting from another class + ConfiguredModel, ConfiguredModel must be the first class in the inheritance list
100 def serialize( 101 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None 102 ) -> dict[str, Any]: 103 if zanj is None: 104 zanj = ZANJ() 105 obj = dict( 106 zanj_model_config=self.zanj_model_config.serialize(), 107 meta=dict( 108 class_name=self.__class__.__name__, 109 class_doc=string_as_lines(self.__class__.__doc__), 110 class_source=safe_getsource(self.__class__), 111 module_name=self.__class__.__module__, 112 module_mro=[str(x) for x in self.__class__.__mro__], 113 num_params=num_params(self), 114 as_str=string_as_lines(str(self)), 115 ), 116 training_records=self.training_records, 117 state_dict=self.state_dict(), 118 __muutils_format__=self.__class__.__name__, 119 ) 120 return obj
136 @classmethod 137 def load( 138 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None 139 ) -> "ConfiguredModel": 140 """load a model from a serialized object""" 141 142 if zanj is None: 143 zanj = ZANJ() 144 145 # get the config 146 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore 147 148 # get the training records 149 training_records: typing.Any = load_item_recursive( 150 obj.get("training_records", None), 151 tuple(path) + ("training_records",), 152 zanj, 153 ) 154 155 # initialize the model 156 model: "ConfiguredModel" = cls(zanj_model_config) 157 158 # load the state dict 159 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive( 160 obj["state_dict"], 161 tuple(path) + ("state_dict",), 162 zanj, 163 ) 164 165 model._load_state_dict_wrapper( 166 tensored_state_dict, 167 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()), 168 ) 169 170 # set the training records 171 model.training_records = training_records 172 173 return model
load a model from a serialized object
175 @classmethod 176 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 177 """read a model from a file""" 178 if zanj is None: 179 zanj = ZANJ() 180 181 mdl: ConfiguredModel = zanj.read(file_path) 182 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}" 183 return mdl
read a model from a file
185 @classmethod 186 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 187 """read a model from a file""" 188 warnings.warn( 189 "load_file() is deprecated, use read() instead", DeprecationWarning 190 ) 191 return cls.read(file_path, zanj)
read a model from a file
193 @classmethod 194 def get_handler(cls) -> LoaderHandler: 195 cls_name: str = str(cls.__name__) 196 return LoaderHandler( 197 check=lambda json_item, path=None, z=None: ( # type: ignore 198 isinstance(json_item, dict) 199 and _FORMAT_KEY in json_item 200 and json_item[_FORMAT_KEY].startswith(cls_name) 201 ), 202 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore 203 uid=cls_name, 204 source_pckg=cls.__module__, 205 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel", 206 )
Inherited Members
- torch.nn.modules.module.Module
- Module
- dump_patches
- training
- call_super_init
- forward
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- set_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- mtia
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_post_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_pre_hook
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
212def set_config_class( 213 config_class: Type[SerializableDataclass], 214) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]: 215 if not issubclass(config_class, SerializableDataclass): 216 raise TypeError(f"{config_class} must be a subclass of SerializableDataclass") 217 218 def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]: 219 # set the config class 220 cls._config_class = config_class 221 222 # register the handlers 223 register_loader_handler(cls.get_handler()) 224 225 # return the new class 226 return cls 227 228 return wrapper
231class ConfigMismatchException(ValueError): 232 def __init__(self, msg: str, diff): 233 super().__init__(msg) 234 self.diff = diff 235 236 def __str__(self): 237 return f"{super().__str__()}: {self.diff}"
Inappropriate argument value (of correct type).
Inherited Members
- builtins.BaseException
- with_traceback
- add_note
- args
240def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 241 """check both models are correct instances and have the same config 242 243 Raises: 244 ConfigMismatchException: if the configs don't match, e.diff will contain the diff 245 """ 246 assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel" 247 assert isinstance(model_a.zanj_model_config, SerializableDataclass), ( 248 "model_a must have a zanj_model_config" 249 ) 250 assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel" 251 assert isinstance(model_b.zanj_model_config, SerializableDataclass), ( 252 "model_b must have a zanj_model_config" 253 ) 254 255 cls_type: type = type(model_a.zanj_model_config) 256 257 if not (model_a.zanj_model_config == model_b.zanj_model_config): 258 raise ConfigMismatchException( 259 f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match", 260 diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config), # type: ignore[attr-defined] 261 )
check both models are correct instances and have the same config
Raises: ConfigMismatchException: if the configs don't match, e.diff will contain the diff
264def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 265 """check the models are exactly equal, including state dict contents""" 266 assert_model_cfg_equality(model_a, model_b) 267 268 model_a_sd_keys: set[str] = set(model_a.state_dict().keys()) 269 model_b_sd_keys: set[str] = set(model_b.state_dict().keys()) 270 assert model_a_sd_keys == model_b_sd_keys, ( 271 f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}" 272 ) 273 keys_failed: list[str] = list() 274 for k, v_a in model_a.state_dict().items(): 275 v_b = model_b.state_dict()[k] 276 if not (v_a == v_b).all(): 277 # if not torch.allclose(v, v_load): 278 keys_failed.append(k) 279 print(f"failed {k}") 280 else: 281 print(f"passed {k}") 282 assert len(keys_failed) == 0, ( 283 f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}" 284 )
check the models are exactly equal, including state dict contents