Coverage for tests\test_bool_array.py: 100%
46 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
1from pathlib import Path
3import numpy as np
4import torch
5from muutils.json_serialize import SerializableDataclass, serializable_dataclass
7from zanj import ZANJ
9TEST_DATA_PATH: Path = Path("tests/junk_data")
12@serializable_dataclass
13class MyClass_list(SerializableDataclass):
14 name: str
15 arr_1: list
16 arr_2: list
19def test_list_bool_array():
20 fname: Path = TEST_DATA_PATH / "test_list_bool_array.zanj"
21 c: MyClass_list = MyClass_list(
22 name="test",
23 arr_1=[True, False, True],
24 arr_2=[True, False, True],
25 )
27 z = ZANJ()
29 z.save(c, fname)
31 c2: MyClass_list = z.read(fname)
33 assert c == c2
36@serializable_dataclass
37class MyClass_np(SerializableDataclass):
38 name: str
39 arr_1: np.ndarray
40 arr_2: np.ndarray
43def test_np_bool_array():
44 fname: Path = TEST_DATA_PATH / "test_np_bool_array.zanj"
45 c: MyClass_np = MyClass_np(
46 name="test",
47 arr_1=np.array([True, False, True]),
48 arr_2=np.array([True, False, True]),
49 )
51 z = ZANJ()
53 z.save(c, fname)
55 c2: MyClass_np = z.read(fname)
57 assert c2.arr_1.dtype == np.bool_
58 assert c2.arr_2.dtype == np.bool_
60 assert c == c2
63@serializable_dataclass
64class MyClass_torch(SerializableDataclass):
65 name: str
66 arr_1: torch.Tensor
67 arr_2: torch.Tensor
70def test_torch_bool_array():
71 fname: Path = TEST_DATA_PATH / "test_torch_bool_array.zanj"
72 c: MyClass_torch = MyClass_torch(
73 name="test",
74 arr_1=torch.tensor([True, False, True]),
75 arr_2=torch.tensor([True, False, True]),
76 )
78 z = ZANJ()
80 z.save(c, fname)
82 c2: MyClass_torch = z.read(fname)
84 assert c2.arr_1.dtype == torch.bool
85 assert c2.arr_2.dtype == torch.bool
87 assert c == c2