Coverage for tests/unit/maze_dataset/generation/test_maze_dataset.py: 100%

137 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-11 00:49 -0600

1import copy 

2from pathlib import Path 

3 

4import numpy as np 

5import pytest 

6from zanj import ZANJ 

7 

8from maze_dataset.constants import CoordArray 

9from maze_dataset.dataset.dataset import ( 

10 register_dataset_filter, 

11 register_filter_namespace_for_dataset, 

12) 

13from maze_dataset.dataset.maze_dataset import ( 

14 MazeDataset, 

15 MazeDatasetConfig, 

16 register_maze_filter, 

17 set_serialize_minimal_threshold, 

18) 

19from maze_dataset.generation.generators import GENERATORS_MAP 

20from maze_dataset.maze import SolvedMaze 

21from maze_dataset.utils import bool_array_from_string 

22 

23 

24class TestMazeDatasetConfig: 

25 pass 

26 

27 

28TEST_CONFIGS = [ 

29 MazeDatasetConfig( 

30 name="test", 

31 grid_n=grid_n, 

32 n_mazes=n_mazes, 

33 maze_ctor=GENERATORS_MAP["gen_dfs"], 

34 maze_ctor_kwargs=maze_ctor_kwargs, 

35 ) 

36 for grid_n, n_mazes, maze_ctor_kwargs in [ 

37 (3, 5, {}), 

38 (3, 1, {}), 

39 (5, 5, dict(do_forks=False)), 

40 ] 

41] 

42 

43 

44def test_generate_serial(): 

45 dataset = MazeDataset.generate(TEST_CONFIGS[0], gen_parallel=False) 

46 

47 assert len(dataset) == 5 

48 for maze in dataset: 

49 assert maze.grid_shape == (3, 3) 

50 

51 

52def test_generate_parallel(): 

53 dataset = MazeDataset.generate( 

54 TEST_CONFIGS[0], 

55 gen_parallel=True, 

56 verbose=True, 

57 pool_kwargs=dict(processes=2), 

58 ) 

59 

60 assert len(dataset) == 5 

61 for maze in dataset: 

62 assert maze.grid_shape == (3, 3) 

63 

64 

65def test_data_hash_wip(): 

66 dataset = MazeDataset.generate(TEST_CONFIGS[0]) 

67 # TODO: dataset.data_hash doesn't work right now 

68 assert dataset 

69 

70 

71def test_download(): 

72 with pytest.raises(NotImplementedError): 

73 MazeDataset.download(TEST_CONFIGS[0]) 

74 

75 

76def test_serialize_load(): 

77 dataset = MazeDataset.generate(TEST_CONFIGS[0]) 

78 dataset_copy = MazeDataset.load(dataset.serialize()) 

79 

80 assert dataset.cfg == dataset_copy.cfg 

81 for maze, maze_copy in zip(dataset, dataset_copy, strict=False): 

82 assert maze == maze_copy 

83 

84 

85@pytest.mark.parametrize( 

86 "config", 

87 [ 

88 pytest.param( 

89 c, 

90 id=f"{c.grid_n=}; {c.n_mazes=}; {c.maze_ctor_kwargs=}", 

91 ) 

92 for c in TEST_CONFIGS 

93 ], 

94) 

95def test_serialize_load_minimal(config): 

96 d = MazeDataset.generate(config, gen_parallel=False) 

97 assert MazeDataset.load(d._serialize_minimal()) == d 

98 

99 

100@pytest.mark.parametrize( 

101 "config", 

102 [ 

103 pytest.param( 

104 c, 

105 id=f"{c.grid_n=}; {c.n_mazes=}; {c.maze_ctor_kwargs=}", 

106 ) 

107 for c in TEST_CONFIGS 

108 ], 

109) 

110def test_save_read_minimal(config): 

111 def save_and_read(d: MazeDataset, p: str): 

112 d.save(file_path=p) 

113 # read as MazeDataset 

114 roundtrip = MazeDataset.read(p) 

115 assert roundtrip == d 

116 # read from zanj 

