Coverage for maze_dataset\dataset\maze_dataset.py: 45%

305 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1"""`MazeDatasetConfig` is where you decide what your dataset should look like, then pass it to `MazeDataset.from_config` to generate or load the dataset. 

2 

3see [demo_dataset notebook](../../notebooks/demo_dataset) 

4 

5""" 

6 

7import copy 

8import functools 

9import json 

10import multiprocessing 

11from pathlib import Path 

12import typing 

13import warnings 

14from collections import Counter, defaultdict 

15from typing import Callable, Optional, cast 

16 

17import numpy as np 

18import tqdm 

19from jaxtyping import Int 

20from muutils.json_serialize import ( 

21 json_serialize, 

22 serializable_dataclass, 

23 serializable_field, 

24) 

25from muutils.json_serialize.util import JSONdict 

26from muutils.json_serialize.util import _FORMAT_KEY 

27from muutils.json_serialize.util import safe_getsource, string_as_lines 

28from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash 

29from zanj import ZANJ 

30from zanj.loading import LoaderHandler, load_item_recursive, register_loader_handler 

31 

32from maze_dataset.constants import Coord, CoordArray, CoordTup 

33from maze_dataset.dataset.dataset import ( 

34 DatasetFilterProtocol, 

35 GPTDataset, 

36 GPTDatasetConfig, 

37 register_dataset_filter, 

38 register_filter_namespace_for_dataset, 

39) 

40from maze_dataset.generation.generators import GENERATORS_MAP 

41from maze_dataset.maze import LatticeMaze, SolvedMaze 

42 

43# If `n_mazes>=SERIALIZE_MINIMAL_THRESHOLD`, then the MazeDataset will use `serialize_minimal`. 

44# Setting to None means that `serialize_minimal` will never be used. 

45# Set to -1 to make calls to `read` use `MazeDataset._load_legacy`. Used for profiling only. 

46SERIALIZE_MINIMAL_THRESHOLD: int | None = 100 

47 

48 

49def set_serialize_minimal_threshold(threshold: int | None) -> None: 

50 global SERIALIZE_MINIMAL_THRESHOLD 

51 SERIALIZE_MINIMAL_THRESHOLD = threshold 

52 

53 

54def _load_maze_ctor(maze_ctor_serialized: str | dict) -> Callable: 

55 "get the maze constructor from `GENERATORS_MAP`" 

56 if isinstance(maze_ctor_serialized, dict): 

57 # this is both the new and old version of the serialization 

58 return GENERATORS_MAP[maze_ctor_serialized["__name__"]] 

59 elif isinstance(maze_ctor_serialized, str): 

60 # this is a version I switched to for a while but now we are switching back 

61 warnings.warn( 

62 "you are loading an old model/config in `_load_maze_ctor()`!!! this should not be happening, please report: " 

63 + "https://github.com/understanding-search/maze-dataset/issues/new" 

64 ) 

65 return GENERATORS_MAP[maze_ctor_serialized] 

66 else: 

67 raise ValueError( 

68 f"maze_ctor_serialized is of type {type(maze_ctor_serialized)}, expected str or dict" 

69 ) 

70 

71 

72EndpointKwargsType = dict[ 

73 typing.Literal[ 

74 "allowed_start", 

75 "allowed_end", 

76 "deadend_start", 

77 "deadend_end", 

78 "endpoints_not_equal", 

79 "except_on_no_valid_endpoint", 

80 ], 

81 bool | None | list[tuple[int, int]], 

82] 

83"type hint for `MazeDatasetConfig.endpoint_kwargs`" 

84 

85 

86def _load_endpoint_kwargs(data: dict) -> EndpointKwargsType: 

87 if data.get("endpoint_kwargs", None) is None: 

88 return dict() 

89 

90 else: 

91 return { 

92 k: ( 

93 # bools and Nones are fine 

94 v 

95 if (isinstance(v, bool) or v is None) 

96 # assume its a CoordList 

97 else [tuple(x) for x in v] # muutils/zanj saves tuples as lists 

98 ) 

99 for k, v in data["endpoint_kwargs"].items() 

100 } 

101 

102 

103@serializable_dataclass(kw_only=True, properties_to_serialize=["grid_shape"]) 

104class MazeDatasetConfig(GPTDatasetConfig): 

