Coverage for maze_dataset\maze\lattice_maze.py: 59%

495 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1import typing 

2import warnings 

3from dataclasses import dataclass 

4from itertools import chain 

5 

6import numpy as np 

7from jaxtyping import Bool, Int, Int8, Shaped 

8from muutils.json_serialize.serializable_dataclass import ( 

9 SerializableDataclass, 

10 serializable_dataclass, 

11 serializable_field, 

12) 

13from muutils.misc import isinstance_by_type_name, list_split 

14 

15from maze_dataset.constants import ( 

16 NEIGHBORS_MASK, 

17 SPECIAL_TOKENS, 

18 ConnectionList, 

19 Coord, 

20 CoordArray, 

21 CoordList, 

22 CoordTup, 

23) 

24from maze_dataset.token_utils import ( 

25 TokenizerDeprecationWarning, 

26 connection_list_to_adj_list, 

27 get_adj_list_tokens, 

28 get_origin_tokens, 

29 get_path_tokens, 

30 get_target_tokens, 

31) 

32 

33if typing.TYPE_CHECKING: 

34 from maze_dataset.tokenization import ( 

35 MazeTokenizer, 

36 MazeTokenizerModular, 

37 TokenizationMode, 

38 ) 

39 

40RGB = tuple[int, int, int] 

41"rgb tuple of values 0-255" 

42 

43PixelGrid = Int[np.ndarray, "x y rgb"] 

44"rgb grid of pixels" 

45BinaryPixelGrid = Bool[np.ndarray, "x y"] 

46"boolean grid of pixels" 

47 

48 

49class NoValidEndpointException(Exception): 

50 """Raised when no valid start or end positions are found in a maze.""" 

51 

52 pass 

53 

54 

55def _fill_edges_with_walls(connection_list: ConnectionList) -> ConnectionList: 

56 """fill the last elements of the connections lists as false for each dim""" 

57 for dim in range(connection_list.shape[0]): 

58 # last row for down 

59 if dim == 0: 

60 connection_list[dim, -1, :] = False 

61 # last column for right 

62 elif dim == 1: 

63 connection_list[dim, :, -1] = False 

64 else: 

65 raise NotImplementedError(f"only 2d lattices supported. got {dim=}") 

66 return connection_list 

67 

68 

69def color_in_pixel_grid(pixel_grid: PixelGrid, color: RGB) -> bool: 

70 for row in pixel_grid: 

71 for pixel in row: 

72 if np.all(pixel == color): 

73 return True 

74 return False 

75 

76 

77@dataclass(frozen=True) 

78class PixelColors: 

79 "standard colors for pixel grids" 

80 

81 WALL: RGB = (0, 0, 0) 

82 OPEN: RGB = (255, 255, 255) 

83 START: RGB = (0, 255, 0) 

84 END: RGB = (255, 0, 0) 

85 PATH: RGB = (0, 0, 255) 

86 

87 

88@dataclass(frozen=True) 

89class AsciiChars: 

90 "standard ascii characters for mazes" 

91 

92 WALL: str = "#" 

93 OPEN: str = " " 

94 START: str = "S" 

95 END: str = "E" 

96 PATH: str = "X" 

97 

98 

99ASCII_PIXEL_PAIRINGS: dict[str, RGB] = { 

100 AsciiChars.WALL: PixelColors.WALL, 

101 AsciiChars.OPEN: PixelColors.OPEN, 

102 AsciiChars.START: PixelColors.START, 

103 AsciiChars.END: PixelColors.END, 

104 AsciiChars.PATH: PixelColors.PATH, 

105} 

106"map ascii characters to pixel colors" 

107 

108 

109@serializable_dataclass( 

110 frozen=True, 

111 kw_only=True, 

112 properties_to_serialize=["lattice_dim", "generation_meta"], 

113) 

114class LatticeMaze(SerializableDataclass): 

115 """lattice maze (nodes on a lattice, connections only to neighboring nodes) 

116 

117 Connection List represents which nodes (N) are connected in each direction. 

118 

119 First and second elements represent rightward and downward connections, 

120 respectively. 

121 

122 Example: 

123 Connection list: 

124 [ 

125 [ # down 

126 [F T], 

127 [F F] 

128 ], 

129 [ # right 

130 [T F], 

131 [T F] 

132 ] 

133 ] 

134 

135 Nodes with connections 

136 N T N F 

137 F T 

138 N T N F 

139 F F 

140 

141 Graph: 

142 N - N 

143 | 

144 N - N 

145 

146 Note: the bottom row connections going down, and the 

147 right-hand connections going right, will always be False. 

148 """ 

149 

150 connection_list: ConnectionList 

151 generation_meta: dict | None = serializable_field(default=None, compare=False) 

152 

153 lattice_dim = property(lambda self: self.connection_list.shape[0]) 

154 grid_shape = property(lambda self: self.connection_list.shape[1:]) 

155 n_connections = property(lambda self: self.connection_list.sum()) 

156 

157 @property 

158 def grid_n(self) -> int: 

159 assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported" 

160 return self.grid_shape[0] 

161 

162 # ============================================================ 

163 # basic methods 

164 # ============================================================ 

165 

166 def __eq__(self, other: object) -> bool: 

167 return super().__eq__(other) 

168 

169 @staticmethod 

170 def heuristic(a: CoordTup, b: CoordTup) -> float: 

171 """return manhattan distance between two points""" 

172 return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1]) 

173 

174 def __hash__(self) -> int: 

175 return hash(self.connection_list.tobytes()) 

176 

177 def nodes_connected(self, a: Coord, b: Coord, /) -> bool: 

178 """returns whether two nodes are connected""" 

179 delta: Coord = b - a 

180 if np.abs(delta).sum() != 1: 

181 # return false if not even adjacent 

182 return False 

183 else: 

184 # test for wall 

185 dim: int = int(np.argmax(np.abs(delta))) 

186 clist_node: Coord = a if (delta.sum() > 0) else b 

187 return self.connection_list[dim, clist_node[0], clist_node[1]] 

188 

189 def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool: 

190 """check if a path is valid""" 

191 # check path is not empty 

192 if len(path) == 0: 

193 if not empty_is_valid: 

194 return False 

195 else: 

196 return True 

197 

198 # check all coords in bounds of maze 

199 if not np.all((0 <= path) & (path < self.grid_shape)): 

200 return False 

201 

202 # check all nodes connected 

203 for i in range(len(path) - 1): 

204 if not self.nodes_connected(path[i], path[i + 1]): 

205 return False 

206 return True 

207 

208 def coord_degrees(self) -> Int8[np.ndarray, "row col"]: 

209 """ 

210 Returns an array with the connectivity degree of each coord. 

211 I.e., how many neighbors each coord has. 

212 """ 

