Coverage for tests\test_zanj_serializable_dataclass.py: 99%

128 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-14 12:57 -0700

1from __future__ import annotations 

2 

3import sys 

4import typing 

5from pathlib import Path 

6 

7import numpy as np 

8import pandas as pd # type: ignore[import] 

9import torch 

10from muutils.json_serialize import ( 

11 SerializableDataclass, 

12 serializable_dataclass, 

13 serializable_field, 

14) 

15 

16from zanj import ZANJ 

17 

18np.random.seed(0) 

19 

20# pylint: disable=missing-function-docstring,missing-class-docstring 

21 

22TEST_DATA_PATH: Path = Path("tests/junk_data") 

23 

24SUPPORTS_KW_ONLY: bool = bool(sys.version_info >= (3, 10)) 

25 

26 

27@serializable_dataclass 

28class BasicZanj(SerializableDataclass): 

29 a: str 

30 q: int = 42 

31 c: typing.List[int] = serializable_field(default_factory=list) 

32 

33 

34def test_Basic(): 

35 instance = BasicZanj("hello", 42, [1, 2, 3]) 

36 

37 z = ZANJ() 

38 path = TEST_DATA_PATH / "test_BasicZanj.zanj" 

39 z.save(instance, path) 

40 recovered = z.read(path) 

41 assert instance == recovered 

42 

43 

44@serializable_dataclass 

45class Nested(SerializableDataclass): 

46 name: str 

47 basic: BasicZanj 

48 val: float 

49 

50 

51def test_Nested(): 

52 instance = Nested("hello", BasicZanj("hello", 42, [1, 2, 3]), 3.14) 

53 

54 z = ZANJ() 

55 path = TEST_DATA_PATH / "test_Nested.zanj" 

56 z.save(instance, path) 

57 recovered = z.read(path) 

58 assert instance == recovered 

59 

60 

61@serializable_dataclass 

62class Nested_with_container(SerializableDataclass): 

63 name: str 

64 basic: BasicZanj 

65 val: float 

66 container: typing.List[Nested] = serializable_field( 

67 default_factory=list, 

68 serialization_fn=lambda c: [n.serialize() for n in c], 

69 loading_fn=lambda data: [Nested.load(n) for n in data["container"]], 

70 ) 

71 

72 

73def test_Nested_with_container(): 

74 instance = Nested_with_container( 

75 "hello", 

76 basic=BasicZanj("hello", 42, [1, 2, 3]), 

77 val=3.14, 

78 container=[ 

79 Nested("n1", BasicZanj("n1_b", 123, [4, 5, 7]), 2.71), 

80 Nested("n2", BasicZanj("n2_b", 456, [7, 8, 9]), 6.28), 

81 ], 

82 ) 

83 

84 z = ZANJ() 

85 path = TEST_DATA_PATH / "test_Nested_with_container.zanj" 

86 z.save(instance, path) 

87 recovered = z.read(path) 

88 assert instance == recovered 

89 

90 

91@serializable_dataclass 

92class sdc_with_np_array(SerializableDataclass): 

93 name: str 

94 arr1: np.ndarray 

95 arr2: np.ndarray 

96 

97 

98def test_sdc_with_np_array_small(): 

99 instance = sdc_with_np_array("small arrays", np.random.rand(10), np.random.rand(20)) 

100 

101 z = ZANJ() 

102 path = TEST_DATA_PATH / "test_sdc_with_np_array.zanj" 

103 z.save(instance, path) 

104 recovered = z.read(path) 

105 assert instance == recovered 

106 

107 

108def test_sdc_with_np_array(): 

109 instance = sdc_with_np_array( 

110 "bigger arrays", np.random.rand(128, 128), np.random.rand(256, 256) 

111 ) 

112 

113 z = ZANJ() 

114 path = TEST_DATA_PATH / "test_sdc_with_np_array.zanj" 

115 z.save(instance, path) 

116 recovered = z.read(path) 

117 assert instance == recovered 

118 

119 

120@serializable_dataclass 

