docs for muutils v0.8.2
View Source on GitHub

muutils.json_serialize

submodule for serializing things to json in a recoverable way

you can throw any object into muutils.json_serialize.json_serialize and it will return a JSONitem, meaning a bool, int, float, str, None, list of JSONitems, or a dict mappting to JSONitem.

The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into json_serialize and it will just work. If you want to do so in a recoverable way, check out ZANJ.

it will do so by looking in DEFAULT_HANDLERS, which will keep it as-is if its already valid, then try to find a .serialize() method on the object, and then have a bunch of special cases. You can add handlers by initializing a JsonSerializer object and passing a sequence of them to handlers_pre

additionally, SerializeableDataclass is a special kind of dataclass where you specify how to serialize each field, and a .serialize() method is automatically added to the class. This is done by using the serializable_dataclass decorator, inheriting from SerializeableDataclass, and serializable_field in place of dataclasses.field when defining non-standard fields.

This module plays nicely with and is a dependency of the ZANJ library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.


 1"""submodule for serializing things to json in a recoverable way
 2
 3you can throw *any* object into `muutils.json_serialize.json_serialize`
 4and it will return a `JSONitem`, meaning a bool, int, float, str, None, list of `JSONitem`s, or a dict mappting to `JSONitem`.
 5
 6The goal of this is if you want to just be able to store something as relatively human-readable JSON, and don't care as much about recovering it, you can throw it into `json_serialize` and it will just work. If you want to do so in a recoverable way, check out [`ZANJ`](https://github.com/mivanit/ZANJ).
 7
 8it will do so by looking in `DEFAULT_HANDLERS`, which will keep it as-is if its already valid, then try to find a `.serialize()` method on the object, and then have a bunch of special cases. You can add handlers by initializing a `JsonSerializer` object and passing a sequence of them to `handlers_pre`
 9
10additionally, `SerializeableDataclass` is a special kind of dataclass where you specify how to serialize each field, and a `.serialize()` method is automatically added to the class. This is done by using the `serializable_dataclass` decorator, inheriting from `SerializeableDataclass`, and `serializable_field` in place of `dataclasses.field` when defining non-standard fields.
11
12This module plays nicely with and is a dependency of the [`ZANJ`](https://github.com/mivanit/ZANJ) library, which extends this to support saving things to disk in a more efficient way than just plain json (arrays are saved as npy files, for example), and automatically detecting how to load saved objects into their original classes.
13
14"""
15
16from __future__ import annotations
17
18from muutils.json_serialize.array import arr_metadata, load_array
19from muutils.json_serialize.json_serialize import (
20    BASE_HANDLERS,
21    JsonSerializer,
22    json_serialize,
23)
24from muutils.json_serialize.serializable_dataclass import (
25    SerializableDataclass,
26    serializable_dataclass,
27    serializable_field,
28)
29from muutils.json_serialize.util import try_catch, JSONitem, dc_eq
30
31__all__ = [
32    # submodules
33    "array",
34    "json_serialize",
35    "serializable_dataclass",
36    "serializable_field",
37    "util",
38    # imports
39    "arr_metadata",
40    "load_array",
41    "BASE_HANDLERS",
42    "JSONitem",
43    "JsonSerializer",
44    "json_serialize",
45    "try_catch",
46    "JSONitem",
47    "dc_eq",
48    "serializable_dataclass",
49    "serializable_field",
50    "SerializableDataclass",
51]

def json_serialize( obj: Any, path: tuple[typing.Union[str, int], ...] = ()) -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]:
332def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem:
333    """serialize object to json-serializable object with default config"""
334    return GLOBAL_JSON_SERIALIZER.json_serialize(obj, path=path)

serialize object to json-serializable object with default config

