Coverage for muutils\json_serialize\serializable_dataclass.py: 55%

248 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-08 01:02 -0700

1"""save and load objects to and from json or compatible formats in a recoverable way 

2 

3`d = dataclasses.asdict(my_obj)` will give you a dict, but if some fields are not json-serializable, 

4you will get an error when you call `json.dumps(d)`. This module provides a way around that. 

5 

6Instead, you define your class: 

7 

8```python 

9@serializable_dataclass 

10class MyClass(SerializableDataclass): 

11 a: int 

12 b: str 

13``` 

14 

15and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 

16 

17 >>> my_obj = MyClass(a=1, b="q") 

18 >>> s = json.dumps(my_obj.serialize()) 

19 >>> s 

20 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 

21 >>> read_obj = MyClass.load(json.loads(s)) 

22 >>> read_obj == my_obj 

23 True 

24 

25This isn't too impressive on its own, but it gets more useful when you have nested classses, 

26or fields that are not json-serializable by default: 

27 

28```python 

29@serializable_dataclass 

30class NestedClass(SerializableDataclass): 

31 x: str 

32 y: MyClass 

33 act_fun: torch.nn.Module = serializable_field( 

34 default=torch.nn.ReLU(), 

35 serialization_fn=lambda x: str(x), 

36 deserialize_fn=lambda x: getattr(torch.nn, x)(), 

37 ) 

38``` 

39 

40which gives us: 

41 

42 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 

43 >>> s = json.dumps(nc.serialize()) 

44 >>> s 

45 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 

46 >>> read_nc = NestedClass.load(json.loads(s)) 

47 >>> read_nc == nc 

48 True 

49 

50""" 

51 

52from __future__ import annotations 

53 

54import abc 

55import dataclasses 

56import functools 

57import json 

58import sys 

59import typing 

60import warnings 

61from typing import Any, Optional, Type, TypeVar 

62 

63from muutils.errormode import ErrorMode 

64from muutils.validate_type import validate_type 

65from muutils.json_serialize.serializable_field import ( 

66 SerializableField, 

67 serializable_field, 

68) 

69from muutils.json_serialize.util import array_safe_eq, dc_eq 

70 

71# pylint: disable=bad-mcs-classmethod-argument, too-many-arguments, protected-access 

72 

73 

74def _dataclass_transform_mock( 

75 *, 

76 eq_default: bool = True, 

77 order_default: bool = False, 

78 kw_only_default: bool = False, 

79 frozen_default: bool = False, 

80 field_specifiers: tuple[type[Any] | typing.Callable[..., Any], ...] = (), 

81 **kwargs: Any, 

82) -> typing.Callable: 

83 "mock `typing.dataclass_transform` for python <3.11" 

84 

85 def decorator(cls_or_fn): 

86 cls_or_fn.__dataclass_transform__ = { 

87 "eq_default": eq_default, 

88 "order_default": order_default, 

89 "kw_only_default": kw_only_default, 

90 "frozen_default": frozen_default, 

91 "field_specifiers": field_specifiers, 

92 "kwargs": kwargs, 

93 } 

94 return cls_or_fn 

95 

96 return decorator 

97 

98 

99dataclass_transform: typing.Callable 

100if sys.version_info < (3, 11): 

101 dataclass_transform = _dataclass_transform_mock 

102else: 

103 dataclass_transform = typing.dataclass_transform 

104 

105 

106T = TypeVar("T") 

107 

108 

109class CantGetTypeHintsWarning(UserWarning): 

110 "special warning for when we can't get type hints" 

111 

112 pass 

113 

114 

115class ZanjMissingWarning(UserWarning): 

116 "special warning for when [`ZANJ`](https://github.com/mivanit/ZANJ) is missing -- `register_loader_serializable_dataclass` will not work" 

117 

118 pass 

119 

120 

121_zanj_loading_needs_import: bool = True 

122"flag to keep track of if we have successfully imported ZANJ" 

123 

124 

125def zanj_register_loader_serializable_dataclass(cls: typing.Type[T]): 

