Coverage for maze_dataset\plotting\plot_maze.py: 84%

212 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1"""provides `MazePlot`, which has many tools for plotting mazes with multiple paths, colored nodes, and more""" 

2 

3from __future__ import annotations # for type hinting self as return value 

4 

5import warnings 

6from copy import deepcopy 

7from dataclasses import dataclass 

8 

9import matplotlib as mpl 

10import matplotlib.pyplot as plt 

11import numpy as np 

12from jaxtyping import Bool, Float 

13 

14from maze_dataset.constants import Coord, CoordArray, CoordList 

15from maze_dataset.maze import ( 

16 LatticeMaze, 

17 SolvedMaze, 

18 TargetedLatticeMaze, 

19) 

20 

21LARGE_NEGATIVE_NUMBER: float = -1e10 

22 

23 

24@dataclass(kw_only=True) 

25class PathFormat: 

26 """formatting options for path plot""" 

27 

28 label: str | None = None 

29 fmt: str = "o" 

30 color: str | None = None 

31 cmap: str | None = None 

32 line_width: float | None = None 

33 quiver_kwargs: dict | None = None 

34 

35 def combine(self, other: PathFormat) -> PathFormat: 

36 """combine with other PathFormat object, overwriting attributes with non-None values. 

37 

38 returns a modified copy of self. 

39 """ 

40 output: PathFormat = deepcopy(self) 

41 for key, value in other.__dict__.items(): 

42 if key == "path": 

43 raise ValueError( 

44 f"Cannot overwrite path attribute! {self = }, {other = }" 

45 ) 

46 if value is not None: 

47 setattr(output, key, value) 

48 

49 return output 

50 

51 

52# styled path 

53@dataclass 

54class StyledPath(PathFormat): 

55 path: CoordArray 

56 

57 

58DEFAULT_FORMATS: dict[str, PathFormat] = { 

59 "true": PathFormat( 

60 label="true path", 

61 fmt="--", 

62 color="red", 

63 line_width=2.5, 

64 quiver_kwargs=None, 

65 ), 

66 "predicted": PathFormat( 

67 label=None, 

68 fmt=":", 

69 color=None, 

70 line_width=2, 

71 quiver_kwargs={"width": 0.015}, 

72 ), 

73} 

74 

75 

76def process_path_input( 

77 path: CoordList | CoordArray | StyledPath, 

78 _default_key: str, 

79 path_fmt: PathFormat | None = None, 

80 **kwargs, 

81) -> StyledPath: 

82 styled_path: StyledPath 

83 if isinstance(path, StyledPath): 

84 styled_path = path 

85 elif isinstance(path, np.ndarray): 

86 styled_path = StyledPath(path=path) 

87 # add default formatting 

88 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) 

89 elif isinstance(path, list): 

90 styled_path = StyledPath(path=np.array(path)) 

91 # add default formatting 

92 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) 

93 else: 

94 raise TypeError( 

95 f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}" 

96 ) 

97 

98 # add formatting from path_fmt 

99 if path_fmt is not None: 

100 styled_path = styled_path.combine(path_fmt) 

101 

102 # add formatting from kwargs 

103 for key, value in kwargs.items(): 

104 setattr(styled_path, key, value) 

105 

106 return styled_path 

107 

108 

109class MazePlot: 

110 """Class for displaying mazes and paths""" 

111 

112 DEFAULT_PREDICTED_PATH_COLORS: list[str] = [ 

113 "tab:orange", 

114 "tab:olive", 

115 "sienna", 

116 "mediumseagreen", 

117 "tab:purple", 

118 "slategrey", 

119 ] 

120 

121 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None: 

122 """ 

123 UNIT_LENGTH: Set ratio between node size and wall thickness in image. 

124 Wall thickness is fixed to 1px 

125 A "unit" consists of a single node and the right and lower connection/wall. 

126 Example: ul = 14 yields 13:1 ratio between node size and wall thickness 

127 """ 

128 self.unit_length: int = unit_length 

129 self.maze: LatticeMaze = maze 

130 self.true_path: StyledPath | None = None 

131 self.predicted_paths: list[StyledPath] = [] 

132 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None 

133 self.custom_node_value_flag: bool = False 

134 self.node_color_map: str = "Blues" 

135 self.target_token_coord: Coord = None 

136 self.preceding_tokens_coords: CoordArray = None 

137 self.colormap_center: float | None = None 

138 self.cbar_ax = None 

139 self.marked_coords: list[tuple[Coord, dict]] = list() 

140 

141 self.marker_kwargs_current: dict = dict( 

142 marker="s", 

143 color="green", 

144 ms=12, 

145 ) 

