Coverage for tests\unit\maze_dataset\generation\test_latticemaze.py: 100%

104 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1import numpy as np 

2import pytest 

3 

4from maze_dataset.constants import CoordArray 

5from maze_dataset.generation.default_generators import DEFAULT_GENERATORS 

6from maze_dataset.generation.generators import GENERATORS_MAP 

7from maze_dataset.maze import LatticeMaze, PixelColors, SolvedMaze, TargetedLatticeMaze 

8from maze_dataset.utils import adj_list_to_nested_set, bool_array_from_string 

9 

10 

11# thanks to gpt for these tests of _from_pixel_grid 

12@pytest.fixture 

13def example_pixel_grid(): 

14 return ~np.array( 

15 [ 

16 [1, 1, 1, 1, 1], 

17 [1, 0, 0, 0, 1], 

18 [1, 1, 1, 0, 1], 

19 [1, 0, 0, 0, 1], 

20 [1, 1, 1, 1, 1], 

21 ], 

22 dtype=bool, 

23 ) 

24 

25 

26@pytest.fixture 

27def example_rgb_pixel_grid(): 

28 return np.array( 

29 [ 

30 [ 

31 PixelColors.WALL, 

32 PixelColors.WALL, 

33 PixelColors.WALL, 

34 PixelColors.WALL, 

35 PixelColors.WALL, 

36 ], 

37 [ 

38 PixelColors.WALL, 

39 PixelColors.OPEN, 

40 PixelColors.OPEN, 

41 PixelColors.OPEN, 

42 PixelColors.WALL, 

43 ], 

44 [ 

45 PixelColors.WALL, 

46 PixelColors.WALL, 

47 PixelColors.WALL, 

48 PixelColors.WALL, 

49 PixelColors.WALL, 

50 ], 

51 [ 

52 PixelColors.WALL, 

53 PixelColors.OPEN, 

54 PixelColors.WALL, 

55 PixelColors.OPEN, 

56 PixelColors.WALL, 

57 ], 

58 [ 

59 PixelColors.WALL, 

60 PixelColors.WALL, 

61 PixelColors.WALL, 

62 PixelColors.WALL, 

63 PixelColors.WALL, 

64 ], 

65 ], 

66 dtype=np.uint8, 

67 ) 

68 

69 

70def test_from_pixel_grid_bw(example_pixel_grid): 

71 connection_list, grid_shape = LatticeMaze._from_pixel_grid_bw(example_pixel_grid) 

72 

73 assert isinstance(connection_list, np.ndarray) 

74 assert connection_list.shape == (2, 2, 2) 

75 assert np.all(connection_list[0] == np.array([[False, True], [False, False]])) 

76 assert np.all(connection_list[1] == np.array([[True, False], [True, False]])) 

77 assert grid_shape == (2, 2) 

78 

79 

80def test_from_pixel_grid_with_positions(example_rgb_pixel_grid): 

81 marked_positions = { 

82 "start": PixelColors.START, 

83 "end": PixelColors.END, 

84 "path": PixelColors.PATH, 

85 } 

86 

87 ( 

88 connection_list, 

89 grid_shape, 

90 out_positions, 

91 ) = LatticeMaze._from_pixel_grid_with_positions( 

92 example_rgb_pixel_grid, marked_positions 

93 ) 

94 

95 assert isinstance(connection_list, np.ndarray) 

96 assert connection_list.shape == (2, 2, 2) 

97 assert np.all(connection_list[0] == np.array([[False, False], [False, False]])) 

98 assert np.all(connection_list[1] == np.array([[True, False], [False, False]])) 

99 assert grid_shape == (2, 2) 

100 

101 assert isinstance(out_positions, dict) 

102 assert len(out_positions) == 3 

103 assert "start" in out_positions and "end" in out_positions 

104 assert ( 

105 isinstance(out_positions["start"], np.ndarray) 

106 and isinstance(out_positions["end"], np.ndarray) 

107 and isinstance(out_positions["path"], np.ndarray) 

108 ) 