126 """Register a serializable dataclass with the ZANJ import 

127 

128 this allows `ZANJ().read()` to load the class and not just return plain dicts 

129 

130 

131 # TODO: there is some duplication here with register_loader_handler 

132 """ 

133 global _zanj_loading_needs_import 

134 

135 if _zanj_loading_needs_import: 

136 try: 

137 from zanj.loading import ( # type: ignore[import] 

138 LoaderHandler, 

139 register_loader_handler, 

140 ) 

141 except ImportError: 

142 warnings.warn( 

143 "ZANJ not installed, cannot register serializable dataclass loader. ZANJ can be found at https://github.com/mivanit/ZANJ or installed via `pip install zanj`", 

144 ZanjMissingWarning, 

145 ) 

146 return 

147 

148 _format: str = f"{cls.__name__}(SerializableDataclass)" 

149 lh: LoaderHandler = LoaderHandler( 

150 check=lambda json_item, path=None, z=None: ( # type: ignore 

151 isinstance(json_item, dict) 

152 and "__format__" in json_item 

153 and json_item["__format__"].startswith(_format) 

154 ), 

155 load=lambda json_item, path=None, z=None: cls.load(json_item), # type: ignore 

156 uid=_format, 

157 source_pckg=cls.__module__, 

158 desc=f"{_format} loader via muutils.json_serialize.serializable_dataclass", 

159 ) 

160 

161 register_loader_handler(lh) 

162 

163 return lh 

164 

165 

166_DEFAULT_ON_TYPECHECK_MISMATCH: ErrorMode = ErrorMode.WARN 

167_DEFAULT_ON_TYPECHECK_ERROR: ErrorMode = ErrorMode.EXCEPT 

168 

169 

170class FieldIsNotInitOrSerializeWarning(UserWarning): 

171 pass 

172 

173 

174def SerializableDataclass__validate_field_type( 

175 self: SerializableDataclass, 

176 field: SerializableField | str, 

177 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

178) -> bool: 

179 """given a dataclass, check the field matches the type hint 

180 

181 this function is written to `SerializableDataclass.validate_field_type` 

182 

183 # Parameters: 

184 - `self : SerializableDataclass` 

185 `SerializableDataclass` instance 

186 - `field : SerializableField | str` 

187 field to validate, will get from `self.__dataclass_fields__` if an `str` 

188 - `on_typecheck_error : ErrorMode` 

189 what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, the function will return `False` 

190 (defaults to `_DEFAULT_ON_TYPECHECK_ERROR`) 

191 

192 # Returns: 

193 - `bool` 

194 if the field type is correct. `False` if the field type is incorrect or an exception is thrown and `on_typecheck_error` is `ignore` 

195 """ 

196 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

197 

198 # get field 

199 _field: SerializableField 

200 if isinstance(field, str): 

201 _field = self.__dataclass_fields__[field] # type: ignore[attr-defined] 

202 else: 

203 _field = field 

204 

205 # do nothing case 

206 if not _field.assert_type: 

207 return True 

208 

209 # if field is not `init` or not `serialize`, skip but warn 

210 # TODO: how to handle fields which are not `init` or `serialize`? 

211 if not _field.init or not _field.serialize: 

212 warnings.warn( 

213 f"Field '{_field.name}' on class {self.__class__} is not `init` or `serialize`, so will not be type checked", 

214 FieldIsNotInitOrSerializeWarning, 

215 ) 

216 return True 

217 

218 assert isinstance( 

219 _field, SerializableField 

220 ), f"Field '{_field.name = }' on class {self.__class__ = } is not a SerializableField, but a {type(_field) = }" 

221 

222 # get field type hints 

223 try: 

224 field_type_hint: Any = get_cls_type_hints(self.__class__)[_field.name] 

225 except KeyError as e: 