105 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset""" 

106 

107 grid_n: int = serializable_field() 

108 

109 # not comparing n_mazes is done primarily to avoid conflicts which happen during `from_config` when we have applied filters 

110 n_mazes: int = serializable_field(compare=False) 

111 

112 maze_ctor: Callable = serializable_field( 

113 default=GENERATORS_MAP["gen_dfs"], 

114 serialization_fn=lambda gen_func: { 

115 "__name__": gen_func.__name__, 

116 "__module__": gen_func.__module__, 

117 "__doc__": string_as_lines(gen_func.__doc__), 

118 "source_code": safe_getsource(gen_func), 

119 }, 

120 loading_fn=lambda data: _load_maze_ctor(data["maze_ctor"]), 

121 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures 

122 ) 

123 

124 maze_ctor_kwargs: dict = serializable_field( 

125 default_factory=dict, 

126 serialization_fn=lambda kwargs: kwargs, 

127 loading_fn=lambda data: ( 

128 dict() 

129 if data.get("maze_ctor_kwargs", None) 

130 is None # this should handle the backwards compatibility 

131 else data["maze_ctor_kwargs"] 

132 ), 

133 ) 

134 

135 endpoint_kwargs: EndpointKwargsType = serializable_field( 

136 default_factory=dict, 

137 serialization_fn=lambda kwargs: kwargs, 

138 loading_fn=_load_endpoint_kwargs, 

139 assert_type=False, 

140 ) 

141 

142 @property 

143 def grid_shape(self) -> CoordTup: 

144 return (self.grid_n, self.grid_n) 

145 

146 @property 

147 def grid_shape_np(self) -> Coord: 

148 return np.array(self.grid_shape) 

149 

150 @property 

151 def max_grid_n(self) -> int: 

152 return max(self.grid_shape) 

153 

154 def stable_hash_cfg(self) -> int: 

155 return stable_hash(json.dumps(self.serialize())) 

156 

157 def to_fname(self) -> str: 

158 return sanitize_fname( 

159 f"{self.name}-g{self.grid_n}-n{shorten_numerical_to_str(self.n_mazes)}-a_{self.maze_ctor.__name__.removeprefix('gen_')}-h{self.stable_hash_cfg() % 10**5}" 

160 ) 

161 

162 def summary(self) -> dict: 

163 """return a summary of the config""" 

164 # do we run this to make sure it doesn't error? 

165 super_summary: dict = super().summary() 

166 assert super_summary 

167 self_ser: dict = self.serialize() 

168 return dict( 

169 name=self.name, 

170 fname=self.to_fname(), 

171 sdc_hash=self.stable_hash_cfg(), 

172 seed=self.seed, 

173 seq_len_min=self.seq_len_min, 

174 seq_len_max=self.seq_len_max, 

175 applied_filters=self.applied_filters, 

176 grid_n=self_ser["grid_n"], 

177 n_mazes=self_ser["n_mazes"], 

178 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 

179 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 

180 endpoint_kwargs=self_ser["endpoint_kwargs"], 

181 ) 

182 

183 

184def _generate_maze_helper(index: int) -> Optional[SolvedMaze]: 

185 """Helper function for generating mazes in parallel. 

186 

187 > [!CAUTION] 

188 > don't use this unless generating in parallel! 

189 """ 

190 # TODO: don't use this unless generating in parallel! 

191 maze: LatticeMaze = _GLOBAL_WORKER_CONFIG.maze_ctor( 

192 grid_shape=_GLOBAL_WORKER_CONFIG.grid_shape_np, 

193 **_GLOBAL_WORKER_CONFIG.maze_ctor_kwargs, 

194 ) 

195 

196 endpoint_kwargs: EndpointKwargsType = _GLOBAL_WORKER_CONFIG.endpoint_kwargs.copy() 

197 

198 # Generate the solution 

199 solution: Optional[CoordArray] = maze.generate_random_path(**endpoint_kwargs) 

200 

201 # Validate the solution 

202 if ( 

203 solution is None 

204 or len(solution) == 0 

205 or not isinstance(solution, np.ndarray) 

206 or len(solution.shape) != 2 

207 ): 

208 return None # Return None if the solution is invalid 

209 

210 return SolvedMaze.from_lattice_maze( 

211 lattice_maze=maze, 

212 solution=solution, 

213 ) 

214 

215 

216def _maze_gen_init_worker(config: MazeDatasetConfig): 

217 """special worker helper 

218 

219 > [!CAUTION] 

220 > this makes the generation depend both on whether parallelism is used, and on the number of processes. this is bad! 

221 

