muutils.json_serialize.serializable_dataclass
save and load objects to and from json or compatible formats in a recoverable way
d = dataclasses.asdict(my_obj)
will give you a dict, but if some fields are not json-serializable,
you will get an error when you call json.dumps(d)
. This module provides a way around that.
Instead, you define your class:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize()
to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
1"""save and load objects to and from json or compatible formats in a recoverable way 2 3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable, 4you will get an error when you call `json.dumps(d)`. This module provides a way around that. 5 6Instead, you define your class: 7 8```python 9@serializable_dataclass 10class MyClass(SerializableDataclass): 11 a: int 12 b: str 13``` 14 15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 16 17 >>> my_obj = MyClass(a=1, b="q") 18 >>> s = json.dumps(my_obj.serialize()) 19 >>> s 20 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 21 >>> read_obj = MyClass.load(json.loads(s)) 22 >>> read_obj == my_obj 23 True 24 25This isn't too impressive on its own, but it gets more useful when you have nested classses, 26or fields that are not json-serializable by default: 27 28```python 29@serializable_dataclass 30class NestedClass(SerializableDataclass): 31 x: str 32 y: MyClass 33 act_fun: torch.nn.Module = serializable_field( 34 default=torch.nn.ReLU(), 35 serialization_fn=lambda x: str(x), 36 deserialize_fn=lambda x: getattr(torch.nn, x)(), 37 ) 38``` 39 40which gives us: 41 42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 43 >>> s = json.dumps(nc.serialize()) 44 >>> s 45 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 46 >>> read_nc = NestedClass.load(json.loads(s)) 47 >>> read_nc == nc 48 True 49 50""" 51 52from __future__ import annotations 53 54import abc 55import dataclasses 56import functools 57import json 58import sys 59import typing 60import warnings 61from typing import Any, Optional, Type, TypeVar 62 63from muutils.errormode import ErrorMode 64from muutils.validate_type import validate_type 65from muutils.json_serialize.serializable_field import ( 66 SerializableField, 67 serializable_field, 68) 69from muutils.json_serialize.util import _FORMAT_KEY, array_safe_eq, dc_eq 70 71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access 72 73# this is quite horrible, but unfortunately mypy fails if we try to assign to `dataclass_transform` directly 74# and every time we try to init a serializable dataclass it says the argument doesnt exist 75try: 76 try: 77 # type ignore here for legacy versions 78 from typing import dataclass_transform # type: ignore[attr-defined] 79 except Exception: 80 from typing_extensions import dataclass_transform 81except Exception: 82 from muutils.json_serialize.dataclass_transform_mock import dataclass_transform 83 84T = TypeVar("T") 85 86 87class CantGetTypeHintsWarning(UserWarning): 88 "special warning for when we can't get type hints" 89 90 pass 91 92 93class ZanjMissingWarning(UserWarning): 94 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 95 96 pass 97 98 99_zanj_loading_needs_import: bool = True 100"flag to keep track of if we have successfully imported ZANJ" 101 102 103def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): 104 """Register a serializable dataclass with the ZANJ import 105 106 this allows `ZANJ().read()` to load the class and not just return plain dicts 107 108 109 # TODO: there is some duplication here with register_loader_handler 110 """ 111 global _zanj_loading_needs_import 112 113 if _zanj_loading_needs_import: 114 try: 115 from zanj.loading import ( # type: ignore[import] 116 LoaderHandler, 117 register_loader_handler, 118 ) 119 except ImportError: 120 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter 121 # warnings.warn( 122 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 123 # ZanjMissingWarning, 124 # ) 125 return 126 127 _format: str = f"{cls.__name__}(SerializableDataclass)" 128 lh: LoaderHandler = LoaderHandler( 129 check=lambda json_item, path=None, z=None: ( # type: ignore 130 isinstance(json_item, dict) 131 and _FORMAT_KEY in json_item 132 and json_item[_FORMAT_KEY].startswith(_format) 133 ), 134 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 135 uid=_format, 136 source_pckg=cls.__module__, 137 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 138 ) 139 140 register_loader_handler(lh) 141 142 return lh 143 144 145_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN 146_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT 147 148 149class FieldIsNotInitOrSerializeWarning(UserWarning): 150 pass 151 152 153def SerializableDataclass__validate_field_type( 154 self: SerializableDataclass, 155 field: SerializableField | str, 156 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 157) -> bool: 158 """given a dataclass, check the field matches the type hint 159 160 this function is written to `SerializableDataclass.validate_field_type` 161 162 # Parameters: 163 - `self : SerializableDataclass` 164 `SerializableDataclass` instance 165 - `field : SerializableField | str` 166 field to validate, will get from `self.__dataclass_fields__` if an `str` 167 - `on_typecheck_error : ErrorMode` 168 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 169 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 170 171 # Returns: 172 - `bool` 173 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 174 """ 175 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 176 177 # get field 178 _field: SerializableField 179 if isinstance(field, str): 180 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 181 else: 182 _field = field 183 184 # do nothing case 185 if not _field.assert_type: 186 return True 187 188 # if field is not `init` or not `serialize`, skip but warn 189 # TODO: how to handle fields which are not `init` or `serialize`? 190 if not _field.init or not _field.serialize: 191 warnings.warn( 192 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 193 FieldIsNotInitOrSerializeWarning, 194 ) 195 return True 196 197 assert isinstance( 198 _field, SerializableField 199 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 200 201 # get field type hints 202 try: 203 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 204 except KeyError as e: 205 on_typecheck_error.process( 206 ( 207 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 208 + f"{get_cls_type_hints(self.__class__) = }\n" 209 + f"Python version is {sys.version_info = }. You can:\n" 210 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 211 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 212 + " - use python 3.9.x or higher\n" 213 + " - specify custom type validation function via `custom_typecheck_fn`\n" 214 ), 215 except_cls=TypeError, 216 except_from=e, 217 ) 218 return False 219 220 # get the value 221 value: Any = getattr(self, _field.name) 222 223 # validate the type 224 try: 225 type_is_valid: bool 226 # validate the type with the default type validator 227 if _field.custom_typecheck_fn is None: 228 type_is_valid = validate_type(value, field_type_hint) 229 # validate the type with a custom type validator 230 else: 231 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 232 233 return type_is_valid 234 235 except Exception as e: 236 on_typecheck_error.process( 237 "exception while validating type: " 238 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 239 except_cls=ValueError, 240 except_from=e, 241 ) 242 return False 243 244 245def SerializableDataclass__validate_fields_types__dict( 246 self: SerializableDataclass, 247 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 248) -> dict[str, bool]: 249 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 250 251 returns a dict of field names to bools, where the bool is if the field type is valid 252 """ 253 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 254 255 # if except, bundle the exceptions 256 results: dict[str, bool] = dict() 257 exceptions: dict[str, Exception] = dict() 258 259 # for each field in the class 260 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 261 for field in cls_fields: 262 try: 263 results[field.name] = self.validate_field_type(field, on_typecheck_error) 264 except Exception as e: 265 results[field.name] = False 266 exceptions[field.name] = e 267 268 # figure out what to do with the exceptions 269 if len(exceptions) > 0: 270 on_typecheck_error.process( 271 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 272 + "\n\t" 273 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 274 except_cls=ValueError, 275 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 276 except_from=list(exceptions.values())[0], 277 ) 278 279 return results 280 281 282def SerializableDataclass__validate_fields_types( 283 self: SerializableDataclass, 284 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 285) -> bool: 286 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 287 return all( 288 SerializableDataclass__validate_fields_types__dict( 289 self, on_typecheck_error=on_typecheck_error 290 ).values() 291 ) 292 293 294@dataclass_transform( 295 field_specifiers=(serializable_field, SerializableField), 296) 297class SerializableDataclass(abc.ABC): 298 """Base class for serializable dataclasses 299 300 only for linting and type checking, still need to call `serializable_dataclass` decorator 301 302 # Usage: 303 304 ```python 305 @serializable_dataclass 306 class MyClass(SerializableDataclass): 307 a: int 308 b: str 309 ``` 310 311 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 312 313 >>> my_obj = MyClass(a=1, b="q") 314 >>> s = json.dumps(my_obj.serialize()) 315 >>> s 316 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 317 >>> read_obj = MyClass.load(json.loads(s)) 318 >>> read_obj == my_obj 319 True 320 321 This isn't too impressive on its own, but it gets more useful when you have nested classses, 322 or fields that are not json-serializable by default: 323 324 ```python 325 @serializable_dataclass 326 class NestedClass(SerializableDataclass): 327 x: str 328 y: MyClass 329 act_fun: torch.nn.Module = serializable_field( 330 default=torch.nn.ReLU(), 331 serialization_fn=lambda x: str(x), 332 deserialize_fn=lambda x: getattr(torch.nn, x)(), 333 ) 334 ``` 335 336 which gives us: 337 338 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 339 >>> s = json.dumps(nc.serialize()) 340 >>> s 341 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 342 >>> read_nc = NestedClass.load(json.loads(s)) 343 >>> read_nc == nc 344 True 345 """ 346 347 def serialize(self) -> dict[str, Any]: 348 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 349 raise NotImplementedError( 350 f"decorate {self.__class__ = } with `@serializable_dataclass`" 351 ) 352 353 @classmethod 354 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 355 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 356 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 357 358 def validate_fields_types( 359 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 360 ) -> bool: 361 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 362 return SerializableDataclass__validate_fields_types( 363 self, on_typecheck_error=on_typecheck_error 364 ) 365 366 def validate_field_type( 367 self, 368 field: "SerializableField|str", 369 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 370 ) -> bool: 371 """given a dataclass, check the field matches the type hint""" 372 return SerializableDataclass__validate_field_type( 373 self, field, on_typecheck_error=on_typecheck_error 374 ) 375 376 def __eq__(self, other: Any) -> bool: 377 return dc_eq(self, other) 378 379 def __hash__(self) -> int: 380 "hashes the json-serialized representation of the class" 381 return hash(json.dumps(self.serialize())) 382 383 def diff( 384 self, other: "SerializableDataclass", of_serialized: bool = False 385 ) -> dict[str, Any]: 386 """get a rich and recursive diff between two instances of a serializable dataclass 387 388 ```python 389 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 390 {'b': {'self': 2, 'other': 3}} 391 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 392 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 393 ``` 394 395 # Parameters: 396 - `other : SerializableDataclass` 397 other instance to compare against 398 - `of_serialized : bool` 399 if true, compare serialized data and not raw values 400 (defaults to `False`) 401 402 # Returns: 403 - `dict[str, Any]` 404 405 406 # Raises: 407 - `ValueError` : if the instances are not of the same type 408 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 409 """ 410 # match types 411 if type(self) is not type(other): 412 raise ValueError( 413 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 414 ) 415 416 # initialize the diff result 417 diff_result: dict = {} 418 419 # if they are the same, return the empty diff 420 try: 421 if self == other: 422 return diff_result 423 except Exception: 424 pass 425 426 # if we are working with serialized data, serialize the instances 427 if of_serialized: 428 ser_self: dict = self.serialize() 429 ser_other: dict = other.serialize() 430 431 # for each field in the class 432 for field in dataclasses.fields(self): # type: ignore[arg-type] 433 # skip fields that are not for comparison 434 if not field.compare: 435 continue 436 437 # get values 438 field_name: str = field.name 439 self_value = getattr(self, field_name) 440 other_value = getattr(other, field_name) 441 442 # if the values are both serializable dataclasses, recurse 443 if isinstance(self_value, SerializableDataclass) and isinstance( 444 other_value, SerializableDataclass 445 ): 446 nested_diff: dict = self_value.diff( 447 other_value, of_serialized=of_serialized 448 ) 449 if nested_diff: 450 diff_result[field_name] = nested_diff 451 # only support serializable dataclasses 452 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 453 other_value 454 ): 455 raise ValueError("Non-serializable dataclass is not supported") 456 else: 457 # get the values of either the serialized or the actual values 458 self_value_s = ser_self[field_name] if of_serialized else self_value 459 other_value_s = ser_other[field_name] if of_serialized else other_value 460 # compare the values 461 if not array_safe_eq(self_value_s, other_value_s): 462 diff_result[field_name] = {"self": self_value, "other": other_value} 463 464 # return the diff result 465 return diff_result 466 467 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 468 """update the instance from a nested dict, useful for configuration from command line args 469 470 # Parameters: 471 - `nested_dict : dict[str, Any]` 472 nested dict to update the instance with 473 """ 474 for field in dataclasses.fields(self): # type: ignore[arg-type] 475 field_name: str = field.name 476 self_value = getattr(self, field_name) 477 478 if field_name in nested_dict: 479 if isinstance(self_value, SerializableDataclass): 480 self_value.update_from_nested_dict(nested_dict[field_name]) 481 else: 482 setattr(self, field_name, nested_dict[field_name]) 483 484 def __copy__(self) -> "SerializableDataclass": 485 "deep copy by serializing and loading the instance to json" 486 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 487 488 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 489 "deep copy by serializing and loading the instance to json" 490 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 491 492 493# cache this so we don't have to keep getting it 494# TODO: are the types hashable? does this even make sense? 495@functools.lru_cache(typed=True) 496def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]: 497 "cached typing.get_type_hints for a class" 498 return typing.get_type_hints(cls) 499 500 501def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: 502 "helper function to get type hints for a class" 503 cls_type_hints: dict[str, Any] 504 try: 505 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 506 if len(cls_type_hints) == 0: 507 cls_type_hints = typing.get_type_hints(cls) 508 509 if len(cls_type_hints) == 0: 510 raise ValueError(f"empty type hints for {cls.__name__ = }") 511 except (TypeError, NameError, ValueError) as e: 512 raise TypeError( 513 f"Cannot get type hints for {cls = }\n" 514 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 515 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 516 + f" {e = }" 517 ) from e 518 519 return cls_type_hints 520 521 522class KWOnlyError(NotImplementedError): 523 "kw-only dataclasses are not supported in python <3.9" 524 525 pass 526 527 528class FieldError(ValueError): 529 "base class for field errors" 530 531 pass 532 533 534class NotSerializableFieldException(FieldError): 535 "field is not a `SerializableField`" 536 537 pass 538 539 540class FieldSerializationError(FieldError): 541 "error while serializing a field" 542 543 pass 544 545 546class FieldLoadingError(FieldError): 547 "error while loading a field" 548 549 pass 550 551 552class FieldTypeMismatchError(FieldError, TypeError): 553 "error when a field type does not match the type hint" 554 555 pass 556 557 558@dataclass_transform( 559 field_specifiers=(serializable_field, SerializableField), 560) 561def serializable_dataclass( 562 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 563 _cls=None, # type: ignore 564 *, 565 init: bool = True, 566 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 567 eq: bool = True, 568 order: bool = False, 569 unsafe_hash: bool = False, 570 frozen: bool = False, 571 properties_to_serialize: Optional[list[str]] = None, 572 register_handler: bool = True, 573 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 574 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 575 methods_no_override: list[str] | None = None, 576 **kwargs, 577): 578 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!** 579 580 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 581 582 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass` 583 584 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 585 586 Examines PEP 526 `__annotations__` to determine fields. 587 588 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 589 590 ```python 591 @serializable_dataclass(kw_only=True) 592 class Myclass(SerializableDataclass): 593 a: int 594 b: str 595 ``` 596 ```python 597 >>> Myclass(a=1, b="q").serialize() 598 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 599 ``` 600 601 # Parameters: 602 603 - `_cls : _type_` 604 class to decorate. don't pass this arg, just use this as a decorator 605 (defaults to `None`) 606 - `init : bool` 607 whether to add an `__init__` method 608 *(passed to dataclasses.dataclass)* 609 (defaults to `True`) 610 - `repr : bool` 611 whether to add a `__repr__` method 612 *(passed to dataclasses.dataclass)* 613 (defaults to `True`) 614 - `order : bool` 615 whether to add rich comparison methods 616 *(passed to dataclasses.dataclass)* 617 (defaults to `False`) 618 - `unsafe_hash : bool` 619 whether to add a `__hash__` method 620 *(passed to dataclasses.dataclass)* 621 (defaults to `False`) 622 - `frozen : bool` 623 whether to make the class frozen 624 *(passed to dataclasses.dataclass)* 625 (defaults to `False`) 626 - `properties_to_serialize : Optional[list[str]]` 627 which properties to add to the serialized data dict 628 **SerializableDataclass only** 629 (defaults to `None`) 630 - `register_handler : bool` 631 if true, register the class with ZANJ for loading 632 **SerializableDataclass only** 633 (defaults to `True`) 634 - `on_typecheck_error : ErrorMode` 635 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 636 **SerializableDataclass only** 637 - `on_typecheck_mismatch : ErrorMode` 638 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 639 **SerializableDataclass only** 640 - `methods_no_override : list[str]|None` 641 list of methods that should not be overridden by the decorator 642 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function, 643 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence 644 **SerializableDataclass only** 645 (defaults to `None`) 646 - `**kwargs` 647 *(passed to dataclasses.dataclass)* 648 649 # Returns: 650 651 - `_type_` 652 the decorated class 653 654 # Raises: 655 656 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 657 - `NotSerializableFieldException` : if a field is not a `SerializableField` 658 - `FieldSerializationError` : if there is an error serializing a field 659 - `AttributeError` : if a property is not found on the class 660 - `FieldLoadingError` : if there is an error loading a field 661 """ 662 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 663 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 664 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 665 666 if properties_to_serialize is None: 667 _properties_to_serialize: list = list() 668 else: 669 _properties_to_serialize = properties_to_serialize 670 671 def wrap(cls: Type[T]) -> Type[T]: 672 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 673 for field_name, field_type in cls.__annotations__.items(): 674 field_value = getattr(cls, field_name, None) 675 if not isinstance(field_value, SerializableField): 676 if isinstance(field_value, dataclasses.Field): 677 # Convert the field to a SerializableField while preserving properties 678 field_value = SerializableField.from_Field(field_value) 679 else: 680 # Create a new SerializableField 681 field_value = serializable_field() 682 setattr(cls, field_name, field_value) 683 684 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 685 if sys.version_info < (3, 10): 686 if "kw_only" in kwargs: 687 if kwargs["kw_only"] == True: # noqa: E712 688 raise KWOnlyError( 689 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored" 690 ) 691 else: 692 del kwargs["kw_only"] 693 694 # call `dataclasses.dataclass` to set some stuff up 695 cls = dataclasses.dataclass( # type: ignore[call-overload] 696 cls, 697 init=init, 698 repr=repr, 699 eq=eq, 700 order=order, 701 unsafe_hash=unsafe_hash, 702 frozen=frozen, 703 **kwargs, 704 ) 705 706 # copy these to the class 707 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 708 709 # ====================================================================== 710 # define `serialize` func 711 # done locally since it depends on args to the decorator 712 # ====================================================================== 713 def serialize(self) -> dict[str, Any]: 714 result: dict[str, Any] = { 715 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 716 } 717 # for each field in the class 718 for field in dataclasses.fields(self): # type: ignore[arg-type] 719 # need it to be our special SerializableField 720 if not isinstance(field, SerializableField): 721 raise NotSerializableFieldException( 722 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 723 f"but a {type(field)} " 724 "this state should be inaccessible, please report this bug!" 725 ) 726 727 # try to save it 728 if field.serialize: 729 try: 730 # get the val 731 value = getattr(self, field.name) 732 # if it is a serializable dataclass, serialize it 733 if isinstance(value, SerializableDataclass): 734 value = value.serialize() 735 # if the value has a serialization function, use that 736 if hasattr(value, "serialize") and callable(value.serialize): 737 value = value.serialize() 738 # if the field has a serialization function, use that 739 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 740 elif field.serialization_fn: 741 value = field.serialization_fn(value) 742 743 # store the value in the result 744 result[field.name] = value 745 except Exception as e: 746 raise FieldSerializationError( 747 "\n".join( 748 [ 749 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 750 f"{field = }", 751 f"{value = }", 752 f"{self = }", 753 ] 754 ) 755 ) from e 756 757 # store each property if we can get it 758 for prop in self._properties_to_serialize: 759 if hasattr(cls, prop): 760 value = getattr(self, prop) 761 result[prop] = value 762 else: 763 raise AttributeError( 764 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 765 + f"but it is in {self._properties_to_serialize = }" 766 + f"\n{self = }" 767 ) 768 769 return result 770 771 # ====================================================================== 772 # define `load` func 773 # done locally since it depends on args to the decorator 774 # ====================================================================== 775 # mypy thinks this isnt a classmethod 776 @classmethod # type: ignore[misc] 777 def load(cls, data: dict[str, Any] | T) -> Type[T]: 778 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 779 if isinstance(data, cls): 780 return data 781 782 assert isinstance( 783 data, typing.Mapping 784 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 785 786 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 787 788 # initialize dict for keeping what we will pass to the constructor 789 ctor_kwargs: dict[str, Any] = dict() 790 791 # iterate over the fields of the class 792 for field in dataclasses.fields(cls): 793 # check if the field is a SerializableField 794 assert isinstance( 795 field, SerializableField 796 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 797 798 # check if the field is in the data and if it should be initialized 799 if (field.name in data) and field.init: 800 # get the value, we will be processing it 801 value: Any = data[field.name] 802 803 # get the type hint for the field 804 field_type_hint: Any = cls_type_hints.get(field.name, None) 805 806 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 807 if field.deserialize_fn: 808 # if it has a deserialization function, use that 809 value = field.deserialize_fn(value) 810 elif field.loading_fn: 811 # if it has a loading function, use that 812 value = field.loading_fn(data) 813 elif ( 814 field_type_hint is not None 815 and hasattr(field_type_hint, "load") 816 and callable(field_type_hint.load) 817 ): 818 # if no loading function but has a type hint with a load method, use that 819 if isinstance(value, dict): 820 value = field_type_hint.load(value) 821 else: 822 raise FieldLoadingError( 823 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 824 ) 825 else: 826 # assume no loading needs to happen, keep `value` as-is 827 pass 828 829 # store the value in the constructor kwargs 830 ctor_kwargs[field.name] = value 831 832 # create a new instance of the class with the constructor kwargs 833 output: cls = cls(**ctor_kwargs) 834 835 # validate the types of the fields if needed 836 if on_typecheck_mismatch != ErrorMode.IGNORE: 837 fields_valid: dict[str, bool] = ( 838 SerializableDataclass__validate_fields_types__dict( 839 output, 840 on_typecheck_error=on_typecheck_error, 841 ) 842 ) 843 844 # if there are any fields that are not valid, raise an error 845 if not all(fields_valid.values()): 846 msg: str = ( 847 f"Type mismatch in fields of {cls.__name__}:\n" 848 + "\n".join( 849 [ 850 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 851 for k, v in fields_valid.items() 852 if not v 853 ] 854 ) 855 ) 856 857 on_typecheck_mismatch.process( 858 msg, except_cls=FieldTypeMismatchError 859 ) 860 861 # return the new instance 862 return output 863 864 _methods_no_override: set[str] 865 if methods_no_override is None: 866 _methods_no_override = set() 867 else: 868 _methods_no_override = set(methods_no_override) 869 870 if _methods_no_override - { 871 "__eq__", 872 "serialize", 873 "load", 874 "validate_fields_types", 875 }: 876 warnings.warn( 877 f"Unknown methods in `methods_no_override`: {_methods_no_override = }" 878 ) 879 880 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 881 if "serialize" not in _methods_no_override: 882 # type is `Callable[[T], dict]` 883 cls.serialize = serialize # type: ignore[attr-defined] 884 if "load" not in _methods_no_override: 885 # type is `Callable[[dict], T]` 886 cls.load = load # type: ignore[attr-defined] 887 888 if "validate_field_type" not in _methods_no_override: 889 # type is `Callable[[T, ErrorMode], bool]` 890 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 891 892 if "__eq__" not in _methods_no_override: 893 # type is `Callable[[T, T], bool]` 894 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 895 896 # Register the class with ZANJ 897 if register_handler: 898 zanj_register_loader_serializable_dataclass(cls) 899 900 return cls 901 902 if _cls is None: 903 return wrap 904 else: 905 return wrap(_cls)
88class CantGetTypeHintsWarning(UserWarning): 89 "special warning for when we can't get type hints" 90 91 pass
special warning for when we can't get type hints
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
94class ZanjMissingWarning(UserWarning): 95 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 96 97 pass
special warning for when ZANJ
is missing -- register_loader_serializable_dataclass
will not work
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
104def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): 105 """Register a serializable dataclass with the ZANJ import 106 107 this allows `ZANJ().read()` to load the class and not just return plain dicts 108 109 110 # TODO: there is some duplication here with register_loader_handler 111 """ 112 global _zanj_loading_needs_import 113 114 if _zanj_loading_needs_import: 115 try: 116 from zanj.loading import ( # type: ignore[import] 117 LoaderHandler, 118 register_loader_handler, 119 ) 120 except ImportError: 121 # NOTE: if ZANJ is not installed, then failing to register the loader handler doesnt matter 122 # warnings.warn( 123 # "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 124 # ZanjMissingWarning, 125 # ) 126 return 127 128 _format: str = f"{cls.__name__}(SerializableDataclass)" 129 lh: LoaderHandler = LoaderHandler( 130 check=lambda json_item, path=None, z=None: ( # type: ignore 131 isinstance(json_item, dict) 132 and _FORMAT_KEY in json_item 133 and json_item[_FORMAT_KEY].startswith(_format) 134 ), 135 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 136 uid=_format, 137 source_pckg=cls.__module__, 138 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 139 ) 140 141 register_loader_handler(lh) 142 143 return lh
Register a serializable dataclass with the ZANJ import
this allows ZANJ().read()
to load the class and not just return plain dicts
TODO: there is some duplication here with register_loader_handler
Base class for warnings generated by user code.
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
154def SerializableDataclass__validate_field_type( 155 self: SerializableDataclass, 156 field: SerializableField | str, 157 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 158) -> bool: 159 """given a dataclass, check the field matches the type hint 160 161 this function is written to `SerializableDataclass.validate_field_type` 162 163 # Parameters: 164 - `self : SerializableDataclass` 165 `SerializableDataclass` instance 166 - `field : SerializableField | str` 167 field to validate, will get from `self.__dataclass_fields__` if an `str` 168 - `on_typecheck_error : ErrorMode` 169 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 170 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 171 172 # Returns: 173 - `bool` 174 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 175 """ 176 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 177 178 # get field 179 _field: SerializableField 180 if isinstance(field, str): 181 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 182 else: 183 _field = field 184 185 # do nothing case 186 if not _field.assert_type: 187 return True 188 189 # if field is not `init` or not `serialize`, skip but warn 190 # TODO: how to handle fields which are not `init` or `serialize`? 191 if not _field.init or not _field.serialize: 192 warnings.warn( 193 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 194 FieldIsNotInitOrSerializeWarning, 195 ) 196 return True 197 198 assert isinstance( 199 _field, SerializableField 200 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 201 202 # get field type hints 203 try: 204 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 205 except KeyError as e: 206 on_typecheck_error.process( 207 ( 208 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 209 + f"{get_cls_type_hints(self.__class__) = }\n" 210 + f"Python version is {sys.version_info = }. You can:\n" 211 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 212 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 213 + " - use python 3.9.x or higher\n" 214 + " - specify custom type validation function via `custom_typecheck_fn`\n" 215 ), 216 except_cls=TypeError, 217 except_from=e, 218 ) 219 return False 220 221 # get the value 222 value: Any = getattr(self, _field.name) 223 224 # validate the type 225 try: 226 type_is_valid: bool 227 # validate the type with the default type validator 228 if _field.custom_typecheck_fn is None: 229 type_is_valid = validate_type(value, field_type_hint) 230 # validate the type with a custom type validator 231 else: 232 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 233 234 return type_is_valid 235 236 except Exception as e: 237 on_typecheck_error.process( 238 "exception while validating type: " 239 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 240 except_cls=ValueError, 241 except_from=e, 242 ) 243 return False
given a dataclass, check the field matches the type hint
this function is written to SerializableDataclass.validate_field_type
Parameters:
self : SerializableDataclass
SerializableDataclass
instancefield : SerializableField | str
field to validate, will get fromself.__dataclass_fields__
if anstr
on_typecheck_error : ErrorMode
what to do if type checking throws an exception (except, warn, ignore). Ifignore
and an exception is thrown, the function will returnFalse
(defaults to_DEFAULT_ON_TYPECHECK_ERROR
)
Returns:
bool
if the field type is correct.False
if the field type is incorrect or an exception is thrown andon_typecheck_error
isignore
246def SerializableDataclass__validate_fields_types__dict( 247 self: SerializableDataclass, 248 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 249) -> dict[str, bool]: 250 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 251 252 returns a dict of field names to bools, where the bool is if the field type is valid 253 """ 254 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 255 256 # if except, bundle the exceptions 257 results: dict[str, bool] = dict() 258 exceptions: dict[str, Exception] = dict() 259 260 # for each field in the class 261 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 262 for field in cls_fields: 263 try: 264 results[field.name] = self.validate_field_type(field, on_typecheck_error) 265 except Exception as e: 266 results[field.name] = False 267 exceptions[field.name] = e 268 269 # figure out what to do with the exceptions 270 if len(exceptions) > 0: 271 on_typecheck_error.process( 272 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 273 + "\n\t" 274 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 275 except_cls=ValueError, 276 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 277 except_from=list(exceptions.values())[0], 278 ) 279 280 return results
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
returns a dict of field names to bools, where the bool is if the field type is valid
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
295@dataclass_transform( 296 field_specifiers=(serializable_field, SerializableField), 297) 298class SerializableDataclass(abc.ABC): 299 """Base class for serializable dataclasses 300 301 only for linting and type checking, still need to call `serializable_dataclass` decorator 302 303 # Usage: 304 305 ```python 306 @serializable_dataclass 307 class MyClass(SerializableDataclass): 308 a: int 309 b: str 310 ``` 311 312 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 313 314 >>> my_obj = MyClass(a=1, b="q") 315 >>> s = json.dumps(my_obj.serialize()) 316 >>> s 317 '{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 318 >>> read_obj = MyClass.load(json.loads(s)) 319 >>> read_obj == my_obj 320 True 321 322 This isn't too impressive on its own, but it gets more useful when you have nested classses, 323 or fields that are not json-serializable by default: 324 325 ```python 326 @serializable_dataclass 327 class NestedClass(SerializableDataclass): 328 x: str 329 y: MyClass 330 act_fun: torch.nn.Module = serializable_field( 331 default=torch.nn.ReLU(), 332 serialization_fn=lambda x: str(x), 333 deserialize_fn=lambda x: getattr(torch.nn, x)(), 334 ) 335 ``` 336 337 which gives us: 338 339 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 340 >>> s = json.dumps(nc.serialize()) 341 >>> s 342 '{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 343 >>> read_nc = NestedClass.load(json.loads(s)) 344 >>> read_nc == nc 345 True 346 """ 347 348 def serialize(self) -> dict[str, Any]: 349 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 350 raise NotImplementedError( 351 f"decorate {self.__class__ = } with `@serializable_dataclass`" 352 ) 353 354 @classmethod 355 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 356 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 357 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 358 359 def validate_fields_types( 360 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 361 ) -> bool: 362 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 363 return SerializableDataclass__validate_fields_types( 364 self, on_typecheck_error=on_typecheck_error 365 ) 366 367 def validate_field_type( 368 self, 369 field: "SerializableField|str", 370 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 371 ) -> bool: 372 """given a dataclass, check the field matches the type hint""" 373 return SerializableDataclass__validate_field_type( 374 self, field, on_typecheck_error=on_typecheck_error 375 ) 376 377 def __eq__(self, other: Any) -> bool: 378 return dc_eq(self, other) 379 380 def __hash__(self) -> int: 381 "hashes the json-serialized representation of the class" 382 return hash(json.dumps(self.serialize())) 383 384 def diff( 385 self, other: "SerializableDataclass", of_serialized: bool = False 386 ) -> dict[str, Any]: 387 """get a rich and recursive diff between two instances of a serializable dataclass 388 389 ```python 390 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 391 {'b': {'self': 2, 'other': 3}} 392 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 393 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 394 ``` 395 396 # Parameters: 397 - `other : SerializableDataclass` 398 other instance to compare against 399 - `of_serialized : bool` 400 if true, compare serialized data and not raw values 401 (defaults to `False`) 402 403 # Returns: 404 - `dict[str, Any]` 405 406 407 # Raises: 408 - `ValueError` : if the instances are not of the same type 409 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 410 """ 411 # match types 412 if type(self) is not type(other): 413 raise ValueError( 414 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 415 ) 416 417 # initialize the diff result 418 diff_result: dict = {} 419 420 # if they are the same, return the empty diff 421 try: 422 if self == other: 423 return diff_result 424 except Exception: 425 pass 426 427 # if we are working with serialized data, serialize the instances 428 if of_serialized: 429 ser_self: dict = self.serialize() 430 ser_other: dict = other.serialize() 431 432 # for each field in the class 433 for field in dataclasses.fields(self): # type: ignore[arg-type] 434 # skip fields that are not for comparison 435 if not field.compare: 436 continue 437 438 # get values 439 field_name: str = field.name 440 self_value = getattr(self, field_name) 441 other_value = getattr(other, field_name) 442 443 # if the values are both serializable dataclasses, recurse 444 if isinstance(self_value, SerializableDataclass) and isinstance( 445 other_value, SerializableDataclass 446 ): 447 nested_diff: dict = self_value.diff( 448 other_value, of_serialized=of_serialized 449 ) 450 if nested_diff: 451 diff_result[field_name] = nested_diff 452 # only support serializable dataclasses 453 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 454 other_value 455 ): 456 raise ValueError("Non-serializable dataclass is not supported") 457 else: 458 # get the values of either the serialized or the actual values 459 self_value_s = ser_self[field_name] if of_serialized else self_value 460 other_value_s = ser_other[field_name] if of_serialized else other_value 461 # compare the values 462 if not array_safe_eq(self_value_s, other_value_s): 463 diff_result[field_name] = {"self": self_value, "other": other_value} 464 465 # return the diff result 466 return diff_result 467 468 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 469 """update the instance from a nested dict, useful for configuration from command line args 470 471 # Parameters: 472 - `nested_dict : dict[str, Any]` 473 nested dict to update the instance with 474 """ 475 for field in dataclasses.fields(self): # type: ignore[arg-type] 476 field_name: str = field.name 477 self_value = getattr(self, field_name) 478 479 if field_name in nested_dict: 480 if isinstance(self_value, SerializableDataclass): 481 self_value.update_from_nested_dict(nested_dict[field_name]) 482 else: 483 setattr(self, field_name, nested_dict[field_name]) 484 485 def __copy__(self) -> "SerializableDataclass": 486 "deep copy by serializing and loading the instance to json" 487 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 488 489 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 490 "deep copy by serializing and loading the instance to json" 491 return self.__class__.load(json.loads(json.dumps(self.serialize())))
Base class for serializable dataclasses
only for linting and type checking, still need to call serializable_dataclass
decorator
Usage:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize()
to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{_FORMAT_KEY: "NestedClass(SerializableDataclass)", "x": "q", "y": {_FORMAT_KEY: "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
348 def serialize(self) -> dict[str, Any]: 349 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 350 raise NotImplementedError( 351 f"decorate {self.__class__ = } with `@serializable_dataclass`" 352 )
returns the class as a dict, implemented by using @serializable_dataclass
decorator
354 @classmethod 355 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 356 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 357 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
359 def validate_fields_types( 360 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 361 ) -> bool: 362 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 363 return SerializableDataclass__validate_fields_types( 364 self, on_typecheck_error=on_typecheck_error 365 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
367 def validate_field_type( 368 self, 369 field: "SerializableField|str", 370 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 371 ) -> bool: 372 """given a dataclass, check the field matches the type hint""" 373 return SerializableDataclass__validate_field_type( 374 self, field, on_typecheck_error=on_typecheck_error 375 )
given a dataclass, check the field matches the type hint
384 def diff( 385 self, other: "SerializableDataclass", of_serialized: bool = False 386 ) -> dict[str, Any]: 387 """get a rich and recursive diff between two instances of a serializable dataclass 388 389 ```python 390 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 391 {'b': {'self': 2, 'other': 3}} 392 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 393 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 394 ``` 395 396 # Parameters: 397 - `other : SerializableDataclass` 398 other instance to compare against 399 - `of_serialized : bool` 400 if true, compare serialized data and not raw values 401 (defaults to `False`) 402 403 # Returns: 404 - `dict[str, Any]` 405 406 407 # Raises: 408 - `ValueError` : if the instances are not of the same type 409 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 410 """ 411 # match types 412 if type(self) is not type(other): 413 raise ValueError( 414 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 415 ) 416 417 # initialize the diff result 418 diff_result: dict = {} 419 420 # if they are the same, return the empty diff 421 try: 422 if self == other: 423 return diff_result 424 except Exception: 425 pass 426 427 # if we are working with serialized data, serialize the instances 428 if of_serialized: 429 ser_self: dict = self.serialize() 430 ser_other: dict = other.serialize() 431 432 # for each field in the class 433 for field in dataclasses.fields(self): # type: ignore[arg-type] 434 # skip fields that are not for comparison 435 if not field.compare: 436 continue 437 438 # get values 439 field_name: str = field.name 440 self_value = getattr(self, field_name) 441 other_value = getattr(other, field_name) 442 443 # if the values are both serializable dataclasses, recurse 444 if isinstance(self_value, SerializableDataclass) and isinstance( 445 other_value, SerializableDataclass 446 ): 447 nested_diff: dict = self_value.diff( 448 other_value, of_serialized=of_serialized 449 ) 450 if nested_diff: 451 diff_result[field_name] = nested_diff 452 # only support serializable dataclasses 453 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 454 other_value 455 ): 456 raise ValueError("Non-serializable dataclass is not supported") 457 else: 458 # get the values of either the serialized or the actual values 459 self_value_s = ser_self[field_name] if of_serialized else self_value 460 other_value_s = ser_other[field_name] if of_serialized else other_value 461 # compare the values 462 if not array_safe_eq(self_value_s, other_value_s): 463 diff_result[field_name] = {"self": self_value, "other": other_value} 464 465 # return the diff result 466 return diff_result
get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
Parameters:
other : SerializableDataclass
other instance to compare againstof_serialized : bool
if true, compare serialized data and not raw values (defaults toFalse
)
Returns:
dict[str, Any]
Raises:
ValueError
: if the instances are not of the same typeValueError
: if the instances aredataclasses.dataclass
but notSerializableDataclass
468 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 469 """update the instance from a nested dict, useful for configuration from command line args 470 471 # Parameters: 472 - `nested_dict : dict[str, Any]` 473 nested dict to update the instance with 474 """ 475 for field in dataclasses.fields(self): # type: ignore[arg-type] 476 field_name: str = field.name 477 self_value = getattr(self, field_name) 478 479 if field_name in nested_dict: 480 if isinstance(self_value, SerializableDataclass): 481 self_value.update_from_nested_dict(nested_dict[field_name]) 482 else: 483 setattr(self, field_name, nested_dict[field_name])
update the instance from a nested dict, useful for configuration from command line args
Parameters:
- `nested_dict : dict[str, Any]`
nested dict to update the instance with
496@functools.lru_cache(typed=True) 497def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]: 498 "cached typing.get_type_hints for a class" 499 return typing.get_type_hints(cls)
cached typing.get_type_hints for a class
502def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: 503 "helper function to get type hints for a class" 504 cls_type_hints: dict[str, Any] 505 try: 506 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 507 if len(cls_type_hints) == 0: 508 cls_type_hints = typing.get_type_hints(cls) 509 510 if len(cls_type_hints) == 0: 511 raise ValueError(f"empty type hints for {cls.__name__ = }") 512 except (TypeError, NameError, ValueError) as e: 513 raise TypeError( 514 f"Cannot get type hints for {cls = }\n" 515 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 516 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 517 + f" {e = }" 518 ) from e 519 520 return cls_type_hints
helper function to get type hints for a class
523class KWOnlyError(NotImplementedError): 524 "kw-only dataclasses are not supported in python <3.9" 525 526 pass
kw-only dataclasses are not supported in python <3.9
Inherited Members
- builtins.NotImplementedError
- NotImplementedError
- builtins.BaseException
- with_traceback
- add_note
- args
base class for field errors
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
535class NotSerializableFieldException(FieldError): 536 "field is not a `SerializableField`" 537 538 pass
field is not a SerializableField
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
error while serializing a field
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
error while loading a field
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
553class FieldTypeMismatchError(FieldError, TypeError): 554 "error when a field type does not match the type hint" 555 556 pass
error when a field type does not match the type hint
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
559@dataclass_transform( 560 field_specifiers=(serializable_field, SerializableField), 561) 562def serializable_dataclass( 563 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 564 _cls=None, # type: ignore 565 *, 566 init: bool = True, 567 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 568 eq: bool = True, 569 order: bool = False, 570 unsafe_hash: bool = False, 571 frozen: bool = False, 572 properties_to_serialize: Optional[list[str]] = None, 573 register_handler: bool = True, 574 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 575 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 576 methods_no_override: list[str] | None = None, 577 **kwargs, 578): 579 """decorator to make a dataclass serializable. **must also make it inherit from `SerializableDataclass`!!** 580 581 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 582 583 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs. any kwargs not listed here are passed to `dataclasses.dataclass` 584 585 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 586 587 Examines PEP 526 `__annotations__` to determine fields. 588 589 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 590 591 ```python 592 @serializable_dataclass(kw_only=True) 593 class Myclass(SerializableDataclass): 594 a: int 595 b: str 596 ``` 597 ```python 598 >>> Myclass(a=1, b="q").serialize() 599 {_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 600 ``` 601 602 # Parameters: 603 604 - `_cls : _type_` 605 class to decorate. don't pass this arg, just use this as a decorator 606 (defaults to `None`) 607 - `init : bool` 608 whether to add an `__init__` method 609 *(passed to dataclasses.dataclass)* 610 (defaults to `True`) 611 - `repr : bool` 612 whether to add a `__repr__` method 613 *(passed to dataclasses.dataclass)* 614 (defaults to `True`) 615 - `order : bool` 616 whether to add rich comparison methods 617 *(passed to dataclasses.dataclass)* 618 (defaults to `False`) 619 - `unsafe_hash : bool` 620 whether to add a `__hash__` method 621 *(passed to dataclasses.dataclass)* 622 (defaults to `False`) 623 - `frozen : bool` 624 whether to make the class frozen 625 *(passed to dataclasses.dataclass)* 626 (defaults to `False`) 627 - `properties_to_serialize : Optional[list[str]]` 628 which properties to add to the serialized data dict 629 **SerializableDataclass only** 630 (defaults to `None`) 631 - `register_handler : bool` 632 if true, register the class with ZANJ for loading 633 **SerializableDataclass only** 634 (defaults to `True`) 635 - `on_typecheck_error : ErrorMode` 636 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 637 **SerializableDataclass only** 638 - `on_typecheck_mismatch : ErrorMode` 639 what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 640 **SerializableDataclass only** 641 - `methods_no_override : list[str]|None` 642 list of methods that should not be overridden by the decorator 643 by default, `__eq__`, `serialize`, `load`, and `validate_fields_types` are overridden by this function, 644 but you can disable this if you'd rather write your own. `dataclasses.dataclass` might still overwrite these, and those options take precedence 645 **SerializableDataclass only** 646 (defaults to `None`) 647 - `**kwargs` 648 *(passed to dataclasses.dataclass)* 649 650 # Returns: 651 652 - `_type_` 653 the decorated class 654 655 # Raises: 656 657 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 658 - `NotSerializableFieldException` : if a field is not a `SerializableField` 659 - `FieldSerializationError` : if there is an error serializing a field 660 - `AttributeError` : if a property is not found on the class 661 - `FieldLoadingError` : if there is an error loading a field 662 """ 663 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 664 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 665 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 666 667 if properties_to_serialize is None: 668 _properties_to_serialize: list = list() 669 else: 670 _properties_to_serialize = properties_to_serialize 671 672 def wrap(cls: Type[T]) -> Type[T]: 673 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 674 for field_name, field_type in cls.__annotations__.items(): 675 field_value = getattr(cls, field_name, None) 676 if not isinstance(field_value, SerializableField): 677 if isinstance(field_value, dataclasses.Field): 678 # Convert the field to a SerializableField while preserving properties 679 field_value = SerializableField.from_Field(field_value) 680 else: 681 # Create a new SerializableField 682 field_value = serializable_field() 683 setattr(cls, field_name, field_value) 684 685 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 686 if sys.version_info < (3, 10): 687 if "kw_only" in kwargs: 688 if kwargs["kw_only"] == True: # noqa: E712 689 raise KWOnlyError( 690 "kw_only is not supported in python < 3.10, but if you pass a `False` value, it will be ignored" 691 ) 692 else: 693 del kwargs["kw_only"] 694 695 # call `dataclasses.dataclass` to set some stuff up 696 cls = dataclasses.dataclass( # type: ignore[call-overload] 697 cls, 698 init=init, 699 repr=repr, 700 eq=eq, 701 order=order, 702 unsafe_hash=unsafe_hash, 703 frozen=frozen, 704 **kwargs, 705 ) 706 707 # copy these to the class 708 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 709 710 # ====================================================================== 711 # define `serialize` func 712 # done locally since it depends on args to the decorator 713 # ====================================================================== 714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result 771 772 # ====================================================================== 773 # define `load` func 774 # done locally since it depends on args to the decorator 775 # ====================================================================== 776 # mypy thinks this isnt a classmethod 777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output 864 865 _methods_no_override: set[str] 866 if methods_no_override is None: 867 _methods_no_override = set() 868 else: 869 _methods_no_override = set(methods_no_override) 870 871 if _methods_no_override - { 872 "__eq__", 873 "serialize", 874 "load", 875 "validate_fields_types", 876 }: 877 warnings.warn( 878 f"Unknown methods in `methods_no_override`: {_methods_no_override = }" 879 ) 880 881 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 882 if "serialize" not in _methods_no_override: 883 # type is `Callable[[T], dict]` 884 cls.serialize = serialize # type: ignore[attr-defined] 885 if "load" not in _methods_no_override: 886 # type is `Callable[[dict], T]` 887 cls.load = load # type: ignore[attr-defined] 888 889 if "validate_field_type" not in _methods_no_override: 890 # type is `Callable[[T, ErrorMode], bool]` 891 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 892 893 if "__eq__" not in _methods_no_override: 894 # type is `Callable[[T, T], bool]` 895 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 896 897 # Register the class with ZANJ 898 if register_handler: 899 zanj_register_loader_serializable_dataclass(cls) 900 901 return cls 902 903 if _cls is None: 904 return wrap 905 else: 906 return wrap(_cls)
decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass
!!
types will be validated (like pydantic) unless on_typecheck_mismatch
is set to ErrorMode.IGNORE
behavior of most kwargs matches that of dataclasses.dataclass
, but with some additional kwargs. any kwargs not listed here are passed to dataclasses.dataclass
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__
to determine fields.
If init is true, an __init__()
method is added to the class. If repr is true, a __repr__()
method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__()
method function is added. If frozen is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str
>>> Myclass(a=1, b="q").serialize()
{_FORMAT_KEY: 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
Parameters:
_cls : _type_
class to decorate. don't pass this arg, just use this as a decorator (defaults toNone
)init : bool
whether to add an__init__
method (passed to dataclasses.dataclass) (defaults toTrue
)repr : bool
whether to add a__repr__
method (passed to dataclasses.dataclass) (defaults toTrue
)order : bool
whether to add rich comparison methods (passed to dataclasses.dataclass) (defaults toFalse
)unsafe_hash : bool
whether to add a__hash__
method (passed to dataclasses.dataclass) (defaults toFalse
)frozen : bool
whether to make the class frozen (passed to dataclasses.dataclass) (defaults toFalse
)properties_to_serialize : Optional[list[str]]
which properties to add to the serialized data dict SerializableDataclass only (defaults toNone
)register_handler : bool
if true, register the class with ZANJ for loading SerializableDataclass only (defaults toTrue
)on_typecheck_error : ErrorMode
what to do if type checking throws an exception (except, warn, ignore). Ifignore
and an exception is thrown, type validation will still return false SerializableDataclass onlyon_typecheck_mismatch : ErrorMode
what to do if a type mismatch is found (except, warn, ignore). Ifignore
, type validation will returnTrue
SerializableDataclass onlymethods_no_override : list[str]|None
list of methods that should not be overridden by the decorator by default,__eq__
,serialize
,load
, andvalidate_fields_types
are overridden by this function, but you can disable this if you'd rather write your own.dataclasses.dataclass
might still overwrite these, and those options take precedence SerializableDataclass only (defaults toNone
)**kwargs
(passed to dataclasses.dataclass)
Returns:
_type_
the decorated class
Raises:
KWOnlyError
: only raised ifkw_only
isTrue
and python version is <3.9, sincedataclasses.dataclass
does not support thisNotSerializableFieldException
: if a field is not aSerializableField
FieldSerializationError
: if there is an error serializing a fieldAttributeError
: if a property is not found on the classFieldLoadingError
: if there is an error loading a field