226 on_typecheck_error.process( 

227 ( 

228 f"Cannot get type hints for {self.__class__.__name__}, field {_field.name = } and so cannot validate.\n" 

229 + f"{get_cls_type_hints(self.__class__) = }\n" 

230 + f"Python version is {sys.version_info = }. You can:\n" 

231 + f" - disable `assert_type`. Currently: {_field.assert_type = }\n" 

232 + f" - use hints like `typing.Dict` instead of `dict` in type hints (this is required on python 3.8.x). You had {_field.type = }\n" 

233 + " - use python 3.9.x or higher\n" 

234 + " - specify custom type validation function via `custom_typecheck_fn`\n" 

235 ), 

236 except_cls=TypeError, 

237 except_from=e, 

238 ) 

239 return False 

240 

241 # get the value 

242 value: Any = getattr(self, _field.name) 

243 

244 # validate the type 

245 try: 

246 type_is_valid: bool 

247 # validate the type with the default type validator 

248 if _field.custom_typecheck_fn is None: 

249 type_is_valid = validate_type(value, field_type_hint) 

250 # validate the type with a custom type validator 

251 else: 

252 type_is_valid = _field.custom_typecheck_fn(field_type_hint) 

253 

254 return type_is_valid 

255 

256 except Exception as e: 

257 on_typecheck_error.process( 

258 "exception while validating type: " 

259 + f"{_field.name = }, {field_type_hint = }, {type(field_type_hint) = }, {value = }", 

260 except_cls=ValueError, 

261 except_from=e, 

262 ) 

263 return False 

264 

265 

266def SerializableDataclass__validate_fields_types__dict( 

267 self: SerializableDataclass, 

268 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

269) -> dict[str, bool]: 

270 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field 

271 

272 returns a dict of field names to bools, where the bool is if the field type is valid 

273 """ 

274 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

275 

276 # if except, bundle the exceptions 

277 results: dict[str, bool] = dict() 

278 exceptions: dict[str, Exception] = dict() 

279 

280 # for each field in the class 

281 cls_fields: typing.Sequence[SerializableField] = dataclasses.fields(self) # type: ignore[arg-type, assignment] 

282 for field in cls_fields: 

283 try: 

284 results[field.name] = self.validate_field_type(field, on_typecheck_error) 

285 except Exception as e: 

286 results[field.name] = False 

287 exceptions[field.name] = e 

288 

289 # figure out what to do with the exceptions 

290 if len(exceptions) > 0: 

291 on_typecheck_error.process( 

292 f"Exceptions while validating types of fields on {self.__class__.__name__}: {[x.name for x in cls_fields]}" 

293 + "\n\t" 

294 + "\n\t".join([f"{k}:\t{v}" for k, v in exceptions.items()]), 

295 except_cls=ValueError, 

296 # HACK: ExceptionGroup not supported in py < 3.11, so get a random exception from the dict 

297 except_from=list(exceptions.values())[0], 

298 ) 

299 

300 return results 

301 

302 

303def SerializableDataclass__validate_fields_types( 

304 self: SerializableDataclass, 

305 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

306) -> bool: 

307 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 

308 return all( 

309 SerializableDataclass__validate_fields_types__dict( 

310 self, on_typecheck_error=on_typecheck_error 

311 ).values() 

312 ) 

313 

314 

315@dataclass_transform( 

316 field_specifiers=(serializable_field, SerializableField), 

317) 

318class SerializableDataclass(abc.ABC): 

319 """Base class for serializable dataclasses 

320 

321 only for linting and type checking, still need to call `serializable_dataclass` decorator 

322 

323 # Usage: 

324 

325 ```python 

326 @serializable_dataclass 

327 class MyClass(SerializableDataclass): 

328 a: int 

329 b: str 

330 ``` 

331 

332 and then you can call `my_obj.serialize()` to get a dict that can be serialized to json. So, you can do: 

333 

334 >>> my_obj = MyClass(a=1, b="q") 

335 >>> s = json.dumps(my_obj.serialize()) 

336 >>> s 

337 '{"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}' 

338 >>> read_obj = MyClass.load(json.loads(s)) 

339 >>> read_obj == my_obj 

340 True 

341 

342 This isn't too impressive on its own, but it gets more useful when you have nested classses, 