146 self.marker_kwargs_next: dict = dict( 

147 marker="P", 

148 color="green", 

149 ms=12, 

150 ) 

151 

152 if isinstance(maze, SolvedMaze): 

153 self.add_true_path(maze.solution) 

154 else: 

155 if isinstance(maze, TargetedLatticeMaze): 

156 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution) 

157 

158 @property 

159 def solved_maze(self) -> SolvedMaze: 

160 if self.true_path is None: 

161 raise ValueError( 

162 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method." 

163 ) 

164 return SolvedMaze.from_lattice_maze( 

165 lattice_maze=self.maze, 

166 solution=self.true_path.path, 

167 ) 

168 

169 def add_true_path( 

170 self, 

171 path: CoordList | CoordArray | StyledPath, 

172 path_fmt: PathFormat | None = None, 

173 **kwargs, 

174 ) -> MazePlot: 

175 self.true_path = process_path_input( 

176 path=path, 

177 _default_key="true", 

178 path_fmt=path_fmt, 

179 **kwargs, 

180 ) 

181 

182 return self 

183 

184 def add_predicted_path( 

185 self, 

186 path: CoordList | CoordArray | StyledPath, 

187 path_fmt: PathFormat | None = None, 

188 **kwargs, 

189 ) -> MazePlot: 

190 """ 

191 Recieve predicted path and formatting preferences from input and save in predicted_path list. 

192 Default formatting depends on nuber of paths already saved in predicted path list. 

193 """ 

194 styled_path: StyledPath = process_path_input( 

195 path=path, 

196 _default_key="predicted", 

197 path_fmt=path_fmt, 

198 **kwargs, 

199 ) 

200 

201 # set default label and color if not specified 

202 if styled_path.label is None: 

203 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}" 

204 

205 if styled_path.color is None: 

206 color_num: int = len(self.predicted_paths) % len( 

207 self.DEFAULT_PREDICTED_PATH_COLORS 

208 ) 

209 styled_path.color = self.DEFAULT_PREDICTED_PATH_COLORS[color_num] 

210 

211 self.predicted_paths.append(styled_path) 

212 return self 

213 

214 def add_multiple_paths(self, path_list: list[CoordList | CoordArray | StyledPath]): 

215 """ 

216 Function for adding multiple paths to MazePlot at once. This can be done in two ways: 

217 1. Passing a list of 

218 """ 

219 for path in path_list: 

220 self.add_predicted_path(path) 

221 return self 

222 

223 def add_node_values( 

224 self, 

225 node_values: Float[np.ndarray, "grid_n grid_n"], 

226 color_map: str = "Blues", 

227 target_token_coord: Coord | None = None, 

228 preceeding_tokens_coords: CoordArray = None, 

229 colormap_center: float | None = None, 

230 colormap_max: float | None = None, 

231 hide_colorbar: bool = False, 

232 ) -> MazePlot: 

233 assert node_values.shape == self.maze.grid_shape, ( 

234 "Please pass node values of the same sape as LatticeMaze.grid_shape" 

235 ) 

236 # assert np.min(node_values) >= 0, "Please pass non-negative node values only." 

237 

238 self.node_values = node_values 

239 # Set flag for choosing cmap while plotting maze 

240 self.custom_node_value_flag = True 

241 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero 

242 self.node_color_map = color_map 

243 self.colormap_center = colormap_center 

244 self.colormap_max = colormap_max 

245 self.hide_colorbar = hide_colorbar 

246 

247 if target_token_coord is not None: 

248 self.marked_coords.append((target_token_coord, self.marker_kwargs_next)) 

249 if preceeding_tokens_coords is not None: 

250 for coord in preceeding_tokens_coords: 

251 self.marked_coords.append((coord, self.marker_kwargs_current)) 

252 return self 

253 

254 def plot( 

255 self, 

256 dpi: int = 100, 

257 title: str = "", 

258 fig_ax: tuple | None = None, 

259 plain: bool = False, 

260 ) -> MazePlot: 

261 """Plot the maze and paths.""" 

262 

263 # set up figure 

264 if fig_ax is None: 

265 self.fig = plt.figure(dpi=dpi) 

266 self.ax = self.fig.add_subplot(1, 1, 1) 

267 else: 

268 self.fig, self.ax = fig_ax 

269 

270 # plot maze 

271 self._plot_maze() 

272 

273 # Plot labels 

274 if not plain: 

275 tick_arr = np.arange(self.maze.grid_shape[0]) 

276 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr) 

277 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr) 

278 self.ax.set_xlabel("col") 

279 self.ax.set_ylabel("row") 

280 self.ax.set_title(title) 

281 

282 # plot paths 

