docs for muutils v0.8.1
View Source on GitHub

muutils.tensor_utils

utilities for working with tensors and arrays.

notably:

  • TYPE_TO_JAX_DTYPE : a mapping from python, numpy, and torch types to jaxtyping types
  • DTYPE_MAP mapping string representations of types to their type
  • TORCH_DTYPE_MAP mapping string representations of types to torch types
  • compare_state_dicts for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match

  1"""utilities for working with tensors and arrays.
  2
  3notably:
  4
  5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types
  6- `DTYPE_MAP` mapping string representations of types to their type
  7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types
  8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
  9
 10"""
 11
 12from __future__ import annotations
 13
 14import json
 15import typing
 16
 17import jaxtyping
 18import numpy as np
 19import torch
 20
 21from muutils.errormode import ErrorMode
 22from muutils.dictmagic import dotlist_to_nested_dict
 23
 24# pylint: disable=missing-class-docstring
 25
 26
 27TYPE_TO_JAX_DTYPE: dict = {
 28    float: jaxtyping.Float,
 29    int: jaxtyping.Int,
 30    jaxtyping.Float: jaxtyping.Float,
 31    jaxtyping.Int: jaxtyping.Int,
 32    # bool
 33    bool: jaxtyping.Bool,
 34    jaxtyping.Bool: jaxtyping.Bool,
 35    np.bool_: jaxtyping.Bool,
 36    torch.bool: jaxtyping.Bool,
 37    # numpy float
 38    np.float16: jaxtyping.Float,
 39    np.float32: jaxtyping.Float,
 40    np.float64: jaxtyping.Float,
 41    np.half: jaxtyping.Float,
 42    np.single: jaxtyping.Float,
 43    np.double: jaxtyping.Float,
 44    # numpy int
 45    np.int8: jaxtyping.Int,
 46    np.int16: jaxtyping.Int,
 47    np.int32: jaxtyping.Int,
 48    np.int64: jaxtyping.Int,
 49    np.longlong: jaxtyping.Int,
 50    np.short: jaxtyping.Int,
 51    np.uint8: jaxtyping.Int,
 52    # torch float
 53    torch.float: jaxtyping.Float,
 54    torch.float16: jaxtyping.Float,
 55    torch.float32: jaxtyping.Float,
 56    torch.float64: jaxtyping.Float,
 57    torch.half: jaxtyping.Float,
 58    torch.double: jaxtyping.Float,
 59    torch.bfloat16: jaxtyping.Float,
 60    # torch int
 61    torch.int: jaxtyping.Int,
 62    torch.int8: jaxtyping.Int,
 63    torch.int16: jaxtyping.Int,
 64    torch.int32: jaxtyping.Int,
 65    torch.int64: jaxtyping.Int,
 66    torch.long: jaxtyping.Int,
 67    torch.short: jaxtyping.Int,
 68}
 69"dict mapping python, numpy, and torch types to `jaxtyping` types"
 70
 71if np.version.version < "2.0.0":
 72    TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float
 73    TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int
 74
 75
 76# TODO: add proper type annotations to this signature
 77def jaxtype_factory(
 78    name: str,
 79    array_type: type,
 80    default_jax_dtype=jaxtyping.Float,
 81    legacy_mode: ErrorMode = ErrorMode.WARN,
 82) -> type:
 83    """usage:
 84    ```
 85    ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
 86    x: ATensor["dim1 dim2", np.float32]
 87    ```
 88    """
 89    legacy_mode = ErrorMode.from_any(legacy_mode)
 90
 91    class _BaseArray:
 92        """jaxtyping shorthand
 93        (backwards compatible with older versions of muutils.tensor_utils)
 94
 95        default_jax_dtype = {default_jax_dtype}
 96        array_type = {array_type}
 97        """
 98
 99        def __new__(cls, *args, **kwargs):
