Coverage for tests\unit\maze_dataset\generation\test_maze_dataset.py: 100%

138 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1import copy 

2from pathlib import Path 

3 

4import numpy as np 

5import pytest 

6from pytest import mark, param 

7from zanj import ZANJ 

8 

9from maze_dataset.constants import CoordArray 

10from maze_dataset.dataset.dataset import ( 

11 register_dataset_filter, 

12 register_filter_namespace_for_dataset, 

13) 

14from maze_dataset.dataset.maze_dataset import ( 

15 MazeDataset, 

16 MazeDatasetConfig, 

17 register_maze_filter, 

18 set_serialize_minimal_threshold, 

19) 

20from maze_dataset.generation.generators import GENERATORS_MAP 

21from maze_dataset.maze import SolvedMaze 

22from maze_dataset.utils import bool_array_from_string 

23 

24 

25class TestMazeDatasetConfig: 

26 pass 

27 

28 

29TEST_CONFIGS = [ 

30 MazeDatasetConfig( 

31 name="test", 

32 grid_n=grid_n, 

33 n_mazes=n_mazes, 

34 maze_ctor=GENERATORS_MAP["gen_dfs"], 

35 maze_ctor_kwargs=maze_ctor_kwargs, 

36 ) 

37 for grid_n, n_mazes, maze_ctor_kwargs in [ 

38 (3, 5, {}), 

39 (3, 1, {}), 

40 (5, 5, dict(do_forks=False)), 

41 ] 

42] 

43 

44 

45def test_generate_serial(): 

46 dataset = MazeDataset.generate(TEST_CONFIGS[0], gen_parallel=False) 

47 

48 assert len(dataset) == 5 

49 for i, maze in enumerate(dataset): 

50 assert maze.grid_shape == (3, 3) 

51 

52 

53def test_generate_parallel(): 

54 dataset = MazeDataset.generate( 

55 TEST_CONFIGS[0], gen_parallel=True, verbose=True, pool_kwargs=dict(processes=2) 

56 ) 

57 

58 assert len(dataset) == 5 

59 for i, maze in enumerate(dataset): 

60 assert maze.grid_shape == (3, 3) 

61 

62 

63def test_data_hash_wip(): 

64 dataset = MazeDataset.generate(TEST_CONFIGS[0]) 

65 # TODO: dataset.data_hash doesn't work right now 

66 assert dataset 

67 

68 

69def test_download(): 

70 with pytest.raises(NotImplementedError): 

71 MazeDataset.download(TEST_CONFIGS[0]) 

72 

73 

74def test_serialize_load(): 

75 dataset = MazeDataset.generate(TEST_CONFIGS[0]) 

76 dataset_copy = MazeDataset.load(dataset.serialize()) 

77 

78 assert dataset.cfg == dataset_copy.cfg 

79 for maze, maze_copy in zip(dataset, dataset_copy): 

80 assert maze == maze_copy 

81 

82 

83@mark.parametrize( 

84 "config", 

85 [ 

86 param( 

87 c, 

88 id=f"{c.grid_n=}; {c.n_mazes=}; {c.maze_ctor_kwargs=}", 

89 ) 

90 for c in TEST_CONFIGS 

91 ], 

92) 

93def test_serialize_load_minimal(config): 

94 d = MazeDataset.generate(config, gen_parallel=False) 

95 assert MazeDataset.load(d._serialize_minimal()) == d 

96 

97 

98@mark.parametrize( 

99 "config", 

100 [ 

101 param( 

102 c, 

103 id=f"{c.grid_n=}; {c.n_mazes=}; {c.maze_ctor_kwargs=}", 

104 ) 

105 for c in TEST_CONFIGS 

106 ], 

107) 

108def test_save_read_minimal(config): 

109 def save_and_read(d: MazeDataset, p: str): 

110 d.save(file_path=p) 

111 # read as MazeDataset 

112 roundtrip = MazeDataset.read(p) 

