muutils.json_serialize.util
utilities for json_serialize
1"""utilities for json_serialize""" 2 3from __future__ import annotations 4 5import dataclasses 6import functools 7import inspect 8import sys 9import typing 10import warnings 11from typing import Any, Callable, Iterable, Union 12 13_NUMPY_WORKING: bool 14try: 15 _NUMPY_WORKING = True 16except ImportError: 17 warnings.warn("numpy not found, cannot serialize numpy arrays!") 18 _NUMPY_WORKING = False 19 20 21JSONitem = Union[bool, int, float, str, list, typing.Dict[str, Any], None] 22JSONdict = typing.Dict[str, JSONitem] 23Hashableitem = Union[bool, int, float, str, tuple] 24 25# or if python version <3.9 26if typing.TYPE_CHECKING or sys.version_info < (3, 9): 27 MonoTuple = typing.Sequence 28else: 29 30 class MonoTuple: 31 """tuple type hint, but for a tuple of any length with all the same type""" 32 33 __slots__ = () 34 35 def __new__(cls, *args, **kwargs): 36 raise TypeError("Type MonoTuple cannot be instantiated.") 37 38 def __init_subclass__(cls, *args, **kwargs): 39 raise TypeError(f"Cannot subclass {cls.__module__}") 40 41 # idk why mypy thinks there is no such function in typing 42 @typing._tp_cache # type: ignore 43 def __class_getitem__(cls, params): 44 if getattr(params, "__origin__", None) == typing.Union: 45 return typing.GenericAlias(tuple, (params, Ellipsis)) 46 elif isinstance(params, type): 47 typing.GenericAlias(tuple, (params, Ellipsis)) 48 # test if has len and is iterable 49 elif isinstance(params, Iterable): 50 if len(params) == 0: 51 return tuple 52 elif len(params) == 1: 53 return typing.GenericAlias(tuple, (params[0], Ellipsis)) 54 else: 55 raise TypeError(f"MonoTuple expects 1 type argument, got {params = }") 56 57 58class UniversalContainer: 59 """contains everything -- `x in UniversalContainer()` is always True""" 60 61 def __contains__(self, x: Any) -> bool: 62 return True 63 64 65def isinstance_namedtuple(x: Any) -> bool: 66 """checks if `x` is a `namedtuple` 67 68 credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple 69 """ 70 t: type = type(x) 71 b: tuple = t.__bases__ 72 if len(b) != 1 or (b[0] is not tuple): 73 return False 74 f: Any = getattr(t, "_fields", None) 75 if not isinstance(f, tuple): 76 return False 77 return all(isinstance(n, str) for n in f) 78 79 80def try_catch(func: Callable): 81 """wraps the function to catch exceptions, returns serialized error message on exception 82 83 returned func will return normal result on success, or error message on exception 84 """ 85 86 @functools.wraps(func) 87 def newfunc(*args, **kwargs): 88 try: 89 return func(*args, **kwargs) 90 except Exception as e: 91 return f"{e.__class__.__name__}: {e}" 92 93 return newfunc 94 95 96def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem: 97 if isinstance(obj, typing.Mapping): 98 return tuple((k, _recursive_hashify(v)) for k, v in obj.items()) 99 elif isinstance(obj, (tuple, list, Iterable)): 100 return tuple(_recursive_hashify(v) for v in obj) 101 elif isinstance(obj, (bool, int, float, str)): 102 return obj 103 else: 104 if force: 105 return str(obj) 106 else: 107 raise ValueError(f"cannot hashify:\n{obj}") 108 109 110class SerializationException(Exception): 111 pass 112 113 114def string_as_lines(s: str | None) -> list[str]: 115 """for easier reading of long strings in json, split up by newlines 116 117 sort of like how jupyter notebooks do it 118 """ 119 if s is None: 120 return list() 121 else: 122 return s.splitlines(keepends=False) 123 124 125def safe_getsource(func) -> list[str]: 126 try: 127 return string_as_lines(inspect.getsource(func)) 128 except Exception as e: 129 return string_as_lines(f"Error: Unable to retrieve source code:\n{e}") 130 131 132# credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises 133def array_safe_eq(a: Any, b: Any) -> bool: 134 """check if two objects are equal, account for if numpy arrays or torch tensors""" 135 if a is b: 136 return True 137 138 if type(a) is not type(b): 139 return False 140 141 if ( 142 str(type(a)) == "<class 'numpy.ndarray'>" 143 and str(type(b)) == "<class 'numpy.ndarray'>" 144 ) or ( 145 str(type(a)) == "<class 'torch.Tensor'>" 146 and str(type(b)) == "<class 'torch.Tensor'>" 147 ): 148 return (a == b).all() 149 150 if ( 151 str(type(a)) == "<class 'pandas.core.frame.DataFrame'>" 152 and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>" 153 ): 154 return a.equals(b) 155 156 if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence): 157 if len(a) == 0 and len(b) == 0: 158 return True 159 return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b)) 160 161 if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)): 162 return len(a) == len(b) and all( 163 array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2]) 164 for k1, k2 in zip(a.keys(), b.keys()) 165 ) 166 167 try: 168 return bool(a == b) 169 except (TypeError, ValueError) as e: 170 warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}") 171 return NotImplemented # type: ignore[return-value] 172 173 174def dc_eq( 175 dc1, 176 dc2, 177 except_when_class_mismatch: bool = False, 178 false_when_class_mismatch: bool = True, 179 except_when_field_mismatch: bool = False, 180) -> bool: 181 """ 182 checks if two dataclasses which (might) hold numpy arrays are equal 183 184 # Parameters: 185 186 - `dc1`: the first dataclass 187 - `dc2`: the second dataclass 188 - `except_when_class_mismatch: bool` 189 if `True`, will throw `TypeError` if the classes are different. 190 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` 191 (default: `False`) 192 - `false_when_class_mismatch: bool` 193 only relevant if `except_when_class_mismatch` is `False`. 194 if `True`, will return `False` if the classes are different. 195 if `False`, will attempt to compare the fields. 196 - `except_when_field_mismatch: bool` 197 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. 198 if `True`, will throw `TypeError` if the fields are different. 199 (default: `True`) 200 201 # Returns: 202 - `bool`: True if the dataclasses are equal, False otherwise 203 204 # Raises: 205 - `TypeError`: if the dataclasses are of different classes 206 - `AttributeError`: if the dataclasses have different fields 207 208 # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"? 209 ``` 210 [START] 211 ▼ 212 ┌───────────┐ ┌─────────┐ 213 │dc1 is dc2?├─►│ classes │ 214 └──┬────────┘No│ match? │ 215 ──── │ ├─────────┤ 216 (True)◄──┘Yes │No │Yes 217 ──── ▼ ▼ 218 ┌────────────────┐ ┌────────────┐ 219 │ except when │ │ fields keys│ 220 │ class mismatch?│ │ match? │ 221 ├───────────┬────┘ ├───────┬────┘ 222 │Yes │No │No │Yes 223 ▼ ▼ ▼ ▼ 224 ─────────── ┌──────────┐ ┌────────┐ 225 { raise } │ except │ │ field │ 226 { TypeError } │ when │ │ values │ 227 ─────────── │ field │ │ match? │ 228 │ mismatch?│ ├────┬───┘ 229 ├───────┬──┘ │ │Yes 230 │Yes │No │No ▼ 231 ▼ ▼ │ ──── 232 ─────────────── ───── │ (True) 233 { raise } (False)◄┘ ──── 234 { AttributeError} ───── 235 ─────────────── 236 ``` 237 238 """ 239 if dc1 is dc2: 240 return True 241 242 if dc1.__class__ is not dc2.__class__: 243 if except_when_class_mismatch: 244 # if the classes don't match, raise an error 245 raise TypeError( 246 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" 247 ) 248 if except_when_field_mismatch: 249 dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)]) 250 dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)]) 251 fields_match: bool = set(dc1_fields) == set(dc2_fields) 252 if not fields_match: 253 # if the fields match, keep going 254 raise AttributeError( 255 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" 256 ) 257 return False 258 259 return all( 260 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) 261 for fld in dataclasses.fields(dc1) 262 if fld.compare 263 )
59class UniversalContainer: 60 """contains everything -- `x in UniversalContainer()` is always True""" 61 62 def __contains__(self, x: Any) -> bool: 63 return True
contains everything -- x in UniversalContainer()
is always True
66def isinstance_namedtuple(x: Any) -> bool: 67 """checks if `x` is a `namedtuple` 68 69 credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple 70 """ 71 t: type = type(x) 72 b: tuple = t.__bases__ 73 if len(b) != 1 or (b[0] is not tuple): 74 return False 75 f: Any = getattr(t, "_fields", None) 76 if not isinstance(f, tuple): 77 return False 78 return all(isinstance(n, str) for n in f)
checks if x
is a namedtuple
credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
81def try_catch(func: Callable): 82 """wraps the function to catch exceptions, returns serialized error message on exception 83 84 returned func will return normal result on success, or error message on exception 85 """ 86 87 @functools.wraps(func) 88 def newfunc(*args, **kwargs): 89 try: 90 return func(*args, **kwargs) 91 except Exception as e: 92 return f"{e.__class__.__name__}: {e}" 93 94 return newfunc
wraps the function to catch exceptions, returns serialized error message on exception
returned func will return normal result on success, or error message on exception
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- add_note
- args
115def string_as_lines(s: str | None) -> list[str]: 116 """for easier reading of long strings in json, split up by newlines 117 118 sort of like how jupyter notebooks do it 119 """ 120 if s is None: 121 return list() 122 else: 123 return s.splitlines(keepends=False)
for easier reading of long strings in json, split up by newlines
sort of like how jupyter notebooks do it
134def array_safe_eq(a: Any, b: Any) -> bool: 135 """check if two objects are equal, account for if numpy arrays or torch tensors""" 136 if a is b: 137 return True 138 139 if type(a) is not type(b): 140 return False 141 142 if ( 143 str(type(a)) == "<class 'numpy.ndarray'>" 144 and str(type(b)) == "<class 'numpy.ndarray'>" 145 ) or ( 146 str(type(a)) == "<class 'torch.Tensor'>" 147 and str(type(b)) == "<class 'torch.Tensor'>" 148 ): 149 return (a == b).all() 150 151 if ( 152 str(type(a)) == "<class 'pandas.core.frame.DataFrame'>" 153 and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>" 154 ): 155 return a.equals(b) 156 157 if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence): 158 if len(a) == 0 and len(b) == 0: 159 return True 160 return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b)) 161 162 if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)): 163 return len(a) == len(b) and all( 164 array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2]) 165 for k1, k2 in zip(a.keys(), b.keys()) 166 ) 167 168 try: 169 return bool(a == b) 170 except (TypeError, ValueError) as e: 171 warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}") 172 return NotImplemented # type: ignore[return-value]
check if two objects are equal, account for if numpy arrays or torch tensors
175def dc_eq( 176 dc1, 177 dc2, 178 except_when_class_mismatch: bool = False, 179 false_when_class_mismatch: bool = True, 180 except_when_field_mismatch: bool = False, 181) -> bool: 182 """ 183 checks if two dataclasses which (might) hold numpy arrays are equal 184 185 # Parameters: 186 187 - `dc1`: the first dataclass 188 - `dc2`: the second dataclass 189 - `except_when_class_mismatch: bool` 190 if `True`, will throw `TypeError` if the classes are different. 191 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` 192 (default: `False`) 193 - `false_when_class_mismatch: bool` 194 only relevant if `except_when_class_mismatch` is `False`. 195 if `True`, will return `False` if the classes are different. 196 if `False`, will attempt to compare the fields. 197 - `except_when_field_mismatch: bool` 198 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. 199 if `True`, will throw `TypeError` if the fields are different. 200 (default: `True`) 201 202 # Returns: 203 - `bool`: True if the dataclasses are equal, False otherwise 204 205 # Raises: 206 - `TypeError`: if the dataclasses are of different classes 207 - `AttributeError`: if the dataclasses have different fields 208 209 # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"? 210 ``` 211 [START] 212 ▼ 213 ┌───────────┐ ┌─────────┐ 214 │dc1 is dc2?├─►│ classes │ 215 └──┬────────┘No│ match? │ 216 ──── │ ├─────────┤ 217 (True)◄──┘Yes │No │Yes 218 ──── ▼ ▼ 219 ┌────────────────┐ ┌────────────┐ 220 │ except when │ │ fields keys│ 221 │ class mismatch?│ │ match? │ 222 ├───────────┬────┘ ├───────┬────┘ 223 │Yes │No │No │Yes 224 ▼ ▼ ▼ ▼ 225 ─────────── ┌──────────┐ ┌────────┐ 226 { raise } │ except │ │ field │ 227 { TypeError } │ when │ │ values │ 228 ─────────── │ field │ │ match? │ 229 │ mismatch?│ ├────┬───┘ 230 ├───────┬──┘ │ │Yes 231 │Yes │No │No ▼ 232 ▼ ▼ │ ──── 233 ─────────────── ───── │ (True) 234 { raise } (False)◄┘ ──── 235 { AttributeError} ───── 236 ─────────────── 237 ``` 238 239 """ 240 if dc1 is dc2: 241 return True 242 243 if dc1.__class__ is not dc2.__class__: 244 if except_when_class_mismatch: 245 # if the classes don't match, raise an error 246 raise TypeError( 247 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" 248 ) 249 if except_when_field_mismatch: 250 dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)]) 251 dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)]) 252 fields_match: bool = set(dc1_fields) == set(dc2_fields) 253 if not fields_match: 254 # if the fields match, keep going 255 raise AttributeError( 256 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" 257 ) 258 return False 259 260 return all( 261 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) 262 for fld in dataclasses.fields(dc1) 263 if fld.compare 264 )
checks if two dataclasses which (might) hold numpy arrays are equal
Parameters:
dc1
: the first dataclassdc2
: the second dataclassexcept_when_class_mismatch: bool
ifTrue
, will throwTypeError
if the classes are different. if not, will return false by default or attempt to compare the fields iffalse_when_class_mismatch
isFalse
(default:False
)false_when_class_mismatch: bool
only relevant ifexcept_when_class_mismatch
isFalse
. ifTrue
, will returnFalse
if the classes are different. ifFalse
, will attempt to compare the fields.except_when_field_mismatch: bool
only relevant ifexcept_when_class_mismatch
isFalse
andfalse_when_class_mismatch
isFalse
. ifTrue
, will throwTypeError
if the fields are different. (default:True
)
Returns:
bool
: True if the dataclasses are equal, False otherwise
Raises:
TypeError
: if the dataclasses are of different classesAttributeError
: if the dataclasses have different fields
TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?
[START]
▼
┌───────────┐ ┌─────────┐
│dc1 is dc2?├─►│ classes │
└──┬────────┘No│ match? │
──── │ ├─────────┤
(True)◄──┘Yes │No │Yes
──── ▼ ▼
┌────────────────┐ ┌────────────┐
│ except when │ │ fields keys│
│ class mismatch?│ │ match? │
├───────────┬────┘ ├───────┬────┘
│Yes │No │No │Yes
▼ ▼ ▼ ▼
─────────── ┌──────────┐ ┌────────┐
{ raise } │ except │ │ field │
{ TypeError } │ when │ │ values │
─────────── │ field │ │ match? │
│ mismatch?│ ├────┬───┘
├───────┬──┘ │ │Yes
│Yes │No │No ▼
▼ ▼ │ ────
─────────────── ───── │ (True)
{ raise } (False)◄┘ ────
{ AttributeError} ─────
───────────────
31 class MonoTuple: 32 """tuple type hint, but for a tuple of any length with all the same type""" 33 34 __slots__ = () 35 36 def __new__(cls, *args, **kwargs): 37 raise TypeError("Type MonoTuple cannot be instantiated.") 38 39 def __init_subclass__(cls, *args, **kwargs): 40 raise TypeError(f"Cannot subclass {cls.__module__}") 41 42 # idk why mypy thinks there is no such function in typing 43 @typing._tp_cache # type: ignore 44 def __class_getitem__(cls, params): 45 if getattr(params, "__origin__", None) == typing.Union: 46 return typing.GenericAlias(tuple, (params, Ellipsis)) 47 elif isinstance(params, type): 48 typing.GenericAlias(tuple, (params, Ellipsis)) 49 # test if has len and is iterable 50 elif isinstance(params, Iterable): 51 if len(params) == 0: 52 return tuple 53 elif len(params) == 1: 54 return typing.GenericAlias(tuple, (params[0], Ellipsis)) 55 else: 56 raise TypeError(f"MonoTuple expects 1 type argument, got {params = }")
tuple type hint, but for a tuple of any length with all the same type