Coverage for tests\unit\maze_dataset\generation\test_maze_dataset.py: 100%
138 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1import copy
2from pathlib import Path
4import numpy as np
5import pytest
6from pytest import mark, param
7from zanj import ZANJ
9from maze_dataset.constants import CoordArray
10from maze_dataset.dataset.dataset import (
11 register_dataset_filter,
12 register_filter_namespace_for_dataset,
13)
14from maze_dataset.dataset.maze_dataset import (
15 MazeDataset,
16 MazeDatasetConfig,
17 register_maze_filter,
18 set_serialize_minimal_threshold,
19)
20from maze_dataset.generation.generators import GENERATORS_MAP
21from maze_dataset.maze import SolvedMaze
22from maze_dataset.utils import bool_array_from_string
25class TestMazeDatasetConfig:
26 pass
29TEST_CONFIGS = [
30 MazeDatasetConfig(
31 name="test",
32 grid_n=grid_n,
33 n_mazes=n_mazes,
34 maze_ctor=GENERATORS_MAP["gen_dfs"],
35 maze_ctor_kwargs=maze_ctor_kwargs,
36 )
37 for grid_n, n_mazes, maze_ctor_kwargs in [
38 (3, 5, {}),
39 (3, 1, {}),
40 (5, 5, dict(do_forks=False)),
41 ]
42]
45def test_generate_serial():
46 dataset = MazeDataset.generate(TEST_CONFIGS[0], gen_parallel=False)
48 assert len(dataset) == 5
49 for i, maze in enumerate(dataset):
50 assert maze.grid_shape == (3, 3)
53def test_generate_parallel():
54 dataset = MazeDataset.generate(
55 TEST_CONFIGS[0], gen_parallel=True, verbose=True, pool_kwargs=dict(processes=2)
56 )
58 assert len(dataset) == 5
59 for i, maze in enumerate(dataset):
60 assert maze.grid_shape == (3, 3)
63def test_data_hash_wip():
64 dataset = MazeDataset.generate(TEST_CONFIGS[0])
65 # TODO: dataset.data_hash doesn't work right now
66 assert dataset
69def test_download():
70 with pytest.raises(NotImplementedError):
71 MazeDataset.download(TEST_CONFIGS[0])
74def test_serialize_load():
75 dataset = MazeDataset.generate(TEST_CONFIGS[0])
76 dataset_copy = MazeDataset.load(dataset.serialize())
78 assert dataset.cfg == dataset_copy.cfg
79 for maze, maze_copy in zip(dataset, dataset_copy):
80 assert maze == maze_copy
83@mark.parametrize(
84 "config",
85 [
86 param(
87 c,
88 id=f"{c.grid_n=}; {c.n_mazes=}; {c.maze_ctor_kwargs=}",
89 )
90 for c in TEST_CONFIGS
91 ],
92)
93def test_serialize_load_minimal(config):
94 d = MazeDataset.generate(config, gen_parallel=False)
95 assert MazeDataset.load(d._serialize_minimal()) == d
98@mark.parametrize(
99 "config",
100 [
101 param(
102 c,
103 id=f"{c.grid_n=}; {c.n_mazes=}; {c.maze_ctor_kwargs=}",
104 )
105 for c in TEST_CONFIGS
106 ],
107)
108def test_save_read_minimal(config):
109 def save_and_read(d: MazeDataset, p: str):
110 d.save(file_path=p)
111 # read as MazeDataset
112 roundtrip = MazeDataset.read(p)
113 assert roundtrip == d
114 # read from zanj
115 z = ZANJ()
116 roundtrip_zanj = z.read(p)
117 assert roundtrip_zanj == d
119 d = MazeDataset.generate(config, gen_parallel=False)
120 p = Path("tests/_temp/test_maze_dataset/") / (d.cfg.to_fname() + ".zanj")
122 # Test with full serialization
123 set_serialize_minimal_threshold(None)
124 save_and_read(d, p)
126 # Test with minimal serialization
127 set_serialize_minimal_threshold(0)
128 save_and_read(d, p)
130 d.save(file_path=p)
131 # read as MazeDataset
132 roundtrip = MazeDataset.read(p)
133 assert d.cfg.diff(roundtrip.cfg) == dict()
134 cfg_diff = roundtrip.cfg.diff(d.cfg)
135 assert cfg_diff == {}
136 assert roundtrip.cfg == d.cfg
137 assert roundtrip.mazes == d.mazes
138 assert roundtrip == d
139 # read from zanj
140 z = ZANJ()
141 roundtrip_zanj = z.read(p)
142 assert roundtrip_zanj == d
145def test_custom_maze_filter():
146 connection_list = bool_array_from_string(
147 """
148 F T
149 F F
151 T F
152 T F
153 """,
154 shape=[2, 2, 2],
155 )
156 solutions = [
157 [[0, 0], [0, 1], [1, 1]],
158 [[0, 0], [0, 1]],
159 [[0, 0]],
160 ]
162 def custom_filter_solution_length(maze: SolvedMaze, solution_length: int) -> bool:
163 return len(maze.solution) == solution_length
165 mazes = [
166 SolvedMaze(connection_list=connection_list, solution=solution)
167 for solution in solutions
168 ]
169 dataset = MazeDataset(cfg=TEST_CONFIGS[0], mazes=mazes)
171 filtered_lambda = dataset.custom_maze_filter(lambda m: len(m.solution) == 1)
172 filtered_func = dataset.custom_maze_filter(
173 custom_filter_solution_length, solution_length=1
174 )
176 assert filtered_lambda.mazes == filtered_func.mazes == [mazes[2]]
179class TestMazeDatasetFilters:
180 config = MazeDatasetConfig(name="test", grid_n=3, n_mazes=5)
181 connection_list = bool_array_from_string(
182 """
183 F T
184 F F
186 T F
187 T F
188 """,
189 shape=[2, 2, 2],
190 )
192 def test_filters(self):
193 class TestDataset(MazeDataset): ...
195 @register_filter_namespace_for_dataset(TestDataset)
196 class TestFilters:
197 @register_maze_filter
198 @staticmethod
199 def solution_match(maze: SolvedMaze, solution: CoordArray) -> bool:
200 """Test for solution equality"""
201 return (maze.solution == solution).all()
203 @register_dataset_filter
204 @staticmethod
205 def drop_nth(dataset: TestDataset, n: int) -> TestDataset:
206 """Filter mazes by path length"""
207 return copy.deepcopy(
208 TestDataset(
209 dataset.cfg, [maze for i, maze in enumerate(dataset) if i != n]
210 )
211 )
213 maze1 = SolvedMaze(
214 connection_list=self.connection_list, solution=np.array([[0, 0]])
215 )
216 maze2 = SolvedMaze(
217 connection_list=self.connection_list, solution=np.array([[0, 1]])
218 )
220 dataset = TestDataset(self.config, [maze1, maze2])
222 maze_filter = dataset.filter_by.solution_match(solution=np.array([[0, 0]]))
223 maze_filter2 = dataset.filter_by.solution_match(np.array([[0, 0]]))
225 dataset_filter = dataset.filter_by.drop_nth(n=0)
226 dataset_filter2 = dataset.filter_by.drop_nth(0)
228 assert maze_filter.mazes == maze_filter2.mazes == [maze1]
229 assert dataset_filter.mazes == dataset_filter2.mazes == [maze2]
231 def test_path_length(self):
232 long_maze = SolvedMaze(
233 connection_list=self.connection_list,
234 solution=np.array([[0, 0], [0, 1], [1, 1]]),
235 )
237 short_maze = SolvedMaze(
238 connection_list=self.connection_list, solution=np.array([[0, 0], [0, 1]])
239 )
241 dataset = MazeDataset(self.config, [long_maze, short_maze])
242 path_length_filtered = dataset.filter_by.path_length(3)
243 start_end_filtered = dataset.filter_by.start_end_distance(2)
245 assert type(path_length_filtered) == type(dataset) # noqa: E721
246 assert path_length_filtered.mazes == [long_maze]
247 assert start_end_filtered.mazes == [long_maze]
248 assert dataset.mazes == [long_maze, short_maze]
250 def test_cut_percentile_shortest(self):
251 solutions = [
252 [[0, 0], [0, 1], [1, 1]],
253 [[0, 0], [0, 1]],
254 [[0, 0]],
255 ]
257 mazes = [
258 SolvedMaze(connection_list=self.connection_list, solution=solution)
259 for solution in solutions
260 ]
261 dataset = MazeDataset(cfg=self.config, mazes=mazes)
262 filtered = dataset.filter_by.cut_percentile_shortest(49.0)
264 assert filtered.mazes == mazes[:2]
267DUPE_DATASET = [
268 """
269#####
270# E#
271###X#
272#SXX#
273#####
274""",
275 """
276#####
277#SXE#
278### #
279# #
280#####
281""",
282 """
283#####
284# E#
285###X#
286#SXX#
287#####
288""",
289 """
290#####
291# # #
292# # #
293#EXS#
294#####
295""",
296 """
297#####
298#SXX#
299###X#
300#EXX#
301#####
302""",
303]
306def _helper_dataset_from_ascii(ascii: str) -> MazeDataset:
307 mazes: list[SolvedMaze] = list()
308 for maze in ascii:
309 mazes.append(SolvedMaze.from_ascii(maze.strip()))
311 return MazeDataset(
312 MazeDatasetConfig(
313 name="test", grid_n=mazes[0].grid_shape[0], n_mazes=len(mazes)
314 ),
315 mazes,
316 )
319def test_remove_duplicates():
320 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET)
321 dataset_deduped: MazeDataset = dataset.filter_by.remove_duplicates()
323 assert len(dataset) == 5
324 assert dataset_deduped.mazes == [dataset.mazes[3], dataset.mazes[4]]
327def test_data_hash():
328 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET)
329 hash_1 = dataset.data_hash()
330 hash_2 = dataset.data_hash()
332 assert hash_1 == hash_2
335def test_remove_duplicates_fast():
336 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET)
337 dataset_deduped: MazeDataset = dataset.filter_by.remove_duplicates_fast()
339 assert len(dataset) == 5
340 assert dataset_deduped.mazes == [
341 dataset.mazes[0],
342 dataset.mazes[1],
343 dataset.mazes[3],
344 dataset.mazes[4],
345 ]