Coverage for maze_dataset\plotting\plot_maze.py: 84%
212 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1"""provides `MazePlot`, which has many tools for plotting mazes with multiple paths, colored nodes, and more"""
3from __future__ import annotations # for type hinting self as return value
5import warnings
6from copy import deepcopy
7from dataclasses import dataclass
9import matplotlib as mpl
10import matplotlib.pyplot as plt
11import numpy as np
12from jaxtyping import Bool, Float
14from maze_dataset.constants import Coord, CoordArray, CoordList
15from maze_dataset.maze import (
16 LatticeMaze,
17 SolvedMaze,
18 TargetedLatticeMaze,
19)
21LARGE_NEGATIVE_NUMBER: float = -1e10
24@dataclass(kw_only=True)
25class PathFormat:
26 """formatting options for path plot"""
28 label: str | None = None
29 fmt: str = "o"
30 color: str | None = None
31 cmap: str | None = None
32 line_width: float | None = None
33 quiver_kwargs: dict | None = None
35 def combine(self, other: PathFormat) -> PathFormat:
36 """combine with other PathFormat object, overwriting attributes with non-None values.
38 returns a modified copy of self.
39 """
40 output: PathFormat = deepcopy(self)
41 for key, value in other.__dict__.items():
42 if key == "path":
43 raise ValueError(
44 f"Cannot overwrite path attribute! {self = }, {other = }"
45 )
46 if value is not None:
47 setattr(output, key, value)
49 return output
52# styled path
53@dataclass
54class StyledPath(PathFormat):
55 path: CoordArray
58DEFAULT_FORMATS: dict[str, PathFormat] = {
59 "true": PathFormat(
60 label="true path",
61 fmt="--",
62 color="red",
63 line_width=2.5,
64 quiver_kwargs=None,
65 ),
66 "predicted": PathFormat(
67 label=None,
68 fmt=":",
69 color=None,
70 line_width=2,
71 quiver_kwargs={"width": 0.015},
72 ),
73}
76def process_path_input(
77 path: CoordList | CoordArray | StyledPath,
78 _default_key: str,
79 path_fmt: PathFormat | None = None,
80 **kwargs,
81) -> StyledPath:
82 styled_path: StyledPath
83 if isinstance(path, StyledPath):
84 styled_path = path
85 elif isinstance(path, np.ndarray):
86 styled_path = StyledPath(path=path)
87 # add default formatting
88 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key])
89 elif isinstance(path, list):
90 styled_path = StyledPath(path=np.array(path))
91 # add default formatting
92 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key])
93 else:
94 raise TypeError(
95 f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}"
96 )
98 # add formatting from path_fmt
99 if path_fmt is not None:
100 styled_path = styled_path.combine(path_fmt)
102 # add formatting from kwargs
103 for key, value in kwargs.items():
104 setattr(styled_path, key, value)
106 return styled_path
109class MazePlot:
110 """Class for displaying mazes and paths"""
112 DEFAULT_PREDICTED_PATH_COLORS: list[str] = [
113 "tab:orange",
114 "tab:olive",
115 "sienna",
116 "mediumseagreen",
117 "tab:purple",
118 "slategrey",
119 ]
121 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None:
122 """
123 UNIT_LENGTH: Set ratio between node size and wall thickness in image.
124 Wall thickness is fixed to 1px
125 A "unit" consists of a single node and the right and lower connection/wall.
126 Example: ul = 14 yields 13:1 ratio between node size and wall thickness
127 """
128 self.unit_length: int = unit_length
129 self.maze: LatticeMaze = maze
130 self.true_path: StyledPath | None = None
131 self.predicted_paths: list[StyledPath] = []
132 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None
133 self.custom_node_value_flag: bool = False
134 self.node_color_map: str = "Blues"
135 self.target_token_coord: Coord = None
136 self.preceding_tokens_coords: CoordArray = None
137 self.colormap_center: float | None = None
138 self.cbar_ax = None
139 self.marked_coords: list[tuple[Coord, dict]] = list()
141 self.marker_kwargs_current: dict = dict(
142 marker="s",
143 color="green",
144 ms=12,
145 )
146 self.marker_kwargs_next: dict = dict(
147 marker="P",
148 color="green",
149 ms=12,
150 )
152 if isinstance(maze, SolvedMaze):
153 self.add_true_path(maze.solution)
154 else:
155 if isinstance(maze, TargetedLatticeMaze):
156 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)
158 @property
159 def solved_maze(self) -> SolvedMaze:
160 if self.true_path is None:
161 raise ValueError(
162 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method."
163 )
164 return SolvedMaze.from_lattice_maze(
165 lattice_maze=self.maze,
166 solution=self.true_path.path,
167 )
169 def add_true_path(
170 self,
171 path: CoordList | CoordArray | StyledPath,
172 path_fmt: PathFormat | None = None,
173 **kwargs,
174 ) -> MazePlot:
175 self.true_path = process_path_input(
176 path=path,
177 _default_key="true",
178 path_fmt=path_fmt,
179 **kwargs,
180 )
182 return self
184 def add_predicted_path(
185 self,
186 path: CoordList | CoordArray | StyledPath,
187 path_fmt: PathFormat | None = None,
188 **kwargs,
189 ) -> MazePlot:
190 """
191 Recieve predicted path and formatting preferences from input and save in predicted_path list.
192 Default formatting depends on nuber of paths already saved in predicted path list.
193 """
194 styled_path: StyledPath = process_path_input(
195 path=path,
196 _default_key="predicted",
197 path_fmt=path_fmt,
198 **kwargs,
199 )
201 # set default label and color if not specified
202 if styled_path.label is None:
203 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}"
205 if styled_path.color is None:
206 color_num: int = len(self.predicted_paths) % len(
207 self.DEFAULT_PREDICTED_PATH_COLORS
208 )
209 styled_path.color = self.DEFAULT_PREDICTED_PATH_COLORS[color_num]
211 self.predicted_paths.append(styled_path)
212 return self
214 def add_multiple_paths(self, path_list: list[CoordList | CoordArray | StyledPath]):
215 """
216 Function for adding multiple paths to MazePlot at once. This can be done in two ways:
217 1. Passing a list of
218 """
219 for path in path_list:
220 self.add_predicted_path(path)
221 return self
223 def add_node_values(
224 self,
225 node_values: Float[np.ndarray, "grid_n grid_n"],
226 color_map: str = "Blues",
227 target_token_coord: Coord | None = None,
228 preceeding_tokens_coords: CoordArray = None,
229 colormap_center: float | None = None,
230 colormap_max: float | None = None,
231 hide_colorbar: bool = False,
232 ) -> MazePlot:
233 assert node_values.shape == self.maze.grid_shape, (
234 "Please pass node values of the same sape as LatticeMaze.grid_shape"
235 )
236 # assert np.min(node_values) >= 0, "Please pass non-negative node values only."
238 self.node_values = node_values
239 # Set flag for choosing cmap while plotting maze
240 self.custom_node_value_flag = True
241 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero
242 self.node_color_map = color_map
243 self.colormap_center = colormap_center
244 self.colormap_max = colormap_max
245 self.hide_colorbar = hide_colorbar
247 if target_token_coord is not None:
248 self.marked_coords.append((target_token_coord, self.marker_kwargs_next))
249 if preceeding_tokens_coords is not None:
250 for coord in preceeding_tokens_coords:
251 self.marked_coords.append((coord, self.marker_kwargs_current))
252 return self
254 def plot(
255 self,
256 dpi: int = 100,
257 title: str = "",
258 fig_ax: tuple | None = None,
259 plain: bool = False,
260 ) -> MazePlot:
261 """Plot the maze and paths."""
263 # set up figure
264 if fig_ax is None:
265 self.fig = plt.figure(dpi=dpi)
266 self.ax = self.fig.add_subplot(1, 1, 1)
267 else:
268 self.fig, self.ax = fig_ax
270 # plot maze
271 self._plot_maze()
273 # Plot labels
274 if not plain:
275 tick_arr = np.arange(self.maze.grid_shape[0])
276 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr)
277 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr)
278 self.ax.set_xlabel("col")
279 self.ax.set_ylabel("row")
280 self.ax.set_title(title)
282 # plot paths
283 if self.true_path is not None:
284 self._plot_path(self.true_path)
285 for path in self.predicted_paths:
286 self._plot_path(path)
288 # plot markers
289 for coord, kwargs in self.marked_coords:
290 self._place_marked_coords([coord], **kwargs)
292 return self
294 def _rowcol_to_coord(self, point: Coord) -> np.ndarray:
295 """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis."""
296 point = np.array([point[1], point[0]])
297 return self.unit_length * (point + 0.5)
299 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot:
300 kwargs = {
301 **dict(marker="+", color="blue"),
302 **kwargs,
303 }
304 for coord in coords:
305 self.marked_coords.append((coord, kwargs))
307 return self
309 def _place_marked_coords(
310 self, coords: CoordArray | list[Coord], **kwargs
311 ) -> MazePlot:
312 coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords])
313 self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs)
315 return self
317 def _plot_maze(self) -> None:
318 """
319 Define Colormap and plot maze.
320 Colormap: x is -inf: black
321 else: use colormap
322 """
323 img = self._lattice_maze_to_img()
325 # if no node_values have been passed (no colormap)
326 if self.custom_node_value_flag is False:
327 self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1)
329 else:
330 assert self.node_values is not None, "Please pass node values."
331 assert not np.isnan(self.node_values).any(), (
332 "Please pass node values, they cannot be nan."
333 )
335 vals_min: float = np.nanmin(self.node_values)
336 vals_max: float = np.nanmax(self.node_values)
337 # if both are negative or both are positive, set max/min to 0
338 if vals_max < 0.0:
339 vals_max = 0.0
340 elif vals_min > 0.0:
341 vals_min = 0.0
343 # adjust vals_max, in case you need consistent colorbar across multiple plots
344 vals_max = self.colormap_max or vals_max
346 # create colormap
347 cmap = mpl.colormaps[self.node_color_map]
348 # TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black
349 cmap.set_bad(color="black")
351 if self.colormap_center is not None:
352 if not (vals_min < self.colormap_center < vals_max):
353 if vals_min == self.colormap_center:
354 vals_min -= 1e-10
355 elif vals_max == self.colormap_center:
356 vals_max += 1e-10
357 else:
358 raise ValueError(
359 f"Please pass colormap_center value between {vals_min} and {vals_max}"
360 )
362 norm = mpl.colors.TwoSlopeNorm(
363 vmin=vals_min,
364 vcenter=self.colormap_center,
365 vmax=vals_max,
366 )
367 _plotted = self.ax.imshow(img, cmap=cmap, norm=norm)
368 else:
369 _plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max)
371 # Add colorbar based on the condition of self.hide_colorbar
372 if not self.hide_colorbar:
373 ticks = np.linspace(vals_min, vals_max, 5)
375 if (vals_min < 0.0 < vals_max) and (0.0 not in ticks):
376 ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0)
378 if (
379 self.colormap_center is not None
380 and self.colormap_center not in ticks
381 and vals_min < self.colormap_center < vals_max
382 ):
383 ticks = np.insert(
384 ticks,
385 np.searchsorted(ticks, self.colormap_center),
386 self.colormap_center,
387 )
389 cbar = plt.colorbar(
390 _plotted,
391 ticks=ticks,
392 ax=self.ax,
393 cax=self.cbar_ax,
394 )
395 self.cbar_ax = cbar.ax
397 # make the boundaries of the image thicker (walls look weird without this)
398 for axis in ["top", "bottom", "left", "right"]:
399 self.ax.spines[axis].set_linewidth(2)
401 def _lattice_maze_to_img(
402 self,
403 connection_val_scale: float = 0.93,
404 ) -> Bool[np.ndarray, "row col"]:
405 """
406 Build an image to visualise the maze.
407 Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul.
408 - Nodes have area: (ul-1) * (ul-1) and value 1 by default
409 - take node_value if passed via .add_node_values()
410 - Walls have area: 1 * (ul-1) and value -1
411 - Connections have area: 1 * (ul-1); color and value 0.93 by default
412 - take node_value if passed via .add_node_values()
414 Axes definition:
415 (0,0) col
416 ----|----------->
417 |
418 row |
419 |
420 v
422 Returns a matrix of side length (ul) * n + 1 where n is the number of nodes.
423 """
425 # TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored
426 node_bdry_hack: int
427 connection_list_processed: Float[np.ndarray, "dim row col"]
428 # Set node and connection values
429 if self.node_values is None:
430 scaled_node_values = np.ones(self.maze.grid_shape)
431 connection_values = scaled_node_values * connection_val_scale
432 node_bdry_hack = 0
433 # TODO: hack
434 # invert connection list
435 connection_list_processed = np.logical_not(self.maze.connection_list)
436 else:
437 # TODO: hack
438 scaled_node_values = self.node_values
439 # connection_values = scaled_node_values
440 connection_values = np.full_like(scaled_node_values, np.nan)
441 node_bdry_hack = 1
442 connection_list_processed = self.maze.connection_list
444 # Create background image (all pixels set to -1, walls everywhere)
445 img: Float[np.ndarray, "row col"] = -np.ones(
446 (
447 self.maze.grid_shape[0] * self.unit_length + 1,
448 self.maze.grid_shape[1] * self.unit_length + 1,
449 ),
450 dtype=float,
451 )
453 # Draw nodes and connections by iterating through lattice
454 for row in range(self.maze.grid_shape[0]):
455 for col in range(self.maze.grid_shape[1]):
456 # Draw node
457 img[
458 row * self.unit_length + 1 : (row + 1) * self.unit_length
459 + node_bdry_hack,
460 col * self.unit_length + 1 : (col + 1) * self.unit_length
461 + node_bdry_hack,
462 ] = scaled_node_values[row, col]
464 # Down connection
465 if not connection_list_processed[0, row, col]:
466 img[
467 (row + 1) * self.unit_length,
468 col * self.unit_length + 1 : (col + 1) * self.unit_length,
469 ] = connection_values[row, col]
471 # Right connection
472 if not connection_list_processed[1, row, col]:
473 img[
474 row * self.unit_length + 1 : (row + 1) * self.unit_length,
475 (col + 1) * self.unit_length,
476 ] = connection_values[row, col]
478 return img
480 def _plot_path(self, path_format: PathFormat) -> None:
481 if len(path_format.path) == 0:
482 warnings.warn(f"Empty path, skipping plotting\n{path_format = }")
483 return
484 p_transformed = np.array(
485 [self._rowcol_to_coord(coord) for coord in path_format.path]
486 )
487 if path_format.quiver_kwargs is not None:
488 try:
489 x: np.ndarray = p_transformed[:, 0]
490 y: np.ndarray = p_transformed[:, 1]
491 except Exception as e:
492 raise ValueError(
493 f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}"
494 ) from e
496 # Generate colors from the colormap
497 if path_format.cmap is not None:
498 n = len(x) - 1 # Number of arrows
499 cmap = plt.get_cmap(path_format.cmap)
500 colors = [cmap(i / n) for i in range(n)]
501 else:
502 colors = path_format.color
504 self.ax.quiver(
505 x[:-1],
506 y[:-1],
507 x[1:] - x[:-1],
508 y[1:] - y[:-1],
509 scale_units="xy",
510 angles="xy",
511 scale=1,
512 color=colors,
513 **path_format.quiver_kwargs,
514 )
515 else:
516 self.ax.plot(
517 p_transformed[:, 0],
518 p_transformed[:, 1],
519 path_format.fmt,
520 lw=path_format.line_width,
521 color=path_format.color,
522 label=path_format.label,
523 )
524 # mark endpoints
525 self.ax.plot(
526 [p_transformed[0][0]],
527 [p_transformed[0][1]],
528 "o",
529 color=path_format.color,
530 ms=10,
531 )
532 self.ax.plot(
533 [p_transformed[-1][0]],
534 [p_transformed[-1][1]],
535 "x",
536 color=path_format.color,
537 ms=10,
538 )
540 def to_ascii(
541 self,
542 show_endpoints: bool = True,
543 show_solution: bool = True,
544 ) -> str:
545 if self.true_path:
546 return self.solved_maze.as_ascii(
547 show_endpoints=show_endpoints, show_solution=show_solution
548 )
549 else:
550 return self.maze.as_ascii(show_endpoints=show_endpoints)