117 z = ZANJ() 

118 roundtrip_zanj = z.read(p) 

119 assert roundtrip_zanj == d 

120 

121 d = MazeDataset.generate(config, gen_parallel=False) 

122 p = Path("tests/_temp/test_maze_dataset/") / (d.cfg.to_fname() + ".zanj") 

123 

124 # Test with full serialization 

125 set_serialize_minimal_threshold(None) 

126 save_and_read(d, p) 

127 

128 # Test with minimal serialization 

129 set_serialize_minimal_threshold(0) 

130 save_and_read(d, p) 

131 

132 d.save(file_path=p) 

133 # read as MazeDataset 

134 roundtrip = MazeDataset.read(p) 

135 assert d.cfg.diff(roundtrip.cfg) == dict() 

136 cfg_diff = roundtrip.cfg.diff(d.cfg) 

137 assert cfg_diff == {} 

138 assert roundtrip.cfg == d.cfg 

139 assert roundtrip.mazes == d.mazes 

140 assert roundtrip == d 

141 # read from zanj 

142 z = ZANJ() 

143 roundtrip_zanj = z.read(p) 

144 assert roundtrip_zanj == d 

145 

146 

147def test_custom_maze_filter(): 

148 connection_list = bool_array_from_string( 

149 """ 

150 F T 

151 F F 

152 

153 T F 

154 T F 

155 """, 

156 shape=[2, 2, 2], 

157 ) 

158 solutions = [ 

159 [[0, 0], [0, 1], [1, 1]], 

160 [[0, 0], [0, 1]], 

161 [[0, 0]], 

162 ] 

163 

164 def custom_filter_solution_length(maze: SolvedMaze, solution_length: int) -> bool: 

165 return len(maze.solution) == solution_length 

166 

167 mazes = [ 

168 SolvedMaze(connection_list=connection_list, solution=solution) 

169 for solution in solutions 

170 ] 

171 dataset = MazeDataset(cfg=TEST_CONFIGS[0], mazes=mazes) 

172 

173 filtered_lambda = dataset.custom_maze_filter(lambda m: len(m.solution) == 1) 

174 filtered_func = dataset.custom_maze_filter( 

175 custom_filter_solution_length, 

176 solution_length=1, 

177 ) 

178 

179 assert filtered_lambda.mazes == filtered_func.mazes == [mazes[2]] 

180 

181 

182class TestMazeDatasetFilters: 

183 config = MazeDatasetConfig(name="test", grid_n=3, n_mazes=5) 

184 connection_list = bool_array_from_string( 

185 """ 

186 F T 

187 F F 

188 

189 T F 

190 T F 

191 """, 

192 shape=[2, 2, 2], 

193 ) 

194 

195 def test_filters(self): 

196 class TestDataset(MazeDataset): ... 

197 

198 @register_filter_namespace_for_dataset(TestDataset) 

199 class TestFilters: 

200 @register_maze_filter 

201 @staticmethod 

202 def solution_match(maze: SolvedMaze, solution: CoordArray) -> bool: 

203 """Test for solution equality""" 

204 return (maze.solution == solution).all() 

205 

206 @register_dataset_filter 

207 @staticmethod 

208 def drop_nth(dataset: TestDataset, n: int) -> TestDataset: 

209 """Filter mazes by path length""" 

210 return copy.deepcopy( 

211 TestDataset( 

212 dataset.cfg, 

213 [maze for i, maze in enumerate(dataset) if i != n], 

214 ), 

215 ) 

216 

217 maze1 = SolvedMaze( 

218 connection_list=self.connection_list, 

219 solution=np.array([[0, 0]]), 

220 ) 

221 maze2 = SolvedMaze( 

222 connection_list=self.connection_list, 

223 solution=np.array([[0, 1]]), 

224 ) 

225 

226 dataset = TestDataset(self.config, [maze1, maze2]) 

227 

228 maze_filter = dataset.filter_by.solution_match(solution=np.array([[0, 0]])) 

