docs for muutils v0.8.1
View Source on GitHub

muutils.json_serialize.util

utilities for json_serialize


  1"""utilities for json_serialize"""
  2
  3from __future__ import annotations
  4
  5import dataclasses
  6import functools
  7import inspect
  8import sys
  9import typing
 10import warnings
 11from typing import Any, Callable, Iterable, Union
 12
 13_NUMPY_WORKING: bool
 14try:
 15    _NUMPY_WORKING = True
 16except ImportError:
 17    warnings.warn("numpy not found, cannot serialize numpy arrays!")
 18    _NUMPY_WORKING = False
 19
 20
 21BaseType = Union[
 22    bool,
 23    int,
 24    float,
 25    str,
 26    None,
 27]
 28
 29JSONitem = Union[
 30    BaseType,
 31    # mypy doesn't like recursive types, so we just go down a few levels manually
 32    typing.List[Union[BaseType, typing.List[Any], typing.Dict[str, Any]]],
 33    typing.Dict[str, Union[BaseType, typing.List[Any], typing.Dict[str, Any]]],
 34]
 35JSONdict = typing.Dict[str, JSONitem]
 36
 37Hashableitem = Union[bool, int, float, str, tuple]
 38
 39
 40_FORMAT_KEY: str = "__muutils_format__"
 41_REF_KEY: str = "$ref"
 42
 43# or if python version <3.9
 44if typing.TYPE_CHECKING or sys.version_info < (3, 9):
 45    MonoTuple = typing.Sequence
 46else:
 47
 48    class MonoTuple:
 49        """tuple type hint, but for a tuple of any length with all the same type"""
 50
 51        __slots__ = ()
 52
 53        def __new__(cls, *args, **kwargs):
 54            raise TypeError("Type MonoTuple cannot be instantiated.")
 55
 56        def __init_subclass__(cls, *args, **kwargs):
 57            raise TypeError(f"Cannot subclass {cls.__module__}")
 58
 59        # idk why mypy thinks there is no such function in typing
 60        @typing._tp_cache  # type: ignore
 61        def __class_getitem__(cls, params):
 62            if getattr(params, "__origin__", None) == typing.Union:
 63                return typing.GenericAlias(tuple, (params, Ellipsis))
 64            elif isinstance(params, type):
 65                typing.GenericAlias(tuple, (params, Ellipsis))
 66            # test if has len and is iterable
 67            elif isinstance(params, Iterable):
 68                if len(params) == 0:
 69                    return tuple
 70                elif len(params) == 1:
 71                    return typing.GenericAlias(tuple, (params[0], Ellipsis))
 72            else:
 73                raise TypeError(f"MonoTuple expects 1 type argument, got {params = }")
 74
 75
 76class UniversalContainer:
 77    """contains everything -- `x in UniversalContainer()` is always True"""
 78
 79    def __contains__(self, x: Any) -> bool:
 80        return True
 81
 82
 83def isinstance_namedtuple(x: Any) -> bool:
 84    """checks if `x` is a `namedtuple`
 85
 86    credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
 87    """
 88    t: type = type(x)
 89    b: tuple = t.__bases__
 90    if len(b) != 1 or (b[0] is not tuple):
 91        return False
 92    f: Any = getattr(t, "_fields", None)
 93    if not isinstance(f, tuple):
 94        return False
 95    return all(isinstance(n, str) for n in f)
 96
 97
 98def try_catch(func: Callable):
 99    """wraps the function to catch exceptions, returns serialized error message on exception
100
101    returned func will return normal result on success, or error message on exception
102    """
103
104    @functools.wraps(func)
105    def newfunc(*args, **kwargs):
106        try:
107            return func(*args, **kwargs)
108        except Exception as e:
109            return f"{e.__class__.__name__}: {e}"
110
111    return newfunc
112
113
114def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem:
115    if isinstance(obj, typing.Mapping):
116        return tuple((k, _recursive_hashify(v)) for k, v in obj.items())
117    elif isinstance(obj, (tuple, list, Iterable)):
118        return tuple(_recursive_hashify(v) for v in obj)
119    elif isinstance(obj, (bool, int, float, str)):
120        return obj
121    else:
122        if force:
123            return str(obj)
124        else:
125            raise ValueError(f"cannot hashify:\n{obj}")
126
127
128class SerializationException(Exception):
129    pass
130
131
132def string_as_lines(s: str | None) -> list[str]:
133    """for easier reading of long strings in json, split up by newlines
134
135    sort of like how jupyter notebooks do it
136    """
137    if s is None:
138        return list()
139    else:
140        return s.splitlines(keepends=False)
141
142
143def safe_getsource(func) -> list[str]:
144    try:
145        return string_as_lines(inspect.getsource(func))
146    except Exception as e:
147        return string_as_lines(f"Error: Unable to retrieve source code:\n{e}")
148
149
150# credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises
151def array_safe_eq(a: Any, b: Any) -> bool:
152    """check if two objects are equal, account for if numpy arrays or torch tensors"""
153    if a is b:
154        return True
155
156    if type(a) is not type(b):
157        return False
158
159    if (
160        str(type(a)) == "<class 'numpy.ndarray'>"
161        and str(type(b)) == "<class 'numpy.ndarray'>"
162    ) or (
163        str(type(a)) == "<class 'torch.Tensor'>"
164        and str(type(b)) == "<class 'torch.Tensor'>"
165    ):
166        return (a == b).all()
167
168    if (
169        str(type(a)) == "<class 'pandas.core.frame.DataFrame'>"
170        and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>"
171    ):
172        return a.equals(b)
173
174    if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence):
175        if len(a) == 0 and len(b) == 0:
176            return True
177        return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b))
178
179    if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)):
180        return len(a) == len(b) and all(
181            array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2])
182            for k1, k2 in zip(a.keys(), b.keys())
183        )
184
185    try:
186        return bool(a == b)
187    except (TypeError, ValueError) as e:
188        warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}")
189        return NotImplemented  # type: ignore[return-value]
190
191
192def dc_eq(
193    dc1,
194    dc2,
195    except_when_class_mismatch: bool = False,
196    false_when_class_mismatch: bool = True,
197    except_when_field_mismatch: bool = False,
198) -> bool:
199    """
200    checks if two dataclasses which (might) hold numpy arrays are equal
201
202    # Parameters:
203
204    - `dc1`: the first dataclass
205    - `dc2`: the second dataclass
206    - `except_when_class_mismatch: bool`
207        if `True`, will throw `TypeError` if the classes are different.
208        if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False`
209        (default: `False`)
210    - `false_when_class_mismatch: bool`
211        only relevant if `except_when_class_mismatch` is `False`.
212        if `True`, will return `False` if the classes are different.
213        if `False`, will attempt to compare the fields.
214    - `except_when_field_mismatch: bool`
215        only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`.
216        if `True`, will throw `TypeError` if the fields are different.
217        (default: `True`)
218
219    # Returns:
220    - `bool`: True if the dataclasses are equal, False otherwise
221
222    # Raises:
223    - `TypeError`: if the dataclasses are of different classes
224    - `AttributeError`: if the dataclasses have different fields
225
226    # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?
227    ```
228              [START]
229
230           ┌───────────┐  ┌─────────┐
231           │dc1 is dc2?├─►│ classes │
232           └──┬────────┘No│ match?  │
233      ────    │           ├─────────┤
234     (True)◄──┘Yes        │No       │Yes
235      ────                ▼         ▼
236          ┌────────────────┐ ┌────────────┐
237          │ except when    │ │ fields keys│
238          │ class mismatch?│ │ match?     │
239          ├───────────┬────┘ ├───────┬────┘
240          │Yes        │No    │No     │Yes
241          ▼           ▼      ▼       ▼
242     ───────────  ┌──────────┐  ┌────────┐
243    { raise     } │ except   │  │ field  │
244    { TypeError } │ when     │  │ values │
245     ───────────  │ field    │  │ match? │
246                  │ mismatch?│  ├────┬───┘
247                  ├───────┬──┘  │    │Yes
248                  │Yes    │No   │No  ▼
249                  ▼       ▼     │   ────
250     ───────────────     ─────  │  (True)
251    { raise         }   (False)◄┘   ────
252    { AttributeError}    ─────
253     ───────────────
254    ```
255
256    """
257    if dc1 is dc2:
258        return True
259
260    if dc1.__class__ is not dc2.__class__:
261        if except_when_class_mismatch:
262            # if the classes don't match, raise an error
263            raise TypeError(
264                f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`"
265            )
266        if except_when_field_mismatch:
267            dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)])
268            dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)])
269            fields_match: bool = set(dc1_fields) == set(dc2_fields)
270            if not fields_match:
271                # if the fields match, keep going
272                raise AttributeError(
273                    f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`"
274                )
275        return False
276
277    return all(
278        array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name))
279        for fld in dataclasses.fields(dc1)
280        if fld.compare
281    )

