Coverage for tests\unit\json_serialize\serializable_dataclass\test_serializable_dataclass.py: 87%

482 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-12 20:43 -0700

1from __future__ import annotations 

2 

3from copy import deepcopy 

4import typing 

5from typing import Any, Dict, Generic, List, Optional, TypeVar, Union 

6 

7import pytest 

8 

9from muutils.errormode import ErrorMode 

10from muutils.json_serialize import ( 

11 SerializableDataclass, 

12 serializable_dataclass, 

13 serializable_field, 

14) 

15 

16from muutils.json_serialize.serializable_dataclass import ( 

17 FieldIsNotInitOrSerializeWarning, 

18 FieldTypeMismatchError, 

19) 

20 

21# pylint: disable=missing-class-docstring, unused-variable 

22 

23 

24@serializable_dataclass 

25class BasicAutofields(SerializableDataclass): 

26 a: str 

27 b: int 

28 c: typing.List[int] 

29 

30 

31def test_basic_auto_fields(): 

32 data = dict(a="hello", b=42, c=[1, 2, 3]) 

33 instance = BasicAutofields(**data) 

34 data_with_format = data.copy() 

35 data_with_format["__format__"] = "BasicAutofields(SerializableDataclass)" 

36 assert instance.serialize() == data_with_format 

37 assert instance == instance 

38 assert instance.diff(instance) == {} 

39 

40 

41def test_basic_diff(): 

42 instance_1 = BasicAutofields(a="hello", b=42, c=[1, 2, 3]) 

43 instance_2 = BasicAutofields(a="goodbye", b=42, c=[1, 2, 3]) 

44 instance_3 = BasicAutofields(a="hello", b=-1, c=[1, 2, 3]) 

45 instance_4 = BasicAutofields(a="hello", b=-1, c=[42]) 

46 

47 assert instance_1.diff(instance_2) == {"a": {"self": "hello", "other": "goodbye"}} 

48 assert instance_1.diff(instance_3) == {"b": {"self": 42, "other": -1}} 

49 assert instance_1.diff(instance_4) == { 

50 "b": {"self": 42, "other": -1}, 

51 "c": {"self": [1, 2, 3], "other": [42]}, 

52 } 

53 assert instance_1.diff(instance_1) == {} 

54 assert instance_2.diff(instance_3) == { 

55 "a": {"self": "goodbye", "other": "hello"}, 

56 "b": {"self": 42, "other": -1}, 

57 } 

58 

59 

60@serializable_dataclass 

61class SimpleFields(SerializableDataclass): 

62 d: str 

63 e: int = 42 

64 f: typing.List[int] = serializable_field(default_factory=list) # noqa: F821 

65 

66 

67@serializable_dataclass 

68class FieldOptions(SerializableDataclass): 

69 a: str = serializable_field() 

70 b: str = serializable_field() 

71 c: str = serializable_field(init=False, serialize=False, repr=False, compare=False) 

72 d: str = serializable_field( 

73 serialization_fn=lambda x: x.upper(), loading_fn=lambda x: x["d"].lower() 

74 ) 

75 

76 

77@serializable_dataclass(properties_to_serialize=["full_name"]) 

78class WithProperty(SerializableDataclass): 

79 first_name: str 

80 last_name: str 

81 

82 @property 

83 def full_name(self) -> str: 

84 return f"{self.first_name} {self.last_name}" 

85 

86 

87class Child(FieldOptions, WithProperty): 

88 pass 

89 

90 

91@pytest.fixture 

92def simple_fields_instance(): 

93 return SimpleFields(d="hello", e=42, f=[1, 2, 3]) 

94 

95 

96@pytest.fixture 

97def field_options_instance(): 

98 return FieldOptions(a="hello", b="world", d="case") 

99 

100 

101@pytest.fixture 

102def with_property_instance(): 

103 return WithProperty(first_name="John", last_name="Doe") 

104 

105 

106def test_simple_fields_serialization(simple_fields_instance): 

107 serialized = simple_fields_instance.serialize() 

108 assert serialized == { 

109 "d": "hello", 

110 "e": 42, 

111 "f": [1, 2, 3], 

112 "__format__": "SimpleFields(SerializableDataclass)", 

113 } 

114 

115 

116def test_simple_fields_loading(simple_fields_instance): 

117 serialized = simple_fields_instance.serialize() 

118 

119 loaded = SimpleFields.load(serialized) 

120 

121 assert loaded == simple_fields_instance 

122 assert loaded.diff(simple_fields_instance) == {} 

123 assert simple_fields_instance.diff(loaded) == {} 

124 

125 

126def test_field_options_serialization(field_options_instance): 

127 serialized = field_options_instance.serialize() 