121class sdc_with_torch_tensor(SerializableDataclass): 

122 name: str 

123 tensor1: torch.Tensor 

124 tensor2: torch.Tensor 

125 

126 

127def test_sdc_tensor_small(): 

128 instance = sdc_with_torch_tensor("small tensors", torch.rand(8), torch.rand(16)) 

129 

130 z = ZANJ() 

131 path = TEST_DATA_PATH / "test_sdc_tensor_small.zanj" 

132 z.save(instance, path) 

133 recovered = z.read(path) 

134 assert instance == recovered 

135 

136 

137def test_sdc_tensor(): 

138 instance = sdc_with_torch_tensor( 

139 "bigger tensors", torch.rand(128, 128), torch.rand(256, 256) 

140 ) 

141 

142 z = ZANJ() 

143 path = TEST_DATA_PATH / "test_sdc_tensor.zanj" 

144 z.save(instance, path) 

145 recovered = z.read(path) 

146 assert instance == recovered 

147 

148 

149@serializable_dataclass 

150class sdc_with_df(SerializableDataclass): 

151 name: str 

152 iris_data: pd.DataFrame 

153 brain_data: pd.DataFrame 

154 

155 

156def test_sdc_with_df(): 

157 instance = sdc_with_df( 

158 "downloaded_data", 

159 iris_data=pd.read_csv("tests/input_data/iris.csv"), 

160 brain_data=pd.read_csv("tests/input_data/brain_networks.csv"), 

161 ) 

162 

163 z = ZANJ() 

164 path = TEST_DATA_PATH / "test_sdc_with_df.zanj" 

165 z.save(instance, path) 

166 recovered = z.read(path) 

167 assert instance == recovered 

168 

169 

170@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY) 

171class sdc_complicated(SerializableDataclass): 

172 name: str 

173 arr1: np.ndarray 

174 arr2: np.ndarray 

175 iris_data: pd.DataFrame 

176 brain_data: pd.DataFrame 

177 container: typing.List[Nested] 

178 

179 tensor: torch.Tensor 

180 

181 def __eq__(self, value): 

182 return super().__eq__(value) 

183 

184 

185def test_sdc_complicated(): 

186 instance = sdc_complicated( 

187 name="complicated data", 

188 arr1=np.random.rand(128, 128), 

189 arr2=np.random.rand(256, 256), 

190 iris_data=pd.read_csv("tests/input_data/iris.csv"), 

191 brain_data=pd.read_csv("tests/input_data/brain_networks.csv"), 

192 container=[ 

193 Nested( 

194 f"n-{n}", 

195 BasicZanj(f"n-{n}_b", n * 10 + 1, [n + 1, n + 2, n + 10]), 

196 n * np.pi, 

197 ) 

198 for n in range(10) 

199 ], 

200 tensor=torch.rand(512, 512), 

201 ) 

202 

203 z = ZANJ() 

204 path = TEST_DATA_PATH / "test_sdc_complicated.zanj" 

205 z.save(instance, path) 

206 recovered = z.read(path) 

207 assert instance == recovered 

208 

209 

210@serializable_dataclass 

211class sdc_container_explicit(SerializableDataclass): 

212 name: str 

213 container: typing.List[Nested] = serializable_field( 

214 default_factory=list, 

215 serialization_fn=lambda c: [n.serialize() for n in c], 

216 loading_fn=lambda data: [Nested.load(n) for n in data["container"]], 

217 ) 

218 

219 

220def test_sdc_container_explicit(): 

221 instance = sdc_container_explicit( 

222 "container explicit", 

223 container=[ 

224 Nested( 

225 f"n-{n}", 

226 BasicZanj(f"n-{n}_b", n * 10 + 1, [n + 1, n + 2, n + 10]), 

227 n * np.pi, 

228 ) 

229 for n in range(10) 

230 ], 

231 ) 

232 

233 z = ZANJ() 

234 path = TEST_DATA_PATH / "test_sdc_container_explicit.zanj" 

235 z.save(instance, path) 

236 recovered = z.read(path) 

237 assert instance == recovered