maze_dataset.generation
generation functions have signature (grid_shape: Coord, **kwargs) -> LatticeMaze
and are methods in LatticeMazeGenerators
DEFAULT_GENERATORS
is a list of generator name, generator kwargs pairs used in tests and demos
1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators` 2 3`DEFAULT_GENERATORS` is a list of generator name, generator kwargs pairs used in tests and demos 4 5""" 6 7from maze_dataset.generation.generators import ( 8 GENERATORS_MAP, 9 LatticeMazeGenerators, 10 get_maze_with_solution, 11 numpy_rng, 12) 13 14__all__ = [ 15 # submodules 16 "default_generators", 17 "generators", 18 "seed", 19 # imports 20 "LatticeMazeGenerators", 21 "GENERATORS_MAP", 22 "get_maze_with_solution", 23 "numpy_rng", 24]
54class LatticeMazeGenerators: 55 """namespace for lattice maze generation algorithms""" 56 57 @staticmethod 58 def gen_dfs( 59 grid_shape: Coord | CoordTup, 60 lattice_dim: int = 2, 61 accessible_cells: float | None = None, 62 max_tree_depth: float | None = None, 63 do_forks: bool = True, 64 randomized_stack: bool = False, 65 start_coord: Coord | None = None, 66 ) -> LatticeMaze: 67 """generate a lattice maze using depth first search, iterative 68 69 # Arguments 70 - `grid_shape: Coord`: the shape of the grid 71 - `lattice_dim: int`: the dimension of the lattice 72 (default: `2`) 73 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** 74 (default: `None`) 75 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** 76 (default: `None`) 77 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. 78 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. 79 80 # algorithm 81 1. Choose the initial cell, mark it as visited and push it to the stack 82 2. While the stack is not empty 83 1. Pop a cell from the stack and make it a current cell 84 2. If the current cell has any neighbours which have not been visited 85 1. Push the current cell to the stack 86 2. Choose one of the unvisited neighbours 87 3. Remove the wall between the current cell and the chosen cell 88 4. Mark the chosen cell as visited and push it to the stack 89 """ 90 # Default values if no constraints have been passed 91 grid_shape_: Coord = np.array(grid_shape) 92 n_total_cells: int = int(np.prod(grid_shape_)) 93 94 n_accessible_cells: int 95 if accessible_cells is None: 96 n_accessible_cells = n_total_cells 97 elif isinstance(accessible_cells, float): 98 assert accessible_cells <= 1, ( 99 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" 100 ) 101 102 n_accessible_cells = int(accessible_cells * n_total_cells) 103 else: 104 assert isinstance(accessible_cells, int) 105 n_accessible_cells = accessible_cells 106 107 if max_tree_depth is None: 108 max_tree_depth = ( 109 2 * n_total_cells 110 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. 111 elif isinstance(max_tree_depth, float): 112 assert max_tree_depth <= 1, ( 113 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" 114 ) 115 116 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_)) 117 118 # choose a random start coord 119 start_coord = _random_start_coord(grid_shape_, start_coord) 120 121 # initialize the maze with no connections 122 connection_list: ConnectionList = np.zeros( 123 (lattice_dim, grid_shape_[0], grid_shape_[1]), 124 dtype=np.bool_, 125 ) 126 127 # initialize the stack with the target coord 128 visited_cells: set[tuple[int, int]] = set() 129 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol 130 stack: list[Coord] = [start_coord] 131 132 # initialize tree_depth_counter 133 current_tree_depth: int = 1 134 135 # loop until the stack is empty or n_connected_cells is reached 136 while stack and (len(visited_cells) < n_accessible_cells): 137 # get the current coord from the stack 138 current_coord: Coord 139 if randomized_stack: 140 current_coord = stack.pop(random.randint(0, len(stack) - 1)) 141 else: 142 current_coord = stack.pop() 143 144 # filter neighbors by being within grid bounds and being unvisited 145 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ 146 (neighbor, delta) 147 for neighbor, delta in zip( 148 current_coord + NEIGHBORS_MASK, 149 NEIGHBORS_MASK, 150 strict=False, 151 ) 152 if ( 153 (tuple(neighbor) not in visited_cells) 154 and (0 <= neighbor[0] < grid_shape_[0]) 155 and (0 <= neighbor[1] < grid_shape_[1]) 156 ) 157 ] 158 159 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) 160 if unvisited_neighbors_deltas and ( 161 current_tree_depth <= max_tree_depth / 2 162 ): 163 # if we want a maze without forks, simply don't add the current coord back to the stack 164 if do_forks and (len(unvisited_neighbors_deltas) > 1): 165 stack.append(current_coord) 166 167 # choose one of the unvisited neighbors 168 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) 169 170 # add connection 171 dim: int = int(np.argmax(np.abs(delta))) 172 # if positive, down/right from current coord 173 # if negative, up/left from current coord (down/right from neighbor) 174 clist_node: Coord = ( 175 current_coord if (delta.sum() > 0) else chosen_neighbor 176 ) 177 connection_list[dim, clist_node[0], clist_node[1]] = True 178 179 # add to visited cells and stack 180 visited_cells.add(tuple(chosen_neighbor)) 181 stack.append(chosen_neighbor) 182 183 # Update current tree depth 184 current_tree_depth += 1 185 else: 186 current_tree_depth -= 1 187 188 return LatticeMaze( 189 connection_list=connection_list, 190 generation_meta=dict( 191 func_name="gen_dfs", 192 grid_shape=grid_shape_, 193 start_coord=start_coord, 194 n_accessible_cells=int(n_accessible_cells), 195 max_tree_depth=int(max_tree_depth), 196 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug 197 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is 198 # treated as fully connected even when it is most certainly not, causing solving the maze to break 199 fully_connected=bool(len(visited_cells) == n_total_cells), 200 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, 201 ), 202 ) 203 204 @staticmethod 205 def gen_prim( 206 grid_shape: Coord | CoordTup, 207 lattice_dim: int = 2, 208 accessible_cells: float | None = None, 209 max_tree_depth: float | None = None, 210 do_forks: bool = True, 211 start_coord: Coord | None = None, 212 ) -> LatticeMaze: 213 "(broken!) generate a lattice maze using Prim's algorithm" 214 warnings.warn( 215 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12", 216 ) 217 return LatticeMazeGenerators.gen_dfs( 218 grid_shape=grid_shape, 219 lattice_dim=lattice_dim, 220 accessible_cells=accessible_cells, 221 max_tree_depth=max_tree_depth, 222 do_forks=do_forks, 223 start_coord=start_coord, 224 randomized_stack=True, 225 ) 226 227 @staticmethod 228 def gen_wilson( 229 grid_shape: Coord | CoordTup, 230 **kwargs, 231 ) -> LatticeMaze: 232 """Generate a lattice maze using Wilson's algorithm. 233 234 # Algorithm 235 Wilson's algorithm generates an unbiased (random) maze 236 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is 237 acyclic and all cells are part of a unique connected space. 238 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm 239 """ 240 assert not kwargs, ( 241 f"gen_wilson does not take any additional arguments, got {kwargs = }" 242 ) 243 244 grid_shape_: Coord = np.array(grid_shape) 245 246 # Initialize grid and visited cells 247 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_) 248 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_) 249 250 # Choose a random cell and mark it as visited 251 start_coord: Coord = _random_start_coord(grid_shape_, None) 252 visited[start_coord[0], start_coord[1]] = True 253 del start_coord 254 255 while not visited.all(): 256 # Perform loop-erased random walk from another random cell 257 258 # Choose walk_start only from unvisited cells 259 unvisited_coords: CoordArray = np.column_stack(np.where(~visited)) 260 walk_start: Coord = unvisited_coords[ 261 np.random.choice(unvisited_coords.shape[0]) 262 ] 263 264 # Perform the random walk 265 path: list[Coord] = [walk_start] 266 current: Coord = walk_start 267 268 # exit the loop once the current path hits a visited cell 269 while not visited[current[0], current[1]]: 270 # find a valid neighbor (one always exists on a lattice) 271 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_) 272 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])] 273 274 # Check for loop 275 loop_exit: int | None = None 276 for i, p in enumerate(path): 277 if np.array_equal(next_cell, p): 278 loop_exit = i 279 break 280 281 # erase the loop, or continue the walk 282 if loop_exit is not None: 283 # this removes everything after and including the loop start 284 path = path[: loop_exit + 1] 285 # reset current cell to end of path 286 current = path[-1] 287 else: 288 path.append(next_cell) 289 current = next_cell 290 291 # Add the path to the maze 292 for i in range(len(path) - 1): 293 c_1: Coord = path[i] 294 c_2: Coord = path[i + 1] 295 296 # find the dimension of the connection 297 delta: Coord = c_2 - c_1 298 dim: int = int(np.argmax(np.abs(delta))) 299 300 # if positive, down/right from current coord 301 # if negative, up/left from current coord (down/right from neighbor) 302 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2 303 connection_list[dim, clist_node[0], clist_node[1]] = True 304 visited[c_1[0], c_1[1]] = True 305 # we dont add c_2 because the last c_2 will have already been visited 306 307 return LatticeMaze( 308 connection_list=connection_list, 309 generation_meta=dict( 310 func_name="gen_wilson", 311 grid_shape=grid_shape_, 312 fully_connected=True, 313 ), 314 ) 315 316 @staticmethod 317 def gen_percolation( 318 grid_shape: Coord | CoordTup, 319 p: float = 0.4, 320 lattice_dim: int = 2, 321 start_coord: Coord | None = None, 322 ) -> LatticeMaze: 323 """generate a lattice maze using simple percolation 324 325 note that p in the range (0.4, 0.7) gives the most interesting mazes 326 327 # Arguments 328 - `grid_shape: Coord`: the shape of the grid 329 - `lattice_dim: int`: the dimension of the lattice (default: `2`) 330 - `p: float`: the probability of a cell being accessible (default: `0.5`) 331 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start) 332 """ 333 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018 334 grid_shape_: Coord = np.array(grid_shape) 335 336 start_coord = _random_start_coord(grid_shape_, start_coord) 337 338 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p 339 340 connection_list = _fill_edges_with_walls(connection_list) 341 342 output: LatticeMaze = LatticeMaze( 343 connection_list=connection_list, 344 generation_meta=dict( 345 func_name="gen_percolation", 346 grid_shape=grid_shape_, 347 percolation_p=p, 348 start_coord=start_coord, 349 ), 350 ) 351 352 # generation_meta is sometimes None, but not here since we just made it a dict above 353 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index] 354 start_coord, 355 ) 356 357 return output 358 359 @staticmethod 360 def gen_dfs_percolation( 361 grid_shape: Coord | CoordTup, 362 p: float = 0.4, 363 lattice_dim: int = 2, 364 accessible_cells: int | None = None, 365 max_tree_depth: int | None = None, 366 start_coord: Coord | None = None, 367 ) -> LatticeMaze: 368 """dfs and then percolation (adds cycles)""" 369 grid_shape_: Coord = np.array(grid_shape) 370 start_coord = _random_start_coord(grid_shape_, start_coord) 371 372 # generate initial maze via dfs 373 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs( 374 grid_shape=grid_shape_, 375 lattice_dim=lattice_dim, 376 accessible_cells=accessible_cells, 377 max_tree_depth=max_tree_depth, 378 start_coord=start_coord, 379 ) 380 381 # percolate 382 connection_list_perc: np.ndarray = ( 383 np.random.rand(*maze.connection_list.shape) < p 384 ) 385 connection_list_perc = _fill_edges_with_walls(connection_list_perc) 386 387 maze.__dict__["connection_list"] = np.logical_or( 388 maze.connection_list, 389 connection_list_perc, 390 ) 391 392 # generation_meta is sometimes None, but not here since we just made it a dict above 393 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index] 394 maze.generation_meta["percolation_p"] = p # type: ignore[index] 395 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index] 396 start_coord, 397 ) 398 399 return maze 400 401 @staticmethod 402 def gen_kruskal( 403 grid_shape: "Coord | CoordTup", 404 lattice_dim: int = 2, 405 start_coord: "Coord | None" = None, 406 ) -> "LatticeMaze": 407 """Generate a maze using Kruskal's algorithm. 408 409 This function generates a random spanning tree over a grid using Kruskal's algorithm. 410 Each cell is treated as a node, and all valid adjacent edges are listed and processed 411 in random order. An edge is added (i.e. its passage carved) only if it connects two cells 412 that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) 413 without cycles. 414 415 https://en.wikipedia.org/wiki/Kruskal's_algorithm 416 417 # Parameters: 418 - `grid_shape : Coord | CoordTup` 419 The shape of the maze grid (for example, `(n_rows, n_cols)`). 420 - `lattice_dim : int` 421 The lattice dimension (default is `2`). 422 - `start_coord : Coord | None` 423 Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen. 424 - `**kwargs` 425 Additional keyword arguments (currently unused). 426 427 # Returns: 428 - `LatticeMaze` 429 A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm. 430 431 # Usage: 432 ```python 433 maze = gen_kruskal((10, 10)) 434 ``` 435 """ 436 assert lattice_dim == 2, ( # noqa: PLR2004 437 "Kruskal's algorithm is only implemented for 2D lattices." 438 ) 439 # Convert grid_shape to a tuple of ints 440 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 441 n_rows, n_cols = grid_shape_ 442 443 # Initialize union-find data structure. 444 parent: dict[tuple[int, int], tuple[int, int]] = {} 445 446 def find(cell: tuple[int, int]) -> tuple[int, int]: 447 while parent[cell] != cell: 448 parent[cell] = parent[parent[cell]] 449 cell = parent[cell] 450 return cell 451 452 def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None: 453 root1 = find(cell1) 454 root2 = find(cell2) 455 parent[root2] = root1 456 457 # Initialize each cell as its own set. 458 for i in range(n_rows): 459 for j in range(n_cols): 460 parent[(i, j)] = (i, j) 461 462 # List all possible edges. 463 # For vertical edges (i.e. connecting a cell to its right neighbor): 464 edges: list[tuple[tuple[int, int], tuple[int, int], int]] = [] 465 for i in range(n_rows): 466 for j in range(n_cols - 1): 467 edges.append(((i, j), (i, j + 1), 1)) 468 # For horizontal edges (i.e. connecting a cell to its bottom neighbor): 469 for i in range(n_rows - 1): 470 for j in range(n_cols): 471 edges.append(((i, j), (i + 1, j), 0)) 472 473 # Shuffle the list of edges. 474 import random 475 476 random.shuffle(edges) 477 478 # Initialize connection_list with no connections. 479 # connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)). 480 # connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)). 481 import numpy as np 482 483 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 484 485 # Process each edge; if it connects two different trees, union them and carve the passage. 486 for cell1, cell2, direction in edges: 487 if find(cell1) != find(cell2): 488 union(cell1, cell2) 489 if direction == 0: 490 # Horizontal edge: connection is stored in connection_list[0] at cell1. 491 connection_list[0, cell1[0], cell1[1]] = True 492 else: 493 # Vertical edge: connection is stored in connection_list[1] at cell1. 494 connection_list[1, cell1[0], cell1[1]] = True 495 496 if start_coord is None: 497 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 498 499 generation_meta: dict = dict( 500 func_name="gen_kruskal", 501 grid_shape=grid_shape_, 502 start_coord=start_coord, 503 algorithm="kruskal", 504 fully_connected=True, 505 ) 506 return LatticeMaze( 507 connection_list=connection_list, generation_meta=generation_meta 508 ) 509 510 @staticmethod 511 def gen_recursive_division( 512 grid_shape: "Coord | CoordTup", 513 lattice_dim: int = 2, 514 start_coord: "Coord | None" = None, 515 ) -> "LatticeMaze": 516 """Generate a maze using the recursive division algorithm. 517 518 This function generates a maze by recursively dividing the grid with walls and carving a single 519 passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent 520 cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. 521 The resulting maze is a perfect maze, meaning there is exactly one path between any two cells. 522 523 # Parameters: 524 - `grid_shape : Coord | CoordTup` 525 The shape of the maze grid (e.g., `(n_rows, n_cols)`). 526 - `lattice_dim : int` 527 The lattice dimension (default is `2`). 528 - `start_coord : Coord | None` 529 Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen. 530 - `**kwargs` 531 Additional keyword arguments (currently unused). 532 533 # Returns: 534 - `LatticeMaze` 535 A maze represented by a connection list, generated using recursive division. 536 537 # Usage: 538 ```python 539 maze = gen_recursive_division((10, 10)) 540 ``` 541 """ 542 assert lattice_dim == 2, ( # noqa: PLR2004 543 "Recursive division algorithm is only implemented for 2D lattices." 544 ) 545 # Convert grid_shape to a tuple of ints. 546 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 547 n_rows, n_cols = grid_shape_ 548 549 # Initialize connection_list as a fully connected grid. 550 # For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True. 551 # For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True. 552 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 553 connection_list[0, : n_rows - 1, :] = True 554 connection_list[1, :, : n_cols - 1] = True 555 556 def divide(x: int, y: int, width: int, height: int) -> None: 557 """Recursively divide the region starting at (x, y) with the given width and height. 558 559 Removes connections along the chosen division line except for one randomly chosen gap. 560 """ 561 if width < 2 or height < 2: # noqa: PLR2004 562 return 563 564 if width > height: 565 # Vertical division. 566 wall_col = random.randint(x + 1, x + width - 1) 567 gap_row = random.randint(y, y + height - 1) 568 for row in range(y, y + height): 569 if row == gap_row: 570 continue 571 # Remove the vertical connection between (row, wall_col-1) and (row, wall_col). 572 if wall_col - 1 < n_cols - 1: 573 connection_list[1, row, wall_col - 1] = False 574 # Recurse on the left and right subregions. 575 divide(x, y, wall_col - x, height) 576 divide(wall_col, y, x + width - wall_col, height) 577 else: 578 # Horizontal division. 579 wall_row = random.randint(y + 1, y + height - 1) 580 gap_col = random.randint(x, x + width - 1) 581 for col in range(x, x + width): 582 if col == gap_col: 583 continue 584 # Remove the horizontal connection between (wall_row-1, col) and (wall_row, col). 585 if wall_row - 1 < n_rows - 1: 586 connection_list[0, wall_row - 1, col] = False 587 # Recurse on the top and bottom subregions. 588 divide(x, y, width, wall_row - y) 589 divide(x, wall_row, width, y + height - wall_row) 590 591 # Begin the division on the full grid. 592 divide(0, 0, n_cols, n_rows) 593 594 if start_coord is None: 595 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 596 597 generation_meta: dict = dict( 598 func_name="gen_recursive_division", 599 grid_shape=grid_shape_, 600 start_coord=start_coord, 601 algorithm="recursive_division", 602 fully_connected=True, 603 ) 604 return LatticeMaze( 605 connection_list=connection_list, generation_meta=generation_meta 606 )
namespace for lattice maze generation algorithms
57 @staticmethod 58 def gen_dfs( 59 grid_shape: Coord | CoordTup, 60 lattice_dim: int = 2, 61 accessible_cells: float | None = None, 62 max_tree_depth: float | None = None, 63 do_forks: bool = True, 64 randomized_stack: bool = False, 65 start_coord: Coord | None = None, 66 ) -> LatticeMaze: 67 """generate a lattice maze using depth first search, iterative 68 69 # Arguments 70 - `grid_shape: Coord`: the shape of the grid 71 - `lattice_dim: int`: the dimension of the lattice 72 (default: `2`) 73 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** 74 (default: `None`) 75 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** 76 (default: `None`) 77 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. 78 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. 79 80 # algorithm 81 1. Choose the initial cell, mark it as visited and push it to the stack 82 2. While the stack is not empty 83 1. Pop a cell from the stack and make it a current cell 84 2. If the current cell has any neighbours which have not been visited 85 1. Push the current cell to the stack 86 2. Choose one of the unvisited neighbours 87 3. Remove the wall between the current cell and the chosen cell 88 4. Mark the chosen cell as visited and push it to the stack 89 """ 90 # Default values if no constraints have been passed 91 grid_shape_: Coord = np.array(grid_shape) 92 n_total_cells: int = int(np.prod(grid_shape_)) 93 94 n_accessible_cells: int 95 if accessible_cells is None: 96 n_accessible_cells = n_total_cells 97 elif isinstance(accessible_cells, float): 98 assert accessible_cells <= 1, ( 99 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" 100 ) 101 102 n_accessible_cells = int(accessible_cells * n_total_cells) 103 else: 104 assert isinstance(accessible_cells, int) 105 n_accessible_cells = accessible_cells 106 107 if max_tree_depth is None: 108 max_tree_depth = ( 109 2 * n_total_cells 110 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. 111 elif isinstance(max_tree_depth, float): 112 assert max_tree_depth <= 1, ( 113 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" 114 ) 115 116 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_)) 117 118 # choose a random start coord 119 start_coord = _random_start_coord(grid_shape_, start_coord) 120 121 # initialize the maze with no connections 122 connection_list: ConnectionList = np.zeros( 123 (lattice_dim, grid_shape_[0], grid_shape_[1]), 124 dtype=np.bool_, 125 ) 126 127 # initialize the stack with the target coord 128 visited_cells: set[tuple[int, int]] = set() 129 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol 130 stack: list[Coord] = [start_coord] 131 132 # initialize tree_depth_counter 133 current_tree_depth: int = 1 134 135 # loop until the stack is empty or n_connected_cells is reached 136 while stack and (len(visited_cells) < n_accessible_cells): 137 # get the current coord from the stack 138 current_coord: Coord 139 if randomized_stack: 140 current_coord = stack.pop(random.randint(0, len(stack) - 1)) 141 else: 142 current_coord = stack.pop() 143 144 # filter neighbors by being within grid bounds and being unvisited 145 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ 146 (neighbor, delta) 147 for neighbor, delta in zip( 148 current_coord + NEIGHBORS_MASK, 149 NEIGHBORS_MASK, 150 strict=False, 151 ) 152 if ( 153 (tuple(neighbor) not in visited_cells) 154 and (0 <= neighbor[0] < grid_shape_[0]) 155 and (0 <= neighbor[1] < grid_shape_[1]) 156 ) 157 ] 158 159 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) 160 if unvisited_neighbors_deltas and ( 161 current_tree_depth <= max_tree_depth / 2 162 ): 163 # if we want a maze without forks, simply don't add the current coord back to the stack 164 if do_forks and (len(unvisited_neighbors_deltas) > 1): 165 stack.append(current_coord) 166 167 # choose one of the unvisited neighbors 168 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) 169 170 # add connection 171 dim: int = int(np.argmax(np.abs(delta))) 172 # if positive, down/right from current coord 173 # if negative, up/left from current coord (down/right from neighbor) 174 clist_node: Coord = ( 175 current_coord if (delta.sum() > 0) else chosen_neighbor 176 ) 177 connection_list[dim, clist_node[0], clist_node[1]] = True 178 179 # add to visited cells and stack 180 visited_cells.add(tuple(chosen_neighbor)) 181 stack.append(chosen_neighbor) 182 183 # Update current tree depth 184 current_tree_depth += 1 185 else: 186 current_tree_depth -= 1 187 188 return LatticeMaze( 189 connection_list=connection_list, 190 generation_meta=dict( 191 func_name="gen_dfs", 192 grid_shape=grid_shape_, 193 start_coord=start_coord, 194 n_accessible_cells=int(n_accessible_cells), 195 max_tree_depth=int(max_tree_depth), 196 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug 197 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is 198 # treated as fully connected even when it is most certainly not, causing solving the maze to break 199 fully_connected=bool(len(visited_cells) == n_total_cells), 200 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, 201 ), 202 )
generate a lattice maze using depth first search, iterative
Arguments
grid_shape: Coord
: the shape of the gridlattice_dim: int
: the dimension of the lattice (default:2
)accessible_cells: int | float |None
: the number of accessible cells in the maze. IfNone
, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of total cells (default:None
)max_tree_depth: int | float | None
: the maximum depth of the tree. IfNone
, defaults to2 * accessible_cells
. if a float, asserts it is <= 1 and treats it as a proportion of the sum of the grid shape (default:None
)do_forks: bool
: whether to allow forks in the maze. IfFalse
, the maze will be have no forks and will be a simple hallway.start_coord: Coord | None
: the starting coordinate of the generation algorithm. IfNone
, defaults to a random coordinate.
algorithm
- Choose the initial cell, mark it as visited and push it to the stack
- While the stack is not empty
- Pop a cell from the stack and make it a current cell
- If the current cell has any neighbours which have not been visited
- Push the current cell to the stack
- Choose one of the unvisited neighbours
- Remove the wall between the current cell and the chosen cell
- Mark the chosen cell as visited and push it to the stack
204 @staticmethod 205 def gen_prim( 206 grid_shape: Coord | CoordTup, 207 lattice_dim: int = 2, 208 accessible_cells: float | None = None, 209 max_tree_depth: float | None = None, 210 do_forks: bool = True, 211 start_coord: Coord | None = None, 212 ) -> LatticeMaze: 213 "(broken!) generate a lattice maze using Prim's algorithm" 214 warnings.warn( 215 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12", 216 ) 217 return LatticeMazeGenerators.gen_dfs( 218 grid_shape=grid_shape, 219 lattice_dim=lattice_dim, 220 accessible_cells=accessible_cells, 221 max_tree_depth=max_tree_depth, 222 do_forks=do_forks, 223 start_coord=start_coord, 224 randomized_stack=True, 225 )
(broken!) generate a lattice maze using Prim's algorithm
227 @staticmethod 228 def gen_wilson( 229 grid_shape: Coord | CoordTup, 230 **kwargs, 231 ) -> LatticeMaze: 232 """Generate a lattice maze using Wilson's algorithm. 233 234 # Algorithm 235 Wilson's algorithm generates an unbiased (random) maze 236 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is 237 acyclic and all cells are part of a unique connected space. 238 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm 239 """ 240 assert not kwargs, ( 241 f"gen_wilson does not take any additional arguments, got {kwargs = }" 242 ) 243 244 grid_shape_: Coord = np.array(grid_shape) 245 246 # Initialize grid and visited cells 247 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_) 248 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_) 249 250 # Choose a random cell and mark it as visited 251 start_coord: Coord = _random_start_coord(grid_shape_, None) 252 visited[start_coord[0], start_coord[1]] = True 253 del start_coord 254 255 while not visited.all(): 256 # Perform loop-erased random walk from another random cell 257 258 # Choose walk_start only from unvisited cells 259 unvisited_coords: CoordArray = np.column_stack(np.where(~visited)) 260 walk_start: Coord = unvisited_coords[ 261 np.random.choice(unvisited_coords.shape[0]) 262 ] 263 264 # Perform the random walk 265 path: list[Coord] = [walk_start] 266 current: Coord = walk_start 267 268 # exit the loop once the current path hits a visited cell 269 while not visited[current[0], current[1]]: 270 # find a valid neighbor (one always exists on a lattice) 271 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_) 272 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])] 273 274 # Check for loop 275 loop_exit: int | None = None 276 for i, p in enumerate(path): 277 if np.array_equal(next_cell, p): 278 loop_exit = i 279 break 280 281 # erase the loop, or continue the walk 282 if loop_exit is not None: 283 # this removes everything after and including the loop start 284 path = path[: loop_exit + 1] 285 # reset current cell to end of path 286 current = path[-1] 287 else: 288 path.append(next_cell) 289 current = next_cell 290 291 # Add the path to the maze 292 for i in range(len(path) - 1): 293 c_1: Coord = path[i] 294 c_2: Coord = path[i + 1] 295 296 # find the dimension of the connection 297 delta: Coord = c_2 - c_1 298 dim: int = int(np.argmax(np.abs(delta))) 299 300 # if positive, down/right from current coord 301 # if negative, up/left from current coord (down/right from neighbor) 302 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2 303 connection_list[dim, clist_node[0], clist_node[1]] = True 304 visited[c_1[0], c_1[1]] = True 305 # we dont add c_2 because the last c_2 will have already been visited 306 307 return LatticeMaze( 308 connection_list=connection_list, 309 generation_meta=dict( 310 func_name="gen_wilson", 311 grid_shape=grid_shape_, 312 fully_connected=True, 313 ), 314 )
Generate a lattice maze using Wilson's algorithm.
Algorithm
Wilson's algorithm generates an unbiased (random) maze sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is acyclic and all cells are part of a unique connected space. https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
316 @staticmethod 317 def gen_percolation( 318 grid_shape: Coord | CoordTup, 319 p: float = 0.4, 320 lattice_dim: int = 2, 321 start_coord: Coord | None = None, 322 ) -> LatticeMaze: 323 """generate a lattice maze using simple percolation 324 325 note that p in the range (0.4, 0.7) gives the most interesting mazes 326 327 # Arguments 328 - `grid_shape: Coord`: the shape of the grid 329 - `lattice_dim: int`: the dimension of the lattice (default: `2`) 330 - `p: float`: the probability of a cell being accessible (default: `0.5`) 331 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start) 332 """ 333 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018 334 grid_shape_: Coord = np.array(grid_shape) 335 336 start_coord = _random_start_coord(grid_shape_, start_coord) 337 338 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p 339 340 connection_list = _fill_edges_with_walls(connection_list) 341 342 output: LatticeMaze = LatticeMaze( 343 connection_list=connection_list, 344 generation_meta=dict( 345 func_name="gen_percolation", 346 grid_shape=grid_shape_, 347 percolation_p=p, 348 start_coord=start_coord, 349 ), 350 ) 351 352 # generation_meta is sometimes None, but not here since we just made it a dict above 353 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index] 354 start_coord, 355 ) 356 357 return output
generate a lattice maze using simple percolation
note that p in the range (0.4, 0.7) gives the most interesting mazes
Arguments
grid_shape: Coord
: the shape of the gridlattice_dim: int
: the dimension of the lattice (default:2
)p: float
: the probability of a cell being accessible (default:0.5
)start_coord: Coord | None
: the starting coordinate for the connected component (default:None
will give a random start)
359 @staticmethod 360 def gen_dfs_percolation( 361 grid_shape: Coord | CoordTup, 362 p: float = 0.4, 363 lattice_dim: int = 2, 364 accessible_cells: int | None = None, 365 max_tree_depth: int | None = None, 366 start_coord: Coord | None = None, 367 ) -> LatticeMaze: 368 """dfs and then percolation (adds cycles)""" 369 grid_shape_: Coord = np.array(grid_shape) 370 start_coord = _random_start_coord(grid_shape_, start_coord) 371 372 # generate initial maze via dfs 373 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs( 374 grid_shape=grid_shape_, 375 lattice_dim=lattice_dim, 376 accessible_cells=accessible_cells, 377 max_tree_depth=max_tree_depth, 378 start_coord=start_coord, 379 ) 380 381 # percolate 382 connection_list_perc: np.ndarray = ( 383 np.random.rand(*maze.connection_list.shape) < p 384 ) 385 connection_list_perc = _fill_edges_with_walls(connection_list_perc) 386 387 maze.__dict__["connection_list"] = np.logical_or( 388 maze.connection_list, 389 connection_list_perc, 390 ) 391 392 # generation_meta is sometimes None, but not here since we just made it a dict above 393 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index] 394 maze.generation_meta["percolation_p"] = p # type: ignore[index] 395 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index] 396 start_coord, 397 ) 398 399 return maze
dfs and then percolation (adds cycles)
401 @staticmethod 402 def gen_kruskal( 403 grid_shape: "Coord | CoordTup", 404 lattice_dim: int = 2, 405 start_coord: "Coord | None" = None, 406 ) -> "LatticeMaze": 407 """Generate a maze using Kruskal's algorithm. 408 409 This function generates a random spanning tree over a grid using Kruskal's algorithm. 410 Each cell is treated as a node, and all valid adjacent edges are listed and processed 411 in random order. An edge is added (i.e. its passage carved) only if it connects two cells 412 that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) 413 without cycles. 414 415 https://en.wikipedia.org/wiki/Kruskal's_algorithm 416 417 # Parameters: 418 - `grid_shape : Coord | CoordTup` 419 The shape of the maze grid (for example, `(n_rows, n_cols)`). 420 - `lattice_dim : int` 421 The lattice dimension (default is `2`). 422 - `start_coord : Coord | None` 423 Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen. 424 - `**kwargs` 425 Additional keyword arguments (currently unused). 426 427 # Returns: 428 - `LatticeMaze` 429 A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm. 430 431 # Usage: 432 ```python 433 maze = gen_kruskal((10, 10)) 434 ``` 435 """ 436 assert lattice_dim == 2, ( # noqa: PLR2004 437 "Kruskal's algorithm is only implemented for 2D lattices." 438 ) 439 # Convert grid_shape to a tuple of ints 440 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 441 n_rows, n_cols = grid_shape_ 442 443 # Initialize union-find data structure. 444 parent: dict[tuple[int, int], tuple[int, int]] = {} 445 446 def find(cell: tuple[int, int]) -> tuple[int, int]: 447 while parent[cell] != cell: 448 parent[cell] = parent[parent[cell]] 449 cell = parent[cell] 450 return cell 451 452 def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None: 453 root1 = find(cell1) 454 root2 = find(cell2) 455 parent[root2] = root1 456 457 # Initialize each cell as its own set. 458 for i in range(n_rows): 459 for j in range(n_cols): 460 parent[(i, j)] = (i, j) 461 462 # List all possible edges. 463 # For vertical edges (i.e. connecting a cell to its right neighbor): 464 edges: list[tuple[tuple[int, int], tuple[int, int], int]] = [] 465 for i in range(n_rows): 466 for j in range(n_cols - 1): 467 edges.append(((i, j), (i, j + 1), 1)) 468 # For horizontal edges (i.e. connecting a cell to its bottom neighbor): 469 for i in range(n_rows - 1): 470 for j in range(n_cols): 471 edges.append(((i, j), (i + 1, j), 0)) 472 473 # Shuffle the list of edges. 474 import random 475 476 random.shuffle(edges) 477 478 # Initialize connection_list with no connections. 479 # connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)). 480 # connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)). 481 import numpy as np 482 483 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 484 485 # Process each edge; if it connects two different trees, union them and carve the passage. 486 for cell1, cell2, direction in edges: 487 if find(cell1) != find(cell2): 488 union(cell1, cell2) 489 if direction == 0: 490 # Horizontal edge: connection is stored in connection_list[0] at cell1. 491 connection_list[0, cell1[0], cell1[1]] = True 492 else: 493 # Vertical edge: connection is stored in connection_list[1] at cell1. 494 connection_list[1, cell1[0], cell1[1]] = True 495 496 if start_coord is None: 497 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 498 499 generation_meta: dict = dict( 500 func_name="gen_kruskal", 501 grid_shape=grid_shape_, 502 start_coord=start_coord, 503 algorithm="kruskal", 504 fully_connected=True, 505 ) 506 return LatticeMaze( 507 connection_list=connection_list, generation_meta=generation_meta 508 )
Generate a maze using Kruskal's algorithm.
This function generates a random spanning tree over a grid using Kruskal's algorithm. Each cell is treated as a node, and all valid adjacent edges are listed and processed in random order. An edge is added (i.e. its passage carved) only if it connects two cells that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) without cycles.
https://en.wikipedia.org/wiki/Kruskal's_algorithm
Parameters:
grid_shape : Coord | CoordTup
The shape of the maze grid (for example,(n_rows, n_cols)
).lattice_dim : int
The lattice dimension (default is2
).start_coord : Coord | None
Optionally, specify a starting coordinate. IfNone
, a random coordinate will be chosen.**kwargs
Additional keyword arguments (currently unused).
Returns:
LatticeMaze
A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
Usage:
maze = gen_kruskal((10, 10))
510 @staticmethod 511 def gen_recursive_division( 512 grid_shape: "Coord | CoordTup", 513 lattice_dim: int = 2, 514 start_coord: "Coord | None" = None, 515 ) -> "LatticeMaze": 516 """Generate a maze using the recursive division algorithm. 517 518 This function generates a maze by recursively dividing the grid with walls and carving a single 519 passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent 520 cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. 521 The resulting maze is a perfect maze, meaning there is exactly one path between any two cells. 522 523 # Parameters: 524 - `grid_shape : Coord | CoordTup` 525 The shape of the maze grid (e.g., `(n_rows, n_cols)`). 526 - `lattice_dim : int` 527 The lattice dimension (default is `2`). 528 - `start_coord : Coord | None` 529 Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen. 530 - `**kwargs` 531 Additional keyword arguments (currently unused). 532 533 # Returns: 534 - `LatticeMaze` 535 A maze represented by a connection list, generated using recursive division. 536 537 # Usage: 538 ```python 539 maze = gen_recursive_division((10, 10)) 540 ``` 541 """ 542 assert lattice_dim == 2, ( # noqa: PLR2004 543 "Recursive division algorithm is only implemented for 2D lattices." 544 ) 545 # Convert grid_shape to a tuple of ints. 546 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 547 n_rows, n_cols = grid_shape_ 548 549 # Initialize connection_list as a fully connected grid. 550 # For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True. 551 # For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True. 552 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 553 connection_list[0, : n_rows - 1, :] = True 554 connection_list[1, :, : n_cols - 1] = True 555 556 def divide(x: int, y: int, width: int, height: int) -> None: 557 """Recursively divide the region starting at (x, y) with the given width and height. 558 559 Removes connections along the chosen division line except for one randomly chosen gap. 560 """ 561 if width < 2 or height < 2: # noqa: PLR2004 562 return 563 564 if width > height: 565 # Vertical division. 566 wall_col = random.randint(x + 1, x + width - 1) 567 gap_row = random.randint(y, y + height - 1) 568 for row in range(y, y + height): 569 if row == gap_row: 570 continue 571 # Remove the vertical connection between (row, wall_col-1) and (row, wall_col). 572 if wall_col - 1 < n_cols - 1: 573 connection_list[1, row, wall_col - 1] = False 574 # Recurse on the left and right subregions. 575 divide(x, y, wall_col - x, height) 576 divide(wall_col, y, x + width - wall_col, height) 577 else: 578 # Horizontal division. 579 wall_row = random.randint(y + 1, y + height - 1) 580 gap_col = random.randint(x, x + width - 1) 581 for col in range(x, x + width): 582 if col == gap_col: 583 continue 584 # Remove the horizontal connection between (wall_row-1, col) and (wall_row, col). 585 if wall_row - 1 < n_rows - 1: 586 connection_list[0, wall_row - 1, col] = False 587 # Recurse on the top and bottom subregions. 588 divide(x, y, width, wall_row - y) 589 divide(x, wall_row, width, y + height - wall_row) 590 591 # Begin the division on the full grid. 592 divide(0, 0, n_cols, n_rows) 593 594 if start_coord is None: 595 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 596 597 generation_meta: dict = dict( 598 func_name="gen_recursive_division", 599 grid_shape=grid_shape_, 600 start_coord=start_coord, 601 algorithm="recursive_division", 602 fully_connected=True, 603 ) 604 return LatticeMaze( 605 connection_list=connection_list, generation_meta=generation_meta 606 )
Generate a maze using the recursive division algorithm.
This function generates a maze by recursively dividing the grid with walls and carving a single passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
Parameters:
grid_shape : Coord | CoordTup
The shape of the maze grid (e.g.,(n_rows, n_cols)
).lattice_dim : int
The lattice dimension (default is2
).start_coord : Coord | None
Optionally, specify a starting coordinate. IfNone
, a random coordinate is chosen.**kwargs
Additional keyword arguments (currently unused).
Returns:
LatticeMaze
A maze represented by a connection list, generated using recursive division.
Usage:
maze = gen_recursive_division((10, 10))
636def get_maze_with_solution( 637 gen_name: str, 638 grid_shape: Coord | CoordTup, 639 maze_ctor_kwargs: dict | None = None, 640) -> SolvedMaze: 641 "helper function to get a maze already with a solution" 642 if maze_ctor_kwargs is None: 643 maze_ctor_kwargs = dict() 644 # TYPING: error: Too few arguments [call-arg] 645 # not sure why this is happening -- doesnt recognize the kwargs? 646 maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs) # type: ignore[call-arg] 647 solution: CoordArray = np.array(maze.generate_random_path()) 648 return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)
helper function to get a maze already with a solution