343 or fields that are not json-serializable by default: 

344 

345 ```python 

346 @serializable_dataclass 

347 class NestedClass(SerializableDataclass): 

348 x: str 

349 y: MyClass 

350 act_fun: torch.nn.Module = serializable_field( 

351 default=torch.nn.ReLU(), 

352 serialization_fn=lambda x: str(x), 

353 deserialize_fn=lambda x: getattr(torch.nn, x)(), 

354 ) 

355 ``` 

356 

357 which gives us: 

358 

359 >>> nc = NestedClass(x="q", y=MyClass(a=1, b="q"), act_fun=torch.nn.Sigmoid()) 

360 >>> s = json.dumps(nc.serialize()) 

361 >>> s 

362 '{"__format__": "NestedClass(SerializableDataclass)", "x": "q", "y": {"__format__": "MyClass(SerializableDataclass)", "a": 1, "b": "q"}, "act_fun": "Sigmoid"}' 

363 >>> read_nc = NestedClass.load(json.loads(s)) 

364 >>> read_nc == nc 

365 True 

366 """ 

367 

368 def serialize(self) -> dict[str, Any]: 

369 "returns the class as a dict, implemented by using `@serializable_dataclass` decorator" 

370 raise NotImplementedError( 

371 f"decorate {self.__class__ = } with `@serializable_dataclass`" 

372 ) 

373 

374 @classmethod 

375 def load(cls: Type[T], data: dict[str, Any] | T) -> T: 

376 "takes in an appropriately structured dict and returns an instance of the class, implemented by using `@serializable_dataclass` decorator" 

377 raise NotImplementedError(f"decorate {cls = } with `@serializable_dataclass`") 

378 

379 def validate_fields_types( 

380 self, on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR 

381 ) -> bool: 

382 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 

383 return SerializableDataclass__validate_fields_types( 

384 self, on_typecheck_error=on_typecheck_error 

385 ) 

386 

387 def validate_field_type( 

388 self, 

389 field: "SerializableField|str", 

390 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

391 ) -> bool: 

392 """given a dataclass, check the field matches the type hint""" 

393 return SerializableDataclass__validate_field_type( 

394 self, field, on_typecheck_error=on_typecheck_error 

395 ) 

396 

397 def __eq__(self, other: Any) -> bool: 

398 return dc_eq(self, other) 

399 

400 def __hash__(self) -> int: 

401 "hashes the json-serialized representation of the class" 

402 return hash(json.dumps(self.serialize())) 

403 

404 def diff( 

405 self, other: "SerializableDataclass", of_serialized: bool = False 

406 ) -> dict[str, Any]: 

407 """get a rich and recursive diff between two instances of a serializable dataclass 

408 

409 ```python 

410 >>> Myclass(a=1, b=2).diff(Myclass(a=1, b=3)) 

411 {'b': {'self': 2, 'other': 3}} 

412 >>> NestedClass(x="q1", y=Myclass(a=1, b=2)).diff(NestedClass(x="q2", y=Myclass(a=1, b=3))) 

413 {'x': {'self': 'q1', 'other': 'q2'}, 'y': {'b': {'self': 2, 'other': 3}}} 

414 ``` 

415 

416 # Parameters: 

417 - `other : SerializableDataclass` 

418 other instance to compare against 

419 - `of_serialized : bool` 

420 if true, compare serialized data and not raw values 

421 (defaults to `False`) 

422 

423 # Returns: 

424 - `dict[str, Any]` 

425 

426 

427 # Raises: 

428 - `ValueError` : if the instances are not of the same type 

429 - `ValueError` : if the instances are `dataclasses.dataclass` but not `SerializableDataclass` 