213 int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = ( 

214 self.connection_list.astype(np.int8) 

215 ) 

216 degrees: Int8[np.ndarray, "row col"] = np.sum( 

217 int_conn, axis=0 

218 ) # Connections to east and south 

219 degrees[:, 1:] += int_conn[1, :, :-1] # Connections to west 

220 degrees[1:, :] += int_conn[0, :-1, :] # Connections to north 

221 return degrees 

222 

223 def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray: 

224 """ 

225 Returns an array of the neighboring, connected coords of `c`. 

226 """ 

227 c = np.array(c) # type: ignore[assignment] 

228 neighbors: list[Coord] = [ 

229 neighbor 

230 for neighbor in (c + NEIGHBORS_MASK) 

231 if ( 

232 (0 <= neighbor[0] < self.grid_shape[0]) # in x bounds 

233 and (0 <= neighbor[1] < self.grid_shape[1]) # in y bounds 

234 and self.nodes_connected(c, neighbor) # connected 

235 ) 

236 ] 

237 

238 output: CoordArray = np.array(neighbors) 

239 if len(neighbors) > 0: 

240 assert output.shape == ( 

241 len(neighbors), 

242 2, 

243 ), ( 

244 f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}" 

245 ) 

246 return output 

247 

248 def gen_connected_component_from(self, c: Coord) -> CoordArray: 

249 """return the connected component from a given coordinate""" 

250 # Stack for DFS 

251 stack: list[Coord] = [c] 

252 

253 # Set to store visited nodes 

254 visited: set[CoordTup] = set() 

255 

256 while stack: 

257 current_node: Coord = stack.pop() 

258 # this is fine since we know current_node is a coord and thus of length 2 

259 visited.add(tuple(current_node)) # type: ignore[arg-type] 

260 

261 # Get the neighbors of the current node 

262 neighbors = self.get_coord_neighbors(current_node) 

263 

264 # Iterate over neighbors 

265 for neighbor in neighbors: 

266 if tuple(neighbor) not in visited: 

267 stack.append(neighbor) 

268 

269 return np.array(list(visited)) 

270 

271 def find_shortest_path( 

272 self, 

273 c_start: CoordTup | Coord, 

274 c_end: CoordTup | Coord, 

275 ) -> CoordArray: 

276 """find the shortest path between two coordinates, using A*""" 

277 c_start = tuple(c_start) # type: ignore[assignment] 

278 c_end = tuple(c_end) # type: ignore[assignment] 

279 

280 g_score: dict[CoordTup, float] = ( 

281 dict() 

282 ) # cost of cheapest path to node from start currently known 

283 f_score: dict[CoordTup, float] = { 

284 c_start: 0.0 

285 } # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end) 

286 

287 # init 

288 g_score[c_start] = 0.0 

289 g_score[c_start] = self.heuristic(c_start, c_end) 

290 

291 closed_vtx: set[CoordTup] = set() # nodes already evaluated 

292 open_vtx: set[CoordTup] = set([c_start]) # nodes to be evaluated 

293 source: dict[CoordTup, CoordTup] = ( 

294 dict() 

295 ) # node immediately preceding each node in the path (currently known shortest path) 

296 

297 while open_vtx: 

298 # get lowest f_score node 

299 c_current: CoordTup = min(open_vtx, key=lambda c: f_score[c]) 

300 # f_current: float = f_score[c_current] 

301 

302 # check if goal is reached 

303 if c_end == c_current: 

304 path: list[CoordTup] = [c_current] 

305 p_current: CoordTup = c_current 

306 while p_current in source: 

307 p_current = source[p_current] 

308 path.append(p_current) 

309 # ---------------------------------------------------------------------- 

310 # this is the only return statement 

311 return np.array(path[::-1]) 

312 # ---------------------------------------------------------------------- 

313 

314 # close current node 

315 closed_vtx.add(c_current) 

316 open_vtx.remove(c_current) 

317 

318 # update g_score of neighbors 

319 _np_neighbor: Coord 

320 for _np_neighbor in self.get_coord_neighbors(c_current): 

321 neighbor: CoordTup = tuple(_np_neighbor) 

322 

323 if neighbor in closed_vtx: 

324 # already checked 

325 continue 

326 g_temp: float = g_score[c_current] + 1 # always 1 for maze neighbors 

327 

328 if neighbor not in open_vtx: 

329 # found new vtx, so add 

330 open_vtx.add(neighbor) 

331 

332 elif g_temp >= g_score[neighbor]: 

333 # if already knew about this one, but current g_score is worse, skip 

334 continue 

335 

336 # store g_score and source 

337 source[neighbor] = c_current 

338 g_score[neighbor] = g_temp 

339 f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end) 

340 

341 raise ValueError( 

342 "A solution could not be found!", 

343 f"{c_start = }, {c_end = }", 

344 self.as_ascii(), 

345 ) 

346 

347 def get_nodes(self) -> CoordArray: 

348 """return a list of all nodes in the maze""" 

349 rows: Int[np.ndarray, "x y"] 

350 cols: Int[np.ndarray, "x y"] 

351 rows, cols = np.meshgrid( 

352 range(self.grid_shape[0]), 

353 range(self.grid_shape[1]), 

354 indexing="ij", 

355 ) 

356 nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T 

357 return nodes 

358 

359 def get_connected_component(self) -> CoordArray: 

360 """get the largest (and assumed only nonsingular) connected component of the maze 

361 

362 TODO: other connected components? 

363 """ 

364 if (self.generation_meta is None) or ( 

365 self.generation_meta.get("fully_connected", False) 

366 ): 

367 # for fully connected case, pick any two positions 

368 return self.get_nodes() 

369 else: 

370 # if metadata provided, use visited cells 

371 visited_cells: set[CoordTup] | None = self.generation_meta.get( 

372 "visited_cells", None 

373 ) 

374 if visited_cells is None: 

375 # TODO: dynamically generate visited_cells? 

376 raise ValueError( 

377 f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}" 

378 ) 

379 else: 

380 visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells)) 

381 return visited_cells_np 

382 

383 @typing.overload 

384 def generate_random_path( 

385 self, 

386 allowed_start: CoordList | None = None, 

387 allowed_end: CoordList | None = None, 

388 deadend_start: bool = False, 

389 deadend_end: bool = False, 

390 endpoints_not_equal: bool = False, 

391 except_on_no_valid_endpoint: typing.Literal[True] = True, 

392 ) -> CoordArray: ... 

393 @typing.overload 

394 def generate_random_path( 

395 self, 

396 allowed_start: CoordList | None = None, 

397 allowed_end: CoordList | None = None, 

398 deadend_start: bool = False, 

399 deadend_end: bool = False, 

400 endpoints_not_equal: bool = False, 

401 except_on_no_valid_endpoint: typing.Literal[False] = False, 

402 ) -> typing.Optional[CoordArray]: ... 