@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
def serializable_dataclass( _cls=None, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except, on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn, methods_no_override: list[str] | None = None, **kwargs):
559@dataclass_transform(
560    field_specifiers=(serializable_field, SerializableField),
561)
562def serializable_dataclass(
563    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
564    _cls=None,  # type: ignore
565    *,
566    init: bool = True,
567    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
568    eq: bool = True,
569    order: bool = False,
570    unsafe_hash: bool = False,
571    frozen: bool = False,
572    properties_to_serialize: Optional[list[str]] = None,
573    register_handler: bool = True,
574    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
575    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
576    methods_no_override: list[str] | None = None,
577    **kwargs,
578):
579    """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!**
580
581    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
582
583    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass`
584
585    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
586
587    Examines PEP 526 `__annotations__` to determine fields.
588
589    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
590
591    ```python
592    @serializable_dataclass(kw_only=True)
593    class Myclass(SerializableDataclass):
594        a: int
595        b: str
596    ```
597    ```python
598    >>> Myclass(a=1, b="q").serialize()
599    {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
600    ```
601
602    # Parameters:
603
604    - `_cls : _type_`
605       class to decorate. don't pass this arg, just use this as a decorator
606       (defaults to `None`)
607    - `init : bool`
608       whether to add an `__init__` method
609       *(passed to dataclasses.dataclass)*
610       (defaults to `True`)
611    - `repr : bool`
612       whether to add a `__repr__` method
613       *(passed to dataclasses.dataclass)*
614       (defaults to `True`)
615    - `order : bool`
616       whether to add rich comparison methods
617       *(passed to dataclasses.dataclass)*
618       (defaults to `False`)
619    - `unsafe_hash : bool`
620       whether to add a `__hash__` method
621       *(passed to dataclasses.dataclass)*
622       (defaults to `False`)
623    - `frozen : bool`
624       whether to make the class frozen
625       *(passed to dataclasses.dataclass)*
626       (defaults to `False`)
627    - `properties_to_serialize : Optional[list[str]]`
628       which properties to add to the serialized data dict
629       **SerializableDataclass only**
630       (defaults to `None`)
631    - `register_handler : bool`
632        if true, register the class with ZANJ for loading
633        **SerializableDataclass only**
634        (defaults to `True`)
635    - `on_typecheck_error : ErrorMode`
636        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
637        **SerializableDataclass only**
638    - `on_typecheck_mismatch : ErrorMode`
639        what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
640        **SerializableDataclass only**
641    - `methods_no_override : list[str]|None`
642        list of methods that should not be overridden by the decorator
643        by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function,
644        but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence
645        **SerializableDataclass only**
646        (defaults to `None`)
647    - `**kwargs`
648        *(passed to dataclasses.dataclass)*
649
650    # Returns:
651
652    - `_type_`
653       the decorated class
654
655    # Raises:
656
657    - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
658    - `NotSerializableFieldException` : if a field is not a `SerializableField`
659    - `FieldSerializationError` : if there is an error serializing a field
660    - `AttributeError` : if a property is not found on the class
661    - `FieldLoadingError` : if there is an error loading a field
662    """
663    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
664    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
665    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
666
667    if properties_to_serialize is None:
668        _properties_to_serialize: list = list()
669    else:
670        _properties_to_serialize = properties_to_serialize
671
672    def wrap(cls: Type[T]) -> Type[T]:
673        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
674        for field_name, field_type in cls.__annotations__.items():
675            field_value = getattr(cls, field_name, None)
676            if not isinstance(field_value, SerializableField):
677                if isinstance(field_value, dataclasses.Field):
678                    # Convert the field to a SerializableField while preserving properties
679                    field_value = SerializableField.from_Field(field_value)
680                else:
681                    # Create a new SerializableField
682                    field_value = serializable_field()
683                setattr(cls, field_name, field_value)
684
685        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
686        if sys.version_info < (3, 10):
687            if "kw_only" in kwargs:
688                if kwargs["kw_only"] == True:  # noqa: E712
689                    raise KWOnlyError(
690                        "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored"
691                    )
692                else:
693                    del kwargs["kw_only"]
694
695        # call `dataclasses.dataclass` to set some stuff up
696        cls = dataclasses.dataclass(  # type: ignore[call-overload]
697            cls,
698            init=init,
699            repr=repr,
700            eq=eq,
701            order=order,
702            unsafe_hash=unsafe_hash,
703            frozen=frozen,
704            **kwargs,
705        )
706
707        # copy these to the class
708        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
709
710        # ======================================================================
711        # define `serialize` func
712        # done locally since it depends on args to the decorator
713        # ======================================================================
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result
771
772        # ======================================================================
773        # define `load` func
774        # done locally since it depends on args to the decorator
775        # ======================================================================
776        # mypy thinks this isnt a classmethod
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output
864
865        _methods_no_override: set[str]
866        if methods_no_override is None:
867            _methods_no_override = set()
868        else:
869            _methods_no_override = set(methods_no_override)
870
871        if _methods_no_override - {
872            "__eq__",
873            "serialize",
874            "load",
875            "validate_fields_types",
876        }:
877            warnings.warn(
878                f"Unknown methods in `methods_no_override`: {_methods_no_override = }"
879            )
880
881        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
882        if "serialize" not in _methods_no_override:
883            # type is `Callable[[T], dict]`
884            cls.serialize = serialize  # type: ignore[attr-defined]
885        if "load" not in _methods_no_override:
886            # type is `Callable[[dict], T]`
887            cls.load = load  # type: ignore[attr-defined]
888
889        if "validate_field_type" not in _methods_no_override:
890            # type is `Callable[[T, ErrorMode], bool]`
891            cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]
892
893        if "__eq__" not in _methods_no_override:
894            # type is `Callable[[T, T], bool]`
895            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
896
897        # Register the class with ZANJ
898        if register_handler:
899            zanj_register_loader_serializable_dataclass(cls)
900
901        return cls
902
903    if _cls is None:
904        return wrap
905    else:
906        return wrap(_cls)

decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass!!

types will be validated (like pydantic) unless on_typecheck_mismatch is set to ErrorMode.IGNORE

behavior of most kwargs matches that of dataclasses.dataclass, but with some additional kwargs. any kwargs not listed here are passed to dataclasses.dataclass

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

Examines PEP 526 __annotations__ to determine fields.

If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.

@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
    a: int
    b: str
>>> Myclass(a=1, b="q").serialize()
{_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}

Parameters:

  • _cls : _type_ class to decorate. don't pass this arg, just use this as a decorator (defaults to None)
  • init : bool whether to add an __init__ method (passed to dataclasses.dataclass) (defaults to True)
  • repr : bool whether to add a __repr__ method (passed to dataclasses.dataclass) (defaults to True)
  • order : bool whether to add rich comparison methods (passed to dataclasses.dataclass) (defaults to False)
  • unsafe_hash : bool whether to add a __hash__ method (passed to dataclasses.dataclass) (defaults to False)
  • frozen : bool whether to make the class frozen (passed to dataclasses.dataclass) (defaults to False)
  • properties_to_serialize : Optional[list[str]] which properties to add to the serialized data dict SerializableDataclass only (defaults to None)
  • register_handler : bool if true, register the class with ZANJ for loading SerializableDataclass only (defaults to True)
  • on_typecheck_error : ErrorMode what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, type validation will still return false SerializableDataclass only
  • on_typecheck_mismatch : ErrorMode what to do if a type mismatch is found (except, warn, ignore). If ignore, type validation will return True SerializableDataclass only
  • methods_no_override : list[str]|None list of methods that should not be overridden by the decorator by default, __eq__, serialize, load, and validate_fields_types are overridden by this function, but you can disable this if you'd rather write your own. dataclasses.dataclass might still overwrite these, and those options take precedence SerializableDataclass only (defaults to None)
  • **kwargs (passed to dataclasses.dataclass)