113 assert roundtrip == d 

114 # read from zanj 

115 z = ZANJ() 

116 roundtrip_zanj = z.read(p) 

117 assert roundtrip_zanj == d 

118 

119 d = MazeDataset.generate(config, gen_parallel=False) 

120 p = Path("tests/_temp/test_maze_dataset/") / (d.cfg.to_fname() + ".zanj") 

121 

122 # Test with full serialization 

123 set_serialize_minimal_threshold(None) 

124 save_and_read(d, p) 

125 

126 # Test with minimal serialization 

127 set_serialize_minimal_threshold(0) 

128 save_and_read(d, p) 

129 

130 d.save(file_path=p) 

131 # read as MazeDataset 

132 roundtrip = MazeDataset.read(p) 

133 assert d.cfg.diff(roundtrip.cfg) == dict() 

134 cfg_diff = roundtrip.cfg.diff(d.cfg) 

135 assert cfg_diff == {} 

136 assert roundtrip.cfg == d.cfg 

137 assert roundtrip.mazes == d.mazes 

138 assert roundtrip == d 

139 # read from zanj 

140 z = ZANJ() 

141 roundtrip_zanj = z.read(p) 

142 assert roundtrip_zanj == d 

143 

144 

145def test_custom_maze_filter(): 

146 connection_list = bool_array_from_string( 

147 """ 

148 F T 

149 F F 

150 

151 T F 

152 T F 

153 """, 

154 shape=[2, 2, 2], 

155 ) 

156 solutions = [ 

157 [[0, 0], [0, 1], [1, 1]], 

158 [[0, 0], [0, 1]], 

159 [[0, 0]], 

160 ] 

161 

162 def custom_filter_solution_length(maze: SolvedMaze, solution_length: int) -> bool: 

163 return len(maze.solution) == solution_length 

164 

165 mazes = [ 

166 SolvedMaze(connection_list=connection_list, solution=solution) 

167 for solution in solutions 

168 ] 

169 dataset = MazeDataset(cfg=TEST_CONFIGS[0], mazes=mazes) 

170 

171 filtered_lambda = dataset.custom_maze_filter(lambda m: len(m.solution) == 1) 

172 filtered_func = dataset.custom_maze_filter( 

173 custom_filter_solution_length, solution_length=1 

174 ) 

175 

176 assert filtered_lambda.mazes == filtered_func.mazes == [mazes[2]] 

177 

178 

179class TestMazeDatasetFilters: 

180 config = MazeDatasetConfig(name="test", grid_n=3, n_mazes=5) 

181 connection_list = bool_array_from_string( 

182 """ 

183 F T 

184 F F 

185 

186 T F 

187 T F 

188 """, 

189 shape=[2, 2, 2], 

190 ) 

191 

192 def test_filters(self): 

193 class TestDataset(MazeDataset): ... 

194 

195 @register_filter_namespace_for_dataset(TestDataset) 

196 class TestFilters: 

197 @register_maze_filter 

198 @staticmethod 

199 def solution_match(maze: SolvedMaze, solution: CoordArray) -> bool: 

200 """Test for solution equality""" 

201 return (maze.solution == solution).all() 

202 

203 @register_dataset_filter 

204 @staticmethod 

205 def drop_nth(dataset: TestDataset, n: int) -> TestDataset: 

206 """Filter mazes by path length""" 

207 return copy.deepcopy( 

208 TestDataset( 

209 dataset.cfg, [maze for i, maze in enumerate(dataset) if i != n] 

210 ) 

211 ) 

212 

213 maze1 = SolvedMaze( 

214 connection_list=self.connection_list, solution=np.array([[0, 0]]) 

215 ) 

216 maze2 = SolvedMaze( 

217 connection_list=self.connection_list, solution=np.array([[0, 1]]) 

218 ) 

219 

220 dataset = TestDataset(self.config, [maze1, maze2]) 

221 

222 maze_filter = dataset.filter_by.solution_match(solution=np.array([[0, 0]])) 