100            raise TypeError("Type FArray cannot be instantiated.")
101
102        def __init_subclass__(cls, *args, **kwargs):
103            raise TypeError(f"Cannot subclass {cls.__name__}")
104
105        @classmethod
106        def param_info(cls, params) -> str:
107            """useful for error printing"""
108            return "\n".join(
109                f"{k} = {v}"
110                for k, v in {
111                    "cls.__name__": cls.__name__,
112                    "cls.__doc__": cls.__doc__,
113                    "params": params,
114                    "type(params)": type(params),
115                }.items()
116            )
117
118        @typing._tp_cache  # type: ignore
119        def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type:  # type: ignore
120            # MyTensor["dim1 dim2"]
121            if isinstance(params, str):
122                return default_jax_dtype[array_type, params]
123
124            elif isinstance(params, tuple):
125                if len(params) != 2:
126                    raise Exception(
127                        f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
128                    )
129
130                if isinstance(params[0], str):
131                    # MyTensor["dim1 dim2", int]
132                    return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]
133
134                elif isinstance(params[0], tuple):
135                    legacy_mode.process(
136                        f"legacy type annotation was used:\n{cls.param_info(params) = }",
137                        except_cls=Exception,
138                    )
139                    # MyTensor[("dim1", "dim2"), int]
140                    shape_anot: list[str] = list()
141                    for x in params[0]:
142                        if isinstance(x, str):
143                            shape_anot.append(x)
144                        elif isinstance(x, int):
145                            shape_anot.append(str(x))
146                        elif isinstance(x, tuple):
147                            shape_anot.append("".join(str(y) for y in x))
148                        else:
149                            raise Exception(
150                                f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
151                            )
152
153                    return TYPE_TO_JAX_DTYPE[params[1]][
154                        array_type, " ".join(shape_anot)
155                    ]
156            else:
157                raise Exception(
158                    f"unexpected type for params:\n{cls.param_info(params)}"
159                )
160
161    _BaseArray.__name__ = name
162
163    if _BaseArray.__doc__ is None:
164        _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"
165
166    _BaseArray.__doc__ = _BaseArray.__doc__.format(
167        default_jax_dtype=repr(default_jax_dtype),
168        array_type=repr(array_type),
169    )
170
171    return _BaseArray
172
173
174if typing.TYPE_CHECKING:
175    # these class definitions are only used here to make pylint happy,
176    # but they make mypy unhappy and there is no way to only run if not mypy
177    # so, later on we have more ignores
178    class ATensor(torch.Tensor):
179        @typing._tp_cache  # type: ignore
180        def __class_getitem__(cls, params):
181            raise NotImplementedError()
182
183    class NDArray(torch.Tensor):
184        @typing._tp_cache  # type: ignore
185        def __class_getitem__(cls, params):
186            raise NotImplementedError()
187
188
189ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)  # type: ignore[misc, assignment]
190
191NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float)  # type: ignore[misc, assignment]
192
193
194def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
195    """convert numpy dtype to torch dtype"""
196    if isinstance(dtype, torch.dtype):
197        return dtype
198    else:
199        return torch.from_numpy(np.array(0, dtype=dtype)).dtype
200
201
202DTYPE_LIST: list = [
203    *[
204        bool,
205        int,
206        float,
207    ],
208    *[
209        # ----------
210        # pytorch
211        # ----------
212        # floats
213        torch.float,
214        torch.float32,
215        torch.float64,
216        torch.half,
217        torch.double,
218        torch.bfloat16,
219        # complex
220        torch.complex64,
221        torch.complex128,
222        # ints
223        torch.int,
224        torch.int8,
225        torch.int16,
226        torch.int32,
227        torch.int64,
228        torch.long,
229        torch.short,
230        # simplest
231        torch.uint8,
232        torch.bool,
233    ],
234    *[
235        # ----------
236        # numpy
237        # ----------
238        # floats
239        np.float16,
240        np.float32,
241        np.float64,
242        np.half,
243        np.single,
244        np.double,
245        # complex
246        np.complex64,
247        np.complex128,
248        # ints
249        np.int8,
250        np.int16,
251        np.int32,
252        np.int64,
253        np.longlong,
254        np.short,
255        # simplest
256        np.uint8,
257        np.bool_,
258    ],
259]
260"list of all the python, numpy, and torch numerical types I could think of"
261
262if np.version.version < "2.0.0":
263    DTYPE_LIST.extend([np.float_, np.int_])
264
265DTYPE_MAP: dict = {
266    **{str(x): x for x in DTYPE_LIST},
267    **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"},
268}
269"mapping from string representations of types to their type"
270
271TORCH_DTYPE_MAP: dict = {
272    key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items()
273}
274"mapping from string representations of types to specifically torch types"
275
276# no idea why we have to do this, smh
277DTYPE_MAP["bool"] = np.bool_
278TORCH_DTYPE_MAP["bool"] = torch.bool
279
280
281TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = {
282    "Adagrad": torch.optim.Adagrad,
283    "Adam": torch.optim.Adam,
284    "AdamW": torch.optim.AdamW,
285    "SparseAdam": torch.optim.SparseAdam,
286    "Adamax": torch.optim.Adamax,
287    "ASGD": torch.optim.ASGD,
288    "LBFGS": torch.optim.LBFGS,
289    "NAdam": torch.optim.NAdam,
290    "RAdam": torch.optim.RAdam,
291    "RMSprop": torch.optim.RMSprop,
292    "Rprop": torch.optim.Rprop,
293    "SGD": torch.optim.SGD,
294}
295
296
297def pad_tensor(
298    tensor: jaxtyping.Shaped[torch.Tensor, "dim1"],  # noqa: F821
299    padded_length: int,
300    pad_value: float = 0.0,
301    rpad: bool = False,
302) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]:  # noqa: F821
303    """pad a 1-d tensor on the left with pad_value to length `padded_length`
304
305    set `rpad = True` to pad on the right instead"""
306
307    temp: list[torch.Tensor] = [
308        torch.full(
309            (padded_length - tensor.shape[0],),
310            pad_value,
311            dtype=tensor.dtype,
312            device=tensor.device,
313        ),
314        tensor,
315    ]
316
317    if rpad:
318        temp.reverse()
319
320    return torch.cat(temp)
321
322
323def lpad_tensor(
324    tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
325) -> torch.Tensor:
326    """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
327    return pad_tensor(tensor, padded_length, pad_value, rpad=False)
328
329
330def rpad_tensor(
331    tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
332) -> torch.Tensor:
333    """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
334    return pad_tensor(tensor, pad_length, pad_value, rpad=True)
335
336
337def pad_array(
338    array: jaxtyping.Shaped[np.ndarray, "dim1"],  # noqa: F821
339    padded_length: int,
340    pad_value: float = 0.0,
341    rpad: bool = False,
342) -> jaxtyping.Shaped[np.ndarray, "padded_length"]:  # noqa: F821
343    """pad a 1-d array on the left with pad_value to length `padded_length`
344
345    set `rpad = True` to pad on the right instead"""
346
347    temp: list[np.ndarray] = [
348        np.full(
349            (padded_length - array.shape[0],),
350            pad_value,
351            dtype=array.dtype,
352        ),
353        array,
354    ]
355
356    if rpad:
357        temp.reverse()
358
359    return np.concatenate(temp)
360
361
362def lpad_array(
363    array: np.ndarray, padded_length: int, pad_value: float = 0.0
364) -> np.ndarray:
365    """pad a 1-d array on the left with pad_value to length `padded_length`"""
366    return pad_array(array, padded_length, pad_value, rpad=False)
367
368
369def rpad_array(
370    array: np.ndarray, pad_length: int, pad_value: float = 0.0
371) -> np.ndarray:
372    """pad a 1-d array on the right with pad_value to length `pad_length`"""
373    return pad_array(array, pad_length, pad_value, rpad=True)
374
375
376def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
377    """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
378    return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
379
380
381def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
382    """printable version of get_dict_shapes"""
383    return json.dumps(
384        dotlist_to_nested_dict(
385            {
386                k: str(
387                    tuple(v.shape)
388                )  # to string, since indent wont play nice with tuples
389                for k, v in d.items()
390            }
391        ),
392        indent=2,
393    )
394
395
396class StateDictCompareError(AssertionError):
397    """raised when state dicts don't match"""
398
399    pass
400
401
402class StateDictKeysError(StateDictCompareError):
403    """raised when state dict keys don't match"""
404
405    pass
406
407
408class StateDictShapeError(StateDictCompareError):
409    """raised when state dict shapes don't match"""
410
411    pass
412
413
414class StateDictValueError(StateDictCompareError):
415    """raised when state dict values don't match"""
416
417    pass
418
419
420def compare_state_dicts(
421    d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
422) -> None:
423    """compare two dicts of tensors
424
425    # Parameters:
426
427     - `d1 : dict`
428     - `d2 : dict`
429     - `rtol : float`
430       (defaults to `1e-5`)
431     - `atol : float`
432       (defaults to `1e-8`)
433     - `verbose : bool`
434       (defaults to `True`)
435
436    # Raises:
437
438     - `StateDictKeysError` : keys don't match
439     - `StateDictShapeError` : shapes don't match (but keys do)
440     - `StateDictValueError` : values don't match (but keys and shapes do)
441    """
442    # check keys match
443    d1_keys: set = set(d1.keys())
444    d2_keys: set = set(d2.keys())
445    symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys)
446    keys_diff_1: set = d1_keys - d2_keys
447    keys_diff_2: set = d2_keys - d1_keys
448    # sort sets for easier debugging
449    symmetric_diff = set(sorted(symmetric_diff))
450    keys_diff_1 = set(sorted(keys_diff_1))
451    keys_diff_2 = set(sorted(keys_diff_2))
452    diff_shapes_1: str = (
453        string_dict_shapes({k: d1[k] for k in keys_diff_1})
454        if verbose
455        else "(verbose = False)"
456    )
457    diff_shapes_2: str = (
458        string_dict_shapes({k: d2[k] for k in keys_diff_2})
459        if verbose
460        else "(verbose = False)"
461    )
462    if not len(symmetric_diff) == 0:
463        raise StateDictKeysError(
464            f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
465        )
466
467    # check tensors match
468    shape_failed: list[str] = list()
469    vals_failed: list[str] = list()
470    for k, v1 in d1.items():
471        v2 = d2[k]
472        # check shapes first
473        if not v1.shape == v2.shape:
474            shape_failed.append(k)
475        else:
476            # if shapes match, check values
477            if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
478                vals_failed.append(k)
479
480    str_shape_failed: str = (
481        string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
482    )
483    str_vals_failed: str = (
484        string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
485    )
486
487    if not len(shape_failed) == 0:
488        raise StateDictShapeError(
489            f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
490        )
491    if not len(vals_failed) == 0:
492        raise StateDictValueError(
493            f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
494        )

