Coverage for tests\test_zanj_serializable_dataclass.py: 99%
128 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
1from __future__ import annotations
3import sys
4import typing
5from pathlib import Path
7import numpy as np
8import pandas as pd # type: ignore[import]
9import torch
10from muutils.json_serialize import (
11 SerializableDataclass,
12 serializable_dataclass,
13 serializable_field,
14)
16from zanj import ZANJ
18np.random.seed(0)
20# pylint: disable=missing-function-docstring,missing-class-docstring
22TEST_DATA_PATH: Path = Path("tests/junk_data")
24SUPPORTS_KW_ONLY: bool = bool(sys.version_info >= (3, 10))
27@serializable_dataclass
28class BasicZanj(SerializableDataclass):
29 a: str
30 q: int = 42
31 c: typing.List[int] = serializable_field(default_factory=list)
34def test_Basic():
35 instance = BasicZanj("hello", 42, [1, 2, 3])
37 z = ZANJ()
38 path = TEST_DATA_PATH / "test_BasicZanj.zanj"
39 z.save(instance, path)
40 recovered = z.read(path)
41 assert instance == recovered
44@serializable_dataclass
45class Nested(SerializableDataclass):
46 name: str
47 basic: BasicZanj
48 val: float
51def test_Nested():
52 instance = Nested("hello", BasicZanj("hello", 42, [1, 2, 3]), 3.14)
54 z = ZANJ()
55 path = TEST_DATA_PATH / "test_Nested.zanj"
56 z.save(instance, path)
57 recovered = z.read(path)
58 assert instance == recovered
61@serializable_dataclass
62class Nested_with_container(SerializableDataclass):
63 name: str
64 basic: BasicZanj
65 val: float
66 container: typing.List[Nested] = serializable_field(
67 default_factory=list,
68 serialization_fn=lambda c: [n.serialize() for n in c],
69 loading_fn=lambda data: [Nested.load(n) for n in data["container"]],
70 )
73def test_Nested_with_container():
74 instance = Nested_with_container(
75 "hello",
76 basic=BasicZanj("hello", 42, [1, 2, 3]),
77 val=3.14,
78 container=[
79 Nested("n1", BasicZanj("n1_b", 123, [4, 5, 7]), 2.71),
80 Nested("n2", BasicZanj("n2_b", 456, [7, 8, 9]), 6.28),
81 ],
82 )
84 z = ZANJ()
85 path = TEST_DATA_PATH / "test_Nested_with_container.zanj"
86 z.save(instance, path)
87 recovered = z.read(path)
88 assert instance == recovered
91@serializable_dataclass
92class sdc_with_np_array(SerializableDataclass):
93 name: str
94 arr1: np.ndarray
95 arr2: np.ndarray
98def test_sdc_with_np_array_small():
99 instance = sdc_with_np_array("small arrays", np.random.rand(10), np.random.rand(20))
101 z = ZANJ()
102 path = TEST_DATA_PATH / "test_sdc_with_np_array.zanj"
103 z.save(instance, path)
104 recovered = z.read(path)
105 assert instance == recovered
108def test_sdc_with_np_array():
109 instance = sdc_with_np_array(
110 "bigger arrays", np.random.rand(128, 128), np.random.rand(256, 256)
111 )
113 z = ZANJ()
114 path = TEST_DATA_PATH / "test_sdc_with_np_array.zanj"
115 z.save(instance, path)
116 recovered = z.read(path)
117 assert instance == recovered
120@serializable_dataclass
121class sdc_with_torch_tensor(SerializableDataclass):
122 name: str
123 tensor1: torch.Tensor
124 tensor2: torch.Tensor
127def test_sdc_tensor_small():
128 instance = sdc_with_torch_tensor("small tensors", torch.rand(8), torch.rand(16))
130 z = ZANJ()
131 path = TEST_DATA_PATH / "test_sdc_tensor_small.zanj"
132 z.save(instance, path)
133 recovered = z.read(path)
134 assert instance == recovered
137def test_sdc_tensor():
138 instance = sdc_with_torch_tensor(
139 "bigger tensors", torch.rand(128, 128), torch.rand(256, 256)
140 )
142 z = ZANJ()
143 path = TEST_DATA_PATH / "test_sdc_tensor.zanj"
144 z.save(instance, path)
145 recovered = z.read(path)
146 assert instance == recovered
149@serializable_dataclass
150class sdc_with_df(SerializableDataclass):
151 name: str
152 iris_data: pd.DataFrame
153 brain_data: pd.DataFrame
156def test_sdc_with_df():
157 instance = sdc_with_df(
158 "downloaded_data",
159 iris_data=pd.read_csv("tests/input_data/iris.csv"),
160 brain_data=pd.read_csv("tests/input_data/brain_networks.csv"),
161 )
163 z = ZANJ()
164 path = TEST_DATA_PATH / "test_sdc_with_df.zanj"
165 z.save(instance, path)
166 recovered = z.read(path)
167 assert instance == recovered
170@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY)
171class sdc_complicated(SerializableDataclass):
172 name: str
173 arr1: np.ndarray
174 arr2: np.ndarray
175 iris_data: pd.DataFrame
176 brain_data: pd.DataFrame
177 container: typing.List[Nested]
179 tensor: torch.Tensor
181 def __eq__(self, value):
182 return super().__eq__(value)
185def test_sdc_complicated():
186 instance = sdc_complicated(
187 name="complicated data",
188 arr1=np.random.rand(128, 128),
189 arr2=np.random.rand(256, 256),
190 iris_data=pd.read_csv("tests/input_data/iris.csv"),
191 brain_data=pd.read_csv("tests/input_data/brain_networks.csv"),
192 container=[
193 Nested(
194 f"n-{n}",
195 BasicZanj(f"n-{n}_b", n * 10 + 1, [n + 1, n + 2, n + 10]),
196 n * np.pi,
197 )
198 for n in range(10)
199 ],
200 tensor=torch.rand(512, 512),
201 )
203 z = ZANJ()
204 path = TEST_DATA_PATH / "test_sdc_complicated.zanj"
205 z.save(instance, path)
206 recovered = z.read(path)
207 assert instance == recovered
210@serializable_dataclass
211class sdc_container_explicit(SerializableDataclass):
212 name: str
213 container: typing.List[Nested] = serializable_field(
214 default_factory=list,
215 serialization_fn=lambda c: [n.serialize() for n in c],
216 loading_fn=lambda data: [Nested.load(n) for n in data["container"]],
217 )
220def test_sdc_container_explicit():
221 instance = sdc_container_explicit(
222 "container explicit",
223 container=[
224 Nested(
225 f"n-{n}",
226 BasicZanj(f"n-{n}_b", n * 10 + 1, [n + 1, n + 2, n + 10]),
227 n * np.pi,
228 )
229 for n in range(10)
230 ],
231 )
233 z = ZANJ()
234 path = TEST_DATA_PATH / "test_sdc_container_explicit.zanj"
235 z.save(instance, path)
236 recovered = z.read(path)
237 assert instance == recovered