Coverage for maze_dataset/generation/generators.py: 83%

206 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 14:42 -0600

1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`""" 

2 

3import random 

4import warnings 

5from typing import Any, Callable 

6 

7import numpy as np 

8from jaxtyping import Bool 

9 

10from maze_dataset.constants import CoordArray, CoordTup 

11from maze_dataset.generation.seed import GLOBAL_SEED 

12from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze 

13from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls 

14 

15numpy_rng = np.random.default_rng(GLOBAL_SEED) 

16random.seed(GLOBAL_SEED) 

17 

18 

19def _random_start_coord( 

20 grid_shape: Coord, 

21 start_coord: Coord | CoordTup | None, 

22) -> Coord: 

23 "picking a random start coord within the bounds of `grid_shape` if none is provided" 

24 start_coord_: Coord 

25 if start_coord is None: 

26 start_coord_ = np.random.randint( 

27 0, # lower bound 

28 np.maximum(grid_shape - 1, 1), # upper bound (at least 1) 

29 size=len(grid_shape), # dimensionality 

30 ) 

31 else: 

32 start_coord_ = np.array(start_coord) 

33 

34 return start_coord_ 

35 

36 

37def get_neighbors_in_bounds( 

38 coord: Coord, 

39 grid_shape: Coord, 

40) -> CoordArray: 

41 "get all neighbors of a coordinate that are within the bounds of the grid" 

42 # get all neighbors 

43 neighbors: CoordArray = coord + NEIGHBORS_MASK 

44 

45 # filter neighbors by being within grid bounds 

46 neighbors_in_bounds: CoordArray = neighbors[ 

47 (neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1) 

48 ] 

49 

50 return neighbors_in_bounds 

51 

52 

53class LatticeMazeGenerators: 

54 """namespace for lattice maze generation algorithms""" 

55 

56 @staticmethod 

57 def gen_dfs( 

58 grid_shape: Coord | CoordTup, 

59 lattice_dim: int = 2, 

60 accessible_cells: float | None = None, 

61 max_tree_depth: float | None = None, 

62 do_forks: bool = True, 

63 randomized_stack: bool = False, 

64 start_coord: Coord | None = None, 

65 ) -> LatticeMaze: 

66 """generate a lattice maze using depth first search, iterative 

67 

68 # Arguments 

69 - `grid_shape: Coord`: the shape of the grid 

70 - `lattice_dim: int`: the dimension of the lattice 

71 (default: `2`) 

72 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** 

73 (default: `None`) 

74 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** 

75 (default: `None`) 

76 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. 

77 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. 

78 

79 # algorithm 

80 1. Choose the initial cell, mark it as visited and push it to the stack 

81 2. While the stack is not empty 

82 1. Pop a cell from the stack and make it a current cell 

83 2. If the current cell has any neighbours which have not been visited 

84 1. Push the current cell to the stack 

85 2. Choose one of the unvisited neighbours 

86 3. Remove the wall between the current cell and the chosen cell 

87 4. Mark the chosen cell as visited and push it to the stack 

88 """ 

89 # Default values if no constraints have been passed 

90 grid_shape_: Coord = np.array(grid_shape) 

91 n_total_cells: int = int(np.prod(grid_shape_)) 

92 

93 n_accessible_cells: int 

94 if accessible_cells is None: 

95 n_accessible_cells = n_total_cells 

96 elif isinstance(accessible_cells, float): 

97 assert accessible_cells <= 1, ( 

98 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" 

99 ) 

100 

101 n_accessible_cells = int(accessible_cells * n_total_cells) 

102 else: 

103 assert isinstance(accessible_cells, int) 

104 n_accessible_cells = accessible_cells 

105 

106 if max_tree_depth is None: 

107 max_tree_depth = ( 

108 2 * n_total_cells 

109 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. 

110 elif isinstance(max_tree_depth, float): 

111 assert max_tree_depth <= 1, ( 

112 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" 

113 ) 

114 

115 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_)) 

116 

117 # choose a random start coord 

118 start_coord = _random_start_coord(grid_shape_, start_coord) 

119 

120 # initialize the maze with no connections 

121 connection_list: ConnectionList = np.zeros( 

122 (lattice_dim, grid_shape_[0], grid_shape_[1]), 

123 dtype=np.bool_, 

124 ) 

125 

126 # initialize the stack with the target coord 

127 visited_cells: set[tuple[int, int]] = set() 

128 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol 

129 stack: list[Coord] = [start_coord] 

130 

131 # initialize tree_depth_counter 

132 current_tree_depth: int = 1 

133 

134 # loop until the stack is empty or n_connected_cells is reached 

135 while stack and (len(visited_cells) < n_accessible_cells): 

136 # get the current coord from the stack 

137 current_coord: Coord 

138 if randomized_stack: 

139 current_coord = stack.pop(random.randint(0, len(stack) - 1)) 

140 else: 

141 current_coord = stack.pop() 

142 

143 # filter neighbors by being within grid bounds and being unvisited 

144 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ 

145 (neighbor, delta) 

146 for neighbor, delta in zip( 

147 current_coord + NEIGHBORS_MASK, 

148 NEIGHBORS_MASK, 

149 strict=False, 

150 ) 

151 if ( 

152 (tuple(neighbor) not in visited_cells) 

153 and (0 <= neighbor[0] < grid_shape_[0]) 

154 and (0 <= neighbor[1] < grid_shape_[1]) 

155 ) 

156 ] 

157 

158 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) 

159 if unvisited_neighbors_deltas and ( 

160 current_tree_depth <= max_tree_depth / 2 

161 ): 

162 # if we want a maze without forks, simply don't add the current coord back to the stack 

163 if do_forks and (len(unvisited_neighbors_deltas) > 1): 

164 stack.append(current_coord) 

165 

166 # choose one of the unvisited neighbors 

167 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) 

168 

169 # add connection 

170 dim: int = int(np.argmax(np.abs(delta))) 

171 # if positive, down/right from current coord 

172 # if negative, up/left from current coord (down/right from neighbor) 

173 clist_node: Coord = ( 

174 current_coord if (delta.sum() > 0) else chosen_neighbor 

175 ) 

176 connection_list[dim, clist_node[0], clist_node[1]] = True 

177 

178 # add to visited cells and stack 

179 visited_cells.add(tuple(chosen_neighbor)) 

180 stack.append(chosen_neighbor) 

181 

182 # Update current tree depth 

183 current_tree_depth += 1 

184 else: 

185 current_tree_depth -= 1 

186 

187 return LatticeMaze( 

188 connection_list=connection_list, 

189 generation_meta=dict( 

190 func_name="gen_dfs", 

191 grid_shape=grid_shape_, 

192 start_coord=start_coord, 

193 n_accessible_cells=int(n_accessible_cells), 

194 max_tree_depth=int(max_tree_depth), 

195 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug 

196 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is 

197 # treated as fully connected even when it is most certainly not, causing solving the maze to break 

198 fully_connected=bool(len(visited_cells) == n_total_cells), 

199 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, 

200 ), 

201 ) 

202 

203 @staticmethod 

204 def gen_prim( 

205 grid_shape: Coord | CoordTup, 

206 lattice_dim: int = 2, 

207 accessible_cells: float | None = None, 

208 max_tree_depth: float | None = None, 

209 do_forks: bool = True, 

210 start_coord: Coord | None = None, 

211 ) -> LatticeMaze: 

212 "(broken!) generate a lattice maze using Prim's algorithm" 

213 warnings.warn( 

214 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12", 

215 ) 

216 return LatticeMazeGenerators.gen_dfs( 

217 grid_shape=grid_shape, 

218 lattice_dim=lattice_dim, 

219 accessible_cells=accessible_cells, 

220 max_tree_depth=max_tree_depth, 

221 do_forks=do_forks, 

222 start_coord=start_coord, 

223 randomized_stack=True, 

224 ) 

225 

226 @staticmethod 

227 def gen_wilson( 

228 grid_shape: Coord | CoordTup, 

229 **kwargs, 

230 ) -> LatticeMaze: 

231 """Generate a lattice maze using Wilson's algorithm. 

232 

233 # Algorithm 

234 Wilson's algorithm generates an unbiased (random) maze 

235 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is 

236 acyclic and all cells are part of a unique connected space. 

237 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm 

238 """ 

239 assert not kwargs, ( 

240 f"gen_wilson does not take any additional arguments, got {kwargs = }" 

241 ) 

242 

243 grid_shape_: Coord = np.array(grid_shape) 

244 

245 # Initialize grid and visited cells 

246 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_) 

247 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_) 

248 

249 # Choose a random cell and mark it as visited 

250 start_coord: Coord = _random_start_coord(grid_shape_, None) 

251 visited[start_coord[0], start_coord[1]] = True 

252 del start_coord 

253 

254 while not visited.all(): 

255 # Perform loop-erased random walk from another random cell 

256 

257 # Choose walk_start only from unvisited cells 

258 unvisited_coords: CoordArray = np.column_stack(np.where(~visited)) 

259 walk_start: Coord = unvisited_coords[ 

260 np.random.choice(unvisited_coords.shape[0]) 

261 ] 

262 

263 # Perform the random walk 

264 path: list[Coord] = [walk_start] 

265 current: Coord = walk_start 

266 

267 # exit the loop once the current path hits a visited cell 

268 while not visited[current[0], current[1]]: 

269 # find a valid neighbor (one always exists on a lattice) 

270 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_) 

271 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])] 

272 

273 # Check for loop 

274 loop_exit: int | None = None 

275 for i, p in enumerate(path): 

276 if np.array_equal(next_cell, p): 

277 loop_exit = i 

278 break 

279 

280 # erase the loop, or continue the walk 

281 if loop_exit is not None: 

282 # this removes everything after and including the loop start 

283 path = path[: loop_exit + 1] 

284 # reset current cell to end of path 

285 current = path[-1] 

286 else: 

287 path.append(next_cell) 

288 current = next_cell 

289 

290 # Add the path to the maze 

291 for i in range(len(path) - 1): 

292 c_1: Coord = path[i] 

293 c_2: Coord = path[i + 1] 

294 

295 # find the dimension of the connection 

296 delta: Coord = c_2 - c_1 

297 dim: int = int(np.argmax(np.abs(delta))) 

298 

299 # if positive, down/right from current coord 

300 # if negative, up/left from current coord (down/right from neighbor) 

301 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2 

302 connection_list[dim, clist_node[0], clist_node[1]] = True 

303 visited[c_1[0], c_1[1]] = True 

304 # we dont add c_2 because the last c_2 will have already been visited 

305 

306 return LatticeMaze( 

307 connection_list=connection_list, 

308 generation_meta=dict( 

309 func_name="gen_wilson", 

310 grid_shape=grid_shape_, 

311 fully_connected=True, 

312 ), 

313 ) 

314 

315 @staticmethod 

316 def gen_percolation( 

317 grid_shape: Coord | CoordTup, 

318 p: float = 0.4, 

319 lattice_dim: int = 2, 

320 start_coord: Coord | None = None, 

321 ) -> LatticeMaze: 

322 """generate a lattice maze using simple percolation 

323 

324 note that p in the range (0.4, 0.7) gives the most interesting mazes 

325 

326 # Arguments 

327 - `grid_shape: Coord`: the shape of the grid 

328 - `lattice_dim: int`: the dimension of the lattice (default: `2`) 

329 - `p: float`: the probability of a cell being accessible (default: `0.5`) 

330 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start) 

331 """ 

332 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018 

333 grid_shape_: Coord = np.array(grid_shape) 

334 

335 start_coord = _random_start_coord(grid_shape_, start_coord) 

336 

337 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p 

338 

339 connection_list = _fill_edges_with_walls(connection_list) 

340 

341 output: LatticeMaze = LatticeMaze( 

342 connection_list=connection_list, 

343 generation_meta=dict( 

344 func_name="gen_percolation", 

345 grid_shape=grid_shape_, 

346 percolation_p=p, 

347 start_coord=start_coord, 

348 ), 

349 ) 

350 

351 # generation_meta is sometimes None, but not here since we just made it a dict above 

352 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index] 

353 start_coord, 

354 ) 

355 

356 return output 

357 

358 @staticmethod 

359 def gen_dfs_percolation( 

360 grid_shape: Coord | CoordTup, 

361 p: float = 0.4, 

362 lattice_dim: int = 2, 

363 accessible_cells: int | None = None, 

364 max_tree_depth: int | None = None, 

365 start_coord: Coord | None = None, 

366 ) -> LatticeMaze: 

367 """dfs and then percolation (adds cycles)""" 

368 grid_shape_: Coord = np.array(grid_shape) 

369 start_coord = _random_start_coord(grid_shape_, start_coord) 

370 

371 # generate initial maze via dfs 

372 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs( 

373 grid_shape=grid_shape_, 

374 lattice_dim=lattice_dim, 

375 accessible_cells=accessible_cells, 

376 max_tree_depth=max_tree_depth, 

377 start_coord=start_coord, 

378 ) 

379 

380 # percolate 

381 connection_list_perc: np.ndarray = ( 

382 np.random.rand(*maze.connection_list.shape) < p 

383 ) 

384 connection_list_perc = _fill_edges_with_walls(connection_list_perc) 

385 

386 maze.__dict__["connection_list"] = np.logical_or( 

387 maze.connection_list, 

388 connection_list_perc, 

389 ) 

390 

391 # generation_meta is sometimes None, but not here since we just made it a dict above 

392 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index] 

393 maze.generation_meta["percolation_p"] = p # type: ignore[index] 

394 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index] 

395 start_coord, 

396 ) 

397 

398 return maze 

399 

400 @staticmethod 

401 def gen_kruskal( 

402 grid_shape: "Coord | CoordTup", 

403 lattice_dim: int = 2, 

404 start_coord: "Coord | None" = None, 

405 ) -> "LatticeMaze": 

406 """Generate a maze using Kruskal's algorithm. 

407 

408 This function generates a random spanning tree over a grid using Kruskal's algorithm. 

409 Each cell is treated as a node, and all valid adjacent edges are listed and processed 

410 in random order. An edge is added (i.e. its passage carved) only if it connects two cells 

411 that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) 

412 without cycles. 

413 

414 https://en.wikipedia.org/wiki/Kruskal's_algorithm 

415 

416 # Parameters: 

417 - `grid_shape : Coord | CoordTup` 

418 The shape of the maze grid (for example, `(n_rows, n_cols)`). 

419 - `lattice_dim : int` 

420 The lattice dimension (default is `2`). 

421 - `start_coord : Coord | None` 

422 Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen. 

423 - `**kwargs` 

424 Additional keyword arguments (currently unused). 

425 

426 # Returns: 

427 - `LatticeMaze` 

428 A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm. 

429 

430 # Usage: 

431 ```python 