TYPE_TO_JAX_DTYPE: dict = {<class 'float'>: <class 'jaxtyping.Float'>, <class 'int'>: <class 'jaxtyping.Int'>, <class 'jaxtyping.Float'>: <class 'jaxtyping.Float'>, <class 'jaxtyping.Int'>: <class 'jaxtyping.Int'>, <class 'bool'>: <class 'jaxtyping.Bool'>, <class 'jaxtyping.Bool'>: <class 'jaxtyping.Bool'>, <class 'numpy.bool_'>: <class 'jaxtyping.Bool'>, torch.bool: <class 'jaxtyping.Bool'>, <class 'numpy.float16'>: <class 'jaxtyping.Float'>, <class 'numpy.float32'>: <class 'jaxtyping.Float'>, <class 'numpy.float64'>: <class 'jaxtyping.Float'>, <class 'numpy.int8'>: <class 'jaxtyping.Int'>, <class 'numpy.int16'>: <class 'jaxtyping.Int'>, <class 'numpy.int32'>: <class 'jaxtyping.Int'>, <class 'numpy.int64'>: <class 'jaxtyping.Int'>, <class 'numpy.uint8'>: <class 'jaxtyping.Int'>, torch.float32: <class 'jaxtyping.Float'>, torch.float16: <class 'jaxtyping.Float'>, torch.float64: <class 'jaxtyping.Float'>, torch.bfloat16: <class 'jaxtyping.Float'>, torch.int32: <class 'jaxtyping.Int'>, torch.int8: <class 'jaxtyping.Int'>, torch.int16: <class 'jaxtyping.Int'>, torch.int64: <class 'jaxtyping.Int'>}