Returns:

  • _type_ the decorated class

Raises:

  • KWOnlyError : only raised if kw_only is True and python version is <3.9, since dataclasses.dataclass does not support this
  • NotSerializableFieldException : if a field is not a SerializableField
  • FieldSerializationError : if there is an error serializing a field
  • AttributeError : if a property is not found on the class
  • FieldLoadingError : if there is an error loading a field
def serializable_field( *_args, default: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, default_factory: Union[Any, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, init: bool = True, repr: bool = True, hash: Optional[bool] = None, compare: bool = True, metadata: Optional[mappingproxy] = None, kw_only: Union[bool, dataclasses._MISSING_TYPE] = <dataclasses._MISSING_TYPE object>, serialize: bool = True, serialization_fn: Optional[Callable[[Any], Any]] = None, deserialize_fn: Optional[Callable[[Any], Any]] = None, assert_type: bool = True, custom_typecheck_fn: Optional[Callable[[type], bool]] = None, **kwargs: Any) -> Any:
188def serializable_field(
189    *_args,
190    default: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
191    default_factory: Union[Any, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
192    init: bool = True,
193    repr: bool = True,
194    hash: Optional[bool] = None,
195    compare: bool = True,
196    metadata: Optional[types.MappingProxyType] = None,
197    kw_only: Union[bool, dataclasses._MISSING_TYPE] = dataclasses.MISSING,
198    serialize: bool = True,
199    serialization_fn: Optional[Callable[[Any], Any]] = None,
200    deserialize_fn: Optional[Callable[[Any], Any]] = None,
201    assert_type: bool = True,
202    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
203    **kwargs: Any,
204) -> Any:
205    """Create a new `SerializableField`
206
207    ```
208    default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
209    default_factory: Callable[[], Sfield_T]
210    | dataclasses._MISSING_TYPE = dataclasses.MISSING,
211    init: bool = True,
212    repr: bool = True,
213    hash: Optional[bool] = None,
214    compare: bool = True,
215    metadata: types.MappingProxyType | None = None,
216    kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
217    # ----------------------------------------------------------------------
218    # new in `SerializableField`, not in `dataclasses.Field`
219    serialize: bool = True,
220    serialization_fn: Optional[Callable[[Any], Any]] = None,
221    loading_fn: Optional[Callable[[Any], Any]] = None,
222    deserialize_fn: Optional[Callable[[Any], Any]] = None,
223    assert_type: bool = True,
224    custom_typecheck_fn: Optional[Callable[[type], bool]] = None,
225    ```
226
227    # new Parameters:
228    - `serialize`: whether to serialize this field when serializing the class'
229    - `serialization_fn`: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the `SerializerHandler`s defined in `muutils.json_serialize.json_serialize`
230    - `loading_fn`: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.
231    - `deserialize_fn`: new alternative to `loading_fn`. takes only the field's value, not the whole class. if both `loading_fn` and `deserialize_fn` are provided, an error will be raised.
232    - `assert_type`: whether to assert the type of the field when loading. if `False`, will not check the type of the field.
233    - `custom_typecheck_fn`: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.
234
235    # Gotchas:
236    - `loading_fn` takes the dict of the **class**, not the field. if you wanted a `loading_fn` that does nothing, you'd write:
237
238    ```python
239    class MyClass:
240        my_field: int = serializable_field(
241            serialization_fn=lambda x: str(x),
242            loading_fn=lambda x["my_field"]: int(x)
243        )
244    ```
245
246    using `deserialize_fn` instead:
247
248    ```python
249    class MyClass:
250        my_field: int = serializable_field(
251            serialization_fn=lambda x: str(x),
252            deserialize_fn=lambda x: int(x)
253        )
254    ```
255
256    In the above code, `my_field` is an int but will be serialized as a string.
257
258    note that if not using ZANJ, and you have a class inside a container, you MUST provide
259    `serialization_fn` and `loading_fn` to serialize and load the container.
260    ZANJ will automatically do this for you.
261
262    # TODO: `custom_value_check_fn`: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test
263    """
264    assert len(_args) == 0, f"unexpected positional arguments: {_args}"
265    return SerializableField(
266        default=default,
267        default_factory=default_factory,
268        init=init,
269        repr=repr,
270        hash=hash,
271        compare=compare,
272        metadata=metadata,
273        kw_only=kw_only,
274        serialize=serialize,
275        serialization_fn=serialization_fn,
276        deserialize_fn=deserialize_fn,
277        assert_type=assert_type,
278        custom_typecheck_fn=custom_typecheck_fn,
279        **kwargs,
280    )

Create a new SerializableField

default: Sfield_T | dataclasses._MISSING_TYPE = dataclasses.MISSING,
default_factory: Callable[[], Sfield_T]
| dataclasses._MISSING_TYPE = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: Optional[bool] = None,
compare: bool = True,
metadata: types.MappingProxyType | None = None,
kw_only: bool | dataclasses._MISSING_TYPE = dataclasses.MISSING,
# ----------------------------------------------------------------------
# new in `SerializableField`, not in `dataclasses.Field`
serialize: bool = True,
serialization_fn: Optional[Callable[[Any], Any]] = None,
loading_fn: Optional[Callable[[Any], Any]] = None,
deserialize_fn: Optional[Callable[[Any], Any]] = None,
assert_type: bool = True,
custom_typecheck_fn: Optional[Callable[[type], bool]] = None,

new Parameters:

  • serialize: whether to serialize this field when serializing the class'
  • serialization_fn: function taking the instance of the field and returning a serializable object. If not provided, will iterate through the SerializerHandlers defined in muutils.json_serialize.json_serialize
  • loading_fn: function taking the serialized object and returning the instance of the field. If not provided, will take object as-is.
  • deserialize_fn: new alternative to loading_fn. takes only the field's value, not the whole class. if both loading_fn and deserialize_fn are provided, an error will be raised.
  • assert_type: whether to assert the type of the field when loading. if False, will not check the type of the field.
  • custom_typecheck_fn: function taking the type of the field and returning whether the type itself is valid. if not provided, will use the default type checking.

Gotchas:

  • loading_fn takes the dict of the class, not the field. if you wanted a loading_fn that does nothing, you'd write:
class MyClass:
    my_field: int = serializable_field(
        serialization_fn=lambda x: str(x),
        loading_fn=lambda x["my_field"]: int(x)
    )

using deserialize_fn instead:

class MyClass:
    my_field: int = serializable_field(
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: int(x)
    )

In the above code, my_field is an int but will be serialized as a string.

note that if not using ZANJ, and you have a class inside a container, you MUST provide serialization_fn and loading_fn to serialize and load the container. ZANJ will automatically do this for you.

TODO: custom_value_check_fn: function taking the value of the field and returning whether the value itself is valid. if not provided, any value is valid as long as it passes the type test

def arr_metadata(arr) -> dict[str, list[int] | str | int]:
49def arr_metadata(arr) -> dict[str, list[int] | str | int]:
50    """get metadata for a numpy array"""
51    return {
52        "shape": list(arr.shape),
53        "dtype": (
54            arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype)
55        ),
56        "n_elements": array_n_elements(arr),
57    }

get metadata for a numpy array

def load_array( arr: Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]], array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None) -> Any:
168def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any:
169    """load a json-serialized array, infer the mode if not specified"""
170    # return arr if its already a numpy array
171    if isinstance(arr, np.ndarray) and array_mode is None:
172        return arr
173
174    # try to infer the array_mode
175    array_mode_inferred: ArrayMode = infer_array_mode(arr)
176    if array_mode is None:
177        array_mode = array_mode_inferred
178    elif array_mode != array_mode_inferred:
179        warnings.warn(
180            f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}"
181        )
182
183    # actually load the array
184    if array_mode == "array_list_meta":
185        assert isinstance(
186            arr, typing.Mapping
187        ), f"invalid list format: {type(arr) = }\n{arr = }"
188        data = np.array(arr["data"], dtype=arr["dtype"])  # type: ignore
189        if tuple(arr["shape"]) != tuple(data.shape):  # type: ignore
190            raise ValueError(f"invalid shape: {arr}")
191        return data
192
193    elif array_mode == "array_hex_meta":
194        assert isinstance(
195            arr, typing.Mapping
196        ), f"invalid list format: {type(arr) = }\n{arr = }"
197        data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"])  # type: ignore
198        return data.reshape(arr["shape"])  # type: ignore
199
200    elif array_mode == "array_b64_meta":
201        assert isinstance(
202            arr, typing.Mapping
203        ), f"invalid list format: {type(arr) = }\n{arr = }"
204        data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"])  # type: ignore
205        return data.reshape(arr["shape"])  # type: ignore
206
207    elif array_mode == "list":
208        assert isinstance(
209            arr, typing.Sequence
210        ), f"invalid list format: {type(arr) = }\n{arr = }"
211        return np.array(arr)  # type: ignore
212    elif array_mode == "external":
213        # assume ZANJ has taken care of it
214        assert isinstance(arr, typing.Mapping)
215        if "data" not in arr:
216            raise KeyError(
217                f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}"
218            )
219        return arr["data"]
220    elif array_mode == "zero_dim":
221        assert isinstance(arr, typing.Mapping)
222        data = np.array(arr["data"])
223        if tuple(arr["shape"]) != tuple(data.shape):  # type: ignore
224            raise ValueError(f"invalid shape: {arr}")
225        return data
226    else:
227        raise ValueError(f"invalid array_mode: {array_mode}")