222 """ 

223 # TODO 

224 global _GLOBAL_WORKER_CONFIG 

225 _GLOBAL_WORKER_CONFIG = config 

226 

227 process_id: tuple[int] = multiprocessing.current_process()._identity 

228 if len(process_id) == 0: 

229 # no multiprocessing, seed was already set 

230 pass 

231 elif len(process_id) == 1: 

232 # multiprocessing, adjust seed based on process id 

233 # only set numpy seed, since we do not use other random gens 

234 np.random.seed(_GLOBAL_WORKER_CONFIG.seed + process_id[0]) 

235 else: 

236 raise ValueError( 

237 f"unexpected process id: {process_id}\n{multiprocessing.Process()}" 

238 ) 

239 

240 

241class MazeDataset(GPTDataset): 

242 """a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`""" 

243 

244 def __init__( 

245 self, 

246 cfg: MazeDatasetConfig, 

247 mazes: typing.Sequence[SolvedMaze], 

248 generation_metadata_collected: dict | None = None, 

249 ) -> None: 

250 super().__init__() 

251 self.cfg: MazeDatasetConfig = cfg 

252 self.mazes: list[SolvedMaze] = list(mazes) 

253 self.generation_metadata_collected: dict | None = generation_metadata_collected 

254 

255 @classmethod 

256 def from_config( 

257 cls, 

258 cfg: MazeDatasetConfig, 

259 do_generate: bool = True, 

260 load_local: bool = True, 

261 save_local: bool = True, 

262 zanj: ZANJ | None = None, 

263 do_download: bool = True, 

264 local_base_path: Path = Path("data/maze_dataset"), 

265 except_on_config_mismatch: bool = True, 

266 allow_generation_metadata_filter_mismatch: bool = True, 

267 verbose: bool = False, 

268 **kwargs, 

269 ) -> "MazeDataset": 

270 """create a maze dataset from a config 

271 

272 priority of loading: 

273 1. load from local 

274 2. download 

275 3. generate 

276 

277 """ 

278 return cast( 

279 MazeDataset, 

280 super().from_config( 

281 cfg=cfg, 

282 do_generate=do_generate, 

283 load_local=load_local, 

284 save_local=save_local, 

285 zanj=zanj, 

286 do_download=do_download, 

287 local_base_path=local_base_path, 

288 except_on_config_mismatch=except_on_config_mismatch, 

289 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 

290 verbose=verbose, 

291 **kwargs, 

292 ), 

293 ) 

294 

295 def data_hash(self) -> int: 

296 return stable_hash(str(tuple([x.serialize() for x in self.mazes]))) 

297 

298 def __getitem__(self, i: int) -> SolvedMaze: 

299 return self.mazes[i] 

300 

301 def __deepcopy__(self, memo) -> "MazeDataset": 

302 return MazeDataset.load(self._serialize_full()) 

303 

304 def as_tokens( 

305 self, 

306 maze_tokenizer, # TODO: MazeTokenizer 

307 limit: int | None = None, 

308 join_tokens_individual_maze: bool = False, 

309 ) -> list[list[str]] | list[str]: 

310 """return the dataset as tokens according to the passed `maze_tokenizer` 

311 

312 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 

313 

314 if `join_tokens_individual_maze` is True, then the tokens of each maze are 

315 joined with a space, and the result is a list of strings. 

316 i.e.: 

317 

318 >>> dataset.as_tokens(join_tokens_individual_maze=False) 

319 [["a", "b", "c"], ["d", "e", "f"]] 

320 >>> dataset.as_tokens(join_tokens_individual_maze=True) 

321 ["a b c", "d e f"] 

322 """ 

323 output: list[list[str]] = [ 

324 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 

325 ] 

326 if join_tokens_individual_maze: 

327 return [" ".join(tokens) for tokens in output] 

328 else: 

329 return output 

330 

331 def __len__(self) -> int: 

332 return len(self.mazes) 

333 

334 def __eq__(self, other: typing.Any) -> bool: 

335 if not isinstance(other, MazeDataset): 

336 return NotImplemented 

337 # TODO: compare hashes of data instead of the data itself? 

338 return self.cfg == other.cfg and self.mazes == other.mazes 

339 

340 @classmethod 

341 def generate( 

342 cls, 

343 cfg: MazeDatasetConfig, 

344 gen_parallel: bool = False, 

345 pool_kwargs: dict | None = None, 

346 verbose: bool = False, 

347 ) -> "MazeDataset": 

348 """Generate a maze dataset given a config and some generation parameters""" 

349 

350 # Copy the config to avoid modifying the original 

351 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 

352 json.loads(json.dumps(cfg.serialize())) 

353 ) 

354 

355 if pool_kwargs is None: 

356 pool_kwargs = dict() 

357 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 

358 

359 solved_mazes: list[SolvedMaze | None] 

360 # Configure tqdm for progress bar 

361 tqdm_kwargs: dict = dict( 

362 total=cfg_cpy.n_mazes, 

363 unit="maze", 

364 desc="generating & solving mazes", 

365 disable=not verbose, 

366 ) 

367 # TODO: don't use the global unless generating in parallel! 

368 if gen_parallel: 

369 with multiprocessing.Pool( 

370 **pool_kwargs, 

371 initializer=_maze_gen_init_worker, 

372 initargs=(cfg_cpy,), 

373 ) as pool: 

374 solved_mazes = list( 

375 tqdm.tqdm( 

376 pool.imap(_generate_maze_helper, maze_indexes), **tqdm_kwargs 

377 ) 

378 ) 

379 

380 else: 

381 _maze_gen_init_worker(cfg_cpy) 

382 solved_mazes = list( 

383 tqdm.tqdm( 

384 map( 

385 _generate_maze_helper, 

386 maze_indexes.tolist(), 

387 ), 

388 **tqdm_kwargs, 

389 ) 

390 ) 

391 

392 # Filter out None values explicitly after ensuring all results are collected 

393 solved_mazes_: list[SolvedMaze] = [ 

394 maze for maze in solved_mazes if maze is not None 

395 ] 

396 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 

397 

398 # Update the config with the actual number of mazes 

399 cfg_cpy.n_mazes = len(solved_mazes_) 

400 

401 dataset: MazeDataset = cls( 

402 cfg=cfg_cpy, 

403 mazes=solved_mazes_, 

404 ) 

405 

406 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 

407 

408 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 

409 

410 return dataset 

411 

412 @classmethod 

413 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 

414 raise NotImplementedError("not implemented yet") 

415 

416 @classmethod 

417 def load(cls, data: JSONdict) -> "MazeDataset": 

418 """load from zanj/json""" 

419 if data[_FORMAT_KEY] == "MazeDataset:minimal": 

420 return cls._load_minimal(data) 

421 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 

422 return cls._load_minimal_soln_cat(data) 

423 elif data[_FORMAT_KEY] == "MazeDataset": 

424 if ( 

425 SERIALIZE_MINIMAL_THRESHOLD == -1 

426 ): # Allow access to `_load_legacy` for profiling 

427 return cls._load_legacy(data) 

428 return cls._load_full(data) 

429 else: 

430 raise KeyError( 

431 f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 

432 ) 

433 

434 @classmethod 

435 def _load_full(cls, data: JSONdict) -> "MazeDataset": 

436 assert data[_FORMAT_KEY] == "MazeDataset" 

437 return cls( 

438 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 

439 mazes=load_item_recursive(data["mazes"], tuple()), 

440 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 

441 ) 

442 

443 @classmethod 

444 def _load_minimal(cls, data: JSONdict) -> "MazeDataset": 

445 assert data[_FORMAT_KEY] == "MazeDataset:minimal" 

446 return cls( 

447 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 

448 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 

449 mazes=[ 

450 SolvedMaze( 

451 clist, 

452 soln[:slen, ...], 

453 ) 

454 for clist, slen, soln in zip( 

455 load_item_recursive(data["maze_connection_lists"], tuple()), 

456 load_item_recursive(data["maze_solution_lengths"], tuple()), 

457 load_item_recursive(data["maze_solutions"], tuple()), 

458 # load_item_recursive(data["maze_endpoints"], tuple()), 

459 ) 

460 ], 

461 ) 

462 

463 @classmethod 

464 def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset": 

465 assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat" 

466 

467 maze_solution_lengths = load_item_recursive( 

468 data["maze_solution_lengths"], tuple() 

469 ) 

470 maze_solutions_concat = load_item_recursive( 

471 data["maze_solutions_concat"], tuple() 

472 ) 

473 maze_solutions = np.split( 

474 maze_solutions_concat, np.cumsum(maze_solution_lengths)[:-1], axis=0 

475 ) 

476 

477 return cls( 

478 cfg=load_item_recursive(data["cfg"], tuple()), 

479 generation_metadata_collected=load_item_recursive( 

480 data["generation_metadata_collected"], tuple() 

481 ), 

482 mazes=[ 

483 SolvedMaze( 

484 connection_list=clist, 

485 solution=soln, 

486 ) 

487 for clist, soln in zip( 

488 load_item_recursive(data["maze_connection_lists"], tuple()), 

489 # load_item_recursive(data["maze_endpoints"], tuple()), 

490 maze_solutions, 

491 ) 

492 ], 

493 ) 

494 

495 @classmethod 

496 def _load_legacy(cls, data: JSONdict) -> "MazeDataset": 

497 """Legacy `load` method from <0.5.2. Used exclusively for profiling comparison.""" 

498 assert data[_FORMAT_KEY] == "MazeDataset" 

499 return cls( 

500 **{ 

501 key: load_item_recursive(data[key], tuple()) 

502 for key in ["cfg", "mazes", "generation_metadata_collected"] 

503 } 

504 ) 

505 

506 def serialize(self) -> JSONdict: 

507 """serialize to zanj/json""" 

508 if ( 

509 SERIALIZE_MINIMAL_THRESHOLD is not None 

510 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 

511 ): 

512 return self._serialize_minimal() 

513 return self._serialize_full() 

514 

515 def _serialize_full(self) -> JSONdict: 

516 return { 

517 _FORMAT_KEY: "MazeDataset", 

518 "cfg": json_serialize(self.cfg), 

519 "mazes": json_serialize(self.mazes), 

520 "generation_metadata_collected": json_serialize( 

521 self.generation_metadata_collected 

522 ), 

523 } 

524 

525 def _serialize_minimal(self) -> JSONdict: 

526 "alternate serialization where metadata is collected and mazes are stored in concatenated form" 

527 filtered_meta: "MazeDataset" 

528 if self.generation_metadata_collected is None: 

529 filtered_meta = self.filter_by.collect_generation_meta() 

530 else: 

531 filtered_meta = self 

532 

533 max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes) 

534 n_mazes: int = len(filtered_meta.mazes) 

535 grid_n: int = filtered_meta.cfg.grid_n 

536 

537 maze_connection_lists: np.ndarray = np.empty( 

538 (n_mazes, 2, grid_n, grid_n), dtype=np.bool_ 

539 ) 

540 # maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 

541 maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32) 

542 maze_solutions: np.ndarray = np.empty( 

543 (n_mazes, max_solution_len, 2), dtype=np.int8 

544 ) 

545 

546 for idx, maze in enumerate(filtered_meta.mazes): 

547 maze_connection_lists[idx] = maze.connection_list 

548 # maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 

549 maze_solution_lengths[idx] = maze.solution.shape[0] 

550 maze_solutions[idx, : maze.solution.shape[0]] = maze.solution 

551 

552 return { 

553 _FORMAT_KEY: "MazeDataset:minimal", 

554 "cfg": json_serialize(filtered_meta.cfg), 

555 "generation_metadata_collected": json_serialize( 

556 filtered_meta.generation_metadata_collected 

557 ), 

558 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 

559 # "maze_endpoints": maze_endpoints, 

560 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 

561 "maze_solutions": maze_solutions, # type: ignore[dict-item] 

562 } 

563 

564 def _serialize_minimal_soln_cat(self) -> JSONdict: 

565 "alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form" 

566 if self.generation_metadata_collected is None: 

567 filtered_meta = self.filter_by.collect_generation_meta() 

568 else: 

569 filtered_meta = self 

570 

571 maze_solution_lengths: np.ndarray = np.array( 

572 [m.solution.shape[0] for m in filtered_meta.mazes], 

573 dtype=np.int32, 

574 ) 

575 n_mazes: int = len(filtered_meta.mazes) 

576 grid_n: int = filtered_meta.cfg.grid_n 

577 total_solution_len: int = np.sum(maze_solution_lengths) 

578 

579 maze_connection_lists: np.ndarray = np.empty( 

580 (n_mazes, 2, grid_n, grid_n), dtype=np.bool_ 

581 ) 

582 maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 

583 maze_solutions_concat: np.ndarray = np.empty( 

584 (total_solution_len, 2), dtype=np.int8 

585 ) 

586 

587 solutions_running_idx: int = 0 

588 for idx, maze in enumerate(filtered_meta.mazes): 

589 maze_connection_lists[idx] = maze.connection_list 

590 maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 

591 soln_len: int = maze.solution.shape[0] 

592 maze_solution_lengths[idx] = soln_len 

593 maze_solutions_concat[ 

594 solutions_running_idx : solutions_running_idx + soln_len 

595 ] = maze.solution 

596 solutions_running_idx += soln_len 

597 

598 return { 

599 _FORMAT_KEY: "MazeDataset:minimal_soln_cat", 

600 "cfg": json_serialize(filtered_meta.cfg), 

601 "generation_metadata_collected": json_serialize( 

602 filtered_meta.generation_metadata_collected 

603 ), 

604 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 

605 "maze_endpoints": maze_endpoints, # type: ignore[dict-item] 

606 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 

607 "maze_solutions_concat": maze_solutions_concat, # type: ignore[dict-item] 

608 } 

609 

610 def update_self_config(self): 

611 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 

612 self.cfg.n_mazes = len(self.mazes) 

613 

614 def custom_maze_filter( 

615 self, 

616 method: typing.Callable[[SolvedMaze], bool], 

617 **kwargs, 

618 ) -> "MazeDataset": 

619 """filter the dataset using a custom method""" 

620 output: MazeDataset = MazeDataset( 

621 cfg=copy.deepcopy(self.cfg), 

622 mazes=[m for m in self.mazes if method(m, **kwargs)], 

623 ) 

624 output.cfg.applied_filters.append( 

625 { 

626 "name": f"__custom__:{method.__name__}", 

627 "kwargs": kwargs, 

628 } 

629 ) 

630 output.update_self_config() 

631 return output 

632 

633 

634# register things with zanj 

635MazeDatasetConfig._dataset_class = property(lambda self: MazeDataset) # type: ignore[method-assign] 

636register_loader_handler( 

637 LoaderHandler( 

638 check=lambda json_item, path=None, z=None: ( 

639 isinstance(json_item, typing.Mapping) 

640 and _FORMAT_KEY in json_item 

641 and json_item[_FORMAT_KEY].startswith("MazeDataset") 

642 ), 

643 load=lambda json_item, path=None, z=None: MazeDataset.load(json_item), 

644 uid="MazeDataset", 

645 source_pckg="maze_dataset.generation.maze_dataset", 

646 desc="MazeDataset", 

647 ) 

648) 

649 

650 

651def register_maze_filter( 

652 method: typing.Callable[[SolvedMaze, typing.Any], bool], 

653) -> DatasetFilterProtocol: 

654 """register a maze filter, casting it to operate over the whole list of mazes 

655 

656 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` 