223 maze_filter2 = dataset.filter_by.solution_match(np.array([[0, 0]])) 

224 

225 dataset_filter = dataset.filter_by.drop_nth(n=0) 

226 dataset_filter2 = dataset.filter_by.drop_nth(0) 

227 

228 assert maze_filter.mazes == maze_filter2.mazes == [maze1] 

229 assert dataset_filter.mazes == dataset_filter2.mazes == [maze2] 

230 

231 def test_path_length(self): 

232 long_maze = SolvedMaze( 

233 connection_list=self.connection_list, 

234 solution=np.array([[0, 0], [0, 1], [1, 1]]), 

235 ) 

236 

237 short_maze = SolvedMaze( 

238 connection_list=self.connection_list, solution=np.array([[0, 0], [0, 1]]) 

239 ) 

240 

241 dataset = MazeDataset(self.config, [long_maze, short_maze]) 

242 path_length_filtered = dataset.filter_by.path_length(3) 

243 start_end_filtered = dataset.filter_by.start_end_distance(2) 

244 

245 assert type(path_length_filtered) == type(dataset) # noqa: E721 

246 assert path_length_filtered.mazes == [long_maze] 

247 assert start_end_filtered.mazes == [long_maze] 

248 assert dataset.mazes == [long_maze, short_maze] 

249 

250 def test_cut_percentile_shortest(self): 

251 solutions = [ 

252 [[0, 0], [0, 1], [1, 1]], 

253 [[0, 0], [0, 1]], 

254 [[0, 0]], 

255 ] 

256 

257 mazes = [ 

258 SolvedMaze(connection_list=self.connection_list, solution=solution) 

259 for solution in solutions 

260 ] 

261 dataset = MazeDataset(cfg=self.config, mazes=mazes) 

262 filtered = dataset.filter_by.cut_percentile_shortest(49.0) 

263 

264 assert filtered.mazes == mazes[:2] 

265 

266 

267DUPE_DATASET = [ 

268 """ 

269##### 

270# E# 

271###X# 

272#SXX# 

273#####  

274""", 

275 """ 

276##### 

277#SXE# 

278### # 

279# # 

280#####  

281""", 

282 """ 

283##### 

284# E# 

285###X# 

286#SXX# 

287#####  

288""", 

289 """ 

290##### 

291# # # 

292# # # 

293#EXS# 

294#####  

295""", 

296 """ 

297##### 

298#SXX# 

299###X# 

300#EXX# 

301#####  

302""", 

303] 

304 

305 

306def _helper_dataset_from_ascii(ascii: str) -> MazeDataset: 

307 mazes: list[SolvedMaze] = list() 

308 for maze in ascii: 

309 mazes.append(SolvedMaze.from_ascii(maze.strip())) 

310 

311 return MazeDataset( 

312 MazeDatasetConfig( 

313 name="test", grid_n=mazes[0].grid_shape[0], n_mazes=len(mazes) 

314 ), 

315 mazes, 

316 ) 

317 

318 

319def test_remove_duplicates(): 

320 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET) 

321 dataset_deduped: MazeDataset = dataset.filter_by.remove_duplicates() 

322 

323 assert len(dataset) == 5 

324 assert dataset_deduped.mazes == [dataset.mazes[3], dataset.mazes[4]] 

325 

326 

327def test_data_hash(): 

328 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET) 

329 hash_1 = dataset.data_hash() 

330 hash_2 = dataset.data_hash() 

331 

332 assert hash_1 == hash_2 

333 

334 

335def test_remove_duplicates_fast(): 

336 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET) 

337 dataset_deduped: MazeDataset = dataset.filter_by.remove_duplicates_fast() 

338 

339 assert len(dataset) == 5 

340 assert dataset_deduped.mazes == [ 

341 dataset.mazes[0], 

342 dataset.mazes[1], 

343 dataset.mazes[3], 

344 dataset.mazes[4], 

345 ]