283 if self.true_path is not None: 

284 self._plot_path(self.true_path) 

285 for path in self.predicted_paths: 

286 self._plot_path(path) 

287 

288 # plot markers 

289 for coord, kwargs in self.marked_coords: 

290 self._place_marked_coords([coord], **kwargs) 

291 

292 return self 

293 

294 def _rowcol_to_coord(self, point: Coord) -> np.ndarray: 

295 """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis.""" 

296 point = np.array([point[1], point[0]]) 

297 return self.unit_length * (point + 0.5) 

298 

299 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot: 

300 kwargs = { 

301 **dict(marker="+", color="blue"), 

302 **kwargs, 

303 } 

304 for coord in coords: 

305 self.marked_coords.append((coord, kwargs)) 

306 

307 return self 

308 

309 def _place_marked_coords( 

310 self, coords: CoordArray | list[Coord], **kwargs 

311 ) -> MazePlot: 

312 coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords]) 

313 self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs) 

314 

315 return self 

316 

317 def _plot_maze(self) -> None: 

318 """ 

319 Define Colormap and plot maze. 

320 Colormap: x is -inf: black 

321 else: use colormap 

322 """ 

323 img = self._lattice_maze_to_img() 

324 

325 # if no node_values have been passed (no colormap) 

326 if self.custom_node_value_flag is False: 

327 self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1) 

328 

329 else: 

330 assert self.node_values is not None, "Please pass node values." 

331 assert not np.isnan(self.node_values).any(), ( 

332 "Please pass node values, they cannot be nan." 

333 ) 

334 

335 vals_min: float = np.nanmin(self.node_values) 

336 vals_max: float = np.nanmax(self.node_values) 

337 # if both are negative or both are positive, set max/min to 0 

338 if vals_max < 0.0: 

339 vals_max = 0.0 

340 elif vals_min > 0.0: 

341 vals_min = 0.0 

342 

343 # adjust vals_max, in case you need consistent colorbar across multiple plots 

344 vals_max = self.colormap_max or vals_max 

345 

346 # create colormap 

347 cmap = mpl.colormaps[self.node_color_map] 

348 # TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black 

349 cmap.set_bad(color="black") 

350 

351 if self.colormap_center is not None: 

352 if not (vals_min < self.colormap_center < vals_max): 

353 if vals_min == self.colormap_center: 

354 vals_min -= 1e-10 

355 elif vals_max == self.colormap_center: 

356 vals_max += 1e-10 

357 else: 

358 raise ValueError( 

359 f"Please pass colormap_center value between {vals_min} and {vals_max}" 

360 ) 

361 

362 norm = mpl.colors.TwoSlopeNorm( 

363 vmin=vals_min, 

364 vcenter=self.colormap_center, 

365 vmax=vals_max, 

366 ) 

367 _plotted = self.ax.imshow(img, cmap=cmap, norm=norm) 

368 else: 

369 _plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max) 

370 

371 # Add colorbar based on the condition of self.hide_colorbar 

372 if not self.hide_colorbar: 

373 ticks = np.linspace(vals_min, vals_max, 5) 

374 

375 if (vals_min < 0.0 < vals_max) and (0.0 not in ticks): 

376 ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0) 

377 

378 if ( 

379 self.colormap_center is not None 

380 and self.colormap_center not in ticks 

381 and vals_min < self.colormap_center < vals_max 

382 ): 

383 ticks = np.insert( 

384 ticks, 

385 np.searchsorted(ticks, self.colormap_center), 

386 self.colormap_center, 

387 ) 

388 

389 cbar = plt.colorbar( 

390 _plotted, 

391 ticks=ticks, 

392 ax=self.ax, 

393 cax=self.cbar_ax, 

394 ) 

395 self.cbar_ax = cbar.ax 

396 

397 # make the boundaries of the image thicker (walls look weird without this) 

398 for axis in ["top", "bottom", "left", "right"]: 

399 self.ax.spines[axis].set_linewidth(2) 

400 

401 def _lattice_maze_to_img( 

402 self, 

403 connection_val_scale: float = 0.93, 

404 ) -> Bool[np.ndarray, "row col"]: 

405 """ 

406 Build an image to visualise the maze. 

407 Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul. 

408 - Nodes have area: (ul-1) * (ul-1) and value 1 by default 

409 - take node_value if passed via .add_node_values() 

410 - Walls have area: 1 * (ul-1) and value -1 

411 - Connections have area: 1 * (ul-1); color and value 0.93 by default 

412 - take node_value if passed via .add_node_values() 

413 

414 Axes definition: 

415 (0,0) col 

416 ----|-----------> 

417 | 

418 row | 

419 | 

420 v 

421 

422 Returns a matrix of side length (ul) * n + 1 where n is the number of nodes. 

423 """ 

