Coverage for tests\unit\maze_dataset\dataset\test_collected_dataset.py: 100%
45 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1from functools import cached_property
3import numpy as np
5from maze_dataset.dataset.collected_dataset import (
6 MazeDatasetCollection,
7 MazeDatasetCollectionConfig,
8 MazeDatasetConfig,
9)
11DATASET_LENGTHS: list[int] = [1, 0, 3, 2, 1]
12DATASET_GRID_SIZES: list[int] = [5, 1, 3, 3, 4]
15class TestMazeDatasetCollection:
16 @cached_property
17 def test_collection(self) -> MazeDatasetCollection:
18 config = MazeDatasetCollectionConfig(
19 name="test_collection",
20 maze_dataset_configs=[
21 MazeDatasetConfig(
22 n_mazes=n_mazes,
23 grid_n=grid_n,
24 name=f"test_dataset_{n_mazes}_{grid_n}",
25 )
26 for n_mazes, grid_n in zip(DATASET_LENGTHS, DATASET_GRID_SIZES)
27 ],
28 )
29 return MazeDatasetCollection.from_config(
30 config,
31 do_generate=True,
32 load_local=False,
33 do_download=False,
34 save_local=True,
35 local_base_path="data/",
36 )
38 def test_dataset_lengths(self):
39 assert np.all(
40 np.array(self.test_collection.dataset_lengths) == np.array(DATASET_LENGTHS)
41 )
43 def test_dataset_cum_lengths(self):
44 assert (
45 self.test_collection.dataset_cum_lengths == np.array([1, 1, 4, 6, 7])
46 ).all()
48 def test_mazes(self):
49 assert len(self.test_collection.mazes) == 7
50 assert self.test_collection.mazes[0].connection_list.shape == (2, 5, 5)
51 assert self.test_collection.mazes[-1].connection_list.shape == (2, 4, 4)
53 def test_len(self):
54 assert len(self.test_collection) == 7
56 def test_getitem(self):
57 # print(len(self.test_collection))
58 # print(self.test_collection.mazes)
59 assert self.test_collection[0].connection_list.shape == (2, 5, 5)
60 assert self.test_collection[1].connection_list.shape == (2, 3, 3)
61 assert self.test_collection[2].connection_list.shape == (2, 3, 3)
62 assert self.test_collection[3].connection_list.shape == (2, 3, 3)
63 assert self.test_collection[4].connection_list.shape == (2, 3, 3)
64 assert self.test_collection[5].connection_list.shape == (2, 3, 3)
65 assert self.test_collection[6].connection_list.shape == (2, 4, 4)
67 for i in range(sum(DATASET_LENGTHS)):
68 assert (
69 self.test_collection[i].connection_list.shape
70 == self.test_collection.mazes[i].connection_list.shape
71 )
72 assert (
73 self.test_collection[i].connection_list
74 == self.test_collection.mazes[i].connection_list
75 ).all()
77 def test_download(self):
78 # TODO
79 pass
81 def test_serialize_and_load(self):
82 serialized = self.test_collection.serialize()
83 loaded = MazeDatasetCollection.load(serialized)
84 assert loaded.mazes == self.test_collection.mazes
85 assert loaded.cfg.diff(self.test_collection.cfg) == {}
86 assert loaded.cfg == self.test_collection.cfg
88 def test_save_read(self):
89 self.test_collection.save("tests/_temp/collected_dataset_test_save_read.zanj")
90 loaded = MazeDatasetCollection.read(
91 "tests/_temp/collected_dataset_test_save_read.zanj"
92 )
93 assert loaded.mazes == self.test_collection.mazes
94 assert loaded.cfg.diff(self.test_collection.cfg) == {}
95 assert loaded.cfg == self.test_collection.cfg