dict mapping python, numpy, and torch types to jaxtyping types

def jaxtype_factory( name: str, array_type: type, default_jax_dtype=<class 'jaxtyping.Float'>, legacy_mode: muutils.errormode.ErrorMode = ErrorMode.Warn) -> type:
 78def jaxtype_factory(
 79    name: str,
 80    array_type: type,
 81    default_jax_dtype=jaxtyping.Float,
 82    legacy_mode: ErrorMode = ErrorMode.WARN,
 83) -> type:
 84    """usage:
 85    ```
 86    ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
 87    x: ATensor["dim1 dim2", np.float32]
 88    ```
 89    """
 90    legacy_mode = ErrorMode.from_any(legacy_mode)
 91
 92    class _BaseArray:
 93        """jaxtyping shorthand
 94        (backwards compatible with older versions of muutils.tensor_utils)
 95
 96        default_jax_dtype = {default_jax_dtype}
 97        array_type = {array_type}
 98        """
 99
100        def __new__(cls, *args, **kwargs):
101            raise TypeError("Type FArray cannot be instantiated.")
102
103        def __init_subclass__(cls, *args, **kwargs):
104            raise TypeError(f"Cannot subclass {cls.__name__}")
105
106        @classmethod
107        def param_info(cls, params) -> str:
108            """useful for error printing"""
109            return "\n".join(
110                f"{k} = {v}"
111                for k, v in {
112                    "cls.__name__": cls.__name__,
113                    "cls.__doc__": cls.__doc__,
114                    "params": params,
115                    "type(params)": type(params),
116                }.items()
117            )
118
119        @typing._tp_cache  # type: ignore
120        def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type:  # type: ignore
121            # MyTensor["dim1 dim2"]
122            if isinstance(params, str):
123                return default_jax_dtype[array_type, params]
124
125            elif isinstance(params, tuple):
126                if len(params) != 2:
127                    raise Exception(
128                        f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
129                    )
130
131                if isinstance(params[0], str):
132                    # MyTensor["dim1 dim2", int]
133                    return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]
134
135                elif isinstance(params[0], tuple):
136                    legacy_mode.process(
137                        f"legacy type annotation was used:\n{cls.param_info(params) = }",
138                        except_cls=Exception,
139                    )
140                    # MyTensor[("dim1", "dim2"), int]
141                    shape_anot: list[str] = list()
142                    for x in params[0]:
143                        if isinstance(x, str):
144                            shape_anot.append(x)
145                        elif isinstance(x, int):
146                            shape_anot.append(str(x))
147                        elif isinstance(x, tuple):
148                            shape_anot.append("".join(str(y) for y in x))
149                        else:
150                            raise Exception(
151                                f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
152                            )
153
154                    return TYPE_TO_JAX_DTYPE[params[1]][
155                        array_type, " ".join(shape_anot)
156                    ]
157            else:
158                raise Exception(
159                    f"unexpected type for params:\n{cls.param_info(params)}"
160                )
161
162    _BaseArray.__name__ = name
163
164    if _BaseArray.__doc__ is None:
165        _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"
166
167    _BaseArray.__doc__ = _BaseArray.__doc__.format(
168        default_jax_dtype=repr(default_jax_dtype),
169        array_type=repr(array_type),
170    )
171
172    return _BaseArray

