docs for maze-dataset v1.3.1
View Source on GitHub

maze_dataset.generation.generators

generation functions have signature (grid_shape: Coord, **kwargs) -> LatticeMaze and are methods in LatticeMazeGenerators


  1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`"""
  2
  3import random
  4import warnings
  5from typing import Any, Callable
  6
  7import numpy as np
  8from jaxtyping import Bool
  9
 10from maze_dataset.constants import CoordArray, CoordTup
 11from maze_dataset.generation.seed import GLOBAL_SEED
 12from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze
 13from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls
 14
 15numpy_rng = np.random.default_rng(GLOBAL_SEED)
 16random.seed(GLOBAL_SEED)
 17
 18
 19def _random_start_coord(
 20	grid_shape: Coord,
 21	start_coord: Coord | CoordTup | None,
 22) -> Coord:
 23	"picking a random start coord within the bounds of `grid_shape` if none is provided"
 24	start_coord_: Coord
 25	if start_coord is None:
 26		start_coord_ = np.random.randint(
 27			0,  # lower bound
 28			np.maximum(grid_shape - 1, 1),  # upper bound (at least 1)
 29			size=len(grid_shape),  # dimensionality
 30		)
 31	else:
 32		start_coord_ = np.array(start_coord)
 33
 34	return start_coord_
 35
 36
 37def get_neighbors_in_bounds(
 38	coord: Coord,
 39	grid_shape: Coord,
 40) -> CoordArray:
 41	"get all neighbors of a coordinate that are within the bounds of the grid"
 42	# get all neighbors
 43	neighbors: CoordArray = coord + NEIGHBORS_MASK
 44
 45	# filter neighbors by being within grid bounds
 46	neighbors_in_bounds: CoordArray = neighbors[
 47		(neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1)
 48	]
 49
 50	return neighbors_in_bounds
 51
 52
 53class LatticeMazeGenerators:
 54	"""namespace for lattice maze generation algorithms"""
 55
 56	@staticmethod
 57	def gen_dfs(
 58		grid_shape: Coord | CoordTup,
 59		lattice_dim: int = 2,
 60		accessible_cells: float | None = None,
 61		max_tree_depth: float | None = None,
 62		do_forks: bool = True,
 63		randomized_stack: bool = False,
 64		start_coord: Coord | None = None,
 65	) -> LatticeMaze:
 66		"""generate a lattice maze using depth first search, iterative
 67
 68		# Arguments
 69		- `grid_shape: Coord`: the shape of the grid
 70		- `lattice_dim: int`: the dimension of the lattice
 71			(default: `2`)
 72		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 73			(default: `None`)
 74		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 75			(default: `None`)
 76		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 77		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 78
 79		# algorithm
 80		1. Choose the initial cell, mark it as visited and push it to the stack
 81		2. While the stack is not empty
 82			1. Pop a cell from the stack and make it a current cell
 83			2. If the current cell has any neighbours which have not been visited
 84				1. Push the current cell to the stack
 85				2. Choose one of the unvisited neighbours
 86				3. Remove the wall between the current cell and the chosen cell
 87				4. Mark the chosen cell as visited and push it to the stack
 88		"""
 89		# Default values if no constraints have been passed
 90		grid_shape_: Coord = np.array(grid_shape)
 91		n_total_cells: int = int(np.prod(grid_shape_))
 92
 93		n_accessible_cells: int
 94		if accessible_cells is None:
 95			n_accessible_cells = n_total_cells
 96		elif isinstance(accessible_cells, float):
 97			assert accessible_cells <= 1, (
 98				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
 99			)
100
101			n_accessible_cells = int(accessible_cells * n_total_cells)
102		else:
103			assert isinstance(accessible_cells, int)
104			n_accessible_cells = accessible_cells
105
106		if max_tree_depth is None:
107			max_tree_depth = (
108				2 * n_total_cells
109			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
110		elif isinstance(max_tree_depth, float):
111			assert max_tree_depth <= 1, (
112				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
113			)
114
115			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
116
117		# choose a random start coord
118		start_coord = _random_start_coord(grid_shape_, start_coord)
119
120		# initialize the maze with no connections
121		connection_list: ConnectionList = np.zeros(
122			(lattice_dim, grid_shape_[0], grid_shape_[1]),
123			dtype=np.bool_,
124		)
125
126		# initialize the stack with the target coord
127		visited_cells: set[tuple[int, int]] = set()
128		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
129		stack: list[Coord] = [start_coord]
130
131		# initialize tree_depth_counter
132		current_tree_depth: int = 1
133
134		# loop until the stack is empty or n_connected_cells is reached
135		while stack and (len(visited_cells) < n_accessible_cells):
136			# get the current coord from the stack
137			current_coord: Coord
138			if randomized_stack:
139				current_coord = stack.pop(random.randint(0, len(stack) - 1))
140			else:
141				current_coord = stack.pop()
142
143			# filter neighbors by being within grid bounds and being unvisited
144			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
145				(neighbor, delta)
146				for neighbor, delta in zip(
147					current_coord + NEIGHBORS_MASK,
148					NEIGHBORS_MASK,
149					strict=False,
150				)
151				if (
152					(tuple(neighbor) not in visited_cells)
153					and (0 <= neighbor[0] < grid_shape_[0])
154					and (0 <= neighbor[1] < grid_shape_[1])
155				)
156			]
157
158			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
159			if unvisited_neighbors_deltas and (
160				current_tree_depth <= max_tree_depth / 2
161			):
162				# if we want a maze without forks, simply don't add the current coord back to the stack
163				if do_forks and (len(unvisited_neighbors_deltas) > 1):
164					stack.append(current_coord)
165
166				# choose one of the unvisited neighbors
167				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)
168
169				# add connection
170				dim: int = int(np.argmax(np.abs(delta)))
171				# if positive, down/right from current coord
172				# if negative, up/left from current coord (down/right from neighbor)
173				clist_node: Coord = (
174					current_coord if (delta.sum() > 0) else chosen_neighbor
175				)
176				connection_list[dim, clist_node[0], clist_node[1]] = True
177
178				# add to visited cells and stack
179				visited_cells.add(tuple(chosen_neighbor))
180				stack.append(chosen_neighbor)
181
182				# Update current tree depth
183				current_tree_depth += 1
184			else:
185				current_tree_depth -= 1
186
187		return LatticeMaze(
188			connection_list=connection_list,
189			generation_meta=dict(
190				func_name="gen_dfs",
191				grid_shape=grid_shape_,
192				start_coord=start_coord,
193				n_accessible_cells=int(n_accessible_cells),
194				max_tree_depth=int(max_tree_depth),
195				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
196				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
197				# treated as fully connected even when it is most certainly not, causing solving the maze to break
198				fully_connected=bool(len(visited_cells) == n_total_cells),
199				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
200			),
201		)
202
203	@staticmethod
204	def gen_prim(
205		grid_shape: Coord | CoordTup,
206		lattice_dim: int = 2,
207		accessible_cells: float | None = None,
208		max_tree_depth: float | None = None,
209		do_forks: bool = True,
210		start_coord: Coord | None = None,
211	) -> LatticeMaze:
212		"(broken!) generate a lattice maze using Prim's algorithm"
213		warnings.warn(
214			"gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
215		)
216		return LatticeMazeGenerators.gen_dfs(
217			grid_shape=grid_shape,
218			lattice_dim=lattice_dim,
219			accessible_cells=accessible_cells,
220			max_tree_depth=max_tree_depth,
221			do_forks=do_forks,
222			start_coord=start_coord,
223			randomized_stack=True,
224		)
225
226	@staticmethod
227	def gen_wilson(
228		grid_shape: Coord | CoordTup,
229		**kwargs,
230	) -> LatticeMaze:
231		"""Generate a lattice maze using Wilson's algorithm.
232
233		# Algorithm
234		Wilson's algorithm generates an unbiased (random) maze
235		sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
236		acyclic and all cells are part of a unique connected space.
237		https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
238		"""
239		assert not kwargs, (
240			f"gen_wilson does not take any additional arguments, got {kwargs = }"
241		)
242
243		grid_shape_: Coord = np.array(grid_shape)
244
245		# Initialize grid and visited cells
246		connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
247		visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
248
249		# Choose a random cell and mark it as visited
250		start_coord: Coord = _random_start_coord(grid_shape_, None)
251		visited[start_coord[0], start_coord[1]] = True
252		del start_coord
253
254		while not visited.all():
255			# Perform loop-erased random walk from another random cell
256
257			# Choose walk_start only from unvisited cells
258			unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
259			walk_start: Coord = unvisited_coords[
260				np.random.choice(unvisited_coords.shape[0])
261			]
262
263			# Perform the random walk
264			path: list[Coord] = [walk_start]
265			current: Coord = walk_start
266
267			# exit the loop once the current path hits a visited cell
268			while not visited[current[0], current[1]]:
269				# find a valid neighbor (one always exists on a lattice)
270				neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
271				next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
272
273				# Check for loop
274				loop_exit: int | None = None
275				for i, p in enumerate(path):
276					if np.array_equal(next_cell, p):
277						loop_exit = i
278						break
279
280				# erase the loop, or continue the walk
281				if loop_exit is not None:
282					# this removes everything after and including the loop start
283					path = path[: loop_exit + 1]
284					# reset current cell to end of path
285					current = path[-1]
286				else:
287					path.append(next_cell)
288					current = next_cell
289
290			# Add the path to the maze
291			for i in range(len(path) - 1):
292				c_1: Coord = path[i]
293				c_2: Coord = path[i + 1]
294
295				# find the dimension of the connection
296				delta: Coord = c_2 - c_1
297				dim: int = int(np.argmax(np.abs(delta)))
298
299				# if positive, down/right from current coord
300				# if negative, up/left from current coord (down/right from neighbor)
301				clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
302				connection_list[dim, clist_node[0], clist_node[1]] = True
303				visited[c_1[0], c_1[1]] = True
304				# we dont add c_2 because the last c_2 will have already been visited
305
306		return LatticeMaze(
307			connection_list=connection_list,
308			generation_meta=dict(
309				func_name="gen_wilson",
310				grid_shape=grid_shape_,
311				fully_connected=True,
312			),
313		)
314
315	@staticmethod
316	def gen_percolation(
317		grid_shape: Coord | CoordTup,
318		p: float = 0.4,
319		lattice_dim: int = 2,
320		start_coord: Coord | None = None,
321	) -> LatticeMaze:
322		"""generate a lattice maze using simple percolation
323
324		note that p in the range (0.4, 0.7) gives the most interesting mazes
325
326		# Arguments
327		- `grid_shape: Coord`: the shape of the grid
328		- `lattice_dim: int`: the dimension of the lattice (default: `2`)
329		- `p: float`: the probability of a cell being accessible (default: `0.5`)
330		- `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
331		"""
332		assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"  # noqa: PT018
333		grid_shape_: Coord = np.array(grid_shape)
334
335		start_coord = _random_start_coord(grid_shape_, start_coord)
336
337		connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
338
339		connection_list = _fill_edges_with_walls(connection_list)
340
341		output: LatticeMaze = LatticeMaze(
342			connection_list=connection_list,
343			generation_meta=dict(
344				func_name="gen_percolation",
345				grid_shape=grid_shape_,
346				percolation_p=p,
347				start_coord=start_coord,
348			),
349		)
350
351		# generation_meta is sometimes None, but not here since we just made it a dict above
352		output.generation_meta["visited_cells"] = output.gen_connected_component_from(  # type: ignore[index]
353			start_coord,
354		)
355
356		return output
357
358	@staticmethod
359	def gen_dfs_percolation(
360		grid_shape: Coord | CoordTup,
361		p: float = 0.4,
362		lattice_dim: int = 2,
363		accessible_cells: int | None = None,
364		max_tree_depth: int | None = None,
365		start_coord: Coord | None = None,
366	) -> LatticeMaze:
367		"""dfs and then percolation (adds cycles)"""
368		grid_shape_: Coord = np.array(grid_shape)
369		start_coord = _random_start_coord(grid_shape_, start_coord)
370
371		# generate initial maze via dfs
372		maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
373			grid_shape=grid_shape_,
374			lattice_dim=lattice_dim,
375			accessible_cells=accessible_cells,
376			max_tree_depth=max_tree_depth,
377			start_coord=start_coord,
378		)
379
380		# percolate
381		connection_list_perc: np.ndarray = (
382			np.random.rand(*maze.connection_list.shape) < p
383		)
384		connection_list_perc = _fill_edges_with_walls(connection_list_perc)
385
386		maze.__dict__["connection_list"] = np.logical_or(
387			maze.connection_list,
388			connection_list_perc,
389		)
390
391		# generation_meta is sometimes None, but not here since we just made it a dict above
392		maze.generation_meta["func_name"] = "gen_dfs_percolation"  # type: ignore[index]
393		maze.generation_meta["percolation_p"] = p  # type: ignore[index]
394		maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(  # type: ignore[index]
395			start_coord,
396		)
397
398		return maze
399
400	@staticmethod
401	def gen_kruskal(
402		grid_shape: "Coord | CoordTup",
403		lattice_dim: int = 2,
404		start_coord: "Coord | None" = None,
405	) -> "LatticeMaze":
406		"""Generate a maze using Kruskal's algorithm.
407
408		This function generates a random spanning tree over a grid using Kruskal's algorithm.
409		Each cell is treated as a node, and all valid adjacent edges are listed and processed
410		in random order. An edge is added (i.e. its passage carved) only if it connects two cells
411		that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree)
412		without cycles.
413
414		https://en.wikipedia.org/wiki/Kruskal's_algorithm
415
416		# Parameters:
417		- `grid_shape : Coord | CoordTup`
418			The shape of the maze grid (for example, `(n_rows, n_cols)`).
419		- `lattice_dim : int`
420			The lattice dimension (default is `2`).
421		- `start_coord : Coord | None`
422			Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen.
423		- `**kwargs`
424			Additional keyword arguments (currently unused).
425
426		# Returns:
427		- `LatticeMaze`
428			A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
429
430		# Usage:
431		```python
432		maze = gen_kruskal((10, 10))
433		```
434		"""
435		assert lattice_dim == 2, (  # noqa: PLR2004
436			"Kruskal's algorithm is only implemented for 2D lattices."
437		)
438		# Convert grid_shape to a tuple of ints
439		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
440		n_rows, n_cols = grid_shape_
441
442		# Initialize union-find data structure.
443		parent: dict[tuple[int, int], tuple[int, int]] = {}
444
445		def find(cell: tuple[int, int]) -> tuple[int, int]:
446			while parent[cell] != cell:
447				parent[cell] = parent[parent[cell]]
448				cell = parent[cell]
449			return cell
450
451		def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None:
452			root1 = find(cell1)
453			root2 = find(cell2)
454			parent[root2] = root1
455
456		# Initialize each cell as its own set.
457		for i in range(n_rows):
458			for j in range(n_cols):
459				parent[(i, j)] = (i, j)
460
461		# List all possible edges.
462		# For vertical edges (i.e. connecting a cell to its right neighbor):
463		edges: list[tuple[tuple[int, int], tuple[int, int], int]] = []
464		for i in range(n_rows):
465			for j in range(n_cols - 1):
466				edges.append(((i, j), (i, j + 1), 1))
467		# For horizontal edges (i.e. connecting a cell to its bottom neighbor):
468		for i in range(n_rows - 1):
469			for j in range(n_cols):
470				edges.append(((i, j), (i + 1, j), 0))
471
472		# Shuffle the list of edges.
473		import random
474
475		random.shuffle(edges)
476
477		# Initialize connection_list with no connections.
478		# connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)).
479		# connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)).
480		import numpy as np
481
482		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
483
484		# Process each edge; if it connects two different trees, union them and carve the passage.
485		for cell1, cell2, direction in edges:
486			if find(cell1) != find(cell2):
487				union(cell1, cell2)
488				if direction == 0:
489					# Horizontal edge: connection is stored in connection_list[0] at cell1.
490					connection_list[0, cell1[0], cell1[1]] = True
491				else:
492					# Vertical edge: connection is stored in connection_list[1] at cell1.
493					connection_list[1, cell1[0], cell1[1]] = True
494
495		if start_coord is None:
496			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
497
498		generation_meta: dict = dict(
499			func_name="gen_kruskal",
500			grid_shape=grid_shape_,
501			start_coord=start_coord,
502			algorithm="kruskal",
503			fully_connected=True,
504		)
505		return LatticeMaze(
506			connection_list=connection_list, generation_meta=generation_meta
507		)
508
509	@staticmethod
510	def gen_recursive_division(
511		grid_shape: "Coord | CoordTup",
512		lattice_dim: int = 2,
513		start_coord: "Coord | None" = None,
514	) -> "LatticeMaze":
515		"""Generate a maze using the recursive division algorithm.
516
517		This function generates a maze by recursively dividing the grid with walls and carving a single
518		passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent
519		cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage.
520		The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
521
522		# Parameters:
523		- `grid_shape : Coord | CoordTup`
524			The shape of the maze grid (e.g., `(n_rows, n_cols)`).
525		- `lattice_dim : int`
526			The lattice dimension (default is `2`).
527		- `start_coord : Coord | None`
528			Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen.
529		- `**kwargs`
530			Additional keyword arguments (currently unused).
531
532		# Returns:
533		- `LatticeMaze`
534			A maze represented by a connection list, generated using recursive division.
535
536		# Usage:
537		```python
538		maze = gen_recursive_division((10, 10))
539		```
540		"""
541		assert lattice_dim == 2, (  # noqa: PLR2004
542			"Recursive division algorithm is only implemented for 2D lattices."
543		)
544		# Convert grid_shape to a tuple of ints.
545		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
546		n_rows, n_cols = grid_shape_
547
548		# Initialize connection_list as a fully connected grid.
549		# For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True.
550		# For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True.
551		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
552		connection_list[0, : n_rows - 1, :] = True
553		connection_list[1, :, : n_cols - 1] = True
554
555		def divide(x: int, y: int, width: int, height: int) -> None:
556			"""Recursively divide the region starting at (x, y) with the given width and height.
557
558			Removes connections along the chosen division line except for one randomly chosen gap.
559			"""
560			if width < 2 or height < 2:  # noqa: PLR2004
561				return
562
563			if width > height:
564				# Vertical division.
565				wall_col = random.randint(x + 1, x + width - 1)
566				gap_row = random.randint(y, y + height - 1)
567				for row in range(y, y + height):
568					if row == gap_row:
569						continue
570					# Remove the vertical connection between (row, wall_col-1) and (row, wall_col).
571					if wall_col - 1 < n_cols - 1:
572						connection_list[1, row, wall_col - 1] = False
573				# Recurse on the left and right subregions.
574				divide(x, y, wall_col - x, height)
575				divide(wall_col, y, x + width - wall_col, height)
576			else:
577				# Horizontal division.
578				wall_row = random.randint(y + 1, y + height - 1)
579				gap_col = random.randint(x, x + width - 1)
580				for col in range(x, x + width):
581					if col == gap_col:
582						continue
583					# Remove the horizontal connection between (wall_row-1, col) and (wall_row, col).
584					if wall_row - 1 < n_rows - 1:
585						connection_list[0, wall_row - 1, col] = False
586				# Recurse on the top and bottom subregions.
587				divide(x, y, width, wall_row - y)
588				divide(x, wall_row, width, y + height - wall_row)
589
590		# Begin the division on the full grid.
591		divide(0, 0, n_cols, n_rows)
592
593		if start_coord is None:
594			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
595
596		generation_meta: dict = dict(
597			func_name="gen_recursive_division",
598			grid_shape=grid_shape_,
599			start_coord=start_coord,
600			algorithm="recursive_division",
601			fully_connected=True,
602		)
603		return LatticeMaze(
604			connection_list=connection_list, generation_meta=generation_meta
605		)
606
607
608# cant automatically populate this because it messes with pickling :(
609GENERATORS_MAP: dict[str, Callable[[Coord | CoordTup, Any], "LatticeMaze"]] = {
610	"gen_dfs": LatticeMazeGenerators.gen_dfs,
611	# TYPING: error: Dict entry 1 has incompatible type
612	# "str": "Callable[[ndarray[Any, Any] | tuple[int, int], KwArg(Any)], LatticeMaze]";
613	# expected "str": "Callable[[ndarray[Any, Any] | tuple[int, int], Any], LatticeMaze]"  [dict-item]
614	# gen_wilson takes no kwargs and we check that the kwargs are empty
615	# but mypy doesnt like this, `Any` != `KwArg(Any)`
616	"gen_wilson": LatticeMazeGenerators.gen_wilson,  # type: ignore[dict-item]
617	"gen_percolation": LatticeMazeGenerators.gen_percolation,
618	"gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation,
619	"gen_prim": LatticeMazeGenerators.gen_prim,
620	"gen_kruskal": LatticeMazeGenerators.gen_kruskal,
621	"gen_recursive_division": LatticeMazeGenerators.gen_recursive_division,
622}
623"mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`"
624
625_GENERATORS_PERCOLATED: list[str] = [
626	"gen_percolation",
627	"gen_dfs_percolation",
628]
629"""list of generator names that generate percolated mazes
630we use this to figure out the expected success rate, since depending on the endpoint kwargs this might fail
631this variable is primarily used in `MazeDatasetConfig._to_ps_array` and `MazeDatasetConfig._from_ps_array`
632"""
633
634
635def get_maze_with_solution(
636	gen_name: str,
637	grid_shape: Coord | CoordTup,
638	maze_ctor_kwargs: dict | None = None,
639) -> SolvedMaze:
640	"helper function to get a maze already with a solution"
641	if maze_ctor_kwargs is None:
642		maze_ctor_kwargs = dict()
643	# TYPING: error: Too few arguments  [call-arg]
644	# not sure why this is happening -- doesnt recognize the kwargs?
645	maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs)  # type: ignore[call-arg]
646	solution: CoordArray = np.array(maze.generate_random_path())
647	return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)

numpy_rng = Generator(PCG64) at 0x70E3AB64D460
def get_neighbors_in_bounds( coord: jaxtyping.Int8[ndarray, 'row_col=2'], grid_shape: jaxtyping.Int8[ndarray, 'row_col=2']) -> jaxtyping.Int8[ndarray, 'coord row_col=2']:
38def get_neighbors_in_bounds(
39	coord: Coord,
40	grid_shape: Coord,
41) -> CoordArray:
42	"get all neighbors of a coordinate that are within the bounds of the grid"
43	# get all neighbors
44	neighbors: CoordArray = coord + NEIGHBORS_MASK
45
46	# filter neighbors by being within grid bounds
47	neighbors_in_bounds: CoordArray = neighbors[
48		(neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1)
49	]
50
51	return neighbors_in_bounds

get all neighbors of a coordinate that are within the bounds of the grid

class LatticeMazeGenerators:
 54class LatticeMazeGenerators:
 55	"""namespace for lattice maze generation algorithms"""
 56
 57	@staticmethod
 58	def gen_dfs(
 59		grid_shape: Coord | CoordTup,
 60		lattice_dim: int = 2,
 61		accessible_cells: float | None = None,
 62		max_tree_depth: float | None = None,
 63		do_forks: bool = True,
 64		randomized_stack: bool = False,
 65		start_coord: Coord | None = None,
 66	) -> LatticeMaze:
 67		"""generate a lattice maze using depth first search, iterative
 68
 69		# Arguments
 70		- `grid_shape: Coord`: the shape of the grid
 71		- `lattice_dim: int`: the dimension of the lattice
 72			(default: `2`)
 73		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 74			(default: `None`)
 75		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 76			(default: `None`)
 77		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 78		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 79
 80		# algorithm
 81		1. Choose the initial cell, mark it as visited and push it to the stack
 82		2. While the stack is not empty
 83			1. Pop a cell from the stack and make it a current cell
 84			2. If the current cell has any neighbours which have not been visited
 85				1. Push the current cell to the stack
 86				2. Choose one of the unvisited neighbours
 87				3. Remove the wall between the current cell and the chosen cell
 88				4. Mark the chosen cell as visited and push it to the stack
 89		"""
 90		# Default values if no constraints have been passed
 91		grid_shape_: Coord = np.array(grid_shape)
 92		n_total_cells: int = int(np.prod(grid_shape_))
 93
 94		n_accessible_cells: int
 95		if accessible_cells is None:
 96			n_accessible_cells = n_total_cells
 97		elif isinstance(accessible_cells, float):
 98			assert accessible_cells <= 1, (
 99				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
100			)
101
102			n_accessible_cells = int(accessible_cells * n_total_cells)
103		else:
104			assert isinstance(accessible_cells, int)
105			n_accessible_cells = accessible_cells
106
107		if max_tree_depth is None:
108			max_tree_depth = (
109				2 * n_total_cells
110			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
111		elif isinstance(max_tree_depth, float):
112			assert max_tree_depth <= 1, (
113				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
114			)
115
116			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
117
118		# choose a random start coord
119		start_coord = _random_start_coord(grid_shape_, start_coord)
120
121		# initialize the maze with no connections
122		connection_list: ConnectionList = np.zeros(
123			(lattice_dim, grid_shape_[0], grid_shape_[1]),
124			dtype=np.bool_,
125		)
126
127		# initialize the stack with the target coord
128		visited_cells: set[tuple[int, int]] = set()
129		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
130		stack: list[Coord] = [start_coord]
131
132		# initialize tree_depth_counter
133		current_tree_depth: int = 1
134
135		# loop until the stack is empty or n_connected_cells is reached
136		while stack and (len(visited_cells) < n_accessible_cells):
137			# get the current coord from the stack
138			current_coord: Coord
139			if randomized_stack:
140				current_coord = stack.pop(random.randint(0, len(stack) - 1))
141			else:
142				current_coord = stack.pop()
143
144			# filter neighbors by being within grid bounds and being unvisited
145			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
146				(neighbor, delta)
147				for neighbor, delta in zip(
148					current_coord + NEIGHBORS_MASK,
149					NEIGHBORS_MASK,
150					strict=False,
151				)
152				if (
153					(tuple(neighbor) not in visited_cells)
154					and (0 <= neighbor[0] < grid_shape_[0])
155					and (0 <= neighbor[1] < grid_shape_[1])
156				)
157			]
158
159			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
160			if unvisited_neighbors_deltas and (
161				current_tree_depth <= max_tree_depth / 2
162			):
163				# if we want a maze without forks, simply don't add the current coord back to the stack
164				if do_forks and (len(unvisited_neighbors_deltas) > 1):
165					stack.append(current_coord)
166
167				# choose one of the unvisited neighbors
168				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)
169
170				# add connection
171				dim: int = int(np.argmax(np.abs(delta)))
172				# if positive, down/right from current coord
173				# if negative, up/left from current coord (down/right from neighbor)
174				clist_node: Coord = (
175					current_coord if (delta.sum() > 0) else chosen_neighbor
176				)
177				connection_list[dim, clist_node[0], clist_node[1]] = True
178
179				# add to visited cells and stack
180				visited_cells.add(tuple(chosen_neighbor))
181				stack.append(chosen_neighbor)
182
183				# Update current tree depth
184				current_tree_depth += 1
185			else:
186				current_tree_depth -= 1
187
188		return LatticeMaze(
189			connection_list=connection_list,
190			generation_meta=dict(
191				func_name="gen_dfs",
192				grid_shape=grid_shape_,
193				start_coord=start_coord,
194				n_accessible_cells=int(n_accessible_cells),
195				max_tree_depth=int(max_tree_depth),
196				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
197				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
198				# treated as fully connected even when it is most certainly not, causing solving the maze to break
199				fully_connected=bool(len(visited_cells) == n_total_cells),
200				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
201			),
202		)
203
204	@staticmethod
205	def gen_prim(
206		grid_shape: Coord | CoordTup,
207		lattice_dim: int = 2,
208		accessible_cells: float | None = None,
209		max_tree_depth: float | None = None,
210		do_forks: bool = True,
211		start_coord: Coord | None = None,
212	) -> LatticeMaze:
213		"(broken!) generate a lattice maze using Prim's algorithm"
214		warnings.warn(
215			"gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
216		)
217		return LatticeMazeGenerators.gen_dfs(
218			grid_shape=grid_shape,
219			lattice_dim=lattice_dim,
220			accessible_cells=accessible_cells,
221			max_tree_depth=max_tree_depth,
222			do_forks=do_forks,
223			start_coord=start_coord,
224			randomized_stack=True,
225		)
226
227	@staticmethod
228	def gen_wilson(
229		grid_shape: Coord | CoordTup,
230		**kwargs,
231	) -> LatticeMaze:
232		"""Generate a lattice maze using Wilson's algorithm.
233
234		# Algorithm
235		Wilson's algorithm generates an unbiased (random) maze
236		sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
237		acyclic and all cells are part of a unique connected space.
238		https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
239		"""
240		assert not kwargs, (
241			f"gen_wilson does not take any additional arguments, got {kwargs = }"
242		)
243
244		grid_shape_: Coord = np.array(grid_shape)
245
246		# Initialize grid and visited cells
247		connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
248		visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
249
250		# Choose a random cell and mark it as visited
251		start_coord: Coord = _random_start_coord(grid_shape_, None)
252		visited[start_coord[0], start_coord[1]] = True
253		del start_coord
254
255		while not visited.all():
256			# Perform loop-erased random walk from another random cell
257
258			# Choose walk_start only from unvisited cells
259			unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
260			walk_start: Coord = unvisited_coords[
261				np.random.choice(unvisited_coords.shape[0])
262			]
263
264			# Perform the random walk
265			path: list[Coord] = [walk_start]
266			current: Coord = walk_start
267
268			# exit the loop once the current path hits a visited cell
269			while not visited[current[0], current[1]]:
270				# find a valid neighbor (one always exists on a lattice)
271				neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
272				next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
273
274				# Check for loop
275				loop_exit: int | None = None
276				for i, p in enumerate(path):
277					if np.array_equal(next_cell, p):
278						loop_exit = i
279						break
280
281				# erase the loop, or continue the walk
282				if loop_exit is not None:
283					# this removes everything after and including the loop start
284					path = path[: loop_exit + 1]
285					# reset current cell to end of path
286					current = path[-1]
287				else:
288					path.append(next_cell)
289					current = next_cell
290
291			# Add the path to the maze
292			for i in range(len(path) - 1):
293				c_1: Coord = path[i]
294				c_2: Coord = path[i + 1]
295
296				# find the dimension of the connection
297				delta: Coord = c_2 - c_1
298				dim: int = int(np.argmax(np.abs(delta)))
299
300				# if positive, down/right from current coord
301				# if negative, up/left from current coord (down/right from neighbor)
302				clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
303				connection_list[dim, clist_node[0], clist_node[1]] = True
304				visited[c_1[0], c_1[1]] = True
305				# we dont add c_2 because the last c_2 will have already been visited
306
307		return LatticeMaze(
308			connection_list=connection_list,
309			generation_meta=dict(
310				func_name="gen_wilson",
311				grid_shape=grid_shape_,
312				fully_connected=True,
313			),
314		)
315
316	@staticmethod
317	def gen_percolation(
318		grid_shape: Coord | CoordTup,
319		p: float = 0.4,
320		lattice_dim: int = 2,
321		start_coord: Coord | None = None,
322	) -> LatticeMaze:
323		"""generate a lattice maze using simple percolation
324
325		note that p in the range (0.4, 0.7) gives the most interesting mazes
326
327		# Arguments
328		- `grid_shape: Coord`: the shape of the grid
329		- `lattice_dim: int`: the dimension of the lattice (default: `2`)
330		- `p: float`: the probability of a cell being accessible (default: `0.5`)
331		- `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
332		"""
333		assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"  # noqa: PT018
334		grid_shape_: Coord = np.array(grid_shape)
335
336		start_coord = _random_start_coord(grid_shape_, start_coord)
337
338		connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
339
340		connection_list = _fill_edges_with_walls(connection_list)
341
342		output: LatticeMaze = LatticeMaze(
343			connection_list=connection_list,
344			generation_meta=dict(
345				func_name="gen_percolation",
346				grid_shape=grid_shape_,
347				percolation_p=p,
348				start_coord=start_coord,
349			),
350		)
351
352		# generation_meta is sometimes None, but not here since we just made it a dict above
353		output.generation_meta["visited_cells"] = output.gen_connected_component_from(  # type: ignore[index]
354			start_coord,
355		)
356
357		return output
358
359	@staticmethod
360	def gen_dfs_percolation(
361		grid_shape: Coord | CoordTup,
362		p: float = 0.4,
363		lattice_dim: int = 2,
364		accessible_cells: int | None = None,
365		max_tree_depth: int | None = None,
366		start_coord: Coord | None = None,
367	) -> LatticeMaze:
368		"""dfs and then percolation (adds cycles)"""
369		grid_shape_: Coord = np.array(grid_shape)
370		start_coord = _random_start_coord(grid_shape_, start_coord)
371
372		# generate initial maze via dfs
373		maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
374			grid_shape=grid_shape_,
375			lattice_dim=lattice_dim,
376			accessible_cells=accessible_cells,
377			max_tree_depth=max_tree_depth,
378			start_coord=start_coord,
379		)
380
381		# percolate
382		connection_list_perc: np.ndarray = (
383			np.random.rand(*maze.connection_list.shape) < p
384		)
385		connection_list_perc = _fill_edges_with_walls(connection_list_perc)
386
387		maze.__dict__["connection_list"] = np.logical_or(
388			maze.connection_list,
389			connection_list_perc,
390		)
391
392		# generation_meta is sometimes None, but not here since we just made it a dict above
393		maze.generation_meta["func_name"] = "gen_dfs_percolation"  # type: ignore[index]
394		maze.generation_meta["percolation_p"] = p  # type: ignore[index]
395		maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(  # type: ignore[index]
396			start_coord,
397		)
398
399		return maze
400
401	@staticmethod
402	def gen_kruskal(
403		grid_shape: "Coord | CoordTup",
404		lattice_dim: int = 2,
405		start_coord: "Coord | None" = None,
406	) -> "LatticeMaze":
407		"""Generate a maze using Kruskal's algorithm.
408
409		This function generates a random spanning tree over a grid using Kruskal's algorithm.
410		Each cell is treated as a node, and all valid adjacent edges are listed and processed
411		in random order. An edge is added (i.e. its passage carved) only if it connects two cells
412		that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree)
413		without cycles.
414
415		https://en.wikipedia.org/wiki/Kruskal's_algorithm
416
417		# Parameters:
418		- `grid_shape : Coord | CoordTup`
419			The shape of the maze grid (for example, `(n_rows, n_cols)`).
420		- `lattice_dim : int`
421			The lattice dimension (default is `2`).
422		- `start_coord : Coord | None`
423			Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen.
424		- `**kwargs`
425			Additional keyword arguments (currently unused).
426
427		# Returns:
428		- `LatticeMaze`
429			A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
430
431		# Usage:
432		```python
433		maze = gen_kruskal((10, 10))
434		```
435		"""
436		assert lattice_dim == 2, (  # noqa: PLR2004
437			"Kruskal's algorithm is only implemented for 2D lattices."
438		)
439		# Convert grid_shape to a tuple of ints
440		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
441		n_rows, n_cols = grid_shape_
442
443		# Initialize union-find data structure.
444		parent: dict[tuple[int, int], tuple[int, int]] = {}
445
446		def find(cell: tuple[int, int]) -> tuple[int, int]:
447			while parent[cell] != cell:
448				parent[cell] = parent[parent[cell]]
449				cell = parent[cell]
450			return cell
451
452		def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None:
453			root1 = find(cell1)
454			root2 = find(cell2)
455			parent[root2] = root1
456
457		# Initialize each cell as its own set.
458		for i in range(n_rows):
459			for j in range(n_cols):
460				parent[(i, j)] = (i, j)
461
462		# List all possible edges.
463		# For vertical edges (i.e. connecting a cell to its right neighbor):
464		edges: list[tuple[tuple[int, int], tuple[int, int], int]] = []
465		for i in range(n_rows):
466			for j in range(n_cols - 1):
467				edges.append(((i, j), (i, j + 1), 1))
468		# For horizontal edges (i.e. connecting a cell to its bottom neighbor):
469		for i in range(n_rows - 1):
470			for j in range(n_cols):
471				edges.append(((i, j), (i + 1, j), 0))
472
473		# Shuffle the list of edges.
474		import random
475
476		random.shuffle(edges)
477
478		# Initialize connection_list with no connections.
479		# connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)).
480		# connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)).
481		import numpy as np
482
483		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
484
485		# Process each edge; if it connects two different trees, union them and carve the passage.
486		for cell1, cell2, direction in edges:
487			if find(cell1) != find(cell2):
488				union(cell1, cell2)
489				if direction == 0:
490					# Horizontal edge: connection is stored in connection_list[0] at cell1.
491					connection_list[0, cell1[0], cell1[1]] = True
492				else:
493					# Vertical edge: connection is stored in connection_list[1] at cell1.
494					connection_list[1, cell1[0], cell1[1]] = True
495
496		if start_coord is None:
497			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
498
499		generation_meta: dict = dict(
500			func_name="gen_kruskal",
501			grid_shape=grid_shape_,
502			start_coord=start_coord,
503			algorithm="kruskal",
504			fully_connected=True,
505		)
506		return LatticeMaze(
507			connection_list=connection_list, generation_meta=generation_meta
508		)
509
510	@staticmethod
511	def gen_recursive_division(
512		grid_shape: "Coord | CoordTup",
513		lattice_dim: int = 2,
514		start_coord: "Coord | None" = None,
515	) -> "LatticeMaze":
516		"""Generate a maze using the recursive division algorithm.
517
518		This function generates a maze by recursively dividing the grid with walls and carving a single
519		passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent
520		cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage.
521		The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
522
523		# Parameters:
524		- `grid_shape : Coord | CoordTup`
525			The shape of the maze grid (e.g., `(n_rows, n_cols)`).
526		- `lattice_dim : int`
527			The lattice dimension (default is `2`).
528		- `start_coord : Coord | None`
529			Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen.
530		- `**kwargs`
531			Additional keyword arguments (currently unused).
532
533		# Returns:
534		- `LatticeMaze`
535			A maze represented by a connection list, generated using recursive division.
536
537		# Usage:
538		```python
539		maze = gen_recursive_division((10, 10))
540		```
541		"""
542		assert lattice_dim == 2, (  # noqa: PLR2004
543			"Recursive division algorithm is only implemented for 2D lattices."
544		)
545		# Convert grid_shape to a tuple of ints.
546		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
547		n_rows, n_cols = grid_shape_
548
549		# Initialize connection_list as a fully connected grid.
550		# For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True.
551		# For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True.
552		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
553		connection_list[0, : n_rows - 1, :] = True
554		connection_list[1, :, : n_cols - 1] = True
555
556		def divide(x: int, y: int, width: int, height: int) -> None:
557			"""Recursively divide the region starting at (x, y) with the given width and height.
558
559			Removes connections along the chosen division line except for one randomly chosen gap.
560			"""
561			if width < 2 or height < 2:  # noqa: PLR2004
562				return
563
564			if width > height:
565				# Vertical division.
566				wall_col = random.randint(x + 1, x + width - 1)
567				gap_row = random.randint(y, y + height - 1)
568				for row in range(y, y + height):
569					if row == gap_row:
570						continue
571					# Remove the vertical connection between (row, wall_col-1) and (row, wall_col).
572					if wall_col - 1 < n_cols - 1:
573						connection_list[1, row, wall_col - 1] = False
574				# Recurse on the left and right subregions.
575				divide(x, y, wall_col - x, height)
576				divide(wall_col, y, x + width - wall_col, height)
577			else:
578				# Horizontal division.
579				wall_row = random.randint(y + 1, y + height - 1)
580				gap_col = random.randint(x, x + width - 1)
581				for col in range(x, x + width):
582					if col == gap_col:
583						continue
584					# Remove the horizontal connection between (wall_row-1, col) and (wall_row, col).
585					if wall_row - 1 < n_rows - 1:
586						connection_list[0, wall_row - 1, col] = False
587				# Recurse on the top and bottom subregions.
588				divide(x, y, width, wall_row - y)
589				divide(x, wall_row, width, y + height - wall_row)
590
591		# Begin the division on the full grid.
592		divide(0, 0, n_cols, n_rows)
593
594		if start_coord is None:
595			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
596
597		generation_meta: dict = dict(
598			func_name="gen_recursive_division",
599			grid_shape=grid_shape_,
600			start_coord=start_coord,
601			algorithm="recursive_division",
602			fully_connected=True,
603		)
604		return LatticeMaze(
605			connection_list=connection_list, generation_meta=generation_meta
606		)

namespace for lattice maze generation algorithms

@staticmethod
def gen_dfs( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, accessible_cells: float | None = None, max_tree_depth: float | None = None, do_forks: bool = True, randomized_stack: bool = False, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
 57	@staticmethod
 58	def gen_dfs(
 59		grid_shape: Coord | CoordTup,
 60		lattice_dim: int = 2,
 61		accessible_cells: float | None = None,
 62		max_tree_depth: float | None = None,
 63		do_forks: bool = True,
 64		randomized_stack: bool = False,
 65		start_coord: Coord | None = None,
 66	) -> LatticeMaze:
 67		"""generate a lattice maze using depth first search, iterative
 68
 69		# Arguments
 70		- `grid_shape: Coord`: the shape of the grid
 71		- `lattice_dim: int`: the dimension of the lattice
 72			(default: `2`)
 73		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 74			(default: `None`)
 75		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 76			(default: `None`)
 77		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 78		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 79
 80		# algorithm
 81		1. Choose the initial cell, mark it as visited and push it to the stack
 82		2. While the stack is not empty
 83			1. Pop a cell from the stack and make it a current cell
 84			2. If the current cell has any neighbours which have not been visited
 85				1. Push the current cell to the stack
 86				2. Choose one of the unvisited neighbours
 87				3. Remove the wall between the current cell and the chosen cell
 88				4. Mark the chosen cell as visited and push it to the stack
 89		"""
 90		# Default values if no constraints have been passed
 91		grid_shape_: Coord = np.array(grid_shape)
 92		n_total_cells: int = int(np.prod(grid_shape_))
 93
 94		n_accessible_cells: int
 95		if accessible_cells is None:
 96			n_accessible_cells = n_total_cells
 97		elif isinstance(accessible_cells, float):
 98			assert accessible_cells <= 1, (
 99				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
100			)
101
102			n_accessible_cells = int(accessible_cells * n_total_cells)
103		else:
104			assert isinstance(accessible_cells, int)
105			n_accessible_cells = accessible_cells
106
107		if max_tree_depth is None:
108			max_tree_depth = (
109				2 * n_total_cells
110			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
111		elif isinstance(max_tree_depth, float):
112			assert max_tree_depth <= 1, (
113				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
114			)
115
116			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
117
118		# choose a random start coord
119		start_coord = _random_start_coord(grid_shape_, start_coord)
120
121		# initialize the maze with no connections
122		connection_list: ConnectionList = np.zeros(
123			(lattice_dim, grid_shape_[0], grid_shape_[1]),
124			dtype=np.bool_,
125		)
126
127		# initialize the stack with the target coord
128		visited_cells: set[tuple[int, int]] = set()
129		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
130		stack: list[Coord] = [start_coord]
131
132		# initialize tree_depth_counter
133		current_tree_depth: int = 1
134
135		# loop until the stack is empty or n_connected_cells is reached
136		while stack and (len(visited_cells) < n_accessible_cells):
137			# get the current coord from the stack
138			current_coord: Coord
139			if randomized_stack:
140				current_coord = stack.pop(random.randint(0, len(stack) - 1))
141			else:
142				current_coord = stack.pop()
143
144			# filter neighbors by being within grid bounds and being unvisited
145			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
146				(neighbor, delta)
147				for neighbor, delta in zip(
148					current_coord + NEIGHBORS_MASK,
149					NEIGHBORS_MASK,
150					strict=False,
151				)
152				if (
153					(tuple(neighbor) not in visited_cells)
154					and (0 <= neighbor[0] < grid_shape_[0])
155					and (0 <= neighbor[1] < grid_shape_[1])
156				)
157			]
158
159			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
160			if unvisited_neighbors_deltas and (
161				current_tree_depth <= max_tree_depth / 2
162			):
163				# if we want a maze without forks, simply don't add the current coord back to the stack
164				if do_forks and (len(unvisited_neighbors_deltas) > 1):
165					stack.append(current_coord)
166
167				# choose one of the unvisited neighbors
168				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)
169
170				# add connection
171				dim: int = int(np.argmax(np.abs(delta)))
172				# if positive, down/right from current coord
173				# if negative, up/left from current coord (down/right from neighbor)
174				clist_node: Coord = (
175					current_coord if (delta.sum() > 0) else chosen_neighbor
176				)
177				connection_list[dim, clist_node[0], clist_node[1]] = True
178
179				# add to visited cells and stack
180				visited_cells.add(tuple(chosen_neighbor))
181				stack.append(chosen_neighbor)
182
183				# Update current tree depth
184				current_tree_depth += 1
185			else:
186				current_tree_depth -= 1
187
188		return LatticeMaze(
189			connection_list=connection_list,
190			generation_meta=dict(
191				func_name="gen_dfs",
192				grid_shape=grid_shape_,
193				start_coord=start_coord,
194				n_accessible_cells=int(n_accessible_cells),
195				max_tree_depth=int(max_tree_depth),
196				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
197				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
198				# treated as fully connected even when it is most certainly not, causing solving the maze to break
199				fully_connected=bool(len(visited_cells) == n_total_cells),
200				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
201			),
202		)

generate a lattice maze using depth first search, iterative

Arguments

  • grid_shape: Coord: the shape of the grid
  • lattice_dim: int: the dimension of the lattice (default: 2)
  • accessible_cells: int | float |None: the number of accessible cells in the maze. If None, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of total cells (default: None)
  • max_tree_depth: int | float | None: the maximum depth of the tree. If None, defaults to 2 * accessible_cells. if a float, asserts it is <= 1 and treats it as a proportion of the sum of the grid shape (default: None)
  • do_forks: bool: whether to allow forks in the maze. If False, the maze will be have no forks and will be a simple hallway.
  • start_coord: Coord | None: the starting coordinate of the generation algorithm. If None, defaults to a random coordinate.

algorithm

  1. Choose the initial cell, mark it as visited and push it to the stack
  2. While the stack is not empty
    1. Pop a cell from the stack and make it a current cell
    2. If the current cell has any neighbours which have not been visited
      1. Push the current cell to the stack
      2. Choose one of the unvisited neighbours
      3. Remove the wall between the current cell and the chosen cell
      4. Mark the chosen cell as visited and push it to the stack
@staticmethod
def gen_prim( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, accessible_cells: float | None = None, max_tree_depth: float | None = None, do_forks: bool = True, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
204	@staticmethod
205	def gen_prim(
206		grid_shape: Coord | CoordTup,
207		lattice_dim: int = 2,
208		accessible_cells: float | None = None,
209		max_tree_depth: float | None = None,
210		do_forks: bool = True,
211		start_coord: Coord | None = None,
212	) -> LatticeMaze:
213		"(broken!) generate a lattice maze using Prim's algorithm"
214		warnings.warn(
215			"gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
216		)
217		return LatticeMazeGenerators.gen_dfs(
218			grid_shape=grid_shape,
219			lattice_dim=lattice_dim,
220			accessible_cells=accessible_cells,
221			max_tree_depth=max_tree_depth,
222			do_forks=do_forks,
223			start_coord=start_coord,
224			randomized_stack=True,
225		)

(broken!) generate a lattice maze using Prim's algorithm

@staticmethod
def gen_wilson( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], **kwargs) -> maze_dataset.LatticeMaze:
227	@staticmethod
228	def gen_wilson(
229		grid_shape: Coord | CoordTup,
230		**kwargs,
231	) -> LatticeMaze:
232		"""Generate a lattice maze using Wilson's algorithm.
233
234		# Algorithm
235		Wilson's algorithm generates an unbiased (random) maze
236		sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
237		acyclic and all cells are part of a unique connected space.
238		https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
239		"""
240		assert not kwargs, (
241			f"gen_wilson does not take any additional arguments, got {kwargs = }"
242		)
243
244		grid_shape_: Coord = np.array(grid_shape)
245
246		# Initialize grid and visited cells
247		connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
248		visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
249
250		# Choose a random cell and mark it as visited
251		start_coord: Coord = _random_start_coord(grid_shape_, None)
252		visited[start_coord[0], start_coord[1]] = True
253		del start_coord
254
255		while not visited.all():
256			# Perform loop-erased random walk from another random cell
257
258			# Choose walk_start only from unvisited cells
259			unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
260			walk_start: Coord = unvisited_coords[
261				np.random.choice(unvisited_coords.shape[0])
262			]
263
264			# Perform the random walk
265			path: list[Coord] = [walk_start]
266			current: Coord = walk_start
267
268			# exit the loop once the current path hits a visited cell
269			while not visited[current[0], current[1]]:
270				# find a valid neighbor (one always exists on a lattice)
271				neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
272				next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
273
274				# Check for loop
275				loop_exit: int | None = None
276				for i, p in enumerate(path):
277					if np.array_equal(next_cell, p):
278						loop_exit = i
279						break
280
281				# erase the loop, or continue the walk
282				if loop_exit is not None:
283					# this removes everything after and including the loop start
284					path = path[: loop_exit + 1]
285					# reset current cell to end of path
286					current = path[-1]
287				else:
288					path.append(next_cell)
289					current = next_cell
290
291			# Add the path to the maze
292			for i in range(len(path) - 1):
293				c_1: Coord = path[i]
294				c_2: Coord = path[i + 1]
295
296				# find the dimension of the connection
297				delta: Coord = c_2 - c_1
298				dim: int = int(np.argmax(np.abs(delta)))
299
300				# if positive, down/right from current coord
301				# if negative, up/left from current coord (down/right from neighbor)
302				clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
303				connection_list[dim, clist_node[0], clist_node[1]] = True
304				visited[c_1[0], c_1[1]] = True
305				# we dont add c_2 because the last c_2 will have already been visited
306
307		return LatticeMaze(
308			connection_list=connection_list,
309			generation_meta=dict(
310				func_name="gen_wilson",
311				grid_shape=grid_shape_,
312				fully_connected=True,
313			),
314		)

Generate a lattice maze using Wilson's algorithm.

Algorithm

Wilson's algorithm generates an unbiased (random) maze sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is acyclic and all cells are part of a unique connected space. https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm

@staticmethod
def gen_percolation( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], p: float = 0.4, lattice_dim: int = 2, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
316	@staticmethod
317	def gen_percolation(
318		grid_shape: Coord | CoordTup,
319		p: float = 0.4,
320		lattice_dim: int = 2,
321		start_coord: Coord | None = None,
322	) -> LatticeMaze:
323		"""generate a lattice maze using simple percolation
324
325		note that p in the range (0.4, 0.7) gives the most interesting mazes
326
327		# Arguments
328		- `grid_shape: Coord`: the shape of the grid
329		- `lattice_dim: int`: the dimension of the lattice (default: `2`)
330		- `p: float`: the probability of a cell being accessible (default: `0.5`)
331		- `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
332		"""
333		assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"  # noqa: PT018
334		grid_shape_: Coord = np.array(grid_shape)
335
336		start_coord = _random_start_coord(grid_shape_, start_coord)
337
338		connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
339
340		connection_list = _fill_edges_with_walls(connection_list)
341
342		output: LatticeMaze = LatticeMaze(
343			connection_list=connection_list,
344			generation_meta=dict(
345				func_name="gen_percolation",
346				grid_shape=grid_shape_,
347				percolation_p=p,
348				start_coord=start_coord,
349			),
350		)
351
352		# generation_meta is sometimes None, but not here since we just made it a dict above
353		output.generation_meta["visited_cells"] = output.gen_connected_component_from(  # type: ignore[index]
354			start_coord,
355		)
356
357		return output

generate a lattice maze using simple percolation

note that p in the range (0.4, 0.7) gives the most interesting mazes

Arguments

  • grid_shape: Coord: the shape of the grid
  • lattice_dim: int: the dimension of the lattice (default: 2)
  • p: float: the probability of a cell being accessible (default: 0.5)
  • start_coord: Coord | None: the starting coordinate for the connected component (default: None will give a random start)
@staticmethod
def gen_dfs_percolation( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], p: float = 0.4, lattice_dim: int = 2, accessible_cells: int | None = None, max_tree_depth: int | None = None, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
359	@staticmethod
360	def gen_dfs_percolation(
361		grid_shape: Coord | CoordTup,
362		p: float = 0.4,
363		lattice_dim: int = 2,
364		accessible_cells: int | None = None,
365		max_tree_depth: int | None = None,
366		start_coord: Coord | None = None,
367	) -> LatticeMaze:
368		"""dfs and then percolation (adds cycles)"""
369		grid_shape_: Coord = np.array(grid_shape)
370		start_coord = _random_start_coord(grid_shape_, start_coord)
371
372		# generate initial maze via dfs
373		maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
374			grid_shape=grid_shape_,
375			lattice_dim=lattice_dim,
376			accessible_cells=accessible_cells,
377			max_tree_depth=max_tree_depth,
378			start_coord=start_coord,
379		)
380
381		# percolate
382		connection_list_perc: np.ndarray = (
383			np.random.rand(*maze.connection_list.shape) < p
384		)
385		connection_list_perc = _fill_edges_with_walls(connection_list_perc)
386
387		maze.__dict__["connection_list"] = np.logical_or(
388			maze.connection_list,
389			connection_list_perc,
390		)
391
392		# generation_meta is sometimes None, but not here since we just made it a dict above
393		maze.generation_meta["func_name"] = "gen_dfs_percolation"  # type: ignore[index]
394		maze.generation_meta["percolation_p"] = p  # type: ignore[index]
395		maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(  # type: ignore[index]
396			start_coord,
397		)
398
399		return maze

dfs and then percolation (adds cycles)

@staticmethod
def gen_kruskal( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
401	@staticmethod
402	def gen_kruskal(
403		grid_shape: "Coord | CoordTup",
404		lattice_dim: int = 2,
405		start_coord: "Coord | None" = None,
406	) -> "LatticeMaze":
407		"""Generate a maze using Kruskal's algorithm.
408
409		This function generates a random spanning tree over a grid using Kruskal's algorithm.
410		Each cell is treated as a node, and all valid adjacent edges are listed and processed
411		in random order. An edge is added (i.e. its passage carved) only if it connects two cells
412		that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree)
413		without cycles.
414
415		https://en.wikipedia.org/wiki/Kruskal's_algorithm
416
417		# Parameters:
418		- `grid_shape : Coord | CoordTup`
419			The shape of the maze grid (for example, `(n_rows, n_cols)`).
420		- `lattice_dim : int`
421			The lattice dimension (default is `2`).
422		- `start_coord : Coord | None`
423			Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen.
424		- `**kwargs`
425			Additional keyword arguments (currently unused).
426
427		# Returns:
428		- `LatticeMaze`
429			A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
430
431		# Usage:
432		```python
433		maze = gen_kruskal((10, 10))
434		```
435		"""
436		assert lattice_dim == 2, (  # noqa: PLR2004
437			"Kruskal's algorithm is only implemented for 2D lattices."
438		)
439		# Convert grid_shape to a tuple of ints
440		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
441		n_rows, n_cols = grid_shape_
442
443		# Initialize union-find data structure.
444		parent: dict[tuple[int, int], tuple[int, int]] = {}
445
446		def find(cell: tuple[int, int]) -> tuple[int, int]:
447			while parent[cell] != cell:
448				parent[cell] = parent[parent[cell]]
449				cell = parent[cell]
450			return cell
451
452		def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None:
453			root1 = find(cell1)
454			root2 = find(cell2)
455			parent[root2] = root1
456
457		# Initialize each cell as its own set.
458		for i in range(n_rows):
459			for j in range(n_cols):
460				parent[(i, j)] = (i, j)
461
462		# List all possible edges.
463		# For vertical edges (i.e. connecting a cell to its right neighbor):
464		edges: list[tuple[tuple[int, int], tuple[int, int], int]] = []
465		for i in range(n_rows):
466			for j in range(n_cols - 1):
467				edges.append(((i, j), (i, j + 1), 1))
468		# For horizontal edges (i.e. connecting a cell to its bottom neighbor):
469		for i in range(n_rows - 1):
470			for j in range(n_cols):
471				edges.append(((i, j), (i + 1, j), 0))
472
473		# Shuffle the list of edges.
474		import random
475
476		random.shuffle(edges)
477
478		# Initialize connection_list with no connections.
479		# connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)).
480		# connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)).
481		import numpy as np
482
483		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
484
485		# Process each edge; if it connects two different trees, union them and carve the passage.
486		for cell1, cell2, direction in edges:
487			if find(cell1) != find(cell2):
488				union(cell1, cell2)
489				if direction == 0:
490					# Horizontal edge: connection is stored in connection_list[0] at cell1.
491					connection_list[0, cell1[0], cell1[1]] = True
492				else:
493					# Vertical edge: connection is stored in connection_list[1] at cell1.
494					connection_list[1, cell1[0], cell1[1]] = True
495
496		if start_coord is None:
497			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
498
499		generation_meta: dict = dict(
500			func_name="gen_kruskal",
501			grid_shape=grid_shape_,
502			start_coord=start_coord,
503			algorithm="kruskal",
504			fully_connected=True,
505		)
506		return LatticeMaze(
507			connection_list=connection_list, generation_meta=generation_meta
508		)

Generate a maze using Kruskal's algorithm.

This function generates a random spanning tree over a grid using Kruskal's algorithm. Each cell is treated as a node, and all valid adjacent edges are listed and processed in random order. An edge is added (i.e. its passage carved) only if it connects two cells that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) without cycles.

https://en.wikipedia.org/wiki/Kruskal's_algorithm

Parameters:

  • grid_shape : Coord | CoordTup The shape of the maze grid (for example, (n_rows, n_cols)).
  • lattice_dim : int The lattice dimension (default is 2).
  • start_coord : Coord | None Optionally, specify a starting coordinate. If None, a random coordinate will be chosen.
  • **kwargs Additional keyword arguments (currently unused).

Returns:

  • LatticeMaze A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.

Usage:

maze = gen_kruskal((10, 10))
@staticmethod
def gen_recursive_division( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
510	@staticmethod
511	def gen_recursive_division(
512		grid_shape: "Coord | CoordTup",
513		lattice_dim: int = 2,
514		start_coord: "Coord | None" = None,
515	) -> "LatticeMaze":
516		"""Generate a maze using the recursive division algorithm.
517
518		This function generates a maze by recursively dividing the grid with walls and carving a single
519		passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent
520		cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage.
521		The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
522
523		# Parameters:
524		- `grid_shape : Coord | CoordTup`
525			The shape of the maze grid (e.g., `(n_rows, n_cols)`).
526		- `lattice_dim : int`
527			The lattice dimension (default is `2`).
528		- `start_coord : Coord | None`
529			Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen.
530		- `**kwargs`
531			Additional keyword arguments (currently unused).
532
533		# Returns:
534		- `LatticeMaze`
535			A maze represented by a connection list, generated using recursive division.
536
537		# Usage:
538		```python
539		maze = gen_recursive_division((10, 10))
540		```
541		"""
542		assert lattice_dim == 2, (  # noqa: PLR2004
543			"Recursive division algorithm is only implemented for 2D lattices."
544		)
545		# Convert grid_shape to a tuple of ints.
546		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
547		n_rows, n_cols = grid_shape_
548
549		# Initialize connection_list as a fully connected grid.
550		# For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True.
551		# For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True.
552		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
553		connection_list[0, : n_rows - 1, :] = True
554		connection_list[1, :, : n_cols - 1] = True
555
556		def divide(x: int, y: int, width: int, height: int) -> None:
557			"""Recursively divide the region starting at (x, y) with the given width and height.
558
559			Removes connections along the chosen division line except for one randomly chosen gap.
560			"""
561			if width < 2 or height < 2:  # noqa: PLR2004
562				return
563
564			if width > height:
565				# Vertical division.
566				wall_col = random.randint(x + 1, x + width - 1)
567				gap_row = random.randint(y, y + height - 1)
568				for row in range(y, y + height):
569					if row == gap_row:
570						continue
571					# Remove the vertical connection between (row, wall_col-1) and (row, wall_col).
572					if wall_col - 1 < n_cols - 1:
573						connection_list[1, row, wall_col - 1] = False
574				# Recurse on the left and right subregions.
575				divide(x, y, wall_col - x, height)
576				divide(wall_col, y, x + width - wall_col, height)
577			else:
578				# Horizontal division.
579				wall_row = random.randint(y + 1, y + height - 1)
580				gap_col = random.randint(x, x + width - 1)
581				for col in range(x, x + width):
582					if col == gap_col:
583						continue
584					# Remove the horizontal connection between (wall_row-1, col) and (wall_row, col).
585					if wall_row - 1 < n_rows - 1:
586						connection_list[0, wall_row - 1, col] = False
587				# Recurse on the top and bottom subregions.
588				divide(x, y, width, wall_row - y)
589				divide(x, wall_row, width, y + height - wall_row)
590
591		# Begin the division on the full grid.
592		divide(0, 0, n_cols, n_rows)
593
594		if start_coord is None:
595			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
596
597		generation_meta: dict = dict(
598			func_name="gen_recursive_division",
599			grid_shape=grid_shape_,
600			start_coord=start_coord,
601			algorithm="recursive_division",
602			fully_connected=True,
603		)
604		return LatticeMaze(
605			connection_list=connection_list, generation_meta=generation_meta
606		)

Generate a maze using the recursive division algorithm.

This function generates a maze by recursively dividing the grid with walls and carving a single passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.

Parameters:

  • grid_shape : Coord | CoordTup The shape of the maze grid (e.g., (n_rows, n_cols)).
  • lattice_dim : int The lattice dimension (default is 2).
  • start_coord : Coord | None Optionally, specify a starting coordinate. If None, a random coordinate is chosen.
  • **kwargs Additional keyword arguments (currently unused).

Returns:

  • LatticeMaze A maze represented by a connection list, generated using recursive division.

Usage:

maze = gen_recursive_division((10, 10))
GENERATORS_MAP: dict[str, typing.Callable[[jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], typing.Any], maze_dataset.LatticeMaze]] = {'gen_dfs': <function LatticeMazeGenerators.gen_dfs>, 'gen_wilson': <function LatticeMazeGenerators.gen_wilson>, 'gen_percolation': <function LatticeMazeGenerators.gen_percolation>, 'gen_dfs_percolation': <function LatticeMazeGenerators.gen_dfs_percolation>, 'gen_prim': <function LatticeMazeGenerators.gen_prim>, 'gen_kruskal': <function LatticeMazeGenerators.gen_kruskal>, 'gen_recursive_division': <function LatticeMazeGenerators.gen_recursive_division>}

mapping of generator names to generator functions, useful for loading MazeDatasetConfig

def get_maze_with_solution( gen_name: str, grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], maze_ctor_kwargs: dict | None = None) -> maze_dataset.SolvedMaze:
636def get_maze_with_solution(
637	gen_name: str,
638	grid_shape: Coord | CoordTup,
639	maze_ctor_kwargs: dict | None = None,
640) -> SolvedMaze:
641	"helper function to get a maze already with a solution"
642	if maze_ctor_kwargs is None:
643		maze_ctor_kwargs = dict()
644	# TYPING: error: Too few arguments  [call-arg]
645	# not sure why this is happening -- doesnt recognize the kwargs?
646	maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs)  # type: ignore[call-arg]
647	solution: CoordArray = np.array(maze.generate_random_path())
648	return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)

helper function to get a maze already with a solution