BaseType = typing.Union[bool, int, float, str, NoneType]
JSONitem = typing.Union[bool, int, float, str, NoneType, typing.List[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]], typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]]]
JSONdict = typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]], typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]]]]
Hashableitem = typing.Union[bool, int, float, str, tuple]
class UniversalContainer:
77class UniversalContainer:
78    """contains everything -- `x in UniversalContainer()` is always True"""
79
80    def __contains__(self, x: Any) -> bool:
81        return True

contains everything -- x in UniversalContainer() is always True

def isinstance_namedtuple(x: Any) -> bool:
84def isinstance_namedtuple(x: Any) -> bool:
85    """checks if `x` is a `namedtuple`
86
87    credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
88    """
89    t: type = type(x)
90    b: tuple = t.__bases__
91    if len(b) != 1 or (b[0] is not tuple):
92        return False
93    f: Any = getattr(t, "_fields", None)
94    if not isinstance(f, tuple):
95        return False
96    return all(isinstance(n, str) for n in f)
def try_catch(func: Callable):
 99def try_catch(func: Callable):
100    """wraps the function to catch exceptions, returns serialized error message on exception
101
102    returned func will return normal result on success, or error message on exception
103    """
104
105    @functools.wraps(func)
106    def newfunc(*args, **kwargs):
107        try:
108            return func(*args, **kwargs)
109        except Exception as e:
110            return f"{e.__class__.__name__}: {e}"
111
112    return newfunc