430 """ 

431 # match types 

432 if type(self) is not type(other): 

433 raise ValueError( 

434 f"Instances must be of the same type, but got {type(self) = } and {type(other) = }" 

435 ) 

436 

437 # initialize the diff result 

438 diff_result: dict = {} 

439 

440 # if they are the same, return the empty diff 

441 if self == other: 

442 return diff_result 

443 

444 # if we are working with serialized data, serialize the instances 

445 if of_serialized: 

446 ser_self: dict = self.serialize() 

447 ser_other: dict = other.serialize() 

448 

449 # for each field in the class 

450 for field in dataclasses.fields(self): # type: ignore[arg-type] 

451 # skip fields that are not for comparison 

452 if not field.compare: 

453 continue 

454 

455 # get values 

456 field_name: str = field.name 

457 self_value = getattr(self, field_name) 

458 other_value = getattr(other, field_name) 

459 

460 # if the values are both serializable dataclasses, recurse 

461 if isinstance(self_value, SerializableDataclass) and isinstance( 

462 other_value, SerializableDataclass 

463 ): 

464 nested_diff: dict = self_value.diff( 

465 other_value, of_serialized=of_serialized 

466 ) 

467 if nested_diff: 

468 diff_result[field_name] = nested_diff 

469 # only support serializable dataclasses 

470 elif dataclasses.is_dataclass(self_value) and dataclasses.is_dataclass( 

471 other_value 

472 ): 

473 raise ValueError("Non-serializable dataclass is not supported") 

474 else: 

475 # get the values of either the serialized or the actual values 

476 self_value_s = ser_self[field_name] if of_serialized else self_value 

477 other_value_s = ser_other[field_name] if of_serialized else other_value 

478 # compare the values 

479 if not array_safe_eq(self_value_s, other_value_s): 

480 diff_result[field_name] = {"self": self_value, "other": other_value} 

481 

482 # return the diff result 

483 return diff_result 

484 

485 def update_from_nested_dict(self, nested_dict: dict[str, Any]): 

486 """update the instance from a nested dict, useful for configuration from command line args 

487 

488 # Parameters: 

489 - `nested_dict : dict[str, Any]` 

490 nested dict to update the instance with 

491 """ 

492 for field in dataclasses.fields(self): # type: ignore[arg-type] 

493 field_name: str = field.name 

494 self_value = getattr(self, field_name) 

495 

496 if field_name in nested_dict: 

497 if isinstance(self_value, SerializableDataclass): 

498 self_value.update_from_nested_dict(nested_dict[field_name]) 

499 else: 

500 setattr(self, field_name, nested_dict[field_name]) 

501 

502 def __copy__(self) -> "SerializableDataclass": 

503 "deep copy by serializing and loading the instance to json" 

504 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 

505 

506 def __deepcopy__(self, memo: dict) -> "SerializableDataclass": 

507 "deep copy by serializing and loading the instance to json" 

508 return self.__class__.load(json.loads(json.dumps(self.serialize()))) 

509 

510 

511# cache this so we don't have to keep getting it 

512# TODO: are the types hashable? does this even make sense? 

513@functools.lru_cache(typed=True) 

514def get_cls_type_hints_cached(cls: Type[T]) -> dict[str, Any]: 

515 "cached typing.get_type_hints for a class" 

516 return typing.get_type_hints(cls) 

517 

518 

519def get_cls_type_hints(cls: Type[T]) -> dict[str, Any]: 

520 "helper function to get type hints for a class" 

521 cls_type_hints: dict[str, Any] 

522 try: 

523 cls_type_hints = get_cls_type_hints_cached(cls) # type: ignore 

524 if len(cls_type_hints) == 0: 

525 cls_type_hints = typing.get_type_hints(cls) 

526 

527 if len(cls_type_hints) == 0: 

528 raise ValueError(f"empty type hints for {cls.__name__ = }") 

529 except (TypeError, NameError, ValueError) as e: 

530 raise TypeError( 

531 f"Cannot get type hints for {cls = }\n" 

532 + f" Python version is {sys.version_info = } (use hints like `typing.Dict` instead of `dict` in type hints on python < 3.9)\n" 

533 + f" {dataclasses.fields(cls) = }\n" # type: ignore[arg-type] 

534 + f" {e = }" 

535 ) from e 

536 

537 return cls_type_hints 

538 

539 

540class KWOnlyError(NotImplementedError): 

541 "kw-only dataclasses are not supported in python <3.9" 

542 

543 pass 

544 

545 

546class FieldError(ValueError): 

547 "base class for field errors" 

548 

549 pass 

550 

551 

552class NotSerializableFieldException(FieldError): 

553 "field is not a `SerializableField`" 

554 

555 pass 

556 

557 

558class FieldSerializationError(FieldError): 

559 "error while serializing a field" 

560 

561 pass 

562 

563 

564class FieldLoadingError(FieldError): 

565 "error while loading a field" 

566 

567 pass 

568 

569 

570class FieldTypeMismatchError(FieldError, TypeError): 

571 "error when a field type does not match the type hint" 

572 

573 pass 

574 

575 

576@dataclass_transform( 

577 field_specifiers=(serializable_field, SerializableField), 

578) 

579def serializable_dataclass( 

580 # this should be `_cls: Type[T] | None = None,` but mypy doesn't like it 

581 _cls=None, # type: ignore 

582 *, 

583 init: bool = True, 

584 repr: bool = True, # this overrides the actual `repr` builtin, but we have to match the interface of `dataclasses.dataclass` 

585 eq: bool = True, 

586 order: bool = False, 

587 unsafe_hash: bool = False, 

588 frozen: bool = False, 

589 properties_to_serialize: Optional[list[str]] = None, 

590 register_handler: bool = True, 

591 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 

592 on_typecheck_mismatch: ErrorMode = _DEFAULT_ON_TYPECHECK_MISMATCH, 

593 **kwargs, 

594): 

595 """decorator to make a dataclass serializable. must also make it inherit from `SerializableDataclass` 