128 assert serialized == { 

129 "a": "hello", 

130 "b": "world", 

131 "d": "CASE", 

132 "__format__": "FieldOptions(SerializableDataclass)", 

133 } 

134 

135 

136def test_field_options_loading(field_options_instance): 

137 # ignore a `FieldIsNotInitOrSerializeWarning` 

138 serialized = field_options_instance.serialize() 

139 with pytest.warns(FieldIsNotInitOrSerializeWarning): 

140 loaded = FieldOptions.load(serialized) 

141 assert loaded == field_options_instance 

142 

143 

144def test_with_property_serialization(with_property_instance): 

145 serialized = with_property_instance.serialize() 

146 assert serialized == { 

147 "first_name": "John", 

148 "last_name": "Doe", 

149 "full_name": "John Doe", 

150 "__format__": "WithProperty(SerializableDataclass)", 

151 } 

152 

153 

154def test_with_property_loading(with_property_instance): 

155 serialized = with_property_instance.serialize() 

156 loaded = WithProperty.load(serialized) 

157 assert loaded == with_property_instance 

158 

159 

160@serializable_dataclass 

161class Address(SerializableDataclass): 

162 street: str 

163 city: str 

164 zip_code: str 

165 

166 

167@serializable_dataclass 

168class Person(SerializableDataclass): 

169 name: str 

170 age: int 

171 address: Address 

172 

173 

174@pytest.fixture 

175def address_instance(): 

176 return Address(street="123 Main St", city="New York", zip_code="10001") 

177 

178 

179@pytest.fixture 

180def person_instance(address_instance): 

181 return Person(name="John Doe", age=30, address=address_instance) 

182 

183 

184def test_nested_serialization(person_instance): 

185 serialized = person_instance.serialize() 

186 expected_ser = { 

187 "name": "John Doe", 

188 "age": 30, 

189 "address": { 

190 "street": "123 Main St", 

191 "city": "New York", 

192 "zip_code": "10001", 

193 "__format__": "Address(SerializableDataclass)", 

194 }, 

195 "__format__": "Person(SerializableDataclass)", 

196 } 

197 assert serialized == expected_ser 

198 

199 

200def test_nested_loading(person_instance): 

201 serialized = person_instance.serialize() 

202 loaded = Person.load(serialized) 

203 assert loaded == person_instance 

204 assert loaded.address == person_instance.address 

205 

206 

207def test_with_printing(): 

208 @serializable_dataclass(properties_to_serialize=["full_name"]) 

209 class MyClass(SerializableDataclass): 

210 name: str 

211 age: int = serializable_field( 

212 serialization_fn=lambda x: x + 1, loading_fn=lambda x: x["age"] - 1 

213 ) 

214 items: list = serializable_field(default_factory=list) 

215 

216 @property 

217 def full_name(self) -> str: 

218 return f"{self.name} Doe" 

219 

220 # Usage 

221 my_instance = MyClass(name="John", age=30, items=["apple", "banana"]) 

222 serialized_data = my_instance.serialize() 

223 print(serialized_data) 

224 

225 loaded_instance = MyClass.load(serialized_data) 

226 print(loaded_instance) 

227 

228 

229def test_simple_class_serialization(): 

230 @serializable_dataclass 

231 class SimpleClass(SerializableDataclass): 

232 a: int 

233 b: str 

234 

235 simple = SimpleClass(a=42, b="hello") 

236 serialized = simple.serialize() 

237 assert serialized == { 

238 "a": 42, 

239 "b": "hello", 

240 "__format__": "SimpleClass(SerializableDataclass)", 

241 } 

242 

243 loaded = SimpleClass.load(serialized) 

244 assert loaded == simple 

245 

246 

247def test_error_when_init_and_not_serialize(): 

248 with pytest.raises(ValueError): 

249 

250 @serializable_dataclass 

251 class SimpleClass(SerializableDataclass): 

252 a: int = serializable_field(init=True, serialize=False) 

253 

254 

255def test_person_serialization(): 

256 @serializable_dataclass(properties_to_serialize=["full_name"]) 

257 class FullPerson(SerializableDataclass): 

258 name: str = serializable_field() 

259 age: int = serializable_field(default=-1) 

260 items: typing.List[str] = serializable_field(default_factory=list) 

261 

262 @property 

263 def full_name(self) -> str: 

264 return f"{self.name} Doe" 

265 

266 person = FullPerson(name="John", items=["apple", "banana"]) 

267 serialized = person.serialize() 

268 expected_ser = { 

269 "name": "John", 

270 "age": -1, 

271 "items": ["apple", "banana"], 

272 "full_name": "John Doe", 

273 "__format__": "FullPerson(SerializableDataclass)", 

274 } 

275 assert serialized == expected_ser, f"Expected {expected_ser}, got {serialized}" 