load a json-serialized array, infer the mode if not specified

BASE_HANDLERS = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'))
JSONitem = typing.Union[bool, int, float, str, NoneType, typing.List[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]], typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]]]
class JsonSerializer:
236class JsonSerializer:
237    """Json serialization class (holds configs)
238
239    # Parameters:
240    - `array_mode : ArrayMode`
241    how to write arrays
242    (defaults to `"array_list_meta"`)
243    - `error_mode : ErrorMode`
244    what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn")
245    (defaults to `"except"`)
246    - `handlers_pre : MonoTuple[SerializerHandler]`
247    handlers to use before the default handlers
248    (defaults to `tuple()`)
249    - `handlers_default : MonoTuple[SerializerHandler]`
250    default handlers to use
251    (defaults to `DEFAULT_HANDLERS`)
252    - `write_only_format : bool`
253    changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading)
254    (defaults to `False`)
255
256    # Raises:
257    - `ValueError`: on init, if `args` is not empty
258    - `SerializationException`: on `json_serialize()`, if any error occurs when trying to serialize an object and `error_mode` is set to `ErrorMode.EXCEPT"`
259
260    """
261
262    def __init__(
263        self,
264        *args,
265        array_mode: ArrayMode = "array_list_meta",
266        error_mode: ErrorMode = ErrorMode.EXCEPT,
267        handlers_pre: MonoTuple[SerializerHandler] = tuple(),
268        handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
269        write_only_format: bool = False,
270    ):
271        if len(args) > 0:
272            raise ValueError(
273                f"JsonSerializer takes no positional arguments!\n{args = }"
274            )
275
276        self.array_mode: ArrayMode = array_mode
277        self.error_mode: ErrorMode = ErrorMode.from_any(error_mode)
278        self.write_only_format: bool = write_only_format
279        # join up the handlers
280        self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
281            handlers_default
282        )
283
284    def json_serialize(
285        self,
286        obj: Any,
287        path: ObjectPath = tuple(),
288    ) -> JSONitem:
289        try:
290            for handler in self.handlers:
291                if handler.check(self, obj, path):
292                    output: JSONitem = handler.serialize_func(self, obj, path)
293                    if self.write_only_format:
294                        if isinstance(output, dict) and _FORMAT_KEY in output:
295                            new_fmt: JSONitem = output.pop(_FORMAT_KEY)
296                            output["__write_format__"] = new_fmt
297                    return output
298
299            raise ValueError(f"no handler found for object with {type(obj) = }")
300
301        except Exception as e:
302            if self.error_mode == "except":
303                obj_str: str = repr(obj)
304                if len(obj_str) > 1000:
305                    obj_str = obj_str[:1000] + "..."
306                raise SerializationException(
307                    f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}"
308                ) from e
309            elif self.error_mode == "warn":
310                warnings.warn(
311                    f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
312                )
313
314            return repr(obj)
315
316    def hashify(
317        self,
318        obj: Any,
319        path: ObjectPath = tuple(),
320        force: bool = True,
321    ) -> Hashableitem:
322        """try to turn any object into something hashable"""
323        data = self.json_serialize(obj, path=path)
324
325        # recursive hashify, turning dicts and lists into tuples
326        return _recursive_hashify(data, force=force)

