docs for muutils v0.8.7
View Source on GitHub

muutils.tensor_utils

utilities for working with tensors and arrays.

notably:

  • TYPE_TO_JAX_DTYPE : a mapping from python, numpy, and torch types to jaxtyping types
  • DTYPE_MAP mapping string representations of types to their type
  • TORCH_DTYPE_MAP mapping string representations of types to torch types
  • compare_state_dicts for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match

  1"""utilities for working with tensors and arrays.
  2
  3notably:
  4
  5- `TYPE_TO_JAX_DTYPE` : a mapping from python, numpy, and torch types to `jaxtyping` types
  6- `DTYPE_MAP` mapping string representations of types to their type
  7- `TORCH_DTYPE_MAP` mapping string representations of types to torch types
  8- `compare_state_dicts` for comparing two state dicts and giving a detailed error message on whether if was keys, shapes, or values that didn't match
  9
 10"""
 11
 12from __future__ import annotations
 13
 14import json
 15import typing
 16
 17import jaxtyping
 18import numpy as np
 19import torch
 20
 21from muutils.errormode import ErrorMode
 22from muutils.dictmagic import dotlist_to_nested_dict
 23
 24# pylint: disable=missing-class-docstring
 25
 26
 27TYPE_TO_JAX_DTYPE: dict = {
 28    float: jaxtyping.Float,
 29    int: jaxtyping.Int,
 30    jaxtyping.Float: jaxtyping.Float,
 31    jaxtyping.Int: jaxtyping.Int,
 32    # bool
 33    bool: jaxtyping.Bool,
 34    jaxtyping.Bool: jaxtyping.Bool,
 35    np.bool_: jaxtyping.Bool,
 36    torch.bool: jaxtyping.Bool,
 37    # numpy float
 38    np.float16: jaxtyping.Float,
 39    np.float32: jaxtyping.Float,
 40    np.float64: jaxtyping.Float,
 41    np.half: jaxtyping.Float,
 42    np.single: jaxtyping.Float,
 43    np.double: jaxtyping.Float,
 44    # numpy int
 45    np.int8: jaxtyping.Int,
 46    np.int16: jaxtyping.Int,
 47    np.int32: jaxtyping.Int,
 48    np.int64: jaxtyping.Int,
 49    np.longlong: jaxtyping.Int,
 50    np.short: jaxtyping.Int,
 51    np.uint8: jaxtyping.Int,
 52    # torch float
 53    torch.float: jaxtyping.Float,
 54    torch.float16: jaxtyping.Float,
 55    torch.float32: jaxtyping.Float,
 56    torch.float64: jaxtyping.Float,
 57    torch.half: jaxtyping.Float,
 58    torch.double: jaxtyping.Float,
 59    torch.bfloat16: jaxtyping.Float,
 60    # torch int
 61    torch.int: jaxtyping.Int,
 62    torch.int8: jaxtyping.Int,
 63    torch.int16: jaxtyping.Int,
 64    torch.int32: jaxtyping.Int,
 65    torch.int64: jaxtyping.Int,
 66    torch.long: jaxtyping.Int,
 67    torch.short: jaxtyping.Int,
 68}
 69"dict mapping python, numpy, and torch types to `jaxtyping` types"
 70
 71if np.version.version < "2.0.0":
 72    TYPE_TO_JAX_DTYPE[np.float_] = jaxtyping.Float
 73    TYPE_TO_JAX_DTYPE[np.int_] = jaxtyping.Int
 74
 75
 76# TODO: add proper type annotations to this signature
 77# TODO: maybe get rid of this altogether?
 78def jaxtype_factory(
 79    name: str,
 80    array_type: type,
 81    default_jax_dtype=jaxtyping.Float,
 82    legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN,
 83) -> type:
 84    """usage:
 85    ```
 86    ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
 87    x: ATensor["dim1 dim2", np.float32]
 88    ```
 89    """
 90    legacy_mode_ = ErrorMode.from_any(legacy_mode)
 91
 92    class _BaseArray:
 93        """jaxtyping shorthand
 94        (backwards compatible with older versions of muutils.tensor_utils)
 95
 96        default_jax_dtype = {default_jax_dtype}
 97        array_type = {array_type}
 98        """
 99