657 

658 this is a more restricted version of `register_dataset_filter` that removes the need for boilerplate for operating over the arrays 

659 """ 

660 

661 @functools.wraps(method) 

662 def wrapper(dataset: MazeDataset, *args, **kwargs): 

663 # copy and filter 

664 new_dataset: MazeDataset = copy.deepcopy( 

665 MazeDataset( 

666 cfg=dataset.cfg, 

667 mazes=[m for m in dataset.mazes if method(m, *args, **kwargs)], 

668 ) 

669 ) 

670 # update the config 

671 new_dataset.cfg.applied_filters.append( 

672 dict(name=method.__name__, args=args, kwargs=kwargs) 

673 ) 

674 new_dataset.update_self_config() 

675 return new_dataset 

676 

677 return wrapper 

678 

679 

680@register_filter_namespace_for_dataset(MazeDataset) 

681class MazeDatasetFilters: 

682 "namespace for filters for `MazeDataset`s" 

683 

684 @register_maze_filter 

685 @staticmethod 

686 def path_length(maze: SolvedMaze, min_length: int) -> bool: 

687 """filter out mazes with a solution length less than `min_length`""" 

688 return len(maze.solution) >= min_length 

689 

690 @register_maze_filter 

691 @staticmethod 

692 def start_end_distance(maze: SolvedMaze, min_distance: int) -> bool: 

693 """filter out datasets where the start and end pos are less than `min_distance` apart on the manhattan distance (ignoring walls)""" 

694 return np.linalg.norm(maze.start_pos - maze.end_pos, 1) >= min_distance 

695 

696 @register_dataset_filter 

697 @staticmethod 

698 def cut_percentile_shortest( 

699 dataset: MazeDataset, 

700 percentile: float = 10.0, 

701 ) -> MazeDataset: 

702 """cut the shortest `percentile` of mazes from the dataset 