wraps the function to catch exceptions, returns serialized error message on exception

returned func will return normal result on success, or error message on exception

class SerializationException(builtins.Exception):
129class SerializationException(Exception):
130    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
add_note
args
def string_as_lines(s: str | None) -> list[str]:
133def string_as_lines(s: str | None) -> list[str]:
134    """for easier reading of long strings in json, split up by newlines
135
136    sort of like how jupyter notebooks do it
137    """
138    if s is None:
139        return list()
140    else:
141        return s.splitlines(keepends=False)

for easier reading of long strings in json, split up by newlines

sort of like how jupyter notebooks do it

def safe_getsource(func) -> list[str]:
144def safe_getsource(func) -> list[str]:
145    try:
146        return string_as_lines(inspect.getsource(func))
147    except Exception as e:
148        return string_as_lines(f"Error: Unable to retrieve source code:\n{e}")
def array_safe_eq(a: Any, b: Any) -> bool:
152def array_safe_eq(a: Any, b: Any) -> bool:
153    """check if two objects are equal, account for if numpy arrays or torch tensors"""
154    if a is b:
155        return True
156
157    if type(a) is not type(b):
158        return False
159
160    if (
161        str(type(a)) == "<class 'numpy.ndarray'>"
162        and str(type(b)) == "<class 'numpy.ndarray'>"
163    ) or (
164        str(type(a)) == "<class 'torch.Tensor'>"
165        and str(type(b)) == "<class 'torch.Tensor'>"
166    ):
167        return (a == b).all()
168
169    if (
170        str(type(a)) == "<class 'pandas.core.frame.DataFrame'>"
171        and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>"
172    ):
173        return a.equals(b)
174
175    if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence):
176        if len(a) == 0 and len(b) == 0:
177            return True
178        return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b))
179
180    if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)):
181        return len(a) == len(b) and all(
182            array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2])
183            for k1, k2 in zip(a.keys(), b.keys())
184        )
185
186    try:
187        return bool(a == b)
188    except (TypeError, ValueError) as e:
189        warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}")
190        return NotImplemented  # type: ignore[return-value]

check if two objects are equal, account for if numpy arrays or torch tensors