100        def __new__(cls, *args, **kwargs):
101            raise TypeError("Type FArray cannot be instantiated.")
102
103        def __init_subclass__(cls, *args, **kwargs):
104            raise TypeError(f"Cannot subclass {cls.__name__}")
105
106        @classmethod
107        def param_info(cls, params) -> str:
108            """useful for error printing"""
109            return "\n".join(
110                f"{k} = {v}"
111                for k, v in {
112                    "cls.__name__": cls.__name__,
113                    "cls.__doc__": cls.__doc__,
114                    "params": params,
115                    "type(params)": type(params),
116                }.items()
117            )
118
119        @typing._tp_cache  # type: ignore
120        def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type:  # type: ignore
121            # MyTensor["dim1 dim2"]
122            if isinstance(params, str):
123                return default_jax_dtype[array_type, params]
124
125            elif isinstance(params, tuple):
126                if len(params) != 2:
127                    raise Exception(
128                        f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
129                    )
130
131                if isinstance(params[0], str):
132                    # MyTensor["dim1 dim2", int]
133                    return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]
134
135                elif isinstance(params[0], tuple):
136                    legacy_mode_.process(
137                        f"legacy type annotation was used:\n{cls.param_info(params) = }",
138                        except_cls=Exception,
139                    )
140                    # MyTensor[("dim1", "dim2"), int]
141                    shape_anot: list[str] = list()
142                    for x in params[0]:
143                        if isinstance(x, str):
144                            shape_anot.append(x)
145                        elif isinstance(x, int):
146                            shape_anot.append(str(x))
147                        elif isinstance(x, tuple):
148                            shape_anot.append("".join(str(y) for y in x))
149                        else:
150                            raise Exception(
151                                f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
152                            )
153
154                    return TYPE_TO_JAX_DTYPE[params[1]][
155                        array_type, " ".join(shape_anot)
156                    ]
157            else:
158                raise Exception(
159                    f"unexpected type for params:\n{cls.param_info(params)}"
160                )
161
162    _BaseArray.__name__ = name
163
164    if _BaseArray.__doc__ is None:
165        _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"
166
167    _BaseArray.__doc__ = _BaseArray.__doc__.format(
168        default_jax_dtype=repr(default_jax_dtype),
169        array_type=repr(array_type),
170    )
171
172    return _BaseArray
173
174
175if typing.TYPE_CHECKING:
176    # these class definitions are only used here to make pylint happy,
177    # but they make mypy unhappy and there is no way to only run if not mypy
178    # so, later on we have more ignores
179    class ATensor(torch.Tensor):
180        @typing._tp_cache  # type: ignore
181        def __class_getitem__(cls, params):
182            raise NotImplementedError()
183
184    class NDArray(torch.Tensor):
185        @typing._tp_cache  # type: ignore
186        def __class_getitem__(cls, params):
187            raise NotImplementedError()
188
189
190ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)  # type: ignore[misc, assignment]
191
192NDArray = jaxtype_factory("NDArray", np.ndarray, jaxtyping.Float)  # type: ignore[misc, assignment]
193
194
195def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
196    """convert numpy dtype to torch dtype"""
197    if isinstance(dtype, torch.dtype):
198        return dtype
199    else:
200        return torch.from_numpy(np.array(0, dtype=dtype)).dtype
201
202
203DTYPE_LIST: list = [
204    *[
205        bool,
206        int,
207        float,
208    ],
209    *[
210        # ----------
211        # pytorch
212        # ----------
213        # floats
214        torch.float,
215        torch.float32,
216        torch.float64,
217        torch.half,
218        torch.double,
219        torch.bfloat16,
220        # complex
221        torch.complex64,
222        torch.complex128,
223        # ints
224        torch.int,
225        torch.int8,
226        torch.int16,
227        torch.int32,
228        torch.int64,
229        torch.long,
230        torch.short,
231        # simplest
232        torch.uint8,
233        torch.bool,
234    ],
235    *[
236        # ----------
237        # numpy
238        # ----------
239        # floats
240        np.float16,
241        np.float32,
242        np.float64,
243        np.half,
244        np.single,
245        np.double,
246        # complex
247        np.complex64,
248        np.complex128,
249        # ints
250        np.int8,
251        np.int16,
252        np.int32,
253        np.int64,
254        np.longlong,
255        np.short,
256        # simplest
257        np.uint8,
258        np.bool_,
259    ],
260]
261"list of all the python, numpy, and torch numerical types I could think of"
262
263if np.version.version < "2.0.0":
264    DTYPE_LIST.extend([np.float_, np.int_])
265
266DTYPE_MAP: dict = {
267    **{str(x): x for x in DTYPE_LIST},
268    **{dtype.__name__: dtype for dtype in DTYPE_LIST if dtype.__module__ == "numpy"},
269}
270"mapping from string representations of types to their type"
271
272TORCH_DTYPE_MAP: dict = {
273    key: numpy_to_torch_dtype(dtype) for key, dtype in DTYPE_MAP.items()
274}
275"mapping from string representations of types to specifically torch types"
276
277# no idea why we have to do this, smh
278DTYPE_MAP["bool"] = np.bool_
279TORCH_DTYPE_MAP["bool"] = torch.bool
280
281
282TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.Optimizer]] = {
283    "Adagrad": torch.optim.Adagrad,
284    "Adam": torch.optim.Adam,
285    "AdamW": torch.optim.AdamW,
286    "SparseAdam": torch.optim.SparseAdam,
287    "Adamax": torch.optim.Adamax,
288    "ASGD": torch.optim.ASGD,
289    "LBFGS": torch.optim.LBFGS,
290    "NAdam": torch.optim.NAdam,
291    "RAdam": torch.optim.RAdam,
292    "RMSprop": torch.optim.RMSprop,
293    "Rprop": torch.optim.Rprop,
294    "SGD": torch.optim.SGD,
295}
296
297
298def pad_tensor(
299    tensor: jaxtyping.Shaped[torch.Tensor, "dim1"],  # noqa: F821
300    padded_length: int,
301    pad_value: float = 0.0,
302    rpad: bool = False,
303) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]:  # noqa: F821
304    """pad a 1-d tensor on the left with pad_value to length `padded_length`
305
306    set `rpad = True` to pad on the right instead"""
307
308    temp: list[torch.Tensor] = [
309        torch.full(
310            (padded_length - tensor.shape[0],),
311            pad_value,
312            dtype=tensor.dtype,
313            device=tensor.device,
314        ),
315        tensor,
316    ]
317
318    if rpad:
319        temp.reverse()
320
321    return torch.cat(temp)
322
323
324def lpad_tensor(
325    tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
326) -> torch.Tensor:
327    """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
328    return pad_tensor(tensor, padded_length, pad_value, rpad=False)
329
330
331def rpad_tensor(
332    tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
333) -> torch.Tensor:
334    """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
335    return pad_tensor(tensor, pad_length, pad_value, rpad=True)
336
337
338def pad_array(
339    array: jaxtyping.Shaped[np.ndarray, "dim1"],  # noqa: F821
340    padded_length: int,
341    pad_value: float = 0.0,
342    rpad: bool = False,
343) -> jaxtyping.Shaped[np.ndarray, "padded_length"]:  # noqa: F821
344    """pad a 1-d array on the left with pad_value to length `padded_length`
345
346    set `rpad = True` to pad on the right instead"""
347
348    temp: list[np.ndarray] = [
349        np.full(
350            (padded_length - array.shape[0],),
351            pad_value,
352            dtype=array.dtype,
353        ),
354        array,
355    ]
356
357    if rpad:
358        temp.reverse()
359
360    return np.concatenate(temp)
361
362
363def lpad_array(
364    array: np.ndarray, padded_length: int, pad_value: float = 0.0
365) -> np.ndarray:
366    """pad a 1-d array on the left with pad_value to length `padded_length`"""
367    return pad_array(array, padded_length, pad_value, rpad=False)
368
369
370def rpad_array(
371    array: np.ndarray, pad_length: int, pad_value: float = 0.0
372) -> np.ndarray:
373    """pad a 1-d array on the right with pad_value to length `pad_length`"""
374    return pad_array(array, pad_length, pad_value, rpad=True)
375
376
377def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
378    """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
379    return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
380
381
382def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
383    """printable version of get_dict_shapes"""
384    return json.dumps(
385        dotlist_to_nested_dict(
386            {
387                k: str(
388                    tuple(v.shape)
389                )  # to string, since indent wont play nice with tuples
390                for k, v in d.items()
391            }
392        ),
393        indent=2,
394    )
395
396
397class StateDictCompareError(AssertionError):
398    """raised when state dicts don't match"""
399
400    pass
401
402
403class StateDictKeysError(StateDictCompareError):
404    """raised when state dict keys don't match"""
405
406    pass
407
408
409class StateDictShapeError(StateDictCompareError):
410    """raised when state dict shapes don't match"""
411
412    pass
413
414
415class StateDictValueError(StateDictCompareError):
416    """raised when state dict values don't match"""
417
418    pass
419
420
421def compare_state_dicts(
422    d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
423) -> None:
424    """compare two dicts of tensors
425
426    # Parameters:
427
428     - `d1 : dict`
429     - `d2 : dict`
430     - `rtol : float`
431       (defaults to `1e-5`)
432     - `atol : float`
433       (defaults to `1e-8`)
434     - `verbose : bool`
435       (defaults to `True`)
436
437    # Raises:
438
439     - `StateDictKeysError` : keys don't match
440     - `StateDictShapeError` : shapes don't match (but keys do)
441     - `StateDictValueError` : values don't match (but keys and shapes do)
442    """
443    # check keys match
444    d1_keys: set = set(d1.keys())
445    d2_keys: set = set(d2.keys())
446    symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys)
447    keys_diff_1: set = d1_keys - d2_keys
448    keys_diff_2: set = d2_keys - d1_keys
449    # sort sets for easier debugging
450    symmetric_diff = set(sorted(symmetric_diff))
451    keys_diff_1 = set(sorted(keys_diff_1))
452    keys_diff_2 = set(sorted(keys_diff_2))
453    diff_shapes_1: str = (
454        string_dict_shapes({k: d1[k] for k in keys_diff_1})
455        if verbose
456        else "(verbose = False)"
457    )
458    diff_shapes_2: str = (
459        string_dict_shapes({k: d2[k] for k in keys_diff_2})
460        if verbose
461        else "(verbose = False)"
462    )
463    if not len(symmetric_diff) == 0:
464        raise StateDictKeysError(
465            f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
466        )
467
468    # check tensors match
469    shape_failed: list[str] = list()
470    vals_failed: list[str] = list()
471    for k, v1 in d1.items():
472        v2 = d2[k]
473        # check shapes first
474        if not v1.shape == v2.shape:
475            shape_failed.append(k)
476        else:
477            # if shapes match, check values
478            if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
479                vals_failed.append(k)
480
481    str_shape_failed: str = (
482        string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
483    )
484    str_vals_failed: str = (
485        string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
486    )
487
488    if not len(shape_failed) == 0:
489        raise StateDictShapeError(
490            f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
491        )
492    if not len(vals_failed) == 0:
493        raise StateDictValueError(
494            f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
495        )