276 

277 loaded = FullPerson.load(serialized) 

278 

279 assert loaded == person 

280 

281 

282def test_custom_serialization(): 

283 @serializable_dataclass 

284 class CustomSerialization(SerializableDataclass): 

285 data: Any = serializable_field( 

286 serialization_fn=lambda x: x * 2, loading_fn=lambda x: x["data"] // 2 

287 ) 

288 

289 custom = CustomSerialization(data=5) 

290 serialized = custom.serialize() 

291 assert serialized == { 

292 "data": 10, 

293 "__format__": "CustomSerialization(SerializableDataclass)", 

294 } 

295 

296 loaded = CustomSerialization.load(serialized) 

297 assert loaded == custom 

298 

299 

300@serializable_dataclass 

301class Nested_with_Container(SerializableDataclass): 

302 val_int: int 

303 val_str: str 

304 val_list: typing.List[BasicAutofields] = serializable_field( 

305 default_factory=list, 

306 serialization_fn=lambda x: [y.serialize() for y in x], 

307 loading_fn=lambda x: [BasicAutofields.load(y) for y in x["val_list"]], 

308 ) 

309 

310 

311def test_nested_with_container(): 

312 instance = Nested_with_Container( 

313 val_int=42, 

314 val_str="hello", 

315 val_list=[ 

316 BasicAutofields(a="a", b=1, c=[1, 2, 3]), 

317 BasicAutofields(a="b", b=2, c=[4, 5, 6]), 

318 ], 

319 ) 

320 

321 serialized = instance.serialize() 

322 expected_ser = { 

323 "val_int": 42, 

324 "val_str": "hello", 

325 "val_list": [ 

326 { 

327 "a": "a", 

328 "b": 1, 

329 "c": [1, 2, 3], 

330 "__format__": "BasicAutofields(SerializableDataclass)", 

331 }, 

332 { 

333 "a": "b", 

334 "b": 2, 

335 "c": [4, 5, 6], 

336 "__format__": "BasicAutofields(SerializableDataclass)", 

337 }, 

338 ], 

339 "__format__": "Nested_with_Container(SerializableDataclass)", 

340 } 

341 

342 assert serialized == expected_ser 

343 

344 loaded = Nested_with_Container.load(serialized) 

345 

346 assert loaded == instance 

347 

348 

349class Custom_class_with_serialization: 

350 """custom class which doesnt inherit but does serialize""" 

351 

352 def __init__(self, a: int, b: str): 

353 self.a: int = a 

354 self.b: str = b 

355 

356 def serialize(self): 

357 return {"a": self.a, "b": self.b} 

358 

359 @classmethod 

360 def load(cls, data): 

361 return cls(data["a"], data["b"]) 

362 

363 def __eq__(self, other): 

364 return (self.a == other.a) and (self.b == other.b) 

365 

366 

367@serializable_dataclass 

368class nested_custom(SerializableDataclass): 

369 value: float 

370 data1: Custom_class_with_serialization 

371 

372 

373def test_nested_custom(recwarn): # this will send some warnings but whatever 

374 instance = nested_custom( 

375 value=42.0, data1=Custom_class_with_serialization(1, "hello") 

376 ) 

377 serialized = instance.serialize() 

378 expected_ser = { 

379 "value": 42.0, 

380 "data1": {"a": 1, "b": "hello"}, 

381 "__format__": "nested_custom(SerializableDataclass)", 

382 } 

383 assert serialized == expected_ser 

384 loaded = nested_custom.load(serialized) 

385 assert loaded == instance 

386 

387 

388def test_deserialize_fn(): 

389 @serializable_dataclass 

390 class DeserializeFn(SerializableDataclass): 

391 data: int = serializable_field( 

392 serialization_fn=lambda x: str(x), 

393 deserialize_fn=lambda x: int(x), 

394 ) 

395 

396 instance = DeserializeFn(data=5) 

397 serialized = instance.serialize() 

398 assert serialized == { 

399 "data": "5", 

400 "__format__": "DeserializeFn(SerializableDataclass)", 

401 } 

402 

403 loaded = DeserializeFn.load(serialized) 

404 assert loaded == instance 

405 assert loaded.data == 5 

406 

407 

408@serializable_dataclass 

409class DictContainer(SerializableDataclass): 

410 """Test class containing a dictionary field""" 

411 

412 simple_dict: Dict[str, int] 

413 nested_dict: Dict[str, Dict[str, int]] = serializable_field(default_factory=dict) 

414 optional_dict: Dict[str, str] = serializable_field(default_factory=dict) 

415 

416 

417def test_dict_serialization(): 

418 """Test serialization of dictionaries within SerializableDataclass""" 

