Coverage for zanj\torchutil.py: 88%

113 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-14 12:57 -0700

1from __future__ import annotations 

2 

3import abc 

4import typing 

5import warnings 

6from typing import Any, Type, TypeVar 

7 

8import torch 

9from muutils.json_serialize import SerializableDataclass 

10from muutils.json_serialize.json_serialize import ObjectPath 

11from muutils.json_serialize.util import safe_getsource, string_as_lines, _FORMAT_KEY 

12 

13from zanj import ZANJ, register_loader_handler 

14from zanj.loading import LoaderHandler, load_item_recursive 

15 

16# pylint: disable=protected-access 

17 

18KWArgs = Any 

19 

20 

21def num_params(m: torch.nn.Module, only_trainable: bool = True): 

22 """return total number of parameters in a model 

23 

24 - only counting shared parameters once 

25 - if `only_trainable` is False, will include parameters with `requires_grad = False` 

26 

27 https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model 

28 """ 

29 parameters: list[torch.nn.Parameter] = list(m.parameters()) 

30 if only_trainable: 

31 parameters = [p for p in parameters if p.requires_grad] 

32 

33 unique: list[torch.nn.Parameter] = list( 

34 {p.data_ptr(): p for p in parameters}.values() 

35 ) 

36 

37 return sum(p.numel() for p in unique) 

38 

39 

40def get_module_device( 

41 m: torch.nn.Module, 

42) -> tuple[bool, torch.device | dict[str, torch.device]]: 

43 """get the current devices""" 

44 

45 devs: dict[str, torch.device] = {name: p.device for name, p in m.named_parameters()} 

46 

47 if len(devs) == 0: 

48 return False, devs 

49 

50 # check if all devices are the same by getting one device 

51 dev_uni: torch.device = next(iter(devs.values())) 

52 

53 if all(dev == dev_uni for dev in devs.values()): 

54 return True, dev_uni 

55 else: 

56 return False, devs 

57 

58 

59T_config = TypeVar("T_config", bound=SerializableDataclass) 

60 

61 

62class ConfiguredModel( 

63 torch.nn.Module, 

64 typing.Generic[T_config], 

65 metaclass=abc.ABCMeta, 

66): 

67 """a model that has a configuration, for saving with ZANJ 

68 

69 ```python 

70 @set_config_class(YourConfig) 

71 class YourModule(ConfiguredModel[YourConfig]): 

72 def __init__(self, cfg: YourConfig): 

73 super().__init__(cfg) 

74 ``` 

75 

76 `__init__()` must initialize the model from a config object only, and call 

77 `super().__init__(zanj_model_config)` 

78 

79 If you are inheriting from another class + ConfiguredModel, 

80 ConfiguredModel must be the first class in the inheritance list 

81 """ 

82 

83 # dont set this directly, use `set_config_class()` decorator 

84 _config_class: type | None = None 

85 zanj_config_class = property(lambda self: type(self)._config_class) 

86 

87 def __init__(self, zanj_model_config: T_config, **kwargs): 

88 super().__init__(**kwargs) 

89 if self.zanj_config_class is None: 

90 raise NotImplementedError("you need to set `config_class` for your model") 

91 if not isinstance(zanj_model_config, self.zanj_config_class): # type: ignore 

92 raise TypeError( 

93 f"config must be an instance of {self.zanj_config_class = }, got {type(zanj_model_config) = }" 

94 ) 

95 

96 self.zanj_model_config: T_config = zanj_model_config 

97 self.training_records: dict | None = None 

98 

99 def serialize( 

100 self, path: ObjectPath = tuple(), zanj: ZANJ | None = None 

101 ) -> dict[str, Any]: 

102 if zanj is None: 

103 zanj = ZANJ() 

104 obj = dict( 

105 zanj_model_config=self.zanj_model_config.serialize(), 

106 meta=dict( 

107 class_name=self.__class__.__name__, 

108 class_doc=string_as_lines(self.__class__.__doc__), 

109 class_source=safe_getsource(self.__class__), 

110 module_name=self.__class__.__module__, 

111 module_mro=[str(x) for x in self.__class__.__mro__], 

112 num_params=num_params(self), 

113 as_str=string_as_lines(str(self)), 

114 ), 

115 training_records=self.training_records, 

116 state_dict=self.state_dict(), 

117 __muutils_format__=self.__class__.__name__, 

118 ) 

119 return obj 

120 

121 def save(self, file_path: str, zanj: ZANJ | None = None): 

122 if zanj is None: 

123 zanj = ZANJ() 

124 zanj.save(self.serialize(), file_path) 

125 

126 def _load_state_dict_wrapper( 

127 self, 

128 state_dict: dict[str, torch.Tensor], 

129 **kwargs, 

130 ): 

131 """wrapper for `load_state_dict()` in case you need to override it""" 

132 assert len(kwargs) == 0, f"got unexpected kwargs: {kwargs}" 

133 return self.load_state_dict(state_dict) 

134 

135 @classmethod 

136 def load( 

137 cls, obj: dict[str, Any], path: ObjectPath, zanj: ZANJ | None = None 

138 ) -> "ConfiguredModel": 

139 """load a model from a serialized object""" 

140 

141 if zanj is None: 

142 zanj = ZANJ() 

143 

144 # get the config 

145 zanj_model_config: T_config = cls._config_class.load(obj["zanj_model_config"]) # type: ignore 

146 

147 # get the training records 