usage:

ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
x: ATensor["dim1 dim2", np.float32]
ATensor = <class 'jaxtype_factory.<locals>._BaseArray'>
NDArray = <class 'jaxtype_factory.<locals>._BaseArray'>
def numpy_to_torch_dtype(dtype: Union[numpy.dtype, torch.dtype]) -> torch.dtype:
195def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
196    """convert numpy dtype to torch dtype"""
197    if isinstance(dtype, torch.dtype):
198        return dtype
199    else:
200        return torch.from_numpy(np.array(0, dtype=dtype)).dtype

convert numpy dtype to torch dtype

DTYPE_LIST: list = [<class 'bool'>, <class 'int'>, <class 'float'>, torch.float32, torch.float32, torch.float64, torch.float16, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.int32, torch.int8, torch.int16, torch.int32, torch.int64, torch.int64, torch.int16, torch.uint8, torch.bool, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.complex64'>, <class 'numpy.complex128'>, <class 'numpy.int8'>, <class 'numpy.int16'>, <class 'numpy.int32'>, <class 'numpy.int64'>, <class 'numpy.int64'>, <class 'numpy.int16'>, <class 'numpy.uint8'>, <class 'numpy.bool_'>, <class 'numpy.float64'>, <class 'numpy.int32'>]

list of all the python, numpy, and torch numerical types I could think of

DTYPE_MAP: dict = {"<class 'bool'>": <class 'bool'>, "<class 'int'>": <class 'int'>, "<class 'float'>": <class 'float'>, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float16'>": <class 'numpy.float16'>, "<class 'numpy.float32'>": <class 'numpy.float32'>, "<class 'numpy.float64'>": <class 'numpy.float64'>, "<class 'numpy.complex64'>": <class 'numpy.complex64'>, "<class 'numpy.complex128'>": <class 'numpy.complex128'>, "<class 'numpy.int8'>": <class 'numpy.int8'>, "<class 'numpy.int16'>": <class 'numpy.int16'>, "<class 'numpy.int32'>": <class 'numpy.int32'>, "<class 'numpy.int64'>": <class 'numpy.int64'>, "<class 'numpy.uint8'>": <class 'numpy.uint8'>, "<class 'numpy.bool_'>": <class 'numpy.bool_'>, 'float16': <class 'numpy.float16'>, 'float32': <class 'numpy.float32'>, 'float64': <class 'numpy.float64'>, 'complex64': <class 'numpy.complex64'>, 'complex128': <class 'numpy.complex128'>, 'int8': <class 'numpy.int8'>, 'int16': <class 'numpy.int16'>, 'int32': <class 'numpy.int32'>, 'int64': <class 'numpy.int64'>, 'uint8': <class 'numpy.uint8'>, 'bool_': <class 'numpy.bool_'>, 'bool': <class 'numpy.bool_'>}

mapping from string representations of types to their type

TORCH_DTYPE_MAP: dict = {"<class 'bool'>": torch.bool, "<class 'int'>": torch.int32, "<class 'float'>": torch.float64, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float16'>": torch.float16, "<class 'numpy.float32'>": torch.float32, "<class 'numpy.float64'>": torch.float64, "<class 'numpy.complex64'>": torch.complex64, "<class 'numpy.complex128'>": torch.complex128, "<class 'numpy.int8'>": torch.int8, "<class 'numpy.int16'>": torch.int16, "<class 'numpy.int32'>": torch.int32, "<class 'numpy.int64'>": torch.int64, "<class 'numpy.uint8'>": torch.uint8, "<class 'numpy.bool_'>": torch.bool, 'float16': torch.float16, 'float32': torch.float32, 'float64': torch.float64, 'complex64': torch.complex64, 'complex128': torch.complex128, 'int8': torch.int8, 'int16': torch.int16, 'int32': torch.int32, 'int64': torch.int64, 'uint8': torch.uint8, 'bool_': torch.bool, 'bool': torch.bool}

mapping from string representations of types to specifically torch types

TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.optimizer.Optimizer]] = {'Adagrad': <class 'torch.optim.adagrad.Adagrad'>, 'Adam': <class 'torch.optim.adam.Adam'>, 'AdamW': <class 'torch.optim.adamw.AdamW'>, 'SparseAdam': <class 'torch.optim.sparse_adam.SparseAdam'>, 'Adamax': <class 'torch.optim.adamax.Adamax'>, 'ASGD': <class 'torch.optim.asgd.ASGD'>, 'LBFGS': <class 'torch.optim.lbfgs.LBFGS'>, 'NAdam': <class 'torch.optim.nadam.NAdam'>, 'RAdam': <class 'torch.optim.radam.RAdam'>, 'RMSprop': <class 'torch.optim.rmsprop.RMSprop'>, 'Rprop': <class 'torch.optim.rprop.Rprop'>, 'SGD': <class 'torch.optim.sgd.SGD'>}
def pad_tensor( tensor: jaxtyping.Shaped[Tensor, 'dim1'], padded_length: int, pad_value: float = 0.0, rpad: bool = False) -> jaxtyping.Shaped[Tensor, 'padded_length']:
298def pad_tensor(
299    tensor: jaxtyping.Shaped[torch.Tensor, "dim1"],  # noqa: F821
300    padded_length: int,
301    pad_value: float = 0.0,
302    rpad: bool = False,
303) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]:  # noqa: F821
304    """pad a 1-d tensor on the left with pad_value to length `padded_length`
305
306    set `rpad = True` to pad on the right instead"""
307
308    temp: list[torch.Tensor] = [
309        torch.full(
310            (padded_length - tensor.shape[0],),
311            pad_value,
312            dtype=tensor.dtype,
313            device=tensor.device,
314        ),
315        tensor,
316    ]
317
318    if rpad:
319        temp.reverse()
320
321    return torch.cat(temp)