419 data = DictContainer( 

420 simple_dict={"a": 1, "b": 2}, 

421 nested_dict={"x": {"y": 3, "z": 4}}, 

422 optional_dict={"hello": "world"}, 

423 ) 

424 

425 serialized = data.serialize() 

426 expected = { 

427 "__format__": "DictContainer(SerializableDataclass)", 

428 "simple_dict": {"a": 1, "b": 2}, 

429 "nested_dict": {"x": {"y": 3, "z": 4}}, 

430 "optional_dict": {"hello": "world"}, 

431 } 

432 

433 assert serialized == expected 

434 

435 

436def test_dict_loading(): 

437 """Test loading dictionaries into SerializableDataclass""" 

438 original_data = { 

439 "__format__": "DictContainer(SerializableDataclass)", 

440 "simple_dict": {"a": 1, "b": 2}, 

441 "nested_dict": {"x": {"y": 3, "z": 4}}, 

442 "optional_dict": {"hello": "world"}, 

443 } 

444 

445 loaded = DictContainer.load(original_data) 

446 assert loaded.simple_dict == {"a": 1, "b": 2} 

447 assert loaded.nested_dict == {"x": {"y": 3, "z": 4}} 

448 assert loaded.optional_dict == {"hello": "world"} 

449 

450 

451def test_dict_equality(): 

452 """Test equality comparison of dictionaries within SerializableDataclass""" 

453 instance1 = DictContainer( 

454 simple_dict={"a": 1, "b": 2}, 

455 nested_dict={"x": {"y": 3, "z": 4}}, 

456 optional_dict={"hello": "world"}, 

457 ) 

458 

459 instance2 = DictContainer( 

460 simple_dict={"a": 1, "b": 2}, 

461 nested_dict={"x": {"y": 3, "z": 4}}, 

462 optional_dict={"hello": "world"}, 

463 ) 

464 

465 instance3 = DictContainer( 

466 simple_dict={"a": 1, "b": 3}, # Different value 

467 nested_dict={"x": {"y": 3, "z": 4}}, 

468 optional_dict={"hello": "world"}, 

469 ) 

470 

471 assert instance1 == instance2 

472 assert instance1 != instance3 

473 assert instance2 != instance3 

474 

475 

476def test_dict_diff(): 

477 """Test diff functionality with dictionaries""" 

478 instance1 = DictContainer( 

479 simple_dict={"a": 1, "b": 2}, 

480 nested_dict={"x": {"y": 3, "z": 4}}, 

481 optional_dict={"hello": "world"}, 

482 ) 

483 

484 # Different simple_dict value 

485 instance2 = DictContainer( 

486 simple_dict={"a": 1, "b": 3}, 

487 nested_dict={"x": {"y": 3, "z": 4}}, 

488 optional_dict={"hello": "world"}, 

489 ) 

490 

491 # Different nested_dict value 

492 instance3 = DictContainer( 

493 simple_dict={"a": 1, "b": 2}, 

494 nested_dict={"x": {"y": 3, "z": 5}}, 

495 optional_dict={"hello": "world"}, 

496 ) 

497 

498 # Different optional_dict value 

499 instance4 = DictContainer( 

500 simple_dict={"a": 1, "b": 2}, 

501 nested_dict={"x": {"y": 3, "z": 4}}, 

502 optional_dict={"hello": "python"}, 

503 ) 

504 

505 # Test diff with simple_dict changes 

506 diff1 = instance1.diff(instance2) 

507 assert diff1 == { 

508 "simple_dict": {"self": {"a": 1, "b": 2}, "other": {"a": 1, "b": 3}} 

509 } 

510 

511 # Test diff with nested_dict changes 

512 diff2 = instance1.diff(instance3) 

513 assert diff2 == { 

514 "nested_dict": { 

515 "self": {"x": {"y": 3, "z": 4}}, 

516 "other": {"x": {"y": 3, "z": 5}}, 

517 } 

518 } 

519 

520 # Test diff with optional_dict changes 

521 diff3 = instance1.diff(instance4) 

522 assert diff3 == { 

523 "optional_dict": {"self": {"hello": "world"}, "other": {"hello": "python"}} 

524 } 

525 

526 # Test no diff when comparing identical instances 

527 assert instance1.diff(instance1) == {} 

528 

529 

530@serializable_dataclass 

531class ComplexDictContainer(SerializableDataclass): 

532 """Test class with more complex dictionary structures""" 

533 

534 mixed_dict: Dict[str, Any] 

535 list_dict: Dict[str, typing.List[int]] 

536 multi_nested: Dict[str, Dict[str, Dict[str, int]]] 

537 

538 

539def test_complex_dict_serialization(): 

540 """Test serialization of more complex dictionary structures""" 