229 maze_filter2 = dataset.filter_by.solution_match(np.array([[0, 0]])) 

230 

231 dataset_filter = dataset.filter_by.drop_nth(n=0) 

232 dataset_filter2 = dataset.filter_by.drop_nth(0) 

233 

234 assert maze_filter.mazes == maze_filter2.mazes == [maze1] 

235 assert dataset_filter.mazes == dataset_filter2.mazes == [maze2] 

236 

237 def test_path_length(self): 

238 long_maze = SolvedMaze( 

239 connection_list=self.connection_list, 

240 solution=np.array([[0, 0], [0, 1], [1, 1]]), 

241 ) 

242 

243 short_maze = SolvedMaze( 

244 connection_list=self.connection_list, 

245 solution=np.array([[0, 0], [0, 1]]), 

246 ) 

247 

248 dataset = MazeDataset(self.config, [long_maze, short_maze]) 

249 path_length_filtered = dataset.filter_by.path_length(3) 

250 start_end_filtered = dataset.filter_by.start_end_distance(2) 

251 

252 assert type(path_length_filtered) == type(dataset) # noqa: E721 

253 assert path_length_filtered.mazes == [long_maze] 

254 assert start_end_filtered.mazes == [long_maze] 

255 assert dataset.mazes == [long_maze, short_maze] 

256 

257 def test_cut_percentile_shortest(self): 

258 solutions = [ 

259 [[0, 0], [0, 1], [1, 1]], 

260 [[0, 0], [0, 1]], 

261 [[0, 0]], 

262 ] 

263 

264 mazes = [ 

265 SolvedMaze(connection_list=self.connection_list, solution=solution) 

266 for solution in solutions 

267 ] 

268 dataset = MazeDataset(cfg=self.config, mazes=mazes) 

269 filtered = dataset.filter_by.cut_percentile_shortest(49.0) 

270 

271 assert filtered.mazes == mazes[:2] 

272 

273 

274DUPE_DATASET = [ 

275 """ 

276##### 

277# E# 

278###X# 

279#SXX# 

280##### 

281""", 

282 """ 

283##### 

284#SXE# 

285### # 

286# # 

287##### 

288""", 

289 """ 

290##### 

291# E# 

292###X# 

293#SXX# 

294##### 

295""", 

296 """ 

297##### 

298# # # 

299# # # 

300#EXS# 

301##### 

302""", 

303 """ 

304##### 

305#SXX# 

306###X# 

307#EXX# 

308##### 

309""", 

310] 

311 

312 

313def _helper_dataset_from_ascii(ascii_rep: str) -> MazeDataset: 

314 mazes: list[SolvedMaze] = list() 

315 for maze_ascii in ascii_rep: 

316 # TODO: PERF401 Use `list.extend` to create a transformed list 

317 mazes.append(SolvedMaze.from_ascii(maze_ascii.strip())) 

318 

319 return MazeDataset( 

320 MazeDatasetConfig( 

321 name="test", 

322 grid_n=mazes[0].grid_shape[0], 

323 n_mazes=len(mazes), 

324 ), 

325 mazes, 

326 ) 

327 

328 

329def test_remove_duplicates(): 

330 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET) 

331 dataset_deduped: MazeDataset = dataset.filter_by.remove_duplicates() 

332 

333 assert len(dataset) == 5 

334 assert dataset_deduped.mazes == [dataset.mazes[3], dataset.mazes[4]] 

335 

336 

337def test_data_hash(): 

338 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET) 

339 hash_1 = dataset.data_hash() 

340 hash_2 = dataset.data_hash() 

341 

342 assert hash_1 == hash_2 

343 

344 

345def test_remove_duplicates_fast(): 

346 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET) 

347 dataset_deduped: MazeDataset = dataset.filter_by.remove_duplicates_fast() 

348 

349 assert len(dataset) == 5 

350 assert dataset_deduped.mazes == [ 

351 dataset.mazes[0], 

352 dataset.mazes[1], 

353 dataset.mazes[3], 

354 dataset.mazes[4], 

355 ]