703 

704 `percentile` is 1-100, not 0-1, as this is what `np.percentile` expects 

705 """ 

706 lengths: np.ndarray = np.array([len(m.solution) for m in dataset]) 

707 cutoff: int = int(np.percentile(lengths, percentile)) 

708 

709 filtered_mazes: list[SolvedMaze] = [ 

710 m for m in dataset if len(m.solution) > cutoff 

711 ] 

712 new_dataset: MazeDataset = MazeDataset(cfg=dataset.cfg, mazes=filtered_mazes) 

713 

714 return copy.deepcopy(new_dataset) 

715 

716 @register_dataset_filter 

717 @staticmethod 

718 def truncate_count( 

719 dataset: MazeDataset, 

720 max_count: int, 

721 ) -> MazeDataset: 

722 """truncate the dataset to be at most `max_count` mazes""" 

723 new_dataset: MazeDataset = MazeDataset( 

724 cfg=dataset.cfg, mazes=dataset.mazes[:max_count] 

725 ) 

726 return copy.deepcopy(new_dataset) 

727 

728 @register_dataset_filter 

729 @staticmethod 

730 def remove_duplicates( 

731 dataset: MazeDataset, 

732 minimum_difference_connection_list: int | None = 1, 

733 minimum_difference_solution: int | None = 1, 

734 _max_dataset_len_threshold: int = 1000, 

735 ) -> MazeDataset: 

736 """remove duplicates from a dataset, keeping the **LAST** unique maze 