403 def generate_random_path( 

404 self, 

405 allowed_start: CoordList | None = None, 

406 allowed_end: CoordList | None = None, 

407 deadend_start: bool = False, 

408 deadend_end: bool = False, 

409 endpoints_not_equal: bool = False, 

410 except_on_no_valid_endpoint: bool = True, 

411 ) -> typing.Optional[CoordArray]: 

412 """return a path between randomly chosen start and end nodes within the connected component 

413 

414 Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end. 

415 

416 # Parameters: 

417 - `allowed_start : CoordList | None` 

418 a list of allowed start positions. If `None`, any position in the connected component is allowed 

419 (defaults to `None`) 

420 - `allowed_end : CoordList | None` 

421 a list of allowed end positions. If `None`, any position in the connected component is allowed 

422 (defaults to `None`) 

423 - `deadend_start : bool` 

424 whether to ***force*** the start position to be a deadend (defaults to `False`) 

425 (defaults to `False`) 

426 - `deadend_end : bool` 

427 whether to ***force*** the end position to be a deadend (defaults to `False`) 

428 (defaults to `False`) 

429 - `endpoints_not_equal : bool` 

430 whether to ensure tha the start and end point are not the same 

431 (defaults to `False`) 

432 - `except_on_no_valid_endpoint : bool` 

433 whether to raise an error if no valid start or end positions are found 

434 if this is `False`, the function might return `None` and this must be handled by the caller 

435 (defaults to `True`) 

436 

437 # Returns: 

438 - `CoordArray` 

439 a path between the selected start and end positions 

440 

441 # Raises: 

442 - `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True` 

443 """ 

444 

445 # we can't create a "path" in a single-node maze 

446 assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, ( 

447 f"can't create path in single-node maze: {self.as_ascii()}" 

448 ) 

449 

450 # get connected component 

451 connected_component: CoordArray = self.get_connected_component() 

452 

453 # initialize start and end positions 

454 positions: Int[np.int8, "2 2"] 

455 

456 # if no special conditions on start and end positions 

457 if (allowed_start, allowed_end, deadend_start, deadend_end) == ( 

458 None, 

459 None, 

460 False, 

461 False, 

462 ): 

463 try: 

464 positions = connected_component[ # type: ignore[assignment] 

465 np.random.choice( 

466 len(connected_component), 

467 size=2, 

468 replace=False, 

469 ) 

470 ] 

471 except ValueError as e: 

472 if except_on_no_valid_endpoint: 

473 raise NoValidEndpointException( 

474 f"No valid start or end positions found because we could not sample from {connected_component = }" 

475 ) from e 

476 else: 

477 return None 

478 

479 return self.find_shortest_path(positions[0], positions[1]) # type: ignore[index] 

480 

481 # handle special conditions 

482 connected_component_set: set[CoordTup] = set(map(tuple, connected_component)) 

483 # copy connected component set 

484 allowed_start_set: set[CoordTup] = connected_component_set.copy() 

485 allowed_end_set: set[CoordTup] = connected_component_set.copy() 

486 

487 # filter by explicitly allowed start and end positions 

488 # '# type: ignore[assignment]' here because the returned tuple can be of any length 

489 if allowed_start is not None: 

490 allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set # type: ignore[assignment] 

491 

492 if allowed_end is not None: 

493 allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set # type: ignore[assignment] 

494 

495 # filter by forcing deadends 

496 if deadend_start: 

497 allowed_start_set = set( 

498 filter( 

499 lambda x: len(self.get_coord_neighbors(x)) == 1, allowed_start_set 

500 ) 

501 ) 

502 

503 if deadend_end: 

504 allowed_end_set = set( 

505 filter(lambda x: len(self.get_coord_neighbors(x)) == 1, allowed_end_set) 

506 ) 

507 

508 # check we have valid positions 

509 if len(allowed_start_set) == 0 or len(allowed_end_set) == 0: 

510 if except_on_no_valid_endpoint: 

511 raise NoValidEndpointException( 

512 f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }" 

513 ) 

514 else: 

515 return None 

516 

517 # randomly select start and end positions 

518 try: 

519 # ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok 

520 start_pos: CoordTup = tuple( # type: ignore[assignment] 

521 list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))] 

522 ) 

523 if endpoints_not_equal: 

524 # remove start position from end positions 

525 allowed_end_set.discard(start_pos) 

526 end_pos: CoordTup = tuple( # type: ignore[assignment] 

527 list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))] 

528 ) 

529 except ValueError as e: 

530 if except_on_no_valid_endpoint: 

531 raise NoValidEndpointException( 

532 f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }" 

533 ) from e 

534 else: 

535 return None 

536 

537 return self.find_shortest_path(start_pos, end_pos) 

538 

539 # ============================================================ 

540 # to and from adjacency list 

541 # ============================================================ 

542 def as_adj_list( 

543 self, shuffle_d0: bool = True, shuffle_d1: bool = True 

544 ) -> Int8[np.ndarray, "conn start_end coord"]: 

545 return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1) 

546 

547 @classmethod 

548 def from_adj_list( 

549 cls, 

550 adj_list: Int8[np.ndarray, "conn start_end coord"], 

551 ) -> "LatticeMaze": 

552 """create a LatticeMaze from a list of connections 

553 

554 > [!NOTE] 

555 > This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed. 

556 """ 

557 

558 # this is where it would probably break for rectangular mazes 

559 grid_n: int = adj_list.max() + 1 

560 

561 connection_list: ConnectionList = np.zeros( 

562 (2, grid_n, grid_n), 

563 dtype=np.bool_, 

564 ) 

565 

566 for c_start, c_end in adj_list: 

567 # check that exactly 1 coordinate matches 

568 if (c_start == c_end).sum() != 1: 

569 raise ValueError("invalid connection") 

570 

571 # get the direction 

572 d: int = (c_start != c_end).argmax() 

573 

574 x: int 

575 y: int 

576 # pick whichever has the lesser value in the direction `d` 

577 if c_start[d] < c_end[d]: 

578 x, y = c_start 

579 else: 

580 x, y = c_end 

581 

582 connection_list[d, x, y] = True 

583 

584 return LatticeMaze( 

585 connection_list=connection_list, 

586 ) 

587 

588 def as_adj_list_tokens(self) -> list[str | CoordTup]: 

589 warnings.warn( 

590 "`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.", 

591 TokenizerDeprecationWarning, 

592 ) 

