docs for muutils v0.8.6
View Source on GitHub

muutils.json_serialize.serializable_dataclass

save and load objects to and from json or compatible formats in a recoverable way

d = dataclasses.asdict(my_obj) will give you a dict, but if some fields are not json-serializable, you will get an error when you call json.dumps(d). This module provides a way around that.

Instead, you define your class:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True

  1"""save and load objects to and from json or compatible formats in a recoverable way
  2
  3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable,
  4you will get an error when you call `json.dumps(d)`. This module provides a way around that.
  5
  6Instead, you define your class:
  7
  8```python
  9@serializable_dataclass
 10class MyClass(SerializableDataclass):
 11    a: int
 12    b: str
 13```
 14
 15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
 16
 17    >>> my_obj = MyClass(a=1, b="q")
 18    >>> s = json.dumps(my_obj.serialize())
 19    >>> s
 20    '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
 21    >>> read_obj = MyClass.load(json.loads(s))
 22    >>> read_obj == my_obj
 23    True
 24
 25This isn't too impressive on its own, but it gets more useful when you have nested classses,
 26or fields that are not json-serializable by default:
 27
 28```python
 29@serializable_dataclass
 30class NestedClass(SerializableDataclass):
 31    x: str
 32    y: MyClass
 33    act_fun: torch.nn.Module = serializable_field(
 34        default=torch.nn.ReLU(),
 35        serialization_fn=lambda x: str(x),
 36        deserialize_fn=lambda x: getattr(torch.nn, x)(),
 37    )
 38```
 39
 40which gives us:
 41
 42    >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
 43    >>> s = json.dumps(nc.serialize())
 44    >>> s
 45    '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
 46    >>> read_nc = NestedClass.load(json.loads(s))
 47    >>> read_nc == nc
 48    True
 49
 50"""
 51
 52from __future__ import annotations
 53
 54import abc
 55import dataclasses
 56import functools
 57import json
 58import sys
 59import typing
 60import warnings
 61from typing import Any, Optional, Type, TypeVar
 62
 63from muutils.errormode import ErrorMode
 64from muutils.validate_type import validate_type
 65from muutils.json_serialize.serializable_field import (
 66    SerializableField,
 67    serializable_field,
 68)
 69from muutils.json_serialize.util import _FORMAT_KEY, array_safe_eq, dc_eq
 70
 71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access
 72
 73# this is quite horrible, but unfortunately mypy fails if we try to assign to `dataclass_transform` directly
 74# and every time we try to init a serializable dataclass it says the argument doesnt exist
 75try:
 76    try:
 77        # type ignore here for legacy versions
 78        from typing import dataclass_transform  # type: ignore[attr-defined]
 79    except Exception:
 80        from typing_extensions import dataclass_transform
 81except Exception:
 82    from muutils.json_serialize.dataclass_transform_mock import dataclass_transform
 83
 84T = TypeVar("T")
 85
 86
 87class CantGetTypeHintsWarning(UserWarning):
 88    "special warning for when we can't get type hints"
 89
 90    pass
 91
 92
 93class ZanjMissingWarning(UserWarning):
 94    "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
 95
 96    pass
 97
 98
 99_zanj_loading_needs_import: bool = True