148 training_records: typing.Any = load_item_recursive( 

149 obj.get("training_records", None), 

150 tuple(path) + ("training_records",), 

151 zanj, 

152 ) 

153 

154 # initialize the model 

155 model: "ConfiguredModel" = cls(zanj_model_config) 

156 

157 # load the state dict 

158 tensored_state_dict: dict[str, torch.Tensor] = load_item_recursive( 

159 obj["state_dict"], 

160 tuple(path) + ("state_dict",), 

161 zanj, 

162 ) 

163 

164 model._load_state_dict_wrapper( 

165 tensored_state_dict, 

166 **zanj.custom_settings.get("_load_state_dict_wrapper", dict()), 

167 ) 

168 

169 # set the training records 

170 model.training_records = training_records 

171 

172 return model 

173 

174 @classmethod 

175 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 

176 """read a model from a file""" 

177 if zanj is None: 

178 zanj = ZANJ() 

179 

180 mdl: ConfiguredModel = zanj.read(file_path) 

181 assert isinstance(mdl, cls), f"loaded object must be a {cls}, got {type(mdl)}" 

182 return mdl 

183 

184 @classmethod 

185 def load_file(cls, file_path: str, zanj: ZANJ | None = None) -> "ConfiguredModel": 

186 """read a model from a file""" 

187 warnings.warn( 

188 "load_file() is deprecated, use read() instead", DeprecationWarning 

189 ) 

190 return cls.read(file_path, zanj) 

191 

192 @classmethod 

193 def get_handler(cls) -> LoaderHandler: 

194 cls_name: str = str(cls.__name__) 

195 return LoaderHandler( 

196 check=lambda json_item, path=None, z=None: ( # type: ignore 

197 isinstance(json_item, dict) 

198 and _FORMAT_KEY in json_item 

199 and json_item[_FORMAT_KEY].startswith(cls_name) 

200 ), 

201 load=lambda json_item, path=None, z=None: cls.load(json_item, path, z), # type: ignore 

202 uid=cls_name, 

203 source_pckg=cls.__module__, 

204 desc=f"{cls.__module__} {cls_name} loader via zanj.torchutil.ConfiguredModel", 

205 ) 

206 

207 def num_params(self) -> int: 

208 return num_params(self) 

209 

210 

211def set_config_class( 

212 config_class: Type[SerializableDataclass], 

213) -> typing.Callable[[Type[ConfiguredModel]], Type[ConfiguredModel]]: 

214 if not issubclass(config_class, SerializableDataclass): 

215 raise TypeError(f"{config_class} must be a subclass of SerializableDataclass") 

216 

217 def wrapper(cls: Type[ConfiguredModel]) -> Type[ConfiguredModel]: 

218 # set the config class 

219 cls._config_class = config_class 

220 

221 # register the handlers 

222 register_loader_handler(cls.get_handler()) 

223 

224 # return the new class 

225 return cls 

226 

227 return wrapper 

228 

229 

230class ConfigMismatchException(ValueError): 

231 def __init__(self, msg: str, diff): 

232 super().__init__(msg) 

233 self.diff = diff 

234 

235 def __str__(self): 

236 return f"{super().__str__()}: {self.diff}" 

237 

238 

239def assert_model_cfg_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 

240 """check both models are correct instances and have the same config 

241 

242 Raises: 

243 ConfigMismatchException: if the configs don't match, e.diff will contain the diff 

244 """ 

245 assert isinstance(model_a, ConfiguredModel), "model_a must be a ConfiguredModel" 

246 assert isinstance(model_a.zanj_model_config, SerializableDataclass), ( 

247 "model_a must have a zanj_model_config" 

248 ) 

249 assert isinstance(model_b, ConfiguredModel), "model_b must be a ConfiguredModel" 

250 assert isinstance(model_b.zanj_model_config, SerializableDataclass), ( 

251 "model_b must have a zanj_model_config" 

252 ) 

253 

254 cls_type: type = type(model_a.zanj_model_config) 

255 

256 if not (model_a.zanj_model_config == model_b.zanj_model_config): 

257 raise ConfigMismatchException( 

258 f"configs of type {type(model_a.zanj_model_config)}, {type(model_b.zanj_model_config)} don't match", 

259 diff=cls_type.diff(model_a.zanj_model_config, model_b.zanj_model_config), # type: ignore[attr-defined] 

260 ) 

261 

262 

263def assert_model_exact_equality(model_a: ConfiguredModel, model_b: ConfiguredModel): 

264 """check the models are exactly equal, including state dict contents""" 

265 assert_model_cfg_equality(model_a, model_b) 

266 

267 model_a_sd_keys: set[str] = set(model_a.state_dict().keys()) 

268 model_b_sd_keys: set[str] = set(model_b.state_dict().keys()) 

269 assert model_a_sd_keys == model_b_sd_keys, ( 

270 f"state dict keys don't match: {model_a_sd_keys - model_b_sd_keys} / {model_b_sd_keys - model_a_sd_keys}" 

271 ) 

272 keys_failed: list[str] = list() 

273 for k, v_a in model_a.state_dict().items(): 

274 v_b = model_b.state_dict()[k] 

275 if not (v_a == v_b).all(): 

276 # if not torch.allclose(v, v_load): 

277 keys_failed.append(k) 

278 print(f"failed {k}") 

279 else: 

280 print(f"passed {k}") 

281 assert len(keys_failed) == 0, ( 

282 f"{len(keys_failed)} / {len(model_a_sd_keys)} state dict elements don't match: {keys_failed}" 

283 )