Coverage for tests/unit/maze_dataset/dataset/test_collected_dataset.py: 100%
45 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
1from functools import cached_property
3import numpy as np
5from maze_dataset.dataset.collected_dataset import (
6 MazeDatasetCollection,
7 MazeDatasetCollectionConfig,
8 MazeDatasetConfig,
9)
11DATASET_LENGTHS: list[int] = [1, 0, 3, 2, 1]
12DATASET_GRID_SIZES: list[int] = [5, 1, 3, 3, 4]
15class TestMazeDatasetCollection:
16 @cached_property
17 def test_collection(self) -> MazeDatasetCollection:
18 config = MazeDatasetCollectionConfig(
19 name="test_collection",
20 maze_dataset_configs=[
21 MazeDatasetConfig(
22 n_mazes=n_mazes,
23 grid_n=grid_n,
24 name=f"test_dataset_{n_mazes}_{grid_n}",
25 )
26 for n_mazes, grid_n in zip(
27 DATASET_LENGTHS,
28 DATASET_GRID_SIZES,
29 strict=False,
30 )
31 ],
32 )
33 return MazeDatasetCollection.from_config(
34 config,
35 do_generate=True,
36 load_local=False,
37 do_download=False,
38 save_local=True,
39 local_base_path="data/",
40 )
42 def test_dataset_lengths(self):
43 assert np.all(
44 np.array(self.test_collection.dataset_lengths) == np.array(DATASET_LENGTHS),
45 )
47 def test_dataset_cum_lengths(self):
48 assert (
49 self.test_collection.dataset_cum_lengths == np.array([1, 1, 4, 6, 7])
50 ).all()
52 def test_mazes(self):
53 assert len(self.test_collection.mazes) == 7
54 assert self.test_collection.mazes[0].connection_list.shape == (2, 5, 5)
55 assert self.test_collection.mazes[-1].connection_list.shape == (2, 4, 4)
57 def test_len(self):
58 assert len(self.test_collection) == 7
60 def test_getitem(self):
61 # print(len(self.test_collection))
62 # print(self.test_collection.mazes)
63 assert self.test_collection[0].connection_list.shape == (2, 5, 5)
64 assert self.test_collection[1].connection_list.shape == (2, 3, 3)
65 assert self.test_collection[2].connection_list.shape == (2, 3, 3)
66 assert self.test_collection[3].connection_list.shape == (2, 3, 3)
67 assert self.test_collection[4].connection_list.shape == (2, 3, 3)
68 assert self.test_collection[5].connection_list.shape == (2, 3, 3)
69 assert self.test_collection[6].connection_list.shape == (2, 4, 4)
71 for i in range(sum(DATASET_LENGTHS)):
72 assert (
73 self.test_collection[i].connection_list.shape
74 == self.test_collection.mazes[i].connection_list.shape
75 )
76 assert (
77 self.test_collection[i].connection_list
78 == self.test_collection.mazes[i].connection_list
79 ).all()
81 def test_download(self):
82 # TODO: test downloading after we implement downloading datasets
83 pass
85 def test_serialize_and_load(self):
86 serialized = self.test_collection.serialize()
87 loaded = MazeDatasetCollection.load(serialized)
88 assert loaded.mazes == self.test_collection.mazes
89 assert loaded.cfg.diff(self.test_collection.cfg) == {}
90 assert loaded.cfg == self.test_collection.cfg
92 def test_save_read(self):
93 self.test_collection.save("tests/_temp/collected_dataset_test_save_read.zanj")
94 loaded = MazeDatasetCollection.read(
95 "tests/_temp/collected_dataset_test_save_read.zanj",
96 )
97 assert loaded.mazes == self.test_collection.mazes
98 assert loaded.cfg.diff(self.test_collection.cfg) == {}
99 assert loaded.cfg == self.test_collection.cfg