TYPE_TO_JAX_DTYPE: dict = {<class 'float'>: <class 'jaxtyping.Float'>, <class 'int'>: <class 'jaxtyping.Int'>, <class 'jaxtyping.Float'>: <class 'jaxtyping.Float'>, <class 'jaxtyping.Int'>: <class 'jaxtyping.Int'>, <class 'bool'>: <class 'jaxtyping.Bool'>, <class 'jaxtyping.Bool'>: <class 'jaxtyping.Bool'>, <class 'numpy.bool_'>: <class 'jaxtyping.Bool'>, torch.bool: <class 'jaxtyping.Bool'>, <class 'numpy.float16'>: <class 'jaxtyping.Float'>, <class 'numpy.float32'>: <class 'jaxtyping.Float'>, <class 'numpy.float64'>: <class 'jaxtyping.Float'>, <class 'numpy.int8'>: <class 'jaxtyping.Int'>, <class 'numpy.int16'>: <class 'jaxtyping.Int'>, <class 'numpy.int32'>: <class 'jaxtyping.Int'>, <class 'numpy.int64'>: <class 'jaxtyping.Int'>, <class 'numpy.longlong'>: <class 'jaxtyping.Int'>, <class 'numpy.uint8'>: <class 'jaxtyping.Int'>, torch.float32: <class 'jaxtyping.Float'>, torch.float16: <class 'jaxtyping.Float'>, torch.float64: <class 'jaxtyping.Float'>, torch.bfloat16: <class 'jaxtyping.Float'>, torch.int32: <class 'jaxtyping.Int'>, torch.int8: <class 'jaxtyping.Int'>, torch.int16: <class 'jaxtyping.Int'>, torch.int64: <class 'jaxtyping.Int'>}

