Coverage for tests/unit/maze_dataset/generation/test_maze_dataset.py: 100%
137 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
1import copy
2from pathlib import Path
4import numpy as np
5import pytest
6from zanj import ZANJ
8from maze_dataset.constants import CoordArray
9from maze_dataset.dataset.dataset import (
10 register_dataset_filter,
11 register_filter_namespace_for_dataset,
12)
13from maze_dataset.dataset.maze_dataset import (
14 MazeDataset,
15 MazeDatasetConfig,
16 register_maze_filter,
17 set_serialize_minimal_threshold,
18)
19from maze_dataset.generation.generators import GENERATORS_MAP
20from maze_dataset.maze import SolvedMaze
21from maze_dataset.utils import bool_array_from_string
24class TestMazeDatasetConfig:
25 pass
28TEST_CONFIGS = [
29 MazeDatasetConfig(
30 name="test",
31 grid_n=grid_n,
32 n_mazes=n_mazes,
33 maze_ctor=GENERATORS_MAP["gen_dfs"],
34 maze_ctor_kwargs=maze_ctor_kwargs,
35 )
36 for grid_n, n_mazes, maze_ctor_kwargs in [
37 (3, 5, {}),
38 (3, 1, {}),
39 (5, 5, dict(do_forks=False)),
40 ]
41]
44def test_generate_serial():
45 dataset = MazeDataset.generate(TEST_CONFIGS[0], gen_parallel=False)
47 assert len(dataset) == 5
48 for maze in dataset:
49 assert maze.grid_shape == (3, 3)
52def test_generate_parallel():
53 dataset = MazeDataset.generate(
54 TEST_CONFIGS[0],
55 gen_parallel=True,
56 verbose=True,
57 pool_kwargs=dict(processes=2),
58 )
60 assert len(dataset) == 5
61 for maze in dataset:
62 assert maze.grid_shape == (3, 3)
65def test_data_hash_wip():
66 dataset = MazeDataset.generate(TEST_CONFIGS[0])
67 # TODO: dataset.data_hash doesn't work right now
68 assert dataset
71def test_download():
72 with pytest.raises(NotImplementedError):
73 MazeDataset.download(TEST_CONFIGS[0])
76def test_serialize_load():
77 dataset = MazeDataset.generate(TEST_CONFIGS[0])
78 dataset_copy = MazeDataset.load(dataset.serialize())
80 assert dataset.cfg == dataset_copy.cfg
81 for maze, maze_copy in zip(dataset, dataset_copy, strict=False):
82 assert maze == maze_copy
85@pytest.mark.parametrize(
86 "config",
87 [
88 pytest.param(
89 c,
90 id=f"{c.grid_n=}; {c.n_mazes=}; {c.maze_ctor_kwargs=}",
91 )
92 for c in TEST_CONFIGS
93 ],
94)
95def test_serialize_load_minimal(config):
96 d = MazeDataset.generate(config, gen_parallel=False)
97 assert MazeDataset.load(d._serialize_minimal()) == d
100@pytest.mark.parametrize(
101 "config",
102 [
103 pytest.param(
104 c,
105 id=f"{c.grid_n=}; {c.n_mazes=}; {c.maze_ctor_kwargs=}",
106 )
107 for c in TEST_CONFIGS
108 ],
109)
110def test_save_read_minimal(config):
111 def save_and_read(d: MazeDataset, p: str):
112 d.save(file_path=p)
113 # read as MazeDataset
114 roundtrip = MazeDataset.read(p)
115 assert roundtrip == d
116 # read from zanj
117 z = ZANJ()
118 roundtrip_zanj = z.read(p)
119 assert roundtrip_zanj == d
121 d = MazeDataset.generate(config, gen_parallel=False)
122 p = Path("tests/_temp/test_maze_dataset/") / (d.cfg.to_fname() + ".zanj")
124 # Test with full serialization
125 set_serialize_minimal_threshold(None)
126 save_and_read(d, p)
128 # Test with minimal serialization
129 set_serialize_minimal_threshold(0)
130 save_and_read(d, p)
132 d.save(file_path=p)
133 # read as MazeDataset
134 roundtrip = MazeDataset.read(p)
135 assert d.cfg.diff(roundtrip.cfg) == dict()
136 cfg_diff = roundtrip.cfg.diff(d.cfg)
137 assert cfg_diff == {}
138 assert roundtrip.cfg == d.cfg
139 assert roundtrip.mazes == d.mazes
140 assert roundtrip == d
141 # read from zanj
142 z = ZANJ()
143 roundtrip_zanj = z.read(p)
144 assert roundtrip_zanj == d
147def test_custom_maze_filter():
148 connection_list = bool_array_from_string(
149 """
150 F T
151 F F
153 T F
154 T F
155 """,
156 shape=[2, 2, 2],
157 )
158 solutions = [
159 [[0, 0], [0, 1], [1, 1]],
160 [[0, 0], [0, 1]],
161 [[0, 0]],
162 ]
164 def custom_filter_solution_length(maze: SolvedMaze, solution_length: int) -> bool:
165 return len(maze.solution) == solution_length
167 mazes = [
168 SolvedMaze(connection_list=connection_list, solution=solution)
169 for solution in solutions
170 ]
171 dataset = MazeDataset(cfg=TEST_CONFIGS[0], mazes=mazes)
173 filtered_lambda = dataset.custom_maze_filter(lambda m: len(m.solution) == 1)
174 filtered_func = dataset.custom_maze_filter(
175 custom_filter_solution_length,
176 solution_length=1,
177 )
179 assert filtered_lambda.mazes == filtered_func.mazes == [mazes[2]]
182class TestMazeDatasetFilters:
183 config = MazeDatasetConfig(name="test", grid_n=3, n_mazes=5)
184 connection_list = bool_array_from_string(
185 """
186 F T
187 F F
189 T F
190 T F
191 """,
192 shape=[2, 2, 2],
193 )
195 def test_filters(self):
196 class TestDataset(MazeDataset): ...
198 @register_filter_namespace_for_dataset(TestDataset)
199 class TestFilters:
200 @register_maze_filter
201 @staticmethod
202 def solution_match(maze: SolvedMaze, solution: CoordArray) -> bool:
203 """Test for solution equality"""
204 return (maze.solution == solution).all()
206 @register_dataset_filter
207 @staticmethod
208 def drop_nth(dataset: TestDataset, n: int) -> TestDataset:
209 """Filter mazes by path length"""
210 return copy.deepcopy(
211 TestDataset(
212 dataset.cfg,
213 [maze for i, maze in enumerate(dataset) if i != n],
214 ),
215 )
217 maze1 = SolvedMaze(
218 connection_list=self.connection_list,
219 solution=np.array([[0, 0]]),
220 )
221 maze2 = SolvedMaze(
222 connection_list=self.connection_list,
223 solution=np.array([[0, 1]]),
224 )
226 dataset = TestDataset(self.config, [maze1, maze2])
228 maze_filter = dataset.filter_by.solution_match(solution=np.array([[0, 0]]))
229 maze_filter2 = dataset.filter_by.solution_match(np.array([[0, 0]]))
231 dataset_filter = dataset.filter_by.drop_nth(n=0)
232 dataset_filter2 = dataset.filter_by.drop_nth(0)
234 assert maze_filter.mazes == maze_filter2.mazes == [maze1]
235 assert dataset_filter.mazes == dataset_filter2.mazes == [maze2]
237 def test_path_length(self):
238 long_maze = SolvedMaze(
239 connection_list=self.connection_list,
240 solution=np.array([[0, 0], [0, 1], [1, 1]]),
241 )
243 short_maze = SolvedMaze(
244 connection_list=self.connection_list,
245 solution=np.array([[0, 0], [0, 1]]),
246 )
248 dataset = MazeDataset(self.config, [long_maze, short_maze])
249 path_length_filtered = dataset.filter_by.path_length(3)
250 start_end_filtered = dataset.filter_by.start_end_distance(2)
252 assert type(path_length_filtered) == type(dataset) # noqa: E721
253 assert path_length_filtered.mazes == [long_maze]
254 assert start_end_filtered.mazes == [long_maze]
255 assert dataset.mazes == [long_maze, short_maze]
257 def test_cut_percentile_shortest(self):
258 solutions = [
259 [[0, 0], [0, 1], [1, 1]],
260 [[0, 0], [0, 1]],
261 [[0, 0]],
262 ]
264 mazes = [
265 SolvedMaze(connection_list=self.connection_list, solution=solution)
266 for solution in solutions
267 ]
268 dataset = MazeDataset(cfg=self.config, mazes=mazes)
269 filtered = dataset.filter_by.cut_percentile_shortest(49.0)
271 assert filtered.mazes == mazes[:2]
274DUPE_DATASET = [
275 """
276#####
277# E#
278###X#
279#SXX#
280#####
281""",
282 """
283#####
284#SXE#
285### #
286# #
287#####
288""",
289 """
290#####
291# E#
292###X#
293#SXX#
294#####
295""",
296 """
297#####
298# # #
299# # #
300#EXS#
301#####
302""",
303 """
304#####
305#SXX#
306###X#
307#EXX#
308#####
309""",
310]
313def _helper_dataset_from_ascii(ascii_rep: str) -> MazeDataset:
314 mazes: list[SolvedMaze] = list()
315 for maze_ascii in ascii_rep:
316 # TODO: PERF401 Use `list.extend` to create a transformed list
317 mazes.append(SolvedMaze.from_ascii(maze_ascii.strip()))
319 return MazeDataset(
320 MazeDatasetConfig(
321 name="test",
322 grid_n=mazes[0].grid_shape[0],
323 n_mazes=len(mazes),
324 ),
325 mazes,
326 )
329def test_remove_duplicates():
330 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET)
331 dataset_deduped: MazeDataset = dataset.filter_by.remove_duplicates()
333 assert len(dataset) == 5
334 assert dataset_deduped.mazes == [dataset.mazes[3], dataset.mazes[4]]
337def test_data_hash():
338 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET)
339 hash_1 = dataset.data_hash()
340 hash_2 = dataset.data_hash()
342 assert hash_1 == hash_2
345def test_remove_duplicates_fast():
346 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET)
347 dataset_deduped: MazeDataset = dataset.filter_by.remove_duplicates_fast()
349 assert len(dataset) == 5
350 assert dataset_deduped.mazes == [
351 dataset.mazes[0],
352 dataset.mazes[1],
353 dataset.mazes[3],
354 dataset.mazes[4],
355 ]