Coverage for tests\unit\maze_dataset\plotting\test_maze_plot.py: 100%
39 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1"""mostly taken from `demo_latticemaze.ipynb`"""
3import os
5import matplotlib.pyplot as plt
6import numpy as np
8from maze_dataset.generation import LatticeMazeGenerators
9from maze_dataset.maze import SolvedMaze, TargetedLatticeMaze
10from maze_dataset.plotting import MazePlot
12FIG_SAVE: str = "tests/_temp/figures/"
15def test_maze_plot():
16 N: int = 10
18 os.makedirs(FIG_SAVE, exist_ok=True)
20 maze = LatticeMazeGenerators.gen_dfs(np.array([N, N]))
21 tgt_maze: TargetedLatticeMaze = TargetedLatticeMaze.from_lattice_maze(
22 maze, (0, 0), (N - 1, N - 1)
23 )
24 solved_maze: SolvedMaze = SolvedMaze.from_targeted_lattice_maze(tgt_maze)
26 fig, ax = plt.subplots(1, 3, figsize=(15, 5))
28 for ax_i, temp_maze in zip(ax, [maze, tgt_maze, solved_maze]):
29 ax_i.set_title(temp_maze.as_ascii(), fontfamily="monospace")
30 ax_i.imshow(temp_maze.as_pixels())
32 assert temp_maze == temp_maze.__class__.from_pixels(temp_maze.as_pixels())
33 assert temp_maze == temp_maze.__class__.from_ascii(temp_maze.as_ascii())
35 plt.savefig(FIG_SAVE + "pixels_and_ascii.png")
37 MazePlot(maze).plot()
38 plt.savefig(FIG_SAVE + "mazeplot-pathless.png")
40 true_path = maze.find_shortest_path(c_start=(0, 0), c_end=(3, 3))
42 MazePlot(solved_maze).plot()
43 plt.savefig(FIG_SAVE + "mazeplot-solvedmaze.png")
45 pred_path1 = [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (3, 2), (3, 3)]
46 pred_path2 = [
47 (0, 0),
48 (0, 1),
49 (0, 2),
50 (0, 3),
51 (1, 3),
52 (2, 3),
53 (2, 2),
54 (3, 2),
55 (3, 3),
56 ]
57 (
58 MazePlot(maze)
59 .add_true_path(true_path)
60 .add_predicted_path(pred_path1)
61 .add_predicted_path(pred_path2)
62 .plot()
63 )
64 plt.savefig(FIG_SAVE + "mazeplot-fakepaths.png")
66 node_values = np.random.uniform(size=maze.grid_shape)
68 MazePlot(maze).add_node_values(node_values, color_map="Blues").plot()
69 plt.savefig(FIG_SAVE + "mazeplot-nodevalues.png")
71 MazePlot(maze).add_node_values(
72 node_values,
73 color_map="Blues",
74 target_token_coord=np.array([2, 0]),
75 preceeding_tokens_coords=np.array([[0, 0], [3, 1]]),
76 ).plot()
77 plt.savefig(FIG_SAVE + "mazeplot-nodevalues_target.png")
79 pred_paths = [pred_path1, pred_path2]
80 MazePlot(maze).add_multiple_paths(pred_paths).plot()
81 plt.savefig(FIG_SAVE + "mazeplot-multipath.png")
83 ascii_maze = MazePlot(maze).to_ascii()
84 print(ascii_maze)