541 data = ComplexDictContainer( 

542 mixed_dict={"str": "hello", "int": 42, "list": [1, 2, 3]}, 

543 list_dict={"a": [1, 2, 3], "b": [4, 5, 6]}, 

544 multi_nested={"x": {"y": {"z": 1, "w": 2}, "v": {"u": 3, "t": 4}}}, 

545 ) 

546 

547 serialized = data.serialize() 

548 loaded = ComplexDictContainer.load(serialized) 

549 assert loaded == data 

550 assert loaded.diff(data) == {} 

551 

552 

553def test_empty_dicts(): 

554 """Test handling of empty dictionaries""" 

555 data = DictContainer(simple_dict={}, nested_dict={}, optional_dict={}) 

556 

557 serialized = data.serialize() 

558 loaded = DictContainer.load(serialized) 

559 assert loaded == data 

560 assert loaded.diff(data) == {} 

561 

562 # Test equality with another empty instance 

563 another_empty = DictContainer(simple_dict={}, nested_dict={}, optional_dict={}) 

564 assert data == another_empty 

565 

566 

567# Test invalid dictionary type validation 

568@serializable_dataclass(on_typecheck_mismatch=ErrorMode.EXCEPT) 

569class StrictDictContainer(SerializableDataclass): 

570 """Test class with strict dictionary typing""" 

571 

572 int_dict: Dict[str, int] 

573 str_dict: Dict[str, str] 

574 float_dict: Dict[str, float] 

575 

576 

577# TODO: figure this out 

578@pytest.mark.skip(reason="dict type validation doesnt seem to work") 

579def test_dict_type_validation(): 

580 """Test type validation for dictionary values""" 

581 # Valid case 

582 valid = StrictDictContainer( 

583 int_dict={"a": 1, "b": 2}, 

584 str_dict={"x": "hello", "y": "world"}, 

585 float_dict={"m": 1.0, "n": 2.5}, 

586 ) 

587 assert valid.validate_fields_types() 

588 

589 # Invalid int_dict 

590 with pytest.raises(FieldTypeMismatchError): 

591 StrictDictContainer( 

592 int_dict={"a": "not an int"}, # Type error 

593 str_dict={"x": "hello"}, 

594 float_dict={"m": 1.0}, 

595 ) 

596 

597 # Invalid str_dict 

598 with pytest.raises(FieldTypeMismatchError): 

599 StrictDictContainer( 

600 int_dict={"a": 1}, 

601 str_dict={"x": 123}, # Type error 

602 float_dict={"m": 1.0}, 

603 ) 

604 

605 

606# Test dictionary with optional values 

607@serializable_dataclass 

608class OptionalDictContainer(SerializableDataclass): 

609 """Test class with optional dictionary values""" 

610 

611 optional_values: Dict[str, Optional[int]] 

612 union_values: Dict[str, Union[int, str]] 

613 nullable_dict: Optional[Dict[str, int]] = None 

614 

615 

616def test_optional_dict_values(): 

617 """Test dictionaries with optional/union values""" 

618 instance = OptionalDictContainer( 

619 optional_values={"a": 1, "b": None, "c": 3}, 

620 union_values={"x": 1, "y": "string", "z": 42}, 

621 nullable_dict={"m": 1, "n": 2}, 

622 ) 

623 

624 serialized = instance.serialize() 

625 loaded = OptionalDictContainer.load(serialized) 

626 assert loaded == instance 

627 

628 # Test with None dict 

629 instance2 = OptionalDictContainer( 

630 optional_values={"a": None, "b": None}, 

631 union_values={"x": "all strings", "y": "here"}, 

632 nullable_dict=None, 

633 ) 

634 

635 serialized2 = instance2.serialize() 

636 loaded2 = OptionalDictContainer.load(serialized2) 

637 assert loaded2 == instance2 

638 

639 

640# Test dictionary mutation 

641def test_dict_mutation(): 

642 """Test behavior when mutating dictionary contents""" 

643 instance1 = DictContainer( 

644 simple_dict={"a": 1, "b": 2}, 

645 nested_dict={"x": {"y": 3}}, 

646 optional_dict={"hello": "world"}, 

647 ) 

648 

649 instance2 = deepcopy(instance1) 

650 

651 # Mutate dictionary in instance1 

652 instance1.simple_dict["c"] = 3 

653 instance1.nested_dict["x"]["z"] = 4 

654 instance1.optional_dict["new"] = "value" 

655 

656 # Verify instance2 was not affected 

657 assert instance2.simple_dict == {"a": 1, "b": 2} 

658 assert instance2.nested_dict == {"x": {"y": 3}} 

659 assert instance2.optional_dict == {"hello": "world"} 

660 

661 # Verify diff shows the changes 

662 diff = instance2.diff(instance1) 

663 assert "simple_dict" in diff 

664 assert "nested_dict" in diff 

665 assert "optional_dict" in diff 

666 

667 

668# Test dictionary key types 

669@serializable_dataclass 

670class IntKeyDictContainer(SerializableDataclass): 

671 """Test class with non-string dictionary keys""" 

672 

673 int_keys: Dict[int, str] = serializable_field( 

674 serialization_fn=lambda x: {str(k): v for k, v in x.items()}, 

675 loading_fn=lambda x: {int(k): v for k, v in x["int_keys"].items()}, 

676 ) 

677 

678 

679def test_non_string_dict_keys(): 

680 """Test handling of dictionaries with non-string keys""" 

681 instance = IntKeyDictContainer(int_keys={1: "one", 2: "two", 3: "three"}) 

682 

683 serialized = instance.serialize() 

684 # Keys should be converted to strings in serialized form 

685 assert all(isinstance(k, str) for k in serialized["int_keys"].keys()) 

686 

687 loaded = IntKeyDictContainer.load(serialized) 

688 # Keys should be integers again after loading 

689 assert all(isinstance(k, int) for k in loaded.int_keys.keys()) 

690 assert loaded == instance 

691 

692 

693@serializable_dataclass 

694class RecursiveDictContainer(SerializableDataclass): 

695 """Test class with recursively defined dictionary type""" 

696 

697 data: Dict[str, Any] 

698 

699 

700def test_recursive_dict_structure(): 

701 """Test handling of recursively nested dictionaries""" 

702 deep_dict = { 

703 "level1": { 

704 "level2": {"level3": {"value": 42, "list": [1, 2, {"nested": "value"}]}} 

705 } 

706 } 

707 

708 instance = RecursiveDictContainer(data=deep_dict) 

709 serialized = instance.serialize() 

710 loaded = RecursiveDictContainer.load(serialized) 

711 

712 assert loaded == instance 

713 assert loaded.data == deep_dict 

714 

715 

716# need to define this outside, otherwise the validator cant see it? 

717class CustomSerializable: 

718 def __init__(self, value): 

719 self.value: Union[str, int] = value 

720 

721 def serialize(self): 

722 return {"value": self.value} 

723 

724 @classmethod 

725 def load(cls, data): 

726 return cls(data["value"]) 

727 

728 def __eq__(self, other): 

729 return isinstance(other, CustomSerializable) and self.value == other.value 

730 

731 

732def test_dict_with_custom_objects(): 

733 """Test dictionaries containing custom objects that implement serialize/load""" 

734 

735 @serializable_dataclass 

736 class CustomObjectDict(SerializableDataclass): 

737 objects: Dict[str, CustomSerializable] 

738 

739 instance = CustomObjectDict( 

740 objects={"a": CustomSerializable(42), "b": CustomSerializable("hello")} 

741 ) 

742 

743 serialized = instance.serialize() 

744 loaded = CustomObjectDict.load(serialized) 

745 assert loaded == instance 

746 

747 

748def test_empty_optional_dicts(): 

749 """Test handling of None vs empty dict in optional dictionary fields""" 

750 

751 @serializable_dataclass 

752 class OptionalDictFields(SerializableDataclass): 

753 required_dict: Dict[str, int] 

754 optional_dict: Optional[Dict[str, int]] = None 

755 default_empty: Dict[str, int] = serializable_field(default_factory=dict) 

756 

757 # Test with None 

758 instance1 = OptionalDictFields(required_dict={"a": 1}, optional_dict=None) 

759 

760 # Test with empty dict 

761 instance2 = OptionalDictFields(required_dict={"a": 1}, optional_dict={}) 

762 

763 serialized1 = instance1.serialize() 

764 serialized2 = instance2.serialize() 

765 

766 loaded1 = OptionalDictFields.load(serialized1) 

767 loaded2 = OptionalDictFields.load(serialized2) 

768 

769 assert loaded1.optional_dict is None 

770 assert loaded2.optional_dict == {} 

771 assert loaded1.default_empty == {} 

772 assert loaded2.default_empty == {} 

773 

774 

775# Test inheritance hierarchies 

776@serializable_dataclass( 

777 on_typecheck_error=ErrorMode.EXCEPT, on_typecheck_mismatch=ErrorMode.EXCEPT 

778) 

779class BaseClass(SerializableDataclass): 

780 """Base class for testing inheritance""" 

781 

782 base_field: str 

783 shared_field: int = 0 

784 

785 

786@serializable_dataclass 

787class ChildClass(BaseClass): 

788 """Child class inheriting from BaseClass""" 

789 

790 child_field: float 

791 shared_field: int = 1 # Override base class field 

792 

793 

794@serializable_dataclass 

795class GrandchildClass(ChildClass): 

796 """Grandchild class for deep inheritance testing""" 

797 

798 grandchild_field: bool 

799 

800 

801def test_inheritance(): 

802 """Test inheritance behavior of serializable dataclasses""" 

803 instance = GrandchildClass( 

804 base_field="base", shared_field=42, child_field=3.14, grandchild_field=True 

805 ) 

806 

807 serialized = instance.serialize() 

808 assert serialized["base_field"] == "base" 

809 assert serialized["shared_field"] == 42 

810 assert serialized["child_field"] == 3.14 

811 assert serialized["grandchild_field"] is True 

812 

813 loaded = GrandchildClass.load(serialized) 

814 assert loaded == instance 

815 

816 # Test that we can load as parent class 

817 base_loaded = BaseClass.load({"base_field": "test", "shared_field": 1}) 

818 assert isinstance(base_loaded, BaseClass) 

819 assert not isinstance(base_loaded, ChildClass) 

820 

821 

822@pytest.mark.skip( 

823 reason="Not implemented yet, generic types not supported and throw a `TypeHintNotImplementedError`" 

824) 

825def test_generic_types(): 

826 """Test handling of generic type parameters""" 

827 

828 T = TypeVar("T") 

829 

830 @serializable_dataclass(on_typecheck_mismatch=ErrorMode.EXCEPT) 

831 class GenericContainer(SerializableDataclass, Generic[T]): 

832 """Test generic type parameters""" 

833 

834 value: T 

835 values: List[T] 

836 

837 # Test with int 

838 int_container = GenericContainer[int](value=42, values=[1, 2, 3]) 

839 serialized = int_container.serialize() 

840 loaded = GenericContainer[int].load(serialized) 

841 assert loaded == int_container 

842 

843 # Test with str 

844 str_container = GenericContainer[str](value="hello", values=["a", "b", "c"]) 

845 serialized = str_container.serialize() 

846 loaded = GenericContainer[str].load(serialized) 

847 assert loaded == str_container 

848 

849 

850# Test custom serialization/deserialization 

851class CustomObject: 

852 def __init__(self, value): 

853 self.value = value 

854 

855 def __eq__(self, other): 

856 return isinstance(other, CustomObject) and self.value == other.value 

857 

858 

859@serializable_dataclass 

860class CustomSerializationContainer(SerializableDataclass): 

861 """Test custom serialization functions""" 

862 

863 custom_obj: CustomObject = serializable_field( 

864 serialization_fn=lambda x: x.value, 

865 loading_fn=lambda x: CustomObject(x["custom_obj"]), 

866 ) 

867 transform_field: int = serializable_field( 

868 serialization_fn=lambda x: x * 2, loading_fn=lambda x: x["transform_field"] // 2 

869 ) 

870 

871 

872def test_custom_serialization_2(): 

873 """Test custom serialization and loading functions""" 

874 instance = CustomSerializationContainer( 

875 custom_obj=CustomObject(42), transform_field=10 

876 ) 

877 

878 serialized = instance.serialize() 

879 assert serialized["custom_obj"] == 42 

880 assert serialized["transform_field"] == 20 

881 

882 loaded = CustomSerializationContainer.load(serialized) 

883 assert loaded == instance 

884 assert loaded.transform_field == 10 

885 

886 

887# @pytest.mark.skip(reason="Not implemented yet, waiting on `custom_value_check_fn`") 

888# def test_value_validation(): 

889# """Test field validation""" 

890# @serializable_dataclass 

891# class ValidationContainer(SerializableDataclass): 

892# """Test validation and error handling""" 

893# positive_int: int = serializable_field( 

894# custom_value_check_fn=lambda x: x > 0 

895# ) 

896# email: str = serializable_field( 

897# custom_value_check_fn=lambda x: '@' in x 

898# ) 

899 

900# # Valid case 

901# valid = ValidationContainer(positive_int=42, email="test@example.com") 

902# assert valid.validate_fields_types() 

903 

904# # what will this do? 

905# maybe_valid = ValidationContainer(positive_int=4.2, email="test@example.com") 

906# assert maybe_valid.validate_fields_types() 

907 

908# maybe_valid_2 = ValidationContainer(positive_int=42, email=["test", "@", "example", ".com"]) 

909# assert maybe_valid_2.validate_fields_types() 

910 

911# # Invalid positive_int 

912# with pytest.raises(ValueError): 

913# ValidationContainer(positive_int=-1, email="test@example.com") 

914 

915# # Invalid email 

916# with pytest.raises(ValueError): 

917# ValidationContainer(positive_int=42, email="invalid") 

918 

919 

920def test_init_true_serialize_false(): 

921 with pytest.raises(ValueError): 

922 

923 @serializable_dataclass 