109 assert out_positions["start"].shape == (0,) 

110 assert out_positions["end"].shape == (0,) 

111 assert out_positions["path"].shape == (0,) 

112 

113 

114def test_find_start_end_points_in_rgb_pixel_grid(): 

115 rgb_pixel_grid_with_positions = np.array( 

116 [ 

117 [ 

118 PixelColors.WALL, 

119 PixelColors.WALL, 

120 PixelColors.WALL, 

121 PixelColors.WALL, 

122 PixelColors.WALL, 

123 ], 

124 [ 

125 PixelColors.WALL, 

126 PixelColors.START, 

127 PixelColors.OPEN, 

128 PixelColors.END, 

129 PixelColors.WALL, 

130 ], 

131 [ 

132 PixelColors.WALL, 

133 PixelColors.WALL, 

134 PixelColors.WALL, 

135 PixelColors.WALL, 

136 PixelColors.WALL, 

137 ], 

138 [ 

139 PixelColors.WALL, 

140 PixelColors.OPEN, 

141 PixelColors.WALL, 

142 PixelColors.OPEN, 

143 PixelColors.WALL, 

144 ], 

145 [ 

146 PixelColors.WALL, 

147 PixelColors.WALL, 

148 PixelColors.WALL, 

149 PixelColors.WALL, 

150 PixelColors.WALL, 

151 ], 

152 ], 

153 dtype=np.uint8, 

154 ) 

155 

156 marked_positions = { 

157 "start": PixelColors.START, 

158 "end": PixelColors.END, 

159 "path": PixelColors.PATH, 

160 } 

161 

162 ( 

163 connection_list, 

164 grid_shape, 

165 out_positions, 

166 ) = LatticeMaze._from_pixel_grid_with_positions( 

167 rgb_pixel_grid_with_positions, marked_positions 

168 ) 

169 

170 print(f"{out_positions = }") 

171 

172 assert isinstance(out_positions, dict) 

173 assert len(out_positions) == 3 

174 assert "start" in out_positions and "end" in out_positions 

175 assert ( 

176 isinstance(out_positions["start"], np.ndarray) 

177 and isinstance(out_positions["end"], np.ndarray) 

178 and isinstance(out_positions["path"], np.ndarray) 

179 ) 

180 

181 assert np.all(out_positions["start"] == np.array([[0, 0]])) 

182 assert np.all(out_positions["end"] == np.array([[0, 1]])) 

183 assert out_positions["path"].shape == (0,) 

184 

185 

186@pytest.mark.parametrize("gfunc_name, kwargs", DEFAULT_GENERATORS) 

187def test_pixels_ascii_roundtrip(gfunc_name, kwargs): 

188 """tests all generators work and can be written to/from ascii and pixels""" 

189 n: int = 5 

190 maze_gen_func = GENERATORS_MAP[gfunc_name] 

191 maze: LatticeMaze = maze_gen_func(np.array([n, n]), **kwargs) 

192 

193 maze_pixels: np.ndarray = maze.as_pixels() 

194 maze_ascii: str = maze.as_ascii() 

195 

196 assert maze == LatticeMaze.from_pixels(maze_pixels) 

197 assert maze == LatticeMaze.from_ascii(maze_ascii) 

198 

199 expected_shape: tuple = (n * 2 + 1, n * 2 + 1, 3) 

200 assert maze_pixels.shape == expected_shape, ( 

201 f"{maze_pixels.shape} != {expected_shape}" 

202 ) 

203 assert all(n * 2 + 1 == len(line) for line in maze_ascii.splitlines()), ( 

204 f"{maze_ascii}" 

205 ) 

206 

207 

208@pytest.mark.parametrize("gfunc_name, kwargs", DEFAULT_GENERATORS) 

209def test_targeted_solved_maze(gfunc_name, kwargs): 

210 n: int = 5 

211 maze_gen_func = GENERATORS_MAP[gfunc_name] 

212 maze: LatticeMaze = maze_gen_func(np.array([n, n]), **kwargs) 