pad a 1-d tensor on the left with pad_value to length padded_length

set rpad = True to pad on the right instead

def lpad_tensor( tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0) -> torch.Tensor:
324def lpad_tensor(
325    tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
326) -> torch.Tensor:
327    """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
328    return pad_tensor(tensor, padded_length, pad_value, rpad=False)

pad a 1-d tensor on the left with pad_value to length padded_length

def rpad_tensor( tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0) -> torch.Tensor:
331def rpad_tensor(
332    tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
333) -> torch.Tensor:
334    """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
335    return pad_tensor(tensor, pad_length, pad_value, rpad=True)

pad a 1-d tensor on the right with pad_value to length pad_length

def pad_array( array: jaxtyping.Shaped[ndarray, 'dim1'], padded_length: int, pad_value: float = 0.0, rpad: bool = False) -> jaxtyping.Shaped[ndarray, 'padded_length']:
338def pad_array(
339    array: jaxtyping.Shaped[np.ndarray, "dim1"],  # noqa: F821
340    padded_length: int,
341    pad_value: float = 0.0,
342    rpad: bool = False,
343) -> jaxtyping.Shaped[np.ndarray, "padded_length"]:  # noqa: F821
344    """pad a 1-d array on the left with pad_value to length `padded_length`
345
346    set `rpad = True` to pad on the right instead"""
347
348    temp: list[np.ndarray] = [
349        np.full(
350            (padded_length - array.shape[0],),
351            pad_value,
352            dtype=array.dtype,
353        ),
354        array,
355    ]
356
357    if rpad:
358        temp.reverse()
359
360    return np.concatenate(temp)