593 return [ 

594 SPECIAL_TOKENS.ADJLIST_START, 

595 *chain.from_iterable( # type: ignore[list-item] 

596 [ 

597 [ 

598 tuple(c_s), 

599 SPECIAL_TOKENS.CONNECTOR, 

600 tuple(c_e), 

601 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 

602 ] 

603 for c_s, c_e in self.as_adj_list() 

604 ] 

605 ), 

606 SPECIAL_TOKENS.ADJLIST_END, 

607 ] 

608 

609 def _as_adj_list_tokens(self) -> list[str | CoordTup]: 

610 return [ 

611 SPECIAL_TOKENS.ADJLIST_START, 

612 *chain.from_iterable( # type: ignore[list-item] 

613 [ 

614 [ 

615 tuple(c_s), 

616 SPECIAL_TOKENS.CONNECTOR, 

617 tuple(c_e), 

618 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 

619 ] 

620 for c_s, c_e in self.as_adj_list() 

621 ] 

622 ), 

623 SPECIAL_TOKENS.ADJLIST_END, 

624 ] 

625 

626 def _as_coords_and_special_AOTP(self) -> list[CoordTup | str]: 

627 """turn the maze into adjacency list, origin, target, and solution -- keep coords as tuples""" 

628 

629 output: list[CoordTup | str] = self._as_adj_list_tokens() 

630 # if getattr(self, "start_pos", None) is not None: 

631 if isinstance(self, TargetedLatticeMaze): 

632 output += self._get_start_pos_tokens() 

633 if isinstance(self, TargetedLatticeMaze): 

634 output += self._get_end_pos_tokens() 

635 if isinstance(self, SolvedMaze): 

636 output += self._get_solution_tokens() 

637 return output 

638 

639 def _as_tokens( 

640 self, maze_tokenizer: "MazeTokenizer | TokenizationMode" 

641 ) -> list[str]: 

642 # type ignores here fine since we check the instance 

643 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 

644 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 

645 if ( 

646 isinstance_by_type_name(maze_tokenizer, "MazeTokenizer") 

647 and maze_tokenizer.is_AOTP() # type: ignore[union-attr] 

648 ): 

649 coords_raw: list[CoordTup | str] = self._as_coords_and_special_AOTP() 

650 coords_processed: list[str] = maze_tokenizer.coords_to_strings( # type: ignore[union-attr] 

651 coords=coords_raw, when_noncoord="include" 

652 ) 

653 return coords_processed 

654 else: 

655 raise NotImplementedError(f"Unsupported tokenizer type: {maze_tokenizer}") 

656 

657 def as_tokens( 

658 self, 

659 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 

660 ) -> list[str]: 

661 """serialize maze and solution to tokens""" 

662 if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"): 

663 return maze_tokenizer.to_tokens(self) # type: ignore[union-attr] 

664 else: 

665 return self._as_tokens(maze_tokenizer) # type: ignore[union-attr,arg-type] 

666 

667 @classmethod 

668 def _from_tokens_AOTP( 

669 cls, tokens: list[str], maze_tokenizer: "MazeTokenizer | MazeTokenizerModular" 

670 ) -> "LatticeMaze": 

671 """create a LatticeMaze from a list of tokens""" 

672 

673 # figure out what input format 

674 # ======================================== 

675 if tokens[0] == SPECIAL_TOKENS.ADJLIST_START: 

676 adj_list_tokens = get_adj_list_tokens(tokens) 

677 else: 

678 # If we're not getting a "complete" tokenized maze, assume it's just a the adjacency list tokens 

679 adj_list_tokens = tokens 

680 warnings.warn( 

681 "Assuming input is just adjacency list tokens, no special tokens found" 

682 ) 

683 

684 # process edges for adjacency list 

685 # ======================================== 

686 edges: list[list[str]] = list_split( 

687 adj_list_tokens, 

688 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 

689 ) 

690 

691 coordinates: list[tuple[CoordTup, CoordTup]] = list() 

692 for e in edges: 

693 # skip last endline 

694 if len(e) != 0: 

695 # convert to coords, split start and end 

696 e_coords: list[CoordTup] = maze_tokenizer.strings_to_coords( 

697 e, 

698 # TODO: i changed this to a skip since we then pipe the coords into a numpy array 

699 # but I'm not entirely sure 

700 when_noncoord="include", 

701 ) 

702 assert len(e_coords) == 3, f"invalid edge: {e = } {e_coords = }" 

703 assert e_coords[1] == SPECIAL_TOKENS.CONNECTOR, ( 

704 f"invalid edge: {e = } {e_coords = }" 

705 ) 

706 coordinates.append((e_coords[0], e_coords[-1])) 

707 

708 assert all(len(c) == 2 for c in coordinates), ( 

709 f"invalid coordinates: {coordinates = }" 

710 ) 

711 adj_list: Int8[np.ndarray, "conn start_end coord"] = np.array(coordinates) 

712 assert tuple(adj_list.shape) == ( 

713 len(coordinates), 

714 2, 

715 2, 

716 ), f"invalid adj_list: {adj_list.shape = } {coordinates = }" 

717 

718 output_maze: LatticeMaze = cls.from_adj_list(adj_list) 

719 

720 # add start and end positions 

721 # ======================================== 

722 is_targeted: bool = False 

723 if all( 

724 x in tokens 

725 for x in ( 

726 SPECIAL_TOKENS.ORIGIN_START, 

727 SPECIAL_TOKENS.ORIGIN_END, 

728 SPECIAL_TOKENS.TARGET_START, 

729 SPECIAL_TOKENS.TARGET_END, 

730 ) 

731 ): 

732 start_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( 

733 get_origin_tokens(tokens), when_noncoord="error" 

734 ) 

735 end_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( 

736 get_target_tokens(tokens), when_noncoord="error" 

737 ) 

738 assert len(start_pos_list) == 1, ( 

739 f"invalid start_pos_list: {start_pos_list = }" 

740 ) 

741 assert len(end_pos_list) == 1, f"invalid end_pos_list: {end_pos_list = }" 

742 

743 start_pos: CoordTup = start_pos_list[0] 

744 end_pos: CoordTup = end_pos_list[0] 

745 

746 output_maze = TargetedLatticeMaze.from_lattice_maze( 

747 lattice_maze=output_maze, 

748 start_pos=start_pos, 

749 end_pos=end_pos, 

750 ) 

751 

752 is_targeted = True 

753 

754 if all( 

755 x in tokens for x in (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END) 

756 ): 

757 assert is_targeted, "maze must be targeted to have a solution" 

758 solution: list[CoordTup] = maze_tokenizer.strings_to_coords( 

759 get_path_tokens(tokens, trim_end=True), 

760 when_noncoord="error", 

761 ) 

762 output_maze = SolvedMaze.from_targeted_lattice_maze( 

763 # HACK: I think this is fine, but im not sure 

764 targeted_lattice_maze=output_maze, # type: ignore[arg-type] 

765 solution=solution, 

766 ) 