432 maze = gen_kruskal((10, 10)) 

433 ``` 

434 """ 

435 assert lattice_dim == 2, ( # noqa: PLR2004 

436 "Kruskal's algorithm is only implemented for 2D lattices." 

437 ) 

438 # Convert grid_shape to a tuple of ints 

439 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 

440 n_rows, n_cols = grid_shape_ 

441 

442 # Initialize union-find data structure. 

443 parent: dict[tuple[int, int], tuple[int, int]] = {} 

444 

445 def find(cell: tuple[int, int]) -> tuple[int, int]: 

446 while parent[cell] != cell: 

447 parent[cell] = parent[parent[cell]] 

448 cell = parent[cell] 

449 return cell 

450 

451 def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None: 

452 root1 = find(cell1) 

453 root2 = find(cell2) 

454 parent[root2] = root1 

455 

456 # Initialize each cell as its own set. 

457 for i in range(n_rows): 

458 for j in range(n_cols): 

459 parent[(i, j)] = (i, j) 

460 

461 # List all possible edges. 

462 # For vertical edges (i.e. connecting a cell to its right neighbor): 

463 edges: list[tuple[tuple[int, int], tuple[int, int], int]] = [] 

464 for i in range(n_rows): 

465 for j in range(n_cols - 1): 

466 edges.append(((i, j), (i, j + 1), 1)) 

467 # For horizontal edges (i.e. connecting a cell to its bottom neighbor): 

468 for i in range(n_rows - 1): 

469 for j in range(n_cols): 

470 edges.append(((i, j), (i + 1, j), 0)) 

471 

472 # Shuffle the list of edges. 

473 import random 

474 

475 random.shuffle(edges) 

476 

477 # Initialize connection_list with no connections. 

478 # connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)). 

479 # connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)). 

480 import numpy as np 

481 

482 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 

483 

484 # Process each edge; if it connects two different trees, union them and carve the passage. 

485 for cell1, cell2, direction in edges: 

486 if find(cell1) != find(cell2): 

487 union(cell1, cell2) 

488 if direction == 0: 

489 # Horizontal edge: connection is stored in connection_list[0] at cell1. 

490 connection_list[0, cell1[0], cell1[1]] = True 

491 else: 

492 # Vertical edge: connection is stored in connection_list[1] at cell1. 

493 connection_list[1, cell1[0], cell1[1]] = True 

494 

495 if start_coord is None: 

496 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 

497 

498 generation_meta: dict = dict( 

499 func_name="gen_kruskal", 

500 grid_shape=grid_shape_, 

501 start_coord=start_coord, 

502 algorithm="kruskal", 

503 fully_connected=True, 

504 ) 

505 return LatticeMaze( 

506 connection_list=connection_list, generation_meta=generation_meta 

507 ) 

508 

509 @staticmethod 

510 def gen_recursive_division( 

511 grid_shape: "Coord | CoordTup", 

512 lattice_dim: int = 2, 

513 start_coord: "Coord | None" = None, 

514 ) -> "LatticeMaze": 

515 """Generate a maze using the recursive division algorithm. 

516 

517 This function generates a maze by recursively dividing the grid with walls and carving a single 

518 passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent 

519 cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. 

520 The resulting maze is a perfect maze, meaning there is exactly one path between any two cells. 

521 

522 # Parameters: 

523 - `grid_shape : Coord | CoordTup` 

524 The shape of the maze grid (e.g., `(n_rows, n_cols)`). 

525 - `lattice_dim : int` 

526 The lattice dimension (default is `2`). 

527 - `start_coord : Coord | None` 

528 Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen. 

529 - `**kwargs` 

530 Additional keyword arguments (currently unused). 

531 

532 # Returns: 

533 - `LatticeMaze` 

534 A maze represented by a connection list, generated using recursive division. 

535 

536 # Usage: 

537 ```python 