596 

597 types will be validated (like pydantic) unless `on_typecheck_mismatch` is set to `ErrorMode.IGNORE` 

598 

599 behavior of most kwargs matches that of `dataclasses.dataclass`, but with some additional kwargs 

600 

601 Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. 

602 

603 Examines PEP 526 `__annotations__` to determine fields. 

604 

605 If init is true, an `__init__()` method is added to the class. If repr is true, a `__repr__()` method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a `__hash__()` method function is added. If frozen is true, fields may not be assigned to after instance creation. 

606 

607 ```python 

608 @serializable_dataclass(kw_only=True) 

609 class Myclass(SerializableDataclass): 

610 a: int 

611 b: str 

612 ``` 

613 ```python 

614 >>> Myclass(a=1, b="q").serialize() 

615 {'__format__': 'Myclass(SerializableDataclass)', 'a': 1, 'b': 'q'} 

616 ``` 

617 

618 # Parameters: 

619 - `_cls : _type_` 

620 class to decorate. don't pass this arg, just use this as a decorator 

621 (defaults to `None`) 

622 - `init : bool` 

623 (defaults to `True`) 

624 - `repr : bool` 

625 (defaults to `True`) 

626 - `order : bool` 

627 (defaults to `False`) 

628 - `unsafe_hash : bool` 

629 (defaults to `False`) 

630 - `frozen : bool` 

631 (defaults to `False`) 

632 - `properties_to_serialize : Optional[list[str]]` 

633 **SerializableDataclass only:** which properties to add to the serialized data dict 

634 (defaults to `None`) 

635 - `register_handler : bool` 

636 **SerializableDataclass only:** if true, register the class with ZANJ for loading 

637 (defaults to `True`) 

638 - `on_typecheck_error : ErrorMode` 

639 **SerializableDataclass only:** what to do if type checking throws an exception (except, warn, ignore). If `ignore` and an exception is thrown, type validation will still return false 

640 - `on_typecheck_mismatch : ErrorMode` 

641 **SerializableDataclass only:** what to do if a type mismatch is found (except, warn, ignore). If `ignore`, type validation will return `True` 

642 

643 # Returns: 

644 - `_type_` 

645 the decorated class 

646 

647 # Raises: 

648 - `KWOnlyError` : only raised if `kw_only` is `True` and python version is <3.9, since `dataclasses.dataclass` does not support this 

649 - `NotSerializableFieldException` : if a field is not a `SerializableField` 