767 

768 return output_maze 

769 

770 @classmethod 

771 def from_tokens( 

772 cls, 

773 tokens: list[str], 

774 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 

775 ) -> "LatticeMaze": 

776 """ 

777 Constructs a maze from a tokenization. 

778 Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported. 

779 """ 

780 # HACK: type ignores here fine since we check the instance 

781 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 

782 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 

783 if ( 

784 isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular") 

785 and not maze_tokenizer.is_legacy_equivalent() # type: ignore[union-attr] 

786 ): 

787 raise NotImplementedError( 

788 f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}." 

789 ) 

790 

791 if isinstance(tokens, str): 

792 tokens = tokens.split() 

793 

794 if maze_tokenizer.is_AOTP(): # type: ignore[union-attr] 

795 return cls._from_tokens_AOTP(tokens, maze_tokenizer) # type: ignore[arg-type] 

796 else: 

797 raise NotImplementedError("only AOTP tokenization is supported") 

798 

799 # ============================================================ 

800 # to and from pixels 

801 # ============================================================ 

802 def _as_pixels_bw(self) -> BinaryPixelGrid: 

803 assert self.lattice_dim == 2, "only 2D mazes are supported" 

804 # Create an empty pixel grid with walls 

805 pixel_grid: Int[np.ndarray, "x y"] = np.full( 

806 (self.grid_shape[0] * 2 + 1, self.grid_shape[1] * 2 + 1), 

807 False, 

808 dtype=np.bool_, 

809 ) 

810 

811 # Set white nodes 

812 pixel_grid[1::2, 1::2] = True 

813 

814 # Set white connections (downward) 

815 for i, row in enumerate(self.connection_list[0]): 

816 for j, connected in enumerate(row): 

817 if connected: 

818 pixel_grid[i * 2 + 2, j * 2 + 1] = True 

819 

820 # Set white connections (rightward) 

821 for i, row in enumerate(self.connection_list[1]): 

822 for j, connected in enumerate(row): 

823 if connected: 

824 pixel_grid[i * 2 + 1, j * 2 + 2] = True 

825 

826 return pixel_grid 

827 

828 def as_pixels( 

829 self, 

830 show_endpoints: bool = True, 

831 show_solution: bool = True, 

832 ) -> PixelGrid: 

833 # HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze` 

834 # but solution, start_pos, end_pos not always defined 

835 # but its fine since we explicitly check the type 

836 if show_solution and not show_endpoints: 

837 raise ValueError("show_solution=True requires show_endpoints=True") 

838 # convert original bool pixel grid to RGB 

839 pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw() 

840 pixel_grid: PixelGrid = np.full( 

841 (*pixel_grid_bw.shape, 3), PixelColors.WALL, dtype=np.uint8 

842 ) 

843 pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN # noqa: E712 

844 

845 if self.__class__ == LatticeMaze: 

846 return pixel_grid 

847 

848 # set endpoints for TargetedLatticeMaze 

849 if self.__class__ == TargetedLatticeMaze: 

850 if show_endpoints: 

851 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 

852 PixelColors.START 

853 ) 

854 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 

855 PixelColors.END 

856 ) 

857 return pixel_grid 

858 

859 # set solution -- we only reach this part if `self.__class__ == SolvedMaze` 

860 if show_solution: 

861 for coord in self.solution: # type: ignore[attr-defined] 

862 pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH 

863 

864 # set pixels between coords 

865 for index, coord in enumerate(self.solution[:-1]): # type: ignore[attr-defined] 

866 next_coord = self.solution[index + 1] # type: ignore[attr-defined] 

867 # check they are adjacent using norm 

868 assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, ( 

869 f"Coords {coord} and {next_coord} are not adjacent" 

870 ) 

871 # set pixel between them 

872 pixel_grid[ 

873 coord[0] * 2 + 1 + next_coord[0] - coord[0], 

874 coord[1] * 2 + 1 + next_coord[1] - coord[1], 

875 ] = PixelColors.PATH 

876 

877 # set endpoints (again, since path would overwrite them) 

878 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 

879 PixelColors.START 

880 ) 

881 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 

882 PixelColors.END 

883 ) 

884 

885 return pixel_grid 

886 

887 @classmethod 

888 def _from_pixel_grid_bw( 

889 cls, pixel_grid: BinaryPixelGrid 

890 ) -> tuple[ConnectionList, tuple[int, int]]: 

891 grid_shape: tuple[int, int] = ( 

892 pixel_grid.shape[0] // 2, 

893 pixel_grid.shape[1] // 2, 

894 ) 

895 connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_) 

896 

897 # Extract downward connections 

898 connection_list[0] = pixel_grid[2::2, 1::2] 

899 

900 # Extract rightward connections 

901 connection_list[1] = pixel_grid[1::2, 2::2] 

902 

903 return connection_list, grid_shape 

904 

905 @classmethod 

906 def _from_pixel_grid_with_positions( 

907 cls, 

908 pixel_grid: PixelGrid | BinaryPixelGrid, 

909 marked_positions: dict[str, RGB], 

910 ) -> tuple[ConnectionList, tuple[int, int], dict[str, CoordArray]]: 

911 # Convert RGB pixel grid to Bool pixel grid 

912 # error: Incompatible types in assignment (expression has type 

913 # "numpy.bool[builtins.bool] | ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]", 

914 # variable has type "ndarray[Any, Any]") [assignment] 

915 pixel_grid_bw: BinaryPixelGrid = ~np.all( # type: ignore[assignment] 

916 pixel_grid == PixelColors.WALL, axis=-1 

917 ) 

918 connection_list: ConnectionList 

919 grid_shape: tuple[int, int] 

920 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid_bw) 

921 

922 # Find any marked positions 

923 out_positions: dict[str, CoordArray] = dict() 

924 for key, color in marked_positions.items(): 

925 pos_temp: Int[np.ndarray, "x y"] = np.argwhere( 

926 np.all(pixel_grid == color, axis=-1) 

927 ) 

928 pos_save: list[CoordTup] = list() 

929 for pos in pos_temp: 

930 # if it is a coordinate and not connection (transform position, %2==1) 

931 if pos[0] % 2 == 1 and pos[1] % 2 == 1: 

932 pos_save.append((pos[0] // 2, pos[1] // 2)) 

933 

934 out_positions[key] = np.array(pos_save) 

935 

936 return connection_list, grid_shape, out_positions 

937 

938 @classmethod 

939 def from_pixels( 

940 cls, 

941 pixel_grid: PixelGrid, 

942 ) -> "LatticeMaze": 

943 connection_list: ConnectionList 

944 grid_shape: tuple[int, int] 

945 

946 # if a binary pixel grid, return regular LatticeMaze 

947 if len(pixel_grid.shape) == 2: 

948 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid) 

949 return LatticeMaze(connection_list=connection_list) 

950 

951 # otherwise, detect and check it's valid 

952 cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid) 

