muutils.tensor_utils
utilities for working with tensors and arrays.
notably:
TYPE_TO_JAX_DTYPE
: a mapping from python, numpy, and torch types tojaxtyping
typesDTYPE_MAP
mapping string representations of types to their typeTORCH_DTYPE_MAP
mapping string representations of types to torch typescompare_state_dicts
for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
1"""utilities for working with tensors and arrays. 2 3notably: 4 5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types 6- `DTYPE_MAP` mapping string representations of types to their type 7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types 8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match 9 10""" 11 12from __future__ import annotations 13 14import json 15import typing 16 17import jaxtyping 18import numpy as np 19import torch 20 21from muutils.errormode import ErrorMode 22from muutils.dictmagic import dotlist_to_nested_dict 23 24# pylint: disable=missing-class-docstring 25 26 27TYPE_TO_JAX_DTYPE: dict = { 28 float: jaxtyping.Float, 29 int: jaxtyping.Int, 30 jaxtyping.Float: jaxtyping.Float, 31 jaxtyping.Int: jaxtyping.Int, 32 # bool 33 bool: jaxtyping.Bool, 34 jaxtyping.Bool: jaxtyping.Bool, 35 np.bool_: jaxtyping.Bool, 36 torch.bool: jaxtyping.Bool, 37 # numpy float 38 np.float16: jaxtyping.Float, 39 np.float32: jaxtyping.Float, 40 np.float64: jaxtyping.Float, 41 np.half: jaxtyping.Float, 42 np.single: jaxtyping.Float, 43 np.double: jaxtyping.Float, 44 # numpy int 45 np.int8: jaxtyping.Int, 46 np.int16: jaxtyping.Int, 47 np.int32: jaxtyping.Int, 48 np.int64: jaxtyping.Int, 49 np.longlong: jaxtyping.Int, 50 np.short: jaxtyping.Int, 51 np.uint8: jaxtyping.Int, 52 # torch float 53 torch.float: jaxtyping.Float, 54 torch.float16: jaxtyping.Float, 55 torch.float32: jaxtyping.Float, 56 torch.float64: jaxtyping.Float, 57 torch.half: jaxtyping.Float, 58 torch.double: jaxtyping.Float, 59 torch.bfloat16: jaxtyping.Float, 60 # torch int 61 torch.int: jaxtyping.Int, 62 torch.int8: jaxtyping.Int, 63 torch.int16: jaxtyping.Int, 64 torch.int32: jaxtyping.Int, 65 torch.int64: jaxtyping.Int, 66 torch.long: jaxtyping.Int, 67 torch.short: jaxtyping.Int, 68} 69"dict mapping python, numpy, and torch types to `jaxtyping` types" 70 71if np.version.version < "2.0.0": 72 TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float 73 TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int 74 75 76# TODO: add proper type annotations to this signature 77# TODO: maybe get rid of this altogether? 78def jaxtype_factory( 79 name: str, 80 array_type: type, 81 default_jax_dtype=jaxtyping.Float, 82 legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN, 83) -> type: 84 """usage: 85 ``` 86 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) 87 x: ATensor["dim1 dim2", np.float32] 88 ``` 89 """ 90 legacy_mode_ = ErrorMode.from_any(legacy_mode) 91 92 class _BaseArray: 93 """jaxtyping shorthand 94 (backwards compatible with older versions of muutils.tensor_utils) 95 96 default_jax_dtype = {default_jax_dtype} 97 array_type = {array_type} 98 """ 99 100 def __new__(cls, *args, **kwargs): 101 raise TypeError("Type FArray cannot be instantiated.") 102 103 def __init_subclass__(cls, *args, **kwargs): 104 raise TypeError(f"Cannot subclass {cls.__name__}") 105 106 @classmethod 107 def param_info(cls, params) -> str: 108 """useful for error printing""" 109 return "\n".join( 110 f"{k} = {v}" 111 for k, v in { 112 "cls.__name__": cls.__name__, 113 "cls.__doc__": cls.__doc__, 114 "params": params, 115 "type(params)": type(params), 116 }.items() 117 ) 118 119 @typing._tp_cache # type: ignore 120 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # type: ignore 121 # MyTensor["dim1 dim2"] 122 if isinstance(params, str): 123 return default_jax_dtype[array_type, params] 124 125 elif isinstance(params, tuple): 126 if len(params) != 2: 127 raise Exception( 128 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" 129 ) 130 131 if isinstance(params[0], str): 132 # MyTensor["dim1 dim2", int] 133 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] 134 135 elif isinstance(params[0], tuple): 136 legacy_mode_.process( 137 f"legacy type annotation was used:\n{cls.param_info(params) = }", 138 except_cls=Exception, 139 ) 140 # MyTensor[("dim1", "dim2"), int] 141 shape_anot: list[str] = list() 142 for x in params[0]: 143 if isinstance(x, str): 144 shape_anot.append(x) 145 elif isinstance(x, int): 146 shape_anot.append(str(x)) 147 elif isinstance(x, tuple): 148 shape_anot.append("".join(str(y) for y in x)) 149 else: 150 raise Exception( 151 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" 152 ) 153 154 return TYPE_TO_JAX_DTYPE[params[1]][ 155 array_type, " ".join(shape_anot) 156 ] 157 else: 158 raise Exception( 159 f"unexpected type for params:\n{cls.param_info(params)}" 160 ) 161 162 _BaseArray.__name__ = name 163 164 if _BaseArray.__doc__ is None: 165 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" 166 167 _BaseArray.__doc__ = _BaseArray.__doc__.format( 168 default_jax_dtype=repr(default_jax_dtype), 169 array_type=repr(array_type), 170 ) 171 172 return _BaseArray 173 174 175if typing.TYPE_CHECKING: 176 # these class definitions are only used here to make pylint happy, 177 # but they make mypy unhappy and there is no way to only run if not mypy 178 # so, later on we have more ignores 179 class ATensor(torch.Tensor): 180 @typing._tp_cache # type: ignore 181 def __class_getitem__(cls, params): 182 raise NotImplementedError() 183 184 class NDArray(torch.Tensor): 185 @typing._tp_cache # type: ignore 186 def __class_getitem__(cls, params): 187 raise NotImplementedError() 188 189 190ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment] 191 192NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment] 193 194 195def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 196 """convert numpy dtype to torch dtype""" 197 if isinstance(dtype, torch.dtype): 198 return dtype 199 else: 200 return torch.from_numpy(np.array(0, dtype=dtype)).dtype 201 202 203DTYPE_LIST: list = [ 204 *[ 205 bool, 206 int, 207 float, 208 ], 209 *[ 210 # ---------- 211 # pytorch 212 # ---------- 213 # floats 214 torch.float, 215 torch.float32, 216 torch.float64, 217 torch.half, 218 torch.double, 219 torch.bfloat16, 220 # complex 221 torch.complex64, 222 torch.complex128, 223 # ints 224 torch.int, 225 torch.int8, 226 torch.int16, 227 torch.int32, 228 torch.int64, 229 torch.long, 230 torch.short, 231 # simplest 232 torch.uint8, 233 torch.bool, 234 ], 235 *[ 236 # ---------- 237 # numpy 238 # ---------- 239 # floats 240 np.float16, 241 np.float32, 242 np.float64, 243 np.half, 244 np.single, 245 np.double, 246 # complex 247 np.complex64, 248 np.complex128, 249 # ints 250 np.int8, 251 np.int16, 252 np.int32, 253 np.int64, 254 np.longlong, 255 np.short, 256 # simplest 257 np.uint8, 258 np.bool_, 259 ], 260] 261"list of all the python, numpy, and torch numerical types I could think of" 262 263if np.version.version < "2.0.0": 264 DTYPE_LIST.extend([np.float_, np.int_]) 265 266DTYPE_MAP: dict = { 267 **{str(x): x for x in DTYPE_LIST}, 268 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"}, 269} 270"mapping from string representations of types to their type" 271 272TORCH_DTYPE_MAP: dict = { 273 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items() 274} 275"mapping from string representations of types to specifically torch types" 276 277# no idea why we have to do this, smh 278DTYPE_MAP["bool"] = np.bool_ 279TORCH_DTYPE_MAP["bool"] = torch.bool 280 281 282TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = { 283 "Adagrad": torch.optim.Adagrad, 284 "Adam": torch.optim.Adam, 285 "AdamW": torch.optim.AdamW, 286 "SparseAdam": torch.optim.SparseAdam, 287 "Adamax": torch.optim.Adamax, 288 "ASGD": torch.optim.ASGD, 289 "LBFGS": torch.optim.LBFGS, 290 "NAdam": torch.optim.NAdam, 291 "RAdam": torch.optim.RAdam, 292 "RMSprop": torch.optim.RMSprop, 293 "Rprop": torch.optim.Rprop, 294 "SGD": torch.optim.SGD, 295} 296 297 298def pad_tensor( 299 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 300 padded_length: int, 301 pad_value: float = 0.0, 302 rpad: bool = False, 303) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 304 """pad a 1-d tensor on the left with pad_value to length `padded_length` 305 306 set `rpad = True` to pad on the right instead""" 307 308 temp: list[torch.Tensor] = [ 309 torch.full( 310 (padded_length - tensor.shape[0],), 311 pad_value, 312 dtype=tensor.dtype, 313 device=tensor.device, 314 ), 315 tensor, 316 ] 317 318 if rpad: 319 temp.reverse() 320 321 return torch.cat(temp) 322 323 324def lpad_tensor( 325 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 326) -> torch.Tensor: 327 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 328 return pad_tensor(tensor, padded_length, pad_value, rpad=False) 329 330 331def rpad_tensor( 332 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 333) -> torch.Tensor: 334 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 335 return pad_tensor(tensor, pad_length, pad_value, rpad=True) 336 337 338def pad_array( 339 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 340 padded_length: int, 341 pad_value: float = 0.0, 342 rpad: bool = False, 343) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 344 """pad a 1-d array on the left with pad_value to length `padded_length` 345 346 set `rpad = True` to pad on the right instead""" 347 348 temp: list[np.ndarray] = [ 349 np.full( 350 (padded_length - array.shape[0],), 351 pad_value, 352 dtype=array.dtype, 353 ), 354 array, 355 ] 356 357 if rpad: 358 temp.reverse() 359 360 return np.concatenate(temp) 361 362 363def lpad_array( 364 array: np.ndarray, padded_length: int, pad_value: float = 0.0 365) -> np.ndarray: 366 """pad a 1-d array on the left with pad_value to length `padded_length`""" 367 return pad_array(array, padded_length, pad_value, rpad=False) 368 369 370def rpad_array( 371 array: np.ndarray, pad_length: int, pad_value: float = 0.0 372) -> np.ndarray: 373 """pad a 1-d array on the right with pad_value to length `pad_length`""" 374 return pad_array(array, pad_length, pad_value, rpad=True) 375 376 377def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 378 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 379 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) 380 381 382def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 383 """printable version of get_dict_shapes""" 384 return json.dumps( 385 dotlist_to_nested_dict( 386 { 387 k: str( 388 tuple(v.shape) 389 ) # to string, since indent wont play nice with tuples 390 for k, v in d.items() 391 } 392 ), 393 indent=2, 394 ) 395 396 397class StateDictCompareError(AssertionError): 398 """raised when state dicts don't match""" 399 400 pass 401 402 403class StateDictKeysError(StateDictCompareError): 404 """raised when state dict keys don't match""" 405 406 pass 407 408 409class StateDictShapeError(StateDictCompareError): 410 """raised when state dict shapes don't match""" 411 412 pass 413 414 415class StateDictValueError(StateDictCompareError): 416 """raised when state dict values don't match""" 417 418 pass 419 420 421def compare_state_dicts( 422 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True 423) -> None: 424 """compare two dicts of tensors 425 426 # Parameters: 427 428 - `d1 : dict` 429 - `d2 : dict` 430 - `rtol : float` 431 (defaults to `1e-5`) 432 - `atol : float` 433 (defaults to `1e-8`) 434 - `verbose : bool` 435 (defaults to `True`) 436 437 # Raises: 438 439 - `StateDictKeysError` : keys don't match 440 - `StateDictShapeError` : shapes don't match (but keys do) 441 - `StateDictValueError` : values don't match (but keys and shapes do) 442 """ 443 # check keys match 444 d1_keys: set = set(d1.keys()) 445 d2_keys: set = set(d2.keys()) 446 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys) 447 keys_diff_1: set = d1_keys - d2_keys 448 keys_diff_2: set = d2_keys - d1_keys 449 # sort sets for easier debugging 450 symmetric_diff = set(sorted(symmetric_diff)) 451 keys_diff_1 = set(sorted(keys_diff_1)) 452 keys_diff_2 = set(sorted(keys_diff_2)) 453 diff_shapes_1: str = ( 454 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 455 if verbose 456 else "(verbose = False)" 457 ) 458 diff_shapes_2: str = ( 459 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 460 if verbose 461 else "(verbose = False)" 462 ) 463 if not len(symmetric_diff) == 0: 464 raise StateDictKeysError( 465 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 466 ) 467 468 # check tensors match 469 shape_failed: list[str] = list() 470 vals_failed: list[str] = list() 471 for k, v1 in d1.items(): 472 v2 = d2[k] 473 # check shapes first 474 if not v1.shape == v2.shape: 475 shape_failed.append(k) 476 else: 477 # if shapes match, check values 478 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 479 vals_failed.append(k) 480 481 str_shape_failed: str = ( 482 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 483 ) 484 str_vals_failed: str = ( 485 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 486 ) 487 488 if not len(shape_failed) == 0: 489 raise StateDictShapeError( 490 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 491 ) 492 if not len(vals_failed) == 0: 493 raise StateDictValueError( 494 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 495 )
dict mapping python, numpy, and torch types to jaxtyping
types
79def jaxtype_factory( 80 name: str, 81 array_type: type, 82 default_jax_dtype=jaxtyping.Float, 83 legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN, 84) -> type: 85 """usage: 86 ``` 87 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) 88 x: ATensor["dim1 dim2", np.float32] 89 ``` 90 """ 91 legacy_mode_ = ErrorMode.from_any(legacy_mode) 92 93 class _BaseArray: 94 """jaxtyping shorthand 95 (backwards compatible with older versions of muutils.tensor_utils) 96 97 default_jax_dtype = {default_jax_dtype} 98 array_type = {array_type} 99 """ 100 101 def __new__(cls, *args, **kwargs): 102 raise TypeError("Type FArray cannot be instantiated.") 103 104 def __init_subclass__(cls, *args, **kwargs): 105 raise TypeError(f"Cannot subclass {cls.__name__}") 106 107 @classmethod 108 def param_info(cls, params) -> str: 109 """useful for error printing""" 110 return "\n".join( 111 f"{k} = {v}" 112 for k, v in { 113 "cls.__name__": cls.__name__, 114 "cls.__doc__": cls.__doc__, 115 "params": params, 116 "type(params)": type(params), 117 }.items() 118 ) 119 120 @typing._tp_cache # type: ignore 121 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # type: ignore 122 # MyTensor["dim1 dim2"] 123 if isinstance(params, str): 124 return default_jax_dtype[array_type, params] 125 126 elif isinstance(params, tuple): 127 if len(params) != 2: 128 raise Exception( 129 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" 130 ) 131 132 if isinstance(params[0], str): 133 # MyTensor["dim1 dim2", int] 134 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] 135 136 elif isinstance(params[0], tuple): 137 legacy_mode_.process( 138 f"legacy type annotation was used:\n{cls.param_info(params) = }", 139 except_cls=Exception, 140 ) 141 # MyTensor[("dim1", "dim2"), int] 142 shape_anot: list[str] = list() 143 for x in params[0]: 144 if isinstance(x, str): 145 shape_anot.append(x) 146 elif isinstance(x, int): 147 shape_anot.append(str(x)) 148 elif isinstance(x, tuple): 149 shape_anot.append("".join(str(y) for y in x)) 150 else: 151 raise Exception( 152 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" 153 ) 154 155 return TYPE_TO_JAX_DTYPE[params[1]][ 156 array_type, " ".join(shape_anot) 157 ] 158 else: 159 raise Exception( 160 f"unexpected type for params:\n{cls.param_info(params)}" 161 ) 162 163 _BaseArray.__name__ = name 164 165 if _BaseArray.__doc__ is None: 166 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" 167 168 _BaseArray.__doc__ = _BaseArray.__doc__.format( 169 default_jax_dtype=repr(default_jax_dtype), 170 array_type=repr(array_type), 171 ) 172 173 return _BaseArray
usage:
ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
x: ATensor["dim1 dim2", np.float32]
196def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 197 """convert numpy dtype to torch dtype""" 198 if isinstance(dtype, torch.dtype): 199 return dtype 200 else: 201 return torch.from_numpy(np.array(0, dtype=dtype)).dtype
convert numpy dtype to torch dtype
list of all the python, numpy, and torch numerical types I could think of
mapping from string representations of types to their type
mapping from string representations of types to specifically torch types
299def pad_tensor( 300 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 301 padded_length: int, 302 pad_value: float = 0.0, 303 rpad: bool = False, 304) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 305 """pad a 1-d tensor on the left with pad_value to length `padded_length` 306 307 set `rpad = True` to pad on the right instead""" 308 309 temp: list[torch.Tensor] = [ 310 torch.full( 311 (padded_length - tensor.shape[0],), 312 pad_value, 313 dtype=tensor.dtype, 314 device=tensor.device, 315 ), 316 tensor, 317 ] 318 319 if rpad: 320 temp.reverse() 321 322 return torch.cat(temp)
pad a 1-d tensor on the left with pad_value to length padded_length
set rpad = True
to pad on the right instead
325def lpad_tensor( 326 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 327) -> torch.Tensor: 328 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 329 return pad_tensor(tensor, padded_length, pad_value, rpad=False)
pad a 1-d tensor on the left with pad_value to length padded_length
332def rpad_tensor( 333 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 334) -> torch.Tensor: 335 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 336 return pad_tensor(tensor, pad_length, pad_value, rpad=True)
pad a 1-d tensor on the right with pad_value to length pad_length
339def pad_array( 340 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 341 padded_length: int, 342 pad_value: float = 0.0, 343 rpad: bool = False, 344) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 345 """pad a 1-d array on the left with pad_value to length `padded_length` 346 347 set `rpad = True` to pad on the right instead""" 348 349 temp: list[np.ndarray] = [ 350 np.full( 351 (padded_length - array.shape[0],), 352 pad_value, 353 dtype=array.dtype, 354 ), 355 array, 356 ] 357 358 if rpad: 359 temp.reverse() 360 361 return np.concatenate(temp)
pad a 1-d array on the left with pad_value to length padded_length
set rpad = True
to pad on the right instead
364def lpad_array( 365 array: np.ndarray, padded_length: int, pad_value: float = 0.0 366) -> np.ndarray: 367 """pad a 1-d array on the left with pad_value to length `padded_length`""" 368 return pad_array(array, padded_length, pad_value, rpad=False)
pad a 1-d array on the left with pad_value to length padded_length
371def rpad_array( 372 array: np.ndarray, pad_length: int, pad_value: float = 0.0 373) -> np.ndarray: 374 """pad a 1-d array on the right with pad_value to length `pad_length`""" 375 return pad_array(array, pad_length, pad_value, rpad=True)
pad a 1-d array on the right with pad_value to length pad_length
378def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 379 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 380 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
given a state dict or cache dict, compute the shapes and put them in a nested dict
383def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 384 """printable version of get_dict_shapes""" 385 return json.dumps( 386 dotlist_to_nested_dict( 387 { 388 k: str( 389 tuple(v.shape) 390 ) # to string, since indent wont play nice with tuples 391 for k, v in d.items() 392 } 393 ), 394 indent=2, 395 )
printable version of get_dict_shapes
398class StateDictCompareError(AssertionError): 399 """raised when state dicts don't match""" 400 401 pass
raised when state dicts don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
404class StateDictKeysError(StateDictCompareError): 405 """raised when state dict keys don't match""" 406 407 pass
raised when state dict keys don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
410class StateDictShapeError(StateDictCompareError): 411 """raised when state dict shapes don't match""" 412 413 pass
raised when state dict shapes don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
416class StateDictValueError(StateDictCompareError): 417 """raised when state dict values don't match""" 418 419 pass
raised when state dict values don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
422def compare_state_dicts( 423 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True 424) -> None: 425 """compare two dicts of tensors 426 427 # Parameters: 428 429 - `d1 : dict` 430 - `d2 : dict` 431 - `rtol : float` 432 (defaults to `1e-5`) 433 - `atol : float` 434 (defaults to `1e-8`) 435 - `verbose : bool` 436 (defaults to `True`) 437 438 # Raises: 439 440 - `StateDictKeysError` : keys don't match 441 - `StateDictShapeError` : shapes don't match (but keys do) 442 - `StateDictValueError` : values don't match (but keys and shapes do) 443 """ 444 # check keys match 445 d1_keys: set = set(d1.keys()) 446 d2_keys: set = set(d2.keys()) 447 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys) 448 keys_diff_1: set = d1_keys - d2_keys 449 keys_diff_2: set = d2_keys - d1_keys 450 # sort sets for easier debugging 451 symmetric_diff = set(sorted(symmetric_diff)) 452 keys_diff_1 = set(sorted(keys_diff_1)) 453 keys_diff_2 = set(sorted(keys_diff_2)) 454 diff_shapes_1: str = ( 455 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 456 if verbose 457 else "(verbose = False)" 458 ) 459 diff_shapes_2: str = ( 460 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 461 if verbose 462 else "(verbose = False)" 463 ) 464 if not len(symmetric_diff) == 0: 465 raise StateDictKeysError( 466 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 467 ) 468 469 # check tensors match 470 shape_failed: list[str] = list() 471 vals_failed: list[str] = list() 472 for k, v1 in d1.items(): 473 v2 = d2[k] 474 # check shapes first 475 if not v1.shape == v2.shape: 476 shape_failed.append(k) 477 else: 478 # if shapes match, check values 479 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 480 vals_failed.append(k) 481 482 str_shape_failed: str = ( 483 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 484 ) 485 str_vals_failed: str = ( 486 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 487 ) 488 489 if not len(shape_failed) == 0: 490 raise StateDictShapeError( 491 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 492 ) 493 if not len(vals_failed) == 0: 494 raise StateDictValueError( 495 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 496 )
compare two dicts of tensors
Parameters:
d1 : dict
d2 : dict
rtol : float
(defaults to1e-5
)atol : float
(defaults to1e-8
)verbose : bool
(defaults toTrue
)
Raises:
StateDictKeysError
: keys don't matchStateDictShapeError
: shapes don't match (but keys do)StateDictValueError
: values don't match (but keys and shapes do)