213 solution: CoordArray = maze.generate_random_path() 

214 tgt_maze: TargetedLatticeMaze = TargetedLatticeMaze.from_lattice_maze( 

215 maze, 

216 solution[0], 

217 solution[-1], 

218 ) 

219 

220 tgt_maze_pixels: np.ndarray = tgt_maze.as_pixels() 

221 tgt_maze_ascii: str = tgt_maze.as_ascii() 

222 

223 assert tgt_maze == TargetedLatticeMaze.from_pixels(tgt_maze_pixels) 

224 assert tgt_maze == TargetedLatticeMaze.from_ascii(tgt_maze_ascii) 

225 

226 expected_shape: tuple = (n * 2 + 1, n * 2 + 1, 3) 

227 assert tgt_maze_pixels.shape == expected_shape, ( 

228 f"{tgt_maze_pixels.shape} != {expected_shape}" 

229 ) 

230 assert all(n * 2 + 1 == len(line) for line in tgt_maze_ascii.splitlines()), ( 

231 f"{tgt_maze_ascii}" 

232 ) 

233 

234 solved_maze: SolvedMaze = SolvedMaze.from_targeted_lattice_maze(tgt_maze) 

235 

236 solved_maze_pixels: np.ndarray = solved_maze.as_pixels() 

237 solved_maze_ascii: str = solved_maze.as_ascii() 

238 

239 assert solved_maze == SolvedMaze.from_pixels(solved_maze_pixels) 

240 assert solved_maze == SolvedMaze.from_ascii(solved_maze_ascii) 

241 

242 expected_shape: tuple = (n * 2 + 1, n * 2 + 1, 3) 

243 assert tgt_maze_pixels.shape == expected_shape, ( 

244 f"{tgt_maze_pixels.shape} != {expected_shape}" 

245 ) 

246 assert all(n * 2 + 1 == len(line) for line in solved_maze_ascii.splitlines()), ( 

247 f"{solved_maze_ascii}" 

248 ) 

249 

250 

251def test_as_adj_list(): 

252 connection_list = bool_array_from_string( 

253 """ 

254 F T 

255 F F 

256 

257 T F 

258 T F 

259 """, 

260 shape=[2, 2, 2], 

261 ) 

262 

263 maze = LatticeMaze(connection_list=connection_list) 

264 

265 adj_list = maze.as_adj_list(shuffle_d0=False, shuffle_d1=False) 

266 

267 expected = [[[0, 1], [1, 1]], [[0, 0], [0, 1]], [[1, 0], [1, 1]]] 

268 

269 assert adj_list_to_nested_set(expected) == adj_list_to_nested_set(adj_list) 

270 

271 

272@pytest.mark.parametrize("gfunc_name, kwargs", DEFAULT_GENERATORS) 

273def test_get_nodes(gfunc_name, kwargs): 

274 maze_gen_func = GENERATORS_MAP[gfunc_name] 

275 maze = maze_gen_func(np.array((3, 2)), **kwargs) 

276 assert ( 

277 maze.get_nodes().tolist() 

278 == np.array([(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)]).tolist() 

279 ) 

280 

281 

282@pytest.mark.parametrize("gfunc_name, kwargs", DEFAULT_GENERATORS) 

283def test_generate_random_path(gfunc_name, kwargs): 

284 maze_gen_func = GENERATORS_MAP[gfunc_name] 

285 maze = maze_gen_func(np.array((2, 2)), **kwargs) 

286 path = maze.generate_random_path() 

287 

288 # len > 1 ensures that we have unique start and end nodes 

289 assert len(path) > 1 

290 

291 

292@pytest.mark.parametrize("gfunc_name, kwargs", DEFAULT_GENERATORS) 

293def test_generate_random_path_size_1(gfunc_name, kwargs): 

294 maze_gen_func = GENERATORS_MAP[gfunc_name] 

295 maze = maze_gen_func(np.array((1, 1)), **kwargs) 

296 with pytest.raises(AssertionError): 

297 maze.generate_random_path()