953 if cls not in cls_detected.__mro__: 

954 raise ValueError( 

955 f"Pixel grid cannot be cast to {cls.__name__}, detected type {cls_detected.__name__}" 

956 ) 

957 

958 ( 

959 connection_list, 

960 grid_shape, 

961 marked_pos, 

962 ) = cls._from_pixel_grid_with_positions( 

963 pixel_grid=pixel_grid, 

964 marked_positions=dict( 

965 start=PixelColors.START, end=PixelColors.END, solution=PixelColors.PATH 

966 ), 

967 ) 

968 # if we wanted a LatticeMaze, return it 

969 if cls == LatticeMaze: 

970 return LatticeMaze(connection_list=connection_list) 

971 

972 # otherwise, keep going 

973 temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list) 

974 

975 # start and end pos 

976 start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"] 

977 assert start_pos_arr.shape == ( 

978 1, 

979 2, 

980 ), ( 

981 f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 

982 ) 

983 assert end_pos_arr.shape == ( 

984 1, 

985 2, 

986 ), ( 

987 f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 

988 ) 

989 

990 start_pos: Coord = start_pos_arr[0] 

991 end_pos: Coord = end_pos_arr[0] 

992 

993 # return a TargetedLatticeMaze if that's what we wanted 

994 if cls == TargetedLatticeMaze: 

995 return TargetedLatticeMaze( 

996 connection_list=connection_list, 

997 start_pos=start_pos, 

998 end_pos=end_pos, 

999 ) 

1000 

1001 # raw solution, only contains path elements and not start or end 

1002 solution_raw: CoordArray = marked_pos["solution"] 

1003 if len(solution_raw.shape) == 2: 

1004 assert solution_raw.shape[1] == 2, ( 

1005 f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)" 

1006 ) 

1007 elif solution_raw.shape == (0,): 

1008 # the solution and end should be immediately adjacent 

1009 assert np.sum(np.abs(start_pos - end_pos)) == 1, ( 

1010 f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given" 

1011 ) 

1012 

1013 # order the solution, by creating a list from the start to the end 

1014 # add end pos, since we will iterate over all these starting from the start pos 

1015 solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [ 

1016 tuple(end_pos) 

1017 ] 

1018 # solution starts with start point 

1019 solution: list[CoordTup] = [tuple(start_pos)] 

1020 while solution[-1] != tuple(end_pos): 

1021 # use `get_coord_neighbors` to find connected neighbors 

1022 neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1]) 

1023 # TODO: make this less ugly 

1024 assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), ( 

1025 f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}" 

1026 ) 

1027 # neighbors = neighbors[:, [1, 0]] 

1028 # filter out neighbors that are not in the raw solution 

1029 neighbors_filtered: CoordArray = np.array( 

1030 [ 

1031 coord 

1032 for coord in neighbors 

1033 if ( 

1034 tuple(coord) in solution_raw_list 

1035 and tuple(coord) not in solution 

1036 ) 

1037 ] 

1038 ) 

1039 # assert only one element is left, and then add it to the solution 

1040 assert neighbors_filtered.shape == ( 

1041 1, 

1042 2, 

1043 ), ( 

1044 f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}" 

1045 ) 

1046 solution.append(tuple(neighbors_filtered[0])) 

1047 

1048 # assert the solution is complete 

1049 assert solution[0] == tuple(start_pos), ( 

1050 f"solution {solution} does not start at start_pos {start_pos}" 

1051 ) 

1052 assert solution[-1] == tuple(end_pos), ( 

1053 f"solution {solution} does not end at end_pos {end_pos}" 

1054 ) 

1055 

1056 return cls( 

1057 connection_list=np.array(connection_list), 

1058 solution=np.array(solution), # type: ignore[call-arg] 

1059 ) 

1060 

1061 # ============================================================ 

1062 # to and from ASCII 

1063 # ============================================================ 

1064 def _as_ascii_grid(self) -> Shaped[np.ndarray, "x y"]: 

1065 # Get the pixel grid using to_pixels(). 

1066 pixel_grid: Bool[np.ndarray, "x y"] = self._as_pixels_bw() 

1067 

1068 # Replace pixel values with ASCII characters. 

1069 ascii_grid: Shaped[np.ndarray, "x y"] = np.full( 

1070 pixel_grid.shape, AsciiChars.WALL, dtype=str 

1071 ) 

1072 ascii_grid[pixel_grid == True] = AsciiChars.OPEN # noqa: E712 

1073 

1074 return ascii_grid 

1075 

1076 def as_ascii( 

1077 self, 

1078 show_endpoints: bool = True, 

1079 show_solution: bool = True, 

1080 ) -> str: 

1081 """return an ASCII grid of the maze""" 

1082 ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid() 

1083 pixel_grid: PixelGrid = self.as_pixels( 

1084 show_endpoints=show_endpoints, show_solution=show_solution 

1085 ) 

1086 

1087 chars_replace: tuple = tuple() 

1088 if show_endpoints: 

1089 chars_replace += (AsciiChars.START, AsciiChars.END) 

1090 if show_solution: 

1091 chars_replace += (AsciiChars.PATH,) 

1092 

1093 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 

1094 if ascii_char in chars_replace: 

1095 ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char 

1096 

1097 return "\n".join("".join(row) for row in ascii_grid) 

1098 

1099 @classmethod 

1100 def from_ascii(cls, ascii_str: str) -> "LatticeMaze": 

1101 lines: list[str] = ascii_str.strip().split("\n") 

1102 lines = [line.strip() for line in lines] 

1103 ascii_grid: Shaped[np.ndarray, "x y"] = np.array( 

1104 [list(line) for line in lines], dtype=str 

1105 ) 

1106 pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8) 

1107 

1108 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 

1109 pixel_grid[ascii_grid == ascii_char] = pixel_color 

1110 

1111 return cls.from_pixels(pixel_grid) 

1112 

1113 

1114# type ignore here even though theyre all frozen 

1115# maybe `SerializeableDataclass` itself is not frozen, but thats an ABC 

1116# error: Cannot inherit frozen dataclass from a non-frozen one [misc] 

1117@serializable_dataclass(frozen=True, kw_only=True) 

1118class TargetedLatticeMaze(LatticeMaze): # type: ignore[misc] 

1119 """A LatticeMaze with a start and end position""" 

1120 

1121 # this jank is so that SolvedMaze can inherit from this class without needing arguments for start_pos and end_pos 

1122 # type ignore here because even though its a kw-only dataclass, 

1123 # mypy doesn't like that non-default arguments are after default arguments 

