muutils.json_serialize
submodule for serializing things to json in a recoverable way
you can throw any object into muutils.json_serialize.json_serialize
and it will return a JSONitem
, meaning a bool, int, float, str, None, list of JSONitem
s, or a dict mappting to JSONitem
.
The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into json_serialize
and it will just work. If you want to do so in a recoverable way, check out ZANJ
.
it will do so by looking in DEFAULT_HANDLERS
, which will keep it as-is if its already valid, then try to find a .serialize()
method on the object, and then have a bunch of special cases. You can add handlers by initializing a JsonSerializer
object and passing a sequence of them to handlers_pre
additionally, SerializeableDataclass
is a special kind of dataclass where you specify how to serialize each field, and a .serialize()
method is automatically added to the class. This is done by using the serializable_dataclass
decorator, inheriting from SerializeableDataclass
, and serializable_field
in place of dataclasses.field
when defining non-standard fields.
This module plays nicely with and is a dependency of the ZANJ
library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.
1"""submodule for serializing things to json in a recoverable way 2 3you can throw *any* object into `muutils.json_serialize.json_serialize` 4and it will return a `JSONitem`, meaning a bool, int, float, str, None, list of `JSONitem`s, or a dict mappting to `JSONitem`. 5 6The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into `json_serialize` and it will just work. If you want to do so in a recoverable way, check out [`ZANJ`](https://github.com/mivanit/ZANJ). 7 8it will do so by looking in `DEFAULT_HANDLERS`, which will keep it as-is if its already valid, then try to find a `.serialize()` method on the object, and then have a bunch of special cases. You can add handlers by initializing a `JsonSerializer` object and passing a sequence of them to `handlers_pre` 9 10additionally, `SerializeableDataclass` is a special kind of dataclass where you specify how to serialize each field, and a `.serialize()` method is automatically added to the class. This is done by using the `serializable_dataclass` decorator, inheriting from `SerializeableDataclass`, and `serializable_field` in place of `dataclasses.field` when defining non-standard fields. 11 12This module plays nicely with and is a dependency of the [`ZANJ`](https://github.com/mivanit/ZANJ) library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes. 13 14""" 15 16from __future__ import annotations 17 18from muutils.json_serialize.array import arr_metadata, load_array 19from muutils.json_serialize.json_serialize import ( 20 BASE_HANDLERS, 21 JsonSerializer, 22 json_serialize, 23) 24from muutils.json_serialize.serializable_dataclass import ( 25 SerializableDataclass, 26 serializable_dataclass, 27 serializable_field, 28) 29from muutils.json_serialize.util import try_catch, JSONitem, dc_eq 30 31__all__ = [ 32 # submodules 33 "array", 34 "json_serialize", 35 "serializable_dataclass", 36 "serializable_field", 37 "util", 38 # imports 39 "arr_metadata", 40 "load_array", 41 "BASE_HANDLERS", 42 "JSONitem", 43 "JsonSerializer", 44 "json_serialize", 45 "try_catch", 46 "JSONitem", 47 "dc_eq", 48 "serializable_dataclass", 49 "serializable_field", 50 "SerializableDataclass", 51]
330def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem: 331 """serialize object to json-serializable object with default config""" 332 return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)
serialize object to json-serializable object with default config
577@dataclass_transform( 578 field_specifiers=(serializable_field, SerializableField), 579) 580def serializable_dataclass( 581 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 582 _cls=None, # type: ignore 583 *, 584 init: bool = True, 585 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 586 eq: bool = True, 587 order: bool = False, 588 unsafe_hash: bool = False, 589 frozen: bool = False, 590 properties_to_serialize: Optional[list[str]] = None, 591 register_handler: bool = True, 592 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 593 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 594 **kwargs, 595): 596 """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass` 597 598 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 599 600 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs 601 602 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 603 604 Examines PEP 526 `__annotations__` to determine fields. 605 606 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 607 608 ```python 609 @serializable_dataclass(kw_only=True) 610 class Myclass(SerializableDataclass): 611 a: int 612 b: str 613 ``` 614 ```python 615 >>> Myclass(a=1, b="q").serialize() 616 {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 617 ``` 618 619 # Parameters: 620 - `_cls : _type_` 621 class to decorate. don't pass this arg, just use this as a decorator 622 (defaults to `None`) 623 - `init : bool` 624 (defaults to `True`) 625 - `repr : bool` 626 (defaults to `True`) 627 - `order : bool` 628 (defaults to `False`) 629 - `unsafe_hash : bool` 630 (defaults to `False`) 631 - `frozen : bool` 632 (defaults to `False`) 633 - `properties_to_serialize : Optional[list[str]]` 634 **SerializableDataclass only:** which properties to add to the serialized data dict 635 (defaults to `None`) 636 - `register_handler : bool` 637 **SerializableDataclass only:** if true, register the class with ZANJ for loading 638 (defaults to `True`) 639 - `on_typecheck_error : ErrorMode` 640 **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 641 - `on_typecheck_mismatch : ErrorMode` 642 **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 643 644 # Returns: 645 - `_type_` 646 the decorated class 647 648 # Raises: 649 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 650 - `NotSerializableFieldException` : if a field is not a `SerializableField` 651 - `FieldSerializationError` : if there is an error serializing a field 652 - `AttributeError` : if a property is not found on the class 653 - `FieldLoadingError` : if there is an error loading a field 654 """ 655 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 656 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 657 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 658 659 if properties_to_serialize is None: 660 _properties_to_serialize: list = list() 661 else: 662 _properties_to_serialize = properties_to_serialize 663 664 def wrap(cls: Type[T]) -> Type[T]: 665 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 666 for field_name, field_type in cls.__annotations__.items(): 667 field_value = getattr(cls, field_name, None) 668 if not isinstance(field_value, SerializableField): 669 if isinstance(field_value, dataclasses.Field): 670 # Convert the field to a SerializableField while preserving properties 671 field_value = SerializableField.from_Field(field_value) 672 else: 673 # Create a new SerializableField 674 field_value = serializable_field() 675 setattr(cls, field_name, field_value) 676 677 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 678 if sys.version_info < (3, 10): 679 if "kw_only" in kwargs: 680 if kwargs["kw_only"] == True: # noqa: E712 681 raise KWOnlyError("kw_only is not supported in python >=3.9") 682 else: 683 del kwargs["kw_only"] 684 685 # call `dataclasses.dataclass` to set some stuff up 686 cls = dataclasses.dataclass( # type: ignore[call-overload] 687 cls, 688 init=init, 689 repr=repr, 690 eq=eq, 691 order=order, 692 unsafe_hash=unsafe_hash, 693 frozen=frozen, 694 **kwargs, 695 ) 696 697 # copy these to the class 698 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 699 700 # ====================================================================== 701 # define `serialize` func 702 # done locally since it depends on args to the decorator 703 # ====================================================================== 704 def serialize(self) -> dict[str, Any]: 705 result: dict[str, Any] = { 706 "__format__": f"{self.__class__.__name__}(SerializableDataclass)" 707 } 708 # for each field in the class 709 for field in dataclasses.fields(self): # type: ignore[arg-type] 710 # need it to be our special SerializableField 711 if not isinstance(field, SerializableField): 712 raise NotSerializableFieldException( 713 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 714 f"but a {type(field)} " 715 "this state should be inaccessible, please report this bug!" 716 ) 717 718 # try to save it 719 if field.serialize: 720 try: 721 # get the val 722 value = getattr(self, field.name) 723 # if it is a serializable dataclass, serialize it 724 if isinstance(value, SerializableDataclass): 725 value = value.serialize() 726 # if the value has a serialization function, use that 727 if hasattr(value, "serialize") and callable(value.serialize): 728 value = value.serialize() 729 # if the field has a serialization function, use that 730 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 731 elif field.serialization_fn: 732 value = field.serialization_fn(value) 733 734 # store the value in the result 735 result[field.name] = value 736 except Exception as e: 737 raise FieldSerializationError( 738 "\n".join( 739 [ 740 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 741 f"{field = }", 742 f"{value = }", 743 f"{self = }", 744 ] 745 ) 746 ) from e 747 748 # store each property if we can get it 749 for prop in self._properties_to_serialize: 750 if hasattr(cls, prop): 751 value = getattr(self, prop) 752 result[prop] = value 753 else: 754 raise AttributeError( 755 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 756 + f"but it is in {self._properties_to_serialize = }" 757 + f"\n{self = }" 758 ) 759 760 return result 761 762 # ====================================================================== 763 # define `load` func 764 # done locally since it depends on args to the decorator 765 # ====================================================================== 766 # mypy thinks this isnt a classmethod 767 @classmethod # type: ignore[misc] 768 def load(cls, data: dict[str, Any] | T) -> Type[T]: 769 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 770 if isinstance(data, cls): 771 return data 772 773 assert isinstance( 774 data, typing.Mapping 775 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 776 777 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 778 779 # initialize dict for keeping what we will pass to the constructor 780 ctor_kwargs: dict[str, Any] = dict() 781 782 # iterate over the fields of the class 783 for field in dataclasses.fields(cls): 784 # check if the field is a SerializableField 785 assert isinstance( 786 field, SerializableField 787 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 788 789 # check if the field is in the data and if it should be initialized 790 if (field.name in data) and field.init: 791 # get the value, we will be processing it 792 value: Any = data[field.name] 793 794 # get the type hint for the field 795 field_type_hint: Any = cls_type_hints.get(field.name, None) 796 797 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 798 if field.deserialize_fn: 799 # if it has a deserialization function, use that 800 value = field.deserialize_fn(value) 801 elif field.loading_fn: 802 # if it has a loading function, use that 803 value = field.loading_fn(data) 804 elif ( 805 field_type_hint is not None 806 and hasattr(field_type_hint, "load") 807 and callable(field_type_hint.load) 808 ): 809 # if no loading function but has a type hint with a load method, use that 810 if isinstance(value, dict): 811 value = field_type_hint.load(value) 812 else: 813 raise FieldLoadingError( 814 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 815 ) 816 else: 817 # assume no loading needs to happen, keep `value` as-is 818 pass 819 820 # store the value in the constructor kwargs 821 ctor_kwargs[field.name] = value 822 823 # create a new instance of the class with the constructor kwargs 824 output: cls = cls(**ctor_kwargs) 825 826 # validate the types of the fields if needed 827 if on_typecheck_mismatch != ErrorMode.IGNORE: 828 fields_valid: dict[str, bool] = ( 829 SerializableDataclass__validate_fields_types__dict( 830 output, 831 on_typecheck_error=on_typecheck_error, 832 ) 833 ) 834 835 # if there are any fields that are not valid, raise an error 836 if not all(fields_valid.values()): 837 msg: str = ( 838 f"Type mismatch in fields of {cls.__name__}:\n" 839 + "\n".join( 840 [ 841 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 842 for k, v in fields_valid.items() 843 if not v 844 ] 845 ) 846 ) 847 848 on_typecheck_mismatch.process( 849 msg, except_cls=FieldTypeMismatchError 850 ) 851 852 # return the new instance 853 return output 854 855 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 856 # type is `Callable[[T], dict]` 857 cls.serialize = serialize # type: ignore[attr-defined] 858 # type is `Callable[[dict], T]` 859 cls.load = load # type: ignore[attr-defined] 860 # type is `Callable[[T, ErrorMode], bool]` 861 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 862 863 # type is `Callable[[T, T], bool]` 864 if not hasattr(cls, "__eq__"): 865 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 866 867 # Register the class with ZANJ 868 if register_handler: 869 zanj_register_loader_serializable_dataclass(cls) 870 871 return cls 872 873 if _cls is None: 874 return wrap 875 else: 876 return wrap(_cls)
decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass
types will be validated (like pydantic) unless on_typecheck_mismatch
is set to ErrorMode.IGNORE
behavior of most kwargs matches that of dataclasses.dataclass
, but with some additional kwargs
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__
to determine fields.
If init is true, an __init__()
method is added to the class. If repr is true, a __repr__()
method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__()
method function is added. If frozen is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str
>>> Myclass(a=1, b="q").serialize()
{'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
Parameters:
_cls : _type_
class to decorate. don't pass this arg, just use this as a decorator (defaults toNone
)init : bool
(defaults toTrue
)repr : bool
(defaults toTrue
)order : bool
(defaults toFalse
)unsafe_hash : bool
(defaults toFalse
)frozen : bool
(defaults toFalse
)properties_to_serialize : Optional[list[str]]
SerializableDataclass only: which properties to add to the serialized data dict (defaults toNone
)register_handler : bool
SerializableDataclass only: if true, register the class with ZANJ for loading (defaults toTrue
)on_typecheck_error : ErrorMode
SerializableDataclass only: what to do if type checking throws an exception (except, warn, ignore). Ifignore
and an exception is thrown, type validation will still return falseon_typecheck_mismatch : ErrorMode
SerializableDataclass only: what to do if a type mismatch is found (except, warn, ignore). Ifignore
, type validation will returnTrue
Returns:
_type_
the decorated class
Raises:
KWOnlyError
: only raised ifkw_only
isTrue
and python version is <3.9, sincedataclasses.dataclass
does not support thisNotSerializableFieldException
: if a field is not aSerializableField
FieldSerializationError
: if there is an error serializing a fieldAttributeError
: if a property is not found on the classFieldLoadingError
: if there is an error loading a field
188def serializable_field( 189 *_args, 190 default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 191 default_factory: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 192 init: bool = True, 193 repr: bool = True, 194 hash: Optional[bool] = None, 195 compare: bool = True, 196 metadata: Optional[types.MappingProxyType] = None, 197 kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 198 serialize: bool = True, 199 serialization_fn: Optional[Callable[[Any], Any]] = None, 200 deserialize_fn: Optional[Callable[[Any], Any]] = None, 201 assert_type: bool = True, 202 custom_typecheck_fn: Optional[Callable[[type], bool]] = None, 203 **kwargs: Any, 204) -> Any: 205 """Create a new `SerializableField` 206 207 ``` 208 default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING, 209 default_factory: Callable[[], Sfield_T] 210 | dataclasses._MISSING_TYPE = dataclasses.MISSING, 211 init: bool = True, 212 repr: bool = True, 213 hash: Optional[bool] = None, 214 compare: bool = True, 215 metadata: types.MappingProxyType | None = None, 216 kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING, 217 # ---------------------------------------------------------------------- 218 # new in `SerializableField`, not in `dataclasses.Field` 219 serialize: bool = True, 220 serialization_fn: Optional[Callable[[Any], Any]] = None, 221 loading_fn: Optional[Callable[[Any], Any]] = None, 222 deserialize_fn: Optional[Callable[[Any], Any]] = None, 223 assert_type: bool = True, 224 custom_typecheck_fn: Optional[Callable[[type], bool]] = None, 225 ``` 226 227 # new Parameters: 228 - `serialize`: whether to serialize this field when serializing the class' 229 - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize` 230 - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is. 231 - `deserialize_fn`: new alternative to `loading_fn`. takes only the field's value, not the whole class. if both `loading_fn` and `deserialize_fn` are provided, an error will be raised. 232 - `assert_type`: whether to assert the type of the field when loading. if `False`, will not check the type of the field. 233 - `custom_typecheck_fn`: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking. 234 235 # Gotchas: 236 - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write: 237 238 ```python 239 class MyClass: 240 my_field: int = serializable_field( 241 serialization_fn=lambda x: str(x), 242 loading_fn=lambda x["my_field"]: int(x) 243 ) 244 ``` 245 246 using `deserialize_fn` instead: 247 248 ```python 249 class MyClass: 250 my_field: int = serializable_field( 251 serialization_fn=lambda x: str(x), 252 deserialize_fn=lambda x: int(x) 253 ) 254 ``` 255 256 In the above code, `my_field` is an int but will be serialized as a string. 257 258 note that if not using ZANJ, and you have a class inside a container, you MUST provide 259 `serialization_fn` and `loading_fn` to serialize and load the container. 260 ZANJ will automatically do this for you. 261 262 # TODO: `custom_value_check_fn`: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test 263 """ 264 assert len(_args) == 0, f"unexpected positional arguments: {_args}" 265 return SerializableField( 266 default=default, 267 default_factory=default_factory, 268 init=init, 269 repr=repr, 270 hash=hash, 271 compare=compare, 272 metadata=metadata, 273 kw_only=kw_only, 274 serialize=serialize, 275 serialization_fn=serialization_fn, 276 deserialize_fn=deserialize_fn, 277 assert_type=assert_type, 278 custom_typecheck_fn=custom_typecheck_fn, 279 **kwargs, 280 )
Create a new SerializableField
default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
# ----------------------------------------------------------------------
# new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
new Parameters:
serialize
: whether to serialize this field when serializing the class'serialization_fn
: function taking the instance of the field and returning a serializable object. If not provided, will iterate through theSerializerHandler
s defined inmuutils.json_serialize.json_serialize
loading_fn
: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.deserialize_fn
: new alternative toloading_fn
. takes only the field's value, not the whole class. if bothloading_fn
anddeserialize_fn
are provided, an error will be raised.assert_type
: whether to assert the type of the field when loading. ifFalse
, will not check the type of the field.custom_typecheck_fn
: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.
Gotchas:
loading_fn
takes the dict of the class, not the field. if you wanted aloading_fn
that does nothing, you'd write:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
loading_fn=lambda x["my_field"]: int(x)
)
using deserialize_fn
instead:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: int(x)
)
In the above code, my_field
is an int but will be serialized as a string.
note that if not using ZANJ, and you have a class inside a container, you MUST provide
serialization_fn
and loading_fn
to serialize and load the container.
ZANJ will automatically do this for you.
TODO: custom_value_check_fn
: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test
49def arr_metadata(arr) -> dict[str, list[int] | str | int]: 50 """get metadata for a numpy array""" 51 return { 52 "shape": list(arr.shape), 53 "dtype": ( 54 arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) 55 ), 56 "n_elements": array_n_elements(arr), 57 }
get metadata for a numpy array
168def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: 169 """load a json-serialized array, infer the mode if not specified""" 170 # return arr if its already a numpy array 171 if isinstance(arr, np.ndarray) and array_mode is None: 172 return arr 173 174 # try to infer the array_mode 175 array_mode_inferred: ArrayMode = infer_array_mode(arr) 176 if array_mode is None: 177 array_mode = array_mode_inferred 178 elif array_mode != array_mode_inferred: 179 warnings.warn( 180 f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}" 181 ) 182 183 # actually load the array 184 if array_mode == "array_list_meta": 185 assert isinstance( 186 arr, typing.Mapping 187 ), f"invalid list format: {type(arr) = }\n{arr = }" 188 189 data = np.array(arr["data"], dtype=arr["dtype"]) 190 if tuple(arr["shape"]) != tuple(data.shape): 191 raise ValueError(f"invalid shape: {arr}") 192 return data 193 194 elif array_mode == "array_hex_meta": 195 assert isinstance( 196 arr, typing.Mapping 197 ), f"invalid list format: {type(arr) = }\n{arr = }" 198 199 data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) 200 return data.reshape(arr["shape"]) 201 202 elif array_mode == "array_b64_meta": 203 assert isinstance( 204 arr, typing.Mapping 205 ), f"invalid list format: {type(arr) = }\n{arr = }" 206 207 data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) 208 return data.reshape(arr["shape"]) 209 210 elif array_mode == "list": 211 assert isinstance( 212 arr, typing.Sequence 213 ), f"invalid list format: {type(arr) = }\n{arr = }" 214 215 return np.array(arr) 216 elif array_mode == "external": 217 # assume ZANJ has taken care of it 218 assert isinstance(arr, typing.Mapping) 219 if "data" not in arr: 220 raise KeyError( 221 f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}" 222 ) 223 return arr["data"] 224 elif array_mode == "zero_dim": 225 assert isinstance(arr, typing.Mapping) 226 data = np.array(arr["data"]) 227 if tuple(arr["shape"]) != tuple(data.shape): 228 raise ValueError(f"invalid shape: {arr}") 229 return data 230 else: 231 raise ValueError(f"invalid array_mode: {array_mode}")
load a json-serialized array, infer the mode if not specified
234class JsonSerializer: 235 """Json serialization class (holds configs) 236 237 # Parameters: 238 - `array_mode : ArrayMode` 239 how to write arrays 240 (defaults to `"array_list_meta"`) 241 - `error_mode : ErrorMode` 242 what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") 243 (defaults to `"except"`) 244 - `handlers_pre : MonoTuple[SerializerHandler]` 245 handlers to use before the default handlers 246 (defaults to `tuple()`) 247 - `handlers_default : MonoTuple[SerializerHandler]` 248 default handlers to use 249 (defaults to `DEFAULT_HANDLERS`) 250 - `write_only_format : bool` 251 changes "__format__" keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) 252 (defaults to `False`) 253 254 # Raises: 255 - `ValueError`: on init, if `args` is not empty 256 - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"` 257 258 """ 259 260 def __init__( 261 self, 262 *args, 263 array_mode: ArrayMode = "array_list_meta", 264 error_mode: ErrorMode = ErrorMode.EXCEPT, 265 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 266 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 267 write_only_format: bool = False, 268 ): 269 if len(args) > 0: 270 raise ValueError( 271 f"JsonSerializer takes no positional arguments!\n{args = }" 272 ) 273 274 self.array_mode: ArrayMode = array_mode 275 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 276 self.write_only_format: bool = write_only_format 277 # join up the handlers 278 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 279 handlers_default 280 ) 281 282 def json_serialize( 283 self, 284 obj: Any, 285 path: ObjectPath = tuple(), 286 ) -> JSONitem: 287 try: 288 for handler in self.handlers: 289 if handler.check(self, obj, path): 290 output: JSONitem = handler.serialize_func(self, obj, path) 291 if self.write_only_format: 292 if isinstance(output, dict) and "__format__" in output: 293 new_fmt: JSONitem = output.pop("__format__") 294 output["__write_format__"] = new_fmt 295 return output 296 297 raise ValueError(f"no handler found for object with {type(obj) = }") 298 299 except Exception as e: 300 if self.error_mode == "except": 301 obj_str: str = repr(obj) 302 if len(obj_str) > 1000: 303 obj_str = obj_str[:1000] + "..." 304 raise SerializationException( 305 f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}" 306 ) from e 307 elif self.error_mode == "warn": 308 warnings.warn( 309 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 310 ) 311 312 return repr(obj) 313 314 def hashify( 315 self, 316 obj: Any, 317 path: ObjectPath = tuple(), 318 force: bool = True, 319 ) -> Hashableitem: 320 """try to turn any object into something hashable""" 321 data = self.json_serialize(obj, path=path) 322 323 # recursive hashify, turning dicts and lists into tuples 324 return _recursive_hashify(data, force=force)
Json serialization class (holds configs)
Parameters:
array_mode : ArrayMode
how to write arrays (defaults to"array_list_meta"
)error_mode : ErrorMode
what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") (defaults to"except"
)handlers_pre : MonoTuple[SerializerHandler]
handlers to use before the default handlers (defaults totuple()
)handlers_default : MonoTuple[SerializerHandler]
default handlers to use (defaults toDEFAULT_HANDLERS
)write_only_format : bool
changes "__format__" keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) (defaults toFalse
)
Raises:
ValueError
: on init, ifargs
is not emptySerializationException
: onjson_serialize()
, if any error occurs when trying to serialize an object anderror_mode
is set toErrorMode.EXCEPT"
260 def __init__( 261 self, 262 *args, 263 array_mode: ArrayMode = "array_list_meta", 264 error_mode: ErrorMode = ErrorMode.EXCEPT, 265 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 266 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 267 write_only_format: bool = False, 268 ): 269 if len(args) > 0: 270 raise ValueError( 271 f"JsonSerializer takes no positional arguments!\n{args = }" 272 ) 273 274 self.array_mode: ArrayMode = array_mode 275 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 276 self.write_only_format: bool = write_only_format 277 # join up the handlers 278 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 279 handlers_default 280 )
282 def json_serialize( 283 self, 284 obj: Any, 285 path: ObjectPath = tuple(), 286 ) -> JSONitem: 287 try: 288 for handler in self.handlers: 289 if handler.check(self, obj, path): 290 output: JSONitem = handler.serialize_func(self, obj, path) 291 if self.write_only_format: 292 if isinstance(output, dict) and "__format__" in output: 293 new_fmt: JSONitem = output.pop("__format__") 294 output["__write_format__"] = new_fmt 295 return output 296 297 raise ValueError(f"no handler found for object with {type(obj) = }") 298 299 except Exception as e: 300 if self.error_mode == "except": 301 obj_str: str = repr(obj) 302 if len(obj_str) > 1000: 303 obj_str = obj_str[:1000] + "..." 304 raise SerializationException( 305 f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}" 306 ) from e 307 elif self.error_mode == "warn": 308 warnings.warn( 309 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 310 ) 311 312 return repr(obj)
314 def hashify( 315 self, 316 obj: Any, 317 path: ObjectPath = tuple(), 318 force: bool = True, 319 ) -> Hashableitem: 320 """try to turn any object into something hashable""" 321 data = self.json_serialize(obj, path=path) 322 323 # recursive hashify, turning dicts and lists into tuples 324 return _recursive_hashify(data, force=force)
try to turn any object into something hashable
81def try_catch(func: Callable): 82 """wraps the function to catch exceptions, returns serialized error message on exception 83 84 returned func will return normal result on success, or error message on exception 85 """ 86 87 @functools.wraps(func) 88 def newfunc(*args, **kwargs): 89 try: 90 return func(*args, **kwargs) 91 except Exception as e: 92 return f"{e.__class__.__name__}: {e}" 93 94 return newfunc
wraps the function to catch exceptions, returns serialized error message on exception
returned func will return normal result on success, or error message on exception
175def dc_eq( 176 dc1, 177 dc2, 178 except_when_class_mismatch: bool = False, 179 false_when_class_mismatch: bool = True, 180 except_when_field_mismatch: bool = False, 181) -> bool: 182 """ 183 checks if two dataclasses which (might) hold numpy arrays are equal 184 185 # Parameters: 186 187 - `dc1`: the first dataclass 188 - `dc2`: the second dataclass 189 - `except_when_class_mismatch: bool` 190 if `True`, will throw `TypeError` if the classes are different. 191 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` 192 (default: `False`) 193 - `false_when_class_mismatch: bool` 194 only relevant if `except_when_class_mismatch` is `False`. 195 if `True`, will return `False` if the classes are different. 196 if `False`, will attempt to compare the fields. 197 - `except_when_field_mismatch: bool` 198 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. 199 if `True`, will throw `TypeError` if the fields are different. 200 (default: `True`) 201 202 # Returns: 203 - `bool`: True if the dataclasses are equal, False otherwise 204 205 # Raises: 206 - `TypeError`: if the dataclasses are of different classes 207 - `AttributeError`: if the dataclasses have different fields 208 209 # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"? 210 ``` 211 [START] 212 ▼ 213 ┌───────────┐ ┌─────────┐ 214 │dc1 is dc2?├─►│ classes │ 215 └──┬────────┘No│ match? │ 216 ──── │ ├─────────┤ 217 (True)◄──┘Yes │No │Yes 218 ──── ▼ ▼ 219 ┌────────────────┐ ┌────────────┐ 220 │ except when │ │ fields keys│ 221 │ class mismatch?│ │ match? │ 222 ├───────────┬────┘ ├───────┬────┘ 223 │Yes │No │No │Yes 224 ▼ ▼ ▼ ▼ 225 ─────────── ┌──────────┐ ┌────────┐ 226 { raise } │ except │ │ field │ 227 { TypeError } │ when │ │ values │ 228 ─────────── │ field │ │ match? │ 229 │ mismatch?│ ├────┬───┘ 230 ├───────┬──┘ │ │Yes 231 │Yes │No │No ▼ 232 ▼ ▼ │ ──── 233 ─────────────── ───── │ (True) 234 { raise } (False)◄┘ ──── 235 { AttributeError} ───── 236 ─────────────── 237 ``` 238 239 """ 240 if dc1 is dc2: 241 return True 242 243 if dc1.__class__ is not dc2.__class__: 244 if except_when_class_mismatch: 245 # if the classes don't match, raise an error 246 raise TypeError( 247 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" 248 ) 249 if except_when_field_mismatch: 250 dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)]) 251 dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)]) 252 fields_match: bool = set(dc1_fields) == set(dc2_fields) 253 if not fields_match: 254 # if the fields match, keep going 255 raise AttributeError( 256 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" 257 ) 258 return False 259 260 return all( 261 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) 262 for fld in dataclasses.fields(dc1) 263 if fld.compare 264 )
checks if two dataclasses which (might) hold numpy arrays are equal
Parameters:
dc1
: the first dataclassdc2
: the second dataclassexcept_when_class_mismatch: bool
ifTrue
, will throwTypeError
if the classes are different. if not, will return false by default or attempt to compare the fields iffalse_when_class_mismatch
isFalse
(default:False
)false_when_class_mismatch: bool
only relevant ifexcept_when_class_mismatch
isFalse
. ifTrue
, will returnFalse
if the classes are different. ifFalse
, will attempt to compare the fields.except_when_field_mismatch: bool
only relevant ifexcept_when_class_mismatch
isFalse
andfalse_when_class_mismatch
isFalse
. ifTrue
, will throwTypeError
if the fields are different. (default:True
)
Returns:
bool
: True if the dataclasses are equal, False otherwise
Raises:
TypeError
: if the dataclasses are of different classesAttributeError
: if the dataclasses have different fields
TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?
[START]
▼
┌───────────┐ ┌─────────┐
│dc1 is dc2?├─►│ classes │
└──┬────────┘No│ match? │
──── │ ├─────────┤
(True)◄──┘Yes │No │Yes
──── ▼ ▼
┌────────────────┐ ┌────────────┐
│ except when │ │ fields keys│
│ class mismatch?│ │ match? │
├───────────┬────┘ ├───────┬────┘
│Yes │No │No │Yes
▼ ▼ ▼ ▼
─────────── ┌──────────┐ ┌────────┐
{ raise } │ except │ │ field │
{ TypeError } │ when │ │ values │
─────────── │ field │ │ match? │
│ mismatch?│ ├────┬───┘
├───────┬──┘ │ │Yes
│Yes │No │No ▼
▼ ▼ │ ────
─────────────── ───── │ (True)
{ raise } (False)◄┘ ────
{ AttributeError} ─────
───────────────
316@dataclass_transform( 317 field_specifiers=(serializable_field, SerializableField), 318) 319class SerializableDataclass(abc.ABC): 320 """Base class for serializable dataclasses 321 322 only for linting and type checking, still need to call `serializable_dataclass` decorator 323 324 # Usage: 325 326 ```python 327 @serializable_dataclass 328 class MyClass(SerializableDataclass): 329 a: int 330 b: str 331 ``` 332 333 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 334 335 >>> my_obj = MyClass(a=1, b="q") 336 >>> s = json.dumps(my_obj.serialize()) 337 >>> s 338 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 339 >>> read_obj = MyClass.load(json.loads(s)) 340 >>> read_obj == my_obj 341 True 342 343 This isn't too impressive on its own, but it gets more useful when you have nested classses, 344 or fields that are not json-serializable by default: 345 346 ```python 347 @serializable_dataclass 348 class NestedClass(SerializableDataclass): 349 x: str 350 y: MyClass 351 act_fun: torch.nn.Module = serializable_field( 352 default=torch.nn.ReLU(), 353 serialization_fn=lambda x: str(x), 354 deserialize_fn=lambda x: getattr(torch.nn, x)(), 355 ) 356 ``` 357 358 which gives us: 359 360 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 361 >>> s = json.dumps(nc.serialize()) 362 >>> s 363 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 364 >>> read_nc = NestedClass.load(json.loads(s)) 365 >>> read_nc == nc 366 True 367 """ 368 369 def serialize(self) -> dict[str, Any]: 370 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 371 raise NotImplementedError( 372 f"decorate {self.__class__ = } with `@serializable_dataclass`" 373 ) 374 375 @classmethod 376 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 377 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 378 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 379 380 def validate_fields_types( 381 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 382 ) -> bool: 383 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 384 return SerializableDataclass__validate_fields_types( 385 self, on_typecheck_error=on_typecheck_error 386 ) 387 388 def validate_field_type( 389 self, 390 field: "SerializableField|str", 391 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 392 ) -> bool: 393 """given a dataclass, check the field matches the type hint""" 394 return SerializableDataclass__validate_field_type( 395 self, field, on_typecheck_error=on_typecheck_error 396 ) 397 398 def __eq__(self, other: Any) -> bool: 399 return dc_eq(self, other) 400 401 def __hash__(self) -> int: 402 "hashes the json-serialized representation of the class" 403 return hash(json.dumps(self.serialize())) 404 405 def diff( 406 self, other: "SerializableDataclass", of_serialized: bool = False 407 ) -> dict[str, Any]: 408 """get a rich and recursive diff between two instances of a serializable dataclass 409 410 ```python 411 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 412 {'b': {'self': 2, 'other': 3}} 413 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 414 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 415 ``` 416 417 # Parameters: 418 - `other : SerializableDataclass` 419 other instance to compare against 420 - `of_serialized : bool` 421 if true, compare serialized data and not raw values 422 (defaults to `False`) 423 424 # Returns: 425 - `dict[str, Any]` 426 427 428 # Raises: 429 - `ValueError` : if the instances are not of the same type 430 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 431 """ 432 # match types 433 if type(self) is not type(other): 434 raise ValueError( 435 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 436 ) 437 438 # initialize the diff result 439 diff_result: dict = {} 440 441 # if they are the same, return the empty diff 442 if self == other: 443 return diff_result 444 445 # if we are working with serialized data, serialize the instances 446 if of_serialized: 447 ser_self: dict = self.serialize() 448 ser_other: dict = other.serialize() 449 450 # for each field in the class 451 for field in dataclasses.fields(self): # type: ignore[arg-type] 452 # skip fields that are not for comparison 453 if not field.compare: 454 continue 455 456 # get values 457 field_name: str = field.name 458 self_value = getattr(self, field_name) 459 other_value = getattr(other, field_name) 460 461 # if the values are both serializable dataclasses, recurse 462 if isinstance(self_value, SerializableDataclass) and isinstance( 463 other_value, SerializableDataclass 464 ): 465 nested_diff: dict = self_value.diff( 466 other_value, of_serialized=of_serialized 467 ) 468 if nested_diff: 469 diff_result[field_name] = nested_diff 470 # only support serializable dataclasses 471 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 472 other_value 473 ): 474 raise ValueError("Non-serializable dataclass is not supported") 475 else: 476 # get the values of either the serialized or the actual values 477 self_value_s = ser_self[field_name] if of_serialized else self_value 478 other_value_s = ser_other[field_name] if of_serialized else other_value 479 # compare the values 480 if not array_safe_eq(self_value_s, other_value_s): 481 diff_result[field_name] = {"self": self_value, "other": other_value} 482 483 # return the diff result 484 return diff_result 485 486 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 487 """update the instance from a nested dict, useful for configuration from command line args 488 489 # Parameters: 490 - `nested_dict : dict[str, Any]` 491 nested dict to update the instance with 492 """ 493 for field in dataclasses.fields(self): # type: ignore[arg-type] 494 field_name: str = field.name 495 self_value = getattr(self, field_name) 496 497 if field_name in nested_dict: 498 if isinstance(self_value, SerializableDataclass): 499 self_value.update_from_nested_dict(nested_dict[field_name]) 500 else: 501 setattr(self, field_name, nested_dict[field_name]) 502 503 def __copy__(self) -> "SerializableDataclass": 504 "deep copy by serializing and loading the instance to json" 505 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 506 507 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 508 "deep copy by serializing and loading the instance to json" 509 return self.__class__.load(json.loads(json.dumps(self.serialize())))
Base class for serializable dataclasses
only for linting and type checking, still need to call serializable_dataclass
decorator
Usage:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize()
to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
369 def serialize(self) -> dict[str, Any]: 370 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 371 raise NotImplementedError( 372 f"decorate {self.__class__ = } with `@serializable_dataclass`" 373 )
returns the class as a dict, implemented by using @serializable_dataclass
decorator
375 @classmethod 376 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 377 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 378 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
380 def validate_fields_types( 381 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 382 ) -> bool: 383 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 384 return SerializableDataclass__validate_fields_types( 385 self, on_typecheck_error=on_typecheck_error 386 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
388 def validate_field_type( 389 self, 390 field: "SerializableField|str", 391 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 392 ) -> bool: 393 """given a dataclass, check the field matches the type hint""" 394 return SerializableDataclass__validate_field_type( 395 self, field, on_typecheck_error=on_typecheck_error 396 )
given a dataclass, check the field matches the type hint
405 def diff( 406 self, other: "SerializableDataclass", of_serialized: bool = False 407 ) -> dict[str, Any]: 408 """get a rich and recursive diff between two instances of a serializable dataclass 409 410 ```python 411 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 412 {'b': {'self': 2, 'other': 3}} 413 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 414 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 415 ``` 416 417 # Parameters: 418 - `other : SerializableDataclass` 419 other instance to compare against 420 - `of_serialized : bool` 421 if true, compare serialized data and not raw values 422 (defaults to `False`) 423 424 # Returns: 425 - `dict[str, Any]` 426 427 428 # Raises: 429 - `ValueError` : if the instances are not of the same type 430 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 431 """ 432 # match types 433 if type(self) is not type(other): 434 raise ValueError( 435 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 436 ) 437 438 # initialize the diff result 439 diff_result: dict = {} 440 441 # if they are the same, return the empty diff 442 if self == other: 443 return diff_result 444 445 # if we are working with serialized data, serialize the instances 446 if of_serialized: 447 ser_self: dict = self.serialize() 448 ser_other: dict = other.serialize() 449 450 # for each field in the class 451 for field in dataclasses.fields(self): # type: ignore[arg-type] 452 # skip fields that are not for comparison 453 if not field.compare: 454 continue 455 456 # get values 457 field_name: str = field.name 458 self_value = getattr(self, field_name) 459 other_value = getattr(other, field_name) 460 461 # if the values are both serializable dataclasses, recurse 462 if isinstance(self_value, SerializableDataclass) and isinstance( 463 other_value, SerializableDataclass 464 ): 465 nested_diff: dict = self_value.diff( 466 other_value, of_serialized=of_serialized 467 ) 468 if nested_diff: 469 diff_result[field_name] = nested_diff 470 # only support serializable dataclasses 471 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 472 other_value 473 ): 474 raise ValueError("Non-serializable dataclass is not supported") 475 else: 476 # get the values of either the serialized or the actual values 477 self_value_s = ser_self[field_name] if of_serialized else self_value 478 other_value_s = ser_other[field_name] if of_serialized else other_value 479 # compare the values 480 if not array_safe_eq(self_value_s, other_value_s): 481 diff_result[field_name] = {"self": self_value, "other": other_value} 482 483 # return the diff result 484 return diff_result
get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
Parameters:
other : SerializableDataclass
other instance to compare againstof_serialized : bool
if true, compare serialized data and not raw values (defaults toFalse
)
Returns:
dict[str, Any]
Raises:
ValueError
: if the instances are not of the same typeValueError
: if the instances aredataclasses.dataclass
but notSerializableDataclass
486 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 487 """update the instance from a nested dict, useful for configuration from command line args 488 489 # Parameters: 490 - `nested_dict : dict[str, Any]` 491 nested dict to update the instance with 492 """ 493 for field in dataclasses.fields(self): # type: ignore[arg-type] 494 field_name: str = field.name 495 self_value = getattr(self, field_name) 496 497 if field_name in nested_dict: 498 if isinstance(self_value, SerializableDataclass): 499 self_value.update_from_nested_dict(nested_dict[field_name]) 500 else: 501 setattr(self, field_name, nested_dict[field_name])
update the instance from a nested dict, useful for configuration from command line args
Parameters:
- `nested_dict : dict[str, Any]`
nested dict to update the instance with