538 maze = gen_recursive_division((10, 10)) 

539 ``` 

540 """ 

541 assert lattice_dim == 2, ( # noqa: PLR2004 

542 "Recursive division algorithm is only implemented for 2D lattices." 

543 ) 

544 # Convert grid_shape to a tuple of ints. 

545 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 

546 n_rows, n_cols = grid_shape_ 

547 

548 # Initialize connection_list as a fully connected grid. 

549 # For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True. 

550 # For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True. 

551 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 

552 connection_list[0, : n_rows - 1, :] = True 

553 connection_list[1, :, : n_cols - 1] = True 

554 

555 def divide(x: int, y: int, width: int, height: int) -> None: 

556 """Recursively divide the region starting at (x, y) with the given width and height. 

557 

558 Removes connections along the chosen division line except for one randomly chosen gap. 

559 """ 

560 if width < 2 or height < 2: # noqa: PLR2004 

561 return 

562 

563 if width > height: 

564 # Vertical division. 

565 wall_col = random.randint(x + 1, x + width - 1) 

566 gap_row = random.randint(y, y + height - 1) 

567 for row in range(y, y + height): 

568 if row == gap_row: 

569 continue 

570 # Remove the vertical connection between (row, wall_col-1) and (row, wall_col). 

571 if wall_col - 1 < n_cols - 1: 

572 connection_list[1, row, wall_col - 1] = False 

573 # Recurse on the left and right subregions. 

574 divide(x, y, wall_col - x, height) 

575 divide(wall_col, y, x + width - wall_col, height) 

576 else: 

577 # Horizontal division. 

578 wall_row = random.randint(y + 1, y + height - 1) 

579 gap_col = random.randint(x, x + width - 1) 

580 for col in range(x, x + width): 

581 if col == gap_col: 

582 continue 

583 # Remove the horizontal connection between (wall_row-1, col) and (wall_row, col). 

584 if wall_row - 1 < n_rows - 1: 

585 connection_list[0, wall_row - 1, col] = False 

586 # Recurse on the top and bottom subregions. 

587 divide(x, y, width, wall_row - y) 

588 divide(x, wall_row, width, y + height - wall_row) 

589 

590 # Begin the division on the full grid. 

591 divide(0, 0, n_cols, n_rows) 

592 

593 if start_coord is None: 

594 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 

595 

596 generation_meta: dict = dict( 

597 func_name="gen_recursive_division", 

598 grid_shape=grid_shape_, 

599 start_coord=start_coord, 

600 algorithm="recursive_division", 

601 fully_connected=True, 

602 ) 

603 return LatticeMaze( 

604 connection_list=connection_list, generation_meta=generation_meta 

605 ) 

606 

607 

608# cant automatically populate this because it messes with pickling :( 

609GENERATORS_MAP: dict[str, Callable[[Coord | CoordTup, Any], "LatticeMaze"]] = { 

610 "gen_dfs": LatticeMazeGenerators.gen_dfs, 

611 # TYPING: error: Dict entry 1 has incompatible type 

612 # "str": "Callable[[ndarray[Any, Any] | tuple[int, int], KwArg(Any)], LatticeMaze]"; 

613 # expected "str": "Callable[[ndarray[Any, Any] | tuple[int, int], Any], LatticeMaze]" [dict-item] 

614 # gen_wilson takes no kwargs and we check that the kwargs are empty 

615 # but mypy doesnt like this, `Any` != `KwArg(Any)` 

616 "gen_wilson": LatticeMazeGenerators.gen_wilson, # type: ignore[dict-item] 

617 "gen_percolation": LatticeMazeGenerators.gen_percolation, 

618 "gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation, 

619 "gen_prim": LatticeMazeGenerators.gen_prim, 

620 "gen_kruskal": LatticeMazeGenerators.gen_kruskal, 

621 "gen_recursive_division": LatticeMazeGenerators.gen_recursive_division, 

622} 

623"mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`" 

624 

625_GENERATORS_PERCOLATED: list[str] = [ 

626 "gen_percolation", 

627 "gen_dfs_percolation", 

628] 

629"""list of generator names that generate percolated mazes 

630we use this to figure out the expected success rate, since depending on the endpoint kwargs this might fail 

631this variable is primarily used in `MazeDatasetConfig._to_ps_array` and `MazeDatasetConfig._from_ps_array` 

632""" 

633 

634 

635def get_maze_with_solution( 

636 gen_name: str, 

637 grid_shape: Coord | CoordTup, 

638 maze_ctor_kwargs: dict | None = None, 

639) -> SolvedMaze: 

640 "helper function to get a maze already with a solution" 

641 if maze_ctor_kwargs is None: 

642 maze_ctor_kwargs = dict() 

643 # TYPING: error: Too few arguments [call-arg] 

644 # not sure why this is happening -- doesnt recognize the kwargs? 

645 maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs) # type: ignore[call-arg] 

646 solution: CoordArray = np.array(maze.generate_random_path()) 

647 return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)