dict mapping python, numpy, and torch types to jaxtyping types

def jaxtype_factory( name: str, array_type: type, default_jax_dtype=<class 'jaxtyping.Float'>, legacy_mode: Union[muutils.errormode.ErrorMode, str] = ErrorMode.Warn) -> type:
 79def jaxtype_factory(
 80    name: str,
 81    array_type: type,
 82    default_jax_dtype=jaxtyping.Float,
 83    legacy_mode: typing.Union[ErrorMode, str] = ErrorMode.WARN,
 84) -> type:
 85    """usage:
 86    ```
 87    ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
 88    x: ATensor["dim1 dim2", np.float32]
 89    ```
 90    """
 91    legacy_mode_ = ErrorMode.from_any(legacy_mode)
 92
 93    class _BaseArray:
 94        """jaxtyping shorthand
 95        (backwards compatible with older versions of muutils.tensor_utils)
 96
 97        default_jax_dtype = {default_jax_dtype}
 98        array_type = {array_type}
 99        """
100
101        def __new__(cls, *args, **kwargs):
102            raise TypeError("Type FArray cannot be instantiated.")
103
104        def __init_subclass__(cls, *args, **kwargs):
105            raise TypeError(f"Cannot subclass {cls.__name__}")
106
107        @classmethod
108        def param_info(cls, params) -> str:
109            """useful for error printing"""
110            return "\n".join(
111                f"{k} = {v}"
112                for k, v in {
113                    "cls.__name__": cls.__name__,
114                    "cls.__doc__": cls.__doc__,
115                    "params": params,
116                    "type(params)": type(params),
117                }.items()
118            )
119
120        @typing._tp_cache  # type: ignore
121        def __class_getitem__(cls, params: typing.Union[str, tuple]) -> type:  # type: ignore
122            # MyTensor["dim1 dim2"]
123            if isinstance(params, str):
124                return default_jax_dtype[array_type, params]
125
126            elif isinstance(params, tuple):
127                if len(params) != 2:
128                    raise Exception(
129                        f"unexpected type for params, expected tuple of length 2 here:\n{cls.param_info(params)}"
130                    )
131
132                if isinstance(params[0], str):
133                    # MyTensor["dim1 dim2", int]
134                    return TYPE_TO_JAX_DTYPE[params[1]][array_type, params[0]]
135
136                elif isinstance(params[0], tuple):
137                    legacy_mode_.process(
138                        f"legacy type annotation was used:\n{cls.param_info(params) = }",
139                        except_cls=Exception,
140                    )
141                    # MyTensor[("dim1", "dim2"), int]
142                    shape_anot: list[str] = list()
143                    for x in params[0]:
144                        if isinstance(x, str):
145                            shape_anot.append(x)
146                        elif isinstance(x, int):
147                            shape_anot.append(str(x))
148                        elif isinstance(x, tuple):
149                            shape_anot.append("".join(str(y) for y in x))
150                        else:
151                            raise Exception(
152                                f"unexpected type for params, expected first part to be str, int, or tuple:\n{cls.param_info(params)}"
153                            )
154
155                    return TYPE_TO_JAX_DTYPE[params[1]][
156                        array_type, " ".join(shape_anot)
157                    ]
158            else:
159                raise Exception(
160                    f"unexpected type for params:\n{cls.param_info(params)}"
161                )
162
163    _BaseArray.__name__ = name
164
165    if _BaseArray.__doc__ is None:
166        _BaseArray.__doc__ = "{default_jax_dtype = }\n{array_type = }"
167
168    _BaseArray.__doc__ = _BaseArray.__doc__.format(
169        default_jax_dtype=repr(default_jax_dtype),
170        array_type=repr(array_type),
171    )
172
173    return _BaseArray