pad a 1-d array on the left with pad_value to length padded_length

set rpad = True to pad on the right instead

def lpad_array( array: numpy.ndarray, padded_length: int, pad_value: float = 0.0) -> numpy.ndarray:
363def lpad_array(
364    array: np.ndarray, padded_length: int, pad_value: float = 0.0
365) -> np.ndarray:
366    """pad a 1-d array on the left with pad_value to length `padded_length`"""
367    return pad_array(array, padded_length, pad_value, rpad=False)

pad a 1-d array on the left with pad_value to length padded_length

def rpad_array( array: numpy.ndarray, pad_length: int, pad_value: float = 0.0) -> numpy.ndarray:
370def rpad_array(
371    array: np.ndarray, pad_length: int, pad_value: float = 0.0
372) -> np.ndarray:
373    """pad a 1-d array on the right with pad_value to length `pad_length`"""
374    return pad_array(array, pad_length, pad_value, rpad=True)

pad a 1-d array on the right with pad_value to length pad_length

def get_dict_shapes(d: dict[str, torch.Tensor]) -> dict[str, tuple[int, ...]]:
377def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
378    """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
379    return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})

given a state dict or cache dict, compute the shapes and put them in a nested dict

def string_dict_shapes(d: dict[str, torch.Tensor]) -> str:
382def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
383    """printable version of get_dict_shapes"""
384    return json.dumps(
385        dotlist_to_nested_dict(
386            {
387                k: str(
388                    tuple(v.shape)
389                )  # to string, since indent wont play nice with tuples
390                for k, v in d.items()
391            }
392        ),
393        indent=2,
394    )

printable version of get_dict_shapes

class StateDictCompareError(builtins.AssertionError):
397class StateDictCompareError(AssertionError):
398    """raised when state dicts don't match"""
399
400    pass