def dc_eq( dc1, dc2, except_when_class_mismatch: bool = False, false_when_class_mismatch: bool = True, except_when_field_mismatch: bool = False) -> bool:
193def dc_eq(
194    dc1,
195    dc2,
196    except_when_class_mismatch: bool = False,
197    false_when_class_mismatch: bool = True,
198    except_when_field_mismatch: bool = False,
199) -> bool:
200    """
201    checks if two dataclasses which (might) hold numpy arrays are equal
202
203    # Parameters:
204
205    - `dc1`: the first dataclass
206    - `dc2`: the second dataclass
207    - `except_when_class_mismatch: bool`
208        if `True`, will throw `TypeError` if the classes are different.
209        if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False`
210        (default: `False`)
211    - `false_when_class_mismatch: bool`
212        only relevant if `except_when_class_mismatch` is `False`.
213        if `True`, will return `False` if the classes are different.
214        if `False`, will attempt to compare the fields.
215    - `except_when_field_mismatch: bool`
216        only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`.
217        if `True`, will throw `TypeError` if the fields are different.
218        (default: `True`)
219
220    # Returns:
221    - `bool`: True if the dataclasses are equal, False otherwise
222
223    # Raises:
224    - `TypeError`: if the dataclasses are of different classes
225    - `AttributeError`: if the dataclasses have different fields
226
227    # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?
228    ```
229              [START]
230
231           ┌───────────┐  ┌─────────┐
232           │dc1 is dc2?├─►│ classes │
233           └──┬────────┘No│ match?  │
234      ────    │           ├─────────┤
235     (True)◄──┘Yes        │No       │Yes
236      ────                ▼         ▼
237          ┌────────────────┐ ┌────────────┐
238          │ except when    │ │ fields keys│
239          │ class mismatch?│ │ match?     │
240          ├───────────┬────┘ ├───────┬────┘
241          │Yes        │No    │No     │Yes
242          ▼           ▼      ▼       ▼
243     ───────────  ┌──────────┐  ┌────────┐
244    { raise     } │ except   │  │ field  │
245    { TypeError } │ when     │  │ values │
246     ───────────  │ field    │  │ match? │
247                  │ mismatch?│  ├────┬───┘
248                  ├───────┬──┘  │    │Yes
249                  │Yes    │No   │No  ▼
250                  ▼       ▼     │   ────
251     ───────────────     ─────  │  (True)
252    { raise         }   (False)◄┘   ────
253    { AttributeError}    ─────
254     ───────────────
255    ```
256
257    """
258    if dc1 is dc2:
259        return True
260
261    if dc1.__class__ is not dc2.__class__:
262        if except_when_class_mismatch:
263            # if the classes don't match, raise an error
264            raise TypeError(
265                f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`"
266            )
267        if except_when_field_mismatch:
268            dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)])
269            dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)])
270            fields_match: bool = set(dc1_fields) == set(dc2_fields)
271            if not fields_match:
272                # if the fields match, keep going
273                raise AttributeError(
274                    f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`"
275                )
276        return False
277
278    return all(
279        array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name))
280        for fld in dataclasses.fields(dc1)
281        if fld.compare
282    )

checks if two dataclasses which (might) hold numpy arrays are equal

Parameters:

  • dc1: the first dataclass
  • dc2: the second dataclass
  • except_when_class_mismatch: bool if True, will throw TypeError if the classes are different. if not, will return false by default or attempt to compare the fields if false_when_class_mismatch is False (default: False)
  • false_when_class_mismatch: bool only relevant if except_when_class_mismatch is False. if True, will return False if the classes are different. if False, will attempt to compare the fields.
  • except_when_field_mismatch: bool only relevant if except_when_class_mismatch is False and false_when_class_mismatch is False. if True, will throw TypeError if the fields are different. (default: True)

Returns:

  • bool: True if the dataclasses are equal, False otherwise

Raises:

  • TypeError: if the dataclasses are of different classes
  • AttributeError: if the dataclasses have different fields

TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?

          [START]
             ▼
       ┌───────────┐  ┌─────────┐
       │dc1 is dc2?├─►│ classes │
       └──┬────────┘No│ match?  │
  ────    │           ├─────────┤
 (True)◄──┘Yes        │No       │Yes
  ────                ▼         ▼
      ┌────────────────┐ ┌────────────┐
      │ except when    │ │ fields keys│
      │ class mismatch?│ │ match?     │
      ├───────────┬────┘ ├───────┬────┘
      │Yes        │No    │No     │Yes
      ▼           ▼      ▼       ▼
 ───────────  ┌──────────┐  ┌────────┐
{ raise     } │ except   │  │ field  │
{ TypeError } │ when     │  │ values │
 ───────────  │ field    │  │ match? │
              │ mismatch?│  ├────┬───┘
              ├───────┬──┘  │    │Yes
              │Yes    │No   │No  ▼
              ▼       ▼     │   ────
 ───────────────     ─────  │  (True)
{ raise         }   (False)◄┘   ────
{ AttributeError}    ─────
 ───────────────
class MonoTuple:
49    class MonoTuple:
50        """tuple type hint, but for a tuple of any length with all the same type"""
51
52        __slots__ = ()
53
54        def __new__(cls, *args, **kwargs):
55            raise TypeError("Type MonoTuple cannot be instantiated.")
56
57        def __init_subclass__(cls, *args, **kwargs):
58            raise TypeError(f"Cannot subclass {cls.__module__}")
59
60        # idk why mypy thinks there is no such function in typing
61        @typing._tp_cache  # type: ignore
62        def __class_getitem__(cls, params):
63            if getattr(params, "__origin__", None) == typing.Union:
64                return typing.GenericAlias(tuple, (params, Ellipsis))
65            elif isinstance(params, type):
66                typing.GenericAlias(tuple, (params, Ellipsis))
67            # test if has len and is iterable
68            elif isinstance(params, Iterable):
69                if len(params) == 0:
70                    return tuple
71                elif len(params) == 1:
72                    return typing.GenericAlias(tuple, (params[0], Ellipsis))
73            else:
74                raise TypeError(f"MonoTuple expects 1 type argument, got {params = }")

tuple type hint, but for a tuple of any length with all the same type