100"flag to keep track of if we have successfully imported ZANJ"
101
102
103def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
104    """Register a serializable dataclass with the ZANJ import
105
106    this allows `ZANJ().read()` to load the class and not just return plain dicts
107
108
109    # TODO: there is some duplication here with register_loader_handler
110    """
111    global _zanj_loading_needs_import
112
113    if _zanj_loading_needs_import:
114        try:
115            from zanj.loading import (  # type: ignore[import]
116                LoaderHandler,
117                register_loader_handler,
118            )
119        except ImportError:
120            # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter
121            # warnings.warn(
122            #     "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
123            #     ZanjMissingWarning,
124            # )
125            return
126
127    _format: str = f"{cls.__name__}(SerializableDataclass)"
128    lh: LoaderHandler = LoaderHandler(
129        check=lambda json_item, path=None, z=None: (  # type: ignore
130            isinstance(json_item, dict)
131            and _FORMAT_KEY in json_item
132            and json_item[_FORMAT_KEY].startswith(_format)
133        ),
134        load=lambda json_item, path=None, z=None: cls.load(json_item),  # type: ignore
135        uid=_format,
136        source_pckg=cls.__module__,
137        desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
138    )
139
140    register_loader_handler(lh)
141
142    return lh
143
144
145_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN
146_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT
147
148
149class FieldIsNotInitOrSerializeWarning(UserWarning):
150    pass
151
152
153def SerializableDataclass__validate_field_type(
154    self: SerializableDataclass,
155    field: SerializableField | str,
156    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
157) -> bool:
158    """given a dataclass, check the field matches the type hint
159
160    this function is written to `SerializableDataclass.validate_field_type`
161
162    # Parameters:
163     - `self : SerializableDataclass`
164       `SerializableDataclass` instance
165     - `field : SerializableField | str`
166        field to validate, will get from `self.__dataclass_fields__` if an `str`
167     - `on_typecheck_error : ErrorMode`
168        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
169       (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
170
171    # Returns:
172     - `bool`
173        if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
174    """
175    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
176
177    # get field
178    _field: SerializableField
179    if isinstance(field, str):
180        _field = self.__dataclass_fields__[field]  # type: ignore[attr-defined]
181    else:
182        _field = field
183
184    # do nothing case
185    if not _field.assert_type:
186        return True
187
188    # if field is not `init` or not `serialize`, skip but warn
189    # TODO: how to handle fields which are not `init` or `serialize`?
190    if not _field.init or not _field.serialize:
191        warnings.warn(
192            f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
193            FieldIsNotInitOrSerializeWarning,
194        )
195        return True
196
197    assert isinstance(
198        _field, SerializableField
199    ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
200
201    # get field type hints
202    try:
203        field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
204    except KeyError as e:
205        on_typecheck_error.process(
206            (
207                f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
208                + f"{get_cls_type_hints(self.__class__) = }\n"
209                + f"Python version is {sys.version_info = }. You can:\n"
210                + f"  - disable `assert_type`. Currently: {_field.assert_type = }\n"
211                + f"  - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
212                + "  - use python 3.9.x or higher\n"
213                + "  - specify custom type validation function via `custom_typecheck_fn`\n"
214            ),
215            except_cls=TypeError,
216            except_from=e,
217        )
218        return False
219
220    # get the value
221    value: Any = getattr(self, _field.name)
222
223    # validate the type
224    try:
225        type_is_valid: bool
226        # validate the type with the default type validator
227        if _field.custom_typecheck_fn is None:
228            type_is_valid = validate_type(value, field_type_hint)
229        # validate the type with a custom type validator
230        else:
231            type_is_valid = _field.custom_typecheck_fn(field_type_hint)
232
233        return type_is_valid
234
235    except Exception as e:
236        on_typecheck_error.process(
237            "exception while validating type: "
238            + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
239            except_cls=ValueError,
240            except_from=e,
241        )
242        return False
243
244
245def SerializableDataclass__validate_fields_types__dict(
246    self: SerializableDataclass,
247    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
248) -> dict[str, bool]:
249    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
250
251    returns a dict of field names to bools, where the bool is if the field type is valid
252    """
253    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
254
255    # if except, bundle the exceptions
256    results: dict[str, bool] = dict()
257    exceptions: dict[str, Exception] = dict()
258
259    # for each field in the class
260    cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self)  # type: ignore[arg-type, assignment]
261    for field in cls_fields:
262        try:
263            results[field.name] = self.validate_field_type(field, on_typecheck_error)
264        except Exception as e:
265            results[field.name] = False
266            exceptions[field.name] = e
267
268    # figure out what to do with the exceptions
269    if len(exceptions) > 0:
270        on_typecheck_error.process(
271            f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
272            + "\n\t"
273            + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
274            except_cls=ValueError,
275            # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
276            except_from=list(exceptions.values())[0],
277        )
278
279    return results
280
281
282def SerializableDataclass__validate_fields_types(
283    self: SerializableDataclass,
284    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
285) -> bool:
286    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
287    return all(
288        SerializableDataclass__validate_fields_types__dict(
289            self, on_typecheck_error=on_typecheck_error
290        ).values()
291    )
292
293
294@dataclass_transform(
295    field_specifiers=(serializable_field, SerializableField),
296)
297class SerializableDataclass(abc.ABC):
298    """Base class for serializable dataclasses
299
300    only for linting and type checking, still need to call `serializable_dataclass` decorator
301
302    # Usage:
303
304    ```python
305    @serializable_dataclass
306    class MyClass(SerializableDataclass):
307        a: int
308        b: str
309    ```
310
311    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
312
313        >>> my_obj = MyClass(a=1, b="q")
314        >>> s = json.dumps(my_obj.serialize())
315        >>> s
316        '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
317        >>> read_obj = MyClass.load(json.loads(s))
318        >>> read_obj == my_obj
319        True
320
321    This isn't too impressive on its own, but it gets more useful when you have nested classses,
322    or fields that are not json-serializable by default:
323
324    ```python
325    @serializable_dataclass
326    class NestedClass(SerializableDataclass):
327        x: str
328        y: MyClass
329        act_fun: torch.nn.Module = serializable_field(
330            default=torch.nn.ReLU(),
331            serialization_fn=lambda x: str(x),
332            deserialize_fn=lambda x: getattr(torch.nn, x)(),
333        )
334    ```
335
336    which gives us:
337
338        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
339        >>> s = json.dumps(nc.serialize())
340        >>> s
341        '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
342        >>> read_nc = NestedClass.load(json.loads(s))
343        >>> read_nc == nc
344        True
345    """
346
347    def serialize(self) -> dict[str, Any]:
348        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
349        raise NotImplementedError(
350            f"decorate {self.__class__ = } with `@serializable_dataclass`"
351        )
352
353    @classmethod
354    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
355        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
356        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
357
358    def validate_fields_types(
359        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
360    ) -> bool:
361        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
362        return SerializableDataclass__validate_fields_types(
363            self, on_typecheck_error=on_typecheck_error
364        )
365
366    def validate_field_type(
367        self,
368        field: "SerializableField|str",
369        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
370    ) -> bool:
371        """given a dataclass, check the field matches the type hint"""
372        return SerializableDataclass__validate_field_type(
373            self, field, on_typecheck_error=on_typecheck_error
374        )
375
376    def __eq__(self, other: Any) -> bool:
377        return dc_eq(self, other)
378
379    def __hash__(self) -> int:
380        "hashes the json-serialized representation of the class"
381        return hash(json.dumps(self.serialize()))
382
383    def diff(
384        self, other: "SerializableDataclass", of_serialized: bool = False
385    ) -> dict[str, Any]:
386        """get a rich and recursive diff between two instances of a serializable dataclass
387
388        ```python
389        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
390        {'b': {'self': 2, 'other': 3}}
391        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
392        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
393        ```
394
395        # Parameters:
396         - `other : SerializableDataclass`
397           other instance to compare against
398         - `of_serialized : bool`
399           if true, compare serialized data and not raw values
400           (defaults to `False`)
401
402        # Returns:
403         - `dict[str, Any]`
404
405
406        # Raises:
407         - `ValueError` : if the instances are not of the same type
408         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
409        """
410        # match types
411        if type(self) is not type(other):
412            raise ValueError(
413                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
414            )
415
416        # initialize the diff result
417        diff_result: dict = {}
418
419        # if they are the same, return the empty diff
420        try:
421            if self == other:
422                return diff_result
423        except Exception:
424            pass
425
426        # if we are working with serialized data, serialize the instances
427        if of_serialized:
428            ser_self: dict = self.serialize()
429            ser_other: dict = other.serialize()
430
431        # for each field in the class
432        for field in dataclasses.fields(self):  # type: ignore[arg-type]
433            # skip fields that are not for comparison
434            if not field.compare:
435                continue
436
437            # get values
438            field_name: str = field.name
439            self_value = getattr(self, field_name)
440            other_value = getattr(other, field_name)
441
442            # if the values are both serializable dataclasses, recurse
443            if isinstance(self_value, SerializableDataclass) and isinstance(
444                other_value, SerializableDataclass
445            ):
446                nested_diff: dict = self_value.diff(
447                    other_value, of_serialized=of_serialized
448                )
449                if nested_diff:
450                    diff_result[field_name] = nested_diff
451            # only support serializable dataclasses
452            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
453                other_value
454            ):
455                raise ValueError("Non-serializable dataclass is not supported")
456            else:
457                # get the values of either the serialized or the actual values
458                self_value_s = ser_self[field_name] if of_serialized else self_value
459                other_value_s = ser_other[field_name] if of_serialized else other_value
460                # compare the values
461                if not array_safe_eq(self_value_s, other_value_s):
462                    diff_result[field_name] = {"self": self_value, "other": other_value}
463
464        # return the diff result
465        return diff_result
466
467    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
468        """update the instance from a nested dict, useful for configuration from command line args
469
470        # Parameters:
471            - `nested_dict : dict[str, Any]`
472                nested dict to update the instance with
473        """
474        for field in dataclasses.fields(self):  # type: ignore[arg-type]
475            field_name: str = field.name
476            self_value = getattr(self, field_name)
477
478            if field_name in nested_dict:
479                if isinstance(self_value, SerializableDataclass):
480                    self_value.update_from_nested_dict(nested_dict[field_name])
481                else:
482                    setattr(self, field_name, nested_dict[field_name])
483
484    def __copy__(self) -> "SerializableDataclass":
485        "deep copy by serializing and loading the instance to json"
486        return self.__class__.load(json.loads(json.dumps(self.serialize())))
487
488    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
489        "deep copy by serializing and loading the instance to json"
490        return self.__class__.load(json.loads(json.dumps(self.serialize())))
491
492
493# cache this so we don't have to keep getting it
494# TODO: are the types hashable? does this even make sense?
495@functools.lru_cache(typed=True)
496def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
497    "cached typing.get_type_hints for a class"
498    return typing.get_type_hints(cls)
499
500
501def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
502    "helper function to get type hints for a class"
503    cls_type_hints: dict[str, Any]
504    try:
505        cls_type_hints = get_cls_type_hints_cached(cls)  # type: ignore
506        if len(cls_type_hints) == 0:
507            cls_type_hints = typing.get_type_hints(cls)
508
509        if len(cls_type_hints) == 0:
510            raise ValueError(f"empty type hints for {cls.__name__ = }")
511    except (TypeError, NameError, ValueError) as e:
512        raise TypeError(
513            f"Cannot get type hints for {cls = }\n"
514            + f"  Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
515            + f"  {dataclasses.fields(cls) = }\n"  # type: ignore[arg-type]
516            + f"  {e = }"
517        ) from e
518
519    return cls_type_hints
520
521
522class KWOnlyError(NotImplementedError):
523    "kw-only dataclasses are not supported in python <3.9"
524
525    pass
526
527
528class FieldError(ValueError):
529    "base class for field errors"
530
531    pass
532
533
534class NotSerializableFieldException(FieldError):
535    "field is not a `SerializableField`"
536
537    pass
538
539
540class FieldSerializationError(FieldError):
541    "error while serializing a field"
542
543    pass
544
545
546class FieldLoadingError(FieldError):
547    "error while loading a field"
548
549    pass
550
551
552class FieldTypeMismatchError(FieldError, TypeError):
553    "error when a field type does not match the type hint"
554
555    pass
556
557
558@dataclass_transform(
559    field_specifiers=(serializable_field, SerializableField),
560)
561def serializable_dataclass(
562    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
563    _cls=None,  # type: ignore
564    *,
565    init: bool = True,
566    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
567    eq: bool = True,
568    order: bool = False,
569    unsafe_hash: bool = False,
570    frozen: bool = False,
571    properties_to_serialize: Optional[list[str]] = None,
572    register_handler: bool = True,
573    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
574    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
575    methods_no_override: list[str] | None = None,
576    **kwargs,
577):
578    """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!**
579
580    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
581
582    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass`
583
584    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
585
586    Examines PEP 526 `__annotations__` to determine fields.
587
588    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
589
590    ```python
591    @serializable_dataclass(kw_only=True)
592    class Myclass(SerializableDataclass):
593        a: int
594        b: str
595    ```
596    ```python
597    >>> Myclass(a=1, b="q").serialize()
598    {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
599    ```
600
601    # Parameters:
602
603    - `_cls : _type_`
604       class to decorate. don't pass this arg, just use this as a decorator
605       (defaults to `None`)
606    - `init : bool`
607       whether to add an `__init__` method
608       *(passed to dataclasses.dataclass)*
609       (defaults to `True`)
610    - `repr : bool`
611       whether to add a `__repr__` method
612       *(passed to dataclasses.dataclass)*
613       (defaults to `True`)
614    - `order : bool`
615       whether to add rich comparison methods
616       *(passed to dataclasses.dataclass)*
617       (defaults to `False`)
618    - `unsafe_hash : bool`
619       whether to add a `__hash__` method
620       *(passed to dataclasses.dataclass)*
621       (defaults to `False`)
622    - `frozen : bool`
623       whether to make the class frozen
624       *(passed to dataclasses.dataclass)*
625       (defaults to `False`)
626    - `properties_to_serialize : Optional[list[str]]`
627       which properties to add to the serialized data dict
628       **SerializableDataclass only**
629       (defaults to `None`)
630    - `register_handler : bool`
631        if true, register the class with ZANJ for loading
632        **SerializableDataclass only**
633        (defaults to `True`)
634    - `on_typecheck_error : ErrorMode`
635        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
636        **SerializableDataclass only**
637    - `on_typecheck_mismatch : ErrorMode`
638        what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
639        **SerializableDataclass only**
640    - `methods_no_override : list[str]|None`
641        list of methods that should not be overridden by the decorator
642        by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function,
643        but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence
644        **SerializableDataclass only**
645        (defaults to `None`)
646    - `**kwargs`
647        *(passed to dataclasses.dataclass)*
648
649    # Returns:
650
651    - `_type_`
652       the decorated class
653
654    # Raises:
655
656    - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
657    - `NotSerializableFieldException` : if a field is not a `SerializableField`
658    - `FieldSerializationError` : if there is an error serializing a field
659    - `AttributeError` : if a property is not found on the class
660    - `FieldLoadingError` : if there is an error loading a field
661    """
662    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
663    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
664    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
665
666    if properties_to_serialize is None:
667        _properties_to_serialize: list = list()
668    else:
669        _properties_to_serialize = properties_to_serialize
670
671    def wrap(cls: Type[T]) -> Type[T]:
672        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
673        for field_name, field_type in cls.__annotations__.items():
674            field_value = getattr(cls, field_name, None)
675            if not isinstance(field_value, SerializableField):
676                if isinstance(field_value, dataclasses.Field):
677                    # Convert the field to a SerializableField while preserving properties
678                    field_value = SerializableField.from_Field(field_value)
679                else:
680                    # Create a new SerializableField
681                    field_value = serializable_field()
682                setattr(cls, field_name, field_value)
683
684        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
685        if sys.version_info < (3, 10):
686            if "kw_only" in kwargs:
687                if kwargs["kw_only"] == True:  # noqa: E712
688                    raise KWOnlyError(
689                        "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored"
690                    )
691                else:
692                    del kwargs["kw_only"]
693
694        # call `dataclasses.dataclass` to set some stuff up
695        cls = dataclasses.dataclass(  # type: ignore[call-overload]
696            cls,
697            init=init,
698            repr=repr,
699            eq=eq,
700            order=order,
701            unsafe_hash=unsafe_hash,
702            frozen=frozen,
703            **kwargs,
704        )
705
706        # copy these to the class
707        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
708
709        # ======================================================================
710        # define `serialize` func
711        # done locally since it depends on args to the decorator
712        # ======================================================================
713        def serialize(self) -> dict[str, Any]:
714            result: dict[str, Any] = {
715                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
716            }
717            # for each field in the class
718            for field in dataclasses.fields(self):  # type: ignore[arg-type]
719                # need it to be our special SerializableField
720                if not isinstance(field, SerializableField):
721                    raise NotSerializableFieldException(
722                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
723                        f"but a {type(field)} "
724                        "this state should be inaccessible, please report this bug!"
725                    )
726
727                # try to save it
728                if field.serialize:
729                    try:
730                        # get the val
731                        value = getattr(self, field.name)
732                        # if it is a serializable dataclass, serialize it
733                        if isinstance(value, SerializableDataclass):
734                            value = value.serialize()
735                        # if the value has a serialization function, use that
736                        if hasattr(value, "serialize") and callable(value.serialize):
737                            value = value.serialize()
738                        # if the field has a serialization function, use that
739                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
740                        elif field.serialization_fn:
741                            value = field.serialization_fn(value)
742
743                        # store the value in the result
744                        result[field.name] = value
745                    except Exception as e:
746                        raise FieldSerializationError(
747                            "\n".join(
748                                [
749                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
750                                    f"{field = }",
751                                    f"{value = }",
752                                    f"{self = }",
753                                ]
754                            )
755                        ) from e
756
757            # store each property if we can get it
758            for prop in self._properties_to_serialize:
759                if hasattr(cls, prop):
760                    value = getattr(self, prop)
761                    result[prop] = value
762                else:
763                    raise AttributeError(
764                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
765                        + f"but it is in {self._properties_to_serialize = }"
766                        + f"\n{self = }"
767                    )
768
769            return result
770
771        # ======================================================================
772        # define `load` func
773        # done locally since it depends on args to the decorator
774        # ======================================================================
775        # mypy thinks this isnt a classmethod
776        @classmethod  # type: ignore[misc]
777        def load(cls, data: dict[str, Any] | T) -> Type[T]:
778            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
779            if isinstance(data, cls):
780                return data
781
782            assert isinstance(
783                data, typing.Mapping
784            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
785
786            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
787
788            # initialize dict for keeping what we will pass to the constructor
789            ctor_kwargs: dict[str, Any] = dict()
790
791            # iterate over the fields of the class
792            for field in dataclasses.fields(cls):
793                # check if the field is a SerializableField
794                assert isinstance(
795                    field, SerializableField
796                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
797
798                # check if the field is in the data and if it should be initialized
799                if (field.name in data) and field.init:
800                    # get the value, we will be processing it
801                    value: Any = data[field.name]
802
803                    # get the type hint for the field
804                    field_type_hint: Any = cls_type_hints.get(field.name, None)
805
806                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
807                    if field.deserialize_fn:
808                        # if it has a deserialization function, use that
809                        value = field.deserialize_fn(value)
810                    elif field.loading_fn:
811                        # if it has a loading function, use that
812                        value = field.loading_fn(data)
813                    elif (
814                        field_type_hint is not None
815                        and hasattr(field_type_hint, "load")
816                        and callable(field_type_hint.load)
817                    ):
818                        # if no loading function but has a type hint with a load method, use that
819                        if isinstance(value, dict):
820                            value = field_type_hint.load(value)
821                        else:
822                            raise FieldLoadingError(
823                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
824                            )
825                    else:
826                        # assume no loading needs to happen, keep `value` as-is
827                        pass
828
829                    # store the value in the constructor kwargs
830                    ctor_kwargs[field.name] = value
831
832            # create a new instance of the class with the constructor kwargs
833            output: cls = cls(**ctor_kwargs)
834
835            # validate the types of the fields if needed
836            if on_typecheck_mismatch != ErrorMode.IGNORE:
837                fields_valid: dict[str, bool] = (
838                    SerializableDataclass__validate_fields_types__dict(
839                        output,
840                        on_typecheck_error=on_typecheck_error,
841                    )
842                )
843
844                # if there are any fields that are not valid, raise an error
845                if not all(fields_valid.values()):
846                    msg: str = (
847                        f"Type mismatch in fields of {cls.__name__}:\n"
848                        + "\n".join(
849                            [
850                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
851                                for k, v in fields_valid.items()
852                                if not v
853                            ]
854                        )
855                    )
856
857                    on_typecheck_mismatch.process(
858                        msg, except_cls=FieldTypeMismatchError
859                    )
860
861            # return the new instance
862            return output
863
864        _methods_no_override: set[str]
865        if methods_no_override is None:
866            _methods_no_override = set()
867        else:
868            _methods_no_override = set(methods_no_override)
869
870        if _methods_no_override - {
871            "__eq__",
872            "serialize",
873            "load",
874            "validate_fields_types",
875        }:
876            warnings.warn(
877                f"Unknown methods in `methods_no_override`: {_methods_no_override = }"
878            )
879
880        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
881        if "serialize" not in _methods_no_override:
882            # type is `Callable[[T], dict]`
883            cls.serialize = serialize  # type: ignore[attr-defined]
884        if "load" not in _methods_no_override:
885            # type is `Callable[[dict], T]`
886            cls.load = load  # type: ignore[attr-defined]
887
888        if "validate_field_type" not in _methods_no_override:
889            # type is `Callable[[T, ErrorMode], bool]`
890            cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]
891
892        if "__eq__" not in _methods_no_override:
893            # type is `Callable[[T, T], bool]`
894            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
895
896        # Register the class with ZANJ
897        if register_handler:
898            zanj_register_loader_serializable_dataclass(cls)
899
900        return cls
901
902    if _cls is None:
903        return wrap
904    else:
905        return wrap(_cls)

class CantGetTypeHintsWarning(builtins.UserWarning):
88class CantGetTypeHintsWarning(UserWarning):
89    "special warning for when we can't get type hints"
90
91    pass

special warning for when we can't get type hints

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
class ZanjMissingWarning(builtins.UserWarning):
94class ZanjMissingWarning(UserWarning):
95    "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
96
97    pass

special warning for when ZANJ is missing -- register_loader_serializable_dataclass will not work

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
def zanj_register_loader_serializable_dataclass(cls: Type[~T]):
104def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
105    """Register a serializable dataclass with the ZANJ import
106
107    this allows `ZANJ().read()` to load the class and not just return plain dicts
108
109
110    # TODO: there is some duplication here with register_loader_handler
111    """
112    global _zanj_loading_needs_import
113
114    if _zanj_loading_needs_import:
115        try:
116            from zanj.loading import (  # type: ignore[import]
117                LoaderHandler,
118                register_loader_handler,
119            )
120        except ImportError:
121            # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter
122            # warnings.warn(
123            #     "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
124            #     ZanjMissingWarning,
125            # )
126            return
127
128    _format: str = f"{cls.__name__}(SerializableDataclass)"
129    lh: LoaderHandler = LoaderHandler(
130        check=lambda json_item, path=None, z=None: (  # type: ignore
131            isinstance(json_item, dict)
132            and _FORMAT_KEY in json_item
133            and json_item[_FORMAT_KEY].startswith(_format)
134        ),
135        load=lambda json_item, path=None, z=None: cls.load(json_item),  # type: ignore
136        uid=_format,
137        source_pckg=cls.__module__,
138        desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
139    )
140
141    register_loader_handler(lh)
142
143    return lh

Register a serializable dataclass with the ZANJ import

this allows ZANJ().read() to load the class and not just return plain dicts

TODO: there is some duplication here with register_loader_handler

class FieldIsNotInitOrSerializeWarning(builtins.UserWarning):
150class FieldIsNotInitOrSerializeWarning(UserWarning):
151    pass

Base class for warnings generated by user code.

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
def SerializableDataclass__validate_field_type( self: SerializableDataclass, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
154def SerializableDataclass__validate_field_type(
155    self: SerializableDataclass,
156    field: SerializableField | str,
157    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
158) -> bool:
159    """given a dataclass, check the field matches the type hint
160
161    this function is written to `SerializableDataclass.validate_field_type`
162
163    # Parameters:
164     - `self : SerializableDataclass`
165       `SerializableDataclass` instance
166     - `field : SerializableField | str`
167        field to validate, will get from `self.__dataclass_fields__` if an `str`
168     - `on_typecheck_error : ErrorMode`
169        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
170       (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
171
172    # Returns:
173     - `bool`
174        if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
175    """
176    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
177
178    # get field
179    _field: SerializableField
180    if isinstance(field, str):
181        _field = self.__dataclass_fields__[field]  # type: ignore[attr-defined]
182    else:
183        _field = field
184
185    # do nothing case
186    if not _field.assert_type:
187        return True
188
189    # if field is not `init` or not `serialize`, skip but warn
190    # TODO: how to handle fields which are not `init` or `serialize`?
191    if not _field.init or not _field.serialize:
192        warnings.warn(
193            f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
194            FieldIsNotInitOrSerializeWarning,
195        )
196        return True
197
198    assert isinstance(
199        _field, SerializableField
200    ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
201
202    # get field type hints
203    try:
204        field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
205    except KeyError as e:
206        on_typecheck_error.process(
207            (
208                f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
209                + f"{get_cls_type_hints(self.__class__) = }\n"
210                + f"Python version is {sys.version_info = }. You can:\n"
211                + f"  - disable `assert_type`. Currently: {_field.assert_type = }\n"
212                + f"  - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
213                + "  - use python 3.9.x or higher\n"
214                + "  - specify custom type validation function via `custom_typecheck_fn`\n"
215            ),
216            except_cls=TypeError,
217            except_from=e,
218        )
219        return False
220
221    # get the value
222    value: Any = getattr(self, _field.name)
223
224    # validate the type
225    try:
226        type_is_valid: bool
227        # validate the type with the default type validator
228        if _field.custom_typecheck_fn is None:
229            type_is_valid = validate_type(value, field_type_hint)
230        # validate the type with a custom type validator
231        else:
232            type_is_valid = _field.custom_typecheck_fn(field_type_hint)
233
234        return type_is_valid
235
236    except Exception as e:
237        on_typecheck_error.process(
238            "exception while validating type: "
239            + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
240            except_cls=ValueError,
241            except_from=e,
242        )
243        return False

given a dataclass, check the field matches the type hint

this function is written to SerializableDataclass.validate_field_type

Parameters:

  • self : SerializableDataclass SerializableDataclass instance
  • field : SerializableField | str field to validate, will get from self.__dataclass_fields__ if an str
  • on_typecheck_error : ErrorMode what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, the function will return False (defaults to _DEFAULT_ON_TYPECHECK_ERROR)

Returns:

  • bool if the field type is correct. False if the field type is incorrect or an exception is thrown and on_typecheck_error is ignore
def SerializableDataclass__validate_fields_types__dict( self: SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> dict[str, bool]:
246def SerializableDataclass__validate_fields_types__dict(
247    self: SerializableDataclass,
248    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
249) -> dict[str, bool]:
250    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
251
252    returns a dict of field names to bools, where the bool is if the field type is valid
253    """
254    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
255
256    # if except, bundle the exceptions
257    results: dict[str, bool] = dict()
258    exceptions: dict[str, Exception] = dict()
259
260    # for each field in the class
261    cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self)  # type: ignore[arg-type, assignment]
262    for field in cls_fields:
263        try:
264            results[field.name] = self.validate_field_type(field, on_typecheck_error)
265        except Exception as e:
266            results[field.name] = False
267            exceptions[field.name] = e
268
269    # figure out what to do with the exceptions
270    if len(exceptions) > 0:
271        on_typecheck_error.process(
272            f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
273            + "\n\t"
274            + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
275            except_cls=ValueError,
276            # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
277            except_from=list(exceptions.values())[0],
278        )
279
280    return results

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

returns a dict of field names to bools, where the bool is if the field type is valid

def SerializableDataclass__validate_fields_types( self: SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
class SerializableDataclass(abc.ABC):
295@dataclass_transform(
296    field_specifiers=(serializable_field, SerializableField),
297)
298class SerializableDataclass(abc.ABC):
299    """Base class for serializable dataclasses
300
301    only for linting and type checking, still need to call `serializable_dataclass` decorator
302
303    # Usage:
304
305    ```python
306    @serializable_dataclass
307    class MyClass(SerializableDataclass):
308        a: int
309        b: str
310    ```
311
312    and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
313
314        >>> my_obj = MyClass(a=1, b="q")
315        >>> s = json.dumps(my_obj.serialize())
316        >>> s
317        '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
318        >>> read_obj = MyClass.load(json.loads(s))
319        >>> read_obj == my_obj
320        True
321
322    This isn't too impressive on its own, but it gets more useful when you have nested classses,
323    or fields that are not json-serializable by default:
324
325    ```python
326    @serializable_dataclass
327    class NestedClass(SerializableDataclass):
328        x: str
329        y: MyClass
330        act_fun: torch.nn.Module = serializable_field(
331            default=torch.nn.ReLU(),
332            serialization_fn=lambda x: str(x),
333            deserialize_fn=lambda x: getattr(torch.nn, x)(),
334        )
335    ```
336
337    which gives us:
338
339        >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
340        >>> s = json.dumps(nc.serialize())
341        >>> s
342        '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
343        >>> read_nc = NestedClass.load(json.loads(s))
344        >>> read_nc == nc
345        True
346    """
347
348    def serialize(self) -> dict[str, Any]:
349        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
350        raise NotImplementedError(
351            f"decorate {self.__class__ = } with `@serializable_dataclass`"
352        )
353
354    @classmethod
355    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
356        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
357        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
358
359    def validate_fields_types(
360        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
361    ) -> bool:
362        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
363        return SerializableDataclass__validate_fields_types(
364            self, on_typecheck_error=on_typecheck_error
365        )
366
367    def validate_field_type(
368        self,
369        field: "SerializableField|str",
370        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
371    ) -> bool:
372        """given a dataclass, check the field matches the type hint"""
373        return SerializableDataclass__validate_field_type(
374            self, field, on_typecheck_error=on_typecheck_error
375        )
376
377    def __eq__(self, other: Any) -> bool:
378        return dc_eq(self, other)
379
380    def __hash__(self) -> int:
381        "hashes the json-serialized representation of the class"
382        return hash(json.dumps(self.serialize()))
383
384    def diff(
385        self, other: "SerializableDataclass", of_serialized: bool = False
386    ) -> dict[str, Any]:
387        """get a rich and recursive diff between two instances of a serializable dataclass
388
389        ```python
390        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
391        {'b': {'self': 2, 'other': 3}}
392        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
393        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
394        ```
395
396        # Parameters:
397         - `other : SerializableDataclass`
398           other instance to compare against
399         - `of_serialized : bool`
400           if true, compare serialized data and not raw values
401           (defaults to `False`)
402
403        # Returns:
404         - `dict[str, Any]`
405
406
407        # Raises:
408         - `ValueError` : if the instances are not of the same type
409         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
410        """
411        # match types
412        if type(self) is not type(other):
413            raise ValueError(
414                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
415            )
416
417        # initialize the diff result
418        diff_result: dict = {}
419
420        # if they are the same, return the empty diff
421        try:
422            if self == other:
423                return diff_result
424        except Exception:
425            pass
426
427        # if we are working with serialized data, serialize the instances
428        if of_serialized:
429            ser_self: dict = self.serialize()
430            ser_other: dict = other.serialize()
431
432        # for each field in the class
433        for field in dataclasses.fields(self):  # type: ignore[arg-type]
434            # skip fields that are not for comparison
435            if not field.compare:
436                continue
437
438            # get values
439            field_name: str = field.name
440            self_value = getattr(self, field_name)
441            other_value = getattr(other, field_name)
442
443            # if the values are both serializable dataclasses, recurse
444            if isinstance(self_value, SerializableDataclass) and isinstance(
445                other_value, SerializableDataclass
446            ):
447                nested_diff: dict = self_value.diff(
448                    other_value, of_serialized=of_serialized
449                )
450                if nested_diff:
451                    diff_result[field_name] = nested_diff
452            # only support serializable dataclasses
453            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
454                other_value
455            ):
456                raise ValueError("Non-serializable dataclass is not supported")
457            else:
458                # get the values of either the serialized or the actual values
459                self_value_s = ser_self[field_name] if of_serialized else self_value
460                other_value_s = ser_other[field_name] if of_serialized else other_value
461                # compare the values
462                if not array_safe_eq(self_value_s, other_value_s):
463                    diff_result[field_name] = {"self": self_value, "other": other_value}
464
465        # return the diff result
466        return diff_result
467
468    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
469        """update the instance from a nested dict, useful for configuration from command line args
470
471        # Parameters:
472            - `nested_dict : dict[str, Any]`
473                nested dict to update the instance with
474        """
475        for field in dataclasses.fields(self):  # type: ignore[arg-type]
476            field_name: str = field.name
477            self_value = getattr(self, field_name)
478
479            if field_name in nested_dict:
480                if isinstance(self_value, SerializableDataclass):
481                    self_value.update_from_nested_dict(nested_dict[field_name])
482                else:
483                    setattr(self, field_name, nested_dict[field_name])
484
485    def __copy__(self) -> "SerializableDataclass":
486        "deep copy by serializing and loading the instance to json"
487        return self.__class__.load(json.loads(json.dumps(self.serialize())))
488
489    def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
490        "deep copy by serializing and loading the instance to json"
491        return self.__class__.load(json.loads(json.dumps(self.serialize())))

Base class for serializable dataclasses

only for linting and type checking, still need to call serializable_dataclass decorator

Usage:

@serializable_dataclass
class MyClass(SerializableDataclass):
    a: int
    b: str

and then you can call my_obj.serialize() to get a dict that can be serialized to json. So, you can do:

>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True

This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:

@serializable_dataclass
class NestedClass(SerializableDataclass):
    x: str
    y: MyClass
    act_fun: torch.nn.Module = serializable_field(
        default=torch.nn.ReLU(),
        serialization_fn=lambda x: str(x),
        deserialize_fn=lambda x: getattr(torch.nn, x)(),
    )

which gives us:

>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
def serialize(self) -> dict[str, typing.Any]:
348    def serialize(self) -> dict[str, Any]:
349        "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
350        raise NotImplementedError(
351            f"decorate {self.__class__ = } with `@serializable_dataclass`"
352        )

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls: Type[~T], data: Union[dict[str, Any], ~T]) -> ~T:
354    @classmethod
355    def load(cls: Type[T], data: dict[str, Any] | T) -> T:
356        "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
357        raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
359    def validate_fields_types(
360        self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
361    ) -> bool:
362        """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
363        return SerializableDataclass__validate_fields_types(
364            self, on_typecheck_error=on_typecheck_error
365        )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

def validate_field_type( self, field: muutils.json_serialize.serializable_field.SerializableField | str, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
367    def validate_field_type(
368        self,
369        field: "SerializableField|str",
370        on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
371    ) -> bool:
372        """given a dataclass, check the field matches the type hint"""
373        return SerializableDataclass__validate_field_type(
374            self, field, on_typecheck_error=on_typecheck_error
375        )

given a dataclass, check the field matches the type hint

def diff( self, other: SerializableDataclass, of_serialized: bool = False) -> dict[str, typing.Any]:
384    def diff(
385        self, other: "SerializableDataclass", of_serialized: bool = False
386    ) -> dict[str, Any]:
387        """get a rich and recursive diff between two instances of a serializable dataclass
388
389        ```python
390        >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
391        {'b': {'self': 2, 'other': 3}}
392        >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
393        {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
394        ```
395
396        # Parameters:
397         - `other : SerializableDataclass`
398           other instance to compare against
399         - `of_serialized : bool`
400           if true, compare serialized data and not raw values
401           (defaults to `False`)
402
403        # Returns:
404         - `dict[str, Any]`
405
406
407        # Raises:
408         - `ValueError` : if the instances are not of the same type
409         - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
410        """
411        # match types
412        if type(self) is not type(other):
413            raise ValueError(
414                f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
415            )
416
417        # initialize the diff result
418        diff_result: dict = {}
419
420        # if they are the same, return the empty diff
421        try:
422            if self == other:
423                return diff_result
424        except Exception:
425            pass
426
427        # if we are working with serialized data, serialize the instances
428        if of_serialized:
429            ser_self: dict = self.serialize()
430            ser_other: dict = other.serialize()
431
432        # for each field in the class
433        for field in dataclasses.fields(self):  # type: ignore[arg-type]
434            # skip fields that are not for comparison
435            if not field.compare:
436                continue
437
438            # get values
439            field_name: str = field.name
440            self_value = getattr(self, field_name)
441            other_value = getattr(other, field_name)
442
443            # if the values are both serializable dataclasses, recurse
444            if isinstance(self_value, SerializableDataclass) and isinstance(
445                other_value, SerializableDataclass
446            ):
447                nested_diff: dict = self_value.diff(
448                    other_value, of_serialized=of_serialized
449                )
450                if nested_diff:
451                    diff_result[field_name] = nested_diff
452            # only support serializable dataclasses
453            elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
454                other_value
455            ):
456                raise ValueError("Non-serializable dataclass is not supported")
457            else:
458                # get the values of either the serialized or the actual values
459                self_value_s = ser_self[field_name] if of_serialized else self_value
460                other_value_s = ser_other[field_name] if of_serialized else other_value
461                # compare the values
462                if not array_safe_eq(self_value_s, other_value_s):
463                    diff_result[field_name] = {"self": self_value, "other": other_value}
464
465        # return the diff result
466        return diff_result

get a rich and recursive diff between two instances of a serializable dataclass

>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}

Parameters:

  • other : SerializableDataclass other instance to compare against
  • of_serialized : bool if true, compare serialized data and not raw values (defaults to False)

Returns:

  • dict[str, Any]

Raises:

  • ValueError : if the instances are not of the same type
  • ValueError : if the instances are dataclasses.dataclass but not SerializableDataclass
def update_from_nested_dict(self, nested_dict: dict[str, typing.Any]):
468    def update_from_nested_dict(self, nested_dict: dict[str, Any]):
469        """update the instance from a nested dict, useful for configuration from command line args
470
471        # Parameters:
472            - `nested_dict : dict[str, Any]`
473                nested dict to update the instance with
474        """
475        for field in dataclasses.fields(self):  # type: ignore[arg-type]
476            field_name: str = field.name
477            self_value = getattr(self, field_name)
478
479            if field_name in nested_dict:
480                if isinstance(self_value, SerializableDataclass):
481                    self_value.update_from_nested_dict(nested_dict[field_name])
482                else:
483                    setattr(self, field_name, nested_dict[field_name])

update the instance from a nested dict, useful for configuration from command line args

Parameters:

- `nested_dict : dict[str, Any]`
    nested dict to update the instance with
@functools.lru_cache(typed=True)
def get_cls_type_hints_cached(cls: Type[~T]) -> dict[str, typing.Any]:
496@functools.lru_cache(typed=True)
497def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
498    "cached typing.get_type_hints for a class"
499    return typing.get_type_hints(cls)

cached typing.get_type_hints for a class

def get_cls_type_hints(cls: Type[~T]) -> dict[str, typing.Any]:
502def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
503    "helper function to get type hints for a class"
504    cls_type_hints: dict[str, Any]
505    try:
506        cls_type_hints = get_cls_type_hints_cached(cls)  # type: ignore
507        if len(cls_type_hints) == 0:
508            cls_type_hints = typing.get_type_hints(cls)
509
510        if len(cls_type_hints) == 0:
511            raise ValueError(f"empty type hints for {cls.__name__ = }")
512    except (TypeError, NameError, ValueError) as e:
513        raise TypeError(
514            f"Cannot get type hints for {cls = }\n"
515            + f"  Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
516            + f"  {dataclasses.fields(cls) = }\n"  # type: ignore[arg-type]
517            + f"  {e = }"
518        ) from e
519
520    return cls_type_hints

helper function to get type hints for a class

class KWOnlyError(builtins.NotImplementedError):
523class KWOnlyError(NotImplementedError):
524    "kw-only dataclasses are not supported in python <3.9"
525
526    pass

kw-only dataclasses are not supported in python <3.9

Inherited Members
builtins.NotImplementedError
NotImplementedError
builtins.BaseException
with_traceback
add_note
args
class FieldError(builtins.ValueError):
529class FieldError(ValueError):
530    "base class for field errors"
531
532    pass

base class for field errors

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class NotSerializableFieldException(FieldError):
535class NotSerializableFieldException(FieldError):
536    "field is not a `SerializableField`"
537
538    pass

field is not a SerializableField

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldSerializationError(FieldError):
541class FieldSerializationError(FieldError):
542    "error while serializing a field"
543
544    pass

error while serializing a field

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldLoadingError(FieldError):
547class FieldLoadingError(FieldError):
548    "error while loading a field"
549
550    pass

error while loading a field

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class FieldTypeMismatchError(FieldError, builtins.TypeError):
553class FieldTypeMismatchError(FieldError, TypeError):
554    "error when a field type does not match the type hint"
555
556    pass

error when a field type does not match the type hint

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
@dataclass_transform(field_specifiers=(serializable_field, SerializableField))
def serializable_dataclass( _cls=None, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, properties_to_serialize: Optional[list[str]] = None, register_handler: bool = True, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except, on_typecheck_mismatch: muutils.errormode.ErrorMode = ErrorMode.Warn, methods_no_override: list[str] | None = None, **kwargs):
559@dataclass_transform(
560    field_specifiers=(serializable_field, SerializableField),
561)
562def serializable_dataclass(
563    # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
564    _cls=None,  # type: ignore
565    *,
566    init: bool = True,
567    repr: bool = True,  # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
568    eq: bool = True,
569    order: bool = False,
570    unsafe_hash: bool = False,
571    frozen: bool = False,
572    properties_to_serialize: Optional[list[str]] = None,
573    register_handler: bool = True,
574    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
575    on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
576    methods_no_override: list[str] | None = None,
577    **kwargs,
578):
579    """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!**
580
581    types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
582
583    behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass`
584
585    Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
586
587    Examines PEP 526 `__annotations__` to determine fields.
588
589    If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
590
591    ```python
592    @serializable_dataclass(kw_only=True)
593    class Myclass(SerializableDataclass):
594        a: int
595        b: str
596    ```
597    ```python
598    >>> Myclass(a=1, b="q").serialize()
599    {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
600    ```
601
602    # Parameters:
603
604    - `_cls : _type_`
605       class to decorate. don't pass this arg, just use this as a decorator
606       (defaults to `None`)
607    - `init : bool`
608       whether to add an `__init__` method
609       *(passed to dataclasses.dataclass)*
610       (defaults to `True`)
611    - `repr : bool`
612       whether to add a `__repr__` method
613       *(passed to dataclasses.dataclass)*
614       (defaults to `True`)
615    - `order : bool`
616       whether to add rich comparison methods
617       *(passed to dataclasses.dataclass)*
618       (defaults to `False`)
619    - `unsafe_hash : bool`
620       whether to add a `__hash__` method
621       *(passed to dataclasses.dataclass)*
622       (defaults to `False`)
623    - `frozen : bool`
624       whether to make the class frozen
625       *(passed to dataclasses.dataclass)*
626       (defaults to `False`)
627    - `properties_to_serialize : Optional[list[str]]`
628       which properties to add to the serialized data dict
629       **SerializableDataclass only**
630       (defaults to `None`)
631    - `register_handler : bool`
632        if true, register the class with ZANJ for loading
633        **SerializableDataclass only**
634        (defaults to `True`)
635    - `on_typecheck_error : ErrorMode`
636        what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
637        **SerializableDataclass only**
638    - `on_typecheck_mismatch : ErrorMode`
639        what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
640        **SerializableDataclass only**
641    - `methods_no_override : list[str]|None`
642        list of methods that should not be overridden by the decorator
643        by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function,
644        but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence
645        **SerializableDataclass only**
646        (defaults to `None`)
647    - `**kwargs`
648        *(passed to dataclasses.dataclass)*
649
650    # Returns:
651
652    - `_type_`
653       the decorated class
654
655    # Raises:
656
657    - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
658    - `NotSerializableFieldException` : if a field is not a `SerializableField`
659    - `FieldSerializationError` : if there is an error serializing a field
660    - `AttributeError` : if a property is not found on the class
661    - `FieldLoadingError` : if there is an error loading a field
662    """
663    # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
664    on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
665    on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
666
667    if properties_to_serialize is None:
668        _properties_to_serialize: list = list()
669    else:
670        _properties_to_serialize = properties_to_serialize
671
672    def wrap(cls: Type[T]) -> Type[T]:
673        # Modify the __annotations__ dictionary to replace regular fields with SerializableField
674        for field_name, field_type in cls.__annotations__.items():
675            field_value = getattr(cls, field_name, None)
676            if not isinstance(field_value, SerializableField):
677                if isinstance(field_value, dataclasses.Field):
678                    # Convert the field to a SerializableField while preserving properties
679                    field_value = SerializableField.from_Field(field_value)
680                else:
681                    # Create a new SerializableField
682                    field_value = serializable_field()
683                setattr(cls, field_name, field_value)
684
685        # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
686        if sys.version_info < (3, 10):
687            if "kw_only" in kwargs:
688                if kwargs["kw_only"] == True:  # noqa: E712
689                    raise KWOnlyError(
690                        "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored"
691                    )
692                else:
693                    del kwargs["kw_only"]
694
695        # call `dataclasses.dataclass` to set some stuff up
696        cls = dataclasses.dataclass(  # type: ignore[call-overload]
697            cls,
698            init=init,
699            repr=repr,
700            eq=eq,
701            order=order,
702            unsafe_hash=unsafe_hash,
703            frozen=frozen,
704            **kwargs,
705        )
706
707        # copy these to the class
708        cls._properties_to_serialize = _properties_to_serialize.copy()  # type: ignore[attr-defined]
709
710        # ======================================================================
711        # define `serialize` func
712        # done locally since it depends on args to the decorator
713        # ======================================================================
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result
771
772        # ======================================================================
773        # define `load` func
774        # done locally since it depends on args to the decorator
775        # ======================================================================
776        # mypy thinks this isnt a classmethod
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output
864
865        _methods_no_override: set[str]
866        if methods_no_override is None:
867            _methods_no_override = set()
868        else:
869            _methods_no_override = set(methods_no_override)
870
871        if _methods_no_override - {
872            "__eq__",
873            "serialize",
874            "load",
875            "validate_fields_types",
876        }:
877            warnings.warn(
878                f"Unknown methods in `methods_no_override`: {_methods_no_override = }"
879            )
880
881        # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
882        if "serialize" not in _methods_no_override:
883            # type is `Callable[[T], dict]`
884            cls.serialize = serialize  # type: ignore[attr-defined]
885        if "load" not in _methods_no_override:
886            # type is `Callable[[dict], T]`
887            cls.load = load  # type: ignore[attr-defined]
888
889        if "validate_field_type" not in _methods_no_override:
890            # type is `Callable[[T, ErrorMode], bool]`
891            cls.validate_fields_types = SerializableDataclass__validate_fields_types  # type: ignore[attr-defined]
892
893        if "__eq__" not in _methods_no_override:
894            # type is `Callable[[T, T], bool]`
895            cls.__eq__ = lambda self, other: dc_eq(self, other)  # type: ignore[assignment]
896
897        # Register the class with ZANJ
898        if register_handler:
899            zanj_register_loader_serializable_dataclass(cls)
900
901        return cls
902
903    if _cls is None:
904        return wrap
905    else:
906        return wrap(_cls)

decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass!!

types will be validated (like pydantic) unless on_typecheck_mismatch is set to ErrorMode.IGNORE

behavior of most kwargs matches that of dataclasses.dataclass, but with some additional kwargs. any kwargs not listed here are passed to dataclasses.dataclass

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

Examines PEP 526 __annotations__ to determine fields.

If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.

@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
    a: int
    b: str
>>> Myclass(a=1, b="q").serialize()
{_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}

Parameters:

  • _cls : _type_ class to decorate. don't pass this arg, just use this as a decorator (defaults to None)
  • init : bool whether to add an __init__ method (passed to dataclasses.dataclass) (defaults to True)
  • repr : bool whether to add a __repr__ method (passed to dataclasses.dataclass) (defaults to True)
  • order : bool whether to add rich comparison methods (passed to dataclasses.dataclass) (defaults to False)
  • unsafe_hash : bool whether to add a __hash__ method (passed to dataclasses.dataclass) (defaults to False)
  • frozen : bool whether to make the class frozen (passed to dataclasses.dataclass) (defaults to False)
  • properties_to_serialize : Optional[list[str]] which properties to add to the serialized data dict SerializableDataclass only (defaults to None)
  • register_handler : bool if true, register the class with ZANJ for loading SerializableDataclass only (defaults to True)
  • on_typecheck_error : ErrorMode what to do if type checking throws an exception (except, warn, ignore). If ignore and an exception is thrown, type validation will still return false SerializableDataclass only
  • on_typecheck_mismatch : ErrorMode what to do if a type mismatch is found (except, warn, ignore). If ignore, type validation will return True SerializableDataclass only
  • methods_no_override : list[str]|None list of methods that should not be overridden by the decorator by default, __eq__, serialize, load, and validate_fields_types are overridden by this function, but you can disable this if you'd rather write your own. dataclasses.dataclass might still overwrite these, and those options take precedence SerializableDataclass only (defaults to None)
  • **kwargs (passed to dataclasses.dataclass)

Returns:

  • _type_ the decorated class

Raises: