Coverage for muutils\json_serialize\serializable_dataclass.py: 56%
257 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-14 01:33 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2025-02-14 01:33 -0700
1"""save and load objects to and from json or compatible formats in a recoverable way
3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable,
4you will get an error when you call `json.dumps(d)`. This module provides a way around that.
6Instead, you define your class:
8```python
9@serializable_dataclass
10class MyClass(SerializableDataclass):
11 a: int
12 b: str
13```
15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
17 >>> my_obj = MyClass(a=1, b="q")
18 >>> s = json.dumps(my_obj.serialize())
19 >>> s
20 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
21 >>> read_obj = MyClass.load(json.loads(s))
22 >>> read_obj == my_obj
23 True
25This isn't too impressive on its own, but it gets more useful when you have nested classses,
26or fields that are not json-serializable by default:
28```python
29@serializable_dataclass
30class NestedClass(SerializableDataclass):
31 x: str
32 y: MyClass
33 act_fun: torch.nn.Module = serializable_field(
34 default=torch.nn.ReLU(),
35 serialization_fn=lambda x: str(x),
36 deserialize_fn=lambda x: getattr(torch.nn, x)(),
37 )
38```
40which gives us:
42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
43 >>> s = json.dumps(nc.serialize())
44 >>> s
45 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
46 >>> read_nc = NestedClass.load(json.loads(s))
47 >>> read_nc == nc
48 True
50"""
52from __future__ import annotations
54import abc
55import dataclasses
56import functools
57import json
58import sys
59import typing
60import warnings
61from typing import Any, Optional, Type, TypeVar
63from muutils.errormode import ErrorMode
64from muutils.validate_type import validate_type
65from muutils.json_serialize.serializable_field import (
66 SerializableField,
67 serializable_field,
68)
69from muutils.json_serialize.util import _FORMAT_KEY, array_safe_eq, dc_eq
71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access
74def _dataclass_transform_mock(
75 *,
76 eq_default: bool = True,
77 order_default: bool = False,
78 kw_only_default: bool = False,
79 frozen_default: bool = False,
80 field_specifiers: tuple[type[Any] | typing.Callable[..., Any], ...] = (),
81 **kwargs: Any,
82) -> typing.Callable:
83 "mock `typing.dataclass_transform` for python <3.11"
85 def decorator(cls_or_fn):
86 cls_or_fn.__dataclass_transform__ = {
87 "eq_default": eq_default,
88 "order_default": order_default,
89 "kw_only_default": kw_only_default,
90 "frozen_default": frozen_default,
91 "field_specifiers": field_specifiers,
92 "kwargs": kwargs,
93 }
94 return cls_or_fn
96 return decorator
99if sys.version_info < (3, 11):
100 dataclass_transform = _dataclass_transform_mock
101else:
102 dataclass_transform = typing.dataclass_transform
105T = TypeVar("T")
108class CantGetTypeHintsWarning(UserWarning):
109 "special warning for when we can't get type hints"
111 pass
114class ZanjMissingWarning(UserWarning):
115 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work"
117 pass
120_zanj_loading_needs_import: bool = True
121"flag to keep track of if we have successfully imported ZANJ"
124def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]):
125 """Register a serializable dataclass with the ZANJ import
127 this allows `ZANJ().read()` to load the class and not just return plain dicts
130 # TODO: there is some duplication here with register_loader_handler
131 """
132 global _zanj_loading_needs_import
134 if _zanj_loading_needs_import:
135 try:
136 from zanj.loading import ( # type: ignore[import]
137 LoaderHandler,
138 register_loader_handler,
139 )
140 except ImportError:
141 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter
142 # warnings.warn(
143 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`",
144 # ZanjMissingWarning,
145 # )
146 return
148 _format: str = f"{cls.__name__}(SerializableDataclass)"
149 lh: LoaderHandler = LoaderHandler(
150 check=lambda json_item, path=None, z=None: ( # type: ignore
151 isinstance(json_item, dict)
152 and _FORMAT_KEY in json_item
153 and json_item[_FORMAT_KEY].startswith(_format)
154 ),
155 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore
156 uid=_format,
157 source_pckg=cls.__module__,
158 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass",
159 )
161 register_loader_handler(lh)
163 return lh
166_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN
167_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT
170class FieldIsNotInitOrSerializeWarning(UserWarning):
171 pass
174def SerializableDataclass__validate_field_type(
175 self: SerializableDataclass,
176 field: SerializableField | str,
177 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
178) -> bool:
179 """given a dataclass, check the field matches the type hint
181 this function is written to `SerializableDataclass.validate_field_type`
183 # Parameters:
184 - `self : SerializableDataclass`
185 `SerializableDataclass` instance
186 - `field : SerializableField | str`
187 field to validate, will get from `self.__dataclass_fields__` if an `str`
188 - `on_typecheck_error : ErrorMode`
189 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False`
190 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`)
192 # Returns:
193 - `bool`
194 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore`
195 """
196 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
198 # get field
199 _field: SerializableField
200 if isinstance(field, str):
201 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined]
202 else:
203 _field = field
205 # do nothing case
206 if not _field.assert_type:
207 return True
209 # if field is not `init` or not `serialize`, skip but warn
210 # TODO: how to handle fields which are not `init` or `serialize`?
211 if not _field.init or not _field.serialize:
212 warnings.warn(
213 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked",
214 FieldIsNotInitOrSerializeWarning,
215 )
216 return True
218 assert isinstance(
219 _field, SerializableField
220 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }"
222 # get field type hints
223 try:
224 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name]
225 except KeyError as e:
226 on_typecheck_error.process(
227 (
228 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n"
229 + f"{get_cls_type_hints(self.__class__) = }\n"
230 + f"Python version is {sys.version_info = }. You can:\n"
231 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n"
232 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n"
233 + " - use python 3.9.x or higher\n"
234 + " - specify custom type validation function via `custom_typecheck_fn`\n"
235 ),
236 except_cls=TypeError,
237 except_from=e,
238 )
239 return False
241 # get the value
242 value: Any = getattr(self, _field.name)
244 # validate the type
245 try:
246 type_is_valid: bool
247 # validate the type with the default type validator
248 if _field.custom_typecheck_fn is None:
249 type_is_valid = validate_type(value, field_type_hint)
250 # validate the type with a custom type validator
251 else:
252 type_is_valid = _field.custom_typecheck_fn(field_type_hint)
254 return type_is_valid
256 except Exception as e:
257 on_typecheck_error.process(
258 "exception while validating type: "
259 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }",
260 except_cls=ValueError,
261 except_from=e,
262 )
263 return False
266def SerializableDataclass__validate_fields_types__dict(
267 self: SerializableDataclass,
268 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
269) -> dict[str, bool]:
270 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field
272 returns a dict of field names to bools, where the bool is if the field type is valid
273 """
274 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
276 # if except, bundle the exceptions
277 results: dict[str, bool] = dict()
278 exceptions: dict[str, Exception] = dict()
280 # for each field in the class
281 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment]
282 for field in cls_fields:
283 try:
284 results[field.name] = self.validate_field_type(field, on_typecheck_error)
285 except Exception as e:
286 results[field.name] = False
287 exceptions[field.name] = e
289 # figure out what to do with the exceptions
290 if len(exceptions) > 0:
291 on_typecheck_error.process(
292 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}"
293 + "\n\t"
294 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]),
295 except_cls=ValueError,
296 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict
297 except_from=list(exceptions.values())[0],
298 )
300 return results
303def SerializableDataclass__validate_fields_types(
304 self: SerializableDataclass,
305 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
306) -> bool:
307 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
308 return all(
309 SerializableDataclass__validate_fields_types__dict(
310 self, on_typecheck_error=on_typecheck_error
311 ).values()
312 )
315@dataclass_transform(
316 field_specifiers=(serializable_field, SerializableField),
317)
318class SerializableDataclass(abc.ABC):
319 """Base class for serializable dataclasses
321 only for linting and type checking, still need to call `serializable_dataclass` decorator
323 # Usage:
325 ```python
326 @serializable_dataclass
327 class MyClass(SerializableDataclass):
328 a: int
329 b: str
330 ```
332 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do:
334 >>> my_obj = MyClass(a=1, b="q")
335 >>> s = json.dumps(my_obj.serialize())
336 >>> s
337 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
338 >>> read_obj = MyClass.load(json.loads(s))
339 >>> read_obj == my_obj
340 True
342 This isn't too impressive on its own, but it gets more useful when you have nested classses,
343 or fields that are not json-serializable by default:
345 ```python
346 @serializable_dataclass
347 class NestedClass(SerializableDataclass):
348 x: str
349 y: MyClass
350 act_fun: torch.nn.Module = serializable_field(
351 default=torch.nn.ReLU(),
352 serialization_fn=lambda x: str(x),
353 deserialize_fn=lambda x: getattr(torch.nn, x)(),
354 )
355 ```
357 which gives us:
359 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
360 >>> s = json.dumps(nc.serialize())
361 >>> s
362 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
363 >>> read_nc = NestedClass.load(json.loads(s))
364 >>> read_nc == nc
365 True
366 """
368 def serialize(self) -> dict[str, Any]:
369 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator"
370 raise NotImplementedError(
371 f"decorate {self.__class__ = } with `@serializable_dataclass`"
372 )
374 @classmethod
375 def load(cls: Type[T], data: dict[str, Any] | T) -> T:
376 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator"
377 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
379 def validate_fields_types(
380 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR
381 ) -> bool:
382 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
383 return SerializableDataclass__validate_fields_types(
384 self, on_typecheck_error=on_typecheck_error
385 )
387 def validate_field_type(
388 self,
389 field: "SerializableField|str",
390 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
391 ) -> bool:
392 """given a dataclass, check the field matches the type hint"""
393 return SerializableDataclass__validate_field_type(
394 self, field, on_typecheck_error=on_typecheck_error
395 )
397 def __eq__(self, other: Any) -> bool:
398 return dc_eq(self, other)
400 def __hash__(self) -> int:
401 "hashes the json-serialized representation of the class"
402 return hash(json.dumps(self.serialize()))
404 def diff(
405 self, other: "SerializableDataclass", of_serialized: bool = False
406 ) -> dict[str, Any]:
407 """get a rich and recursive diff between two instances of a serializable dataclass
409 ```python
410 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
411 {'b': {'self': 2, 'other': 3}}
412 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
413 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
414 ```
416 # Parameters:
417 - `other : SerializableDataclass`
418 other instance to compare against
419 - `of_serialized : bool`
420 if true, compare serialized data and not raw values
421 (defaults to `False`)
423 # Returns:
424 - `dict[str, Any]`
427 # Raises:
428 - `ValueError` : if the instances are not of the same type
429 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass`
430 """
431 # match types
432 if type(self) is not type(other):
433 raise ValueError(
434 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }"
435 )
437 # initialize the diff result
438 diff_result: dict = {}
440 # if they are the same, return the empty diff
441 try:
442 if self == other:
443 return diff_result
444 except Exception:
445 pass
447 # if we are working with serialized data, serialize the instances
448 if of_serialized:
449 ser_self: dict = self.serialize()
450 ser_other: dict = other.serialize()
452 # for each field in the class
453 for field in dataclasses.fields(self): # type: ignore[arg-type]
454 # skip fields that are not for comparison
455 if not field.compare:
456 continue
458 # get values
459 field_name: str = field.name
460 self_value = getattr(self, field_name)
461 other_value = getattr(other, field_name)
463 # if the values are both serializable dataclasses, recurse
464 if isinstance(self_value, SerializableDataclass) and isinstance(
465 other_value, SerializableDataclass
466 ):
467 nested_diff: dict = self_value.diff(
468 other_value, of_serialized=of_serialized
469 )
470 if nested_diff:
471 diff_result[field_name] = nested_diff
472 # only support serializable dataclasses
473 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass(
474 other_value
475 ):
476 raise ValueError("Non-serializable dataclass is not supported")
477 else:
478 # get the values of either the serialized or the actual values
479 self_value_s = ser_self[field_name] if of_serialized else self_value
480 other_value_s = ser_other[field_name] if of_serialized else other_value
481 # compare the values
482 if not array_safe_eq(self_value_s, other_value_s):
483 diff_result[field_name] = {"self": self_value, "other": other_value}
485 # return the diff result
486 return diff_result
488 def update_from_nested_dict(self, nested_dict: dict[str, Any]):
489 """update the instance from a nested dict, useful for configuration from command line args
491 # Parameters:
492 - `nested_dict : dict[str, Any]`
493 nested dict to update the instance with
494 """
495 for field in dataclasses.fields(self): # type: ignore[arg-type]
496 field_name: str = field.name
497 self_value = getattr(self, field_name)
499 if field_name in nested_dict:
500 if isinstance(self_value, SerializableDataclass):
501 self_value.update_from_nested_dict(nested_dict[field_name])
502 else:
503 setattr(self, field_name, nested_dict[field_name])
505 def __copy__(self) -> "SerializableDataclass":
506 "deep copy by serializing and loading the instance to json"
507 return self.__class__.load(json.loads(json.dumps(self.serialize())))
509 def __deepcopy__(self, memo: dict) -> "SerializableDataclass":
510 "deep copy by serializing and loading the instance to json"
511 return self.__class__.load(json.loads(json.dumps(self.serialize())))
514# cache this so we don't have to keep getting it
515# TODO: are the types hashable? does this even make sense?
516@functools.lru_cache(typed=True)
517def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]:
518 "cached typing.get_type_hints for a class"
519 return typing.get_type_hints(cls)
522def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]:
523 "helper function to get type hints for a class"
524 cls_type_hints: dict[str, Any]
525 try:
526 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore
527 if len(cls_type_hints) == 0:
528 cls_type_hints = typing.get_type_hints(cls)
530 if len(cls_type_hints) == 0:
531 raise ValueError(f"empty type hints for {cls.__name__ = }")
532 except (TypeError, NameError, ValueError) as e:
533 raise TypeError(
534 f"Cannot get type hints for {cls = }\n"
535 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n"
536 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type]
537 + f" {e = }"
538 ) from e
540 return cls_type_hints
543class KWOnlyError(NotImplementedError):
544 "kw-only dataclasses are not supported in python <3.9"
546 pass
549class FieldError(ValueError):
550 "base class for field errors"
552 pass
555class NotSerializableFieldException(FieldError):
556 "field is not a `SerializableField`"
558 pass
561class FieldSerializationError(FieldError):
562 "error while serializing a field"
564 pass
567class FieldLoadingError(FieldError):
568 "error while loading a field"
570 pass
573class FieldTypeMismatchError(FieldError, TypeError):
574 "error when a field type does not match the type hint"
576 pass
579@dataclass_transform(
580 field_specifiers=(serializable_field, SerializableField),
581)
582def serializable_dataclass(
583 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it
584 _cls=None, # type: ignore
585 *,
586 init: bool = True,
587 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass`
588 eq: bool = True,
589 order: bool = False,
590 unsafe_hash: bool = False,
591 frozen: bool = False,
592 properties_to_serialize: Optional[list[str]] = None,
593 register_handler: bool = True,
594 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
595 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH,
596 methods_no_override: list[str] | None = None,
597 **kwargs,
598):
599 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!**
601 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE`
603 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass`
605 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
607 Examines PEP 526 `__annotations__` to determine fields.
609 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation.
611 ```python
612 @serializable_dataclass(kw_only=True)
613 class Myclass(SerializableDataclass):
614 a: int
615 b: str
616 ```
617 ```python
618 >>> Myclass(a=1, b="q").serialize()
619 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
620 ```
622 # Parameters:
624 - `_cls : _type_`
625 class to decorate. don't pass this arg, just use this as a decorator
626 (defaults to `None`)
627 - `init : bool`
628 whether to add an `__init__` method
629 *(passed to dataclasses.dataclass)*
630 (defaults to `True`)
631 - `repr : bool`
632 whether to add a `__repr__` method
633 *(passed to dataclasses.dataclass)*
634 (defaults to `True`)
635 - `order : bool`
636 whether to add rich comparison methods
637 *(passed to dataclasses.dataclass)*
638 (defaults to `False`)
639 - `unsafe_hash : bool`
640 whether to add a `__hash__` method
641 *(passed to dataclasses.dataclass)*
642 (defaults to `False`)
643 - `frozen : bool`
644 whether to make the class frozen
645 *(passed to dataclasses.dataclass)*
646 (defaults to `False`)
647 - `properties_to_serialize : Optional[list[str]]`
648 which properties to add to the serialized data dict
649 **SerializableDataclass only**
650 (defaults to `None`)
651 - `register_handler : bool`
652 if true, register the class with ZANJ for loading
653 **SerializableDataclass only**
654 (defaults to `True`)
655 - `on_typecheck_error : ErrorMode`
656 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false
657 **SerializableDataclass only**
658 - `on_typecheck_mismatch : ErrorMode`
659 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True`
660 **SerializableDataclass only**
661 - `methods_no_override : list[str]|None`
662 list of methods that should not be overridden by the decorator
663 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function,
664 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence
665 **SerializableDataclass only**
666 (defaults to `None`)
667 - `**kwargs`
668 *(passed to dataclasses.dataclass)*
670 # Returns:
672 - `_type_`
673 the decorated class
675 # Raises:
677 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this
678 - `NotSerializableFieldException` : if a field is not a `SerializableField`
679 - `FieldSerializationError` : if there is an error serializing a field
680 - `AttributeError` : if a property is not found on the class
681 - `FieldLoadingError` : if there is an error loading a field
682 """
683 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
684 on_typecheck_error = ErrorMode.from_any(on_typecheck_error)
685 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch)
687 if properties_to_serialize is None:
688 _properties_to_serialize: list = list()
689 else:
690 _properties_to_serialize = properties_to_serialize
692 def wrap(cls: Type[T]) -> Type[T]:
693 # Modify the __annotations__ dictionary to replace regular fields with SerializableField
694 for field_name, field_type in cls.__annotations__.items():
695 field_value = getattr(cls, field_name, None)
696 if not isinstance(field_value, SerializableField):
697 if isinstance(field_value, dataclasses.Field):
698 # Convert the field to a SerializableField while preserving properties
699 field_value = SerializableField.from_Field(field_value)
700 else:
701 # Create a new SerializableField
702 field_value = serializable_field()
703 setattr(cls, field_name, field_value)
705 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy
706 if sys.version_info < (3, 10):
707 if "kw_only" in kwargs:
708 if kwargs["kw_only"] == True: # noqa: E712
709 raise KWOnlyError(
710 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored"
711 )
712 else:
713 del kwargs["kw_only"]
715 # call `dataclasses.dataclass` to set some stuff up
716 cls = dataclasses.dataclass( # type: ignore[call-overload]
717 cls,
718 init=init,
719 repr=repr,
720 eq=eq,
721 order=order,
722 unsafe_hash=unsafe_hash,
723 frozen=frozen,
724 **kwargs,
725 )
727 # copy these to the class
728 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined]
730 # ======================================================================
731 # define `serialize` func
732 # done locally since it depends on args to the decorator
733 # ======================================================================
734 def serialize(self) -> dict[str, Any]:
735 result: dict[str, Any] = {
736 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
737 }
738 # for each field in the class
739 for field in dataclasses.fields(self): # type: ignore[arg-type]
740 # need it to be our special SerializableField
741 if not isinstance(field, SerializableField):
742 raise NotSerializableFieldException(
743 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
744 f"but a {type(field)} "
745 "this state should be inaccessible, please report this bug!"
746 )
748 # try to save it
749 if field.serialize:
750 try:
751 # get the val
752 value = getattr(self, field.name)
753 # if it is a serializable dataclass, serialize it
754 if isinstance(value, SerializableDataclass):
755 value = value.serialize()
756 # if the value has a serialization function, use that
757 if hasattr(value, "serialize") and callable(value.serialize):
758 value = value.serialize()
759 # if the field has a serialization function, use that
760 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
761 elif field.serialization_fn:
762 value = field.serialization_fn(value)
764 # store the value in the result
765 result[field.name] = value
766 except Exception as e:
767 raise FieldSerializationError(
768 "\n".join(
769 [
770 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
771 f"{field = }",
772 f"{value = }",
773 f"{self = }",
774 ]
775 )
776 ) from e
778 # store each property if we can get it
779 for prop in self._properties_to_serialize:
780 if hasattr(cls, prop):
781 value = getattr(self, prop)
782 result[prop] = value
783 else:
784 raise AttributeError(
785 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
786 + f"but it is in {self._properties_to_serialize = }"
787 + f"\n{self = }"
788 )
790 return result
792 # ======================================================================
793 # define `load` func
794 # done locally since it depends on args to the decorator
795 # ======================================================================
796 # mypy thinks this isnt a classmethod
797 @classmethod # type: ignore[misc]
798 def load(cls, data: dict[str, Any] | T) -> Type[T]:
799 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
800 if isinstance(data, cls):
801 return data
803 assert isinstance(
804 data, typing.Mapping
805 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
807 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
809 # initialize dict for keeping what we will pass to the constructor
810 ctor_kwargs: dict[str, Any] = dict()
812 # iterate over the fields of the class
813 for field in dataclasses.fields(cls):
814 # check if the field is a SerializableField
815 assert isinstance(
816 field, SerializableField
817 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
819 # check if the field is in the data and if it should be initialized
820 if (field.name in data) and field.init:
821 # get the value, we will be processing it
822 value: Any = data[field.name]
824 # get the type hint for the field
825 field_type_hint: Any = cls_type_hints.get(field.name, None)
827 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
828 if field.deserialize_fn:
829 # if it has a deserialization function, use that
830 value = field.deserialize_fn(value)
831 elif field.loading_fn:
832 # if it has a loading function, use that
833 value = field.loading_fn(data)
834 elif (
835 field_type_hint is not None
836 and hasattr(field_type_hint, "load")
837 and callable(field_type_hint.load)
838 ):
839 # if no loading function but has a type hint with a load method, use that
840 if isinstance(value, dict):
841 value = field_type_hint.load(value)
842 else:
843 raise FieldLoadingError(
844 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
845 )
846 else:
847 # assume no loading needs to happen, keep `value` as-is
848 pass
850 # store the value in the constructor kwargs
851 ctor_kwargs[field.name] = value
853 # create a new instance of the class with the constructor kwargs
854 output: cls = cls(**ctor_kwargs)
856 # validate the types of the fields if needed
857 if on_typecheck_mismatch != ErrorMode.IGNORE:
858 fields_valid: dict[str, bool] = (
859 SerializableDataclass__validate_fields_types__dict(
860 output,
861 on_typecheck_error=on_typecheck_error,
862 )
863 )
865 # if there are any fields that are not valid, raise an error
866 if not all(fields_valid.values()):
867 msg: str = (
868 f"Type mismatch in fields of {cls.__name__}:\n"
869 + "\n".join(
870 [
871 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
872 for k, v in fields_valid.items()
873 if not v
874 ]
875 )
876 )
878 on_typecheck_mismatch.process(
879 msg, except_cls=FieldTypeMismatchError
880 )
882 # return the new instance
883 return output
885 _methods_no_override: set[str]
886 if methods_no_override is None:
887 _methods_no_override = set()
888 else:
889 _methods_no_override = set(methods_no_override)
891 if _methods_no_override - {
892 "__eq__",
893 "serialize",
894 "load",
895 "validate_fields_types",
896 }:
897 warnings.warn(
898 f"Unknown methods in `methods_no_override`: {_methods_no_override = }"
899 )
901 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments
902 if "serialize" not in _methods_no_override:
903 # type is `Callable[[T], dict]`
904 cls.serialize = serialize # type: ignore[attr-defined]
905 if "load" not in _methods_no_override:
906 # type is `Callable[[dict], T]`
907 cls.load = load # type: ignore[attr-defined]
909 if "validate_field_type" not in _methods_no_override:
910 # type is `Callable[[T, ErrorMode], bool]`
911 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined]
913 if "__eq__" not in _methods_no_override:
914 # type is `Callable[[T, T], bool]`
915 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment]
917 # Register the class with ZANJ
918 if register_handler:
919 zanj_register_loader_serializable_dataclass(cls)
921 return cls
923 if _cls is None:
924 return wrap
925 else:
926 return wrap(_cls)