Coverage for maze_dataset\generation\generators.py: 78%
129 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`"""
3import random
4import warnings
5from typing import Any, Callable
7import numpy as np
8from jaxtyping import Bool
9from muutils.mlutils import GLOBAL_SEED
11from maze_dataset.constants import CoordArray
12from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze
13from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls
15numpy_rng = np.random.default_rng(GLOBAL_SEED)
16random.seed(GLOBAL_SEED)
19def _random_start_coord(grid_shape: Coord, start_coord: Coord | None) -> Coord:
20 "picking a random start coord within the bounds of `grid_shape` if none is provided"
21 if start_coord is None:
22 start_coord: Coord = np.random.randint(
23 0, # lower bound
24 np.maximum(grid_shape - 1, 1), # upper bound (at least 1)
25 size=len(grid_shape), # dimensionality
26 )
27 else:
28 start_coord = np.array(start_coord)
30 return start_coord
33def get_neighbors_in_bounds(
34 coord: Coord,
35 grid_shape: Coord,
36) -> CoordArray:
37 "get all neighbors of a coordinate that are within the bounds of the grid"
38 # get all neighbors
39 neighbors: CoordArray = coord + NEIGHBORS_MASK
41 # filter neighbors by being within grid bounds
42 neighbors_in_bounds: CoordArray = neighbors[
43 (neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1)
44 ]
46 return neighbors_in_bounds
49class LatticeMazeGenerators:
50 """namespace for lattice maze generation algorithms"""
52 @staticmethod
53 def gen_dfs(
54 grid_shape: Coord,
55 lattice_dim: int = 2,
56 accessible_cells: int | float | None = None,
57 max_tree_depth: int | float | None = None,
58 do_forks: bool = True,
59 randomized_stack: bool = False,
60 start_coord: Coord | None = None,
61 ) -> LatticeMaze:
62 """generate a lattice maze using depth first search, iterative
64 # Arguments
65 - `grid_shape: Coord`: the shape of the grid
66 - `lattice_dim: int`: the dimension of the lattice
67 (default: `2`)
68 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
69 (default: `None`)
70 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
71 (default: `None`)
72 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
73 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
75 # algorithm
76 1. Choose the initial cell, mark it as visited and push it to the stack
77 2. While the stack is not empty
78 1. Pop a cell from the stack and make it a current cell
79 2. If the current cell has any neighbours which have not been visited
80 1. Push the current cell to the stack
81 2. Choose one of the unvisited neighbours
82 3. Remove the wall between the current cell and the chosen cell
83 4. Mark the chosen cell as visited and push it to the stack
84 """
86 # Default values if no constraints have been passed
87 grid_shape: Coord = np.array(grid_shape)
88 n_total_cells: int = int(np.prod(grid_shape))
90 n_accessible_cells: int
91 if accessible_cells is None:
92 n_accessible_cells = n_total_cells
93 elif isinstance(accessible_cells, float):
94 assert accessible_cells <= 1, (
95 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
96 )
98 n_accessible_cells = int(accessible_cells * n_total_cells)
99 else:
100 assert isinstance(accessible_cells, int)
101 n_accessible_cells = accessible_cells
103 if max_tree_depth is None:
104 max_tree_depth = (
105 2 * n_total_cells
106 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
107 elif isinstance(max_tree_depth, float):
108 assert max_tree_depth <= 1, (
109 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
110 )
112 max_tree_depth = int(max_tree_depth * np.sum(grid_shape))
114 # choose a random start coord
115 start_coord = _random_start_coord(grid_shape, start_coord)
117 # initialize the maze with no connections
118 connection_list: ConnectionList = np.zeros(
119 (lattice_dim, grid_shape[0], grid_shape[1]), dtype=np.bool_
120 )
122 # initialize the stack with the target coord
123 visited_cells: set[tuple[int, int]] = set()
124 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol
125 stack: list[Coord] = [start_coord]
127 # initialize tree_depth_counter
128 current_tree_depth: int = 1
130 # loop until the stack is empty or n_connected_cells is reached
131 while stack and (len(visited_cells) < n_accessible_cells):
132 # get the current coord from the stack
133 current_coord: Coord
134 if randomized_stack:
135 current_coord = stack.pop(random.randint(0, len(stack) - 1))
136 else:
137 current_coord = stack.pop()
139 # filter neighbors by being within grid bounds and being unvisited
140 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
141 (neighbor, delta)
142 for neighbor, delta in zip(
143 current_coord + NEIGHBORS_MASK, NEIGHBORS_MASK
144 )
145 if (
146 (tuple(neighbor) not in visited_cells)
147 and (0 <= neighbor[0] < grid_shape[0])
148 and (0 <= neighbor[1] < grid_shape[1])
149 )
150 ]
152 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
153 if unvisited_neighbors_deltas and (
154 current_tree_depth <= max_tree_depth / 2
155 ):
156 # if we want a maze without forks, simply don't add the current coord back to the stack
157 if do_forks and (len(unvisited_neighbors_deltas) > 1):
158 stack.append(current_coord)
160 # choose one of the unvisited neighbors
161 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)
163 # add connection
164 dim: int = np.argmax(np.abs(delta))
165 # if positive, down/right from current coord
166 # if negative, up/left from current coord (down/right from neighbor)
167 clist_node: Coord = (
168 current_coord if (delta.sum() > 0) else chosen_neighbor
169 )
170 connection_list[dim, clist_node[0], clist_node[1]] = True
172 # add to visited cells and stack
173 visited_cells.add(tuple(chosen_neighbor))
174 stack.append(chosen_neighbor)
176 # Update current tree depth
177 current_tree_depth += 1
178 else:
179 current_tree_depth -= 1
181 output = LatticeMaze(
182 connection_list=connection_list,
183 generation_meta=dict(
184 func_name="gen_dfs",
185 grid_shape=grid_shape,
186 start_coord=start_coord,
187 n_accessible_cells=int(n_accessible_cells),
188 max_tree_depth=int(max_tree_depth),
189 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
190 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
191 # treated as fully connected even when it is most certainly not, causing solving the maze to break
192 fully_connected=bool(len(visited_cells) == n_total_cells),
193 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
194 ),
195 )
197 return output
199 @staticmethod
200 def gen_prim(
201 grid_shape: Coord,
202 lattice_dim: int = 2,
203 accessible_cells: int | float | None = None,
204 max_tree_depth: int | float | None = None,
205 do_forks: bool = True,
206 start_coord: Coord | None = None,
207 ) -> LatticeMaze:
208 warnings.warn(
209 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
210 )
211 return LatticeMazeGenerators.gen_dfs(
212 grid_shape=grid_shape,
213 lattice_dim=lattice_dim,
214 accessible_cells=accessible_cells,
215 max_tree_depth=max_tree_depth,
216 do_forks=do_forks,
217 start_coord=start_coord,
218 randomized_stack=True,
219 )
221 @staticmethod
222 def gen_wilson(
223 grid_shape: Coord,
224 ) -> LatticeMaze:
225 """Generate a lattice maze using Wilson's algorithm.
227 # Algorithm
228 Wilson's algorithm generates an unbiased (random) maze
229 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
230 acyclic and all cells are part of a unique connected space.
231 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
232 """
234 # Initialize grid and visited cells
235 connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_)
236 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape, dtype=np.bool_)
238 # Choose a random cell and mark it as visited
239 start_coord: Coord = _random_start_coord(grid_shape, None)
240 visited[start_coord[0], start_coord[1]] = True
241 del start_coord
243 while not visited.all():
244 # Perform loop-erased random walk from another random cell
246 # Choose walk_start only from unvisited cells
247 unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
248 walk_start: Coord = unvisited_coords[
249 np.random.choice(unvisited_coords.shape[0])
250 ]
252 # Perform the random walk
253 path: list[Coord] = [walk_start]
254 current: Coord = walk_start
256 # exit the loop once the current path hits a visited cell
257 while not visited[current[0], current[1]]:
258 # find a valid neighbor (one always exists on a lattice)
259 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape)
260 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
262 # Check for loop
263 loop_exit: int | None = None
264 for i, p in enumerate(path):
265 if np.array_equal(next_cell, p):
266 loop_exit = i
267 break
269 # erase the loop, or continue the walk
270 if loop_exit is not None:
271 # this removes everything after and including the loop start
272 path = path[: loop_exit + 1]
273 # reset current cell to end of path
274 current = path[-1]
275 else:
276 path.append(next_cell)
277 current = next_cell
279 # Add the path to the maze
280 for i in range(len(path) - 1):
281 c_1: Coord = path[i]
282 c_2: Coord = path[i + 1]
284 # find the dimension of the connection
285 delta: Coord = c_2 - c_1
286 dim: int = np.argmax(np.abs(delta))
288 # if positive, down/right from current coord
289 # if negative, up/left from current coord (down/right from neighbor)
290 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
291 connection_list[dim, clist_node[0], clist_node[1]] = True
292 visited[c_1[0], c_1[1]] = True
293 # we dont add c_2 because the last c_2 will have already been visited
295 return LatticeMaze(
296 connection_list=connection_list,
297 generation_meta=dict(
298 func_name="gen_wilson",
299 grid_shape=grid_shape,
300 fully_connected=True,
301 ),
302 )
304 @staticmethod
305 def gen_percolation(
306 grid_shape: Coord,
307 p: float = 0.4,
308 lattice_dim: int = 2,
309 start_coord: Coord | None = None,
310 ) -> LatticeMaze:
311 """generate a lattice maze using simple percolation
313 note that p in the range (0.4, 0.7) gives the most interesting mazes
315 # Arguments
316 - `grid_shape: Coord`: the shape of the grid
317 - `lattice_dim: int`: the dimension of the lattice (default: `2`)
318 - `p: float`: the probability of a cell being accessible (default: `0.5`)
319 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
320 """
321 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"
322 grid_shape: Coord = np.array(grid_shape)
324 start_coord = _random_start_coord(grid_shape, start_coord)
326 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape) < p
328 connection_list = _fill_edges_with_walls(connection_list)
330 output: LatticeMaze = LatticeMaze(
331 connection_list=connection_list,
332 generation_meta=dict(
333 func_name="gen_percolation",
334 grid_shape=grid_shape,
335 percolation_p=p,
336 start_coord=start_coord,
337 ),
338 )
340 output.generation_meta["visited_cells"] = output.gen_connected_component_from(
341 start_coord
342 )
344 return output
346 @staticmethod
347 def gen_dfs_percolation(
348 grid_shape: Coord,
349 p: float = 0.4,
350 lattice_dim: int = 2,
351 accessible_cells: int | None = None,
352 max_tree_depth: int | None = None,
353 start_coord: Coord | None = None,
354 ) -> LatticeMaze:
355 """dfs and then percolation (adds cycles)"""
356 grid_shape: Coord = np.array(grid_shape)
357 start_coord = _random_start_coord(grid_shape, start_coord)
359 # generate initial maze via dfs
360 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
361 grid_shape=grid_shape,
362 lattice_dim=lattice_dim,
363 accessible_cells=accessible_cells,
364 max_tree_depth=max_tree_depth,
365 start_coord=start_coord,
366 )
368 # percolate
369 connection_list_perc: np.ndarray = (
370 np.random.rand(*maze.connection_list.shape) < p
371 )
372 connection_list_perc = _fill_edges_with_walls(connection_list_perc)
374 maze.__dict__["connection_list"] = np.logical_or(
375 maze.connection_list, connection_list_perc
376 )
378 maze.generation_meta["func_name"] = "gen_dfs_percolation"
379 maze.generation_meta["percolation_p"] = p
380 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(
381 start_coord
382 )
384 return maze
387# cant automatically populate this because it messes with pickling :(
388GENERATORS_MAP: dict[str, Callable[[Coord, Any], "LatticeMaze"]] = {
389 "gen_dfs": LatticeMazeGenerators.gen_dfs,
390 "gen_wilson": LatticeMazeGenerators.gen_wilson,
391 "gen_percolation": LatticeMazeGenerators.gen_percolation,
392 "gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation,
393 "gen_prim": LatticeMazeGenerators.gen_prim,
394}
395"mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`"
398def get_maze_with_solution(
399 gen_name: str,
400 grid_shape: Coord,
401 maze_ctor_kwargs: dict | None = None,
402) -> SolvedMaze:
403 "helper function to get a maze already with a solution"
404 if maze_ctor_kwargs is None:
405 maze_ctor_kwargs = dict()
406 maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs)
407 solution: CoordArray = np.array(maze.generate_random_path())
408 return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)