424 

425 # TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored 

426 node_bdry_hack: int 

427 connection_list_processed: Float[np.ndarray, "dim row col"] 

428 # Set node and connection values 

429 if self.node_values is None: 

430 scaled_node_values = np.ones(self.maze.grid_shape) 

431 connection_values = scaled_node_values * connection_val_scale 

432 node_bdry_hack = 0 

433 # TODO: hack 

434 # invert connection list 

435 connection_list_processed = np.logical_not(self.maze.connection_list) 

436 else: 

437 # TODO: hack 

438 scaled_node_values = self.node_values 

439 # connection_values = scaled_node_values 

440 connection_values = np.full_like(scaled_node_values, np.nan) 

441 node_bdry_hack = 1 

442 connection_list_processed = self.maze.connection_list 

443 

444 # Create background image (all pixels set to -1, walls everywhere) 

445 img: Float[np.ndarray, "row col"] = -np.ones( 

446 ( 

447 self.maze.grid_shape[0] * self.unit_length + 1, 

448 self.maze.grid_shape[1] * self.unit_length + 1, 

449 ), 

450 dtype=float, 

451 ) 

452 

453 # Draw nodes and connections by iterating through lattice 

454 for row in range(self.maze.grid_shape[0]): 

455 for col in range(self.maze.grid_shape[1]): 

456 # Draw node 

457 img[ 

458 row * self.unit_length + 1 : (row + 1) * self.unit_length 

459 + node_bdry_hack, 

460 col * self.unit_length + 1 : (col + 1) * self.unit_length 

461 + node_bdry_hack, 

462 ] = scaled_node_values[row, col] 

463 

464 # Down connection 

465 if not connection_list_processed[0, row, col]: 

466 img[ 

467 (row + 1) * self.unit_length, 

468 col * self.unit_length + 1 : (col + 1) * self.unit_length, 

469 ] = connection_values[row, col] 

470 

471 # Right connection 

472 if not connection_list_processed[1, row, col]: 

473 img[ 

474 row * self.unit_length + 1 : (row + 1) * self.unit_length, 

475 (col + 1) * self.unit_length, 

476 ] = connection_values[row, col] 

477 

478 return img 

479 

480 def _plot_path(self, path_format: PathFormat) -> None: 

481 if len(path_format.path) == 0: 

482 warnings.warn(f"Empty path, skipping plotting\n{path_format = }") 

483 return 

484 p_transformed = np.array( 

485 [self._rowcol_to_coord(coord) for coord in path_format.path] 

486 ) 

487 if path_format.quiver_kwargs is not None: 

488 try: 

489 x: np.ndarray = p_transformed[:, 0] 

490 y: np.ndarray = p_transformed[:, 1] 

491 except Exception as e: 

492 raise ValueError( 

493 f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}" 

494 ) from e 

495 

496 # Generate colors from the colormap 

497 if path_format.cmap is not None: 

498 n = len(x) - 1 # Number of arrows 

499 cmap = plt.get_cmap(path_format.cmap) 

500 colors = [cmap(i / n) for i in range(n)] 

501 else: 

502 colors = path_format.color 

503 

504 self.ax.quiver( 

505 x[:-1], 

506 y[:-1], 

507 x[1:] - x[:-1], 

508 y[1:] - y[:-1], 

509 scale_units="xy", 

510 angles="xy", 

511 scale=1, 

512 color=colors, 

513 **path_format.quiver_kwargs, 

514 ) 

515 else: 

516 self.ax.plot( 

517 p_transformed[:, 0], 

518 p_transformed[:, 1], 

519 path_format.fmt, 

520 lw=path_format.line_width, 

521 color=path_format.color, 

522 label=path_format.label, 

523 ) 

524 # mark endpoints 

525 self.ax.plot( 

526 [p_transformed[0][0]], 

527 [p_transformed[0][1]], 

528 "o", 

529 color=path_format.color, 

530 ms=10, 

531 ) 

532 self.ax.plot( 

533 [p_transformed[-1][0]], 

534 [p_transformed[-1][1]], 

535 "x", 

536 color=path_format.color, 

537 ms=10, 

538 ) 

539 

540 def to_ascii( 

541 self, 

542 show_endpoints: bool = True, 

543 show_solution: bool = True, 

544 ) -> str: 

545 if self.true_path: 

546 return self.solved_maze.as_ascii( 

547 show_endpoints=show_endpoints, show_solution=show_solution 

548 ) 

549 else: 

550 return self.maze.as_ascii(show_endpoints=show_endpoints)