Json serialization class (holds configs)

Parameters:

  • array_mode : ArrayMode how to write arrays (defaults to "array_list_meta")
  • error_mode : ErrorMode what to do when we can't serialize an object (will use repr as fallback if "ignore" or "warn") (defaults to "except")
  • handlers_pre : MonoTuple[SerializerHandler] handlers to use before the default handlers (defaults to tuple())
  • handlers_default : MonoTuple[SerializerHandler] default handlers to use (defaults to DEFAULT_HANDLERS)
  • write_only_format : bool changes _FORMAT_KEY keys in output to "__write_format__" (when you want to serialize something in a way that zanj won't try to recover the object when loading) (defaults to False)

Raises:

  • ValueError: on init, if args is not empty
  • SerializationException: on json_serialize(), if any error occurs when trying to serialize an object and error_mode is set to ErrorMode.EXCEPT"
JsonSerializer( *args, array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim'] = 'array_list_meta', error_mode: muutils.errormode.ErrorMode = ErrorMode.Except, handlers_pre: None = (), handlers_default: None = (SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings')), write_only_format: bool = False)
262    def __init__(
263        self,
264        *args,
265        array_mode: ArrayMode = "array_list_meta",
266        error_mode: ErrorMode = ErrorMode.EXCEPT,
267        handlers_pre: MonoTuple[SerializerHandler] = tuple(),
268        handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
269        write_only_format: bool = False,
270    ):
271        if len(args) > 0:
272            raise ValueError(
273                f"JsonSerializer takes no positional arguments!\n{args = }"
274            )
275
276        self.array_mode: ArrayMode = array_mode
277        self.error_mode: ErrorMode = ErrorMode.from_any(error_mode)
278        self.write_only_format: bool = write_only_format
279        # join up the handlers
280        self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
281            handlers_default
282        )
array_mode: Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
write_only_format: bool
handlers: None
def json_serialize( self, obj: Any, path: tuple[typing.Union[str, int], ...] = ()) -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]:
284    def json_serialize(
285        self,
286        obj: Any,
287        path: ObjectPath = tuple(),
288    ) -> JSONitem:
289        try:
290            for handler in self.handlers:
291                if handler.check(self, obj, path):
292                    output: JSONitem = handler.serialize_func(self, obj, path)
293                    if self.write_only_format:
294                        if isinstance(output, dict) and _FORMAT_KEY in output:
295                            new_fmt: JSONitem = output.pop(_FORMAT_KEY)
296                            output["__write_format__"] = new_fmt
297                    return output
298
299            raise ValueError(f"no handler found for object with {type(obj) = }")
300
301        except Exception as e:
302            if self.error_mode == "except":
303                obj_str: str = repr(obj)
304                if len(obj_str) > 1000:
305                    obj_str = obj_str[:1000] + "..."
306                raise SerializationException(
307                    f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}"
308                ) from e
309            elif self.error_mode == "warn":
310                warnings.warn(
311                    f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
312                )
313
314            return repr(obj)
def hashify( self, obj: Any, path: tuple[typing.Union[str, int], ...] = (), force: bool = True) -> Union[bool, int, float, str, tuple]:
316    def hashify(
317        self,
318        obj: Any,
319        path: ObjectPath = tuple(),
320        force: bool = True,
321    ) -> Hashableitem:
322        """try to turn any object into something hashable"""
323        data = self.json_serialize(obj, path=path)
324
325        # recursive hashify, turning dicts and lists into tuples
326        return _recursive_hashify(data, force=force)

try to turn any object into something hashable

def try_catch(func: Callable):
 99def try_catch(func: Callable):
100    """wraps the function to catch exceptions, returns serialized error message on exception
101
102    returned func will return normal result on success, or error message on exception
103    """
104
105    @functools.wraps(func)
106    def newfunc(*args, **kwargs):
107        try:
108            return func(*args, **kwargs)
109        except Exception as e:
110            return f"{e.__class__.__name__}: {e}"
111
112    return newfunc

wraps the function to catch exceptions, returns serialized error message on exception

returned func will return normal result on success, or error message on exception

def dc_eq( dc1, dc2, except_when_class_mismatch: bool = False, false_when_class_mismatch: bool = True, except_when_field_mismatch: bool = False) -> bool:
193def dc_eq(
194    dc1,
195    dc2,
196    except_when_class_mismatch: bool = False,
197    false_when_class_mismatch: bool = True,
198    except_when_field_mismatch: bool = False,
199) -> bool:
200    """
201    checks if two dataclasses which (might) hold numpy arrays are equal
202
203    # Parameters:
204
205    - `dc1`: the first dataclass
206    - `dc2`: the second dataclass
207    - `except_when_class_mismatch: bool`
208        if `True`, will throw `TypeError` if the classes are different.
209        if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False`
210        (default: `False`)
211    - `false_when_class_mismatch: bool`
212        only relevant if `except_when_class_mismatch` is `False`.
213        if `True`, will return `False` if the classes are different.
214        if `False`, will attempt to compare the fields.
215    - `except_when_field_mismatch: bool`
216        only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`.
217        if `True`, will throw `TypeError` if the fields are different.
218        (default: `True`)
219
220    # Returns:
221    - `bool`: True if the dataclasses are equal, False otherwise
222
223    # Raises:
224    - `TypeError`: if the dataclasses are of different classes
225    - `AttributeError`: if the dataclasses have different fields
226
227    # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?
228    ```
229              [START]
230
231           ┌───────────┐  ┌─────────┐
232           │dc1 is dc2?├─►│ classes │
233           └──┬────────┘No│ match?  │
234      ────    │           ├─────────┤
235     (True)◄──┘Yes        │No       │Yes
236      ────                ▼         ▼
237          ┌────────────────┐ ┌────────────┐
238          │ except when    │ │ fields keys│
239          │ class mismatch?│ │ match?     │
240          ├───────────┬────┘ ├───────┬────┘
241          │Yes        │No    │No     │Yes
242          ▼           ▼      ▼       ▼
243     ───────────  ┌──────────┐  ┌────────┐
244    { raise     } │ except   │  │ field  │
245    { TypeError } │ when     │  │ values │
246     ───────────  │ field    │  │ match? │
247                  │ mismatch?│  ├────┬───┘
248                  ├───────┬──┘  │    │Yes
249                  │Yes    │No   │No  ▼
250                  ▼       ▼     │   ────
251     ───────────────     ─────  │  (True)
252    { raise         }   (False)◄┘   ────
253    { AttributeError}    ─────
254     ───────────────
255    ```
256
257    """
258    if dc1 is dc2:
259        return True
260
261    if dc1.__class__ is not dc2.__class__:
262        if except_when_class_mismatch:
263            # if the classes don't match, raise an error
264            raise TypeError(
265                f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`"
266            )
267        if except_when_field_mismatch:
268            dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)])
269            dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)])
270            fields_match: bool = set(dc1_fields) == set(dc2_fields)
271            if not fields_match:
272                # if the fields match, keep going
273                raise AttributeError(
274                    f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`"
275                )
276        return False
277
278    return all(
279        array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name))
280        for fld in dataclasses.fields(dc1)
281        if fld.compare
282    )

checks if two dataclasses which (might) hold numpy arrays are equal

Parameters:

  • dc1: the first dataclass
  • dc2: the second dataclass
  • except_when_class_mismatch: bool if True, will throw TypeError if the classes are different. if not, will return false by default or attempt to compare the fields if false_when_class_mismatch is False (default: False)
  • false_when_class_mismatch: bool only relevant if except_when_class_mismatch is False. if True, will return False if the classes are different. if False, will attempt to compare the fields.
  • except_when_field_mismatch: bool only relevant if except_when_class_mismatch is False and false_when_class_mismatch is False. if True, will throw TypeError if the fields are different. (default: True)

Returns:

  • bool: True if the dataclasses are equal, False otherwise

Raises:

  • TypeError: if the dataclasses are of different classes
  • AttributeError: if the dataclasses have different fields

TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"?

          [START]
             ▼
       ┌───────────┐  ┌─────────┐
       │dc1 is dc2?├─►│ classes │
       └──┬────────┘No│ match?  │
  ────    │           ├─────────┤
 (True)◄──┘Yes        │No       │Yes
  ────                ▼         ▼
      ┌────────────────┐ ┌────────────┐
      │ except when    │ │ fields keys│
      │ class mismatch?│ │ match?     │
      ├───────────┬────┘ ├───────┬────┘
      │Yes        │No    │No     │Yes
      ▼           ▼      ▼       ▼
 ───────────  ┌──────────┐  ┌────────┐
{ raise     } │ except   │  │ field  │
{ TypeError } │ when     │  │ values │
 ───────────  │ field    │  │ match? │
              │ mismatch?│  ├────┬───┘
              ├───────┬──┘  │    │Yes
              │Yes    │No   │No  ▼
              ▼       ▼     │   ────
 ───────────────     ─────  │  (True)
{ raise         }   (False)◄┘   ────
{ AttributeError}    ─────
 ───────────────
@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
class SerializableDataclass(abc.ABC):
295@dataclass_transform(
296    field_specifiers=(serializable_field, SerializableField),
297)
298class SerializableDataclass(abc.ABC):
299    """Base class for serializable dataclasses
300
301    only for linting and type checking, still need to call `serializable_dataclass` decorator
302
303    # Usage:
304
305    ```python
306    @serializable_dataclass
307    class MyClass(SerializableDataclass):
308        a: int
309        b: str
310    ```
311
312    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
313
314        >>> my_obj = MyClass(a=1, b="q")
315        >>> s = json.dumps(my_obj.serialize())
316        >>> s
317        '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
318        >>> read_obj = MyClass.load(json.loads(s))
319        >>> read_obj == my_obj
320        True
321
322    This isn't too impressive on its own, but it gets more useful when you have nested classses,
323    or fields that are not json-serializable by default:
324
325    ```python
326    @serializable_dataclass
327    class NestedClass(SerializableDataclass):
328        x: str
329        y: MyClass
330        act_fun: torch.nn.Module = serializable_field(
331            default=torch.nn.ReLU(),
332            serialization_fn=lambda x: str(x),
333            deserialize_fn=lambda x: getattr(torch.nn, x)(),
334        )
335    ```
336
337    which gives us:
338
339        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
340        >>> s = json.dumps(nc.serialize())
341        >>> s
342        '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
343        >>> read_nc = NestedClass.load(json.loads(s))
344        >>> read_nc == nc
345        True
346    """
347
348    def serialize(self) -> dict[str, Any]:
349        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
350        raise NotImplementedError(
351            f"decorate {self.__class__ = } with `@serializable_dataclass`"
352        )
353
354    @classmethod
355    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
356        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
357        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
358
359    def validate_fields_types(
360        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
361    ) -> bool:
362        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
363        return SerializableDataclass__validate_fields_types(
364            self, on_typecheck_error=on_typecheck_error
365        )
366
367    def validate_field_type(
368        self,
369        field: "SerializableField|str",
370        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
371    ) -> bool:
372        """given a dataclass, check the field matches the type hint"""
373        return SerializableDataclass__validate_field_type(
374            self, field, on_typecheck_error=on_typecheck_error
375        )
376
377    def __eq__(self, other: Any) -> bool:
378        return dc_eq(self, other)
379
380    def __hash__(self) -> int:
381        "hashes the json-serialized representation of the class"
382        return hash(json.dumps(self.serialize()))
383
384    def diff(
385        self, other: "SerializableDataclass", of_serialized: bool = False
386    ) -> dict[str, Any]:
387        """get a rich and recursive diff between two instances of a serializable dataclass
388
389        ```python
390        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
391        {'b': {'self': 2, 'other': 3}}
392        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
393        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
394        ```
395
396        # Parameters:
397         - `other : SerializableDataclass`
398           other instance to compare against
399         - `of_serialized : bool`
400           if true, compare serialized data and not raw values
401           (defaults to `False`)
402
403        # Returns:
404         - `dict[str, Any]`
405
406
407        # Raises:
408         - `ValueError` : if the instances are not of the same type
409         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
410        """
411        # match types
412        if type(self) is not type(other):
413            raise ValueError(
414                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
415            )
416
417        # initialize the diff result
418        diff_result: dict = {}
419
420        # if they are the same, return the empty diff
421        try:
422            if self == other:
423                return diff_result
424        except Exception:
425            pass
426
427        # if we are working with serialized data, serialize the instances
428        if of_serialized:
429            ser_self: dict = self.serialize()
430            ser_other: dict = other.serialize()
431
432        # for each field in the class
433        for field in dataclasses.fields(self):  # type: ignore[arg-type]
434            # skip fields that are not for comparison
435            if not field.compare:
436                continue
437
438            # get values
439            field_name: str = field.name
440            self_value = getattr(self, field_name)
441            other_value = getattr(other, field_name)
442
443            # if the values are both serializable dataclasses, recurse
444            if isinstance(self_value, SerializableDataclass) and isinstance(
445                other_value, SerializableDataclass
446            ):
447                nested_diff: dict = self_value.diff(
448                    other_value, of_serialized=of_serialized
449                )
450                if nested_diff:
451                    diff_result[field_name] = nested_diff
452            # only support serializable dataclasses
453            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
454                other_value
455            ):
456                raise ValueError("Non-serializable dataclass is not supported")
457            else:
458                # get the values of either the serialized or the actual values
459                self_value_s = ser_self[field_name] if of_serialized else self_value
460                other_value_s = ser_other[field_name] if of_serialized else other_value
461                # compare the values
462                if not array_safe_eq(self_value_s, other_value_s):
463                    diff_result[field_name] = {"self": self_value, "other": other_value}
464
465        # return the diff result
466        return diff_result
467
468    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
469        """update the instance from a nested dict, useful for configuration from command line args
470
471        # Parameters:
472            - `nested_dict : dict[str, Any]`
473                nested dict to update the instance with
474        """
475        for field in dataclasses.fields(self):  # type: ignore[arg-type]
476            field_name: str = field.name
477            self_value = getattr(self, field_name)
478
479            if field_name in nested_dict:
480                if isinstance(self_value, SerializableDataclass):
481                    self_value.update_from_nested_dict(nested_dict[field_name])
482                else:
483                    setattr(self, field_name, nested_dict[field_name])
484
485    def __copy__(self) -> "SerializableDataclass":
486        "deep copy by serializing and loading the instance to json"
487        return self.__class__.load(json.loads(json.dumps(self.serialize())))
488
489    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
490        "deep copy by serializing and loading the instance to json"
491        return self.__class__.load(json.loads(json.dumps(self.serialize())))