737 

738 set minimum either minimum difference to `None` to disable checking 

739 

740 if you want to avoid mazes which have more overlap, set the minimum difference to be greater 

741 

742 Gotchas: 

743 - if two mazes are of different sizes, they will never be considered duplicates 

744 - if two solutions are of different lengths, they will never be considered duplicates 

745 TODO: check for overlap? 

746 """ 

747 if len(dataset) > _max_dataset_len_threshold: 

748 raise ValueError( 

749 "this method is currently very slow for large datasets, consider using `remove_duplicates_fast` instead\n", 

750 "if you know what you're doing, change `_max_dataset_len_threshold`", 

751 ) 

752 

753 unique_mazes: list[SolvedMaze] = list() 

754 

755 maze_a: SolvedMaze 

756 maze_b: SolvedMaze 

757 for i, maze_a in enumerate(dataset.mazes): 

758 a_unique: bool = True 

759 for maze_b in dataset.mazes[i + 1 :]: 

760 # after all that nesting, more nesting to perform checks 

761 if (minimum_difference_connection_list is not None) and ( 

762 maze_a.connection_list.shape == maze_b.connection_list.shape 

763 ): 

764 if ( 

765 np.sum(maze_a.connection_list != maze_b.connection_list) 

766 <= minimum_difference_connection_list 

767 ): 

768 a_unique = False 

769 break 

770 

771 if (minimum_difference_solution is not None) and ( 

772 maze_a.solution.shape == maze_b.solution.shape 

773 ): 

774 if ( 

775 np.sum(maze_a.solution != maze_b.solution) 

776 <= minimum_difference_solution 

777 ): 

778 a_unique = False 

779 break 

780 

781 if a_unique: 

782 unique_mazes.append(maze_a) 

783 

784 return copy.deepcopy( 

785 MazeDataset( 

786 cfg=dataset.cfg, 

787 mazes=unique_mazes, 

788 generation_metadata_collected=dataset.generation_metadata_collected, 

789 ) 

790 ) 

791 

792 @register_dataset_filter 

793 @staticmethod 

794 def remove_duplicates_fast(dataset: MazeDataset) -> MazeDataset: 

795 """remove duplicates from a dataset""" 

796 

797 unique_mazes = list(dict.fromkeys(dataset.mazes)) 

798 return copy.deepcopy( 

799 MazeDataset( 

800 cfg=dataset.cfg, 

801 mazes=unique_mazes, 

802 generation_metadata_collected=dataset.generation_metadata_collected, 

803 ) 

804 ) 

805 

806 @register_dataset_filter 

807 @staticmethod 

808 def strip_generation_meta(dataset: MazeDataset) -> MazeDataset: 

809 """strip the generation meta from the dataset""" 

810 new_dataset: MazeDataset = copy.deepcopy(dataset) 

811 for maze in new_dataset: 

812 # hacky because it's a frozen dataclass 

813 maze.__dict__["generation_meta"] = None 

814 return new_dataset 

815 

816 @register_dataset_filter 

817 @staticmethod 

818 def collect_generation_meta( 

819 dataset: MazeDataset, 

820 clear_in_mazes: bool = True, 

821 inplace: bool = True, 

822 allow_fail: bool = False, 

823 ) -> MazeDataset: 

824 if dataset.generation_metadata_collected is not None: 

825 return dataset 

826 else: 

827 assert dataset[0].generation_meta is not None, ( 

828 "generation meta is not collected and original is not present" 

829 ) 

830 # if the generation meta is already collected, don't collect it again, do nothing 

831 

832 new_dataset: MazeDataset 

833 if inplace: 

834 new_dataset = dataset 

835 else: 

836 new_dataset = copy.deepcopy(dataset) 

837 

838 gen_meta_lists: dict[bool | int | float | str | CoordTup, Counter] = ( 

839 defaultdict(Counter) 

840 ) 

841 for maze in new_dataset: 

842 if maze.generation_meta is None: 

843 if allow_fail: 

844 break 

845 else: 

846 raise ValueError( 

847 "generation meta is not present in a maze, cannot collect generation meta" 

848 ) 

849 for key, value in maze.generation_meta.items(): 

850 if isinstance(value, (bool, int, float, str)): 

851 gen_meta_lists[key][value] += 1 

852 

853 elif isinstance(value, set): 

854 # special case for visited_cells 

855 gen_meta_lists[key].update(value) 

856 

857 elif isinstance(value, (list, np.ndarray)): 

858 if isinstance(value, list): 

859 try: 

860 value = np.array(value) 

861 except ValueError: 

862 raise ValueError( 

863 f"Cannot collect generation meta for {key} as it is a list of type '{str(type(value[0])) = }'", 

864 "expected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords", 

865 ) 

866 

867 if (len(value.shape) == 1) and (value.shape[0] == maze.lattice_dim): 

868 # assume its a single coordinate 

869 gen_meta_lists[key][tuple(value)] += 1 

870 elif (len(value.shape) == 2) and ( 

871 value.shape[1] == maze.lattice_dim 

872 ): 

873 # assume its a list of coordinates 

874 gen_meta_lists[key].update([tuple(v) for v in value]) 

875 else: 

876 raise ValueError( 

877 f"Cannot collect generation meta for {key} as it is an ndarray of shape {value.shape}", 

878 "expected either a coord of shape (2,) or a list of coords of shape (n, 2)", 

879 ) 

880 else: 

881 raise ValueError( 

882 f"Cannot collect generation meta for {key} as it is of type '{str(type(value))}'", 

883 "expected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords", 

884 ) 

885 

886 # clear the data 

887 if clear_in_mazes: 

888 # hacky because it's a frozen dataclass 

889 maze.__dict__["generation_meta"] = None 

890 

891 new_dataset.generation_metadata_collected = { 

892 key: dict(value) for key, value in gen_meta_lists.items() 

893 } 

894 

895 return new_dataset 

896 

897 # the code below is for doing some smarter collecting and type checking. Probably will delete. 

898 """ 

