muutils.tensor_utils
utilities for working with tensors and arrays.
notably:
TYPE_TO_JAX_DTYPE
: a mapping from python, numpy, and torch types tojaxtyping
typesDTYPE_MAP
mapping string representations of types to their typeTORCH_DTYPE_MAP
mapping string representations of types to torch typescompare_state_dicts
for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
1"""utilities for working with tensors and arrays. 2 3notably: 4 5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types 6- `DTYPE_MAP` mapping string representations of types to their type 7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types 8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match 9 10""" 11 12from __future__ import annotations 13 14import json 15import typing 16 17import jaxtyping 18import numpy as np 19import torch 20 21from muutils.errormode import ErrorMode 22from muutils.dictmagic import dotlist_to_nested_dict 23 24# pylint: disable=missing-class-docstring 25 26 27TYPE_TO_JAX_DTYPE: dict = { 28 float: jaxtyping.Float, 29 int: jaxtyping.Int, 30 jaxtyping.Float: jaxtyping.Float, 31 jaxtyping.Int: jaxtyping.Int, 32 # bool 33 bool: jaxtyping.Bool, 34 jaxtyping.Bool: jaxtyping.Bool, 35 np.bool_: jaxtyping.Bool, 36 torch.bool: jaxtyping.Bool, 37 # numpy float 38 np.float_: jaxtyping.Float, 39 np.float16: jaxtyping.Float, 40 np.float32: jaxtyping.Float, 41 np.float64: jaxtyping.Float, 42 np.half: jaxtyping.Float, 43 np.single: jaxtyping.Float, 44 np.double: jaxtyping.Float, 45 # numpy int 46 np.int_: jaxtyping.Int, 47 np.int8: jaxtyping.Int, 48 np.int16: jaxtyping.Int, 49 np.int32: jaxtyping.Int, 50 np.int64: jaxtyping.Int, 51 np.longlong: jaxtyping.Int, 52 np.short: jaxtyping.Int, 53 np.uint8: jaxtyping.Int, 54 # torch float 55 torch.float: jaxtyping.Float, 56 torch.float16: jaxtyping.Float, 57 torch.float32: jaxtyping.Float, 58 torch.float64: jaxtyping.Float, 59 torch.half: jaxtyping.Float, 60 torch.double: jaxtyping.Float, 61 torch.bfloat16: jaxtyping.Float, 62 # torch int 63 torch.int: jaxtyping.Int, 64 torch.int8: jaxtyping.Int, 65 torch.int16: jaxtyping.Int, 66 torch.int32: jaxtyping.Int, 67 torch.int64: jaxtyping.Int, 68 torch.long: jaxtyping.Int, 69 torch.short: jaxtyping.Int, 70} 71"dict mapping python, numpy, and torch types to `jaxtyping` types" 72 73 74# TODO: add proper type annotations to this signature 75def jaxtype_factory( 76 name: str, 77 array_type: type, 78 default_jax_dtype=jaxtyping.Float, 79 legacy_mode: ErrorMode = ErrorMode.WARN, 80) -> type: 81 """usage: 82 ``` 83 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) 84 x: ATensor["dim1 dim2", np.float32] 85 ``` 86 """ 87 legacy_mode = ErrorMode.from_any(legacy_mode) 88 89 class _BaseArray: 90 """jaxtyping shorthand 91 (backwards compatible with older versions of muutils.tensor_utils) 92 93 default_jax_dtype = {default_jax_dtype} 94 array_type = {array_type} 95 """ 96 97 def __new__(cls, *args, **kwargs): 98 raise TypeError("Type FArray cannot be instantiated.") 99 100 def __init_subclass__(cls, *args, **kwargs): 101 raise TypeError(f"Cannot subclass {cls.__name__}") 102 103 @classmethod 104 def param_info(cls, params) -> str: 105 """useful for error printing""" 106 return "\n".join( 107 f"{k} = {v}" 108 for k, v in { 109 "cls.__name__": cls.__name__, 110 "cls.__doc__": cls.__doc__, 111 "params": params, 112 "type(params)": type(params), 113 }.items() 114 ) 115 116 @typing._tp_cache # type: ignore 117 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: 118 # MyTensor["dim1 dim2"] 119 if isinstance(params, str): 120 return default_jax_dtype[array_type, params] 121 122 elif isinstance(params, tuple): 123 if len(params) != 2: 124 raise Exception( 125 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" 126 ) 127 128 if isinstance(params[0], str): 129 # MyTensor["dim1 dim2", int] 130 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] 131 132 elif isinstance(params[0], tuple): 133 legacy_mode.process( 134 f"legacy type annotation was used:\n{cls.param_info(params) = }", 135 except_cls=Exception, 136 ) 137 # MyTensor[("dim1", "dim2"), int] 138 shape_anot: list[str] = list() 139 for x in params[0]: 140 if isinstance(x, str): 141 shape_anot.append(x) 142 elif isinstance(x, int): 143 shape_anot.append(str(x)) 144 elif isinstance(x, tuple): 145 shape_anot.append("".join(str(y) for y in x)) 146 else: 147 raise Exception( 148 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" 149 ) 150 151 return TYPE_TO_JAX_DTYPE[params[1]][ 152 array_type, " ".join(shape_anot) 153 ] 154 else: 155 raise Exception( 156 f"unexpected type for params:\n{cls.param_info(params)}" 157 ) 158 159 _BaseArray.__name__ = name 160 161 if _BaseArray.__doc__ is None: 162 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" 163 164 _BaseArray.__doc__ = _BaseArray.__doc__.format( 165 default_jax_dtype=repr(default_jax_dtype), 166 array_type=repr(array_type), 167 ) 168 169 return _BaseArray 170 171 172if typing.TYPE_CHECKING: 173 # these class definitions are only used here to make pylint happy, 174 # but they make mypy unhappy and there is no way to only run if not mypy 175 # so, later on we have more ignores 176 class ATensor(torch.Tensor): 177 @typing._tp_cache # type: ignore 178 def __class_getitem__(cls, params): 179 raise NotImplementedError() 180 181 class NDArray(torch.Tensor): 182 @typing._tp_cache # type: ignore 183 def __class_getitem__(cls, params): 184 raise NotImplementedError() 185 186 187ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) # type: ignore[misc, assignment] 188 189NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float) # type: ignore[misc, assignment] 190 191 192def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 193 """convert numpy dtype to torch dtype""" 194 if isinstance(dtype, torch.dtype): 195 return dtype 196 else: 197 return torch.from_numpy(np.array(0, dtype=dtype)).dtype 198 199 200DTYPE_LIST: list = [ 201 *[ 202 bool, 203 int, 204 float, 205 ], 206 *[ 207 # ---------- 208 # pytorch 209 # ---------- 210 # floats 211 torch.float, 212 torch.float32, 213 torch.float64, 214 torch.half, 215 torch.double, 216 torch.bfloat16, 217 # complex 218 torch.complex64, 219 torch.complex128, 220 # ints 221 torch.int, 222 torch.int8, 223 torch.int16, 224 torch.int32, 225 torch.int64, 226 torch.long, 227 torch.short, 228 # simplest 229 torch.uint8, 230 torch.bool, 231 ], 232 *[ 233 # ---------- 234 # numpy 235 # ---------- 236 # floats 237 np.float_, 238 np.float16, 239 np.float32, 240 np.float64, 241 np.half, 242 np.single, 243 np.double, 244 # complex 245 np.complex64, 246 np.complex128, 247 # ints 248 np.int8, 249 np.int16, 250 np.int32, 251 np.int64, 252 np.int_, 253 np.longlong, 254 np.short, 255 # simplest 256 np.uint8, 257 np.bool_, 258 ], 259] 260"list of all the python, numpy, and torch numerical types I could think of" 261 262DTYPE_MAP: dict = { 263 **{str(x): x for x in DTYPE_LIST}, 264 **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"}, 265} 266"mapping from string representations of types to their type" 267 268TORCH_DTYPE_MAP: dict = { 269 key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items() 270} 271"mapping from string representations of types to specifically torch types" 272 273# no idea why we have to do this, smh 274DTYPE_MAP["bool"] = np.bool_ 275TORCH_DTYPE_MAP["bool"] = torch.bool 276 277 278TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = { 279 "Adagrad": torch.optim.Adagrad, 280 "Adam": torch.optim.Adam, 281 "AdamW": torch.optim.AdamW, 282 "SparseAdam": torch.optim.SparseAdam, 283 "Adamax": torch.optim.Adamax, 284 "ASGD": torch.optim.ASGD, 285 "LBFGS": torch.optim.LBFGS, 286 "NAdam": torch.optim.NAdam, 287 "RAdam": torch.optim.RAdam, 288 "RMSprop": torch.optim.RMSprop, 289 "Rprop": torch.optim.Rprop, 290 "SGD": torch.optim.SGD, 291} 292 293 294def pad_tensor( 295 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 296 padded_length: int, 297 pad_value: float = 0.0, 298 rpad: bool = False, 299) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 300 """pad a 1-d tensor on the left with pad_value to length `padded_length` 301 302 set `rpad = True` to pad on the right instead""" 303 304 temp: list[torch.Tensor] = [ 305 torch.full( 306 (padded_length - tensor.shape[0],), 307 pad_value, 308 dtype=tensor.dtype, 309 device=tensor.device, 310 ), 311 tensor, 312 ] 313 314 if rpad: 315 temp.reverse() 316 317 return torch.cat(temp) 318 319 320def lpad_tensor( 321 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 322) -> torch.Tensor: 323 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 324 return pad_tensor(tensor, padded_length, pad_value, rpad=False) 325 326 327def rpad_tensor( 328 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 329) -> torch.Tensor: 330 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 331 return pad_tensor(tensor, pad_length, pad_value, rpad=True) 332 333 334def pad_array( 335 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 336 padded_length: int, 337 pad_value: float = 0.0, 338 rpad: bool = False, 339) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 340 """pad a 1-d array on the left with pad_value to length `padded_length` 341 342 set `rpad = True` to pad on the right instead""" 343 344 temp: list[np.ndarray] = [ 345 np.full( 346 (padded_length - array.shape[0],), 347 pad_value, 348 dtype=array.dtype, 349 ), 350 array, 351 ] 352 353 if rpad: 354 temp.reverse() 355 356 return np.concatenate(temp) 357 358 359def lpad_array( 360 array: np.ndarray, padded_length: int, pad_value: float = 0.0 361) -> np.ndarray: 362 """pad a 1-d array on the left with pad_value to length `padded_length`""" 363 return pad_array(array, padded_length, pad_value, rpad=False) 364 365 366def rpad_array( 367 array: np.ndarray, pad_length: int, pad_value: float = 0.0 368) -> np.ndarray: 369 """pad a 1-d array on the right with pad_value to length `pad_length`""" 370 return pad_array(array, pad_length, pad_value, rpad=True) 371 372 373def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 374 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 375 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()}) 376 377 378def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 379 """printable version of get_dict_shapes""" 380 return json.dumps( 381 dotlist_to_nested_dict( 382 { 383 k: str( 384 tuple(v.shape) 385 ) # to string, since indent wont play nice with tuples 386 for k, v in d.items() 387 } 388 ), 389 indent=2, 390 ) 391 392 393class StateDictCompareError(AssertionError): 394 """raised when state dicts don't match""" 395 396 pass 397 398 399class StateDictKeysError(StateDictCompareError): 400 """raised when state dict keys don't match""" 401 402 pass 403 404 405class StateDictShapeError(StateDictCompareError): 406 """raised when state dict shapes don't match""" 407 408 pass 409 410 411class StateDictValueError(StateDictCompareError): 412 """raised when state dict values don't match""" 413 414 pass 415 416 417def compare_state_dicts( 418 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True 419) -> None: 420 """compare two dicts of tensors 421 422 # Parameters: 423 424 - `d1 : dict` 425 - `d2 : dict` 426 - `rtol : float` 427 (defaults to `1e-5`) 428 - `atol : float` 429 (defaults to `1e-8`) 430 - `verbose : bool` 431 (defaults to `True`) 432 433 # Raises: 434 435 - `StateDictKeysError` : keys don't match 436 - `StateDictShapeError` : shapes don't match (but keys do) 437 - `StateDictValueError` : values don't match (but keys and shapes do) 438 """ 439 # check keys match 440 d1_keys: set = set(d1.keys()) 441 d2_keys: set = set(d2.keys()) 442 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys) 443 keys_diff_1: set = d1_keys - d2_keys 444 keys_diff_2: set = d2_keys - d1_keys 445 # sort sets for easier debugging 446 symmetric_diff = set(sorted(symmetric_diff)) 447 keys_diff_1 = set(sorted(keys_diff_1)) 448 keys_diff_2 = set(sorted(keys_diff_2)) 449 diff_shapes_1: str = ( 450 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 451 if verbose 452 else "(verbose = False)" 453 ) 454 diff_shapes_2: str = ( 455 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 456 if verbose 457 else "(verbose = False)" 458 ) 459 if not len(symmetric_diff) == 0: 460 raise StateDictKeysError( 461 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 462 ) 463 464 # check tensors match 465 shape_failed: list[str] = list() 466 vals_failed: list[str] = list() 467 for k, v1 in d1.items(): 468 v2 = d2[k] 469 # check shapes first 470 if not v1.shape == v2.shape: 471 shape_failed.append(k) 472 else: 473 # if shapes match, check values 474 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 475 vals_failed.append(k) 476 477 str_shape_failed: str = ( 478 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 479 ) 480 str_vals_failed: str = ( 481 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 482 ) 483 484 if not len(shape_failed) == 0: 485 raise StateDictShapeError( 486 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 487 ) 488 if not len(vals_failed) == 0: 489 raise StateDictValueError( 490 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 491 )
dict mapping python, numpy, and torch types to jaxtyping
types
76def jaxtype_factory( 77 name: str, 78 array_type: type, 79 default_jax_dtype=jaxtyping.Float, 80 legacy_mode: ErrorMode = ErrorMode.WARN, 81) -> type: 82 """usage: 83 ``` 84 ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float) 85 x: ATensor["dim1 dim2", np.float32] 86 ``` 87 """ 88 legacy_mode = ErrorMode.from_any(legacy_mode) 89 90 class _BaseArray: 91 """jaxtyping shorthand 92 (backwards compatible with older versions of muutils.tensor_utils) 93 94 default_jax_dtype = {default_jax_dtype} 95 array_type = {array_type} 96 """ 97 98 def __new__(cls, *args, **kwargs): 99 raise TypeError("Type FArray cannot be instantiated.") 100 101 def __init_subclass__(cls, *args, **kwargs): 102 raise TypeError(f"Cannot subclass {cls.__name__}") 103 104 @classmethod 105 def param_info(cls, params) -> str: 106 """useful for error printing""" 107 return "\n".join( 108 f"{k} = {v}" 109 for k, v in { 110 "cls.__name__": cls.__name__, 111 "cls.__doc__": cls.__doc__, 112 "params": params, 113 "type(params)": type(params), 114 }.items() 115 ) 116 117 @typing._tp_cache # type: ignore 118 def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type: 119 # MyTensor["dim1 dim2"] 120 if isinstance(params, str): 121 return default_jax_dtype[array_type, params] 122 123 elif isinstance(params, tuple): 124 if len(params) != 2: 125 raise Exception( 126 f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}" 127 ) 128 129 if isinstance(params[0], str): 130 # MyTensor["dim1 dim2", int] 131 return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]] 132 133 elif isinstance(params[0], tuple): 134 legacy_mode.process( 135 f"legacy type annotation was used:\n{cls.param_info(params) = }", 136 except_cls=Exception, 137 ) 138 # MyTensor[("dim1", "dim2"), int] 139 shape_anot: list[str] = list() 140 for x in params[0]: 141 if isinstance(x, str): 142 shape_anot.append(x) 143 elif isinstance(x, int): 144 shape_anot.append(str(x)) 145 elif isinstance(x, tuple): 146 shape_anot.append("".join(str(y) for y in x)) 147 else: 148 raise Exception( 149 f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}" 150 ) 151 152 return TYPE_TO_JAX_DTYPE[params[1]][ 153 array_type, " ".join(shape_anot) 154 ] 155 else: 156 raise Exception( 157 f"unexpected type for params:\n{cls.param_info(params)}" 158 ) 159 160 _BaseArray.__name__ = name 161 162 if _BaseArray.__doc__ is None: 163 _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }" 164 165 _BaseArray.__doc__ = _BaseArray.__doc__.format( 166 default_jax_dtype=repr(default_jax_dtype), 167 array_type=repr(array_type), 168 ) 169 170 return _BaseArray
usage:
ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
x: ATensor["dim1 dim2", np.float32]
193def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype: 194 """convert numpy dtype to torch dtype""" 195 if isinstance(dtype, torch.dtype): 196 return dtype 197 else: 198 return torch.from_numpy(np.array(0, dtype=dtype)).dtype
convert numpy dtype to torch dtype
list of all the python, numpy, and torch numerical types I could think of
mapping from string representations of types to their type
mapping from string representations of types to specifically torch types
295def pad_tensor( 296 tensor: jaxtyping.Shaped[torch.Tensor, "dim1"], # noqa: F821 297 padded_length: int, 298 pad_value: float = 0.0, 299 rpad: bool = False, 300) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]: # noqa: F821 301 """pad a 1-d tensor on the left with pad_value to length `padded_length` 302 303 set `rpad = True` to pad on the right instead""" 304 305 temp: list[torch.Tensor] = [ 306 torch.full( 307 (padded_length - tensor.shape[0],), 308 pad_value, 309 dtype=tensor.dtype, 310 device=tensor.device, 311 ), 312 tensor, 313 ] 314 315 if rpad: 316 temp.reverse() 317 318 return torch.cat(temp)
pad a 1-d tensor on the left with pad_value to length padded_length
set rpad = True
to pad on the right instead
321def lpad_tensor( 322 tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0 323) -> torch.Tensor: 324 """pad a 1-d tensor on the left with pad_value to length `padded_length`""" 325 return pad_tensor(tensor, padded_length, pad_value, rpad=False)
pad a 1-d tensor on the left with pad_value to length padded_length
328def rpad_tensor( 329 tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0 330) -> torch.Tensor: 331 """pad a 1-d tensor on the right with pad_value to length `pad_length`""" 332 return pad_tensor(tensor, pad_length, pad_value, rpad=True)
pad a 1-d tensor on the right with pad_value to length pad_length
335def pad_array( 336 array: jaxtyping.Shaped[np.ndarray, "dim1"], # noqa: F821 337 padded_length: int, 338 pad_value: float = 0.0, 339 rpad: bool = False, 340) -> jaxtyping.Shaped[np.ndarray, "padded_length"]: # noqa: F821 341 """pad a 1-d array on the left with pad_value to length `padded_length` 342 343 set `rpad = True` to pad on the right instead""" 344 345 temp: list[np.ndarray] = [ 346 np.full( 347 (padded_length - array.shape[0],), 348 pad_value, 349 dtype=array.dtype, 350 ), 351 array, 352 ] 353 354 if rpad: 355 temp.reverse() 356 357 return np.concatenate(temp)
pad a 1-d array on the left with pad_value to length padded_length
set rpad = True
to pad on the right instead
360def lpad_array( 361 array: np.ndarray, padded_length: int, pad_value: float = 0.0 362) -> np.ndarray: 363 """pad a 1-d array on the left with pad_value to length `padded_length`""" 364 return pad_array(array, padded_length, pad_value, rpad=False)
pad a 1-d array on the left with pad_value to length padded_length
367def rpad_array( 368 array: np.ndarray, pad_length: int, pad_value: float = 0.0 369) -> np.ndarray: 370 """pad a 1-d array on the right with pad_value to length `pad_length`""" 371 return pad_array(array, pad_length, pad_value, rpad=True)
pad a 1-d array on the right with pad_value to length pad_length
374def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]: 375 """given a state dict or cache dict, compute the shapes and put them in a nested dict""" 376 return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
given a state dict or cache dict, compute the shapes and put them in a nested dict
379def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str: 380 """printable version of get_dict_shapes""" 381 return json.dumps( 382 dotlist_to_nested_dict( 383 { 384 k: str( 385 tuple(v.shape) 386 ) # to string, since indent wont play nice with tuples 387 for k, v in d.items() 388 } 389 ), 390 indent=2, 391 )
printable version of get_dict_shapes
394class StateDictCompareError(AssertionError): 395 """raised when state dicts don't match""" 396 397 pass
raised when state dicts don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
400class StateDictKeysError(StateDictCompareError): 401 """raised when state dict keys don't match""" 402 403 pass
raised when state dict keys don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
406class StateDictShapeError(StateDictCompareError): 407 """raised when state dict shapes don't match""" 408 409 pass
raised when state dict shapes don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
412class StateDictValueError(StateDictCompareError): 413 """raised when state dict values don't match""" 414 415 pass
raised when state dict values don't match
Inherited Members
- builtins.AssertionError
- AssertionError
- builtins.BaseException
- with_traceback
- add_note
- args
418def compare_state_dicts( 419 d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True 420) -> None: 421 """compare two dicts of tensors 422 423 # Parameters: 424 425 - `d1 : dict` 426 - `d2 : dict` 427 - `rtol : float` 428 (defaults to `1e-5`) 429 - `atol : float` 430 (defaults to `1e-8`) 431 - `verbose : bool` 432 (defaults to `True`) 433 434 # Raises: 435 436 - `StateDictKeysError` : keys don't match 437 - `StateDictShapeError` : shapes don't match (but keys do) 438 - `StateDictValueError` : values don't match (but keys and shapes do) 439 """ 440 # check keys match 441 d1_keys: set = set(d1.keys()) 442 d2_keys: set = set(d2.keys()) 443 symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys) 444 keys_diff_1: set = d1_keys - d2_keys 445 keys_diff_2: set = d2_keys - d1_keys 446 # sort sets for easier debugging 447 symmetric_diff = set(sorted(symmetric_diff)) 448 keys_diff_1 = set(sorted(keys_diff_1)) 449 keys_diff_2 = set(sorted(keys_diff_2)) 450 diff_shapes_1: str = ( 451 string_dict_shapes({k: d1[k] for k in keys_diff_1}) 452 if verbose 453 else "(verbose = False)" 454 ) 455 diff_shapes_2: str = ( 456 string_dict_shapes({k: d2[k] for k in keys_diff_2}) 457 if verbose 458 else "(verbose = False)" 459 ) 460 if not len(symmetric_diff) == 0: 461 raise StateDictKeysError( 462 f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}" 463 ) 464 465 # check tensors match 466 shape_failed: list[str] = list() 467 vals_failed: list[str] = list() 468 for k, v1 in d1.items(): 469 v2 = d2[k] 470 # check shapes first 471 if not v1.shape == v2.shape: 472 shape_failed.append(k) 473 else: 474 # if shapes match, check values 475 if not torch.allclose(v1, v2, rtol=rtol, atol=atol): 476 vals_failed.append(k) 477 478 str_shape_failed: str = ( 479 string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else "" 480 ) 481 str_vals_failed: str = ( 482 string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else "" 483 ) 484 485 if not len(shape_failed) == 0: 486 raise StateDictShapeError( 487 f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}" 488 ) 489 if not len(vals_failed) == 0: 490 raise StateDictValueError( 491 f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}" 492 )
compare two dicts of tensors
Parameters:
d1 : dict
d2 : dict
rtol : float
(defaults to1e-5
)atol : float
(defaults to1e-8
)verbose : bool
(defaults toTrue
)
Raises:
StateDictKeysError
: keys don't matchStateDictShapeError
: shapes don't match (but keys do)StateDictValueError
: values don't match (but keys and shapes do)