Coverage for maze_dataset/generation/generators.py: 83%
206 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 14:42 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 14:42 -0600
1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`"""
3import random
4import warnings
5from typing import Any, Callable
7import numpy as np
8from jaxtyping import Bool
10from maze_dataset.constants import CoordArray, CoordTup
11from maze_dataset.generation.seed import GLOBAL_SEED
12from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze
13from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls
15numpy_rng = np.random.default_rng(GLOBAL_SEED)
16random.seed(GLOBAL_SEED)
19def _random_start_coord(
20 grid_shape: Coord,
21 start_coord: Coord | CoordTup | None,
22) -> Coord:
23 "picking a random start coord within the bounds of `grid_shape` if none is provided"
24 start_coord_: Coord
25 if start_coord is None:
26 start_coord_ = np.random.randint(
27 0, # lower bound
28 np.maximum(grid_shape - 1, 1), # upper bound (at least 1)
29 size=len(grid_shape), # dimensionality
30 )
31 else:
32 start_coord_ = np.array(start_coord)
34 return start_coord_
37def get_neighbors_in_bounds(
38 coord: Coord,
39 grid_shape: Coord,
40) -> CoordArray:
41 "get all neighbors of a coordinate that are within the bounds of the grid"
42 # get all neighbors
43 neighbors: CoordArray = coord + NEIGHBORS_MASK
45 # filter neighbors by being within grid bounds
46 neighbors_in_bounds: CoordArray = neighbors[
47 (neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1)
48 ]
50 return neighbors_in_bounds
53class LatticeMazeGenerators:
54 """namespace for lattice maze generation algorithms"""
56 @staticmethod
57 def gen_dfs(
58 grid_shape: Coord | CoordTup,
59 lattice_dim: int = 2,
60 accessible_cells: float | None = None,
61 max_tree_depth: float | None = None,
62 do_forks: bool = True,
63 randomized_stack: bool = False,
64 start_coord: Coord | None = None,
65 ) -> LatticeMaze:
66 """generate a lattice maze using depth first search, iterative
68 # Arguments
69 - `grid_shape: Coord`: the shape of the grid
70 - `lattice_dim: int`: the dimension of the lattice
71 (default: `2`)
72 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
73 (default: `None`)
74 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
75 (default: `None`)
76 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
77 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
79 # algorithm
80 1. Choose the initial cell, mark it as visited and push it to the stack
81 2. While the stack is not empty
82 1. Pop a cell from the stack and make it a current cell
83 2. If the current cell has any neighbours which have not been visited
84 1. Push the current cell to the stack
85 2. Choose one of the unvisited neighbours
86 3. Remove the wall between the current cell and the chosen cell
87 4. Mark the chosen cell as visited and push it to the stack
88 """
89 # Default values if no constraints have been passed
90 grid_shape_: Coord = np.array(grid_shape)
91 n_total_cells: int = int(np.prod(grid_shape_))
93 n_accessible_cells: int
94 if accessible_cells is None:
95 n_accessible_cells = n_total_cells
96 elif isinstance(accessible_cells, float):
97 assert accessible_cells <= 1, (
98 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
99 )
101 n_accessible_cells = int(accessible_cells * n_total_cells)
102 else:
103 assert isinstance(accessible_cells, int)
104 n_accessible_cells = accessible_cells
106 if max_tree_depth is None:
107 max_tree_depth = (
108 2 * n_total_cells
109 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
110 elif isinstance(max_tree_depth, float):
111 assert max_tree_depth <= 1, (
112 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
113 )
115 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
117 # choose a random start coord
118 start_coord = _random_start_coord(grid_shape_, start_coord)
120 # initialize the maze with no connections
121 connection_list: ConnectionList = np.zeros(
122 (lattice_dim, grid_shape_[0], grid_shape_[1]),
123 dtype=np.bool_,
124 )
126 # initialize the stack with the target coord
127 visited_cells: set[tuple[int, int]] = set()
128 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol
129 stack: list[Coord] = [start_coord]
131 # initialize tree_depth_counter
132 current_tree_depth: int = 1
134 # loop until the stack is empty or n_connected_cells is reached
135 while stack and (len(visited_cells) < n_accessible_cells):
136 # get the current coord from the stack
137 current_coord: Coord
138 if randomized_stack:
139 current_coord = stack.pop(random.randint(0, len(stack) - 1))
140 else:
141 current_coord = stack.pop()
143 # filter neighbors by being within grid bounds and being unvisited
144 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
145 (neighbor, delta)
146 for neighbor, delta in zip(
147 current_coord + NEIGHBORS_MASK,
148 NEIGHBORS_MASK,
149 strict=False,
150 )
151 if (
152 (tuple(neighbor) not in visited_cells)
153 and (0 <= neighbor[0] < grid_shape_[0])
154 and (0 <= neighbor[1] < grid_shape_[1])
155 )
156 ]
158 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
159 if unvisited_neighbors_deltas and (
160 current_tree_depth <= max_tree_depth / 2
161 ):
162 # if we want a maze without forks, simply don't add the current coord back to the stack
163 if do_forks and (len(unvisited_neighbors_deltas) > 1):
164 stack.append(current_coord)
166 # choose one of the unvisited neighbors
167 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)
169 # add connection
170 dim: int = int(np.argmax(np.abs(delta)))
171 # if positive, down/right from current coord
172 # if negative, up/left from current coord (down/right from neighbor)
173 clist_node: Coord = (
174 current_coord if (delta.sum() > 0) else chosen_neighbor
175 )
176 connection_list[dim, clist_node[0], clist_node[1]] = True
178 # add to visited cells and stack
179 visited_cells.add(tuple(chosen_neighbor))
180 stack.append(chosen_neighbor)
182 # Update current tree depth
183 current_tree_depth += 1
184 else:
185 current_tree_depth -= 1
187 return LatticeMaze(
188 connection_list=connection_list,
189 generation_meta=dict(
190 func_name="gen_dfs",
191 grid_shape=grid_shape_,
192 start_coord=start_coord,
193 n_accessible_cells=int(n_accessible_cells),
194 max_tree_depth=int(max_tree_depth),
195 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
196 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
197 # treated as fully connected even when it is most certainly not, causing solving the maze to break
198 fully_connected=bool(len(visited_cells) == n_total_cells),
199 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
200 ),
201 )
203 @staticmethod
204 def gen_prim(
205 grid_shape: Coord | CoordTup,
206 lattice_dim: int = 2,
207 accessible_cells: float | None = None,
208 max_tree_depth: float | None = None,
209 do_forks: bool = True,
210 start_coord: Coord | None = None,
211 ) -> LatticeMaze:
212 "(broken!) generate a lattice maze using Prim's algorithm"
213 warnings.warn(
214 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
215 )
216 return LatticeMazeGenerators.gen_dfs(
217 grid_shape=grid_shape,
218 lattice_dim=lattice_dim,
219 accessible_cells=accessible_cells,
220 max_tree_depth=max_tree_depth,
221 do_forks=do_forks,
222 start_coord=start_coord,
223 randomized_stack=True,
224 )
226 @staticmethod
227 def gen_wilson(
228 grid_shape: Coord | CoordTup,
229 **kwargs,
230 ) -> LatticeMaze:
231 """Generate a lattice maze using Wilson's algorithm.
233 # Algorithm
234 Wilson's algorithm generates an unbiased (random) maze
235 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
236 acyclic and all cells are part of a unique connected space.
237 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
238 """
239 assert not kwargs, (
240 f"gen_wilson does not take any additional arguments, got {kwargs = }"
241 )
243 grid_shape_: Coord = np.array(grid_shape)
245 # Initialize grid and visited cells
246 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
247 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
249 # Choose a random cell and mark it as visited
250 start_coord: Coord = _random_start_coord(grid_shape_, None)
251 visited[start_coord[0], start_coord[1]] = True
252 del start_coord
254 while not visited.all():
255 # Perform loop-erased random walk from another random cell
257 # Choose walk_start only from unvisited cells
258 unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
259 walk_start: Coord = unvisited_coords[
260 np.random.choice(unvisited_coords.shape[0])
261 ]
263 # Perform the random walk
264 path: list[Coord] = [walk_start]
265 current: Coord = walk_start
267 # exit the loop once the current path hits a visited cell
268 while not visited[current[0], current[1]]:
269 # find a valid neighbor (one always exists on a lattice)
270 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
271 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
273 # Check for loop
274 loop_exit: int | None = None
275 for i, p in enumerate(path):
276 if np.array_equal(next_cell, p):
277 loop_exit = i
278 break
280 # erase the loop, or continue the walk
281 if loop_exit is not None:
282 # this removes everything after and including the loop start
283 path = path[: loop_exit + 1]
284 # reset current cell to end of path
285 current = path[-1]
286 else:
287 path.append(next_cell)
288 current = next_cell
290 # Add the path to the maze
291 for i in range(len(path) - 1):
292 c_1: Coord = path[i]
293 c_2: Coord = path[i + 1]
295 # find the dimension of the connection
296 delta: Coord = c_2 - c_1
297 dim: int = int(np.argmax(np.abs(delta)))
299 # if positive, down/right from current coord
300 # if negative, up/left from current coord (down/right from neighbor)
301 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
302 connection_list[dim, clist_node[0], clist_node[1]] = True
303 visited[c_1[0], c_1[1]] = True
304 # we dont add c_2 because the last c_2 will have already been visited
306 return LatticeMaze(
307 connection_list=connection_list,
308 generation_meta=dict(
309 func_name="gen_wilson",
310 grid_shape=grid_shape_,
311 fully_connected=True,
312 ),
313 )
315 @staticmethod
316 def gen_percolation(
317 grid_shape: Coord | CoordTup,
318 p: float = 0.4,
319 lattice_dim: int = 2,
320 start_coord: Coord | None = None,
321 ) -> LatticeMaze:
322 """generate a lattice maze using simple percolation
324 note that p in the range (0.4, 0.7) gives the most interesting mazes
326 # Arguments
327 - `grid_shape: Coord`: the shape of the grid
328 - `lattice_dim: int`: the dimension of the lattice (default: `2`)
329 - `p: float`: the probability of a cell being accessible (default: `0.5`)
330 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
331 """
332 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018
333 grid_shape_: Coord = np.array(grid_shape)
335 start_coord = _random_start_coord(grid_shape_, start_coord)
337 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
339 connection_list = _fill_edges_with_walls(connection_list)
341 output: LatticeMaze = LatticeMaze(
342 connection_list=connection_list,
343 generation_meta=dict(
344 func_name="gen_percolation",
345 grid_shape=grid_shape_,
346 percolation_p=p,
347 start_coord=start_coord,
348 ),
349 )
351 # generation_meta is sometimes None, but not here since we just made it a dict above
352 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index]
353 start_coord,
354 )
356 return output
358 @staticmethod
359 def gen_dfs_percolation(
360 grid_shape: Coord | CoordTup,
361 p: float = 0.4,
362 lattice_dim: int = 2,
363 accessible_cells: int | None = None,
364 max_tree_depth: int | None = None,
365 start_coord: Coord | None = None,
366 ) -> LatticeMaze:
367 """dfs and then percolation (adds cycles)"""
368 grid_shape_: Coord = np.array(grid_shape)
369 start_coord = _random_start_coord(grid_shape_, start_coord)
371 # generate initial maze via dfs
372 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
373 grid_shape=grid_shape_,
374 lattice_dim=lattice_dim,
375 accessible_cells=accessible_cells,
376 max_tree_depth=max_tree_depth,
377 start_coord=start_coord,
378 )
380 # percolate
381 connection_list_perc: np.ndarray = (
382 np.random.rand(*maze.connection_list.shape) < p
383 )
384 connection_list_perc = _fill_edges_with_walls(connection_list_perc)
386 maze.__dict__["connection_list"] = np.logical_or(
387 maze.connection_list,
388 connection_list_perc,
389 )
391 # generation_meta is sometimes None, but not here since we just made it a dict above
392 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index]
393 maze.generation_meta["percolation_p"] = p # type: ignore[index]
394 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index]
395 start_coord,
396 )
398 return maze
400 @staticmethod
401 def gen_kruskal(
402 grid_shape: "Coord | CoordTup",
403 lattice_dim: int = 2,
404 start_coord: "Coord | None" = None,
405 ) -> "LatticeMaze":
406 """Generate a maze using Kruskal's algorithm.
408 This function generates a random spanning tree over a grid using Kruskal's algorithm.
409 Each cell is treated as a node, and all valid adjacent edges are listed and processed
410 in random order. An edge is added (i.e. its passage carved) only if it connects two cells
411 that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree)
412 without cycles.
414 https://en.wikipedia.org/wiki/Kruskal's_algorithm
416 # Parameters:
417 - `grid_shape : Coord | CoordTup`
418 The shape of the maze grid (for example, `(n_rows, n_cols)`).
419 - `lattice_dim : int`
420 The lattice dimension (default is `2`).
421 - `start_coord : Coord | None`
422 Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen.
423 - `**kwargs`
424 Additional keyword arguments (currently unused).
426 # Returns:
427 - `LatticeMaze`
428 A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
430 # Usage:
431 ```python
432 maze = gen_kruskal((10, 10))
433 ```
434 """
435 assert lattice_dim == 2, ( # noqa: PLR2004
436 "Kruskal's algorithm is only implemented for 2D lattices."
437 )
438 # Convert grid_shape to a tuple of ints
439 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment]
440 n_rows, n_cols = grid_shape_
442 # Initialize union-find data structure.
443 parent: dict[tuple[int, int], tuple[int, int]] = {}
445 def find(cell: tuple[int, int]) -> tuple[int, int]:
446 while parent[cell] != cell:
447 parent[cell] = parent[parent[cell]]
448 cell = parent[cell]
449 return cell
451 def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None:
452 root1 = find(cell1)
453 root2 = find(cell2)
454 parent[root2] = root1
456 # Initialize each cell as its own set.
457 for i in range(n_rows):
458 for j in range(n_cols):
459 parent[(i, j)] = (i, j)
461 # List all possible edges.
462 # For vertical edges (i.e. connecting a cell to its right neighbor):
463 edges: list[tuple[tuple[int, int], tuple[int, int], int]] = []
464 for i in range(n_rows):
465 for j in range(n_cols - 1):
466 edges.append(((i, j), (i, j + 1), 1))
467 # For horizontal edges (i.e. connecting a cell to its bottom neighbor):
468 for i in range(n_rows - 1):
469 for j in range(n_cols):
470 edges.append(((i, j), (i + 1, j), 0))
472 # Shuffle the list of edges.
473 import random
475 random.shuffle(edges)
477 # Initialize connection_list with no connections.
478 # connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)).
479 # connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)).
480 import numpy as np
482 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
484 # Process each edge; if it connects two different trees, union them and carve the passage.
485 for cell1, cell2, direction in edges:
486 if find(cell1) != find(cell2):
487 union(cell1, cell2)
488 if direction == 0:
489 # Horizontal edge: connection is stored in connection_list[0] at cell1.
490 connection_list[0, cell1[0], cell1[1]] = True
491 else:
492 # Vertical edge: connection is stored in connection_list[1] at cell1.
493 connection_list[1, cell1[0], cell1[1]] = True
495 if start_coord is None:
496 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment]
498 generation_meta: dict = dict(
499 func_name="gen_kruskal",
500 grid_shape=grid_shape_,
501 start_coord=start_coord,
502 algorithm="kruskal",
503 fully_connected=True,
504 )
505 return LatticeMaze(
506 connection_list=connection_list, generation_meta=generation_meta
507 )
509 @staticmethod
510 def gen_recursive_division(
511 grid_shape: "Coord | CoordTup",
512 lattice_dim: int = 2,
513 start_coord: "Coord | None" = None,
514 ) -> "LatticeMaze":
515 """Generate a maze using the recursive division algorithm.
517 This function generates a maze by recursively dividing the grid with walls and carving a single
518 passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent
519 cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage.
520 The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
522 # Parameters:
523 - `grid_shape : Coord | CoordTup`
524 The shape of the maze grid (e.g., `(n_rows, n_cols)`).
525 - `lattice_dim : int`
526 The lattice dimension (default is `2`).
527 - `start_coord : Coord | None`
528 Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen.
529 - `**kwargs`
530 Additional keyword arguments (currently unused).
532 # Returns:
533 - `LatticeMaze`
534 A maze represented by a connection list, generated using recursive division.
536 # Usage:
537 ```python
538 maze = gen_recursive_division((10, 10))
539 ```
540 """
541 assert lattice_dim == 2, ( # noqa: PLR2004
542 "Recursive division algorithm is only implemented for 2D lattices."
543 )
544 # Convert grid_shape to a tuple of ints.
545 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment]
546 n_rows, n_cols = grid_shape_
548 # Initialize connection_list as a fully connected grid.
549 # For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True.
550 # For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True.
551 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
552 connection_list[0, : n_rows - 1, :] = True
553 connection_list[1, :, : n_cols - 1] = True
555 def divide(x: int, y: int, width: int, height: int) -> None:
556 """Recursively divide the region starting at (x, y) with the given width and height.
558 Removes connections along the chosen division line except for one randomly chosen gap.
559 """
560 if width < 2 or height < 2: # noqa: PLR2004
561 return
563 if width > height:
564 # Vertical division.
565 wall_col = random.randint(x + 1, x + width - 1)
566 gap_row = random.randint(y, y + height - 1)
567 for row in range(y, y + height):
568 if row == gap_row:
569 continue
570 # Remove the vertical connection between (row, wall_col-1) and (row, wall_col).
571 if wall_col - 1 < n_cols - 1:
572 connection_list[1, row, wall_col - 1] = False
573 # Recurse on the left and right subregions.
574 divide(x, y, wall_col - x, height)
575 divide(wall_col, y, x + width - wall_col, height)
576 else:
577 # Horizontal division.
578 wall_row = random.randint(y + 1, y + height - 1)
579 gap_col = random.randint(x, x + width - 1)
580 for col in range(x, x + width):
581 if col == gap_col:
582 continue
583 # Remove the horizontal connection between (wall_row-1, col) and (wall_row, col).
584 if wall_row - 1 < n_rows - 1:
585 connection_list[0, wall_row - 1, col] = False
586 # Recurse on the top and bottom subregions.
587 divide(x, y, width, wall_row - y)
588 divide(x, wall_row, width, y + height - wall_row)
590 # Begin the division on the full grid.
591 divide(0, 0, n_cols, n_rows)
593 if start_coord is None:
594 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment]
596 generation_meta: dict = dict(
597 func_name="gen_recursive_division",
598 grid_shape=grid_shape_,
599 start_coord=start_coord,
600 algorithm="recursive_division",
601 fully_connected=True,
602 )
603 return LatticeMaze(
604 connection_list=connection_list, generation_meta=generation_meta
605 )
608# cant automatically populate this because it messes with pickling :(
609GENERATORS_MAP: dict[str, Callable[[Coord | CoordTup, Any], "LatticeMaze"]] = {
610 "gen_dfs": LatticeMazeGenerators.gen_dfs,
611 # TYPING: error: Dict entry 1 has incompatible type
612 # "str": "Callable[[ndarray[Any, Any] | tuple[int, int], KwArg(Any)], LatticeMaze]";
613 # expected "str": "Callable[[ndarray[Any, Any] | tuple[int, int], Any], LatticeMaze]" [dict-item]
614 # gen_wilson takes no kwargs and we check that the kwargs are empty
615 # but mypy doesnt like this, `Any` != `KwArg(Any)`
616 "gen_wilson": LatticeMazeGenerators.gen_wilson, # type: ignore[dict-item]
617 "gen_percolation": LatticeMazeGenerators.gen_percolation,
618 "gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation,
619 "gen_prim": LatticeMazeGenerators.gen_prim,
620 "gen_kruskal": LatticeMazeGenerators.gen_kruskal,
621 "gen_recursive_division": LatticeMazeGenerators.gen_recursive_division,
622}
623"mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`"
625_GENERATORS_PERCOLATED: list[str] = [
626 "gen_percolation",
627 "gen_dfs_percolation",
628]
629"""list of generator names that generate percolated mazes
630we use this to figure out the expected success rate, since depending on the endpoint kwargs this might fail
631this variable is primarily used in `MazeDatasetConfig._to_ps_array` and `MazeDatasetConfig._from_ps_array`
632"""
635def get_maze_with_solution(
636 gen_name: str,
637 grid_shape: Coord | CoordTup,
638 maze_ctor_kwargs: dict | None = None,
639) -> SolvedMaze:
640 "helper function to get a maze already with a solution"
641 if maze_ctor_kwargs is None:
642 maze_ctor_kwargs = dict()
643 # TYPING: error: Too few arguments [call-arg]
644 # not sure why this is happening -- doesnt recognize the kwargs?
645 maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs) # type: ignore[call-arg]
646 solution: CoordArray = np.array(maze.generate_random_path())
647 return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)