raised when state dicts don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictKeysError(StateDictCompareError):
403class StateDictKeysError(StateDictCompareError):
404    """raised when state dict keys don't match"""
405
406    pass

raised when state dict keys don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictShapeError(StateDictCompareError):
409class StateDictShapeError(StateDictCompareError):
410    """raised when state dict shapes don't match"""
411
412    pass

raised when state dict shapes don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictValueError(StateDictCompareError):
415class StateDictValueError(StateDictCompareError):
416    """raised when state dict values don't match"""
417
418    pass

raised when state dict values don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
def compare_state_dicts( d1: dict, d2: dict, rtol: float = 1e-05, atol: float = 1e-08, verbose: bool = True) -> None:
421def compare_state_dicts(
422    d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
423) -> None:
424    """compare two dicts of tensors
425
426    # Parameters:
427
428     - `d1 : dict`
429     - `d2 : dict`
430     - `rtol : float`
431       (defaults to `1e-5`)
432     - `atol : float`
433       (defaults to `1e-8`)
434     - `verbose : bool`
435       (defaults to `True`)
436
437    # Raises:
438
439     - `StateDictKeysError` : keys don't match
440     - `StateDictShapeError` : shapes don't match (but keys do)
441     - `StateDictValueError` : values don't match (but keys and shapes do)
442    """
443    # check keys match
444    d1_keys: set = set(d1.keys())
445    d2_keys: set = set(d2.keys())
446    symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys)
447    keys_diff_1: set = d1_keys - d2_keys
448    keys_diff_2: set = d2_keys - d1_keys
449    # sort sets for easier debugging
450    symmetric_diff = set(sorted(symmetric_diff))
451    keys_diff_1 = set(sorted(keys_diff_1))
452    keys_diff_2 = set(sorted(keys_diff_2))
453    diff_shapes_1: str = (
454        string_dict_shapes({k: d1[k] for k in keys_diff_1})
455        if verbose
456        else "(verbose = False)"
457    )
458    diff_shapes_2: str = (
459        string_dict_shapes({k: d2[k] for k in keys_diff_2})
460        if verbose
461        else "(verbose = False)"
462    )
463    if not len(symmetric_diff) == 0:
464        raise StateDictKeysError(
465            f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
466        )
467
468    # check tensors match
469    shape_failed: list[str] = list()
470    vals_failed: list[str] = list()
471    for k, v1 in d1.items():
472        v2 = d2[k]
473        # check shapes first
474        if not v1.shape == v2.shape:
475            shape_failed.append(k)
476        else:
477            # if shapes match, check values
478            if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
479                vals_failed.append(k)
480
481    str_shape_failed: str = (
482        string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
483    )
484    str_vals_failed: str = (
485        string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
486    )
487
488    if not len(shape_failed) == 0:
489        raise StateDictShapeError(
490            f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
491        )
492    if not len(vals_failed) == 0:
493        raise StateDictValueError(
494            f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
495        )

compare two dicts of tensors

Parameters:

  • d1 : dict
  • d2 : dict
  • rtol : float (defaults to 1e-5)
  • atol : float (defaults to 1e-8)
  • verbose : bool (defaults to True)

Raises: