Coverage for tests\unit\json_serialize\serializable_dataclass\test_serializable_dataclass.py: 87%
482 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-12 20:43 -0700
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-12 20:43 -0700
1from __future__ import annotations
3from copy import deepcopy
4import typing
5from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
7import pytest
9from muutils.errormode import ErrorMode
10from muutils.json_serialize import (
11 SerializableDataclass,
12 serializable_dataclass,
13 serializable_field,
14)
16from muutils.json_serialize.serializable_dataclass import (
17 FieldIsNotInitOrSerializeWarning,
18 FieldTypeMismatchError,
19)
21# pylint: disable=missing-class-docstring, unused-variable
24@serializable_dataclass
25class BasicAutofields(SerializableDataclass):
26 a: str
27 b: int
28 c: typing.List[int]
31def test_basic_auto_fields():
32 data = dict(a="hello", b=42, c=[1, 2, 3])
33 instance = BasicAutofields(**data)
34 data_with_format = data.copy()
35 data_with_format["__format__"] = "BasicAutofields(SerializableDataclass)"
36 assert instance.serialize() == data_with_format
37 assert instance == instance
38 assert instance.diff(instance) == {}
41def test_basic_diff():
42 instance_1 = BasicAutofields(a="hello", b=42, c=[1, 2, 3])
43 instance_2 = BasicAutofields(a="goodbye", b=42, c=[1, 2, 3])
44 instance_3 = BasicAutofields(a="hello", b=-1, c=[1, 2, 3])
45 instance_4 = BasicAutofields(a="hello", b=-1, c=[42])
47 assert instance_1.diff(instance_2) == {"a": {"self": "hello", "other": "goodbye"}}
48 assert instance_1.diff(instance_3) == {"b": {"self": 42, "other": -1}}
49 assert instance_1.diff(instance_4) == {
50 "b": {"self": 42, "other": -1},
51 "c": {"self": [1, 2, 3], "other": [42]},
52 }
53 assert instance_1.diff(instance_1) == {}
54 assert instance_2.diff(instance_3) == {
55 "a": {"self": "goodbye", "other": "hello"},
56 "b": {"self": 42, "other": -1},
57 }
60@serializable_dataclass
61class SimpleFields(SerializableDataclass):
62 d: str
63 e: int = 42
64 f: typing.List[int] = serializable_field(default_factory=list) # noqa: F821
67@serializable_dataclass
68class FieldOptions(SerializableDataclass):
69 a: str = serializable_field()
70 b: str = serializable_field()
71 c: str = serializable_field(init=False, serialize=False, repr=False, compare=False)
72 d: str = serializable_field(
73 serialization_fn=lambda x: x.upper(), loading_fn=lambda x: x["d"].lower()
74 )
77@serializable_dataclass(properties_to_serialize=["full_name"])
78class WithProperty(SerializableDataclass):
79 first_name: str
80 last_name: str
82 @property
83 def full_name(self) -> str:
84 return f"{self.first_name} {self.last_name}"
87class Child(FieldOptions, WithProperty):
88 pass
91@pytest.fixture
92def simple_fields_instance():
93 return SimpleFields(d="hello", e=42, f=[1, 2, 3])
96@pytest.fixture
97def field_options_instance():
98 return FieldOptions(a="hello", b="world", d="case")
101@pytest.fixture
102def with_property_instance():
103 return WithProperty(first_name="John", last_name="Doe")
106def test_simple_fields_serialization(simple_fields_instance):
107 serialized = simple_fields_instance.serialize()
108 assert serialized == {
109 "d": "hello",
110 "e": 42,
111 "f": [1, 2, 3],
112 "__format__": "SimpleFields(SerializableDataclass)",
113 }
116def test_simple_fields_loading(simple_fields_instance):
117 serialized = simple_fields_instance.serialize()
119 loaded = SimpleFields.load(serialized)
121 assert loaded == simple_fields_instance
122 assert loaded.diff(simple_fields_instance) == {}
123 assert simple_fields_instance.diff(loaded) == {}
126def test_field_options_serialization(field_options_instance):
127 serialized = field_options_instance.serialize()
128 assert serialized == {
129 "a": "hello",
130 "b": "world",
131 "d": "CASE",
132 "__format__": "FieldOptions(SerializableDataclass)",
133 }
136def test_field_options_loading(field_options_instance):
137 # ignore a `FieldIsNotInitOrSerializeWarning`
138 serialized = field_options_instance.serialize()
139 with pytest.warns(FieldIsNotInitOrSerializeWarning):
140 loaded = FieldOptions.load(serialized)
141 assert loaded == field_options_instance
144def test_with_property_serialization(with_property_instance):
145 serialized = with_property_instance.serialize()
146 assert serialized == {
147 "first_name": "John",
148 "last_name": "Doe",
149 "full_name": "John Doe",
150 "__format__": "WithProperty(SerializableDataclass)",
151 }
154def test_with_property_loading(with_property_instance):
155 serialized = with_property_instance.serialize()
156 loaded = WithProperty.load(serialized)
157 assert loaded == with_property_instance
160@serializable_dataclass
161class Address(SerializableDataclass):
162 street: str
163 city: str
164 zip_code: str
167@serializable_dataclass
168class Person(SerializableDataclass):
169 name: str
170 age: int
171 address: Address
174@pytest.fixture
175def address_instance():
176 return Address(street="123 Main St", city="New York", zip_code="10001")
179@pytest.fixture
180def person_instance(address_instance):
181 return Person(name="John Doe", age=30, address=address_instance)
184def test_nested_serialization(person_instance):
185 serialized = person_instance.serialize()
186 expected_ser = {
187 "name": "John Doe",
188 "age": 30,
189 "address": {
190 "street": "123 Main St",
191 "city": "New York",
192 "zip_code": "10001",
193 "__format__": "Address(SerializableDataclass)",
194 },
195 "__format__": "Person(SerializableDataclass)",
196 }
197 assert serialized == expected_ser
200def test_nested_loading(person_instance):
201 serialized = person_instance.serialize()
202 loaded = Person.load(serialized)
203 assert loaded == person_instance
204 assert loaded.address == person_instance.address
207def test_with_printing():
208 @serializable_dataclass(properties_to_serialize=["full_name"])
209 class MyClass(SerializableDataclass):
210 name: str
211 age: int = serializable_field(
212 serialization_fn=lambda x: x + 1, loading_fn=lambda x: x["age"] - 1
213 )
214 items: list = serializable_field(default_factory=list)
216 @property
217 def full_name(self) -> str:
218 return f"{self.name} Doe"
220 # Usage
221 my_instance = MyClass(name="John", age=30, items=["apple", "banana"])
222 serialized_data = my_instance.serialize()
223 print(serialized_data)
225 loaded_instance = MyClass.load(serialized_data)
226 print(loaded_instance)
229def test_simple_class_serialization():
230 @serializable_dataclass
231 class SimpleClass(SerializableDataclass):
232 a: int
233 b: str
235 simple = SimpleClass(a=42, b="hello")
236 serialized = simple.serialize()
237 assert serialized == {
238 "a": 42,
239 "b": "hello",
240 "__format__": "SimpleClass(SerializableDataclass)",
241 }
243 loaded = SimpleClass.load(serialized)
244 assert loaded == simple
247def test_error_when_init_and_not_serialize():
248 with pytest.raises(ValueError):
250 @serializable_dataclass
251 class SimpleClass(SerializableDataclass):
252 a: int = serializable_field(init=True, serialize=False)
255def test_person_serialization():
256 @serializable_dataclass(properties_to_serialize=["full_name"])
257 class FullPerson(SerializableDataclass):
258 name: str = serializable_field()
259 age: int = serializable_field(default=-1)
260 items: typing.List[str] = serializable_field(default_factory=list)
262 @property
263 def full_name(self) -> str:
264 return f"{self.name} Doe"
266 person = FullPerson(name="John", items=["apple", "banana"])
267 serialized = person.serialize()
268 expected_ser = {
269 "name": "John",
270 "age": -1,
271 "items": ["apple", "banana"],
272 "full_name": "John Doe",
273 "__format__": "FullPerson(SerializableDataclass)",
274 }
275 assert serialized == expected_ser, f"Expected {expected_ser}, got {serialized}"
277 loaded = FullPerson.load(serialized)
279 assert loaded == person
282def test_custom_serialization():
283 @serializable_dataclass
284 class CustomSerialization(SerializableDataclass):
285 data: Any = serializable_field(
286 serialization_fn=lambda x: x * 2, loading_fn=lambda x: x["data"] // 2
287 )
289 custom = CustomSerialization(data=5)
290 serialized = custom.serialize()
291 assert serialized == {
292 "data": 10,
293 "__format__": "CustomSerialization(SerializableDataclass)",
294 }
296 loaded = CustomSerialization.load(serialized)
297 assert loaded == custom
300@serializable_dataclass
301class Nested_with_Container(SerializableDataclass):
302 val_int: int
303 val_str: str
304 val_list: typing.List[BasicAutofields] = serializable_field(
305 default_factory=list,
306 serialization_fn=lambda x: [y.serialize() for y in x],
307 loading_fn=lambda x: [BasicAutofields.load(y) for y in x["val_list"]],
308 )
311def test_nested_with_container():
312 instance = Nested_with_Container(
313 val_int=42,
314 val_str="hello",
315 val_list=[
316 BasicAutofields(a="a", b=1, c=[1, 2, 3]),
317 BasicAutofields(a="b", b=2, c=[4, 5, 6]),
318 ],
319 )
321 serialized = instance.serialize()
322 expected_ser = {
323 "val_int": 42,
324 "val_str": "hello",
325 "val_list": [
326 {
327 "a": "a",
328 "b": 1,
329 "c": [1, 2, 3],
330 "__format__": "BasicAutofields(SerializableDataclass)",
331 },
332 {
333 "a": "b",
334 "b": 2,
335 "c": [4, 5, 6],
336 "__format__": "BasicAutofields(SerializableDataclass)",
337 },
338 ],
339 "__format__": "Nested_with_Container(SerializableDataclass)",
340 }
342 assert serialized == expected_ser
344 loaded = Nested_with_Container.load(serialized)
346 assert loaded == instance
349class Custom_class_with_serialization:
350 """custom class which doesnt inherit but does serialize"""
352 def __init__(self, a: int, b: str):
353 self.a: int = a
354 self.b: str = b
356 def serialize(self):
357 return {"a": self.a, "b": self.b}
359 @classmethod
360 def load(cls, data):
361 return cls(data["a"], data["b"])
363 def __eq__(self, other):
364 return (self.a == other.a) and (self.b == other.b)
367@serializable_dataclass
368class nested_custom(SerializableDataclass):
369 value: float
370 data1: Custom_class_with_serialization
373def test_nested_custom(recwarn): # this will send some warnings but whatever
374 instance = nested_custom(
375 value=42.0, data1=Custom_class_with_serialization(1, "hello")
376 )
377 serialized = instance.serialize()
378 expected_ser = {
379 "value": 42.0,
380 "data1": {"a": 1, "b": "hello"},
381 "__format__": "nested_custom(SerializableDataclass)",
382 }
383 assert serialized == expected_ser
384 loaded = nested_custom.load(serialized)
385 assert loaded == instance
388def test_deserialize_fn():
389 @serializable_dataclass
390 class DeserializeFn(SerializableDataclass):
391 data: int = serializable_field(
392 serialization_fn=lambda x: str(x),
393 deserialize_fn=lambda x: int(x),
394 )
396 instance = DeserializeFn(data=5)
397 serialized = instance.serialize()
398 assert serialized == {
399 "data": "5",
400 "__format__": "DeserializeFn(SerializableDataclass)",
401 }
403 loaded = DeserializeFn.load(serialized)
404 assert loaded == instance
405 assert loaded.data == 5
408@serializable_dataclass
409class DictContainer(SerializableDataclass):
410 """Test class containing a dictionary field"""
412 simple_dict: Dict[str, int]
413 nested_dict: Dict[str, Dict[str, int]] = serializable_field(default_factory=dict)
414 optional_dict: Dict[str, str] = serializable_field(default_factory=dict)
417def test_dict_serialization():
418 """Test serialization of dictionaries within SerializableDataclass"""
419 data = DictContainer(
420 simple_dict={"a": 1, "b": 2},
421 nested_dict={"x": {"y": 3, "z": 4}},
422 optional_dict={"hello": "world"},
423 )
425 serialized = data.serialize()
426 expected = {
427 "__format__": "DictContainer(SerializableDataclass)",
428 "simple_dict": {"a": 1, "b": 2},
429 "nested_dict": {"x": {"y": 3, "z": 4}},
430 "optional_dict": {"hello": "world"},
431 }
433 assert serialized == expected
436def test_dict_loading():
437 """Test loading dictionaries into SerializableDataclass"""
438 original_data = {
439 "__format__": "DictContainer(SerializableDataclass)",
440 "simple_dict": {"a": 1, "b": 2},
441 "nested_dict": {"x": {"y": 3, "z": 4}},
442 "optional_dict": {"hello": "world"},
443 }
445 loaded = DictContainer.load(original_data)
446 assert loaded.simple_dict == {"a": 1, "b": 2}
447 assert loaded.nested_dict == {"x": {"y": 3, "z": 4}}
448 assert loaded.optional_dict == {"hello": "world"}
451def test_dict_equality():
452 """Test equality comparison of dictionaries within SerializableDataclass"""
453 instance1 = DictContainer(
454 simple_dict={"a": 1, "b": 2},
455 nested_dict={"x": {"y": 3, "z": 4}},
456 optional_dict={"hello": "world"},
457 )
459 instance2 = DictContainer(
460 simple_dict={"a": 1, "b": 2},
461 nested_dict={"x": {"y": 3, "z": 4}},
462 optional_dict={"hello": "world"},
463 )
465 instance3 = DictContainer(
466 simple_dict={"a": 1, "b": 3}, # Different value
467 nested_dict={"x": {"y": 3, "z": 4}},
468 optional_dict={"hello": "world"},
469 )
471 assert instance1 == instance2
472 assert instance1 != instance3
473 assert instance2 != instance3
476def test_dict_diff():
477 """Test diff functionality with dictionaries"""
478 instance1 = DictContainer(
479 simple_dict={"a": 1, "b": 2},
480 nested_dict={"x": {"y": 3, "z": 4}},
481 optional_dict={"hello": "world"},
482 )
484 # Different simple_dict value
485 instance2 = DictContainer(
486 simple_dict={"a": 1, "b": 3},
487 nested_dict={"x": {"y": 3, "z": 4}},
488 optional_dict={"hello": "world"},
489 )
491 # Different nested_dict value
492 instance3 = DictContainer(
493 simple_dict={"a": 1, "b": 2},
494 nested_dict={"x": {"y": 3, "z": 5}},
495 optional_dict={"hello": "world"},
496 )
498 # Different optional_dict value
499 instance4 = DictContainer(
500 simple_dict={"a": 1, "b": 2},
501 nested_dict={"x": {"y": 3, "z": 4}},
502 optional_dict={"hello": "python"},
503 )
505 # Test diff with simple_dict changes
506 diff1 = instance1.diff(instance2)
507 assert diff1 == {
508 "simple_dict": {"self": {"a": 1, "b": 2}, "other": {"a": 1, "b": 3}}
509 }
511 # Test diff with nested_dict changes
512 diff2 = instance1.diff(instance3)
513 assert diff2 == {
514 "nested_dict": {
515 "self": {"x": {"y": 3, "z": 4}},
516 "other": {"x": {"y": 3, "z": 5}},
517 }
518 }
520 # Test diff with optional_dict changes
521 diff3 = instance1.diff(instance4)
522 assert diff3 == {
523 "optional_dict": {"self": {"hello": "world"}, "other": {"hello": "python"}}
524 }
526 # Test no diff when comparing identical instances
527 assert instance1.diff(instance1) == {}
530@serializable_dataclass
531class ComplexDictContainer(SerializableDataclass):
532 """Test class with more complex dictionary structures"""
534 mixed_dict: Dict[str, Any]
535 list_dict: Dict[str, typing.List[int]]
536 multi_nested: Dict[str, Dict[str, Dict[str, int]]]
539def test_complex_dict_serialization():
540 """Test serialization of more complex dictionary structures"""
541 data = ComplexDictContainer(
542 mixed_dict={"str": "hello", "int": 42, "list": [1, 2, 3]},
543 list_dict={"a": [1, 2, 3], "b": [4, 5, 6]},
544 multi_nested={"x": {"y": {"z": 1, "w": 2}, "v": {"u": 3, "t": 4}}},
545 )
547 serialized = data.serialize()
548 loaded = ComplexDictContainer.load(serialized)
549 assert loaded == data
550 assert loaded.diff(data) == {}
553def test_empty_dicts():
554 """Test handling of empty dictionaries"""
555 data = DictContainer(simple_dict={}, nested_dict={}, optional_dict={})
557 serialized = data.serialize()
558 loaded = DictContainer.load(serialized)
559 assert loaded == data
560 assert loaded.diff(data) == {}
562 # Test equality with another empty instance
563 another_empty = DictContainer(simple_dict={}, nested_dict={}, optional_dict={})
564 assert data == another_empty
567# Test invalid dictionary type validation
568@serializable_dataclass(on_typecheck_mismatch=ErrorMode.EXCEPT)
569class StrictDictContainer(SerializableDataclass):
570 """Test class with strict dictionary typing"""
572 int_dict: Dict[str, int]
573 str_dict: Dict[str, str]
574 float_dict: Dict[str, float]
577# TODO: figure this out
578@pytest.mark.skip(reason="dict type validation doesnt seem to work")
579def test_dict_type_validation():
580 """Test type validation for dictionary values"""
581 # Valid case
582 valid = StrictDictContainer(
583 int_dict={"a": 1, "b": 2},
584 str_dict={"x": "hello", "y": "world"},
585 float_dict={"m": 1.0, "n": 2.5},
586 )
587 assert valid.validate_fields_types()
589 # Invalid int_dict
590 with pytest.raises(FieldTypeMismatchError):
591 StrictDictContainer(
592 int_dict={"a": "not an int"}, # Type error
593 str_dict={"x": "hello"},
594 float_dict={"m": 1.0},
595 )
597 # Invalid str_dict
598 with pytest.raises(FieldTypeMismatchError):
599 StrictDictContainer(
600 int_dict={"a": 1},
601 str_dict={"x": 123}, # Type error
602 float_dict={"m": 1.0},
603 )
606# Test dictionary with optional values
607@serializable_dataclass
608class OptionalDictContainer(SerializableDataclass):
609 """Test class with optional dictionary values"""
611 optional_values: Dict[str, Optional[int]]
612 union_values: Dict[str, Union[int, str]]
613 nullable_dict: Optional[Dict[str, int]] = None
616def test_optional_dict_values():
617 """Test dictionaries with optional/union values"""
618 instance = OptionalDictContainer(
619 optional_values={"a": 1, "b": None, "c": 3},
620 union_values={"x": 1, "y": "string", "z": 42},
621 nullable_dict={"m": 1, "n": 2},
622 )
624 serialized = instance.serialize()
625 loaded = OptionalDictContainer.load(serialized)
626 assert loaded == instance
628 # Test with None dict
629 instance2 = OptionalDictContainer(
630 optional_values={"a": None, "b": None},
631 union_values={"x": "all strings", "y": "here"},
632 nullable_dict=None,
633 )
635 serialized2 = instance2.serialize()
636 loaded2 = OptionalDictContainer.load(serialized2)
637 assert loaded2 == instance2
640# Test dictionary mutation
641def test_dict_mutation():
642 """Test behavior when mutating dictionary contents"""
643 instance1 = DictContainer(
644 simple_dict={"a": 1, "b": 2},
645 nested_dict={"x": {"y": 3}},
646 optional_dict={"hello": "world"},
647 )
649 instance2 = deepcopy(instance1)
651 # Mutate dictionary in instance1
652 instance1.simple_dict["c"] = 3
653 instance1.nested_dict["x"]["z"] = 4
654 instance1.optional_dict["new"] = "value"
656 # Verify instance2 was not affected
657 assert instance2.simple_dict == {"a": 1, "b": 2}
658 assert instance2.nested_dict == {"x": {"y": 3}}
659 assert instance2.optional_dict == {"hello": "world"}
661 # Verify diff shows the changes
662 diff = instance2.diff(instance1)
663 assert "simple_dict" in diff
664 assert "nested_dict" in diff
665 assert "optional_dict" in diff
668# Test dictionary key types
669@serializable_dataclass
670class IntKeyDictContainer(SerializableDataclass):
671 """Test class with non-string dictionary keys"""
673 int_keys: Dict[int, str] = serializable_field(
674 serialization_fn=lambda x: {str(k): v for k, v in x.items()},
675 loading_fn=lambda x: {int(k): v for k, v in x["int_keys"].items()},
676 )
679def test_non_string_dict_keys():
680 """Test handling of dictionaries with non-string keys"""
681 instance = IntKeyDictContainer(int_keys={1: "one", 2: "two", 3: "three"})
683 serialized = instance.serialize()
684 # Keys should be converted to strings in serialized form
685 assert all(isinstance(k, str) for k in serialized["int_keys"].keys())
687 loaded = IntKeyDictContainer.load(serialized)
688 # Keys should be integers again after loading
689 assert all(isinstance(k, int) for k in loaded.int_keys.keys())
690 assert loaded == instance
693@serializable_dataclass
694class RecursiveDictContainer(SerializableDataclass):
695 """Test class with recursively defined dictionary type"""
697 data: Dict[str, Any]
700def test_recursive_dict_structure():
701 """Test handling of recursively nested dictionaries"""
702 deep_dict = {
703 "level1": {
704 "level2": {"level3": {"value": 42, "list": [1, 2, {"nested": "value"}]}}
705 }
706 }
708 instance = RecursiveDictContainer(data=deep_dict)
709 serialized = instance.serialize()
710 loaded = RecursiveDictContainer.load(serialized)
712 assert loaded == instance
713 assert loaded.data == deep_dict
716# need to define this outside, otherwise the validator cant see it?
717class CustomSerializable:
718 def __init__(self, value):
719 self.value: Union[str, int] = value
721 def serialize(self):
722 return {"value": self.value}
724 @classmethod
725 def load(cls, data):
726 return cls(data["value"])
728 def __eq__(self, other):
729 return isinstance(other, CustomSerializable) and self.value == other.value
732def test_dict_with_custom_objects():
733 """Test dictionaries containing custom objects that implement serialize/load"""
735 @serializable_dataclass
736 class CustomObjectDict(SerializableDataclass):
737 objects: Dict[str, CustomSerializable]
739 instance = CustomObjectDict(
740 objects={"a": CustomSerializable(42), "b": CustomSerializable("hello")}
741 )
743 serialized = instance.serialize()
744 loaded = CustomObjectDict.load(serialized)
745 assert loaded == instance
748def test_empty_optional_dicts():
749 """Test handling of None vs empty dict in optional dictionary fields"""
751 @serializable_dataclass
752 class OptionalDictFields(SerializableDataclass):
753 required_dict: Dict[str, int]
754 optional_dict: Optional[Dict[str, int]] = None
755 default_empty: Dict[str, int] = serializable_field(default_factory=dict)
757 # Test with None
758 instance1 = OptionalDictFields(required_dict={"a": 1}, optional_dict=None)
760 # Test with empty dict
761 instance2 = OptionalDictFields(required_dict={"a": 1}, optional_dict={})
763 serialized1 = instance1.serialize()
764 serialized2 = instance2.serialize()
766 loaded1 = OptionalDictFields.load(serialized1)
767 loaded2 = OptionalDictFields.load(serialized2)
769 assert loaded1.optional_dict is None
770 assert loaded2.optional_dict == {}
771 assert loaded1.default_empty == {}
772 assert loaded2.default_empty == {}
775# Test inheritance hierarchies
776@serializable_dataclass(
777 on_typecheck_error=ErrorMode.EXCEPT, on_typecheck_mismatch=ErrorMode.EXCEPT
778)
779class BaseClass(SerializableDataclass):
780 """Base class for testing inheritance"""
782 base_field: str
783 shared_field: int = 0
786@serializable_dataclass
787class ChildClass(BaseClass):
788 """Child class inheriting from BaseClass"""
790 child_field: float
791 shared_field: int = 1 # Override base class field
794@serializable_dataclass
795class GrandchildClass(ChildClass):
796 """Grandchild class for deep inheritance testing"""
798 grandchild_field: bool
801def test_inheritance():
802 """Test inheritance behavior of serializable dataclasses"""
803 instance = GrandchildClass(
804 base_field="base", shared_field=42, child_field=3.14, grandchild_field=True
805 )
807 serialized = instance.serialize()
808 assert serialized["base_field"] == "base"
809 assert serialized["shared_field"] == 42
810 assert serialized["child_field"] == 3.14
811 assert serialized["grandchild_field"] is True
813 loaded = GrandchildClass.load(serialized)
814 assert loaded == instance
816 # Test that we can load as parent class
817 base_loaded = BaseClass.load({"base_field": "test", "shared_field": 1})
818 assert isinstance(base_loaded, BaseClass)
819 assert not isinstance(base_loaded, ChildClass)
822@pytest.mark.skip(
823 reason="Not implemented yet, generic types not supported and throw a `TypeHintNotImplementedError`"
824)
825def test_generic_types():
826 """Test handling of generic type parameters"""
828 T = TypeVar("T")
830 @serializable_dataclass(on_typecheck_mismatch=ErrorMode.EXCEPT)
831 class GenericContainer(SerializableDataclass, Generic[T]):
832 """Test generic type parameters"""
834 value: T
835 values: List[T]
837 # Test with int
838 int_container = GenericContainer[int](value=42, values=[1, 2, 3])
839 serialized = int_container.serialize()
840 loaded = GenericContainer[int].load(serialized)
841 assert loaded == int_container
843 # Test with str
844 str_container = GenericContainer[str](value="hello", values=["a", "b", "c"])
845 serialized = str_container.serialize()
846 loaded = GenericContainer[str].load(serialized)
847 assert loaded == str_container
850# Test custom serialization/deserialization
851class CustomObject:
852 def __init__(self, value):
853 self.value = value
855 def __eq__(self, other):
856 return isinstance(other, CustomObject) and self.value == other.value
859@serializable_dataclass
860class CustomSerializationContainer(SerializableDataclass):
861 """Test custom serialization functions"""
863 custom_obj: CustomObject = serializable_field(
864 serialization_fn=lambda x: x.value,
865 loading_fn=lambda x: CustomObject(x["custom_obj"]),
866 )
867 transform_field: int = serializable_field(
868 serialization_fn=lambda x: x * 2, loading_fn=lambda x: x["transform_field"] // 2
869 )
872def test_custom_serialization_2():
873 """Test custom serialization and loading functions"""
874 instance = CustomSerializationContainer(
875 custom_obj=CustomObject(42), transform_field=10
876 )
878 serialized = instance.serialize()
879 assert serialized["custom_obj"] == 42
880 assert serialized["transform_field"] == 20
882 loaded = CustomSerializationContainer.load(serialized)
883 assert loaded == instance
884 assert loaded.transform_field == 10
887# @pytest.mark.skip(reason="Not implemented yet, waiting on `custom_value_check_fn`")
888# def test_value_validation():
889# """Test field validation"""
890# @serializable_dataclass
891# class ValidationContainer(SerializableDataclass):
892# """Test validation and error handling"""
893# positive_int: int = serializable_field(
894# custom_value_check_fn=lambda x: x > 0
895# )
896# email: str = serializable_field(
897# custom_value_check_fn=lambda x: '@' in x
898# )
900# # Valid case
901# valid = ValidationContainer(positive_int=42, email="test@example.com")
902# assert valid.validate_fields_types()
904# # what will this do?
905# maybe_valid = ValidationContainer(positive_int=4.2, email="test@example.com")
906# assert maybe_valid.validate_fields_types()
908# maybe_valid_2 = ValidationContainer(positive_int=42, email=["test", "@", "example", ".com"])
909# assert maybe_valid_2.validate_fields_types()
911# # Invalid positive_int
912# with pytest.raises(ValueError):
913# ValidationContainer(positive_int=-1, email="test@example.com")
915# # Invalid email
916# with pytest.raises(ValueError):
917# ValidationContainer(positive_int=42, email="invalid")
920def test_init_true_serialize_false():
921 with pytest.raises(ValueError):
923 @serializable_dataclass
924 class MetadataContainer(SerializableDataclass):
925 """Test field metadata and options"""
927 hidden: str = serializable_field(serialize=False, init=True)
928 readonly: int = serializable_field(init=True, frozen=True)
929 computed: float = serializable_field(init=False, serialize=True)
931 def __post_init__(self):
932 object.__setattr__(self, "computed", self.readonly * 2.0)
935# Test property serialization
936@serializable_dataclass(properties_to_serialize=["full_name", "age_in_months"])
937class PropertyContainer(SerializableDataclass):
938 """Test property serialization"""
940 first_name: str
941 last_name: str
942 age_years: int
944 @property
945 def full_name(self) -> str:
946 return f"{self.first_name} {self.last_name}"
948 @property
949 def age_in_months(self) -> int:
950 return self.age_years * 12
953def test_property_serialization():
954 """Test serialization of properties"""
955 instance = PropertyContainer(first_name="John", last_name="Doe", age_years=30)
957 serialized = instance.serialize()
958 assert serialized["full_name"] == "John Doe"
959 assert serialized["age_in_months"] == 360
961 loaded = PropertyContainer.load(serialized)
962 assert loaded == instance
965# TODO: this would be nice to fix, but not a massive issue
966@pytest.mark.skip(reason="Not implemented yet")
967def test_edge_cases():
968 """Test a sdc containing instances of itself"""
970 @serializable_dataclass
971 class EdgeCaseContainer(SerializableDataclass):
972 """Test edge cases and corner cases"""
974 empty_list: List[Any] = serializable_field(default_factory=list)
975 optional_value: Optional[int] = serializable_field(default=None)
976 union_field: Union[str, int, None] = serializable_field(default=None)
977 recursive_ref: Optional["EdgeCaseContainer"] = serializable_field(default=None)
979 # Test recursive structure
980 nested = EdgeCaseContainer()
981 instance = EdgeCaseContainer(recursive_ref=nested)
983 serialized = instance.serialize()
984 loaded = EdgeCaseContainer.load(serialized)
985 assert loaded == instance
987 # Test empty/None handling
988 empty = EdgeCaseContainer()
989 assert empty.empty_list == []
990 assert empty.optional_value is None
991 assert empty.union_field is None
993 # Test union field with different types
994 instance.union_field = "string"
995 serialized = instance.serialize()
996 loaded = EdgeCaseContainer.load(serialized)
997 assert loaded.union_field == "string"
999 instance.union_field = 42
1000 serialized = instance.serialize()
1001 loaded = EdgeCaseContainer.load(serialized)
1002 assert loaded.union_field == 42
1005# Test error handling for malformed data
1006def test_error_handling():
1007 """Test error handling for malformed data"""
1008 # Missing required field
1009 with pytest.raises(TypeError):
1010 BaseClass.load({})
1012 x = BaseClass(base_field=42, shared_field="invalid")
1013 assert not x.validate_fields_types()
1015 with pytest.raises(FieldTypeMismatchError):
1016 BaseClass.load(
1017 {
1018 "base_field": 42, # Should be str
1019 "shared_field": "invalid", # Should be int
1020 }
1021 )
1023 # Invalid format string
1024 # with pytest.raises(ValueError):
1025 # BaseClass.load({
1026 # "__format__": "InvalidClass(SerializableDataclass)",
1027 # "base_field": "test",
1028 # "shared_field": 0
1029 # })
1032# Test for memory leaks and cyclic references
1033# TODO: make .serialize() fail on cyclic references! see https://github.com/mivanit/muutils/issues/62
1034@pytest.mark.skip(reason="Not implemented yet")
1035def test_cyclic_references():
1036 """Test handling of cyclic references"""
1038 @serializable_dataclass
1039 class Node(SerializableDataclass):
1040 value: str
1041 next: Optional["Node"] = serializable_field(default=None)
1043 # Create a cycle
1044 node1 = Node("one")
1045 node2 = Node("two")
1046 node1.next = node2
1047 node2.next = node1
1049 # Ensure we can serialize without infinite recursion
1050 serialized = node1.serialize()
1051 loaded = Node.load(serialized)
1052 assert loaded.value == "one"
1053 assert loaded.next.value == "two"