Coverage for tests\unit\maze_dataset\generation\test_latticemaze.py: 100%
104 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1import numpy as np
2import pytest
4from maze_dataset.constants import CoordArray
5from maze_dataset.generation.default_generators import DEFAULT_GENERATORS
6from maze_dataset.generation.generators import GENERATORS_MAP
7from maze_dataset.maze import LatticeMaze, PixelColors, SolvedMaze, TargetedLatticeMaze
8from maze_dataset.utils import adj_list_to_nested_set, bool_array_from_string
11# thanks to gpt for these tests of _from_pixel_grid
12@pytest.fixture
13def example_pixel_grid():
14 return ~np.array(
15 [
16 [1, 1, 1, 1, 1],
17 [1, 0, 0, 0, 1],
18 [1, 1, 1, 0, 1],
19 [1, 0, 0, 0, 1],
20 [1, 1, 1, 1, 1],
21 ],
22 dtype=bool,
23 )
26@pytest.fixture
27def example_rgb_pixel_grid():
28 return np.array(
29 [
30 [
31 PixelColors.WALL,
32 PixelColors.WALL,
33 PixelColors.WALL,
34 PixelColors.WALL,
35 PixelColors.WALL,
36 ],
37 [
38 PixelColors.WALL,
39 PixelColors.OPEN,
40 PixelColors.OPEN,
41 PixelColors.OPEN,
42 PixelColors.WALL,
43 ],
44 [
45 PixelColors.WALL,
46 PixelColors.WALL,
47 PixelColors.WALL,
48 PixelColors.WALL,
49 PixelColors.WALL,
50 ],
51 [
52 PixelColors.WALL,
53 PixelColors.OPEN,
54 PixelColors.WALL,
55 PixelColors.OPEN,
56 PixelColors.WALL,
57 ],
58 [
59 PixelColors.WALL,
60 PixelColors.WALL,
61 PixelColors.WALL,
62 PixelColors.WALL,
63 PixelColors.WALL,
64 ],
65 ],
66 dtype=np.uint8,
67 )
70def test_from_pixel_grid_bw(example_pixel_grid):
71 connection_list, grid_shape = LatticeMaze._from_pixel_grid_bw(example_pixel_grid)
73 assert isinstance(connection_list, np.ndarray)
74 assert connection_list.shape == (2, 2, 2)
75 assert np.all(connection_list[0] == np.array([[False, True], [False, False]]))
76 assert np.all(connection_list[1] == np.array([[True, False], [True, False]]))
77 assert grid_shape == (2, 2)
80def test_from_pixel_grid_with_positions(example_rgb_pixel_grid):
81 marked_positions = {
82 "start": PixelColors.START,
83 "end": PixelColors.END,
84 "path": PixelColors.PATH,
85 }
87 (
88 connection_list,
89 grid_shape,
90 out_positions,
91 ) = LatticeMaze._from_pixel_grid_with_positions(
92 example_rgb_pixel_grid, marked_positions
93 )
95 assert isinstance(connection_list, np.ndarray)
96 assert connection_list.shape == (2, 2, 2)
97 assert np.all(connection_list[0] == np.array([[False, False], [False, False]]))
98 assert np.all(connection_list[1] == np.array([[True, False], [False, False]]))
99 assert grid_shape == (2, 2)
101 assert isinstance(out_positions, dict)
102 assert len(out_positions) == 3
103 assert "start" in out_positions and "end" in out_positions
104 assert (
105 isinstance(out_positions["start"], np.ndarray)
106 and isinstance(out_positions["end"], np.ndarray)
107 and isinstance(out_positions["path"], np.ndarray)
108 )
109 assert out_positions["start"].shape == (0,)
110 assert out_positions["end"].shape == (0,)
111 assert out_positions["path"].shape == (0,)
114def test_find_start_end_points_in_rgb_pixel_grid():
115 rgb_pixel_grid_with_positions = np.array(
116 [
117 [
118 PixelColors.WALL,
119 PixelColors.WALL,
120 PixelColors.WALL,
121 PixelColors.WALL,
122 PixelColors.WALL,
123 ],
124 [
125 PixelColors.WALL,
126 PixelColors.START,
127 PixelColors.OPEN,
128 PixelColors.END,
129 PixelColors.WALL,
130 ],
131 [
132 PixelColors.WALL,
133 PixelColors.WALL,
134 PixelColors.WALL,
135 PixelColors.WALL,
136 PixelColors.WALL,
137 ],
138 [
139 PixelColors.WALL,
140 PixelColors.OPEN,
141 PixelColors.WALL,
142 PixelColors.OPEN,
143 PixelColors.WALL,
144 ],
145 [
146 PixelColors.WALL,
147 PixelColors.WALL,
148 PixelColors.WALL,
149 PixelColors.WALL,
150 PixelColors.WALL,
151 ],
152 ],
153 dtype=np.uint8,
154 )
156 marked_positions = {
157 "start": PixelColors.START,
158 "end": PixelColors.END,
159 "path": PixelColors.PATH,
160 }
162 (
163 connection_list,
164 grid_shape,
165 out_positions,
166 ) = LatticeMaze._from_pixel_grid_with_positions(
167 rgb_pixel_grid_with_positions, marked_positions
168 )
170 print(f"{out_positions = }")
172 assert isinstance(out_positions, dict)
173 assert len(out_positions) == 3
174 assert "start" in out_positions and "end" in out_positions
175 assert (
176 isinstance(out_positions["start"], np.ndarray)
177 and isinstance(out_positions["end"], np.ndarray)
178 and isinstance(out_positions["path"], np.ndarray)
179 )
181 assert np.all(out_positions["start"] == np.array([[0, 0]]))
182 assert np.all(out_positions["end"] == np.array([[0, 1]]))
183 assert out_positions["path"].shape == (0,)
186@pytest.mark.parametrize("gfunc_name, kwargs", DEFAULT_GENERATORS)
187def test_pixels_ascii_roundtrip(gfunc_name, kwargs):
188 """tests all generators work and can be written to/from ascii and pixels"""
189 n: int = 5
190 maze_gen_func = GENERATORS_MAP[gfunc_name]
191 maze: LatticeMaze = maze_gen_func(np.array([n, n]), **kwargs)
193 maze_pixels: np.ndarray = maze.as_pixels()
194 maze_ascii: str = maze.as_ascii()
196 assert maze == LatticeMaze.from_pixels(maze_pixels)
197 assert maze == LatticeMaze.from_ascii(maze_ascii)
199 expected_shape: tuple = (n * 2 + 1, n * 2 + 1, 3)
200 assert maze_pixels.shape == expected_shape, (
201 f"{maze_pixels.shape} != {expected_shape}"
202 )
203 assert all(n * 2 + 1 == len(line) for line in maze_ascii.splitlines()), (
204 f"{maze_ascii}"
205 )
208@pytest.mark.parametrize("gfunc_name, kwargs", DEFAULT_GENERATORS)
209def test_targeted_solved_maze(gfunc_name, kwargs):
210 n: int = 5
211 maze_gen_func = GENERATORS_MAP[gfunc_name]
212 maze: LatticeMaze = maze_gen_func(np.array([n, n]), **kwargs)
213 solution: CoordArray = maze.generate_random_path()
214 tgt_maze: TargetedLatticeMaze = TargetedLatticeMaze.from_lattice_maze(
215 maze,
216 solution[0],
217 solution[-1],
218 )
220 tgt_maze_pixels: np.ndarray = tgt_maze.as_pixels()
221 tgt_maze_ascii: str = tgt_maze.as_ascii()
223 assert tgt_maze == TargetedLatticeMaze.from_pixels(tgt_maze_pixels)
224 assert tgt_maze == TargetedLatticeMaze.from_ascii(tgt_maze_ascii)
226 expected_shape: tuple = (n * 2 + 1, n * 2 + 1, 3)
227 assert tgt_maze_pixels.shape == expected_shape, (
228 f"{tgt_maze_pixels.shape} != {expected_shape}"
229 )
230 assert all(n * 2 + 1 == len(line) for line in tgt_maze_ascii.splitlines()), (
231 f"{tgt_maze_ascii}"
232 )
234 solved_maze: SolvedMaze = SolvedMaze.from_targeted_lattice_maze(tgt_maze)
236 solved_maze_pixels: np.ndarray = solved_maze.as_pixels()
237 solved_maze_ascii: str = solved_maze.as_ascii()
239 assert solved_maze == SolvedMaze.from_pixels(solved_maze_pixels)
240 assert solved_maze == SolvedMaze.from_ascii(solved_maze_ascii)
242 expected_shape: tuple = (n * 2 + 1, n * 2 + 1, 3)
243 assert tgt_maze_pixels.shape == expected_shape, (
244 f"{tgt_maze_pixels.shape} != {expected_shape}"
245 )
246 assert all(n * 2 + 1 == len(line) for line in solved_maze_ascii.splitlines()), (
247 f"{solved_maze_ascii}"
248 )
251def test_as_adj_list():
252 connection_list = bool_array_from_string(
253 """
254 F T
255 F F
257 T F
258 T F
259 """,
260 shape=[2, 2, 2],
261 )
263 maze = LatticeMaze(connection_list=connection_list)
265 adj_list = maze.as_adj_list(shuffle_d0=False, shuffle_d1=False)
267 expected = [[[0, 1], [1, 1]], [[0, 0], [0, 1]], [[1, 0], [1, 1]]]
269 assert adj_list_to_nested_set(expected) == adj_list_to_nested_set(adj_list)
272@pytest.mark.parametrize("gfunc_name, kwargs", DEFAULT_GENERATORS)
273def test_get_nodes(gfunc_name, kwargs):
274 maze_gen_func = GENERATORS_MAP[gfunc_name]
275 maze = maze_gen_func(np.array((3, 2)), **kwargs)
276 assert (
277 maze.get_nodes().tolist()
278 == np.array([(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)]).tolist()
279 )
282@pytest.mark.parametrize("gfunc_name, kwargs", DEFAULT_GENERATORS)
283def test_generate_random_path(gfunc_name, kwargs):
284 maze_gen_func = GENERATORS_MAP[gfunc_name]
285 maze = maze_gen_func(np.array((2, 2)), **kwargs)
286 path = maze.generate_random_path()
288 # len > 1 ensures that we have unique start and end nodes
289 assert len(path) > 1
292@pytest.mark.parametrize("gfunc_name, kwargs", DEFAULT_GENERATORS)
293def test_generate_random_path_size_1(gfunc_name, kwargs):
294 maze_gen_func = GENERATORS_MAP[gfunc_name]
295 maze = maze_gen_func(np.array((1, 1)), **kwargs)
296 with pytest.raises(AssertionError):
297 maze.generate_random_path()