Coverage for tests\unit\maze_dataset\plotting\test_maze_plot.py: 100%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1"""mostly taken from `demo_latticemaze.ipynb`""" 

2 

3import os 

4 

5import matplotlib.pyplot as plt 

6import numpy as np 

7 

8from maze_dataset.generation import LatticeMazeGenerators 

9from maze_dataset.maze import SolvedMaze, TargetedLatticeMaze 

10from maze_dataset.plotting import MazePlot 

11 

12FIG_SAVE: str = "tests/_temp/figures/" 

13 

14 

15def test_maze_plot(): 

16 N: int = 10 

17 

18 os.makedirs(FIG_SAVE, exist_ok=True) 

19 

20 maze = LatticeMazeGenerators.gen_dfs(np.array([N, N])) 

21 tgt_maze: TargetedLatticeMaze = TargetedLatticeMaze.from_lattice_maze( 

22 maze, (0, 0), (N - 1, N - 1) 

23 ) 

24 solved_maze: SolvedMaze = SolvedMaze.from_targeted_lattice_maze(tgt_maze) 

25 

26 fig, ax = plt.subplots(1, 3, figsize=(15, 5)) 

27 

28 for ax_i, temp_maze in zip(ax, [maze, tgt_maze, solved_maze]): 

29 ax_i.set_title(temp_maze.as_ascii(), fontfamily="monospace") 

30 ax_i.imshow(temp_maze.as_pixels()) 

31 

32 assert temp_maze == temp_maze.__class__.from_pixels(temp_maze.as_pixels()) 

33 assert temp_maze == temp_maze.__class__.from_ascii(temp_maze.as_ascii()) 

34 

35 plt.savefig(FIG_SAVE + "pixels_and_ascii.png") 

36 

37 MazePlot(maze).plot() 

38 plt.savefig(FIG_SAVE + "mazeplot-pathless.png") 

39 

40 true_path = maze.find_shortest_path(c_start=(0, 0), c_end=(3, 3)) 

41 

42 MazePlot(solved_maze).plot() 

43 plt.savefig(FIG_SAVE + "mazeplot-solvedmaze.png") 

44 

45 pred_path1 = [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (3, 2), (3, 3)] 

46 pred_path2 = [ 

47 (0, 0), 

48 (0, 1), 

49 (0, 2), 

50 (0, 3), 

51 (1, 3), 

52 (2, 3), 

53 (2, 2), 

54 (3, 2), 

55 (3, 3), 

56 ] 

57 ( 

58 MazePlot(maze) 

59 .add_true_path(true_path) 

60 .add_predicted_path(pred_path1) 

61 .add_predicted_path(pred_path2) 

62 .plot() 

63 ) 

64 plt.savefig(FIG_SAVE + "mazeplot-fakepaths.png") 

65 

66 node_values = np.random.uniform(size=maze.grid_shape) 

67 

68 MazePlot(maze).add_node_values(node_values, color_map="Blues").plot() 

69 plt.savefig(FIG_SAVE + "mazeplot-nodevalues.png") 

70 

71 MazePlot(maze).add_node_values( 

72 node_values, 

73 color_map="Blues", 

74 target_token_coord=np.array([2, 0]), 

75 preceeding_tokens_coords=np.array([[0, 0], [3, 1]]), 

76 ).plot() 

77 plt.savefig(FIG_SAVE + "mazeplot-nodevalues_target.png") 

78 

79 pred_paths = [pred_path1, pred_path2] 

80 MazePlot(maze).add_multiple_paths(pred_paths).plot() 

81 plt.savefig(FIG_SAVE + "mazeplot-multipath.png") 

82 

83 ascii_maze = MazePlot(maze).to_ascii() 

84 print(ascii_maze)