899 collect either the type at the field, or the shape of the field if it is an array 

900 metadata_types: dict[str, set[type, tuple]] = dict() 

901 for maze in new_dataset: 

902 for key, value in maze.generation_meta.items(): 

903 if key not in metadata_types: 

904 metadata_types[key] = set() 

905 

906 if isinstance(value, np.ndarray): 

907 metadata_types[key].add(value.shape) 

908 else: 

909 metadata_types[key].add(type(value)) 

910 

911 # figure out what to do for each field 

912 metadata_actions: dict[str, typing.Callable] = dict() 

913 for key, key_type in metadata_types.items(): 

914 if all(isinstance(kt, tuple) for kt in key_type): 

915 if all(kt == (2,) for kt in key_type): 

916 # its all coords, do a statcounter on those coords 

917 metadata_actions[key] = lambda vals: Counter(tuple(x) for x in vals) 

918 elif all( 

919 (len(kt) == 2) and (kt[1] == 2)  

920 for kt in key_type 

921 ): 

922 # its a list of coords, do a statcounter on those coords 

923 metadata_actions[key] = lambda vals: Counter( 

924 tuple(x) for x in np.concatenate(vals) 

925 ) 

926 else: 

927 # its a list of something else, do a counter on those 

928 # TODO: throw except here? 

929 metadata_actions[key] = Counter 

930  

931 elif all(kt in (bool, int, float) for kt in key_type): 

932 # statcounter for numeric types 

933 metadata_actions[key] = StatCounter 

934 elif all(kt == str for kt in key_type): 

935 # counter for string types 

936 metadata_actions[key] = Counter 

937 else: 

938 # counter for everything else 

939 # TODO: throw except here? 

940 metadata_actions[key] = Counter 

941 """