muutils.tensor_utils
utilities for working with tensors and arrays.
notably:
TYPE_TO_JAX_DTYPE
: a mapping from python, numpy, and torch types tojaxtyping
typesDTYPE_MAP
mapping string representations of types to their typeTORCH_DTYPE_MAP
mapping string representations of types to torch typescompare_state_dicts
for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
1"""utilities for working with tensors and arrays. 2 3notably: 4 5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types 6- `DTYPE_MAP` mapping string representations of types to their type 7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types 8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match 9 10""" 11 12from __future__ import annotations 13 14import json 15import typing 16 17import jaxtyping 18import numpy as np 19import torch 20 21from muutils.errormode import ErrorMode 22from muutils.dictmagic import dotlist_to_nested_dict 23 24# pylint: disable=missing-class-docstring 25 26 27TYPE_TO_JAX_DTYPE: dict = { 28 float: jaxtyping.Float, 29 int: jaxtyping.Int, 30 jaxtyping.Float: jaxtyping.Float, 31 jaxtyping.Int: jaxtyping.Int, 32 # bool 33 bool: jaxtyping.Bool, 34 jaxtyping.Bool: jaxtyping.Bool, 35 np.bool_: jaxtyping.Bool, 36 torch.bool: jaxtyping.Bool, 37 # numpy float 38 np.float16: jaxtyping.Float, 39 np.float32: jaxtyping.Float, 40 np.float64: jaxtyping.Float, 41 np.half: jaxtyping.Float, 42 np.single: jaxtyping.Float, 43 np.double: jaxtyping.Float, 44 # numpy int 45 np.int8: jaxtyping.Int, 46 np.int16: jaxtyping.Int, 47 np.int32: jaxtyping.Int, 48 np.int64: jaxtyping.Int, 49 np.longlong: jaxtyping.Int, 50 np.short: jaxtyping.Int, 51 np.uint8: jaxtyping.Int, 52 # torch float 53 torch.float: jaxtyping.Float, 54 torch.float16: jaxtyping.Float, 55 torch.float32: jaxtyping.Float, 56 torch.float64: jaxtyping.Float, 57 torch.half: jaxtyping.Float, 58 torch.double: jaxtyping.Float, 59 torch.bfloat16: jaxtyping.Float, 60 # torch int 61 torch.int: jaxtyping.Int, 62 torch.int8: jaxtyping.Int, 63 torch.int16: jaxtyping.Int, 64 torch.int32: jaxtyping.Int, 65 torch.int64: jaxtyping.Int, 66 torch.long: jaxtyping.Int, 67 torch.short: jaxtyping.Int, 68} 69"dict mapping python, numpy, and torch types to `jaxtyping` types" 70 71if np.version.version < "2.0.0": 72 TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float 73 TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int 74 75 76# TODO: add proper type annotations to this signature 77def jaxtype_factory( 78 name: str, 79 array_type: type, 80 default_jax_dtype=jaxtyping.Float, 81 legacy_mode: ErrorMode = ErrorMode.WARN, 82) -> type: 83 """usage: 84 ``` 85 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) 86 x: ATensor["dim1 dim2", np.float32] 87 ``` 88 """ 89 legacy_mode = ErrorMode.from_any(legacy_mode) 90 91 class _BaseArray: 92 """jaxtyping shorthand 93 (backwards compatible with older versions of muutils.tensor_utils) 94 95 default_jax_dtype = {default_jax_dtype} 96 array_type = {array_type} 97 """ 98 99 def __new__(cls, *args, **kwargs): 100 raise TypeError("Type FArray cannot be instantiated.") 101 102 def __init_subclass__(cls, *args, **kwargs): 103 raise TypeError(f"Cannot subclass {cls.__name__}") 104 105 @classmethod 106 def param_info(cls, params) -> str: 107 """useful for error printing""" 108 return "\n".join( 109 f"{k} = {v}" 110 for k, v in { 111 "cls.__name__": cls.__name__, 112 "cls.__doc__": cls.__doc__, 113 "params": params, 114 "type(params)": type(params), 115 }.items() 116 ) 117 118 @typing._tp_cache # type: ignore 119 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # type: ignore 120 # MyTensor["dim1 dim2"] 121 if isinstance(params, str): 122 return default_jax_dtype[array_type, params] 123 124 elif isinstance(params, tuple): 125 if len(params) != 2: 126 raise Exception( 127 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" 128 ) 129 130 if isinstance(params[0], str): 131 # MyTensor["dim1 dim2", int] 132 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] 133 134 elif isinstance(params[0], tuple): 135 legacy_mode.process( 136 f"legacy type annotation was used:\n{cls.param_info(params) = }", 137 except_cls=Exception, 138 ) 139 # MyTensor[("dim1", "dim2"), int] 140 shape_anot: list[str] = list() 141 for x in params[0]: 142 if isinstance(x, str): 143 shape_anot.append(x) 144 elif isinstance(x, int): 145 shape_anot.append(str(x)) 146 elif isinstance(x, tuple): 147 shape_anot.append("".join(str(y) for y in x)) 148 else: 149 raise Exception( 150 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" 151 ) 152 153 return TYPE_TO_JAX_DTYPE[params[1]][ 154 array_type, " ".join(shape_anot) 155 ] 156 else: 157 raise Exception( 158 f"unexpected type for params:\n{cls.param_info(params)}" 159 ) 160 161 _BaseArray.__name__ = name 162 163 if _BaseArray.__doc__ is None: 164 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" 165 166 _BaseArray.__doc__ = _BaseArray.__doc__.format( 167 default_jax_dtype=repr(default_jax_dtype), 168 array_type=repr(array_type), 169 ) 170 171 return _BaseArray 172 173 174if typing.TYPE_CHECKING: 175 # these class definitions are only used here to make pylint happy, 176 # but they make mypy unhappy and there is no way to only run if not mypy 177 # so, later on we have more ignores 178 class ATensor(torch.Tensor): 179 @typing._tp_cache # type: ignore 180 def __class_getitem__(cls, params): 181 raise NotImplementedError() 182 183 class NDArray(torch.Tensor): 184 @typing._tp_cache # type: ignore 185 def __class_getitem__(cls, params): 186 raise NotImplementedError() 187 188 189ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment] 190 191NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment] 192 193 194def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 195 """convert numpy dtype to torch dtype""" 196 if isinstance(dtype, torch.dtype): 197 return dtype 198 else: 199 return torch.from_numpy(np.array(0, dtype=dtype)).dtype 200 201 202DTYPE_LIST: list = [ 203 *[ 204 bool, 205 int, 206 float, 207 ], 208 *[ 209 # ---------- 210 # pytorch 211 # ---------- 212 # floats 213 torch.float, 214 torch.float32, 215 torch.float64, 216 torch.half, 217 torch.double, 218 torch.bfloat16, 219 # complex 220 torch.complex64, 221 torch.complex128, 222 # ints 223 torch.int, 224 torch.int8, 225 torch.int16, 226 torch.int32, 227 torch.int64, 228 torch.long, 229 torch.short, 230 # simplest 231 torch.uint8, 232 torch.bool, 233 ], 234 *[ 235 # ---------- 236 # numpy 237 # ---------- 238 # floats 239 np.float16, 240 np.float32, 241 np.float64, 242 np.half, 243 np.single, 244 np.double, 245 # complex 246 np.complex64, 247 np.complex128, 248 # ints 249 np.int8, 250 np.int16, 251 np.int32, 252 np.int64, 253 np.longlong, 254 np.short, 255 # simplest 256 np.uint8, 257 np.bool_, 258 ], 259] 260"list of all the python, numpy, and torch numerical types I could think of" 261 262if np.version.version < "2.0.0": 263 DTYPE_LIST.extend([np.float_, np.int_]) 264 265DTYPE_MAP: dict = { 266 **{str(x): x for x in DTYPE_LIST}, 267 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"}, 268} 269"mapping from string representations of types to their type" 270 271TORCH_DTYPE_MAP: dict = { 272 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items() 273} 274"mapping from string representations of types to specifically torch types" 275 276# no idea why we have to do this, smh 277DTYPE_MAP["bool"] = np.bool_ 278TORCH_DTYPE_MAP["bool"] = torch.bool 279 280 281TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = { 282 "Adagrad": torch.optim.Adagrad, 283 "Adam": torch.optim.Adam, 284 "AdamW": torch.optim.AdamW, 285 "SparseAdam": torch.optim.SparseAdam, 286 "Adamax": torch.optim.Adamax, 287 "ASGD": torch.optim.ASGD, 288 "LBFGS": torch.optim.LBFGS, 289 "NAdam": torch.optim.NAdam, 290 "RAdam": torch.optim.RAdam, 291 "RMSprop": torch.optim.RMSprop, 292 "Rprop": torch.optim.Rprop, 293 "SGD": torch.optim.SGD, 294} 295 296 297def pad_tensor( 298 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 299 padded_length: int, 300 pad_value: float = 0.0, 301 rpad: bool = False, 302) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 303 """pad a 1-d tensor on the left with pad_value to length `padded_length` 304 305 set `rpad = True` to pad on the right instead""" 306 307 temp: list[torch.Tensor] = [ 308 torch.full( 309 (padded_length - tensor.shape[0],), 310 pad_value, 311 dtype=tensor.dtype, 312 device=tensor.device, 313 ), 314 tensor, 315 ] 316 317 if rpad: 318 temp.reverse() 319 320 return torch.cat(temp) 321 322 323def lpad_tensor( 324 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 325) -> torch.Tensor: 326 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 327 return pad_tensor(tensor, padded_length, pad_value, rpad=False) 328 329 330def rpad_tensor( 331 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 332) -> torch.Tensor: 333 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 334 return pad_tensor(tensor, pad_length, pad_value, rpad=True) 335 336 337def pad_array( 338 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 339 padded_length: int, 340 pad_value: float = 0.0, 341 rpad: bool = False, 342) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 343 """pad a 1-d array on the left with pad_value to length `padded_length` 344 345 set `rpad = True` to pad on the right instead""" 346 347 temp: list[np.ndarray] = [ 348 np.full( 349 (padded_length - array.shape[0],), 350 pad_value, 351 dtype=array.dtype, 352 ), 353 array, 354 ] 355 356 if rpad: 357 temp.reverse() 358 359 return np.concatenate(temp) 360 361 362def lpad_array( 363 array: np.ndarray, padded_length: int, pad_value: float = 0.0 364) -> np.ndarray: 365 """pad a 1-d array on the left with pad_value to length `padded_length`""" 366 return pad_array(array, padded_length, pad_value, rpad=False) 367 368 369def rpad_array( 370 array: np.ndarray, pad_length: int, pad_value: float = 0.0 371) -> np.ndarray: 372 """pad a 1-d array on the right with pad_value to length `pad_length`""" 373 return pad_array(array, pad_length, pad_value, rpad=True) 374 375 376def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 377 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 378 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) 379 380 381def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 382 """printable version of get_dict_shapes""" 383 return json.dumps( 384 dotlist_to_nested_dict( 385 { 386 k: str( 387 tuple(v.shape) 388 ) # to string, since indent wont play nice with tuples 389 for k, v in d.items() 390 } 391 ), 392 indent=2, 393 ) 394 395 396class StateDictCompareError(AssertionError): 397 """raised when state dicts don't match""" 398 399 pass 400 401 402class StateDictKeysError(StateDictCompareError): 403 """raised when state dict keys don't match""" 404 405 pass 406 407 408class StateDictShapeError(StateDictCompareError): 409 """raised when state dict shapes don't match""" 410 411 pass 412 413 414class StateDictValueError(StateDictCompareError): 415 """raised when state dict values don't match""" 416 417 pass 418 419 420def compare_state_dicts( 421 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True 422) -> None: 423 """compare two dicts of tensors 424 425 # Parameters: 426 427 - `d1 : dict` 428 - `d2 : dict` 429 - `rtol : float` 430 (defaults to `1e-5`) 431 - `atol : float` 432 (defaults to `1e-8`) 433 - `verbose : bool` 434 (defaults to `True`) 435 436 # Raises: 437 438 - `StateDictKeysError` : keys don't match 439 - `StateDictShapeError` : shapes don't match (but keys do) 440 - `StateDictValueError` : values don't match (but keys and shapes do) 441 """ 442 # check keys match 443 d1_keys: set = set(d1.keys()) 444 d2_keys: set = set(d2.keys()) 445 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys) 446 keys_diff_1: set = d1_keys - d2_keys 447 keys_diff_2: set = d2_keys - d1_keys 448 # sort sets for easier debugging 449 symmetric_diff = set(sorted(symmetric_diff)) 450 keys_diff_1 = set(sorted(keys_diff_1)) 451 keys_diff_2 = set(sorted(keys_diff_2)) 452 diff_shapes_1: str = ( 453 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 454 if verbose 455 else "(verbose = False)" 456 ) 457 diff_shapes_2: str = ( 458 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 459 if verbose 460 else "(verbose = False)" 461 ) 462 if not len(symmetric_diff) == 0: 463 raise StateDictKeysError( 464 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 465 ) 466 467 # check tensors match 468 shape_failed: list[str] = list() 469 vals_failed: list[str] = list() 470 for k, v1 in d1.items(): 471 v2 = d2[k] 472 # check shapes first 473 if not v1.shape == v2.shape: 474 shape_failed.append(k) 475 else: 476 # if shapes match, check values 477 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 478 vals_failed.append(k) 479 480 str_shape_failed: str = ( 481 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 482 ) 483 str_vals_failed: str = ( 484 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 485 ) 486 487 if not len(shape_failed) == 0: 488 raise StateDictShapeError( 489 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 490 ) 491 if not len(vals_failed) == 0: 492 raise StateDictValueError( 493 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 494 )
dict mapping python, numpy, and torch types to jaxtyping
types
78def jaxtype_factory( 79 name: str, 80 array_type: type, 81 default_jax_dtype=jaxtyping.Float, 82 legacy_mode: ErrorMode = ErrorMode.WARN, 83) -> type: 84 """usage: 85 ``` 86 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) 87 x: ATensor["dim1 dim2", np.float32] 88 ``` 89 """ 90 legacy_mode = ErrorMode.from_any(legacy_mode) 91 92 class _BaseArray: 93 """jaxtyping shorthand 94 (backwards compatible with older versions of muutils.tensor_utils) 95 96 default_jax_dtype = {default_jax_dtype} 97 array_type = {array_type} 98 """ 99 100 def __new__(cls, *args, **kwargs): 101 raise TypeError("Type FArray cannot be instantiated.") 102 103 def __init_subclass__(cls, *args, **kwargs): 104 raise TypeError(f"Cannot subclass {cls.__name__}") 105 106 @classmethod 107 def param_info(cls, params) -> str: 108 """useful for error printing""" 109 return "\n".join( 110 f"{k} = {v}" 111 for k, v in { 112 "cls.__name__": cls.__name__, 113 "cls.__doc__": cls.__doc__, 114 "params": params, 115 "type(params)": type(params), 116 }.items() 117 ) 118 119 @typing._tp_cache # type: ignore 120 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: # type: ignore 121 # MyTensor["dim1 dim2"] 122 if isinstance(params, str): 123 return default_jax_dtype[array_type, params] 124 125 elif isinstance(params, tuple): 126 if len(params) != 2: 127 raise Exception( 128 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" 129 ) 130 131 if isinstance(params[0], str): 132 # MyTensor["dim1 dim2", int] 133 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] 134 135 elif isinstance(params[0], tuple): 136 legacy_mode.process( 137 f"legacy type annotation was used:\n{cls.param_info(params) = }", 138 except_cls=Exception, 139 ) 140 # MyTensor[("dim1", "dim2"), int] 141 shape_anot: list[str] = list() 142 for x in params[0]: 143 if isinstance(x, str): 144 shape_anot.append(x) 145 elif isinstance(x, int): 146 shape_anot.append(str(x)) 147 elif isinstance(x, tuple): 148 shape_anot.append("".join(str(y) for y in x)) 149 else: 150 raise Exception( 151 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" 152 ) 153 154 return TYPE_TO_JAX_DTYPE[params[1]][ 155 array_type, " ".join(shape_anot) 156 ] 157 else: 158 raise Exception( 159 f"unexpected type for params:\n{cls.param_info(params)}" 160 ) 161 162 _BaseArray.__name__ = name 163 164 if _BaseArray.__doc__ is None: 165 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" 166 167 _BaseArray.__doc__ = _BaseArray.__doc__.format( 168 default_jax_dtype=repr(default_jax_dtype), 169 array_type=repr(array_type), 170 ) 171 172 return _BaseArray
usage:
ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
x: ATensor["dim1 dim2", np.float32]
195def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 196 """convert numpy dtype to torch dtype""" 197 if isinstance(dtype, torch.dtype): 198 return dtype 199 else: 200 return torch.from_numpy(np.array(0, dtype=dtype)).dtype
convert numpy dtype to torch dtype
list of all the python, numpy, and torch numerical types I could think of
mapping from string representations of types to their type
mapping from string representations of types to specifically torch types
298def pad_tensor( 299 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 300 padded_length: int, 301 pad_value: float = 0.0, 302 rpad: bool = False, 303) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 304 """pad a 1-d tensor on the left with pad_value to length `padded_length` 305 306 set `rpad = True` to pad on the right instead""" 307 308 temp: list[torch.Tensor] = [ 309 torch.full( 310 (padded_length - tensor.shape[0],), 311 pad_value, 312 dtype=tensor.dtype, 313 device=tensor.device, 314 ), 315 tensor, 316 ] 317 318 if rpad: 319 temp.reverse() 320 321 return torch.cat(temp)
pad a 1-d tensor on the left with pad_value to length padded_length
set rpad = True
to pad on the right instead
324def lpad_tensor( 325 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 326) -> torch.Tensor: 327 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 328 return pad_tensor(tensor, padded_length, pad_value, rpad=False)
pad a 1-d tensor on the left with pad_value to length padded_length
331def rpad_tensor( 332 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 333) -> torch.Tensor: 334 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 335 return pad_tensor(tensor, pad_length, pad_value, rpad=True)
pad a 1-d tensor on the right with pad_value to length pad_length
338def pad_array( 339 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 340 padded_length: int, 341 pad_value: float = 0.0, 342 rpad: bool = False, 343) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 344 """pad a 1-d array on the left with pad_value to length `padded_length` 345 346 set `rpad = True` to pad on the right instead""" 347 348 temp: list[np.ndarray] = [ 349 np.full( 350 (padded_length - array.shape[0],), 351 pad_value, 352 dtype=array.dtype, 353 ), 354 array, 355 ] 356 357 if rpad: 358 temp.reverse() 359 360 return np.concatenate(temp)
pad a 1-d array on the left with pad_value to length padded_length
set rpad = True
to pad on the right instead
363def lpad_array( 364 array: np.ndarray, padded_length: int, pad_value: float = 0.0 365) -> np.ndarray: 366 """pad a 1-d array on the left with pad_value to length `padded_length`""" 367 return pad_array(array, padded_length, pad_value, rpad=False)
pad a 1-d array on the left with pad_value to length padded_length
370def rpad_array( 371 array: np.ndarray, pad_length: int, pad_value: float = 0.0 372) -> np.ndarray: 373 """pad a 1-d array on the right with pad_value to length `pad_length`""" 374 return pad_array(array, pad_length, pad_value, rpad=True)
pad a 1-d array on the right with pad_value to length pad_length
377def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 378 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 379 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
given a state dict or cache dict, compute the shapes and put them in a nested dict
382def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 383 """printable version of get_dict_shapes""" 384 return json.dumps( 385 dotlist_to_nested_dict( 386 { 387 k: str( 388 tuple(v.shape) 389 ) # to string, since indent wont play nice with tuples 390 for k, v in d.items() 391 } 392 ), 393 indent=2, 394 )
printable version of get_dict_shapes
397class StateDictCompareError(AssertionError): 398 """raised when state dicts don't match""" 399 400 pass
raised when state dicts don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
403class StateDictKeysError(StateDictCompareError): 404 """raised when state dict keys don't match""" 405 406 pass
raised when state dict keys don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
409class StateDictShapeError(StateDictCompareError): 410 """raised when state dict shapes don't match""" 411 412 pass
raised when state dict shapes don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
415class StateDictValueError(StateDictCompareError): 416 """raised when state dict values don't match""" 417 418 pass
raised when state dict values don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
421def compare_state_dicts( 422 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True 423) -> None: 424 """compare two dicts of tensors 425 426 # Parameters: 427 428 - `d1 : dict` 429 - `d2 : dict` 430 - `rtol : float` 431 (defaults to `1e-5`) 432 - `atol : float` 433 (defaults to `1e-8`) 434 - `verbose : bool` 435 (defaults to `True`) 436 437 # Raises: 438 439 - `StateDictKeysError` : keys don't match 440 - `StateDictShapeError` : shapes don't match (but keys do) 441 - `StateDictValueError` : values don't match (but keys and shapes do) 442 """ 443 # check keys match 444 d1_keys: set = set(d1.keys()) 445 d2_keys: set = set(d2.keys()) 446 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys) 447 keys_diff_1: set = d1_keys - d2_keys 448 keys_diff_2: set = d2_keys - d1_keys 449 # sort sets for easier debugging 450 symmetric_diff = set(sorted(symmetric_diff)) 451 keys_diff_1 = set(sorted(keys_diff_1)) 452 keys_diff_2 = set(sorted(keys_diff_2)) 453 diff_shapes_1: str = ( 454 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 455 if verbose 456 else "(verbose = False)" 457 ) 458 diff_shapes_2: str = ( 459 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 460 if verbose 461 else "(verbose = False)" 462 ) 463 if not len(symmetric_diff) == 0: 464 raise StateDictKeysError( 465 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 466 ) 467 468 # check tensors match 469 shape_failed: list[str] = list() 470 vals_failed: list[str] = list() 471 for k, v1 in d1.items(): 472 v2 = d2[k] 473 # check shapes first 474 if not v1.shape == v2.shape: 475 shape_failed.append(k) 476 else: 477 # if shapes match, check values 478 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 479 vals_failed.append(k) 480 481 str_shape_failed: str = ( 482 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 483 ) 484 str_vals_failed: str = ( 485 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 486 ) 487 488 if not len(shape_failed) == 0: 489 raise StateDictShapeError( 490 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 491 ) 492 if not len(vals_failed) == 0: 493 raise StateDictValueError( 494 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 495 )
compare two dicts of tensors
Parameters:
d1 : dict
d2 : dict
rtol : float
(defaults to1e-5
)atol : float
(defaults to1e-8
)verbose : bool
(defaults toTrue
)
Raises:
StateDictKeysError
: keys don't matchStateDictShapeError
: shapes don't match (but keys do)StateDictValueError
: values don't match (but keys and shapes do)