924 class MetadataContainer(SerializableDataclass): 

925 """Test field metadata and options""" 

926 

927 hidden: str = serializable_field(serialize=False, init=True) 

928 readonly: int = serializable_field(init=True, frozen=True) 

929 computed: float = serializable_field(init=False, serialize=True) 

930 

931 def __post_init__(self): 

932 object.__setattr__(self, "computed", self.readonly * 2.0) 

933 

934 

935# Test property serialization 

936@serializable_dataclass(properties_to_serialize=["full_name", "age_in_months"]) 

937class PropertyContainer(SerializableDataclass): 

938 """Test property serialization""" 

939 

940 first_name: str 

941 last_name: str 

942 age_years: int 

943 

944 @property 

945 def full_name(self) -> str: 

946 return f"{self.first_name} {self.last_name}" 

947 

948 @property 

949 def age_in_months(self) -> int: 

950 return self.age_years * 12 

951 

952 

953def test_property_serialization(): 

954 """Test serialization of properties""" 

955 instance = PropertyContainer(first_name="John", last_name="Doe", age_years=30) 

956 

957 serialized = instance.serialize() 

958 assert serialized["full_name"] == "John Doe" 

959 assert serialized["age_in_months"] == 360 

960 

961 loaded = PropertyContainer.load(serialized) 

962 assert loaded == instance 

963 

964 

965# TODO: this would be nice to fix, but not a massive issue 

966@pytest.mark.skip(reason="Not implemented yet") 

967def test_edge_cases(): 

968 """Test a sdc containing instances of itself""" 

969 

970 @serializable_dataclass 

971 class EdgeCaseContainer(SerializableDataclass): 

972 """Test edge cases and corner cases""" 

973 

974 empty_list: List[Any] = serializable_field(default_factory=list) 

975 optional_value: Optional[int] = serializable_field(default=None) 

976 union_field: Union[str, int, None] = serializable_field(default=None) 

977 recursive_ref: Optional["EdgeCaseContainer"] = serializable_field(default=None) 

978 

979 # Test recursive structure 

980 nested = EdgeCaseContainer() 

981 instance = EdgeCaseContainer(recursive_ref=nested) 

982 

983 serialized = instance.serialize() 

984 loaded = EdgeCaseContainer.load(serialized) 

985 assert loaded == instance 

986 

987 # Test empty/None handling 

988 empty = EdgeCaseContainer() 

989 assert empty.empty_list == [] 

990 assert empty.optional_value is None 

991 assert empty.union_field is None 

992 

993 # Test union field with different types 

994 instance.union_field = "string" 

995 serialized = instance.serialize() 

996 loaded = EdgeCaseContainer.load(serialized) 

997 assert loaded.union_field == "string" 

998 

999 instance.union_field = 42 

1000 serialized = instance.serialize() 

1001 loaded = EdgeCaseContainer.load(serialized) 

1002 assert loaded.union_field == 42 

1003 

1004 

1005# Test error handling for malformed data 

1006def test_error_handling(): 

1007 """Test error handling for malformed data""" 

1008 # Missing required field 

1009 with pytest.raises(TypeError): 

1010 BaseClass.load({}) 

1011 

1012 x = BaseClass(base_field=42, shared_field="invalid") 

1013 assert not x.validate_fields_types() 

1014 

1015 with pytest.raises(FieldTypeMismatchError): 

1016 BaseClass.load( 

1017 { 

1018 "base_field": 42, # Should be str 

1019 "shared_field": "invalid", # Should be int 

1020 } 

1021 ) 

1022 

1023 # Invalid format string 

1024 # with pytest.raises(ValueError): 

1025 # BaseClass.load({ 

1026 # "__format__": "InvalidClass(SerializableDataclass)", 

1027 # "base_field": "test", 

1028 # "shared_field": 0 

1029 # }) 

1030 

1031 

1032# Test for memory leaks and cyclic references 

1033# TODO: make .serialize() fail on cyclic references! see https://github.com/mivanit/muutils/issues/62 

1034@pytest.mark.skip(reason="Not implemented yet") 

1035def test_cyclic_references(): 

1036 """Test handling of cyclic references""" 

1037 

1038 @serializable_dataclass 

1039 class Node(SerializableDataclass): 

1040 value: str 

1041 next: Optional["Node"] = serializable_field(default=None) 

1042 

1043 # Create a cycle 

1044 node1 = Node("one") 

1045 node2 = Node("two") 

1046 node1.next = node2 

1047 node2.next = node1 

1048 

1049 # Ensure we can serialize without infinite recursion 

1050 serialized = node1.serialize() 

1051 loaded = Node.load(serialized) 

1052 assert loaded.value == "one" 

1053 assert loaded.next.value == "two"