650 - `FieldSerializationError` : if there is an error serializing a field 

651 - `AttributeError` : if a property is not found on the class 

652 - `FieldLoadingError` : if there is an error loading a field 

653 """ 

654 # -> Union[Callable[[Type[T]], Type[T]], Type[T]]: 

655 on_typecheck_error = ErrorMode.from_any(on_typecheck_error) 

656 on_typecheck_mismatch = ErrorMode.from_any(on_typecheck_mismatch) 

657 

658 if properties_to_serialize is None: 

659 _properties_to_serialize: list = list() 

660 else: 

661 _properties_to_serialize = properties_to_serialize 

662 

663 def wrap(cls: Type[T]) -> Type[T]: 

664 # Modify the __annotations__ dictionary to replace regular fields with SerializableField 

665 for field_name, field_type in cls.__annotations__.items(): 

666 field_value = getattr(cls, field_name, None) 

667 if not isinstance(field_value, SerializableField): 

668 if isinstance(field_value, dataclasses.Field): 

669 # Convert the field to a SerializableField while preserving properties 

670 field_value = SerializableField.from_Field(field_value) 

671 else: 

672 # Create a new SerializableField 

673 field_value = serializable_field() 

674 setattr(cls, field_name, field_value) 

675 

676 # special check, kw_only is not supported in python <3.9 and `dataclasses.MISSING` is truthy 

677 if sys.version_info < (3, 10): 

678 if "kw_only" in kwargs: 

679 if kwargs["kw_only"] == True: # noqa: E712 

680 raise KWOnlyError("kw_only is not supported in python >=3.9") 

681 else: 

682 del kwargs["kw_only"] 

683 

684 # call `dataclasses.dataclass` to set some stuff up 

685 cls = dataclasses.dataclass( # type: ignore[call-overload] 

686 cls, 

687 init=init, 

688 repr=repr, 

689 eq=eq, 

690 order=order, 

691 unsafe_hash=unsafe_hash, 

692 frozen=frozen, 

693 **kwargs, 

694 ) 

695 

696 # copy these to the class 

697 cls._properties_to_serialize = _properties_to_serialize.copy() # type: ignore[attr-defined] 

698 

699 # ====================================================================== 

700 # define `serialize` func 

701 # done locally since it depends on args to the decorator 

702 # ====================================================================== 

703 def serialize(self) -> dict[str, Any]: 

704 result: dict[str, Any] = { 

705 "__format__": f"{self.__class__.__name__}(SerializableDataclass)" 

706 } 

707 # for each field in the class 

708 for field in dataclasses.fields(self): # type: ignore[arg-type] 

709 # need it to be our special SerializableField 

710 if not isinstance(field, SerializableField): 

711 raise NotSerializableFieldException( 

712 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 

713 f"but a {type(field)} " 

714 "this state should be inaccessible, please report this bug!" 

715 ) 

716 

717 # try to save it 

718 if field.serialize: 

719 try: 

720 # get the val 

721 value = getattr(self, field.name) 

722 # if it is a serializable dataclass, serialize it 

723 if isinstance(value, SerializableDataclass): 

724 value = value.serialize() 

725 # if the value has a serialization function, use that 

726 if hasattr(value, "serialize") and callable(value.serialize): 

727 value = value.serialize() 

728 # if the field has a serialization function, use that 

729 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 

730 elif field.serialization_fn: 

731 value = field.serialization_fn(value) 

732 

733 # store the value in the result 

734 result[field.name] = value 

735 except Exception as e: 

736 raise FieldSerializationError( 

737 "\n".join( 

738 [ 

739 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 

740 f"{field = }", 

741 f"{value = }", 

742 f"{self = }", 

743 ] 

744 ) 

745 ) from e 

746 

747 # store each property if we can get it 

748 for prop in self._properties_to_serialize: 

749 if hasattr(cls, prop): 

750 value = getattr(self, prop) 

751 result[prop] = value 

752 else: 

753 raise AttributeError( 

754 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 

755 + f"but it is in {self._properties_to_serialize = }" 

756 + f"\n{self = }" 

757 ) 

758 

759 return result 

760 

761 # ====================================================================== 

762 # define `load` func 

763 # done locally since it depends on args to the decorator 

764 # ====================================================================== 

765 # mypy thinks this isnt a classmethod 

766 @classmethod # type: ignore[misc] 

767 def load(cls, data: dict[str, Any] | T) -> Type[T]: 

768 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 

769 if isinstance(data, cls): 

770 return data 

771 

772 assert isinstance( 

773 data, typing.Mapping 

774 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 

775 

776 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 

777 

778 # initialize dict for keeping what we will pass to the constructor 

779 ctor_kwargs: dict[str, Any] = dict() 

780 

781 # iterate over the fields of the class 

782 for field in dataclasses.fields(cls): 

783 # check if the field is a SerializableField 

784 assert isinstance( 

785 field, SerializableField 

786 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 

787 

788 # check if the field is in the data and if it should be initialized 

789 if (field.name in data) and field.init: 

790 # get the value, we will be processing it 

791 value: Any = data[field.name] 

792 

793 # get the type hint for the field 

794 field_type_hint: Any = cls_type_hints.get(field.name, None) 

795 

796 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 

797 if field.deserialize_fn: 

798 # if it has a deserialization function, use that 

799 value = field.deserialize_fn(value) 

800 elif field.loading_fn: 

801 # if it has a loading function, use that 

802 value = field.loading_fn(data) 

803 elif ( 

804 field_type_hint is not None 

805 and hasattr(field_type_hint, "load") 

806 and callable(field_type_hint.load) 

807 ): 

808 # if no loading function but has a type hint with a load method, use that 

809 if isinstance(value, dict): 

810 value = field_type_hint.load(value) 

811 else: 

812 raise FieldLoadingError( 

813 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 

814 ) 

815 else: 

816 # assume no loading needs to happen, keep `value` as-is 

817 pass 

818 

819 # store the value in the constructor kwargs 

820 ctor_kwargs[field.name] = value 

821 

822 # create a new instance of the class with the constructor kwargs 

823 output: cls = cls(**ctor_kwargs) 

824 

825 # validate the types of the fields if needed 

826 if on_typecheck_mismatch != ErrorMode.IGNORE: 

827 fields_valid: dict[str, bool] = ( 

828 SerializableDataclass__validate_fields_types__dict( 

829 output, 

830 on_typecheck_error=on_typecheck_error, 

831 ) 

832 ) 

833 

834 # if there are any fields that are not valid, raise an error 

835 if not all(fields_valid.values()): 

836 msg: str = ( 

837 f"Type mismatch in fields of {cls.__name__}:\n" 

838 + "\n".join( 

839 [ 

840 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 

841 for k, v in fields_valid.items() 

842 if not v 

843 ] 

844 ) 

845 ) 

846 

847 on_typecheck_mismatch.process( 

848 msg, except_cls=FieldTypeMismatchError 

849 ) 

850 

851 # return the new instance 

852 return output 

853 

854 # mypy says "Type cannot be declared in assignment to non-self attribute" so thats why I've left the hints in the comments 

855 # type is `Callable[[T], dict]` 

856 cls.serialize = serialize # type: ignore[attr-defined] 

857 # type is `Callable[[dict], T]` 

858 cls.load = load # type: ignore[attr-defined] 

859 # type is `Callable[[T, ErrorMode], bool]` 

860 cls.validate_fields_types = SerializableDataclass__validate_fields_types # type: ignore[attr-defined] 

861 

862 # type is `Callable[[T, T], bool]` 

863 if not hasattr(cls, "__eq__"): 

864 cls.__eq__ = lambda self, other: dc_eq(self, other) # type: ignore[assignment] 

865 

866 # Register the class with ZANJ 

867 if register_handler: 

868 zanj_register_loader_serializable_dataclass(cls) 

869 

870 return cls 

871 

872 if _cls is None: 

873 return wrap 

874 else: 

875 return wrap(_cls)