Base class for serializable dataclasses

only for linting and type checking, still need to call serializable_dataclass decorator

Usage:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]:
348    def serialize(self) -> dict[str, Any]:
349        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
350        raise NotImplementedError(
351            f"decorate {self.__class__ = } with `@serializable_dataclass`"
352        )

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls: Type[~T], data: Union[dict[str, Any], ~T]) -> ~T:
354    @classmethod
355    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
356        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
357        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
359    def validate_fields_types(
360        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
361    ) -> bool:
362        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
363        return SerializableDataclass__validate_fields_types(
364            self, on_typecheck_error=on_typecheck_error
365        )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

def validate_field_type( self, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
367    def validate_field_type(
368        self,
369        field: "SerializableField|str",
370        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
371    ) -> bool:
372        """given a dataclass, check the field matches the type hint"""
373        return SerializableDataclass__validate_field_type(
374            self, field, on_typecheck_error=on_typecheck_error
375        )

given a dataclass, check the field matches the type hint

def diff( self, other: SerializableDataclass, of_serialized: bool = False) -> dict[str, typing.Any]:
384    def diff(
385        self, other: "SerializableDataclass", of_serialized: bool = False
386    ) -> dict[str, Any]:
387        """get a rich and recursive diff between two instances of a serializable dataclass
388
389        ```python
390        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
391        {'b': {'self': 2, 'other': 3}}
392        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
393        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
394        ```
395
396        # Parameters:
397         - `other : SerializableDataclass`
398           other instance to compare against
399         - `of_serialized : bool`
400           if true, compare serialized data and not raw values
401           (defaults to `False`)
402
403        # Returns:
404         - `dict[str, Any]`
405
406
407        # Raises:
408         - `ValueError` : if the instances are not of the same type
409         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
410        """
411        # match types
412        if type(self) is not type(other):
413            raise ValueError(
414                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
415            )
416
417        # initialize the diff result
418        diff_result: dict = {}
419
420        # if they are the same, return the empty diff
421        try:
422            if self == other:
423                return diff_result
424        except Exception:
425            pass
426
427        # if we are working with serialized data, serialize the instances
428        if of_serialized:
429            ser_self: dict = self.serialize()
430            ser_other: dict = other.serialize()
431
432        # for each field in the class
433        for field in dataclasses.fields(self):  # type: ignore[arg-type]
434            # skip fields that are not for comparison
435            if not field.compare:
436                continue
437
438            # get values
439            field_name: str = field.name
440            self_value = getattr(self, field_name)
441            other_value = getattr(other, field_name)
442
443            # if the values are both serializable dataclasses, recurse
444            if isinstance(self_value, SerializableDataclass) and isinstance(
445                other_value, SerializableDataclass
446            ):
447                nested_diff: dict = self_value.diff(
448                    other_value, of_serialized=of_serialized
449                )
450                if nested_diff:
451                    diff_result[field_name] = nested_diff
452            # only support serializable dataclasses
453            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
454                other_value
455            ):
456                raise ValueError("Non-serializable dataclass is not supported")
457            else:
458                # get the values of either the serialized or the actual values
459                self_value_s = ser_self[field_name] if of_serialized else self_value
460                other_value_s = ser_other[field_name] if of_serialized else other_value
461                # compare the values
462                if not array_safe_eq(self_value_s, other_value_s):
463                    diff_result[field_name] = {"self": self_value, "other": other_value}
464
465        # return the diff result
466        return diff_result

get a rich and recursive diff between two instances of a serializable dataclass

>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}

Parameters:

  • other : SerializableDataclass other instance to compare against
  • of_serialized : bool if true, compare serialized data and not raw values (defaults to False)

Returns:

  • dict[str, Any]

Raises:

  • ValueError : if the instances are not of the same type
  • ValueError : if the instances are dataclasses.dataclass but not SerializableDataclass
def update_from_nested_dict(self, nested_dict: dict[str, typing.Any]):
468    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
469        """update the instance from a nested dict, useful for configuration from command line args
470
471        # Parameters:
472            - `nested_dict : dict[str, Any]`
473                nested dict to update the instance with
474        """
475        for field in dataclasses.fields(self):  # type: ignore[arg-type]
476            field_name: str = field.name
477            self_value = getattr(self, field_name)
478
479            if field_name in nested_dict:
480                if isinstance(self_value, SerializableDataclass):
481                    self_value.update_from_nested_dict(nested_dict[field_name])
482                else:
483                    setattr(self, field_name, nested_dict[field_name])

update the instance from a nested dict, useful for configuration from command line args

Parameters:

- `nested_dict : dict[str, Any]`
    nested dict to update the instance with