muutils.json_serialize
submodule for serializing things to json in a recoverable way
you can throw any object into muutils.json_serialize.json_serialize
and it will return a JSONitem
, meaning a bool, int, float, str, None, list of JSONitem
s, or a dict mappting to JSONitem
.
The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into json_serialize
and it will just work. If you want to do so in a recoverable way, check out ZANJ
.
it will do so by looking in DEFAULT_HANDLERS
, which will keep it as-is if its already valid, then try to find a .serialize()
method on the object, and then have a bunch of special cases. You can add handlers by initializing a JsonSerializer
object and passing a sequence of them to handlers_pre
additionally, SerializeableDataclass
is a special kind of dataclass where you specify how to serialize each field, and a .serialize()
method is automatically added to the class. This is done by using the serializable_dataclass
decorator, inheriting from SerializeableDataclass
, and serializable_field
in place of dataclasses.field
when defining non-standard fields.
This module plays nicely with and is a dependency of the ZANJ
library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.
1"""submodule for serializing things to json in a recoverable way 2 3you can throw *any* object into `muutils.json_serialize.json_serialize` 4and it will return a `JSONitem`, meaning a bool, int, float, str, None, list of `JSONitem`s, or a dict mappting to `JSONitem`. 5 6The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into `json_serialize` and it will just work. If you want to do so in a recoverable way, check out [`ZANJ`](https://github.com/mivanit/ZANJ). 7 8it will do so by looking in `DEFAULT_HANDLERS`, which will keep it as-is if its already valid, then try to find a `.serialize()` method on the object, and then have a bunch of special cases. You can add handlers by initializing a `JsonSerializer` object and passing a sequence of them to `handlers_pre` 9 10additionally, `SerializeableDataclass` is a special kind of dataclass where you specify how to serialize each field, and a `.serialize()` method is automatically added to the class. This is done by using the `serializable_dataclass` decorator, inheriting from `SerializeableDataclass`, and `serializable_field` in place of `dataclasses.field` when defining non-standard fields. 11 12This module plays nicely with and is a dependency of the [`ZANJ`](https://github.com/mivanit/ZANJ) library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes. 13 14""" 15 16from __future__ import annotations 17 18from muutils.json_serialize.array import arr_metadata, load_array 19from muutils.json_serialize.json_serialize import ( 20 BASE_HANDLERS, 21 JsonSerializer, 22 json_serialize, 23) 24from muutils.json_serialize.serializable_dataclass import ( 25 SerializableDataclass, 26 serializable_dataclass, 27 serializable_field, 28) 29from muutils.json_serialize.util import try_catch, JSONitem, dc_eq 30 31__all__ = [ 32 # submodules 33 "array", 34 "json_serialize", 35 "serializable_dataclass", 36 "serializable_field", 37 "util", 38 # imports 39 "arr_metadata", 40 "load_array", 41 "BASE_HANDLERS", 42 "JSONitem", 43 "JsonSerializer", 44 "json_serialize", 45 "try_catch", 46 "JSONitem", 47 "dc_eq", 48 "serializable_dataclass", 49 "serializable_field", 50 "SerializableDataclass", 51]
332def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem: 333 """serialize object to json-serializable object with default config""" 334 return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)
serialize object to json-serializable object with default config
580@dataclass_transform( 581 field_specifiers=(serializable_field, SerializableField), 582) 583def serializable_dataclass( 584 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 585 _cls=None, # type: ignore 586 *, 587 init: bool = True, 588 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 589 eq: bool = True, 590 order: bool = False, 591 unsafe_hash: bool = False, 592 frozen: bool = False, 593 properties_to_serialize: Optional[list[str]] = None, 594 register_handler: bool = True, 595 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 596 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 597 methods_no_override: list[str] | None = None, 598 **kwargs, 599): 600 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!** 601 602 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 603 604 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass` 605 606 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 607 608 Examines PEP 526 `__annotations__` to determine fields. 609 610 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 611 612 ```python 613 @serializable_dataclass(kw_only=True) 614 class Myclass(SerializableDataclass): 615 a: int 616 b: str 617 ``` 618 ```python 619 >>> Myclass(a=1, b="q").serialize() 620 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 621 ``` 622 623 # Parameters: 624 625 - `_cls : _type_` 626 class to decorate. don't pass this arg, just use this as a decorator 627 (defaults to `None`) 628 - `init : bool` 629 whether to add an `__init__` method 630 *(passed to dataclasses.dataclass)* 631 (defaults to `True`) 632 - `repr : bool` 633 whether to add a `__repr__` method 634 *(passed to dataclasses.dataclass)* 635 (defaults to `True`) 636 - `order : bool` 637 whether to add rich comparison methods 638 *(passed to dataclasses.dataclass)* 639 (defaults to `False`) 640 - `unsafe_hash : bool` 641 whether to add a `__hash__` method 642 *(passed to dataclasses.dataclass)* 643 (defaults to `False`) 644 - `frozen : bool` 645 whether to make the class frozen 646 *(passed to dataclasses.dataclass)* 647 (defaults to `False`) 648 - `properties_to_serialize : Optional[list[str]]` 649 which properties to add to the serialized data dict 650 **SerializableDataclass only** 651 (defaults to `None`) 652 - `register_handler : bool` 653 if true, register the class with ZANJ for loading 654 **SerializableDataclass only** 655 (defaults to `True`) 656 - `on_typecheck_error : ErrorMode` 657 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 658 **SerializableDataclass only** 659 - `on_typecheck_mismatch : ErrorMode` 660 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 661 **SerializableDataclass only** 662 - `methods_no_override : list[str]|None` 663 list of methods that should not be overridden by the decorator 664 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function, 665 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence 666 **SerializableDataclass only** 667 (defaults to `None`) 668 - `**kwargs` 669 *(passed to dataclasses.dataclass)* 670 671 # Returns: 672 673 - `_type_` 674 the decorated class 675 676 # Raises: 677 678 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 679 - `NotSerializableFieldException` : if a field is not a `SerializableField` 680 - `FieldSerializationError` : if there is an error serializing a field 681 - `AttributeError` : if a property is not found on the class 682 - `FieldLoadingError` : if there is an error loading a field 683 """ 684 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 685 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 686 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 687 688 if properties_to_serialize is None: 689 _properties_to_serialize: list = list() 690 else: 691 _properties_to_serialize = properties_to_serialize 692 693 def wrap(cls: Type[T]) -> Type[T]: 694 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 695 for field_name, field_type in cls.__annotations__.items(): 696 field_value = getattr(cls, field_name, None) 697 if not isinstance(field_value, SerializableField): 698 if isinstance(field_value, dataclasses.Field): 699 # Convert the field to a SerializableField while preserving properties 700 field_value = SerializableField.from_Field(field_value) 701 else: 702 # Create a new SerializableField 703 field_value = serializable_field() 704 setattr(cls, field_name, field_value) 705 706 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 707 if sys.version_info < (3, 10): 708 if "kw_only" in kwargs: 709 if kwargs["kw_only"] == True: # noqa: E712 710 raise KWOnlyError( 711 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored" 712 ) 713 else: 714 del kwargs["kw_only"] 715 716 # call `dataclasses.dataclass` to set some stuff up 717 cls = dataclasses.dataclass( # type: ignore[call-overload] 718 cls, 719 init=init, 720 repr=repr, 721 eq=eq, 722 order=order, 723 unsafe_hash=unsafe_hash, 724 frozen=frozen, 725 **kwargs, 726 ) 727 728 # copy these to the class 729 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 730 731 # ====================================================================== 732 # define `serialize` func 733 # done locally since it depends on args to the decorator 734 # ====================================================================== 735 def serialize(self) -> dict[str, Any]: 736 result: dict[str, Any] = { 737 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 738 } 739 # for each field in the class 740 for field in dataclasses.fields(self): # type: ignore[arg-type] 741 # need it to be our special SerializableField 742 if not isinstance(field, SerializableField): 743 raise NotSerializableFieldException( 744 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 745 f"but a {type(field)} " 746 "this state should be inaccessible, please report this bug!" 747 ) 748 749 # try to save it 750 if field.serialize: 751 try: 752 # get the val 753 value = getattr(self, field.name) 754 # if it is a serializable dataclass, serialize it 755 if isinstance(value, SerializableDataclass): 756 value = value.serialize() 757 # if the value has a serialization function, use that 758 if hasattr(value, "serialize") and callable(value.serialize): 759 value = value.serialize() 760 # if the field has a serialization function, use that 761 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 762 elif field.serialization_fn: 763 value = field.serialization_fn(value) 764 765 # store the value in the result 766 result[field.name] = value 767 except Exception as e: 768 raise FieldSerializationError( 769 "\n".join( 770 [ 771 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 772 f"{field = }", 773 f"{value = }", 774 f"{self = }", 775 ] 776 ) 777 ) from e 778 779 # store each property if we can get it 780 for prop in self._properties_to_serialize: 781 if hasattr(cls, prop): 782 value = getattr(self, prop) 783 result[prop] = value 784 else: 785 raise AttributeError( 786 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 787 + f"but it is in {self._properties_to_serialize = }" 788 + f"\n{self = }" 789 ) 790 791 return result 792 793 # ====================================================================== 794 # define `load` func 795 # done locally since it depends on args to the decorator 796 # ====================================================================== 797 # mypy thinks this isnt a classmethod 798 @classmethod # type: ignore[misc] 799 def load(cls, data: dict[str, Any] | T) -> Type[T]: 800 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 801 if isinstance(data, cls): 802 return data 803 804 assert isinstance( 805 data, typing.Mapping 806 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 807 808 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 809 810 # initialize dict for keeping what we will pass to the constructor 811 ctor_kwargs: dict[str, Any] = dict() 812 813 # iterate over the fields of the class 814 for field in dataclasses.fields(cls): 815 # check if the field is a SerializableField 816 assert isinstance( 817 field, SerializableField 818 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 819 820 # check if the field is in the data and if it should be initialized 821 if (field.name in data) and field.init: 822 # get the value, we will be processing it 823 value: Any = data[field.name] 824 825 # get the type hint for the field 826 field_type_hint: Any = cls_type_hints.get(field.name, None) 827 828 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 829 if field.deserialize_fn: 830 # if it has a deserialization function, use that 831 value = field.deserialize_fn(value) 832 elif field.loading_fn: 833 # if it has a loading function, use that 834 value = field.loading_fn(data) 835 elif ( 836 field_type_hint is not None 837 and hasattr(field_type_hint, "load") 838 and callable(field_type_hint.load) 839 ): 840 # if no loading function but has a type hint with a load method, use that 841 if isinstance(value, dict): 842 value = field_type_hint.load(value) 843 else: 844 raise FieldLoadingError( 845 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 846 ) 847 else: 848 # assume no loading needs to happen, keep `value` as-is 849 pass 850 851 # store the value in the constructor kwargs 852 ctor_kwargs[field.name] = value 853 854 # create a new instance of the class with the constructor kwargs 855 output: cls = cls(**ctor_kwargs) 856 857 # validate the types of the fields if needed 858 if on_typecheck_mismatch != ErrorMode.IGNORE: 859 fields_valid: dict[str, bool] = ( 860 SerializableDataclass__validate_fields_types__dict( 861 output, 862 on_typecheck_error=on_typecheck_error, 863 ) 864 ) 865 866 # if there are any fields that are not valid, raise an error 867 if not all(fields_valid.values()): 868 msg: str = ( 869 f"Type mismatch in fields of {cls.__name__}:\n" 870 + "\n".join( 871 [ 872 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 873 for k, v in fields_valid.items() 874 if not v 875 ] 876 ) 877 ) 878 879 on_typecheck_mismatch.process( 880 msg, except_cls=FieldTypeMismatchError 881 ) 882 883 # return the new instance 884 return output 885 886 _methods_no_override: set[str] 887 if methods_no_override is None: 888 _methods_no_override = set() 889 else: 890 _methods_no_override = set(methods_no_override) 891 892 if _methods_no_override - { 893 "__eq__", 894 "serialize", 895 "load", 896 "validate_fields_types", 897 }: 898 warnings.warn( 899 f"Unknown methods in `methods_no_override`: {_methods_no_override = }" 900 ) 901 902 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 903 if "serialize" not in _methods_no_override: 904 # type is `Callable[[T], dict]` 905 cls.serialize = serialize # type: ignore[attr-defined] 906 if "load" not in _methods_no_override: 907 # type is `Callable[[dict], T]` 908 cls.load = load # type: ignore[attr-defined] 909 910 if "validate_field_type" not in _methods_no_override: 911 # type is `Callable[[T, ErrorMode], bool]` 912 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 913 914 if "__eq__" not in _methods_no_override: 915 # type is `Callable[[T, T], bool]` 916 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 917 918 # Register the class with ZANJ 919 if register_handler: 920 zanj_register_loader_serializable_dataclass(cls) 921 922 return cls 923 924 if _cls is None: 925 return wrap 926 else: 927 return wrap(_cls)
decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass
!!
types will be validated (like pydantic) unless on_typecheck_mismatch
is set to ErrorMode.IGNORE
behavior of most kwargs matches that of dataclasses.dataclass
, but with some additional kwargs. any kwargs not listed here are passed to dataclasses.dataclass
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__
to determine fields.
If init is true, an __init__()
method is added to the class. If repr is true, a __repr__()
method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__()
method function is added. If frozen is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str
>>> Myclass(a=1, b="q").serialize()
{_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
Parameters:
_cls : _type_
class to decorate. don't pass this arg, just use this as a decorator (defaults toNone
)init : bool
whether to add an__init__
method (passed to dataclasses.dataclass) (defaults toTrue
)repr : bool
whether to add a__repr__
method (passed to dataclasses.dataclass) (defaults toTrue
)order : bool
whether to add rich comparison methods (passed to dataclasses.dataclass) (defaults toFalse
)unsafe_hash : bool
whether to add a__hash__
method (passed to dataclasses.dataclass) (defaults toFalse
)frozen : bool
whether to make the class frozen (passed to dataclasses.dataclass) (defaults toFalse
)properties_to_serialize : Optional[list[str]]
which properties to add to the serialized data dict SerializableDataclass only (defaults toNone
)register_handler : bool
if true, register the class with ZANJ for loading SerializableDataclass only (defaults toTrue
)on_typecheck_error : ErrorMode
what to do if type checking throws an exception (except, warn, ignore). Ifignore
and an exception is thrown, type validation will still return false SerializableDataclass onlyon_typecheck_mismatch : ErrorMode
what to do if a type mismatch is found (except, warn, ignore). Ifignore
, type validation will returnTrue
SerializableDataclass onlymethods_no_override : list[str]|None
list of methods that should not be overridden by the decorator by default,__eq__
,serialize
,load
, andvalidate_fields_types
are overridden by this function, but you can disable this if you'd rather write your own.dataclasses.dataclass
might still overwrite these, and those options take precedence SerializableDataclass only (defaults toNone
)**kwargs
(passed to dataclasses.dataclass)
Returns:
_type_
the decorated class
Raises:
KWOnlyError
: only raised ifkw_only
isTrue
and python version is <3.9, sincedataclasses.dataclass
does not support thisNotSerializableFieldException
: if a field is not aSerializableField
FieldSerializationError
: if there is an error serializing a fieldAttributeError
: if a property is not found on the classFieldLoadingError
: if there is an error loading a field
188def serializable_field( 189 *_args, 190 default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 191 default_factory: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 192 init: bool = True, 193 repr: bool = True, 194 hash: Optional[bool] = None, 195 compare: bool = True, 196 metadata: Optional[types.MappingProxyType] = None, 197 kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING, 198 serialize: bool = True, 199 serialization_fn: Optional[Callable[[Any], Any]] = None, 200 deserialize_fn: Optional[Callable[[Any], Any]] = None, 201 assert_type: bool = True, 202 custom_typecheck_fn: Optional[Callable[[type], bool]] = None, 203 **kwargs: Any, 204) -> Any: 205 """Create a new `SerializableField` 206 207 ``` 208 default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING, 209 default_factory: Callable[[], Sfield_T] 210 | dataclasses._MISSING_TYPE = dataclasses.MISSING, 211 init: bool = True, 212 repr: bool = True, 213 hash: Optional[bool] = None, 214 compare: bool = True, 215 metadata: types.MappingProxyType | None = None, 216 kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING, 217 # ---------------------------------------------------------------------- 218 # new in `SerializableField`, not in `dataclasses.Field` 219 serialize: bool = True, 220 serialization_fn: Optional[Callable[[Any], Any]] = None, 221 loading_fn: Optional[Callable[[Any], Any]] = None, 222 deserialize_fn: Optional[Callable[[Any], Any]] = None, 223 assert_type: bool = True, 224 custom_typecheck_fn: Optional[Callable[[type], bool]] = None, 225 ``` 226 227 # new Parameters: 228 - `serialize`: whether to serialize this field when serializing the class' 229 - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize` 230 - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is. 231 - `deserialize_fn`: new alternative to `loading_fn`. takes only the field's value, not the whole class. if both `loading_fn` and `deserialize_fn` are provided, an error will be raised. 232 - `assert_type`: whether to assert the type of the field when loading. if `False`, will not check the type of the field. 233 - `custom_typecheck_fn`: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking. 234 235 # Gotchas: 236 - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write: 237 238 ```python 239 class MyClass: 240 my_field: int = serializable_field( 241 serialization_fn=lambda x: str(x), 242 loading_fn=lambda x["my_field"]: int(x) 243 ) 244 ``` 245 246 using `deserialize_fn` instead: 247 248 ```python 249 class MyClass: 250 my_field: int = serializable_field( 251 serialization_fn=lambda x: str(x), 252 deserialize_fn=lambda x: int(x) 253 ) 254 ``` 255 256 In the above code, `my_field` is an int but will be serialized as a string. 257 258 note that if not using ZANJ, and you have a class inside a container, you MUST provide 259 `serialization_fn` and `loading_fn` to serialize and load the container. 260 ZANJ will automatically do this for you. 261 262 # TODO: `custom_value_check_fn`: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test 263 """ 264 assert len(_args) == 0, f"unexpected positional arguments: {_args}" 265 return SerializableField( 266 default=default, 267 default_factory=default_factory, 268 init=init, 269 repr=repr, 270 hash=hash, 271 compare=compare, 272 metadata=metadata, 273 kw_only=kw_only, 274 serialize=serialize, 275 serialization_fn=serialization_fn, 276 deserialize_fn=deserialize_fn, 277 assert_type=assert_type, 278 custom_typecheck_fn=custom_typecheck_fn, 279 **kwargs, 280 )
Create a new SerializableField
default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
# ----------------------------------------------------------------------
# new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
new Parameters:
serialize
: whether to serialize this field when serializing the class'serialization_fn
: function taking the instance of the field and returning a serializable object. If not provided, will iterate through theSerializerHandler
s defined inmuutils.json_serialize.json_serialize
loading_fn
: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.deserialize_fn
: new alternative toloading_fn
. takes only the field's value, not the whole class. if bothloading_fn
anddeserialize_fn
are provided, an error will be raised.assert_type
: whether to assert the type of the field when loading. ifFalse
, will not check the type of the field.custom_typecheck_fn
: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.
Gotchas:
loading_fn
takes the dict of the class, not the field. if you wanted aloading_fn
that does nothing, you'd write:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
loading_fn=lambda x["my_field"]: int(x)
)
using deserialize_fn
instead:
class MyClass:
my_field: int = serializable_field(
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: int(x)
)
In the above code, my_field
is an int but will be serialized as a string.
note that if not using ZANJ, and you have a class inside a container, you MUST provide
serialization_fn
and loading_fn
to serialize and load the container.
ZANJ will automatically do this for you.
TODO: custom_value_check_fn
: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test
49def arr_metadata(arr) -> dict[str, list[int] | str | int]: 50 """get metadata for a numpy array""" 51 return { 52 "shape": list(arr.shape), 53 "dtype": ( 54 arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype) 55 ), 56 "n_elements": array_n_elements(arr), 57 }
get metadata for a numpy array
168def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any: 169 """load a json-serialized array, infer the mode if not specified""" 170 # return arr if its already a numpy array 171 if isinstance(arr, np.ndarray) and array_mode is None: 172 return arr 173 174 # try to infer the array_mode 175 array_mode_inferred: ArrayMode = infer_array_mode(arr) 176 if array_mode is None: 177 array_mode = array_mode_inferred 178 elif array_mode != array_mode_inferred: 179 warnings.warn( 180 f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}" 181 ) 182 183 # actually load the array 184 if array_mode == "array_list_meta": 185 assert isinstance( 186 arr, typing.Mapping 187 ), f"invalid list format: {type(arr) = }\n{arr = }" 188 data = np.array(arr["data"], dtype=arr["dtype"]) # type: ignore 189 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore 190 raise ValueError(f"invalid shape: {arr}") 191 return data 192 193 elif array_mode == "array_hex_meta": 194 assert isinstance( 195 arr, typing.Mapping 196 ), f"invalid list format: {type(arr) = }\n{arr = }" 197 data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"]) # type: ignore 198 return data.reshape(arr["shape"]) # type: ignore 199 200 elif array_mode == "array_b64_meta": 201 assert isinstance( 202 arr, typing.Mapping 203 ), f"invalid list format: {type(arr) = }\n{arr = }" 204 data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"]) # type: ignore 205 return data.reshape(arr["shape"]) # type: ignore 206 207 elif array_mode == "list": 208 assert isinstance( 209 arr, typing.Sequence 210 ), f"invalid list format: {type(arr) = }\n{arr = }" 211 return np.array(arr) # type: ignore 212 elif array_mode == "external": 213 # assume ZANJ has taken care of it 214 assert isinstance(arr, typing.Mapping) 215 if "data" not in arr: 216 raise KeyError( 217 f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}" 218 ) 219 return arr["data"] 220 elif array_mode == "zero_dim": 221 assert isinstance(arr, typing.Mapping) 222 data = np.array(arr["data"]) 223 if tuple(arr["shape"]) != tuple(data.shape): # type: ignore 224 raise ValueError(f"invalid shape: {arr}") 225 return data 226 else: 227 raise ValueError(f"invalid array_mode: {array_mode}")
load a json-serialized array, infer the mode if not specified
236class JsonSerializer: 237 """Json serialization class (holds configs) 238 239 # Parameters: 240 - `array_mode : ArrayMode` 241 how to write arrays 242 (defaults to `"array_list_meta"`) 243 - `error_mode : ErrorMode` 244 what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") 245 (defaults to `"except"`) 246 - `handlers_pre : MonoTuple[SerializerHandler]` 247 handlers to use before the default handlers 248 (defaults to `tuple()`) 249 - `handlers_default : MonoTuple[SerializerHandler]` 250 default handlers to use 251 (defaults to `DEFAULT_HANDLERS`) 252 - `write_only_format : bool` 253 changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) 254 (defaults to `False`) 255 256 # Raises: 257 - `ValueError`: on init, if `args` is not empty 258 - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"` 259 260 """ 261 262 def __init__( 263 self, 264 *args, 265 array_mode: ArrayMode = "array_list_meta", 266 error_mode: ErrorMode = ErrorMode.EXCEPT, 267 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 268 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 269 write_only_format: bool = False, 270 ): 271 if len(args) > 0: 272 raise ValueError( 273 f"JsonSerializer takes no positional arguments!\n{args = }" 274 ) 275 276 self.array_mode: ArrayMode = array_mode 277 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 278 self.write_only_format: bool = write_only_format 279 # join up the handlers 280 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 281 handlers_default 282 ) 283 284 def json_serialize( 285 self, 286 obj: Any, 287 path: ObjectPath = tuple(), 288 ) -> JSONitem: 289 try: 290 for handler in self.handlers: 291 if handler.check(self, obj, path): 292 output: JSONitem = handler.serialize_func(self, obj, path) 293 if self.write_only_format: 294 if isinstance(output, dict) and _FORMAT_KEY in output: 295 new_fmt: JSONitem = output.pop(_FORMAT_KEY) 296 output["__write_format__"] = new_fmt 297 return output 298 299 raise ValueError(f"no handler found for object with {type(obj) = }") 300 301 except Exception as e: 302 if self.error_mode == "except": 303 obj_str: str = repr(obj) 304 if len(obj_str) > 1000: 305 obj_str = obj_str[:1000] + "..." 306 raise SerializationException( 307 f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}" 308 ) from e 309 elif self.error_mode == "warn": 310 warnings.warn( 311 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 312 ) 313 314 return repr(obj) 315 316 def hashify( 317 self, 318 obj: Any, 319 path: ObjectPath = tuple(), 320 force: bool = True, 321 ) -> Hashableitem: 322 """try to turn any object into something hashable""" 323 data = self.json_serialize(obj, path=path) 324 325 # recursive hashify, turning dicts and lists into tuples 326 return _recursive_hashify(data, force=force)
Json serialization class (holds configs)
Parameters:
array_mode : ArrayMode
how to write arrays (defaults to"array_list_meta"
)error_mode : ErrorMode
what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") (defaults to"except"
)handlers_pre : MonoTuple[SerializerHandler]
handlers to use before the default handlers (defaults totuple()
)handlers_default : MonoTuple[SerializerHandler]
default handlers to use (defaults toDEFAULT_HANDLERS
)write_only_format : bool
changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) (defaults toFalse
)
Raises:
ValueError
: on init, ifargs
is not emptySerializationException
: onjson_serialize()
, if any error occurs when trying to serialize an object anderror_mode
is set toErrorMode.EXCEPT"
262 def __init__( 263 self, 264 *args, 265 array_mode: ArrayMode = "array_list_meta", 266 error_mode: ErrorMode = ErrorMode.EXCEPT, 267 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 268 handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS, 269 write_only_format: bool = False, 270 ): 271 if len(args) > 0: 272 raise ValueError( 273 f"JsonSerializer takes no positional arguments!\n{args = }" 274 ) 275 276 self.array_mode: ArrayMode = array_mode 277 self.error_mode: ErrorMode = ErrorMode.from_any(error_mode) 278 self.write_only_format: bool = write_only_format 279 # join up the handlers 280 self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple( 281 handlers_default 282 )
284 def json_serialize( 285 self, 286 obj: Any, 287 path: ObjectPath = tuple(), 288 ) -> JSONitem: 289 try: 290 for handler in self.handlers: 291 if handler.check(self, obj, path): 292 output: JSONitem = handler.serialize_func(self, obj, path) 293 if self.write_only_format: 294 if isinstance(output, dict) and _FORMAT_KEY in output: 295 new_fmt: JSONitem = output.pop(_FORMAT_KEY) 296 output["__write_format__"] = new_fmt 297 return output 298 299 raise ValueError(f"no handler found for object with {type(obj) = }") 300 301 except Exception as e: 302 if self.error_mode == "except": 303 obj_str: str = repr(obj) 304 if len(obj_str) > 1000: 305 obj_str = obj_str[:1000] + "..." 306 raise SerializationException( 307 f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}" 308 ) from e 309 elif self.error_mode == "warn": 310 warnings.warn( 311 f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}" 312 ) 313 314 return repr(obj)
316 def hashify( 317 self, 318 obj: Any, 319 path: ObjectPath = tuple(), 320 force: bool = True, 321 ) -> Hashableitem: 322 """try to turn any object into something hashable""" 323 data = self.json_serialize(obj, path=path) 324 325 # recursive hashify, turning dicts and lists into tuples 326 return _recursive_hashify(data, force=force)
try to turn any object into something hashable
99def try_catch(func: Callable): 100 """wraps the function to catch exceptions, returns serialized error message on exception 101 102 returned func will return normal result on success, or error message on exception 103 """ 104 105 @functools.wraps(func) 106 def newfunc(*args, **kwargs): 107 try: 108 return func(*args, **kwargs) 109 except Exception as e: 110 return f"{e.__class__.__name__}: {e}" 111 112 return newfunc
wraps the function to catch exceptions, returns serialized error message on exception
returned func will return normal result on success, or error message on exception
193def dc_eq( 194 dc1, 195 dc2, 196 except_when_class_mismatch: bool = False, 197 false_when_class_mismatch: bool = True, 198 except_when_field_mismatch: bool = False, 199) -> bool: 200 """ 201 checks if two dataclasses which (might) hold numpy arrays are equal 202 203 # Parameters: 204 205 - `dc1`: the first dataclass 206 - `dc2`: the second dataclass 207 - `except_when_class_mismatch: bool` 208 if `True`, will throw `TypeError` if the classes are different. 209 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` 210 (default: `False`) 211 - `false_when_class_mismatch: bool` 212 only relevant if `except_when_class_mismatch` is `False`. 213 if `True`, will return `False` if the classes are different. 214 if `False`, will attempt to compare the fields. 215 - `except_when_field_mismatch: bool` 216 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. 217 if `True`, will throw `TypeError` if the fields are different. 218 (default: `True`) 219 220 # Returns: 221 - `bool`: True if the dataclasses are equal, False otherwise 222 223 # Raises: 224 - `TypeError`: if the dataclasses are of different classes 225 - `AttributeError`: if the dataclasses have different fields 226 227 # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"? 228 ``` 229 [START] 230 ▼ 231 ┌───────────┐ ┌─────────┐ 232 │dc1 is dc2?├─►│ classes │ 233 └──┬────────┘No│ match? │ 234 ──── │ ├─────────┤ 235 (True)◄──┘Yes │No │Yes 236 ──── ▼ ▼ 237 ┌────────────────┐ ┌────────────┐ 238 │ except when │ │ fields keys│ 239 │ class mismatch?│ │ match? │ 240 ├───────────┬────┘ ├───────┬────┘ 241 │Yes │No │No │Yes 242 ▼ ▼ ▼ ▼ 243 ─────────── ┌──────────┐ ┌────────┐ 244 { raise } │ except │ │ field │ 245 { TypeError } │ when │ │ values │ 246 ─────────── │ field │ │ match? │ 247 │ mismatch?│ ├────┬───┘ 248 ├───────┬──┘ │ │Yes 249 │Yes │No │No ▼ 250 ▼ ▼ │ ──── 251 ─────────────── ───── │ (True) 252 { raise } (False)◄┘ ──── 253 { AttributeError} ───── 254 ─────────────── 255 ``` 256 257 """ 258 if dc1 is dc2: 259 return True 260 261 if dc1.__class__ is not dc2.__class__: 262 if except_when_class_mismatch: 263 # if the classes don't match, raise an error 264 raise TypeError( 265 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" 266 ) 267 if except_when_field_mismatch: 268 dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)]) 269 dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)]) 270 fields_match: bool = set(dc1_fields) == set(dc2_fields) 271 if not fields_match: 272 # if the fields match, keep going 273 raise AttributeError( 274 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" 275 ) 276 return False 277 278 return all( 279 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) 280 for fld in dataclasses.fields(dc1) 281 if fld.compare 282 )
checks if two dataclasses which (might) hold numpy arrays are equal
Parameters:
dc1
: the first dataclassdc2
: the second dataclassexcept_when_class_mismatch: bool
ifTrue
, will throwTypeError
if the classes are different. if not, will return false by default or attempt to compare the fields iffalse_when_class_mismatch
isFalse
(default:False
)false_when_class_mismatch: bool
only relevant ifexcept_when_class_mismatch
isFalse
. ifTrue
, will returnFalse
if the classes are different. ifFalse
, will attempt to compare the fields.except_when_field_mismatch: bool
only relevant ifexcept_when_class_mismatch
isFalse
andfalse_when_class_mismatch
isFalse
. ifTrue
, will throwTypeError
if the fields are different. (default:True
)
Returns:
bool
: True if the dataclasses are equal, False otherwise
Raises:
TypeError
: if the dataclasses are of different classesAttributeError
: if the dataclasses have different fields
TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?
[START]
▼
┌───────────┐ ┌─────────┐
│dc1 is dc2?├─►│ classes │
└──┬────────┘No│ match? │
──── │ ├─────────┤
(True)◄──┘Yes │No │Yes
──── ▼ ▼
┌────────────────┐ ┌────────────┐
│ except when │ │ fields keys│
│ class mismatch?│ │ match? │
├───────────┬────┘ ├───────┬────┘
│Yes │No │No │Yes
▼ ▼ ▼ ▼
─────────── ┌──────────┐ ┌────────┐
{ raise } │ except │ │ field │
{ TypeError } │ when │ │ values │
─────────── │ field │ │ match? │
│ mismatch?│ ├────┬───┘
├───────┬──┘ │ │Yes
│Yes │No │No ▼
▼ ▼ │ ────
─────────────── ───── │ (True)
{ raise } (False)◄┘ ────
{ AttributeError} ─────
───────────────
316@dataclass_transform( 317 field_specifiers=(serializable_field, SerializableField), 318) 319class SerializableDataclass(abc.ABC): 320 """Base class for serializable dataclasses 321 322 only for linting and type checking, still need to call `serializable_dataclass` decorator 323 324 # Usage: 325 326 ```python 327 @serializable_dataclass 328 class MyClass(SerializableDataclass): 329 a: int 330 b: str 331 ``` 332 333 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 334 335 >>> my_obj = MyClass(a=1, b="q") 336 >>> s = json.dumps(my_obj.serialize()) 337 >>> s 338 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 339 >>> read_obj = MyClass.load(json.loads(s)) 340 >>> read_obj == my_obj 341 True 342 343 This isn't too impressive on its own, but it gets more useful when you have nested classses, 344 or fields that are not json-serializable by default: 345 346 ```python 347 @serializable_dataclass 348 class NestedClass(SerializableDataclass): 349 x: str 350 y: MyClass 351 act_fun: torch.nn.Module = serializable_field( 352 default=torch.nn.ReLU(), 353 serialization_fn=lambda x: str(x), 354 deserialize_fn=lambda x: getattr(torch.nn, x)(), 355 ) 356 ``` 357 358 which gives us: 359 360 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 361 >>> s = json.dumps(nc.serialize()) 362 >>> s 363 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 364 >>> read_nc = NestedClass.load(json.loads(s)) 365 >>> read_nc == nc 366 True 367 """ 368 369 def serialize(self) -> dict[str, Any]: 370 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 371 raise NotImplementedError( 372 f"decorate {self.__class__ = } with `@serializable_dataclass`" 373 ) 374 375 @classmethod 376 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 377 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 378 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 379 380 def validate_fields_types( 381 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 382 ) -> bool: 383 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 384 return SerializableDataclass__validate_fields_types( 385 self, on_typecheck_error=on_typecheck_error 386 ) 387 388 def validate_field_type( 389 self, 390 field: "SerializableField|str", 391 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 392 ) -> bool: 393 """given a dataclass, check the field matches the type hint""" 394 return SerializableDataclass__validate_field_type( 395 self, field, on_typecheck_error=on_typecheck_error 396 ) 397 398 def __eq__(self, other: Any) -> bool: 399 return dc_eq(self, other) 400 401 def __hash__(self) -> int: 402 "hashes the json-serialized representation of the class" 403 return hash(json.dumps(self.serialize())) 404 405 def diff( 406 self, other: "SerializableDataclass", of_serialized: bool = False 407 ) -> dict[str, Any]: 408 """get a rich and recursive diff between two instances of a serializable dataclass 409 410 ```python 411 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 412 {'b': {'self': 2, 'other': 3}} 413 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 414 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 415 ``` 416 417 # Parameters: 418 - `other : SerializableDataclass` 419 other instance to compare against 420 - `of_serialized : bool` 421 if true, compare serialized data and not raw values 422 (defaults to `False`) 423 424 # Returns: 425 - `dict[str, Any]` 426 427 428 # Raises: 429 - `ValueError` : if the instances are not of the same type 430 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 431 """ 432 # match types 433 if type(self) is not type(other): 434 raise ValueError( 435 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 436 ) 437 438 # initialize the diff result 439 diff_result: dict = {} 440 441 # if they are the same, return the empty diff 442 try: 443 if self == other: 444 return diff_result 445 except Exception: 446 pass 447 448 # if we are working with serialized data, serialize the instances 449 if of_serialized: 450 ser_self: dict = self.serialize() 451 ser_other: dict = other.serialize() 452 453 # for each field in the class 454 for field in dataclasses.fields(self): # type: ignore[arg-type] 455 # skip fields that are not for comparison 456 if not field.compare: 457 continue 458 459 # get values 460 field_name: str = field.name 461 self_value = getattr(self, field_name) 462 other_value = getattr(other, field_name) 463 464 # if the values are both serializable dataclasses, recurse 465 if isinstance(self_value, SerializableDataclass) and isinstance( 466 other_value, SerializableDataclass 467 ): 468 nested_diff: dict = self_value.diff( 469 other_value, of_serialized=of_serialized 470 ) 471 if nested_diff: 472 diff_result[field_name] = nested_diff 473 # only support serializable dataclasses 474 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 475 other_value 476 ): 477 raise ValueError("Non-serializable dataclass is not supported") 478 else: 479 # get the values of either the serialized or the actual values 480 self_value_s = ser_self[field_name] if of_serialized else self_value 481 other_value_s = ser_other[field_name] if of_serialized else other_value 482 # compare the values 483 if not array_safe_eq(self_value_s, other_value_s): 484 diff_result[field_name] = {"self": self_value, "other": other_value} 485 486 # return the diff result 487 return diff_result 488 489 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 490 """update the instance from a nested dict, useful for configuration from command line args 491 492 # Parameters: 493 - `nested_dict : dict[str, Any]` 494 nested dict to update the instance with 495 """ 496 for field in dataclasses.fields(self): # type: ignore[arg-type] 497 field_name: str = field.name 498 self_value = getattr(self, field_name) 499 500 if field_name in nested_dict: 501 if isinstance(self_value, SerializableDataclass): 502 self_value.update_from_nested_dict(nested_dict[field_name]) 503 else: 504 setattr(self, field_name, nested_dict[field_name]) 505 506 def __copy__(self) -> "SerializableDataclass": 507 "deep copy by serializing and loading the instance to json" 508 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 509 510 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 511 "deep copy by serializing and loading the instance to json" 512 return self.__class__.load(json.loads(json.dumps(self.serialize())))
Base class for serializable dataclasses
only for linting and type checking, still need to call serializable_dataclass
decorator
Usage:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize()
to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
369 def serialize(self) -> dict[str, Any]: 370 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 371 raise NotImplementedError( 372 f"decorate {self.__class__ = } with `@serializable_dataclass`" 373 )
returns the class as a dict, implemented by using @serializable_dataclass
decorator
375 @classmethod 376 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 377 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 378 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
380 def validate_fields_types( 381 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 382 ) -> bool: 383 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 384 return SerializableDataclass__validate_fields_types( 385 self, on_typecheck_error=on_typecheck_error 386 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
388 def validate_field_type( 389 self, 390 field: "SerializableField|str", 391 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 392 ) -> bool: 393 """given a dataclass, check the field matches the type hint""" 394 return SerializableDataclass__validate_field_type( 395 self, field, on_typecheck_error=on_typecheck_error 396 )
given a dataclass, check the field matches the type hint
405 def diff( 406 self, other: "SerializableDataclass", of_serialized: bool = False 407 ) -> dict[str, Any]: 408 """get a rich and recursive diff between two instances of a serializable dataclass 409 410 ```python 411 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 412 {'b': {'self': 2, 'other': 3}} 413 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 414 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 415 ``` 416 417 # Parameters: 418 - `other : SerializableDataclass` 419 other instance to compare against 420 - `of_serialized : bool` 421 if true, compare serialized data and not raw values 422 (defaults to `False`) 423 424 # Returns: 425 - `dict[str, Any]` 426 427 428 # Raises: 429 - `ValueError` : if the instances are not of the same type 430 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 431 """ 432 # match types 433 if type(self) is not type(other): 434 raise ValueError( 435 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 436 ) 437 438 # initialize the diff result 439 diff_result: dict = {} 440 441 # if they are the same, return the empty diff 442 try: 443 if self == other: 444 return diff_result 445 except Exception: 446 pass 447 448 # if we are working with serialized data, serialize the instances 449 if of_serialized: 450 ser_self: dict = self.serialize() 451 ser_other: dict = other.serialize() 452 453 # for each field in the class 454 for field in dataclasses.fields(self): # type: ignore[arg-type] 455 # skip fields that are not for comparison 456 if not field.compare: 457 continue 458 459 # get values 460 field_name: str = field.name 461 self_value = getattr(self, field_name) 462 other_value = getattr(other, field_name) 463 464 # if the values are both serializable dataclasses, recurse 465 if isinstance(self_value, SerializableDataclass) and isinstance( 466 other_value, SerializableDataclass 467 ): 468 nested_diff: dict = self_value.diff( 469 other_value, of_serialized=of_serialized 470 ) 471 if nested_diff: 472 diff_result[field_name] = nested_diff 473 # only support serializable dataclasses 474 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 475 other_value 476 ): 477 raise ValueError("Non-serializable dataclass is not supported") 478 else: 479 # get the values of either the serialized or the actual values 480 self_value_s = ser_self[field_name] if of_serialized else self_value 481 other_value_s = ser_other[field_name] if of_serialized else other_value 482 # compare the values 483 if not array_safe_eq(self_value_s, other_value_s): 484 diff_result[field_name] = {"self": self_value, "other": other_value} 485 486 # return the diff result 487 return diff_result
get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
Parameters:
other : SerializableDataclass
other instance to compare againstof_serialized : bool
if true, compare serialized data and not raw values (defaults toFalse
)
Returns:
dict[str, Any]
Raises:
ValueError
: if the instances are not of the same typeValueError
: if the instances aredataclasses.dataclass
but notSerializableDataclass
489 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 490 """update the instance from a nested dict, useful for configuration from command line args 491 492 # Parameters: 493 - `nested_dict : dict[str, Any]` 494 nested dict to update the instance with 495 """ 496 for field in dataclasses.fields(self): # type: ignore[arg-type] 497 field_name: str = field.name 498 self_value = getattr(self, field_name) 499 500 if field_name in nested_dict: 501 if isinstance(self_value, SerializableDataclass): 502 self_value.update_from_nested_dict(nested_dict[field_name]) 503 else: 504 setattr(self, field_name, nested_dict[field_name])
update the instance from a nested dict, useful for configuration from command line args
Parameters:
- `nested_dict : dict[str, Any]`
nested dict to update the instance with