1124 start_pos: Coord = serializable_field( # type: ignore[misc] 

1125 assert_type=False, 

1126 ) 

1127 end_pos: Coord = serializable_field( # type: ignore[misc] 

1128 assert_type=False, 

1129 ) 

1130 

1131 def __post_init__(self) -> None: 

1132 # make things numpy arrays (very jank to override frozen dataclass) 

1133 self.__dict__["start_pos"] = np.array(self.start_pos) 

1134 self.__dict__["end_pos"] = np.array(self.end_pos) 

1135 assert self.start_pos is not None 

1136 assert self.end_pos is not None 

1137 # check that start and end are in bounds 

1138 if ( 

1139 self.start_pos[0] >= self.grid_shape[0] 

1140 or self.start_pos[1] >= self.grid_shape[1] 

1141 ): 

1142 raise ValueError( 

1143 f"start_pos {self.start_pos} is out of bounds for grid shape {self.grid_shape}" 

1144 ) 

1145 if ( 

1146 self.end_pos[0] >= self.grid_shape[0] 

1147 or self.end_pos[1] >= self.grid_shape[1] 

1148 ): 

1149 raise ValueError( 

1150 f"end_pos {self.end_pos} is out of bounds for grid shape {self.grid_shape}" 

1151 ) 

1152 

1153 def __eq__(self, other: object) -> bool: 

1154 return super().__eq__(other) 

1155 

1156 def _get_start_pos_tokens(self) -> list[str | CoordTup]: 

1157 return [ 

1158 SPECIAL_TOKENS.ORIGIN_START, 

1159 tuple(self.start_pos), 

1160 SPECIAL_TOKENS.ORIGIN_END, 

1161 ] 

1162 

1163 def get_start_pos_tokens(self) -> list[str | CoordTup]: 

1164 warnings.warn( 

1165 "`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.", 

1166 TokenizerDeprecationWarning, 

1167 ) 

1168 return self._get_start_pos_tokens() 

1169 

1170 def _get_end_pos_tokens(self) -> list[str | CoordTup]: 

1171 return [ 

1172 SPECIAL_TOKENS.TARGET_START, 

1173 tuple(self.end_pos), 

1174 SPECIAL_TOKENS.TARGET_END, 

1175 ] 

1176 

1177 def get_end_pos_tokens(self) -> list[str | CoordTup]: 

1178 warnings.warn( 

1179 "`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.", 

1180 TokenizerDeprecationWarning, 

1181 ) 

1182 return self._get_end_pos_tokens() 

1183 

1184 @classmethod 

1185 def from_lattice_maze( 

1186 cls, 

1187 lattice_maze: LatticeMaze, 

1188 start_pos: Coord | CoordTup, 

1189 end_pos: Coord | CoordTup, 

1190 ) -> "TargetedLatticeMaze": 

1191 return cls( 

1192 connection_list=lattice_maze.connection_list, 

1193 start_pos=np.array(start_pos), 

1194 end_pos=np.array(end_pos), 

1195 generation_meta=lattice_maze.generation_meta, 

1196 ) 

1197 

1198 

1199@serializable_dataclass(frozen=True, kw_only=True) 

1200class SolvedMaze(TargetedLatticeMaze): # type: ignore[misc] 

1201 """Stores a maze and a solution""" 

1202 

1203 solution: CoordArray = serializable_field( # type: ignore[misc] 

1204 assert_type=False, 

1205 ) 

1206 

1207 def __init__( 

1208 self, 

1209 connection_list: ConnectionList, 

1210 solution: CoordArray, 

1211 generation_meta: dict | None = None, 

1212 start_pos: Coord | None = None, 

1213 end_pos: Coord | None = None, 

1214 allow_invalid: bool = False, 

1215 ) -> None: 

1216 # figure out the solution 

1217 solution_valid: bool = False 

1218 if solution is not None: 

1219 solution = np.array(solution) 

1220 # note that a path length of 1 here is valid, since the start and end pos could be the same 

1221 if (solution.shape[0] > 0) and (solution.shape[1] == 2): 

1222 solution_valid = True 

1223 

1224 if not solution_valid and not allow_invalid: 

1225 raise ValueError( 

1226 f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }", 

1227 f"{connection_list = }", 

1228 ) 

1229 

1230 # init the TargetedLatticeMaze 

1231 super().__init__( 

1232 connection_list=connection_list, 

1233 generation_meta=generation_meta, 

1234 # TODO: the argument type is stricter than the expected type but it still fails? 

1235 # error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type 

1236 # "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]" [arg-type] 

1237 start_pos=np.array(solution[0]) if solution_valid else None, # type: ignore[arg-type] 

1238 end_pos=np.array(solution[-1]) if solution_valid else None, # type: ignore[arg-type] 

1239 ) 

1240 

1241 self.__dict__["solution"] = solution 

1242 

1243 # adjust the endpoints 

1244 if not allow_invalid: 

1245 if start_pos is not None: 

1246 assert np.array_equal(np.array(start_pos), self.start_pos), ( 

1247 f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}" 

1248 ) 

1249 if end_pos is not None: 

1250 assert np.array_equal(np.array(end_pos), self.end_pos), ( 

1251 f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}" 

1252 ) 

1253 # TODO: assert the path does not backtrack, walk through walls, etc? 

1254 

1255 def __eq__(self, other: object) -> bool: 

1256 return super().__eq__(other) 

1257 

1258 def __hash__(self) -> int: 

1259 return hash((self.connection_list.tobytes(), self.solution.tobytes())) 

1260 

1261 def _get_solution_tokens(self) -> list[str | CoordTup]: 

1262 return [ 

1263 SPECIAL_TOKENS.PATH_START, 

1264 *[tuple(c) for c in self.solution], 

1265 SPECIAL_TOKENS.PATH_END, 

1266 ] 

1267 

1268 def get_solution_tokens(self) -> list[str | CoordTup]: 

1269 warnings.warn( 

1270 "`LatticeMaze.get_solution_tokens` is deprecated.", 

1271 TokenizerDeprecationWarning, 

1272 ) 

1273 return self._get_solution_tokens() 

1274 

1275 # for backwards compatibility 

1276 @property 

1277 def maze(self) -> LatticeMaze: 

1278 warnings.warn( 

1279 "`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.", 

1280 DeprecationWarning, 

1281 ) 

1282 return LatticeMaze(connection_list=self.connection_list) 

1283 

1284 # type ignore here since we're overriding a method with a different signature 

1285 @classmethod 

1286 def from_lattice_maze( # type: ignore[override] 

1287 cls, lattice_maze: LatticeMaze, solution: list[CoordTup] 

1288 ) -> "SolvedMaze": 

1289 return cls( 

1290 connection_list=lattice_maze.connection_list, 

1291 solution=np.array(solution), 

1292 generation_meta=lattice_maze.generation_meta, 

1293 ) 