usage:

ATensor = jaxtype_factory("ATensor", torch.Tensor, jaxtyping.Float)
x: ATensor["dim1 dim2", np.float32]
ATensor = <class 'jaxtype_factory.<locals>._BaseArray'>
NDArray = <class 'jaxtype_factory.<locals>._BaseArray'>
def numpy_to_torch_dtype(dtype: Union[numpy.dtype, torch.dtype]) -> torch.dtype:
196def numpy_to_torch_dtype(dtype: typing.Union[np.dtype, torch.dtype]) -> torch.dtype:
197    """convert numpy dtype to torch dtype"""
198    if isinstance(dtype, torch.dtype):
199        return dtype
200    else:
201        return torch.from_numpy(np.array(0, dtype=dtype)).dtype

convert numpy dtype to torch dtype

DTYPE_LIST: list = [<class 'bool'>, <class 'int'>, <class 'float'>, torch.float32, torch.float32, torch.float64, torch.float16, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.int32, torch.int8, torch.int16, torch.int32, torch.int64, torch.int64, torch.int16, torch.uint8, torch.bool, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.complex64'>, <class 'numpy.complex128'>, <class 'numpy.int8'>, <class 'numpy.int16'>, <class 'numpy.int32'>, <class 'numpy.int64'>, <class 'numpy.longlong'>, <class 'numpy.int16'>, <class 'numpy.uint8'>, <class 'numpy.bool_'>, <class 'numpy.float64'>, <class 'numpy.int64'>]

list of all the python, numpy, and torch numerical types I could think of

DTYPE_MAP: dict = {"<class 'bool'>": <class 'bool'>, "<class 'int'>": <class 'int'>, "<class 'float'>": <class 'float'>, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float16'>": <class 'numpy.float16'>, "<class 'numpy.float32'>": <class 'numpy.float32'>, "<class 'numpy.float64'>": <class 'numpy.float64'>, "<class 'numpy.complex64'>": <class 'numpy.complex64'>, "<class 'numpy.complex128'>": <class 'numpy.complex128'>, "<class 'numpy.int8'>": <class 'numpy.int8'>, "<class 'numpy.int16'>": <class 'numpy.int16'>, "<class 'numpy.int32'>": <class 'numpy.int32'>, "<class 'numpy.int64'>": <class 'numpy.int64'>, "<class 'numpy.longlong'>": <class 'numpy.longlong'>, "<class 'numpy.uint8'>": <class 'numpy.uint8'>, "<class 'numpy.bool_'>": <class 'numpy.bool_'>, 'float16': <class 'numpy.float16'>, 'float32': <class 'numpy.float32'>, 'float64': <class 'numpy.float64'>, 'complex64': <class 'numpy.complex64'>, 'complex128': <class 'numpy.complex128'>, 'int8': <class 'numpy.int8'>, 'int16': <class 'numpy.int16'>, 'int32': <class 'numpy.int32'>, 'int64': <class 'numpy.int64'>, 'longlong': <class 'numpy.longlong'>, 'uint8': <class 'numpy.uint8'>, 'bool_': <class 'numpy.bool_'>, 'bool': <class 'numpy.bool_'>}

mapping from string representations of types to their type

TORCH_DTYPE_MAP: dict = {"<class 'bool'>": torch.bool, "<class 'int'>": torch.int64, "<class 'float'>": torch.float64, 'torch.float32': torch.float32, 'torch.float64': torch.float64, 'torch.float16': torch.float16, 'torch.bfloat16': torch.bfloat16, 'torch.complex64': torch.complex64, 'torch.complex128': torch.complex128, 'torch.int32': torch.int32, 'torch.int8': torch.int8, 'torch.int16': torch.int16, 'torch.int64': torch.int64, 'torch.uint8': torch.uint8, 'torch.bool': torch.bool, "<class 'numpy.float16'>": torch.float16, "<class 'numpy.float32'>": torch.float32, "<class 'numpy.float64'>": torch.float64, "<class 'numpy.complex64'>": torch.complex64, "<class 'numpy.complex128'>": torch.complex128, "<class 'numpy.int8'>": torch.int8, "<class 'numpy.int16'>": torch.int16, "<class 'numpy.int32'>": torch.int32, "<class 'numpy.int64'>": torch.int64, "<class 'numpy.longlong'>": torch.int64, "<class 'numpy.uint8'>": torch.uint8, "<class 'numpy.bool_'>": torch.bool, 'float16': torch.float16, 'float32': torch.float32, 'float64': torch.float64, 'complex64': torch.complex64, 'complex128': torch.complex128, 'int8': torch.int8, 'int16': torch.int16, 'int32': torch.int32, 'int64': torch.int64, 'longlong': torch.int64, 'uint8': torch.uint8, 'bool_': torch.bool, 'bool': torch.bool}

mapping from string representations of types to specifically torch types

TORCH_OPTIMIZERS_MAP: dict[str, typing.Type[torch.optim.optimizer.Optimizer]] = {'Adagrad': <class 'torch.optim.adagrad.Adagrad'>, 'Adam': <class 'torch.optim.adam.Adam'>, 'AdamW': <class 'torch.optim.adamw.AdamW'>, 'SparseAdam': <class 'torch.optim.sparse_adam.SparseAdam'>, 'Adamax': <class 'torch.optim.adamax.Adamax'>, 'ASGD': <class 'torch.optim.asgd.ASGD'>, 'LBFGS': <class 'torch.optim.lbfgs.LBFGS'>, 'NAdam': <class 'torch.optim.nadam.NAdam'>, 'RAdam': <class 'torch.optim.radam.RAdam'>, 'RMSprop': <class 'torch.optim.rmsprop.RMSprop'>, 'Rprop': <class 'torch.optim.rprop.Rprop'>, 'SGD': <class 'torch.optim.sgd.SGD'>}
def pad_tensor( tensor: jaxtyping.Shaped[Tensor, 'dim1'], padded_length: int, pad_value: float = 0.0, rpad: bool = False) -> jaxtyping.Shaped[Tensor, 'padded_length']:
299def pad_tensor(
300    tensor: jaxtyping.Shaped[torch.Tensor, "dim1"],  # noqa: F821
301    padded_length: int,
302    pad_value: float = 0.0,
303    rpad: bool = False,
304) -> jaxtyping.Shaped[torch.Tensor, "padded_length"]:  # noqa: F821
305    """pad a 1-d tensor on the left with pad_value to length `padded_length`
306
307    set `rpad = True` to pad on the right instead"""
308
309    temp: list[torch.Tensor] = [
310        torch.full(
311            (padded_length - tensor.shape[0],),
312            pad_value,
313            dtype=tensor.dtype,
314            device=tensor.device,
315        ),
316        tensor,
317    ]
318
319    if rpad:
320        temp.reverse()
321
322    return torch.cat(temp)

pad a 1-d tensor on the left with pad_value to length padded_length

set rpad = True to pad on the right instead

def lpad_tensor( tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0) -> torch.Tensor:
325def lpad_tensor(
326    tensor: torch.Tensor, padded_length: int, pad_value: float = 0.0
327) -> torch.Tensor:
328    """pad a 1-d tensor on the left with pad_value to length `padded_length`"""
329    return pad_tensor(tensor, padded_length, pad_value, rpad=False)

pad a 1-d tensor on the left with pad_value to length padded_length

def rpad_tensor( tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0) -> torch.Tensor:
332def rpad_tensor(
333    tensor: torch.Tensor, pad_length: int, pad_value: float = 0.0
334) -> torch.Tensor:
335    """pad a 1-d tensor on the right with pad_value to length `pad_length`"""
336    return pad_tensor(tensor, pad_length, pad_value, rpad=True)

pad a 1-d tensor on the right with pad_value to length pad_length

def pad_array( array: jaxtyping.Shaped[ndarray, 'dim1'], padded_length: int, pad_value: float = 0.0, rpad: bool = False) -> jaxtyping.Shaped[ndarray, 'padded_length']:
339def pad_array(
340    array: jaxtyping.Shaped[np.ndarray, "dim1"],  # noqa: F821
341    padded_length: int,
342    pad_value: float = 0.0,
343    rpad: bool = False,
344) -> jaxtyping.Shaped[np.ndarray, "padded_length"]:  # noqa: F821
345    """pad a 1-d array on the left with pad_value to length `padded_length`
346
347    set `rpad = True` to pad on the right instead"""
348
349    temp: list[np.ndarray] = [
350        np.full(
351            (padded_length - array.shape[0],),
352            pad_value,
353            dtype=array.dtype,
354        ),
355        array,
356    ]
357
358    if rpad:
359        temp.reverse()
360
361    return np.concatenate(temp)

pad a 1-d array on the left with pad_value to length padded_length

set rpad = True to pad on the right instead

def lpad_array( array: numpy.ndarray, padded_length: int, pad_value: float = 0.0) -> numpy.ndarray:
364def lpad_array(
365    array: np.ndarray, padded_length: int, pad_value: float = 0.0
366) -> np.ndarray:
367    """pad a 1-d array on the left with pad_value to length `padded_length`"""
368    return pad_array(array, padded_length, pad_value, rpad=False)

pad a 1-d array on the left with pad_value to length padded_length

def rpad_array( array: numpy.ndarray, pad_length: int, pad_value: float = 0.0) -> numpy.ndarray:
371def rpad_array(
372    array: np.ndarray, pad_length: int, pad_value: float = 0.0
373) -> np.ndarray:
374    """pad a 1-d array on the right with pad_value to length `pad_length`"""
375    return pad_array(array, pad_length, pad_value, rpad=True)

pad a 1-d array on the right with pad_value to length pad_length

def get_dict_shapes(d: dict[str, torch.Tensor]) -> dict[str, tuple[int, ...]]:
378def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
379    """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
380    return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})

given a state dict or cache dict, compute the shapes and put them in a nested dict

def string_dict_shapes(d: dict[str, torch.Tensor]) -> str:
383def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
384    """printable version of get_dict_shapes"""
385    return json.dumps(
386        dotlist_to_nested_dict(
387            {
388                k: str(
389                    tuple(v.shape)
390                )  # to string, since indent wont play nice with tuples
391                for k, v in d.items()
392            }
393        ),
394        indent=2,
395    )

printable version of get_dict_shapes

class StateDictCompareError(builtins.AssertionError):
398class StateDictCompareError(AssertionError):
399    """raised when state dicts don't match"""
400
401    pass

raised when state dicts don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictKeysError(StateDictCompareError):
404class StateDictKeysError(StateDictCompareError):
405    """raised when state dict keys don't match"""
406
407    pass

raised when state dict keys don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictShapeError(StateDictCompareError):
410class StateDictShapeError(StateDictCompareError):
411    """raised when state dict shapes don't match"""
412
413    pass

raised when state dict shapes don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
class StateDictValueError(StateDictCompareError):
416class StateDictValueError(StateDictCompareError):
417    """raised when state dict values don't match"""
418
419    pass

raised when state dict values don't match

Inherited Members
builtins.AssertionError
AssertionError
builtins.BaseException
with_traceback
add_note
args
def compare_state_dicts( d1: dict, d2: dict, rtol: float = 1e-05, atol: float = 1e-08, verbose: bool = True) -> None:
422def compare_state_dicts(
423    d1: dict, d2: dict, rtol: float = 1e-5, atol: float = 1e-8, verbose: bool = True
424) -> None:
425    """compare two dicts of tensors
426
427    # Parameters:
428
429     - `d1 : dict`
430     - `d2 : dict`
431     - `rtol : float`
432       (defaults to `1e-5`)
433     - `atol : float`
434       (defaults to `1e-8`)
435     - `verbose : bool`
436       (defaults to `True`)
437
438    # Raises:
439
440     - `StateDictKeysError` : keys don't match
441     - `StateDictShapeError` : shapes don't match (but keys do)
442     - `StateDictValueError` : values don't match (but keys and shapes do)
443    """
444    # check keys match
445    d1_keys: set = set(d1.keys())
446    d2_keys: set = set(d2.keys())
447    symmetric_diff: set = set.symmetric_difference(d1_keys, d2_keys)
448    keys_diff_1: set = d1_keys - d2_keys
449    keys_diff_2: set = d2_keys - d1_keys
450    # sort sets for easier debugging
451    symmetric_diff = set(sorted(symmetric_diff))
452    keys_diff_1 = set(sorted(keys_diff_1))
453    keys_diff_2 = set(sorted(keys_diff_2))
454    diff_shapes_1: str = (
455        string_dict_shapes({k: d1[k] for k in keys_diff_1})
456        if verbose
457        else "(verbose = False)"
458    )
459    diff_shapes_2: str = (
460        string_dict_shapes({k: d2[k] for k in keys_diff_2})
461        if verbose
462        else "(verbose = False)"
463    )
464    if not len(symmetric_diff) == 0:
465        raise StateDictKeysError(
466            f"state dicts do not match:\n{symmetric_diff = }\n{keys_diff_1 = }\n{keys_diff_2 = }\nd1_shapes = {diff_shapes_1}\nd2_shapes = {diff_shapes_2}"
467        )
468
469    # check tensors match
470    shape_failed: list[str] = list()
471    vals_failed: list[str] = list()
472    for k, v1 in d1.items():
473        v2 = d2[k]
474        # check shapes first
475        if not v1.shape == v2.shape:
476            shape_failed.append(k)
477        else:
478            # if shapes match, check values
479            if not torch.allclose(v1, v2, rtol=rtol, atol=atol):
480                vals_failed.append(k)
481
482    str_shape_failed: str = (
483        string_dict_shapes({k: d1[k] for k in shape_failed}) if verbose else ""
484    )
485    str_vals_failed: str = (
486        string_dict_shapes({k: d1[k] for k in vals_failed}) if verbose else ""
487    )
488
489    if not len(shape_failed) == 0:
490        raise StateDictShapeError(
491            f"{len(shape_failed)} / {len(d1)} state dict elements don't match in shape:\n{shape_failed = }\n{str_shape_failed}"
492        )
493    if not len(vals_failed) == 0:
494        raise StateDictValueError(
495            f"{len(vals_failed)} / {len(d1)} state dict elements don't match in values:\n{vals_failed = }\n{str_vals_failed}"
496        )

compare two dicts of tensors

Parameters:

  • d1 : dict
  • d2 : dict
  • rtol : float (defaults to 1e-5)
  • atol : float (defaults to 1e-8)
  • verbose : bool (defaults to True)

Raises: