Coverage for muutils\tensor_utils.py: 86%

128 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-01-17 01:00 -0700

1"""utilities for working with tensors and arrays. 

2 

3notably: 

4 

5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types 

6- `DTYPE_MAP` mapping string representations of types to their type 

7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types 

8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match 

9 

10""" 

11 

12from __future__ import annotations 

13 

14import json 

15import typing 

16 

17import jaxtyping 

18import numpy as np 

19import torch 

20 

21from muutils.errormode import ErrorMode 

22from muutils.dictmagic import dotlist_to_nested_dict 

23 

24# pylint: disable=missing-class-docstring 

25 

26 

27TYPE_TO_JAX_DTYPE: dict = { 

28 float: jaxtyping.Float, 

29 int: jaxtyping.Int, 

30 jaxtyping.Float: jaxtyping.Float, 

31 jaxtyping.Int: jaxtyping.Int, 

32 # bool 

33 bool: jaxtyping.Bool, 

34 jaxtyping.Bool: jaxtyping.Bool, 

35 np.bool_: jaxtyping.Bool, 

36 torch.bool: jaxtyping.Bool, 

37 # numpy float 

38 np.float_: jaxtyping.Float, 

39 np.float16: jaxtyping.Float, 

40 np.float32: jaxtyping.Float, 

41 np.float64: jaxtyping.Float, 

42 np.half: jaxtyping.Float, 

43 np.single: jaxtyping.Float, 

44 np.double: jaxtyping.Float, 

45 # numpy int 

46 np.int_: jaxtyping.Int, 

47 np.int8: jaxtyping.Int, 

48 np.int16: jaxtyping.Int, 

49 np.int32: jaxtyping.Int, 

50 np.int64: jaxtyping.Int, 

51 np.longlong: jaxtyping.Int, 

52 np.short: jaxtyping.Int, 

53 np.uint8: jaxtyping.Int, 

54 # torch float 

55 torch.float: jaxtyping.Float, 

56 torch.float16: jaxtyping.Float, 

57 torch.float32: jaxtyping.Float, 

58 torch.float64: jaxtyping.Float, 

59 torch.half: jaxtyping.Float, 

60 torch.double: jaxtyping.Float, 

61 torch.bfloat16: jaxtyping.Float, 

62 # torch int 

63 torch.int: jaxtyping.Int, 

64 torch.int8: jaxtyping.Int, 

65 torch.int16: jaxtyping.Int, 

66 torch.int32: jaxtyping.Int, 

67 torch.int64: jaxtyping.Int, 

68 torch.long: jaxtyping.Int, 

69 torch.short: jaxtyping.Int, 

70} 

71"dict mapping python, numpy, and torch types to `jaxtyping` types" 

72 

73 

74# TODO: add proper type annotations to this signature 

75def jaxtype_factory( 

76 name: str, 

77 array_type: type, 

78 default_jax_dtype=jaxtyping.Float, 

79 legacy_mode: ErrorMode = ErrorMode.WARN, 

80) -> type: 

81 """usage: 

82 ``` 

83 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) 

84 x: ATensor["dim1 dim2", np.float32] 

85 ``` 

86 """ 

87 legacy_mode = ErrorMode.from_any(legacy_mode) 

88 

89 class _BaseArray: 

90 """jaxtyping shorthand 

91 (backwards compatible with older versions of muutils.tensor_utils) 

92 

93 default_jax_dtype = {default_jax_dtype} 

94 array_type = {array_type} 

95 """ 

96 

97 def __new__(cls, *args, **kwargs): 

98 raise TypeError("Type FArray cannot be instantiated.") 

99 

100 def __init_subclass__(cls, *args, **kwargs): 

101 raise TypeError(f"Cannot subclass {cls.__name__}") 

102 

103 @classmethod 

104 def param_info(cls, params) -> str: 

105 """useful for error printing""" 

106 return "\n".join( 

107 f"{k} = {v}" 

108 for k, v in { 

109 "cls.__name__": cls.__name__, 

110 "cls.__doc__": cls.__doc__, 

111 "params": params, 

112 "type(params)": type(params), 

113 }.items() 

114 ) 

115 

116 @typing._tp_cache # type: ignore 

117 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: 

118 # MyTensor["dim1 dim2"] 

119 if isinstance(params, str): 

120 return default_jax_dtype[array_type, params] 

121 

122 elif isinstance(params, tuple): 

123 if len(params) != 2: 

124 raise Exception( 

125 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" 

126 ) 

127 

128 if isinstance(params[0], str): 

129 # MyTensor["dim1 dim2", int] 

130 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] 

131 

132 elif isinstance(params[0], tuple): 

133 legacy_mode.process( 

134 f"legacy type annotation was used:\n{cls.param_info(params) = }", 

135 except_cls=Exception, 

136 ) 

137 # MyTensor[("dim1", "dim2"), int] 

138 shape_anot: list[str] = list() 

139 for x in params[0]: 

140 if isinstance(x, str): 

