muutils.json_serialize.serializable_dataclass
save and load objects to and from json or compatible formats in a recoverable way
d = dataclasses.asdict(my_obj)
will give you a dict, but if some fields are not json-serializable,
you will get an error when you call json.dumps(d)
. This module provides a way around that.
Instead, you define your class:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize()
to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
1"""save and load objects to and from json or compatible formats in a recoverable way 2 3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable, 4you will get an error when you call `json.dumps(d)`. This module provides a way around that. 5 6Instead, you define your class: 7 8```python 9@serializable_dataclass 10class MyClass(SerializableDataclass): 11 a: int 12 b: str 13``` 14 15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 16 17 >>> my_obj = MyClass(a=1, b="q") 18 >>> s = json.dumps(my_obj.serialize()) 19 >>> s 20 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 21 >>> read_obj = MyClass.load(json.loads(s)) 22 >>> read_obj == my_obj 23 True 24 25This isn't too impressive on its own, but it gets more useful when you have nested classses, 26or fields that are not json-serializable by default: 27 28```python 29@serializable_dataclass 30class NestedClass(SerializableDataclass): 31 x: str 32 y: MyClass 33 act_fun: torch.nn.Module = serializable_field( 34 default=torch.nn.ReLU(), 35 serialization_fn=lambda x: str(x), 36 deserialize_fn=lambda x: getattr(torch.nn, x)(), 37 ) 38``` 39 40which gives us: 41 42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 43 >>> s = json.dumps(nc.serialize()) 44 >>> s 45 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 46 >>> read_nc = NestedClass.load(json.loads(s)) 47 >>> read_nc == nc 48 True 49 50""" 51 52from __future__ import annotations 53 54import abc 55import dataclasses 56import functools 57import json 58import sys 59import typing 60import warnings 61from typing import Any, Optional, Type, TypeVar 62 63from muutils.errormode import ErrorMode 64from muutils.validate_type import validate_type 65from muutils.json_serialize.serializable_field import ( 66 SerializableField, 67 serializable_field, 68) 69from muutils.json_serialize.util import array_safe_eq, dc_eq 70 71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access 72 73 74def _dataclass_transform_mock( 75 *, 76 eq_default: bool = True, 77 order_default: bool = False, 78 kw_only_default: bool = False, 79 frozen_default: bool = False, 80 field_specifiers: tuple[type[Any] | typing.Callable[..., Any], ...] = (), 81 **kwargs: Any, 82) -> typing.Callable: 83 "mock `typing.dataclass_transform` for python <3.11" 84 85 def decorator(cls_or_fn): 86 cls_or_fn.__dataclass_transform__ = { 87 "eq_default": eq_default, 88 "order_default": order_default, 89 "kw_only_default": kw_only_default, 90 "frozen_default": frozen_default, 91 "field_specifiers": field_specifiers, 92 "kwargs": kwargs, 93 } 94 return cls_or_fn 95 96 return decorator 97 98 99dataclass_transform: typing.Callable 100if sys.version_info < (3, 11): 101 dataclass_transform = _dataclass_transform_mock 102else: 103 dataclass_transform = typing.dataclass_transform 104 105 106T = TypeVar("T") 107 108 109class CantGetTypeHintsWarning(UserWarning): 110 "special warning for when we can't get type hints" 111 112 pass 113 114 115class ZanjMissingWarning(UserWarning): 116 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 117 118 pass 119 120 121_zanj_loading_needs_import: bool = True 122"flag to keep track of if we have successfully imported ZANJ" 123 124 125def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): 126 """Register a serializable dataclass with the ZANJ import 127 128 this allows `ZANJ().read()` to load the class and not just return plain dicts 129 130 131 # TODO: there is some duplication here with register_loader_handler 132 """ 133 global _zanj_loading_needs_import 134 135 if _zanj_loading_needs_import: 136 try: 137 from zanj.loading import ( # type: ignore[import] 138 LoaderHandler, 139 register_loader_handler, 140 ) 141 except ImportError: 142 warnings.warn( 143 "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 144 ZanjMissingWarning, 145 ) 146 return 147 148 _format: str = f"{cls.__name__}(SerializableDataclass)" 149 lh: LoaderHandler = LoaderHandler( 150 check=lambda json_item, path=None, z=None: ( # type: ignore 151 isinstance(json_item, dict) 152 and "__format__" in json_item 153 and json_item["__format__"].startswith(_format) 154 ), 155 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 156 uid=_format, 157 source_pckg=cls.__module__, 158 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 159 ) 160 161 register_loader_handler(lh) 162 163 return lh 164 165 166_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN 167_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT 168 169 170class FieldIsNotInitOrSerializeWarning(UserWarning): 171 pass 172 173 174def SerializableDataclass__validate_field_type( 175 self: SerializableDataclass, 176 field: SerializableField | str, 177 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 178) -> bool: 179 """given a dataclass, check the field matches the type hint 180 181 this function is written to `SerializableDataclass.validate_field_type` 182 183 # Parameters: 184 - `self : SerializableDataclass` 185 `SerializableDataclass` instance 186 - `field : SerializableField | str` 187 field to validate, will get from `self.__dataclass_fields__` if an `str` 188 - `on_typecheck_error : ErrorMode` 189 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 190 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 191 192 # Returns: 193 - `bool` 194 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 195 """ 196 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 197 198 # get field 199 _field: SerializableField 200 if isinstance(field, str): 201 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 202 else: 203 _field = field 204 205 # do nothing case 206 if not _field.assert_type: 207 return True 208 209 # if field is not `init` or not `serialize`, skip but warn 210 # TODO: how to handle fields which are not `init` or `serialize`? 211 if not _field.init or not _field.serialize: 212 warnings.warn( 213 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 214 FieldIsNotInitOrSerializeWarning, 215 ) 216 return True 217 218 assert isinstance( 219 _field, SerializableField 220 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 221 222 # get field type hints 223 try: 224 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 225 except KeyError as e: 226 on_typecheck_error.process( 227 ( 228 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 229 + f"{get_cls_type_hints(self.__class__) = }\n" 230 + f"Python version is {sys.version_info = }. You can:\n" 231 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 232 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 233 + " - use python 3.9.x or higher\n" 234 + " - specify custom type validation function via `custom_typecheck_fn`\n" 235 ), 236 except_cls=TypeError, 237 except_from=e, 238 ) 239 return False 240 241 # get the value 242 value: Any = getattr(self, _field.name) 243 244 # validate the type 245 try: 246 type_is_valid: bool 247 # validate the type with the default type validator 248 if _field.custom_typecheck_fn is None: 249 type_is_valid = validate_type(value, field_type_hint) 250 # validate the type with a custom type validator 251 else: 252 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 253 254 return type_is_valid 255 256 except Exception as e: 257 on_typecheck_error.process( 258 "exception while validating type: " 259 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 260 except_cls=ValueError, 261 except_from=e, 262 ) 263 return False 264 265 266def SerializableDataclass__validate_fields_types__dict( 267 self: SerializableDataclass, 268 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 269) -> dict[str, bool]: 270 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 271 272 returns a dict of field names to bools, where the bool is if the field type is valid 273 """ 274 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 275 276 # if except, bundle the exceptions 277 results: dict[str, bool] = dict() 278 exceptions: dict[str, Exception] = dict() 279 280 # for each field in the class 281 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 282 for field in cls_fields: 283 try: 284 results[field.name] = self.validate_field_type(field, on_typecheck_error) 285 except Exception as e: 286 results[field.name] = False 287 exceptions[field.name] = e 288 289 # figure out what to do with the exceptions 290 if len(exceptions) > 0: 291 on_typecheck_error.process( 292 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 293 + "\n\t" 294 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 295 except_cls=ValueError, 296 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 297 except_from=list(exceptions.values())[0], 298 ) 299 300 return results 301 302 303def SerializableDataclass__validate_fields_types( 304 self: SerializableDataclass, 305 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 306) -> bool: 307 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 308 return all( 309 SerializableDataclass__validate_fields_types__dict( 310 self, on_typecheck_error=on_typecheck_error 311 ).values() 312 ) 313 314 315@dataclass_transform( 316 field_specifiers=(serializable_field, SerializableField), 317) 318class SerializableDataclass(abc.ABC): 319 """Base class for serializable dataclasses 320 321 only for linting and type checking, still need to call `serializable_dataclass` decorator 322 323 # Usage: 324 325 ```python 326 @serializable_dataclass 327 class MyClass(SerializableDataclass): 328 a: int 329 b: str 330 ``` 331 332 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 333 334 >>> my_obj = MyClass(a=1, b="q") 335 >>> s = json.dumps(my_obj.serialize()) 336 >>> s 337 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 338 >>> read_obj = MyClass.load(json.loads(s)) 339 >>> read_obj == my_obj 340 True 341 342 This isn't too impressive on its own, but it gets more useful when you have nested classses, 343 or fields that are not json-serializable by default: 344 345 ```python 346 @serializable_dataclass 347 class NestedClass(SerializableDataclass): 348 x: str 349 y: MyClass 350 act_fun: torch.nn.Module = serializable_field( 351 default=torch.nn.ReLU(), 352 serialization_fn=lambda x: str(x), 353 deserialize_fn=lambda x: getattr(torch.nn, x)(), 354 ) 355 ``` 356 357 which gives us: 358 359 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 360 >>> s = json.dumps(nc.serialize()) 361 >>> s 362 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 363 >>> read_nc = NestedClass.load(json.loads(s)) 364 >>> read_nc == nc 365 True 366 """ 367 368 def serialize(self) -> dict[str, Any]: 369 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 370 raise NotImplementedError( 371 f"decorate {self.__class__ = } with `@serializable_dataclass`" 372 ) 373 374 @classmethod 375 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 376 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 377 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 378 379 def validate_fields_types( 380 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 381 ) -> bool: 382 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 383 return SerializableDataclass__validate_fields_types( 384 self, on_typecheck_error=on_typecheck_error 385 ) 386 387 def validate_field_type( 388 self, 389 field: "SerializableField|str", 390 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 391 ) -> bool: 392 """given a dataclass, check the field matches the type hint""" 393 return SerializableDataclass__validate_field_type( 394 self, field, on_typecheck_error=on_typecheck_error 395 ) 396 397 def __eq__(self, other: Any) -> bool: 398 return dc_eq(self, other) 399 400 def __hash__(self) -> int: 401 "hashes the json-serialized representation of the class" 402 return hash(json.dumps(self.serialize())) 403 404 def diff( 405 self, other: "SerializableDataclass", of_serialized: bool = False 406 ) -> dict[str, Any]: 407 """get a rich and recursive diff between two instances of a serializable dataclass 408 409 ```python 410 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 411 {'b': {'self': 2, 'other': 3}} 412 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 413 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 414 ``` 415 416 # Parameters: 417 - `other : SerializableDataclass` 418 other instance to compare against 419 - `of_serialized : bool` 420 if true, compare serialized data and not raw values 421 (defaults to `False`) 422 423 # Returns: 424 - `dict[str, Any]` 425 426 427 # Raises: 428 - `ValueError` : if the instances are not of the same type 429 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 430 """ 431 # match types 432 if type(self) is not type(other): 433 raise ValueError( 434 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 435 ) 436 437 # initialize the diff result 438 diff_result: dict = {} 439 440 # if they are the same, return the empty diff 441 if self == other: 442 return diff_result 443 444 # if we are working with serialized data, serialize the instances 445 if of_serialized: 446 ser_self: dict = self.serialize() 447 ser_other: dict = other.serialize() 448 449 # for each field in the class 450 for field in dataclasses.fields(self): # type: ignore[arg-type] 451 # skip fields that are not for comparison 452 if not field.compare: 453 continue 454 455 # get values 456 field_name: str = field.name 457 self_value = getattr(self, field_name) 458 other_value = getattr(other, field_name) 459 460 # if the values are both serializable dataclasses, recurse 461 if isinstance(self_value, SerializableDataclass) and isinstance( 462 other_value, SerializableDataclass 463 ): 464 nested_diff: dict = self_value.diff( 465 other_value, of_serialized=of_serialized 466 ) 467 if nested_diff: 468 diff_result[field_name] = nested_diff 469 # only support serializable dataclasses 470 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 471 other_value 472 ): 473 raise ValueError("Non-serializable dataclass is not supported") 474 else: 475 # get the values of either the serialized or the actual values 476 self_value_s = ser_self[field_name] if of_serialized else self_value 477 other_value_s = ser_other[field_name] if of_serialized else other_value 478 # compare the values 479 if not array_safe_eq(self_value_s, other_value_s): 480 diff_result[field_name] = {"self": self_value, "other": other_value} 481 482 # return the diff result 483 return diff_result 484 485 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 486 """update the instance from a nested dict, useful for configuration from command line args 487 488 # Parameters: 489 - `nested_dict : dict[str, Any]` 490 nested dict to update the instance with 491 """ 492 for field in dataclasses.fields(self): # type: ignore[arg-type] 493 field_name: str = field.name 494 self_value = getattr(self, field_name) 495 496 if field_name in nested_dict: 497 if isinstance(self_value, SerializableDataclass): 498 self_value.update_from_nested_dict(nested_dict[field_name]) 499 else: 500 setattr(self, field_name, nested_dict[field_name]) 501 502 def __copy__(self) -> "SerializableDataclass": 503 "deep copy by serializing and loading the instance to json" 504 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 505 506 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 507 "deep copy by serializing and loading the instance to json" 508 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 509 510 511# cache this so we don't have to keep getting it 512# TODO: are the types hashable? does this even make sense? 513@functools.lru_cache(typed=True) 514def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]: 515 "cached typing.get_type_hints for a class" 516 return typing.get_type_hints(cls) 517 518 519def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: 520 "helper function to get type hints for a class" 521 cls_type_hints: dict[str, Any] 522 try: 523 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 524 if len(cls_type_hints) == 0: 525 cls_type_hints = typing.get_type_hints(cls) 526 527 if len(cls_type_hints) == 0: 528 raise ValueError(f"empty type hints for {cls.__name__ = }") 529 except (TypeError, NameError, ValueError) as e: 530 raise TypeError( 531 f"Cannot get type hints for {cls = }\n" 532 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 533 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 534 + f" {e = }" 535 ) from e 536 537 return cls_type_hints 538 539 540class KWOnlyError(NotImplementedError): 541 "kw-only dataclasses are not supported in python <3.9" 542 543 pass 544 545 546class FieldError(ValueError): 547 "base class for field errors" 548 549 pass 550 551 552class NotSerializableFieldException(FieldError): 553 "field is not a `SerializableField`" 554 555 pass 556 557 558class FieldSerializationError(FieldError): 559 "error while serializing a field" 560 561 pass 562 563 564class FieldLoadingError(FieldError): 565 "error while loading a field" 566 567 pass 568 569 570class FieldTypeMismatchError(FieldError, TypeError): 571 "error when a field type does not match the type hint" 572 573 pass 574 575 576@dataclass_transform( 577 field_specifiers=(serializable_field, SerializableField), 578) 579def serializable_dataclass( 580 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 581 _cls=None, # type: ignore 582 *, 583 init: bool = True, 584 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 585 eq: bool = True, 586 order: bool = False, 587 unsafe_hash: bool = False, 588 frozen: bool = False, 589 properties_to_serialize: Optional[list[str]] = None, 590 register_handler: bool = True, 591 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 592 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 593 **kwargs, 594): 595 """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass` 596 597 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 598 599 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs 600 601 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 602 603 Examines PEP 526 `__annotations__` to determine fields. 604 605 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 606 607 ```python 608 @serializable_dataclass(kw_only=True) 609 class Myclass(SerializableDataclass): 610 a: int 611 b: str 612 ``` 613 ```python 614 >>> Myclass(a=1, b="q").serialize() 615 {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 616 ``` 617 618 # Parameters: 619 - `_cls : _type_` 620 class to decorate. don't pass this arg, just use this as a decorator 621 (defaults to `None`) 622 - `init : bool` 623 (defaults to `True`) 624 - `repr : bool` 625 (defaults to `True`) 626 - `order : bool` 627 (defaults to `False`) 628 - `unsafe_hash : bool` 629 (defaults to `False`) 630 - `frozen : bool` 631 (defaults to `False`) 632 - `properties_to_serialize : Optional[list[str]]` 633 **SerializableDataclass only:** which properties to add to the serialized data dict 634 (defaults to `None`) 635 - `register_handler : bool` 636 **SerializableDataclass only:** if true, register the class with ZANJ for loading 637 (defaults to `True`) 638 - `on_typecheck_error : ErrorMode` 639 **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 640 - `on_typecheck_mismatch : ErrorMode` 641 **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 642 643 # Returns: 644 - `_type_` 645 the decorated class 646 647 # Raises: 648 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 649 - `NotSerializableFieldException` : if a field is not a `SerializableField` 650 - `FieldSerializationError` : if there is an error serializing a field 651 - `AttributeError` : if a property is not found on the class 652 - `FieldLoadingError` : if there is an error loading a field 653 """ 654 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 655 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 656 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 657 658 if properties_to_serialize is None: 659 _properties_to_serialize: list = list() 660 else: 661 _properties_to_serialize = properties_to_serialize 662 663 def wrap(cls: Type[T]) -> Type[T]: 664 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 665 for field_name, field_type in cls.__annotations__.items(): 666 field_value = getattr(cls, field_name, None) 667 if not isinstance(field_value, SerializableField): 668 if isinstance(field_value, dataclasses.Field): 669 # Convert the field to a SerializableField while preserving properties 670 field_value = SerializableField.from_Field(field_value) 671 else: 672 # Create a new SerializableField 673 field_value = serializable_field() 674 setattr(cls, field_name, field_value) 675 676 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 677 if sys.version_info < (3, 10): 678 if "kw_only" in kwargs: 679 if kwargs["kw_only"] == True: # noqa: E712 680 raise KWOnlyError("kw_only is not supported in python >=3.9") 681 else: 682 del kwargs["kw_only"] 683 684 # call `dataclasses.dataclass` to set some stuff up 685 cls = dataclasses.dataclass( # type: ignore[call-overload] 686 cls, 687 init=init, 688 repr=repr, 689 eq=eq, 690 order=order, 691 unsafe_hash=unsafe_hash, 692 frozen=frozen, 693 **kwargs, 694 ) 695 696 # copy these to the class 697 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 698 699 # ====================================================================== 700 # define `serialize` func 701 # done locally since it depends on args to the decorator 702 # ====================================================================== 703 def serialize(self) -> dict[str, Any]: 704 result: dict[str, Any] = { 705 "__format__": f"{self.__class__.__name__}(SerializableDataclass)" 706 } 707 # for each field in the class 708 for field in dataclasses.fields(self): # type: ignore[arg-type] 709 # need it to be our special SerializableField 710 if not isinstance(field, SerializableField): 711 raise NotSerializableFieldException( 712 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 713 f"but a {type(field)} " 714 "this state should be inaccessible, please report this bug!" 715 ) 716 717 # try to save it 718 if field.serialize: 719 try: 720 # get the val 721 value = getattr(self, field.name) 722 # if it is a serializable dataclass, serialize it 723 if isinstance(value, SerializableDataclass): 724 value = value.serialize() 725 # if the value has a serialization function, use that 726 if hasattr(value, "serialize") and callable(value.serialize): 727 value = value.serialize() 728 # if the field has a serialization function, use that 729 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 730 elif field.serialization_fn: 731 value = field.serialization_fn(value) 732 733 # store the value in the result 734 result[field.name] = value 735 except Exception as e: 736 raise FieldSerializationError( 737 "\n".join( 738 [ 739 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 740 f"{field = }", 741 f"{value = }", 742 f"{self = }", 743 ] 744 ) 745 ) from e 746 747 # store each property if we can get it 748 for prop in self._properties_to_serialize: 749 if hasattr(cls, prop): 750 value = getattr(self, prop) 751 result[prop] = value 752 else: 753 raise AttributeError( 754 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 755 + f"but it is in {self._properties_to_serialize = }" 756 + f"\n{self = }" 757 ) 758 759 return result 760 761 # ====================================================================== 762 # define `load` func 763 # done locally since it depends on args to the decorator 764 # ====================================================================== 765 # mypy thinks this isnt a classmethod 766 @classmethod # type: ignore[misc] 767 def load(cls, data: dict[str, Any] | T) -> Type[T]: 768 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 769 if isinstance(data, cls): 770 return data 771 772 assert isinstance( 773 data, typing.Mapping 774 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 775 776 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 777 778 # initialize dict for keeping what we will pass to the constructor 779 ctor_kwargs: dict[str, Any] = dict() 780 781 # iterate over the fields of the class 782 for field in dataclasses.fields(cls): 783 # check if the field is a SerializableField 784 assert isinstance( 785 field, SerializableField 786 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 787 788 # check if the field is in the data and if it should be initialized 789 if (field.name in data) and field.init: 790 # get the value, we will be processing it 791 value: Any = data[field.name] 792 793 # get the type hint for the field 794 field_type_hint: Any = cls_type_hints.get(field.name, None) 795 796 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 797 if field.deserialize_fn: 798 # if it has a deserialization function, use that 799 value = field.deserialize_fn(value) 800 elif field.loading_fn: 801 # if it has a loading function, use that 802 value = field.loading_fn(data) 803 elif ( 804 field_type_hint is not None 805 and hasattr(field_type_hint, "load") 806 and callable(field_type_hint.load) 807 ): 808 # if no loading function but has a type hint with a load method, use that 809 if isinstance(value, dict): 810 value = field_type_hint.load(value) 811 else: 812 raise FieldLoadingError( 813 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 814 ) 815 else: 816 # assume no loading needs to happen, keep `value` as-is 817 pass 818 819 # store the value in the constructor kwargs 820 ctor_kwargs[field.name] = value 821 822 # create a new instance of the class with the constructor kwargs 823 output: cls = cls(**ctor_kwargs) 824 825 # validate the types of the fields if needed 826 if on_typecheck_mismatch != ErrorMode.IGNORE: 827 fields_valid: dict[str, bool] = ( 828 SerializableDataclass__validate_fields_types__dict( 829 output, 830 on_typecheck_error=on_typecheck_error, 831 ) 832 ) 833 834 # if there are any fields that are not valid, raise an error 835 if not all(fields_valid.values()): 836 msg: str = ( 837 f"Type mismatch in fields of {cls.__name__}:\n" 838 + "\n".join( 839 [ 840 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 841 for k, v in fields_valid.items() 842 if not v 843 ] 844 ) 845 ) 846 847 on_typecheck_mismatch.process( 848 msg, except_cls=FieldTypeMismatchError 849 ) 850 851 # return the new instance 852 return output 853 854 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 855 # type is `Callable[[T], dict]` 856 cls.serialize = serialize # type: ignore[attr-defined] 857 # type is `Callable[[dict], T]` 858 cls.load = load # type: ignore[attr-defined] 859 # type is `Callable[[T, ErrorMode], bool]` 860 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 861 862 # type is `Callable[[T, T], bool]` 863 if not hasattr(cls, "__eq__"): 864 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 865 866 # Register the class with ZANJ 867 if register_handler: 868 zanj_register_loader_serializable_dataclass(cls) 869 870 return cls 871 872 if _cls is None: 873 return wrap 874 else: 875 return wrap(_cls)
3275def dataclass_transform( 3276 *, 3277 eq_default: bool = True, 3278 order_default: bool = False, 3279 kw_only_default: bool = False, 3280 frozen_default: bool = False, 3281 field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (), 3282 **kwargs: Any, 3283) -> _IdentityCallable: 3284 """Decorator to mark an object as providing dataclass-like behaviour. 3285 3286 The decorator can be applied to a function, class, or metaclass. 3287 3288 Example usage with a decorator function:: 3289 3290 @dataclass_transform() 3291 def create_model[T](cls: type[T]) -> type[T]: 3292 ... 3293 return cls 3294 3295 @create_model 3296 class CustomerModel: 3297 id: int 3298 name: str 3299 3300 On a base class:: 3301 3302 @dataclass_transform() 3303 class ModelBase: ... 3304 3305 class CustomerModel(ModelBase): 3306 id: int 3307 name: str 3308 3309 On a metaclass:: 3310 3311 @dataclass_transform() 3312 class ModelMeta(type): ... 3313 3314 class ModelBase(metaclass=ModelMeta): ... 3315 3316 class CustomerModel(ModelBase): 3317 id: int 3318 name: str 3319 3320 The ``CustomerModel`` classes defined above will 3321 be treated by type checkers similarly to classes created with 3322 ``@dataclasses.dataclass``. 3323 For example, type checkers will assume these classes have 3324 ``__init__`` methods that accept ``id`` and ``name``. 3325 3326 The arguments to this decorator can be used to customize this behavior: 3327 - ``eq_default`` indicates whether the ``eq`` parameter is assumed to be 3328 ``True`` or ``False`` if it is omitted by the caller. 3329 - ``order_default`` indicates whether the ``order`` parameter is 3330 assumed to be True or False if it is omitted by the caller. 3331 - ``kw_only_default`` indicates whether the ``kw_only`` parameter is 3332 assumed to be True or False if it is omitted by the caller. 3333 - ``frozen_default`` indicates whether the ``frozen`` parameter is 3334 assumed to be True or False if it is omitted by the caller. 3335 - ``field_specifiers`` specifies a static list of supported classes 3336 or functions that describe fields, similar to ``dataclasses.field()``. 3337 - Arbitrary other keyword arguments are accepted in order to allow for 3338 possible future extensions. 3339 3340 At runtime, this decorator records its arguments in the 3341 ``__dataclass_transform__`` attribute on the decorated object. 3342 It has no other runtime effect. 3343 3344 See PEP 681 for more details. 3345 """ 3346 def decorator(cls_or_fn): 3347 cls_or_fn.__dataclass_transform__ = { 3348 "eq_default": eq_default, 3349 "order_default": order_default, 3350 "kw_only_default": kw_only_default, 3351 "frozen_default": frozen_default, 3352 "field_specifiers": field_specifiers, 3353 "kwargs": kwargs, 3354 } 3355 return cls_or_fn 3356 return decorator
Decorator to mark an object as providing dataclass-like behaviour.
The decorator can be applied to a function, class, or metaclass.
Example usage with a decorator function::
@dataclass_transform()
def create_model[T](cls: type[T]) -> type[T]:
...
return cls
@create_model
class CustomerModel:
id: int
name: str
On a base class::
@dataclass_transform()
class ModelBase: ...
class CustomerModel(ModelBase):
id: int
name: str
On a metaclass::
@dataclass_transform()
class ModelMeta(type): ...
class ModelBase(metaclass=ModelMeta): ...
class CustomerModel(ModelBase):
id: int
name: str
The CustomerModel
classes defined above will
be treated by type checkers similarly to classes created with
@dataclasses.dataclass
.
For example, type checkers will assume these classes have
__init__
methods that accept id
and name
.
The arguments to this decorator can be used to customize this behavior:
eq_default
indicates whether theeq
parameter is assumed to beTrue
orFalse
if it is omitted by the caller.order_default
indicates whether theorder
parameter is assumed to be True or False if it is omitted by the caller.kw_only_default
indicates whether thekw_only
parameter is assumed to be True or False if it is omitted by the caller.frozen_default
indicates whether thefrozen
parameter is assumed to be True or False if it is omitted by the caller.field_specifiers
specifies a static list of supported classes or functions that describe fields, similar todataclasses.field()
.- Arbitrary other keyword arguments are accepted in order to allow for possible future extensions.
At runtime, this decorator records its arguments in the
__dataclass_transform__
attribute on the decorated object.
It has no other runtime effect.
See PEP 681 for more details.
110class CantGetTypeHintsWarning(UserWarning): 111 "special warning for when we can't get type hints" 112 113 pass
special warning for when we can't get type hints
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
116class ZanjMissingWarning(UserWarning): 117 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 118 119 pass
special warning for when ZANJ
is missing -- register_loader_serializable_dataclass
will not work
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
126def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): 127 """Register a serializable dataclass with the ZANJ import 128 129 this allows `ZANJ().read()` to load the class and not just return plain dicts 130 131 132 # TODO: there is some duplication here with register_loader_handler 133 """ 134 global _zanj_loading_needs_import 135 136 if _zanj_loading_needs_import: 137 try: 138 from zanj.loading import ( # type: ignore[import] 139 LoaderHandler, 140 register_loader_handler, 141 ) 142 except ImportError: 143 warnings.warn( 144 "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 145 ZanjMissingWarning, 146 ) 147 return 148 149 _format: str = f"{cls.__name__}(SerializableDataclass)" 150 lh: LoaderHandler = LoaderHandler( 151 check=lambda json_item, path=None, z=None: ( # type: ignore 152 isinstance(json_item, dict) 153 and "__format__" in json_item 154 and json_item["__format__"].startswith(_format) 155 ), 156 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 157 uid=_format, 158 source_pckg=cls.__module__, 159 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 160 ) 161 162 register_loader_handler(lh) 163 164 return lh
Register a serializable dataclass with the ZANJ import
this allows ZANJ().read()
to load the class and not just return plain dicts
TODO: there is some duplication here with register_loader_handler
Base class for warnings generated by user code.
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
175def SerializableDataclass__validate_field_type( 176 self: SerializableDataclass, 177 field: SerializableField | str, 178 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 179) -> bool: 180 """given a dataclass, check the field matches the type hint 181 182 this function is written to `SerializableDataclass.validate_field_type` 183 184 # Parameters: 185 - `self : SerializableDataclass` 186 `SerializableDataclass` instance 187 - `field : SerializableField | str` 188 field to validate, will get from `self.__dataclass_fields__` if an `str` 189 - `on_typecheck_error : ErrorMode` 190 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 191 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 192 193 # Returns: 194 - `bool` 195 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 196 """ 197 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 198 199 # get field 200 _field: SerializableField 201 if isinstance(field, str): 202 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 203 else: 204 _field = field 205 206 # do nothing case 207 if not _field.assert_type: 208 return True 209 210 # if field is not `init` or not `serialize`, skip but warn 211 # TODO: how to handle fields which are not `init` or `serialize`? 212 if not _field.init or not _field.serialize: 213 warnings.warn( 214 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 215 FieldIsNotInitOrSerializeWarning, 216 ) 217 return True 218 219 assert isinstance( 220 _field, SerializableField 221 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 222 223 # get field type hints 224 try: 225 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 226 except KeyError as e: 227 on_typecheck_error.process( 228 ( 229 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 230 + f"{get_cls_type_hints(self.__class__) = }\n" 231 + f"Python version is {sys.version_info = }. You can:\n" 232 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 233 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 234 + " - use python 3.9.x or higher\n" 235 + " - specify custom type validation function via `custom_typecheck_fn`\n" 236 ), 237 except_cls=TypeError, 238 except_from=e, 239 ) 240 return False 241 242 # get the value 243 value: Any = getattr(self, _field.name) 244 245 # validate the type 246 try: 247 type_is_valid: bool 248 # validate the type with the default type validator 249 if _field.custom_typecheck_fn is None: 250 type_is_valid = validate_type(value, field_type_hint) 251 # validate the type with a custom type validator 252 else: 253 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 254 255 return type_is_valid 256 257 except Exception as e: 258 on_typecheck_error.process( 259 "exception while validating type: " 260 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 261 except_cls=ValueError, 262 except_from=e, 263 ) 264 return False
given a dataclass, check the field matches the type hint
this function is written to SerializableDataclass.validate_field_type
Parameters:
self : SerializableDataclass
SerializableDataclass
instancefield : SerializableField | str
field to validate, will get fromself.__dataclass_fields__
if anstr
on_typecheck_error : ErrorMode
what to do if type checking throws an exception (except, warn, ignore). Ifignore
and an exception is thrown, the function will returnFalse
(defaults to_DEFAULT_ON_TYPECHECK_ERROR
)
Returns:
bool
if the field type is correct.False
if the field type is incorrect or an exception is thrown andon_typecheck_error
isignore
267def SerializableDataclass__validate_fields_types__dict( 268 self: SerializableDataclass, 269 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 270) -> dict[str, bool]: 271 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 272 273 returns a dict of field names to bools, where the bool is if the field type is valid 274 """ 275 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 276 277 # if except, bundle the exceptions 278 results: dict[str, bool] = dict() 279 exceptions: dict[str, Exception] = dict() 280 281 # for each field in the class 282 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 283 for field in cls_fields: 284 try: 285 results[field.name] = self.validate_field_type(field, on_typecheck_error) 286 except Exception as e: 287 results[field.name] = False 288 exceptions[field.name] = e 289 290 # figure out what to do with the exceptions 291 if len(exceptions) > 0: 292 on_typecheck_error.process( 293 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 294 + "\n\t" 295 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 296 except_cls=ValueError, 297 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 298 except_from=list(exceptions.values())[0], 299 ) 300 301 return results
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
returns a dict of field names to bools, where the bool is if the field type is valid
304def SerializableDataclass__validate_fields_types( 305 self: SerializableDataclass, 306 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 307) -> bool: 308 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 309 return all( 310 SerializableDataclass__validate_fields_types__dict( 311 self, on_typecheck_error=on_typecheck_error 312 ).values() 313 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
316@dataclass_transform( 317 field_specifiers=(serializable_field, SerializableField), 318) 319class SerializableDataclass(abc.ABC): 320 """Base class for serializable dataclasses 321 322 only for linting and type checking, still need to call `serializable_dataclass` decorator 323 324 # Usage: 325 326 ```python 327 @serializable_dataclass 328 class MyClass(SerializableDataclass): 329 a: int 330 b: str 331 ``` 332 333 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 334 335 >>> my_obj = MyClass(a=1, b="q") 336 >>> s = json.dumps(my_obj.serialize()) 337 >>> s 338 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 339 >>> read_obj = MyClass.load(json.loads(s)) 340 >>> read_obj == my_obj 341 True 342 343 This isn't too impressive on its own, but it gets more useful when you have nested classses, 344 or fields that are not json-serializable by default: 345 346 ```python 347 @serializable_dataclass 348 class NestedClass(SerializableDataclass): 349 x: str 350 y: MyClass 351 act_fun: torch.nn.Module = serializable_field( 352 default=torch.nn.ReLU(), 353 serialization_fn=lambda x: str(x), 354 deserialize_fn=lambda x: getattr(torch.nn, x)(), 355 ) 356 ``` 357 358 which gives us: 359 360 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 361 >>> s = json.dumps(nc.serialize()) 362 >>> s 363 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 364 >>> read_nc = NestedClass.load(json.loads(s)) 365 >>> read_nc == nc 366 True 367 """ 368 369 def serialize(self) -> dict[str, Any]: 370 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 371 raise NotImplementedError( 372 f"decorate {self.__class__ = } with `@serializable_dataclass`" 373 ) 374 375 @classmethod 376 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 377 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 378 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 379 380 def validate_fields_types( 381 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 382 ) -> bool: 383 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 384 return SerializableDataclass__validate_fields_types( 385 self, on_typecheck_error=on_typecheck_error 386 ) 387 388 def validate_field_type( 389 self, 390 field: "SerializableField|str", 391 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 392 ) -> bool: 393 """given a dataclass, check the field matches the type hint""" 394 return SerializableDataclass__validate_field_type( 395 self, field, on_typecheck_error=on_typecheck_error 396 ) 397 398 def __eq__(self, other: Any) -> bool: 399 return dc_eq(self, other) 400 401 def __hash__(self) -> int: 402 "hashes the json-serialized representation of the class" 403 return hash(json.dumps(self.serialize())) 404 405 def diff( 406 self, other: "SerializableDataclass", of_serialized: bool = False 407 ) -> dict[str, Any]: 408 """get a rich and recursive diff between two instances of a serializable dataclass 409 410 ```python 411 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 412 {'b': {'self': 2, 'other': 3}} 413 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 414 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 415 ``` 416 417 # Parameters: 418 - `other : SerializableDataclass` 419 other instance to compare against 420 - `of_serialized : bool` 421 if true, compare serialized data and not raw values 422 (defaults to `False`) 423 424 # Returns: 425 - `dict[str, Any]` 426 427 428 # Raises: 429 - `ValueError` : if the instances are not of the same type 430 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 431 """ 432 # match types 433 if type(self) is not type(other): 434 raise ValueError( 435 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 436 ) 437 438 # initialize the diff result 439 diff_result: dict = {} 440 441 # if they are the same, return the empty diff 442 if self == other: 443 return diff_result 444 445 # if we are working with serialized data, serialize the instances 446 if of_serialized: 447 ser_self: dict = self.serialize() 448 ser_other: dict = other.serialize() 449 450 # for each field in the class 451 for field in dataclasses.fields(self): # type: ignore[arg-type] 452 # skip fields that are not for comparison 453 if not field.compare: 454 continue 455 456 # get values 457 field_name: str = field.name 458 self_value = getattr(self, field_name) 459 other_value = getattr(other, field_name) 460 461 # if the values are both serializable dataclasses, recurse 462 if isinstance(self_value, SerializableDataclass) and isinstance( 463 other_value, SerializableDataclass 464 ): 465 nested_diff: dict = self_value.diff( 466 other_value, of_serialized=of_serialized 467 ) 468 if nested_diff: 469 diff_result[field_name] = nested_diff 470 # only support serializable dataclasses 471 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 472 other_value 473 ): 474 raise ValueError("Non-serializable dataclass is not supported") 475 else: 476 # get the values of either the serialized or the actual values 477 self_value_s = ser_self[field_name] if of_serialized else self_value 478 other_value_s = ser_other[field_name] if of_serialized else other_value 479 # compare the values 480 if not array_safe_eq(self_value_s, other_value_s): 481 diff_result[field_name] = {"self": self_value, "other": other_value} 482 483 # return the diff result 484 return diff_result 485 486 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 487 """update the instance from a nested dict, useful for configuration from command line args 488 489 # Parameters: 490 - `nested_dict : dict[str, Any]` 491 nested dict to update the instance with 492 """ 493 for field in dataclasses.fields(self): # type: ignore[arg-type] 494 field_name: str = field.name 495 self_value = getattr(self, field_name) 496 497 if field_name in nested_dict: 498 if isinstance(self_value, SerializableDataclass): 499 self_value.update_from_nested_dict(nested_dict[field_name]) 500 else: 501 setattr(self, field_name, nested_dict[field_name]) 502 503 def __copy__(self) -> "SerializableDataclass": 504 "deep copy by serializing and loading the instance to json" 505 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 506 507 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 508 "deep copy by serializing and loading the instance to json" 509 return self.__class__.load(json.loads(json.dumps(self.serialize())))
Base class for serializable dataclasses
only for linting and type checking, still need to call serializable_dataclass
decorator
Usage:
@serializable_dataclass
class MyClass(SerializableDataclass):
a: int
b: str
and then you can call my_obj.serialize()
to get a dict that can be serialized to json. So, you can do:
>>> my_obj = MyClass(a=1, b="q")
>>> s = json.dumps(my_obj.serialize())
>>> s
'{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}'
>>> read_obj = MyClass.load(json.loads(s))
>>> read_obj == my_obj
True
This isn't too impressive on its own, but it gets more useful when you have nested classses, or fields that are not json-serializable by default:
@serializable_dataclass
class NestedClass(SerializableDataclass):
x: str
y: MyClass
act_fun: torch.nn.Module = serializable_field(
default=torch.nn.ReLU(),
serialization_fn=lambda x: str(x),
deserialize_fn=lambda x: getattr(torch.nn, x)(),
)
which gives us:
>>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid())
>>> s = json.dumps(nc.serialize())
>>> s
'{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}'
>>> read_nc = NestedClass.load(json.loads(s))
>>> read_nc == nc
True
369 def serialize(self) -> dict[str, Any]: 370 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 371 raise NotImplementedError( 372 f"decorate {self.__class__ = } with `@serializable_dataclass`" 373 )
returns the class as a dict, implemented by using @serializable_dataclass
decorator
375 @classmethod 376 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 377 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 378 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`")
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
380 def validate_fields_types( 381 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 382 ) -> bool: 383 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 384 return SerializableDataclass__validate_fields_types( 385 self, on_typecheck_error=on_typecheck_error 386 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
388 def validate_field_type( 389 self, 390 field: "SerializableField|str", 391 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 392 ) -> bool: 393 """given a dataclass, check the field matches the type hint""" 394 return SerializableDataclass__validate_field_type( 395 self, field, on_typecheck_error=on_typecheck_error 396 )
given a dataclass, check the field matches the type hint
405 def diff( 406 self, other: "SerializableDataclass", of_serialized: bool = False 407 ) -> dict[str, Any]: 408 """get a rich and recursive diff between two instances of a serializable dataclass 409 410 ```python 411 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 412 {'b': {'self': 2, 'other': 3}} 413 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 414 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 415 ``` 416 417 # Parameters: 418 - `other : SerializableDataclass` 419 other instance to compare against 420 - `of_serialized : bool` 421 if true, compare serialized data and not raw values 422 (defaults to `False`) 423 424 # Returns: 425 - `dict[str, Any]` 426 427 428 # Raises: 429 - `ValueError` : if the instances are not of the same type 430 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 431 """ 432 # match types 433 if type(self) is not type(other): 434 raise ValueError( 435 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 436 ) 437 438 # initialize the diff result 439 diff_result: dict = {} 440 441 # if they are the same, return the empty diff 442 if self == other: 443 return diff_result 444 445 # if we are working with serialized data, serialize the instances 446 if of_serialized: 447 ser_self: dict = self.serialize() 448 ser_other: dict = other.serialize() 449 450 # for each field in the class 451 for field in dataclasses.fields(self): # type: ignore[arg-type] 452 # skip fields that are not for comparison 453 if not field.compare: 454 continue 455 456 # get values 457 field_name: str = field.name 458 self_value = getattr(self, field_name) 459 other_value = getattr(other, field_name) 460 461 # if the values are both serializable dataclasses, recurse 462 if isinstance(self_value, SerializableDataclass) and isinstance( 463 other_value, SerializableDataclass 464 ): 465 nested_diff: dict = self_value.diff( 466 other_value, of_serialized=of_serialized 467 ) 468 if nested_diff: 469 diff_result[field_name] = nested_diff 470 # only support serializable dataclasses 471 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 472 other_value 473 ): 474 raise ValueError("Non-serializable dataclass is not supported") 475 else: 476 # get the values of either the serialized or the actual values 477 self_value_s = ser_self[field_name] if of_serialized else self_value 478 other_value_s = ser_other[field_name] if of_serialized else other_value 479 # compare the values 480 if not array_safe_eq(self_value_s, other_value_s): 481 diff_result[field_name] = {"self": self_value, "other": other_value} 482 483 # return the diff result 484 return diff_result
get a rich and recursive diff between two instances of a serializable dataclass
>>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3))
{'b': {'self': 2, 'other': 3}}
>>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3)))
{'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}}
Parameters:
other : SerializableDataclass
other instance to compare againstof_serialized : bool
if true, compare serialized data and not raw values (defaults toFalse
)
Returns:
dict[str, Any]
Raises:
ValueError
: if the instances are not of the same typeValueError
: if the instances aredataclasses.dataclass
but notSerializableDataclass
486 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 487 """update the instance from a nested dict, useful for configuration from command line args 488 489 # Parameters: 490 - `nested_dict : dict[str, Any]` 491 nested dict to update the instance with 492 """ 493 for field in dataclasses.fields(self): # type: ignore[arg-type] 494 field_name: str = field.name 495 self_value = getattr(self, field_name) 496 497 if field_name in nested_dict: 498 if isinstance(self_value, SerializableDataclass): 499 self_value.update_from_nested_dict(nested_dict[field_name]) 500 else: 501 setattr(self, field_name, nested_dict[field_name])
update the instance from a nested dict, useful for configuration from command line args
Parameters:
- `nested_dict : dict[str, Any]`
nested dict to update the instance with
514@functools.lru_cache(typed=True) 515def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]: 516 "cached typing.get_type_hints for a class" 517 return typing.get_type_hints(cls)
cached typing.get_type_hints for a class
520def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: 521 "helper function to get type hints for a class" 522 cls_type_hints: dict[str, Any] 523 try: 524 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 525 if len(cls_type_hints) == 0: 526 cls_type_hints = typing.get_type_hints(cls) 527 528 if len(cls_type_hints) == 0: 529 raise ValueError(f"empty type hints for {cls.__name__ = }") 530 except (TypeError, NameError, ValueError) as e: 531 raise TypeError( 532 f"Cannot get type hints for {cls = }\n" 533 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 534 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 535 + f" {e = }" 536 ) from e 537 538 return cls_type_hints
helper function to get type hints for a class
541class KWOnlyError(NotImplementedError): 542 "kw-only dataclasses are not supported in python <3.9" 543 544 pass
kw-only dataclasses are not supported in python <3.9
Inherited Members
- builtins.NotImplementedError
- NotImplementedError
- builtins.BaseException
- with_traceback
- add_note
- args
base class for field errors
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
553class NotSerializableFieldException(FieldError): 554 "field is not a `SerializableField`" 555 556 pass
field is not a SerializableField
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
error while serializing a field
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
error while loading a field
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
571class FieldTypeMismatchError(FieldError, TypeError): 572 "error when a field type does not match the type hint" 573 574 pass
error when a field type does not match the type hint
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
577@dataclass_transform( 578 field_specifiers=(serializable_field, SerializableField), 579) 580def serializable_dataclass( 581 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 582 _cls=None, # type: ignore 583 *, 584 init: bool = True, 585 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 586 eq: bool = True, 587 order: bool = False, 588 unsafe_hash: bool = False, 589 frozen: bool = False, 590 properties_to_serialize: Optional[list[str]] = None, 591 register_handler: bool = True, 592 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 593 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 594 **kwargs, 595): 596 """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass` 597 598 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 599 600 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs 601 602 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 603 604 Examines PEP 526 `__annotations__` to determine fields. 605 606 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 607 608 ```python 609 @serializable_dataclass(kw_only=True) 610 class Myclass(SerializableDataclass): 611 a: int 612 b: str 613 ``` 614 ```python 615 >>> Myclass(a=1, b="q").serialize() 616 {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 617 ``` 618 619 # Parameters: 620 - `_cls : _type_` 621 class to decorate. don't pass this arg, just use this as a decorator 622 (defaults to `None`) 623 - `init : bool` 624 (defaults to `True`) 625 - `repr : bool` 626 (defaults to `True`) 627 - `order : bool` 628 (defaults to `False`) 629 - `unsafe_hash : bool` 630 (defaults to `False`) 631 - `frozen : bool` 632 (defaults to `False`) 633 - `properties_to_serialize : Optional[list[str]]` 634 **SerializableDataclass only:** which properties to add to the serialized data dict 635 (defaults to `None`) 636 - `register_handler : bool` 637 **SerializableDataclass only:** if true, register the class with ZANJ for loading 638 (defaults to `True`) 639 - `on_typecheck_error : ErrorMode` 640 **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 641 - `on_typecheck_mismatch : ErrorMode` 642 **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 643 644 # Returns: 645 - `_type_` 646 the decorated class 647 648 # Raises: 649 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 650 - `NotSerializableFieldException` : if a field is not a `SerializableField` 651 - `FieldSerializationError` : if there is an error serializing a field 652 - `AttributeError` : if a property is not found on the class 653 - `FieldLoadingError` : if there is an error loading a field 654 """ 655 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 656 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 657 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 658 659 if properties_to_serialize is None: 660 _properties_to_serialize: list = list() 661 else: 662 _properties_to_serialize = properties_to_serialize 663 664 def wrap(cls: Type[T]) -> Type[T]: 665 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 666 for field_name, field_type in cls.__annotations__.items(): 667 field_value = getattr(cls, field_name, None) 668 if not isinstance(field_value, SerializableField): 669 if isinstance(field_value, dataclasses.Field): 670 # Convert the field to a SerializableField while preserving properties 671 field_value = SerializableField.from_Field(field_value) 672 else: 673 # Create a new SerializableField 674 field_value = serializable_field() 675 setattr(cls, field_name, field_value) 676 677 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 678 if sys.version_info < (3, 10): 679 if "kw_only" in kwargs: 680 if kwargs["kw_only"] == True: # noqa: E712 681 raise KWOnlyError("kw_only is not supported in python >=3.9") 682 else: 683 del kwargs["kw_only"] 684 685 # call `dataclasses.dataclass` to set some stuff up 686 cls = dataclasses.dataclass( # type: ignore[call-overload] 687 cls, 688 init=init, 689 repr=repr, 690 eq=eq, 691 order=order, 692 unsafe_hash=unsafe_hash, 693 frozen=frozen, 694 **kwargs, 695 ) 696 697 # copy these to the class 698 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 699 700 # ====================================================================== 701 # define `serialize` func 702 # done locally since it depends on args to the decorator 703 # ====================================================================== 704 def serialize(self) -> dict[str, Any]: 705 result: dict[str, Any] = { 706 "__format__": f"{self.__class__.__name__}(SerializableDataclass)" 707 } 708 # for each field in the class 709 for field in dataclasses.fields(self): # type: ignore[arg-type] 710 # need it to be our special SerializableField 711 if not isinstance(field, SerializableField): 712 raise NotSerializableFieldException( 713 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 714 f"but a {type(field)} " 715 "this state should be inaccessible, please report this bug!" 716 ) 717 718 # try to save it 719 if field.serialize: 720 try: 721 # get the val 722 value = getattr(self, field.name) 723 # if it is a serializable dataclass, serialize it 724 if isinstance(value, SerializableDataclass): 725 value = value.serialize() 726 # if the value has a serialization function, use that 727 if hasattr(value, "serialize") and callable(value.serialize): 728 value = value.serialize() 729 # if the field has a serialization function, use that 730 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 731 elif field.serialization_fn: 732 value = field.serialization_fn(value) 733 734 # store the value in the result 735 result[field.name] = value 736 except Exception as e: 737 raise FieldSerializationError( 738 "\n".join( 739 [ 740 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 741 f"{field = }", 742 f"{value = }", 743 f"{self = }", 744 ] 745 ) 746 ) from e 747 748 # store each property if we can get it 749 for prop in self._properties_to_serialize: 750 if hasattr(cls, prop): 751 value = getattr(self, prop) 752 result[prop] = value 753 else: 754 raise AttributeError( 755 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 756 + f"but it is in {self._properties_to_serialize = }" 757 + f"\n{self = }" 758 ) 759 760 return result 761 762 # ====================================================================== 763 # define `load` func 764 # done locally since it depends on args to the decorator 765 # ====================================================================== 766 # mypy thinks this isnt a classmethod 767 @classmethod # type: ignore[misc] 768 def load(cls, data: dict[str, Any] | T) -> Type[T]: 769 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 770 if isinstance(data, cls): 771 return data 772 773 assert isinstance( 774 data, typing.Mapping 775 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 776 777 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 778 779 # initialize dict for keeping what we will pass to the constructor 780 ctor_kwargs: dict[str, Any] = dict() 781 782 # iterate over the fields of the class 783 for field in dataclasses.fields(cls): 784 # check if the field is a SerializableField 785 assert isinstance( 786 field, SerializableField 787 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 788 789 # check if the field is in the data and if it should be initialized 790 if (field.name in data) and field.init: 791 # get the value, we will be processing it 792 value: Any = data[field.name] 793 794 # get the type hint for the field 795 field_type_hint: Any = cls_type_hints.get(field.name, None) 796 797 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 798 if field.deserialize_fn: 799 # if it has a deserialization function, use that 800 value = field.deserialize_fn(value) 801 elif field.loading_fn: 802 # if it has a loading function, use that 803 value = field.loading_fn(data) 804 elif ( 805 field_type_hint is not None 806 and hasattr(field_type_hint, "load") 807 and callable(field_type_hint.load) 808 ): 809 # if no loading function but has a type hint with a load method, use that 810 if isinstance(value, dict): 811 value = field_type_hint.load(value) 812 else: 813 raise FieldLoadingError( 814 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 815 ) 816 else: 817 # assume no loading needs to happen, keep `value` as-is 818 pass 819 820 # store the value in the constructor kwargs 821 ctor_kwargs[field.name] = value 822 823 # create a new instance of the class with the constructor kwargs 824 output: cls = cls(**ctor_kwargs) 825 826 # validate the types of the fields if needed 827 if on_typecheck_mismatch != ErrorMode.IGNORE: 828 fields_valid: dict[str, bool] = ( 829 SerializableDataclass__validate_fields_types__dict( 830 output, 831 on_typecheck_error=on_typecheck_error, 832 ) 833 ) 834 835 # if there are any fields that are not valid, raise an error 836 if not all(fields_valid.values()): 837 msg: str = ( 838 f"Type mismatch in fields of {cls.__name__}:\n" 839 + "\n".join( 840 [ 841 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 842 for k, v in fields_valid.items() 843 if not v 844 ] 845 ) 846 ) 847 848 on_typecheck_mismatch.process( 849 msg, except_cls=FieldTypeMismatchError 850 ) 851 852 # return the new instance 853 return output 854 855 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 856 # type is `Callable[[T], dict]` 857 cls.serialize = serialize # type: ignore[attr-defined] 858 # type is `Callable[[dict], T]` 859 cls.load = load # type: ignore[attr-defined] 860 # type is `Callable[[T, ErrorMode], bool]` 861 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 862 863 # type is `Callable[[T, T], bool]` 864 if not hasattr(cls, "__eq__"): 865 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 866 867 # Register the class with ZANJ 868 if register_handler: 869 zanj_register_loader_serializable_dataclass(cls) 870 871 return cls 872 873 if _cls is None: 874 return wrap 875 else: 876 return wrap(_cls)
decorator to make a dataclass serializable. must also make it inherit from SerializableDataclass
types will be validated (like pydantic) unless on_typecheck_mismatch
is set to ErrorMode.IGNORE
behavior of most kwargs matches that of dataclasses.dataclass
, but with some additional kwargs
Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.
Examines PEP 526 __annotations__
to determine fields.
If init is true, an __init__()
method is added to the class. If repr is true, a __repr__()
method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__()
method function is added. If frozen is true, fields may not be assigned to after instance creation.
@serializable_dataclass(kw_only=True)
class Myclass(SerializableDataclass):
a: int
b: str
>>> Myclass(a=1, b="q").serialize()
{'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'}
Parameters:
_cls : _type_
class to decorate. don't pass this arg, just use this as a decorator (defaults toNone
)init : bool
(defaults toTrue
)repr : bool
(defaults toTrue
)order : bool
(defaults toFalse
)unsafe_hash : bool
(defaults toFalse
)frozen : bool
(defaults toFalse
)properties_to_serialize : Optional[list[str]]
SerializableDataclass only: which properties to add to the serialized data dict (defaults toNone
)register_handler : bool
SerializableDataclass only: if true, register the class with ZANJ for loading (defaults toTrue
)on_typecheck_error : ErrorMode
SerializableDataclass only: what to do if type checking throws an exception (except, warn, ignore). Ifignore
and an exception is thrown, type validation will still return falseon_typecheck_mismatch : ErrorMode
SerializableDataclass only: what to do if a type mismatch is found (except, warn, ignore). Ifignore
, type validation will returnTrue
Returns:
_type_
the decorated class
Raises:
KWOnlyError
: only raised ifkw_only
isTrue
and python version is <3.9, sincedataclasses.dataclass
does not support thisNotSerializableFieldException
: if a field is not aSerializableField
FieldSerializationError
: if there is an error serializing a fieldAttributeError
: if a property is not found on the classFieldLoadingError
: if there is an error loading a field