1294 

1295 @classmethod 

1296 def from_targeted_lattice_maze( 

1297 cls, 

1298 targeted_lattice_maze: TargetedLatticeMaze, 

1299 solution: list[CoordTup] | CoordArray | None = None, 

1300 ) -> "SolvedMaze": 

1301 """solves the given targeted lattice maze and returns a SolvedMaze""" 

1302 if solution is None: 

1303 solution = targeted_lattice_maze.find_shortest_path( 

1304 targeted_lattice_maze.start_pos, 

1305 targeted_lattice_maze.end_pos, 

1306 ) 

1307 return cls( 

1308 connection_list=targeted_lattice_maze.connection_list, 

1309 solution=np.array(solution), 

1310 generation_meta=targeted_lattice_maze.generation_meta, 

1311 ) 

1312 

1313 def get_solution_forking_points( 

1314 self, 

1315 always_include_endpoints: bool = False, 

1316 ) -> tuple[list[int], CoordArray]: 

1317 """coordinates and their indicies from the solution where a fork is present 

1318 

1319 - if the start point is not a dead end, this counts as a fork 

1320 - if the end point is not a dead end, this counts as a fork 

1321 """ 

1322 output_idxs: list[int] = list() 

1323 output_coords: list[CoordTup] = list() 

1324 

1325 for idx, coord in enumerate(self.solution): 

1326 # more than one choice for first coord, or more than 2 for any other 

1327 # since the previous coord doesn't count as a choice 

1328 is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1 

1329 theshold: int = 1 if is_endpoint else 2 

1330 if self.get_coord_neighbors(coord).shape[0] > theshold or ( 

1331 is_endpoint and always_include_endpoints 

1332 ): 

1333 output_idxs.append(idx) 

1334 output_coords.append(coord) 

1335 

1336 return output_idxs, np.array(output_coords) 

1337 

1338 def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]: 

1339 """coordinates from the solution where there is only a single (non-backtracking) point to move to 

1340 

1341 returns the complement of `get_solution_forking_points` from the path""" 

1342 forks_idxs, _ = self.get_solution_forking_points() 

1343 # HACK: idk why type ignore here 

1344 return ( # type: ignore[return-value] 

1345 np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0), 

1346 np.delete(self.solution, forks_idxs, axis=0), 

1347 ) 

1348 

1349 

1350def detect_pixels_type(data: PixelGrid) -> typing.Type[LatticeMaze]: 

1351 """Detects the type of pixels data by checking for the presence of start and end pixels""" 

1352 if color_in_pixel_grid(data, PixelColors.START) or color_in_pixel_grid( 

1353 data, PixelColors.END 

1354 ): 

1355 if color_in_pixel_grid(data, PixelColors.PATH): 

1356 return SolvedMaze 

1357 else: 

1358 return TargetedLatticeMaze 

1359 else: 

1360 return LatticeMaze 

1361 

1362 

1363def _remove_isolated_cells( 

1364 image: Int[np.ndarray, "RGB x y"], 

1365) -> Int[np.ndarray, "RGB x y"]: 

1366 """ 

1367 Removes isolated cells from an image. An isolated cell is a cell that is surrounded by walls on all sides. 

1368 """ 

1369 # Create a binary mask where True represents walls 

1370 wall_mask = np.all(image == PixelColors.WALL, axis=-1) 

1371 

1372 # Pad the wall mask to handle edge cases 

1373 padded_wall_mask = np.pad( 

1374 wall_mask, ((1, 1), (1, 1)), mode="constant", constant_values=True 

1375 ) 

1376 

1377 # Check neighbors in all four directions 

1378 isolated_mask = ( 

1379 padded_wall_mask[1:-1, 2:] # right 

1380 & padded_wall_mask[1:-1, :-2] # left 

1381 & padded_wall_mask[2:, 1:-1] # down 

1382 & padded_wall_mask[:-2, 1:-1] # up 

1383 ) 

1384 

1385 # Combine with non-wall mask to only affect open cells 

1386 non_wall_mask = ~wall_mask 

1387 isolated_mask = isolated_mask & non_wall_mask 

1388 

1389 # Create the output image 

1390 output_image = image.copy() 

1391 output_image[isolated_mask] = PixelColors.WALL 

1392 

1393 return output_image 

1394 

1395 

1396_RIC_PADS: dict = { 

1397 "left": ((1, 0), (0, 0)), 

1398 "right": ((0, 1), (0, 0)), 

1399 "up": ((0, 0), (1, 0)), 

1400 "down": ((0, 0), (0, 1)), 

1401} 

1402 

1403# Define slices for each direction 

1404_RIC_SLICES: dict = { 

1405 "left": (slice(1, None), slice(None, None)), 

1406 "right": (slice(None, -1), slice(None, None)), 

1407 "up": (slice(None, None), slice(1, None)), 

1408 "down": (slice(None, None), slice(None, -1)), 

1409} 

1410 

1411 

1412# TODO: figure out why this function doesnt work, or maybe just get rid of it 

1413# def _remove_isolated_cells_old( 

1414# image: Int[np.ndarray, "RGB x y"], 

1415# ) -> Int[np.ndarray, "RGB x y"]: 

1416# """ 

1417# Removes isolated cells from an image. An isolated cell is a cell that is surrounded by walls on all sides. 

1418# """ 

1419# warnings.warn("this functin doesn't work and I have no idea why!!!") 

1420# masks: dict[str, np.ndarray] = { 

1421# d: np.all( 

1422# np.pad( 

1423# image[_RIC_SLICES[d][0], _RIC_SLICES[d][1], :] == PixelColors.WALL, 

1424# np.array((*_RIC_PADS[d], (0, 0)), dtype=np.int8), 

1425# mode="constant", 

1426# constant_values=True, 

1427# ), 

1428# axis=2, 

1429# ) 

1430# for d in _RIC_SLICES.keys() 

1431# } 

1432 

1433# # Create a mask for non-wall cells 

1434# mask_non_wall = np.all(image != PixelColors.WALL, axis=2) 

1435 

1436# # print(f"{mask_non_wall.shape = }") 

1437# # print(f"{ {k: masks[k].shape for k in masks.keys()} = }") 

1438 

1439# # print(f"{mask_non_wall = }") 

1440# # print(f"{masks['down'] = }") 

1441 

1442# # Combine the masks 

1443# mask = mask_non_wall & masks["left"] & masks["right"] & masks["up"] & masks["down"] 

1444 

1445# # Apply the mask 

1446# output_image = np.where( 

1447# np.stack([mask] * 3, axis=-1), 

1448# PixelColors.WALL, 

1449# image, 

1450# ) 

1451 

1452# return output_image