141 shape_anot.append(x) 

142 elif isinstance(x, int): 

143 shape_anot.append(str(x)) 

144 elif isinstance(x, tuple): 

145 shape_anot.append("".join(str(y) for y in x)) 

146 else: 

147 raise Exception( 

148 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" 

149 ) 

150 

151 return TYPE_TO_JAX_DTYPE[params[1]][ 

152 array_type, " ".join(shape_anot) 

153 ] 

154 else: 

155 raise Exception( 

156 f"unexpected type for params:\n{cls.param_info(params)}" 

157 ) 

158 

159 _BaseArray.__name__ = name 

160 

161 if _BaseArray.__doc__ is None: 

162 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" 

163 

164 _BaseArray.__doc__ = _BaseArray.__doc__.format( 

165 default_jax_dtype=repr(default_jax_dtype), 

166 array_type=repr(array_type), 

167 ) 

168 

169 return _BaseArray 

170 

171 

172if typing.TYPE_CHECKING: 

173 # these class definitions are only used here to make pylint happy, 

174 # but they make mypy unhappy and there is no way to only run if not mypy 

175 # so, later on we have more ignores 

176 class ATensor(torch.Tensor): 

177 @typing._tp_cache # type: ignore 

178 def __class_getitem__(cls, params): 

179 raise NotImplementedError() 

180 

181 class NDArray(torch.Tensor): 

182 @typing._tp_cache # type: ignore 

183 def __class_getitem__(cls, params): 

184 raise NotImplementedError() 

185 

186 

187ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment] 

188 

189NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment] 

190 

191 

192def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 

193 """convert numpy dtype to torch dtype""" 

194 if isinstance(dtype, torch.dtype): 

195 return dtype 

196 else: 

197 return torch.from_numpy(np.array(0, dtype=dtype)).dtype 

198 

199 

200DTYPE_LIST: list = [ 

201 *[ 

202 bool, 

203 int, 

204 float, 

205 ], 

206 *[ 

207 # ---------- 

208 # pytorch 

209 # ---------- 

210 # floats 

211 torch.float, 

212 torch.float32, 

213 torch.float64, 

214 torch.half, 

215 torch.double, 

216 torch.bfloat16, 

217 # complex 

218 torch.complex64, 

219 torch.complex128, 

220 # ints 

221 torch.int, 

222 torch.int8, 

223 torch.int16, 

224 torch.int32, 

225 torch.int64, 

226 torch.long, 

227 torch.short, 

228 # simplest 

229 torch.uint8, 

230 torch.bool, 

231 ], 

232 *[ 

233 # ---------- 

234 # numpy 

235 # ---------- 

236 # floats 

237 np.float_, 

238 np.float16, 

239 np.float32, 

240 np.float64, 

241 np.half, 

242 np.single, 

243 np.double, 

244 # complex 

245 np.complex64, 

246 np.complex128, 

247 # ints 

248 np.int8, 

249 np.int16, 

250 np.int32, 

251 np.int64, 

252 np.int_, 

253 np.longlong, 

254 np.short, 

255 # simplest 

256 np.uint8, 

257 np.bool_, 

258 ], 

259] 

260"list of all the python, numpy, and torch numerical types I could think of" 

261 

262DTYPE_MAP: dict = { 

263 **{str(x): x for x in DTYPE_LIST}, 

264 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"}, 

265} 

266"mapping from string representations of types to their type" 

267 

268TORCH_DTYPE_MAP: dict = { 

269 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items() 

270} 

271"mapping from string representations of types to specifically torch types" 

272 

273# no idea why we have to do this, smh 

274DTYPE_MAP["bool"] = np.bool_ 

275TORCH_DTYPE_MAP["bool"] = torch.bool 

276 

277 

278TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = { 

279 "Adagrad": torch.optim.Adagrad, 

280 "Adam": torch.optim.Adam, 

281 "AdamW": torch.optim.AdamW, 

282 "SparseAdam": torch.optim.SparseAdam, 

283 "Adamax": torch.optim.Adamax, 

284 "ASGD": torch.optim.ASGD, 

285 "LBFGS": torch.optim.LBFGS, 

286 "NAdam": torch.optim.NAdam, 

287 "RAdam": torch.optim.RAdam, 

288 "RMSprop": torch.optim.RMSprop, 

289 "Rprop": torch.optim.Rprop, 

290 "SGD": torch.optim.SGD, 

291} 

292 

293 

294def pad_tensor( 

295 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 

296 padded_length: int, 

297 pad_value: float = 0.0, 

298 rpad: bool = False, 

299) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 

300 """pad a 1-d tensor on the left with pad_value to length `padded_length` 

301 

302 set `rpad = True` to pad on the right instead""" 

303 

304 temp: list[torch.Tensor] = [ 

305 torch.full( 

306 (padded_length - tensor.shape[0],), 

307 pad_value, 

308 dtype=tensor.dtype, 

309 device=tensor.device, 

310 ), 

311 tensor, 

312 ] 

