Coverage for maze_dataset\maze\lattice_maze.py: 59%
495 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1import typing
2import warnings
3from dataclasses import dataclass
4from itertools import chain
6import numpy as np
7from jaxtyping import Bool, Int, Int8, Shaped
8from muutils.json_serialize.serializable_dataclass import (
9 SerializableDataclass,
10 serializable_dataclass,
11 serializable_field,
12)
13from muutils.misc import isinstance_by_type_name, list_split
15from maze_dataset.constants import (
16 NEIGHBORS_MASK,
17 SPECIAL_TOKENS,
18 ConnectionList,
19 Coord,
20 CoordArray,
21 CoordList,
22 CoordTup,
23)
24from maze_dataset.token_utils import (
25 TokenizerDeprecationWarning,
26 connection_list_to_adj_list,
27 get_adj_list_tokens,
28 get_origin_tokens,
29 get_path_tokens,
30 get_target_tokens,
31)
33if typing.TYPE_CHECKING:
34 from maze_dataset.tokenization import (
35 MazeTokenizer,
36 MazeTokenizerModular,
37 TokenizationMode,
38 )
40RGB = tuple[int, int, int]
41"rgb tuple of values 0-255"
43PixelGrid = Int[np.ndarray, "x y rgb"]
44"rgb grid of pixels"
45BinaryPixelGrid = Bool[np.ndarray, "x y"]
46"boolean grid of pixels"
49class NoValidEndpointException(Exception):
50 """Raised when no valid start or end positions are found in a maze."""
52 pass
55def _fill_edges_with_walls(connection_list: ConnectionList) -> ConnectionList:
56 """fill the last elements of the connections lists as false for each dim"""
57 for dim in range(connection_list.shape[0]):
58 # last row for down
59 if dim == 0:
60 connection_list[dim, -1, :] = False
61 # last column for right
62 elif dim == 1:
63 connection_list[dim, :, -1] = False
64 else:
65 raise NotImplementedError(f"only 2d lattices supported. got {dim=}")
66 return connection_list
69def color_in_pixel_grid(pixel_grid: PixelGrid, color: RGB) -> bool:
70 for row in pixel_grid:
71 for pixel in row:
72 if np.all(pixel == color):
73 return True
74 return False
77@dataclass(frozen=True)
78class PixelColors:
79 "standard colors for pixel grids"
81 WALL: RGB = (0, 0, 0)
82 OPEN: RGB = (255, 255, 255)
83 START: RGB = (0, 255, 0)
84 END: RGB = (255, 0, 0)
85 PATH: RGB = (0, 0, 255)
88@dataclass(frozen=True)
89class AsciiChars:
90 "standard ascii characters for mazes"
92 WALL: str = "#"
93 OPEN: str = " "
94 START: str = "S"
95 END: str = "E"
96 PATH: str = "X"
99ASCII_PIXEL_PAIRINGS: dict[str, RGB] = {
100 AsciiChars.WALL: PixelColors.WALL,
101 AsciiChars.OPEN: PixelColors.OPEN,
102 AsciiChars.START: PixelColors.START,
103 AsciiChars.END: PixelColors.END,
104 AsciiChars.PATH: PixelColors.PATH,
105}
106"map ascii characters to pixel colors"
109@serializable_dataclass(
110 frozen=True,
111 kw_only=True,
112 properties_to_serialize=["lattice_dim", "generation_meta"],
113)
114class LatticeMaze(SerializableDataclass):
115 """lattice maze (nodes on a lattice, connections only to neighboring nodes)
117 Connection List represents which nodes (N) are connected in each direction.
119 First and second elements represent rightward and downward connections,
120 respectively.
122 Example:
123 Connection list:
124 [
125 [ # down
126 [F T],
127 [F F]
128 ],
129 [ # right
130 [T F],
131 [T F]
132 ]
133 ]
135 Nodes with connections
136 N T N F
137 F T
138 N T N F
139 F F
141 Graph:
142 N - N
143 |
144 N - N
146 Note: the bottom row connections going down, and the
147 right-hand connections going right, will always be False.
148 """
150 connection_list: ConnectionList
151 generation_meta: dict | None = serializable_field(default=None, compare=False)
153 lattice_dim = property(lambda self: self.connection_list.shape[0])
154 grid_shape = property(lambda self: self.connection_list.shape[1:])
155 n_connections = property(lambda self: self.connection_list.sum())
157 @property
158 def grid_n(self) -> int:
159 assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported"
160 return self.grid_shape[0]
162 # ============================================================
163 # basic methods
164 # ============================================================
166 def __eq__(self, other: object) -> bool:
167 return super().__eq__(other)
169 @staticmethod
170 def heuristic(a: CoordTup, b: CoordTup) -> float:
171 """return manhattan distance between two points"""
172 return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1])
174 def __hash__(self) -> int:
175 return hash(self.connection_list.tobytes())
177 def nodes_connected(self, a: Coord, b: Coord, /) -> bool:
178 """returns whether two nodes are connected"""
179 delta: Coord = b - a
180 if np.abs(delta).sum() != 1:
181 # return false if not even adjacent
182 return False
183 else:
184 # test for wall
185 dim: int = int(np.argmax(np.abs(delta)))
186 clist_node: Coord = a if (delta.sum() > 0) else b
187 return self.connection_list[dim, clist_node[0], clist_node[1]]
189 def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool:
190 """check if a path is valid"""
191 # check path is not empty
192 if len(path) == 0:
193 if not empty_is_valid:
194 return False
195 else:
196 return True
198 # check all coords in bounds of maze
199 if not np.all((0 <= path) & (path < self.grid_shape)):
200 return False
202 # check all nodes connected
203 for i in range(len(path) - 1):
204 if not self.nodes_connected(path[i], path[i + 1]):
205 return False
206 return True
208 def coord_degrees(self) -> Int8[np.ndarray, "row col"]:
209 """
210 Returns an array with the connectivity degree of each coord.
211 I.e., how many neighbors each coord has.
212 """
213 int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = (
214 self.connection_list.astype(np.int8)
215 )
216 degrees: Int8[np.ndarray, "row col"] = np.sum(
217 int_conn, axis=0
218 ) # Connections to east and south
219 degrees[:, 1:] += int_conn[1, :, :-1] # Connections to west
220 degrees[1:, :] += int_conn[0, :-1, :] # Connections to north
221 return degrees
223 def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray:
224 """
225 Returns an array of the neighboring, connected coords of `c`.
226 """
227 c = np.array(c) # type: ignore[assignment]
228 neighbors: list[Coord] = [
229 neighbor
230 for neighbor in (c + NEIGHBORS_MASK)
231 if (
232 (0 <= neighbor[0] < self.grid_shape[0]) # in x bounds
233 and (0 <= neighbor[1] < self.grid_shape[1]) # in y bounds
234 and self.nodes_connected(c, neighbor) # connected
235 )
236 ]
238 output: CoordArray = np.array(neighbors)
239 if len(neighbors) > 0:
240 assert output.shape == (
241 len(neighbors),
242 2,
243 ), (
244 f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}"
245 )
246 return output
248 def gen_connected_component_from(self, c: Coord) -> CoordArray:
249 """return the connected component from a given coordinate"""
250 # Stack for DFS
251 stack: list[Coord] = [c]
253 # Set to store visited nodes
254 visited: set[CoordTup] = set()
256 while stack:
257 current_node: Coord = stack.pop()
258 # this is fine since we know current_node is a coord and thus of length 2
259 visited.add(tuple(current_node)) # type: ignore[arg-type]
261 # Get the neighbors of the current node
262 neighbors = self.get_coord_neighbors(current_node)
264 # Iterate over neighbors
265 for neighbor in neighbors:
266 if tuple(neighbor) not in visited:
267 stack.append(neighbor)
269 return np.array(list(visited))
271 def find_shortest_path(
272 self,
273 c_start: CoordTup | Coord,
274 c_end: CoordTup | Coord,
275 ) -> CoordArray:
276 """find the shortest path between two coordinates, using A*"""
277 c_start = tuple(c_start) # type: ignore[assignment]
278 c_end = tuple(c_end) # type: ignore[assignment]
280 g_score: dict[CoordTup, float] = (
281 dict()
282 ) # cost of cheapest path to node from start currently known
283 f_score: dict[CoordTup, float] = {
284 c_start: 0.0
285 } # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end)
287 # init
288 g_score[c_start] = 0.0
289 g_score[c_start] = self.heuristic(c_start, c_end)
291 closed_vtx: set[CoordTup] = set() # nodes already evaluated
292 open_vtx: set[CoordTup] = set([c_start]) # nodes to be evaluated
293 source: dict[CoordTup, CoordTup] = (
294 dict()
295 ) # node immediately preceding each node in the path (currently known shortest path)
297 while open_vtx:
298 # get lowest f_score node
299 c_current: CoordTup = min(open_vtx, key=lambda c: f_score[c])
300 # f_current: float = f_score[c_current]
302 # check if goal is reached
303 if c_end == c_current:
304 path: list[CoordTup] = [c_current]
305 p_current: CoordTup = c_current
306 while p_current in source:
307 p_current = source[p_current]
308 path.append(p_current)
309 # ----------------------------------------------------------------------
310 # this is the only return statement
311 return np.array(path[::-1])
312 # ----------------------------------------------------------------------
314 # close current node
315 closed_vtx.add(c_current)
316 open_vtx.remove(c_current)
318 # update g_score of neighbors
319 _np_neighbor: Coord
320 for _np_neighbor in self.get_coord_neighbors(c_current):
321 neighbor: CoordTup = tuple(_np_neighbor)
323 if neighbor in closed_vtx:
324 # already checked
325 continue
326 g_temp: float = g_score[c_current] + 1 # always 1 for maze neighbors
328 if neighbor not in open_vtx:
329 # found new vtx, so add
330 open_vtx.add(neighbor)
332 elif g_temp >= g_score[neighbor]:
333 # if already knew about this one, but current g_score is worse, skip
334 continue
336 # store g_score and source
337 source[neighbor] = c_current
338 g_score[neighbor] = g_temp
339 f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end)
341 raise ValueError(
342 "A solution could not be found!",
343 f"{c_start = }, {c_end = }",
344 self.as_ascii(),
345 )
347 def get_nodes(self) -> CoordArray:
348 """return a list of all nodes in the maze"""
349 rows: Int[np.ndarray, "x y"]
350 cols: Int[np.ndarray, "x y"]
351 rows, cols = np.meshgrid(
352 range(self.grid_shape[0]),
353 range(self.grid_shape[1]),
354 indexing="ij",
355 )
356 nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T
357 return nodes
359 def get_connected_component(self) -> CoordArray:
360 """get the largest (and assumed only nonsingular) connected component of the maze
362 TODO: other connected components?
363 """
364 if (self.generation_meta is None) or (
365 self.generation_meta.get("fully_connected", False)
366 ):
367 # for fully connected case, pick any two positions
368 return self.get_nodes()
369 else:
370 # if metadata provided, use visited cells
371 visited_cells: set[CoordTup] | None = self.generation_meta.get(
372 "visited_cells", None
373 )
374 if visited_cells is None:
375 # TODO: dynamically generate visited_cells?
376 raise ValueError(
377 f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}"
378 )
379 else:
380 visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells))
381 return visited_cells_np
383 @typing.overload
384 def generate_random_path(
385 self,
386 allowed_start: CoordList | None = None,
387 allowed_end: CoordList | None = None,
388 deadend_start: bool = False,
389 deadend_end: bool = False,
390 endpoints_not_equal: bool = False,
391 except_on_no_valid_endpoint: typing.Literal[True] = True,
392 ) -> CoordArray: ...
393 @typing.overload
394 def generate_random_path(
395 self,
396 allowed_start: CoordList | None = None,
397 allowed_end: CoordList | None = None,
398 deadend_start: bool = False,
399 deadend_end: bool = False,
400 endpoints_not_equal: bool = False,
401 except_on_no_valid_endpoint: typing.Literal[False] = False,
402 ) -> typing.Optional[CoordArray]: ...
403 def generate_random_path(
404 self,
405 allowed_start: CoordList | None = None,
406 allowed_end: CoordList | None = None,
407 deadend_start: bool = False,
408 deadend_end: bool = False,
409 endpoints_not_equal: bool = False,
410 except_on_no_valid_endpoint: bool = True,
411 ) -> typing.Optional[CoordArray]:
412 """return a path between randomly chosen start and end nodes within the connected component
414 Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end.
416 # Parameters:
417 - `allowed_start : CoordList | None`
418 a list of allowed start positions. If `None`, any position in the connected component is allowed
419 (defaults to `None`)
420 - `allowed_end : CoordList | None`
421 a list of allowed end positions. If `None`, any position in the connected component is allowed
422 (defaults to `None`)
423 - `deadend_start : bool`
424 whether to ***force*** the start position to be a deadend (defaults to `False`)
425 (defaults to `False`)
426 - `deadend_end : bool`
427 whether to ***force*** the end position to be a deadend (defaults to `False`)
428 (defaults to `False`)
429 - `endpoints_not_equal : bool`
430 whether to ensure tha the start and end point are not the same
431 (defaults to `False`)
432 - `except_on_no_valid_endpoint : bool`
433 whether to raise an error if no valid start or end positions are found
434 if this is `False`, the function might return `None` and this must be handled by the caller
435 (defaults to `True`)
437 # Returns:
438 - `CoordArray`
439 a path between the selected start and end positions
441 # Raises:
442 - `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True`
443 """
445 # we can't create a "path" in a single-node maze
446 assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, (
447 f"can't create path in single-node maze: {self.as_ascii()}"
448 )
450 # get connected component
451 connected_component: CoordArray = self.get_connected_component()
453 # initialize start and end positions
454 positions: Int[np.int8, "2 2"]
456 # if no special conditions on start and end positions
457 if (allowed_start, allowed_end, deadend_start, deadend_end) == (
458 None,
459 None,
460 False,
461 False,
462 ):
463 try:
464 positions = connected_component[ # type: ignore[assignment]
465 np.random.choice(
466 len(connected_component),
467 size=2,
468 replace=False,
469 )
470 ]
471 except ValueError as e:
472 if except_on_no_valid_endpoint:
473 raise NoValidEndpointException(
474 f"No valid start or end positions found because we could not sample from {connected_component = }"
475 ) from e
476 else:
477 return None
479 return self.find_shortest_path(positions[0], positions[1]) # type: ignore[index]
481 # handle special conditions
482 connected_component_set: set[CoordTup] = set(map(tuple, connected_component))
483 # copy connected component set
484 allowed_start_set: set[CoordTup] = connected_component_set.copy()
485 allowed_end_set: set[CoordTup] = connected_component_set.copy()
487 # filter by explicitly allowed start and end positions
488 # '# type: ignore[assignment]' here because the returned tuple can be of any length
489 if allowed_start is not None:
490 allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set # type: ignore[assignment]
492 if allowed_end is not None:
493 allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set # type: ignore[assignment]
495 # filter by forcing deadends
496 if deadend_start:
497 allowed_start_set = set(
498 filter(
499 lambda x: len(self.get_coord_neighbors(x)) == 1, allowed_start_set
500 )
501 )
503 if deadend_end:
504 allowed_end_set = set(
505 filter(lambda x: len(self.get_coord_neighbors(x)) == 1, allowed_end_set)
506 )
508 # check we have valid positions
509 if len(allowed_start_set) == 0 or len(allowed_end_set) == 0:
510 if except_on_no_valid_endpoint:
511 raise NoValidEndpointException(
512 f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }"
513 )
514 else:
515 return None
517 # randomly select start and end positions
518 try:
519 # ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok
520 start_pos: CoordTup = tuple( # type: ignore[assignment]
521 list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))]
522 )
523 if endpoints_not_equal:
524 # remove start position from end positions
525 allowed_end_set.discard(start_pos)
526 end_pos: CoordTup = tuple( # type: ignore[assignment]
527 list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))]
528 )
529 except ValueError as e:
530 if except_on_no_valid_endpoint:
531 raise NoValidEndpointException(
532 f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }"
533 ) from e
534 else:
535 return None
537 return self.find_shortest_path(start_pos, end_pos)
539 # ============================================================
540 # to and from adjacency list
541 # ============================================================
542 def as_adj_list(
543 self, shuffle_d0: bool = True, shuffle_d1: bool = True
544 ) -> Int8[np.ndarray, "conn start_end coord"]:
545 return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1)
547 @classmethod
548 def from_adj_list(
549 cls,
550 adj_list: Int8[np.ndarray, "conn start_end coord"],
551 ) -> "LatticeMaze":
552 """create a LatticeMaze from a list of connections
554 > [!NOTE]
555 > This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed.
556 """
558 # this is where it would probably break for rectangular mazes
559 grid_n: int = adj_list.max() + 1
561 connection_list: ConnectionList = np.zeros(
562 (2, grid_n, grid_n),
563 dtype=np.bool_,
564 )
566 for c_start, c_end in adj_list:
567 # check that exactly 1 coordinate matches
568 if (c_start == c_end).sum() != 1:
569 raise ValueError("invalid connection")
571 # get the direction
572 d: int = (c_start != c_end).argmax()
574 x: int
575 y: int
576 # pick whichever has the lesser value in the direction `d`
577 if c_start[d] < c_end[d]:
578 x, y = c_start
579 else:
580 x, y = c_end
582 connection_list[d, x, y] = True
584 return LatticeMaze(
585 connection_list=connection_list,
586 )
588 def as_adj_list_tokens(self) -> list[str | CoordTup]:
589 warnings.warn(
590 "`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.",
591 TokenizerDeprecationWarning,
592 )
593 return [
594 SPECIAL_TOKENS.ADJLIST_START,
595 *chain.from_iterable( # type: ignore[list-item]
596 [
597 [
598 tuple(c_s),
599 SPECIAL_TOKENS.CONNECTOR,
600 tuple(c_e),
601 SPECIAL_TOKENS.ADJACENCY_ENDLINE,
602 ]
603 for c_s, c_e in self.as_adj_list()
604 ]
605 ),
606 SPECIAL_TOKENS.ADJLIST_END,
607 ]
609 def _as_adj_list_tokens(self) -> list[str | CoordTup]:
610 return [
611 SPECIAL_TOKENS.ADJLIST_START,
612 *chain.from_iterable( # type: ignore[list-item]
613 [
614 [
615 tuple(c_s),
616 SPECIAL_TOKENS.CONNECTOR,
617 tuple(c_e),
618 SPECIAL_TOKENS.ADJACENCY_ENDLINE,
619 ]
620 for c_s, c_e in self.as_adj_list()
621 ]
622 ),
623 SPECIAL_TOKENS.ADJLIST_END,
624 ]
626 def _as_coords_and_special_AOTP(self) -> list[CoordTup | str]:
627 """turn the maze into adjacency list, origin, target, and solution -- keep coords as tuples"""
629 output: list[CoordTup | str] = self._as_adj_list_tokens()
630 # if getattr(self, "start_pos", None) is not None:
631 if isinstance(self, TargetedLatticeMaze):
632 output += self._get_start_pos_tokens()
633 if isinstance(self, TargetedLatticeMaze):
634 output += self._get_end_pos_tokens()
635 if isinstance(self, SolvedMaze):
636 output += self._get_solution_tokens()
637 return output
639 def _as_tokens(
640 self, maze_tokenizer: "MazeTokenizer | TokenizationMode"
641 ) -> list[str]:
642 # type ignores here fine since we check the instance
643 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"):
644 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr]
645 if (
646 isinstance_by_type_name(maze_tokenizer, "MazeTokenizer")
647 and maze_tokenizer.is_AOTP() # type: ignore[union-attr]
648 ):
649 coords_raw: list[CoordTup | str] = self._as_coords_and_special_AOTP()
650 coords_processed: list[str] = maze_tokenizer.coords_to_strings( # type: ignore[union-attr]
651 coords=coords_raw, when_noncoord="include"
652 )
653 return coords_processed
654 else:
655 raise NotImplementedError(f"Unsupported tokenizer type: {maze_tokenizer}")
657 def as_tokens(
658 self,
659 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular",
660 ) -> list[str]:
661 """serialize maze and solution to tokens"""
662 if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"):
663 return maze_tokenizer.to_tokens(self) # type: ignore[union-attr]
664 else:
665 return self._as_tokens(maze_tokenizer) # type: ignore[union-attr,arg-type]
667 @classmethod
668 def _from_tokens_AOTP(
669 cls, tokens: list[str], maze_tokenizer: "MazeTokenizer | MazeTokenizerModular"
670 ) -> "LatticeMaze":
671 """create a LatticeMaze from a list of tokens"""
673 # figure out what input format
674 # ========================================
675 if tokens[0] == SPECIAL_TOKENS.ADJLIST_START:
676 adj_list_tokens = get_adj_list_tokens(tokens)
677 else:
678 # If we're not getting a "complete" tokenized maze, assume it's just a the adjacency list tokens
679 adj_list_tokens = tokens
680 warnings.warn(
681 "Assuming input is just adjacency list tokens, no special tokens found"
682 )
684 # process edges for adjacency list
685 # ========================================
686 edges: list[list[str]] = list_split(
687 adj_list_tokens,
688 SPECIAL_TOKENS.ADJACENCY_ENDLINE,
689 )
691 coordinates: list[tuple[CoordTup, CoordTup]] = list()
692 for e in edges:
693 # skip last endline
694 if len(e) != 0:
695 # convert to coords, split start and end
696 e_coords: list[CoordTup] = maze_tokenizer.strings_to_coords(
697 e,
698 # TODO: i changed this to a skip since we then pipe the coords into a numpy array
699 # but I'm not entirely sure
700 when_noncoord="include",
701 )
702 assert len(e_coords) == 3, f"invalid edge: {e = } {e_coords = }"
703 assert e_coords[1] == SPECIAL_TOKENS.CONNECTOR, (
704 f"invalid edge: {e = } {e_coords = }"
705 )
706 coordinates.append((e_coords[0], e_coords[-1]))
708 assert all(len(c) == 2 for c in coordinates), (
709 f"invalid coordinates: {coordinates = }"
710 )
711 adj_list: Int8[np.ndarray, "conn start_end coord"] = np.array(coordinates)
712 assert tuple(adj_list.shape) == (
713 len(coordinates),
714 2,
715 2,
716 ), f"invalid adj_list: {adj_list.shape = } {coordinates = }"
718 output_maze: LatticeMaze = cls.from_adj_list(adj_list)
720 # add start and end positions
721 # ========================================
722 is_targeted: bool = False
723 if all(
724 x in tokens
725 for x in (
726 SPECIAL_TOKENS.ORIGIN_START,
727 SPECIAL_TOKENS.ORIGIN_END,
728 SPECIAL_TOKENS.TARGET_START,
729 SPECIAL_TOKENS.TARGET_END,
730 )
731 ):
732 start_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords(
733 get_origin_tokens(tokens), when_noncoord="error"
734 )
735 end_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords(
736 get_target_tokens(tokens), when_noncoord="error"
737 )
738 assert len(start_pos_list) == 1, (
739 f"invalid start_pos_list: {start_pos_list = }"
740 )
741 assert len(end_pos_list) == 1, f"invalid end_pos_list: {end_pos_list = }"
743 start_pos: CoordTup = start_pos_list[0]
744 end_pos: CoordTup = end_pos_list[0]
746 output_maze = TargetedLatticeMaze.from_lattice_maze(
747 lattice_maze=output_maze,
748 start_pos=start_pos,
749 end_pos=end_pos,
750 )
752 is_targeted = True
754 if all(
755 x in tokens for x in (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END)
756 ):
757 assert is_targeted, "maze must be targeted to have a solution"
758 solution: list[CoordTup] = maze_tokenizer.strings_to_coords(
759 get_path_tokens(tokens, trim_end=True),
760 when_noncoord="error",
761 )
762 output_maze = SolvedMaze.from_targeted_lattice_maze(
763 # HACK: I think this is fine, but im not sure
764 targeted_lattice_maze=output_maze, # type: ignore[arg-type]
765 solution=solution,
766 )
768 return output_maze
770 @classmethod
771 def from_tokens(
772 cls,
773 tokens: list[str],
774 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular",
775 ) -> "LatticeMaze":
776 """
777 Constructs a maze from a tokenization.
778 Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported.
779 """
780 # HACK: type ignores here fine since we check the instance
781 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"):
782 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr]
783 if (
784 isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular")
785 and not maze_tokenizer.is_legacy_equivalent() # type: ignore[union-attr]
786 ):
787 raise NotImplementedError(
788 f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}."
789 )
791 if isinstance(tokens, str):
792 tokens = tokens.split()
794 if maze_tokenizer.is_AOTP(): # type: ignore[union-attr]
795 return cls._from_tokens_AOTP(tokens, maze_tokenizer) # type: ignore[arg-type]
796 else:
797 raise NotImplementedError("only AOTP tokenization is supported")
799 # ============================================================
800 # to and from pixels
801 # ============================================================
802 def _as_pixels_bw(self) -> BinaryPixelGrid:
803 assert self.lattice_dim == 2, "only 2D mazes are supported"
804 # Create an empty pixel grid with walls
805 pixel_grid: Int[np.ndarray, "x y"] = np.full(
806 (self.grid_shape[0] * 2 + 1, self.grid_shape[1] * 2 + 1),
807 False,
808 dtype=np.bool_,
809 )
811 # Set white nodes
812 pixel_grid[1::2, 1::2] = True
814 # Set white connections (downward)
815 for i, row in enumerate(self.connection_list[0]):
816 for j, connected in enumerate(row):
817 if connected:
818 pixel_grid[i * 2 + 2, j * 2 + 1] = True
820 # Set white connections (rightward)
821 for i, row in enumerate(self.connection_list[1]):
822 for j, connected in enumerate(row):
823 if connected:
824 pixel_grid[i * 2 + 1, j * 2 + 2] = True
826 return pixel_grid
828 def as_pixels(
829 self,
830 show_endpoints: bool = True,
831 show_solution: bool = True,
832 ) -> PixelGrid:
833 # HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze`
834 # but solution, start_pos, end_pos not always defined
835 # but its fine since we explicitly check the type
836 if show_solution and not show_endpoints:
837 raise ValueError("show_solution=True requires show_endpoints=True")
838 # convert original bool pixel grid to RGB
839 pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw()
840 pixel_grid: PixelGrid = np.full(
841 (*pixel_grid_bw.shape, 3), PixelColors.WALL, dtype=np.uint8
842 )
843 pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN # noqa: E712
845 if self.__class__ == LatticeMaze:
846 return pixel_grid
848 # set endpoints for TargetedLatticeMaze
849 if self.__class__ == TargetedLatticeMaze:
850 if show_endpoints:
851 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined]
852 PixelColors.START
853 )
854 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined]
855 PixelColors.END
856 )
857 return pixel_grid
859 # set solution -- we only reach this part if `self.__class__ == SolvedMaze`
860 if show_solution:
861 for coord in self.solution: # type: ignore[attr-defined]
862 pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH
864 # set pixels between coords
865 for index, coord in enumerate(self.solution[:-1]): # type: ignore[attr-defined]
866 next_coord = self.solution[index + 1] # type: ignore[attr-defined]
867 # check they are adjacent using norm
868 assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, (
869 f"Coords {coord} and {next_coord} are not adjacent"
870 )
871 # set pixel between them
872 pixel_grid[
873 coord[0] * 2 + 1 + next_coord[0] - coord[0],
874 coord[1] * 2 + 1 + next_coord[1] - coord[1],
875 ] = PixelColors.PATH
877 # set endpoints (again, since path would overwrite them)
878 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined]
879 PixelColors.START
880 )
881 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined]
882 PixelColors.END
883 )
885 return pixel_grid
887 @classmethod
888 def _from_pixel_grid_bw(
889 cls, pixel_grid: BinaryPixelGrid
890 ) -> tuple[ConnectionList, tuple[int, int]]:
891 grid_shape: tuple[int, int] = (
892 pixel_grid.shape[0] // 2,
893 pixel_grid.shape[1] // 2,
894 )
895 connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_)
897 # Extract downward connections
898 connection_list[0] = pixel_grid[2::2, 1::2]
900 # Extract rightward connections
901 connection_list[1] = pixel_grid[1::2, 2::2]
903 return connection_list, grid_shape
905 @classmethod
906 def _from_pixel_grid_with_positions(
907 cls,
908 pixel_grid: PixelGrid | BinaryPixelGrid,
909 marked_positions: dict[str, RGB],
910 ) -> tuple[ConnectionList, tuple[int, int], dict[str, CoordArray]]:
911 # Convert RGB pixel grid to Bool pixel grid
912 # error: Incompatible types in assignment (expression has type
913 # "numpy.bool[builtins.bool] | ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]",
914 # variable has type "ndarray[Any, Any]") [assignment]
915 pixel_grid_bw: BinaryPixelGrid = ~np.all( # type: ignore[assignment]
916 pixel_grid == PixelColors.WALL, axis=-1
917 )
918 connection_list: ConnectionList
919 grid_shape: tuple[int, int]
920 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid_bw)
922 # Find any marked positions
923 out_positions: dict[str, CoordArray] = dict()
924 for key, color in marked_positions.items():
925 pos_temp: Int[np.ndarray, "x y"] = np.argwhere(
926 np.all(pixel_grid == color, axis=-1)
927 )
928 pos_save: list[CoordTup] = list()
929 for pos in pos_temp:
930 # if it is a coordinate and not connection (transform position, %2==1)
931 if pos[0] % 2 == 1 and pos[1] % 2 == 1:
932 pos_save.append((pos[0] // 2, pos[1] // 2))
934 out_positions[key] = np.array(pos_save)
936 return connection_list, grid_shape, out_positions
938 @classmethod
939 def from_pixels(
940 cls,
941 pixel_grid: PixelGrid,
942 ) -> "LatticeMaze":
943 connection_list: ConnectionList
944 grid_shape: tuple[int, int]
946 # if a binary pixel grid, return regular LatticeMaze
947 if len(pixel_grid.shape) == 2:
948 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid)
949 return LatticeMaze(connection_list=connection_list)
951 # otherwise, detect and check it's valid
952 cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid)
953 if cls not in cls_detected.__mro__:
954 raise ValueError(
955 f"Pixel grid cannot be cast to {cls.__name__}, detected type {cls_detected.__name__}"
956 )
958 (
959 connection_list,
960 grid_shape,
961 marked_pos,
962 ) = cls._from_pixel_grid_with_positions(
963 pixel_grid=pixel_grid,
964 marked_positions=dict(
965 start=PixelColors.START, end=PixelColors.END, solution=PixelColors.PATH
966 ),
967 )
968 # if we wanted a LatticeMaze, return it
969 if cls == LatticeMaze:
970 return LatticeMaze(connection_list=connection_list)
972 # otherwise, keep going
973 temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list)
975 # start and end pos
976 start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"]
977 assert start_pos_arr.shape == (
978 1,
979 2,
980 ), (
981 f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate"
982 )
983 assert end_pos_arr.shape == (
984 1,
985 2,
986 ), (
987 f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate"
988 )
990 start_pos: Coord = start_pos_arr[0]
991 end_pos: Coord = end_pos_arr[0]
993 # return a TargetedLatticeMaze if that's what we wanted
994 if cls == TargetedLatticeMaze:
995 return TargetedLatticeMaze(
996 connection_list=connection_list,
997 start_pos=start_pos,
998 end_pos=end_pos,
999 )
1001 # raw solution, only contains path elements and not start or end
1002 solution_raw: CoordArray = marked_pos["solution"]
1003 if len(solution_raw.shape) == 2:
1004 assert solution_raw.shape[1] == 2, (
1005 f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)"
1006 )
1007 elif solution_raw.shape == (0,):
1008 # the solution and end should be immediately adjacent
1009 assert np.sum(np.abs(start_pos - end_pos)) == 1, (
1010 f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given"
1011 )
1013 # order the solution, by creating a list from the start to the end
1014 # add end pos, since we will iterate over all these starting from the start pos
1015 solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [
1016 tuple(end_pos)
1017 ]
1018 # solution starts with start point
1019 solution: list[CoordTup] = [tuple(start_pos)]
1020 while solution[-1] != tuple(end_pos):
1021 # use `get_coord_neighbors` to find connected neighbors
1022 neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1])
1023 # TODO: make this less ugly
1024 assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), (
1025 f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}"
1026 )
1027 # neighbors = neighbors[:, [1, 0]]
1028 # filter out neighbors that are not in the raw solution
1029 neighbors_filtered: CoordArray = np.array(
1030 [
1031 coord
1032 for coord in neighbors
1033 if (
1034 tuple(coord) in solution_raw_list
1035 and tuple(coord) not in solution
1036 )
1037 ]
1038 )
1039 # assert only one element is left, and then add it to the solution
1040 assert neighbors_filtered.shape == (
1041 1,
1042 2,
1043 ), (
1044 f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}"
1045 )
1046 solution.append(tuple(neighbors_filtered[0]))
1048 # assert the solution is complete
1049 assert solution[0] == tuple(start_pos), (
1050 f"solution {solution} does not start at start_pos {start_pos}"
1051 )
1052 assert solution[-1] == tuple(end_pos), (
1053 f"solution {solution} does not end at end_pos {end_pos}"
1054 )
1056 return cls(
1057 connection_list=np.array(connection_list),
1058 solution=np.array(solution), # type: ignore[call-arg]
1059 )
1061 # ============================================================
1062 # to and from ASCII
1063 # ============================================================
1064 def _as_ascii_grid(self) -> Shaped[np.ndarray, "x y"]:
1065 # Get the pixel grid using to_pixels().
1066 pixel_grid: Bool[np.ndarray, "x y"] = self._as_pixels_bw()
1068 # Replace pixel values with ASCII characters.
1069 ascii_grid: Shaped[np.ndarray, "x y"] = np.full(
1070 pixel_grid.shape, AsciiChars.WALL, dtype=str
1071 )
1072 ascii_grid[pixel_grid == True] = AsciiChars.OPEN # noqa: E712
1074 return ascii_grid
1076 def as_ascii(
1077 self,
1078 show_endpoints: bool = True,
1079 show_solution: bool = True,
1080 ) -> str:
1081 """return an ASCII grid of the maze"""
1082 ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid()
1083 pixel_grid: PixelGrid = self.as_pixels(
1084 show_endpoints=show_endpoints, show_solution=show_solution
1085 )
1087 chars_replace: tuple = tuple()
1088 if show_endpoints:
1089 chars_replace += (AsciiChars.START, AsciiChars.END)
1090 if show_solution:
1091 chars_replace += (AsciiChars.PATH,)
1093 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items():
1094 if ascii_char in chars_replace:
1095 ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char
1097 return "\n".join("".join(row) for row in ascii_grid)
1099 @classmethod
1100 def from_ascii(cls, ascii_str: str) -> "LatticeMaze":
1101 lines: list[str] = ascii_str.strip().split("\n")
1102 lines = [line.strip() for line in lines]
1103 ascii_grid: Shaped[np.ndarray, "x y"] = np.array(
1104 [list(line) for line in lines], dtype=str
1105 )
1106 pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8)
1108 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items():
1109 pixel_grid[ascii_grid == ascii_char] = pixel_color
1111 return cls.from_pixels(pixel_grid)
1114# type ignore here even though theyre all frozen
1115# maybe `SerializeableDataclass` itself is not frozen, but thats an ABC
1116# error: Cannot inherit frozen dataclass from a non-frozen one [misc]
1117@serializable_dataclass(frozen=True, kw_only=True)
1118class TargetedLatticeMaze(LatticeMaze): # type: ignore[misc]
1119 """A LatticeMaze with a start and end position"""
1121 # this jank is so that SolvedMaze can inherit from this class without needing arguments for start_pos and end_pos
1122 # type ignore here because even though its a kw-only dataclass,
1123 # mypy doesn't like that non-default arguments are after default arguments
1124 start_pos: Coord = serializable_field( # type: ignore[misc]
1125 assert_type=False,
1126 )
1127 end_pos: Coord = serializable_field( # type: ignore[misc]
1128 assert_type=False,
1129 )
1131 def __post_init__(self) -> None:
1132 # make things numpy arrays (very jank to override frozen dataclass)
1133 self.__dict__["start_pos"] = np.array(self.start_pos)
1134 self.__dict__["end_pos"] = np.array(self.end_pos)
1135 assert self.start_pos is not None
1136 assert self.end_pos is not None
1137 # check that start and end are in bounds
1138 if (
1139 self.start_pos[0] >= self.grid_shape[0]
1140 or self.start_pos[1] >= self.grid_shape[1]
1141 ):
1142 raise ValueError(
1143 f"start_pos {self.start_pos} is out of bounds for grid shape {self.grid_shape}"
1144 )
1145 if (
1146 self.end_pos[0] >= self.grid_shape[0]
1147 or self.end_pos[1] >= self.grid_shape[1]
1148 ):
1149 raise ValueError(
1150 f"end_pos {self.end_pos} is out of bounds for grid shape {self.grid_shape}"
1151 )
1153 def __eq__(self, other: object) -> bool:
1154 return super().__eq__(other)
1156 def _get_start_pos_tokens(self) -> list[str | CoordTup]:
1157 return [
1158 SPECIAL_TOKENS.ORIGIN_START,
1159 tuple(self.start_pos),
1160 SPECIAL_TOKENS.ORIGIN_END,
1161 ]
1163 def get_start_pos_tokens(self) -> list[str | CoordTup]:
1164 warnings.warn(
1165 "`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.",
1166 TokenizerDeprecationWarning,
1167 )
1168 return self._get_start_pos_tokens()
1170 def _get_end_pos_tokens(self) -> list[str | CoordTup]:
1171 return [
1172 SPECIAL_TOKENS.TARGET_START,
1173 tuple(self.end_pos),
1174 SPECIAL_TOKENS.TARGET_END,
1175 ]
1177 def get_end_pos_tokens(self) -> list[str | CoordTup]:
1178 warnings.warn(
1179 "`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.",
1180 TokenizerDeprecationWarning,
1181 )
1182 return self._get_end_pos_tokens()
1184 @classmethod
1185 def from_lattice_maze(
1186 cls,
1187 lattice_maze: LatticeMaze,
1188 start_pos: Coord | CoordTup,
1189 end_pos: Coord | CoordTup,
1190 ) -> "TargetedLatticeMaze":
1191 return cls(
1192 connection_list=lattice_maze.connection_list,
1193 start_pos=np.array(start_pos),
1194 end_pos=np.array(end_pos),
1195 generation_meta=lattice_maze.generation_meta,
1196 )
1199@serializable_dataclass(frozen=True, kw_only=True)
1200class SolvedMaze(TargetedLatticeMaze): # type: ignore[misc]
1201 """Stores a maze and a solution"""
1203 solution: CoordArray = serializable_field( # type: ignore[misc]
1204 assert_type=False,
1205 )
1207 def __init__(
1208 self,
1209 connection_list: ConnectionList,
1210 solution: CoordArray,
1211 generation_meta: dict | None = None,
1212 start_pos: Coord | None = None,
1213 end_pos: Coord | None = None,
1214 allow_invalid: bool = False,
1215 ) -> None:
1216 # figure out the solution
1217 solution_valid: bool = False
1218 if solution is not None:
1219 solution = np.array(solution)
1220 # note that a path length of 1 here is valid, since the start and end pos could be the same
1221 if (solution.shape[0] > 0) and (solution.shape[1] == 2):
1222 solution_valid = True
1224 if not solution_valid and not allow_invalid:
1225 raise ValueError(
1226 f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }",
1227 f"{connection_list = }",
1228 )
1230 # init the TargetedLatticeMaze
1231 super().__init__(
1232 connection_list=connection_list,
1233 generation_meta=generation_meta,
1234 # TODO: the argument type is stricter than the expected type but it still fails?
1235 # error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type
1236 # "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]" [arg-type]
1237 start_pos=np.array(solution[0]) if solution_valid else None, # type: ignore[arg-type]
1238 end_pos=np.array(solution[-1]) if solution_valid else None, # type: ignore[arg-type]
1239 )
1241 self.__dict__["solution"] = solution
1243 # adjust the endpoints
1244 if not allow_invalid:
1245 if start_pos is not None:
1246 assert np.array_equal(np.array(start_pos), self.start_pos), (
1247 f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}"
1248 )
1249 if end_pos is not None:
1250 assert np.array_equal(np.array(end_pos), self.end_pos), (
1251 f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}"
1252 )
1253 # TODO: assert the path does not backtrack, walk through walls, etc?
1255 def __eq__(self, other: object) -> bool:
1256 return super().__eq__(other)
1258 def __hash__(self) -> int:
1259 return hash((self.connection_list.tobytes(), self.solution.tobytes()))
1261 def _get_solution_tokens(self) -> list[str | CoordTup]:
1262 return [
1263 SPECIAL_TOKENS.PATH_START,
1264 *[tuple(c) for c in self.solution],
1265 SPECIAL_TOKENS.PATH_END,
1266 ]
1268 def get_solution_tokens(self) -> list[str | CoordTup]:
1269 warnings.warn(
1270 "`LatticeMaze.get_solution_tokens` is deprecated.",
1271 TokenizerDeprecationWarning,
1272 )
1273 return self._get_solution_tokens()
1275 # for backwards compatibility
1276 @property
1277 def maze(self) -> LatticeMaze:
1278 warnings.warn(
1279 "`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.",
1280 DeprecationWarning,
1281 )
1282 return LatticeMaze(connection_list=self.connection_list)
1284 # type ignore here since we're overriding a method with a different signature
1285 @classmethod
1286 def from_lattice_maze( # type: ignore[override]
1287 cls, lattice_maze: LatticeMaze, solution: list[CoordTup]
1288 ) -> "SolvedMaze":
1289 return cls(
1290 connection_list=lattice_maze.connection_list,
1291 solution=np.array(solution),
1292 generation_meta=lattice_maze.generation_meta,
1293 )
1295 @classmethod
1296 def from_targeted_lattice_maze(
1297 cls,
1298 targeted_lattice_maze: TargetedLatticeMaze,
1299 solution: list[CoordTup] | CoordArray | None = None,
1300 ) -> "SolvedMaze":
1301 """solves the given targeted lattice maze and returns a SolvedMaze"""
1302 if solution is None:
1303 solution = targeted_lattice_maze.find_shortest_path(
1304 targeted_lattice_maze.start_pos,
1305 targeted_lattice_maze.end_pos,
1306 )
1307 return cls(
1308 connection_list=targeted_lattice_maze.connection_list,
1309 solution=np.array(solution),
1310 generation_meta=targeted_lattice_maze.generation_meta,
1311 )
1313 def get_solution_forking_points(
1314 self,
1315 always_include_endpoints: bool = False,
1316 ) -> tuple[list[int], CoordArray]:
1317 """coordinates and their indicies from the solution where a fork is present
1319 - if the start point is not a dead end, this counts as a fork
1320 - if the end point is not a dead end, this counts as a fork
1321 """
1322 output_idxs: list[int] = list()
1323 output_coords: list[CoordTup] = list()
1325 for idx, coord in enumerate(self.solution):
1326 # more than one choice for first coord, or more than 2 for any other
1327 # since the previous coord doesn't count as a choice
1328 is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1
1329 theshold: int = 1 if is_endpoint else 2
1330 if self.get_coord_neighbors(coord).shape[0] > theshold or (
1331 is_endpoint and always_include_endpoints
1332 ):
1333 output_idxs.append(idx)
1334 output_coords.append(coord)
1336 return output_idxs, np.array(output_coords)
1338 def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]:
1339 """coordinates from the solution where there is only a single (non-backtracking) point to move to
1341 returns the complement of `get_solution_forking_points` from the path"""
1342 forks_idxs, _ = self.get_solution_forking_points()
1343 # HACK: idk why type ignore here
1344 return ( # type: ignore[return-value]
1345 np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0),
1346 np.delete(self.solution, forks_idxs, axis=0),
1347 )
1350def detect_pixels_type(data: PixelGrid) -> typing.Type[LatticeMaze]:
1351 """Detects the type of pixels data by checking for the presence of start and end pixels"""
1352 if color_in_pixel_grid(data, PixelColors.START) or color_in_pixel_grid(
1353 data, PixelColors.END
1354 ):
1355 if color_in_pixel_grid(data, PixelColors.PATH):
1356 return SolvedMaze
1357 else:
1358 return TargetedLatticeMaze
1359 else:
1360 return LatticeMaze
1363def _remove_isolated_cells(
1364 image: Int[np.ndarray, "RGB x y"],
1365) -> Int[np.ndarray, "RGB x y"]:
1366 """
1367 Removes isolated cells from an image. An isolated cell is a cell that is surrounded by walls on all sides.
1368 """
1369 # Create a binary mask where True represents walls
1370 wall_mask = np.all(image == PixelColors.WALL, axis=-1)
1372 # Pad the wall mask to handle edge cases
1373 padded_wall_mask = np.pad(
1374 wall_mask, ((1, 1), (1, 1)), mode="constant", constant_values=True
1375 )
1377 # Check neighbors in all four directions
1378 isolated_mask = (
1379 padded_wall_mask[1:-1, 2:] # right
1380 & padded_wall_mask[1:-1, :-2] # left
1381 & padded_wall_mask[2:, 1:-1] # down
1382 & padded_wall_mask[:-2, 1:-1] # up
1383 )
1385 # Combine with non-wall mask to only affect open cells
1386 non_wall_mask = ~wall_mask
1387 isolated_mask = isolated_mask & non_wall_mask
1389 # Create the output image
1390 output_image = image.copy()
1391 output_image[isolated_mask] = PixelColors.WALL
1393 return output_image
1396_RIC_PADS: dict = {
1397 "left": ((1, 0), (0, 0)),
1398 "right": ((0, 1), (0, 0)),
1399 "up": ((0, 0), (1, 0)),
1400 "down": ((0, 0), (0, 1)),
1401}
1403# Define slices for each direction
1404_RIC_SLICES: dict = {
1405 "left": (slice(1, None), slice(None, None)),
1406 "right": (slice(None, -1), slice(None, None)),
1407 "up": (slice(None, None), slice(1, None)),
1408 "down": (slice(None, None), slice(None, -1)),
1409}
1412# TODO: figure out why this function doesnt work, or maybe just get rid of it
1413# def _remove_isolated_cells_old(
1414# image: Int[np.ndarray, "RGB x y"],
1415# ) -> Int[np.ndarray, "RGB x y"]:
1416# """
1417# Removes isolated cells from an image. An isolated cell is a cell that is surrounded by walls on all sides.
1418# """
1419# warnings.warn("this functin doesn't work and I have no idea why!!!")
1420# masks: dict[str, np.ndarray] = {
1421# d: np.all(
1422# np.pad(
1423# image[_RIC_SLICES[d][0], _RIC_SLICES[d][1], :] == PixelColors.WALL,
1424# np.array((*_RIC_PADS[d], (0, 0)), dtype=np.int8),
1425# mode="constant",
1426# constant_values=True,
1427# ),
1428# axis=2,
1429# )
1430# for d in _RIC_SLICES.keys()
1431# }
1433# # Create a mask for non-wall cells
1434# mask_non_wall = np.all(image != PixelColors.WALL, axis=2)
1436# # print(f"{mask_non_wall.shape = }")
1437# # print(f"{ {k: masks[k].shape for k in masks.keys()} = }")
1439# # print(f"{mask_non_wall = }")
1440# # print(f"{masks['down'] = }")
1442# # Combine the masks
1443# mask = mask_non_wall & masks["left"] & masks["right"] & masks["up"] & masks["down"]
1445# # Apply the mask
1446# output_image = np.where(
1447# np.stack([mask] * 3, axis=-1),
1448# PixelColors.WALL,
1449# image,
1450# )
1452# return output_image