Coverage for maze_dataset\generation\generators.py: 78%

129 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`""" 

2 

3import random 

4import warnings 

5from typing import Any, Callable 

6 

7import numpy as np 

8from jaxtyping import Bool 

9from muutils.mlutils import GLOBAL_SEED 

10 

11from maze_dataset.constants import CoordArray 

12from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze 

13from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls 

14 

15numpy_rng = np.random.default_rng(GLOBAL_SEED) 

16random.seed(GLOBAL_SEED) 

17 

18 

19def _random_start_coord(grid_shape: Coord, start_coord: Coord | None) -> Coord: 

20 "picking a random start coord within the bounds of `grid_shape` if none is provided" 

21 if start_coord is None: 

22 start_coord: Coord = np.random.randint( 

23 0, # lower bound 

24 np.maximum(grid_shape - 1, 1), # upper bound (at least 1) 

25 size=len(grid_shape), # dimensionality 

26 ) 

27 else: 

28 start_coord = np.array(start_coord) 

29 

30 return start_coord 

31 

32 

33def get_neighbors_in_bounds( 

34 coord: Coord, 

35 grid_shape: Coord, 

36) -> CoordArray: 

37 "get all neighbors of a coordinate that are within the bounds of the grid" 

38 # get all neighbors 

39 neighbors: CoordArray = coord + NEIGHBORS_MASK 

40 

41 # filter neighbors by being within grid bounds 

42 neighbors_in_bounds: CoordArray = neighbors[ 

43 (neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1) 

44 ] 

45 

46 return neighbors_in_bounds 

47 

48 

49class LatticeMazeGenerators: 

50 """namespace for lattice maze generation algorithms""" 

51 

52 @staticmethod 

53 def gen_dfs( 

54 grid_shape: Coord, 

55 lattice_dim: int = 2, 

56 accessible_cells: int | float | None = None, 

57 max_tree_depth: int | float | None = None, 

58 do_forks: bool = True, 

59 randomized_stack: bool = False, 

60 start_coord: Coord | None = None, 

61 ) -> LatticeMaze: 

62 """generate a lattice maze using depth first search, iterative 

63 

64 # Arguments 

65 - `grid_shape: Coord`: the shape of the grid 

66 - `lattice_dim: int`: the dimension of the lattice 

67 (default: `2`) 

68 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** 

69 (default: `None`) 

70 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** 

71 (default: `None`) 

72 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. 

73 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. 

74 

75 # algorithm 

76 1. Choose the initial cell, mark it as visited and push it to the stack 

77 2. While the stack is not empty 

78 1. Pop a cell from the stack and make it a current cell 

79 2. If the current cell has any neighbours which have not been visited 

80 1. Push the current cell to the stack 

81 2. Choose one of the unvisited neighbours 

82 3. Remove the wall between the current cell and the chosen cell 

83 4. Mark the chosen cell as visited and push it to the stack 

84 """ 

85 

86 # Default values if no constraints have been passed 

87 grid_shape: Coord = np.array(grid_shape) 

88 n_total_cells: int = int(np.prod(grid_shape)) 

89 

90 n_accessible_cells: int 

91 if accessible_cells is None: 

92 n_accessible_cells = n_total_cells 

93 elif isinstance(accessible_cells, float): 

94 assert accessible_cells <= 1, ( 

95 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" 

96 ) 

97 

98 n_accessible_cells = int(accessible_cells * n_total_cells) 

99 else: 

100 assert isinstance(accessible_cells, int) 

101 n_accessible_cells = accessible_cells 

102 

103 if max_tree_depth is None: 

104 max_tree_depth = ( 

105 2 * n_total_cells 

106 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. 

107 elif isinstance(max_tree_depth, float): 

108 assert max_tree_depth <= 1, ( 

109 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" 

110 ) 

111 

112 max_tree_depth = int(max_tree_depth * np.sum(grid_shape)) 

113 

114 # choose a random start coord 

115 start_coord = _random_start_coord(grid_shape, start_coord) 

116 

117 # initialize the maze with no connections 

118 connection_list: ConnectionList = np.zeros( 

119 (lattice_dim, grid_shape[0], grid_shape[1]), dtype=np.bool_ 

120 ) 

121 

122 # initialize the stack with the target coord 

123 visited_cells: set[tuple[int, int]] = set() 

124 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol 

125 stack: list[Coord] = [start_coord] 

126 

127 # initialize tree_depth_counter 

128 current_tree_depth: int = 1 

129 

130 # loop until the stack is empty or n_connected_cells is reached 

131 while stack and (len(visited_cells) < n_accessible_cells): 

132 # get the current coord from the stack 

133 current_coord: Coord 

134 if randomized_stack: 

135 current_coord = stack.pop(random.randint(0, len(stack) - 1)) 

136 else: 

137 current_coord = stack.pop() 

138 

139 # filter neighbors by being within grid bounds and being unvisited 

140 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ 

141 (neighbor, delta) 

142 for neighbor, delta in zip( 

143 current_coord + NEIGHBORS_MASK, NEIGHBORS_MASK 

144 ) 

145 if ( 

146 (tuple(neighbor) not in visited_cells) 

147 and (0 <= neighbor[0] < grid_shape[0]) 

148 and (0 <= neighbor[1] < grid_shape[1]) 

149 ) 

150 ] 

151 

152 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) 

153 if unvisited_neighbors_deltas and ( 

154 current_tree_depth <= max_tree_depth / 2 

155 ): 

156 # if we want a maze without forks, simply don't add the current coord back to the stack 

157 if do_forks and (len(unvisited_neighbors_deltas) > 1): 

158 stack.append(current_coord) 

159 

160 # choose one of the unvisited neighbors 

161 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) 

162 

163 # add connection 

164 dim: int = np.argmax(np.abs(delta)) 

165 # if positive, down/right from current coord 

166 # if negative, up/left from current coord (down/right from neighbor) 

167 clist_node: Coord = ( 

168 current_coord if (delta.sum() > 0) else chosen_neighbor 

169 ) 

170 connection_list[dim, clist_node[0], clist_node[1]] = True 

171 

172 # add to visited cells and stack 

173 visited_cells.add(tuple(chosen_neighbor)) 

174 stack.append(chosen_neighbor) 

175 

176 # Update current tree depth 

177 current_tree_depth += 1 

178 else: 

179 current_tree_depth -= 1 

180 

181 output = LatticeMaze( 

182 connection_list=connection_list, 

183 generation_meta=dict( 

184 func_name="gen_dfs", 

185 grid_shape=grid_shape, 

186 start_coord=start_coord, 

187 n_accessible_cells=int(n_accessible_cells), 

188 max_tree_depth=int(max_tree_depth), 

189 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug 

190 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is 

191 # treated as fully connected even when it is most certainly not, causing solving the maze to break 

192 fully_connected=bool(len(visited_cells) == n_total_cells), 

193 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, 

194 ), 

195 ) 

196 

197 return output 

198 

199 @staticmethod 

200 def gen_prim( 

201 grid_shape: Coord, 

202 lattice_dim: int = 2, 

203 accessible_cells: int | float | None = None, 

204 max_tree_depth: int | float | None = None, 

205 do_forks: bool = True, 

206 start_coord: Coord | None = None, 

207 ) -> LatticeMaze: 

208 warnings.warn( 

209 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12", 

210 ) 

211 return LatticeMazeGenerators.gen_dfs( 

212 grid_shape=grid_shape, 

213 lattice_dim=lattice_dim, 

214 accessible_cells=accessible_cells, 

215 max_tree_depth=max_tree_depth, 

216 do_forks=do_forks, 

217 start_coord=start_coord, 

218 randomized_stack=True, 

219 ) 

220 

221 @staticmethod 

222 def gen_wilson( 

223 grid_shape: Coord, 

224 ) -> LatticeMaze: 

225 """Generate a lattice maze using Wilson's algorithm. 

226 

227 # Algorithm 

228 Wilson's algorithm generates an unbiased (random) maze 

229 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is 

230 acyclic and all cells are part of a unique connected space. 

231 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm 

232 """ 

233 

234 # Initialize grid and visited cells 

235 connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_) 

236 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape, dtype=np.bool_) 

237 

238 # Choose a random cell and mark it as visited 

239 start_coord: Coord = _random_start_coord(grid_shape, None) 

240 visited[start_coord[0], start_coord[1]] = True 

241 del start_coord 

242 

243 while not visited.all(): 

244 # Perform loop-erased random walk from another random cell 

245 

246 # Choose walk_start only from unvisited cells 

247 unvisited_coords: CoordArray = np.column_stack(np.where(~visited)) 

248 walk_start: Coord = unvisited_coords[ 

249 np.random.choice(unvisited_coords.shape[0]) 

250 ] 

251 

252 # Perform the random walk 

253 path: list[Coord] = [walk_start] 

254 current: Coord = walk_start 

255 

256 # exit the loop once the current path hits a visited cell 

257 while not visited[current[0], current[1]]: 

258 # find a valid neighbor (one always exists on a lattice) 

259 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape) 

260 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])] 

261 

262 # Check for loop 

263 loop_exit: int | None = None 

264 for i, p in enumerate(path): 

265 if np.array_equal(next_cell, p): 

266 loop_exit = i 

267 break 

268 

269 # erase the loop, or continue the walk 

270 if loop_exit is not None: 

271 # this removes everything after and including the loop start 

272 path = path[: loop_exit + 1] 

273 # reset current cell to end of path 

274 current = path[-1] 

275 else: 

276 path.append(next_cell) 

277 current = next_cell 

278 

279 # Add the path to the maze 

280 for i in range(len(path) - 1): 

281 c_1: Coord = path[i] 

282 c_2: Coord = path[i + 1] 

283 

284 # find the dimension of the connection 

285 delta: Coord = c_2 - c_1 

286 dim: int = np.argmax(np.abs(delta)) 

287 

288 # if positive, down/right from current coord 

289 # if negative, up/left from current coord (down/right from neighbor) 

290 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2 

291 connection_list[dim, clist_node[0], clist_node[1]] = True 

292 visited[c_1[0], c_1[1]] = True 

293 # we dont add c_2 because the last c_2 will have already been visited 

294 

295 return LatticeMaze( 

296 connection_list=connection_list, 

297 generation_meta=dict( 

298 func_name="gen_wilson", 

299 grid_shape=grid_shape, 

300 fully_connected=True, 

301 ), 

302 ) 

303 

304 @staticmethod 

305 def gen_percolation( 

306 grid_shape: Coord, 

307 p: float = 0.4, 

308 lattice_dim: int = 2, 

309 start_coord: Coord | None = None, 

310 ) -> LatticeMaze: 

311 """generate a lattice maze using simple percolation 

312 

313 note that p in the range (0.4, 0.7) gives the most interesting mazes 

314 

315 # Arguments 

316 - `grid_shape: Coord`: the shape of the grid 

317 - `lattice_dim: int`: the dimension of the lattice (default: `2`) 

318 - `p: float`: the probability of a cell being accessible (default: `0.5`) 

319 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start) 

320 """ 

321 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" 

322 grid_shape: Coord = np.array(grid_shape) 

323 

324 start_coord = _random_start_coord(grid_shape, start_coord) 

325 

326 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape) < p 

327 

328 connection_list = _fill_edges_with_walls(connection_list) 

329 

330 output: LatticeMaze = LatticeMaze( 

331 connection_list=connection_list, 

332 generation_meta=dict( 

333 func_name="gen_percolation", 

334 grid_shape=grid_shape, 

335 percolation_p=p, 

336 start_coord=start_coord, 

337 ), 

338 ) 

339 

340 output.generation_meta["visited_cells"] = output.gen_connected_component_from( 

341 start_coord 

342 ) 

343 

344 return output 

345 

346 @staticmethod 

347 def gen_dfs_percolation( 

348 grid_shape: Coord, 

349 p: float = 0.4, 

350 lattice_dim: int = 2, 

351 accessible_cells: int | None = None, 

352 max_tree_depth: int | None = None, 

353 start_coord: Coord | None = None, 

354 ) -> LatticeMaze: 

355 """dfs and then percolation (adds cycles)""" 

356 grid_shape: Coord = np.array(grid_shape) 

357 start_coord = _random_start_coord(grid_shape, start_coord) 

358 

359 # generate initial maze via dfs 

360 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs( 

361 grid_shape=grid_shape, 

362 lattice_dim=lattice_dim, 

363 accessible_cells=accessible_cells, 

364 max_tree_depth=max_tree_depth, 

365 start_coord=start_coord, 

366 ) 

367 

368 # percolate 

369 connection_list_perc: np.ndarray = ( 

370 np.random.rand(*maze.connection_list.shape) < p 

371 ) 

372 connection_list_perc = _fill_edges_with_walls(connection_list_perc) 

373 

374 maze.__dict__["connection_list"] = np.logical_or( 

375 maze.connection_list, connection_list_perc 

376 ) 

377 

378 maze.generation_meta["func_name"] = "gen_dfs_percolation" 

379 maze.generation_meta["percolation_p"] = p 

380 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( 

381 start_coord 

382 ) 

383 

384 return maze 

385 

386 

387# cant automatically populate this because it messes with pickling :( 

388GENERATORS_MAP: dict[str, Callable[[Coord, Any], "LatticeMaze"]] = { 

389 "gen_dfs": LatticeMazeGenerators.gen_dfs, 

390 "gen_wilson": LatticeMazeGenerators.gen_wilson, 

391 "gen_percolation": LatticeMazeGenerators.gen_percolation, 

392 "gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation, 

393 "gen_prim": LatticeMazeGenerators.gen_prim, 

394} 

395"mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`" 

396 

397 

398def get_maze_with_solution( 

399 gen_name: str, 

400 grid_shape: Coord, 

401 maze_ctor_kwargs: dict | None = None, 

402) -> SolvedMaze: 

403 "helper function to get a maze already with a solution" 

404 if maze_ctor_kwargs is None: 

405 maze_ctor_kwargs = dict() 

406 maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs) 

407 solution: CoordArray = np.array(maze.generate_random_path()) 

408 return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)