313 

314 if rpad: 

315 temp.reverse() 

316 

317 return torch.cat(temp) 

318 

319 

320def lpad_tensor( 

321 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 

322) -> torch.Tensor: 

323 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 

324 return pad_tensor(tensor, padded_length, pad_value, rpad=False) 

325 

326 

327def rpad_tensor( 

328 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 

329) -> torch.Tensor: 

330 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 

331 return pad_tensor(tensor, pad_length, pad_value, rpad=True) 

332 

333 

334def pad_array( 

335 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 

336 padded_length: int, 

337 pad_value: float = 0.0, 

338 rpad: bool = False, 

339) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 

340 """pad a 1-d array on the left with pad_value to length `padded_length` 

341 

342 set `rpad = True` to pad on the right instead""" 

343 

344 temp: list[np.ndarray] = [ 

345 np.full( 

346 (padded_length - array.shape[0],), 

347 pad_value, 

348 dtype=array.dtype, 

349 ), 

350 array, 

351 ] 

352 

353 if rpad: 

354 temp.reverse() 

355 

356 return np.concatenate(temp) 

357 

358 

359def lpad_array( 

360 array: np.ndarray, padded_length: int, pad_value: float = 0.0 

361) -> np.ndarray: 

362 """pad a 1-d array on the left with pad_value to length `padded_length`""" 

363 return pad_array(array, padded_length, pad_value, rpad=False) 

364 

365 

366def rpad_array( 

367 array: np.ndarray, pad_length: int, pad_value: float = 0.0 

368) -> np.ndarray: 

369 """pad a 1-d array on the right with pad_value to length `pad_length`""" 

370 return pad_array(array, pad_length, pad_value, rpad=True) 

371 

372 

373def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 

374 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 

375 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) 

376 

377 

378def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 

379 """printable version of get_dict_shapes""" 

380 return json.dumps( 

381 dotlist_to_nested_dict( 

382 { 

383 k: str( 

384 tuple(v.shape) 

385 ) # to string, since indent wont play nice with tuples 

386 for k, v in d.items() 

387 } 

388 ), 

389 indent=2, 

390 ) 

391 

392 

393class StateDictCompareError(AssertionError): 

394 """raised when state dicts don't match""" 

395 

396 pass 

397 

398 

399class StateDictKeysError(StateDictCompareError): 

400 """raised when state dict keys don't match""" 

401 

402 pass 

403 

404 

405class StateDictShapeError(StateDictCompareError): 

406 """raised when state dict shapes don't match""" 

407 

408 pass 

409 

410 

411class StateDictValueError(StateDictCompareError): 

412 """raised when state dict values don't match""" 

413 

414 pass 

415 

416 

417def compare_state_dicts( 

418 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True 

419) -> None: 

420 """compare two dicts of tensors 

421 

422 # Parameters: 

423 

424 - `d1 : dict` 

425 - `d2 : dict` 

426 - `rtol : float` 

427 (defaults to `1e-5`) 

428 - `atol : float` 

429 (defaults to `1e-8`) 

430 - `verbose : bool` 

431 (defaults to `True`) 

432 

433 # Raises: 

434 

435 - `StateDictKeysError` : keys don't match 

436 - `StateDictShapeError` : shapes don't match (but keys do) 

437 - `StateDictValueError` : values don't match (but keys and shapes do) 

438 """ 

439 # check keys match 

440 d1_keys: set = set(d1.keys()) 

441 d2_keys: set = set(d2.keys()) 

442 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys) 

443 keys_diff_1: set = d1_keys - d2_keys 

444 keys_diff_2: set = d2_keys - d1_keys 

445 # sort sets for easier debugging 

446 symmetric_diff = set(sorted(symmetric_diff)) 

447 keys_diff_1 = set(sorted(keys_diff_1)) 

448 keys_diff_2 = set(sorted(keys_diff_2)) 

449 diff_shapes_1: str = ( 

450 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 

451 if verbose 

452 else "(verbose = False)" 

453 ) 

454 diff_shapes_2: str = ( 

455 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 

456 if verbose 

457 else "(verbose = False)" 

458 ) 

459 if not len(symmetric_diff) == 0: 

460 raise StateDictKeysError( 

461 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 

462 ) 

463 

464 # check tensors match 

465 shape_failed: list[str] = list() 

466 vals_failed: list[str] = list() 

467 for k, v1 in d1.items(): 

468 v2 = d2[k] 

469 # check shapes first 

470 if not v1.shape == v2.shape: 

471 shape_failed.append(k) 

472 else: 

473 # if shapes match, check values 

474 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 

475 vals_failed.append(k) 

476 

477 str_shape_failed: str = ( 

478 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 

479 ) 

480 str_vals_failed: str = ( 

481 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 

482 ) 

483 

484 if not len(shape_failed) == 0: 

485 raise StateDictShapeError( 

486 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 

487 ) 

488 if not len(vals_failed) == 0: 

489 raise StateDictValueError( 

490 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 

491 )