Coverage for maze_dataset\dataset\maze_dataset.py: 45%
305 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1"""`MazeDatasetConfig` is where you decide what your dataset should look like, then pass it to `MazeDataset.from_config` to generate or load the dataset.
3see [demo_dataset notebook](../../notebooks/demo_dataset)
5"""
7import copy
8import functools
9import json
10import multiprocessing
11from pathlib import Path
12import typing
13import warnings
14from collections import Counter, defaultdict
15from typing import Callable, Optional, cast
17import numpy as np
18import tqdm
19from jaxtyping import Int
20from muutils.json_serialize import (
21 json_serialize,
22 serializable_dataclass,
23 serializable_field,
24)
25from muutils.json_serialize.util import JSONdict
26from muutils.json_serialize.util import _FORMAT_KEY
27from muutils.json_serialize.util import safe_getsource, string_as_lines
28from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash
29from zanj import ZANJ
30from zanj.loading import LoaderHandler, load_item_recursive, register_loader_handler
32from maze_dataset.constants import Coord, CoordArray, CoordTup
33from maze_dataset.dataset.dataset import (
34 DatasetFilterProtocol,
35 GPTDataset,
36 GPTDatasetConfig,
37 register_dataset_filter,
38 register_filter_namespace_for_dataset,
39)
40from maze_dataset.generation.generators import GENERATORS_MAP
41from maze_dataset.maze import LatticeMaze, SolvedMaze
43# If `n_mazes>=SERIALIZE_MINIMAL_THRESHOLD`, then the MazeDataset will use `serialize_minimal`.
44# Setting to None means that `serialize_minimal` will never be used.
45# Set to -1 to make calls to `read` use `MazeDataset._load_legacy`. Used for profiling only.
46SERIALIZE_MINIMAL_THRESHOLD: int | None = 100
49def set_serialize_minimal_threshold(threshold: int | None) -> None:
50 global SERIALIZE_MINIMAL_THRESHOLD
51 SERIALIZE_MINIMAL_THRESHOLD = threshold
54def _load_maze_ctor(maze_ctor_serialized: str | dict) -> Callable:
55 "get the maze constructor from `GENERATORS_MAP`"
56 if isinstance(maze_ctor_serialized, dict):
57 # this is both the new and old version of the serialization
58 return GENERATORS_MAP[maze_ctor_serialized["__name__"]]
59 elif isinstance(maze_ctor_serialized, str):
60 # this is a version I switched to for a while but now we are switching back
61 warnings.warn(
62 "you are loading an old model/config in `_load_maze_ctor()`!!! this should not be happening, please report: "
63 + "https://github.com/understanding-search/maze-dataset/issues/new"
64 )
65 return GENERATORS_MAP[maze_ctor_serialized]
66 else:
67 raise ValueError(
68 f"maze_ctor_serialized is of type {type(maze_ctor_serialized)}, expected str or dict"
69 )
72EndpointKwargsType = dict[
73 typing.Literal[
74 "allowed_start",
75 "allowed_end",
76 "deadend_start",
77 "deadend_end",
78 "endpoints_not_equal",
79 "except_on_no_valid_endpoint",
80 ],
81 bool | None | list[tuple[int, int]],
82]
83"type hint for `MazeDatasetConfig.endpoint_kwargs`"
86def _load_endpoint_kwargs(data: dict) -> EndpointKwargsType:
87 if data.get("endpoint_kwargs", None) is None:
88 return dict()
90 else:
91 return {
92 k: (
93 # bools and Nones are fine
94 v
95 if (isinstance(v, bool) or v is None)
96 # assume its a CoordList
97 else [tuple(x) for x in v] # muutils/zanj saves tuples as lists
98 )
99 for k, v in data["endpoint_kwargs"].items()
100 }
103@serializable_dataclass(kw_only=True, properties_to_serialize=["grid_shape"])
104class MazeDatasetConfig(GPTDatasetConfig):
105 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset"""
107 grid_n: int = serializable_field()
109 # not comparing n_mazes is done primarily to avoid conflicts which happen during `from_config` when we have applied filters
110 n_mazes: int = serializable_field(compare=False)
112 maze_ctor: Callable = serializable_field(
113 default=GENERATORS_MAP["gen_dfs"],
114 serialization_fn=lambda gen_func: {
115 "__name__": gen_func.__name__,
116 "__module__": gen_func.__module__,
117 "__doc__": string_as_lines(gen_func.__doc__),
118 "source_code": safe_getsource(gen_func),
119 },
120 loading_fn=lambda data: _load_maze_ctor(data["maze_ctor"]),
121 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures
122 )
124 maze_ctor_kwargs: dict = serializable_field(
125 default_factory=dict,
126 serialization_fn=lambda kwargs: kwargs,
127 loading_fn=lambda data: (
128 dict()
129 if data.get("maze_ctor_kwargs", None)
130 is None # this should handle the backwards compatibility
131 else data["maze_ctor_kwargs"]
132 ),
133 )
135 endpoint_kwargs: EndpointKwargsType = serializable_field(
136 default_factory=dict,
137 serialization_fn=lambda kwargs: kwargs,
138 loading_fn=_load_endpoint_kwargs,
139 assert_type=False,
140 )
142 @property
143 def grid_shape(self) -> CoordTup:
144 return (self.grid_n, self.grid_n)
146 @property
147 def grid_shape_np(self) -> Coord:
148 return np.array(self.grid_shape)
150 @property
151 def max_grid_n(self) -> int:
152 return max(self.grid_shape)
154 def stable_hash_cfg(self) -> int:
155 return stable_hash(json.dumps(self.serialize()))
157 def to_fname(self) -> str:
158 return sanitize_fname(
159 f"{self.name}-g{self.grid_n}-n{shorten_numerical_to_str(self.n_mazes)}-a_{self.maze_ctor.__name__.removeprefix('gen_')}-h{self.stable_hash_cfg() % 10**5}"
160 )
162 def summary(self) -> dict:
163 """return a summary of the config"""
164 # do we run this to make sure it doesn't error?
165 super_summary: dict = super().summary()
166 assert super_summary
167 self_ser: dict = self.serialize()
168 return dict(
169 name=self.name,
170 fname=self.to_fname(),
171 sdc_hash=self.stable_hash_cfg(),
172 seed=self.seed,
173 seq_len_min=self.seq_len_min,
174 seq_len_max=self.seq_len_max,
175 applied_filters=self.applied_filters,
176 grid_n=self_ser["grid_n"],
177 n_mazes=self_ser["n_mazes"],
178 maze_ctor_name=self_ser["maze_ctor"]["__name__"],
179 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
180 endpoint_kwargs=self_ser["endpoint_kwargs"],
181 )
184def _generate_maze_helper(index: int) -> Optional[SolvedMaze]:
185 """Helper function for generating mazes in parallel.
187 > [!CAUTION]
188 > don't use this unless generating in parallel!
189 """
190 # TODO: don't use this unless generating in parallel!
191 maze: LatticeMaze = _GLOBAL_WORKER_CONFIG.maze_ctor(
192 grid_shape=_GLOBAL_WORKER_CONFIG.grid_shape_np,
193 **_GLOBAL_WORKER_CONFIG.maze_ctor_kwargs,
194 )
196 endpoint_kwargs: EndpointKwargsType = _GLOBAL_WORKER_CONFIG.endpoint_kwargs.copy()
198 # Generate the solution
199 solution: Optional[CoordArray] = maze.generate_random_path(**endpoint_kwargs)
201 # Validate the solution
202 if (
203 solution is None
204 or len(solution) == 0
205 or not isinstance(solution, np.ndarray)
206 or len(solution.shape) != 2
207 ):
208 return None # Return None if the solution is invalid
210 return SolvedMaze.from_lattice_maze(
211 lattice_maze=maze,
212 solution=solution,
213 )
216def _maze_gen_init_worker(config: MazeDatasetConfig):
217 """special worker helper
219 > [!CAUTION]
220 > this makes the generation depend both on whether parallelism is used, and on the number of processes. this is bad!
222 """
223 # TODO
224 global _GLOBAL_WORKER_CONFIG
225 _GLOBAL_WORKER_CONFIG = config
227 process_id: tuple[int] = multiprocessing.current_process()._identity
228 if len(process_id) == 0:
229 # no multiprocessing, seed was already set
230 pass
231 elif len(process_id) == 1:
232 # multiprocessing, adjust seed based on process id
233 # only set numpy seed, since we do not use other random gens
234 np.random.seed(_GLOBAL_WORKER_CONFIG.seed + process_id[0])
235 else:
236 raise ValueError(
237 f"unexpected process id: {process_id}\n{multiprocessing.Process()}"
238 )
241class MazeDataset(GPTDataset):
242 """a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`"""
244 def __init__(
245 self,
246 cfg: MazeDatasetConfig,
247 mazes: typing.Sequence[SolvedMaze],
248 generation_metadata_collected: dict | None = None,
249 ) -> None:
250 super().__init__()
251 self.cfg: MazeDatasetConfig = cfg
252 self.mazes: list[SolvedMaze] = list(mazes)
253 self.generation_metadata_collected: dict | None = generation_metadata_collected
255 @classmethod
256 def from_config(
257 cls,
258 cfg: MazeDatasetConfig,
259 do_generate: bool = True,
260 load_local: bool = True,
261 save_local: bool = True,
262 zanj: ZANJ | None = None,
263 do_download: bool = True,
264 local_base_path: Path = Path("data/maze_dataset"),
265 except_on_config_mismatch: bool = True,
266 allow_generation_metadata_filter_mismatch: bool = True,
267 verbose: bool = False,
268 **kwargs,
269 ) -> "MazeDataset":
270 """create a maze dataset from a config
272 priority of loading:
273 1. load from local
274 2. download
275 3. generate
277 """
278 return cast(
279 MazeDataset,
280 super().from_config(
281 cfg=cfg,
282 do_generate=do_generate,
283 load_local=load_local,
284 save_local=save_local,
285 zanj=zanj,
286 do_download=do_download,
287 local_base_path=local_base_path,
288 except_on_config_mismatch=except_on_config_mismatch,
289 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
290 verbose=verbose,
291 **kwargs,
292 ),
293 )
295 def data_hash(self) -> int:
296 return stable_hash(str(tuple([x.serialize() for x in self.mazes])))
298 def __getitem__(self, i: int) -> SolvedMaze:
299 return self.mazes[i]
301 def __deepcopy__(self, memo) -> "MazeDataset":
302 return MazeDataset.load(self._serialize_full())
304 def as_tokens(
305 self,
306 maze_tokenizer, # TODO: MazeTokenizer
307 limit: int | None = None,
308 join_tokens_individual_maze: bool = False,
309 ) -> list[list[str]] | list[str]:
310 """return the dataset as tokens according to the passed `maze_tokenizer`
312 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular`
314 if `join_tokens_individual_maze` is True, then the tokens of each maze are
315 joined with a space, and the result is a list of strings.
316 i.e.:
318 >>> dataset.as_tokens(join_tokens_individual_maze=False)
319 [["a", "b", "c"], ["d", "e", "f"]]
320 >>> dataset.as_tokens(join_tokens_individual_maze=True)
321 ["a b c", "d e f"]
322 """
323 output: list[list[str]] = [
324 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
325 ]
326 if join_tokens_individual_maze:
327 return [" ".join(tokens) for tokens in output]
328 else:
329 return output
331 def __len__(self) -> int:
332 return len(self.mazes)
334 def __eq__(self, other: typing.Any) -> bool:
335 if not isinstance(other, MazeDataset):
336 return NotImplemented
337 # TODO: compare hashes of data instead of the data itself?
338 return self.cfg == other.cfg and self.mazes == other.mazes
340 @classmethod
341 def generate(
342 cls,
343 cfg: MazeDatasetConfig,
344 gen_parallel: bool = False,
345 pool_kwargs: dict | None = None,
346 verbose: bool = False,
347 ) -> "MazeDataset":
348 """Generate a maze dataset given a config and some generation parameters"""
350 # Copy the config to avoid modifying the original
351 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load(
352 json.loads(json.dumps(cfg.serialize()))
353 )
355 if pool_kwargs is None:
356 pool_kwargs = dict()
357 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment]
359 solved_mazes: list[SolvedMaze | None]
360 # Configure tqdm for progress bar
361 tqdm_kwargs: dict = dict(
362 total=cfg_cpy.n_mazes,
363 unit="maze",
364 desc="generating & solving mazes",
365 disable=not verbose,
366 )
367 # TODO: don't use the global unless generating in parallel!
368 if gen_parallel:
369 with multiprocessing.Pool(
370 **pool_kwargs,
371 initializer=_maze_gen_init_worker,
372 initargs=(cfg_cpy,),
373 ) as pool:
374 solved_mazes = list(
375 tqdm.tqdm(
376 pool.imap(_generate_maze_helper, maze_indexes), **tqdm_kwargs
377 )
378 )
380 else:
381 _maze_gen_init_worker(cfg_cpy)
382 solved_mazes = list(
383 tqdm.tqdm(
384 map(
385 _generate_maze_helper,
386 maze_indexes.tolist(),
387 ),
388 **tqdm_kwargs,
389 )
390 )
392 # Filter out None values explicitly after ensuring all results are collected
393 solved_mazes_: list[SolvedMaze] = [
394 maze for maze in solved_mazes if maze is not None
395 ]
396 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes))
398 # Update the config with the actual number of mazes
399 cfg_cpy.n_mazes = len(solved_mazes_)
401 dataset: MazeDataset = cls(
402 cfg=cfg_cpy,
403 mazes=solved_mazes_,
404 )
406 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes
408 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy
410 return dataset
412 @classmethod
413 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset":
414 raise NotImplementedError("not implemented yet")
416 @classmethod
417 def load(cls, data: JSONdict) -> "MazeDataset":
418 """load from zanj/json"""
419 if data[_FORMAT_KEY] == "MazeDataset:minimal":
420 return cls._load_minimal(data)
421 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat":
422 return cls._load_minimal_soln_cat(data)
423 elif data[_FORMAT_KEY] == "MazeDataset":
424 if (
425 SERIALIZE_MINIMAL_THRESHOLD == -1
426 ): # Allow access to `_load_legacy` for profiling
427 return cls._load_legacy(data)
428 return cls._load_full(data)
429 else:
430 raise KeyError(
431 f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })"
432 )
434 @classmethod
435 def _load_full(cls, data: JSONdict) -> "MazeDataset":
436 assert data[_FORMAT_KEY] == "MazeDataset"
437 return cls(
438 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type]
439 mazes=load_item_recursive(data["mazes"], tuple()),
440 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type]
441 )
443 @classmethod
444 def _load_minimal(cls, data: JSONdict) -> "MazeDataset":
445 assert data[_FORMAT_KEY] == "MazeDataset:minimal"
446 return cls(
447 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type]
448 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type]
449 mazes=[
450 SolvedMaze(
451 clist,
452 soln[:slen, ...],
453 )
454 for clist, slen, soln in zip(
455 load_item_recursive(data["maze_connection_lists"], tuple()),
456 load_item_recursive(data["maze_solution_lengths"], tuple()),
457 load_item_recursive(data["maze_solutions"], tuple()),
458 # load_item_recursive(data["maze_endpoints"], tuple()),
459 )
460 ],
461 )
463 @classmethod
464 def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset":
465 assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat"
467 maze_solution_lengths = load_item_recursive(
468 data["maze_solution_lengths"], tuple()
469 )
470 maze_solutions_concat = load_item_recursive(
471 data["maze_solutions_concat"], tuple()
472 )
473 maze_solutions = np.split(
474 maze_solutions_concat, np.cumsum(maze_solution_lengths)[:-1], axis=0
475 )
477 return cls(
478 cfg=load_item_recursive(data["cfg"], tuple()),
479 generation_metadata_collected=load_item_recursive(
480 data["generation_metadata_collected"], tuple()
481 ),
482 mazes=[
483 SolvedMaze(
484 connection_list=clist,
485 solution=soln,
486 )
487 for clist, soln in zip(
488 load_item_recursive(data["maze_connection_lists"], tuple()),
489 # load_item_recursive(data["maze_endpoints"], tuple()),
490 maze_solutions,
491 )
492 ],
493 )
495 @classmethod
496 def _load_legacy(cls, data: JSONdict) -> "MazeDataset":
497 """Legacy `load` method from <0.5.2. Used exclusively for profiling comparison."""
498 assert data[_FORMAT_KEY] == "MazeDataset"
499 return cls(
500 **{
501 key: load_item_recursive(data[key], tuple())
502 for key in ["cfg", "mazes", "generation_metadata_collected"]
503 }
504 )
506 def serialize(self) -> JSONdict:
507 """serialize to zanj/json"""
508 if (
509 SERIALIZE_MINIMAL_THRESHOLD is not None
510 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD
511 ):
512 return self._serialize_minimal()
513 return self._serialize_full()
515 def _serialize_full(self) -> JSONdict:
516 return {
517 _FORMAT_KEY: "MazeDataset",
518 "cfg": json_serialize(self.cfg),
519 "mazes": json_serialize(self.mazes),
520 "generation_metadata_collected": json_serialize(
521 self.generation_metadata_collected
522 ),
523 }
525 def _serialize_minimal(self) -> JSONdict:
526 "alternate serialization where metadata is collected and mazes are stored in concatenated form"
527 filtered_meta: "MazeDataset"
528 if self.generation_metadata_collected is None:
529 filtered_meta = self.filter_by.collect_generation_meta()
530 else:
531 filtered_meta = self
533 max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes)
534 n_mazes: int = len(filtered_meta.mazes)
535 grid_n: int = filtered_meta.cfg.grid_n
537 maze_connection_lists: np.ndarray = np.empty(
538 (n_mazes, 2, grid_n, grid_n), dtype=np.bool_
539 )
540 # maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
541 maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32)
542 maze_solutions: np.ndarray = np.empty(
543 (n_mazes, max_solution_len, 2), dtype=np.int8
544 )
546 for idx, maze in enumerate(filtered_meta.mazes):
547 maze_connection_lists[idx] = maze.connection_list
548 # maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
549 maze_solution_lengths[idx] = maze.solution.shape[0]
550 maze_solutions[idx, : maze.solution.shape[0]] = maze.solution
552 return {
553 _FORMAT_KEY: "MazeDataset:minimal",
554 "cfg": json_serialize(filtered_meta.cfg),
555 "generation_metadata_collected": json_serialize(
556 filtered_meta.generation_metadata_collected
557 ),
558 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item]
559 # "maze_endpoints": maze_endpoints,
560 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item]
561 "maze_solutions": maze_solutions, # type: ignore[dict-item]
562 }
564 def _serialize_minimal_soln_cat(self) -> JSONdict:
565 "alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form"
566 if self.generation_metadata_collected is None:
567 filtered_meta = self.filter_by.collect_generation_meta()
568 else:
569 filtered_meta = self
571 maze_solution_lengths: np.ndarray = np.array(
572 [m.solution.shape[0] for m in filtered_meta.mazes],
573 dtype=np.int32,
574 )
575 n_mazes: int = len(filtered_meta.mazes)
576 grid_n: int = filtered_meta.cfg.grid_n
577 total_solution_len: int = np.sum(maze_solution_lengths)
579 maze_connection_lists: np.ndarray = np.empty(
580 (n_mazes, 2, grid_n, grid_n), dtype=np.bool_
581 )
582 maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
583 maze_solutions_concat: np.ndarray = np.empty(
584 (total_solution_len, 2), dtype=np.int8
585 )
587 solutions_running_idx: int = 0
588 for idx, maze in enumerate(filtered_meta.mazes):
589 maze_connection_lists[idx] = maze.connection_list
590 maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
591 soln_len: int = maze.solution.shape[0]
592 maze_solution_lengths[idx] = soln_len
593 maze_solutions_concat[
594 solutions_running_idx : solutions_running_idx + soln_len
595 ] = maze.solution
596 solutions_running_idx += soln_len
598 return {
599 _FORMAT_KEY: "MazeDataset:minimal_soln_cat",
600 "cfg": json_serialize(filtered_meta.cfg),
601 "generation_metadata_collected": json_serialize(
602 filtered_meta.generation_metadata_collected
603 ),
604 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item]
605 "maze_endpoints": maze_endpoints, # type: ignore[dict-item]
606 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item]
607 "maze_solutions_concat": maze_solutions_concat, # type: ignore[dict-item]
608 }
610 def update_self_config(self):
611 """update the config to match the current state of the dataset (number of mazes, such as after filtering)"""
612 self.cfg.n_mazes = len(self.mazes)
614 def custom_maze_filter(
615 self,
616 method: typing.Callable[[SolvedMaze], bool],
617 **kwargs,
618 ) -> "MazeDataset":
619 """filter the dataset using a custom method"""
620 output: MazeDataset = MazeDataset(
621 cfg=copy.deepcopy(self.cfg),
622 mazes=[m for m in self.mazes if method(m, **kwargs)],
623 )
624 output.cfg.applied_filters.append(
625 {
626 "name": f"__custom__:{method.__name__}",
627 "kwargs": kwargs,
628 }
629 )
630 output.update_self_config()
631 return output
634# register things with zanj
635MazeDatasetConfig._dataset_class = property(lambda self: MazeDataset) # type: ignore[method-assign]
636register_loader_handler(
637 LoaderHandler(
638 check=lambda json_item, path=None, z=None: (
639 isinstance(json_item, typing.Mapping)
640 and _FORMAT_KEY in json_item
641 and json_item[_FORMAT_KEY].startswith("MazeDataset")
642 ),
643 load=lambda json_item, path=None, z=None: MazeDataset.load(json_item),
644 uid="MazeDataset",
645 source_pckg="maze_dataset.generation.maze_dataset",
646 desc="MazeDataset",
647 )
648)
651def register_maze_filter(
652 method: typing.Callable[[SolvedMaze, typing.Any], bool],
653) -> DatasetFilterProtocol:
654 """register a maze filter, casting it to operate over the whole list of mazes
656 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset`
658 this is a more restricted version of `register_dataset_filter` that removes the need for boilerplate for operating over the arrays
659 """
661 @functools.wraps(method)
662 def wrapper(dataset: MazeDataset, *args, **kwargs):
663 # copy and filter
664 new_dataset: MazeDataset = copy.deepcopy(
665 MazeDataset(
666 cfg=dataset.cfg,
667 mazes=[m for m in dataset.mazes if method(m, *args, **kwargs)],
668 )
669 )
670 # update the config
671 new_dataset.cfg.applied_filters.append(
672 dict(name=method.__name__, args=args, kwargs=kwargs)
673 )
674 new_dataset.update_self_config()
675 return new_dataset
677 return wrapper
680@register_filter_namespace_for_dataset(MazeDataset)
681class MazeDatasetFilters:
682 "namespace for filters for `MazeDataset`s"
684 @register_maze_filter
685 @staticmethod
686 def path_length(maze: SolvedMaze, min_length: int) -> bool:
687 """filter out mazes with a solution length less than `min_length`"""
688 return len(maze.solution) >= min_length
690 @register_maze_filter
691 @staticmethod
692 def start_end_distance(maze: SolvedMaze, min_distance: int) -> bool:
693 """filter out datasets where the start and end pos are less than `min_distance` apart on the manhattan distance (ignoring walls)"""
694 return np.linalg.norm(maze.start_pos - maze.end_pos, 1) >= min_distance
696 @register_dataset_filter
697 @staticmethod
698 def cut_percentile_shortest(
699 dataset: MazeDataset,
700 percentile: float = 10.0,
701 ) -> MazeDataset:
702 """cut the shortest `percentile` of mazes from the dataset
704 `percentile` is 1-100, not 0-1, as this is what `np.percentile` expects
705 """
706 lengths: np.ndarray = np.array([len(m.solution) for m in dataset])
707 cutoff: int = int(np.percentile(lengths, percentile))
709 filtered_mazes: list[SolvedMaze] = [
710 m for m in dataset if len(m.solution) > cutoff
711 ]
712 new_dataset: MazeDataset = MazeDataset(cfg=dataset.cfg, mazes=filtered_mazes)
714 return copy.deepcopy(new_dataset)
716 @register_dataset_filter
717 @staticmethod
718 def truncate_count(
719 dataset: MazeDataset,
720 max_count: int,
721 ) -> MazeDataset:
722 """truncate the dataset to be at most `max_count` mazes"""
723 new_dataset: MazeDataset = MazeDataset(
724 cfg=dataset.cfg, mazes=dataset.mazes[:max_count]
725 )
726 return copy.deepcopy(new_dataset)
728 @register_dataset_filter
729 @staticmethod
730 def remove_duplicates(
731 dataset: MazeDataset,
732 minimum_difference_connection_list: int | None = 1,
733 minimum_difference_solution: int | None = 1,
734 _max_dataset_len_threshold: int = 1000,
735 ) -> MazeDataset:
736 """remove duplicates from a dataset, keeping the **LAST** unique maze
738 set minimum either minimum difference to `None` to disable checking
740 if you want to avoid mazes which have more overlap, set the minimum difference to be greater
742 Gotchas:
743 - if two mazes are of different sizes, they will never be considered duplicates
744 - if two solutions are of different lengths, they will never be considered duplicates
745 TODO: check for overlap?
746 """
747 if len(dataset) > _max_dataset_len_threshold:
748 raise ValueError(
749 "this method is currently very slow for large datasets, consider using `remove_duplicates_fast` instead\n",
750 "if you know what you're doing, change `_max_dataset_len_threshold`",
751 )
753 unique_mazes: list[SolvedMaze] = list()
755 maze_a: SolvedMaze
756 maze_b: SolvedMaze
757 for i, maze_a in enumerate(dataset.mazes):
758 a_unique: bool = True
759 for maze_b in dataset.mazes[i + 1 :]:
760 # after all that nesting, more nesting to perform checks
761 if (minimum_difference_connection_list is not None) and (
762 maze_a.connection_list.shape == maze_b.connection_list.shape
763 ):
764 if (
765 np.sum(maze_a.connection_list != maze_b.connection_list)
766 <= minimum_difference_connection_list
767 ):
768 a_unique = False
769 break
771 if (minimum_difference_solution is not None) and (
772 maze_a.solution.shape == maze_b.solution.shape
773 ):
774 if (
775 np.sum(maze_a.solution != maze_b.solution)
776 <= minimum_difference_solution
777 ):
778 a_unique = False
779 break
781 if a_unique:
782 unique_mazes.append(maze_a)
784 return copy.deepcopy(
785 MazeDataset(
786 cfg=dataset.cfg,
787 mazes=unique_mazes,
788 generation_metadata_collected=dataset.generation_metadata_collected,
789 )
790 )
792 @register_dataset_filter
793 @staticmethod
794 def remove_duplicates_fast(dataset: MazeDataset) -> MazeDataset:
795 """remove duplicates from a dataset"""
797 unique_mazes = list(dict.fromkeys(dataset.mazes))
798 return copy.deepcopy(
799 MazeDataset(
800 cfg=dataset.cfg,
801 mazes=unique_mazes,
802 generation_metadata_collected=dataset.generation_metadata_collected,
803 )
804 )
806 @register_dataset_filter
807 @staticmethod
808 def strip_generation_meta(dataset: MazeDataset) -> MazeDataset:
809 """strip the generation meta from the dataset"""
810 new_dataset: MazeDataset = copy.deepcopy(dataset)
811 for maze in new_dataset:
812 # hacky because it's a frozen dataclass
813 maze.__dict__["generation_meta"] = None
814 return new_dataset
816 @register_dataset_filter
817 @staticmethod
818 def collect_generation_meta(
819 dataset: MazeDataset,
820 clear_in_mazes: bool = True,
821 inplace: bool = True,
822 allow_fail: bool = False,
823 ) -> MazeDataset:
824 if dataset.generation_metadata_collected is not None:
825 return dataset
826 else:
827 assert dataset[0].generation_meta is not None, (
828 "generation meta is not collected and original is not present"
829 )
830 # if the generation meta is already collected, don't collect it again, do nothing
832 new_dataset: MazeDataset
833 if inplace:
834 new_dataset = dataset
835 else:
836 new_dataset = copy.deepcopy(dataset)
838 gen_meta_lists: dict[bool | int | float | str | CoordTup, Counter] = (
839 defaultdict(Counter)
840 )
841 for maze in new_dataset:
842 if maze.generation_meta is None:
843 if allow_fail:
844 break
845 else:
846 raise ValueError(
847 "generation meta is not present in a maze, cannot collect generation meta"
848 )
849 for key, value in maze.generation_meta.items():
850 if isinstance(value, (bool, int, float, str)):
851 gen_meta_lists[key][value] += 1
853 elif isinstance(value, set):
854 # special case for visited_cells
855 gen_meta_lists[key].update(value)
857 elif isinstance(value, (list, np.ndarray)):
858 if isinstance(value, list):
859 try:
860 value = np.array(value)
861 except ValueError:
862 raise ValueError(
863 f"Cannot collect generation meta for {key} as it is a list of type '{str(type(value[0])) = }'",
864 "expected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords",
865 )
867 if (len(value.shape) == 1) and (value.shape[0] == maze.lattice_dim):
868 # assume its a single coordinate
869 gen_meta_lists[key][tuple(value)] += 1
870 elif (len(value.shape) == 2) and (
871 value.shape[1] == maze.lattice_dim
872 ):
873 # assume its a list of coordinates
874 gen_meta_lists[key].update([tuple(v) for v in value])
875 else:
876 raise ValueError(
877 f"Cannot collect generation meta for {key} as it is an ndarray of shape {value.shape}",
878 "expected either a coord of shape (2,) or a list of coords of shape (n, 2)",
879 )
880 else:
881 raise ValueError(
882 f"Cannot collect generation meta for {key} as it is of type '{str(type(value))}'",
883 "expected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords",
884 )
886 # clear the data
887 if clear_in_mazes:
888 # hacky because it's a frozen dataclass
889 maze.__dict__["generation_meta"] = None
891 new_dataset.generation_metadata_collected = {
892 key: dict(value) for key, value in gen_meta_lists.items()
893 }
895 return new_dataset
897 # the code below is for doing some smarter collecting and type checking. Probably will delete.
898 """
899 collect either the type at the field, or the shape of the field if it is an array
900 metadata_types: dict[str, set[type, tuple]] = dict()
901 for maze in new_dataset:
902 for key, value in maze.generation_meta.items():
903 if key not in metadata_types:
904 metadata_types[key] = set()
906 if isinstance(value, np.ndarray):
907 metadata_types[key].add(value.shape)
908 else:
909 metadata_types[key].add(type(value))
911 # figure out what to do for each field
912 metadata_actions: dict[str, typing.Callable] = dict()
913 for key, key_type in metadata_types.items():
914 if all(isinstance(kt, tuple) for kt in key_type):
915 if all(kt == (2,) for kt in key_type):
916 # its all coords, do a statcounter on those coords
917 metadata_actions[key] = lambda vals: Counter(tuple(x) for x in vals)
918 elif all(
919 (len(kt) == 2) and (kt[1] == 2)
920 for kt in key_type
921 ):
922 # its a list of coords, do a statcounter on those coords
923 metadata_actions[key] = lambda vals: Counter(
924 tuple(x) for x in np.concatenate(vals)
925 )
926 else:
927 # its a list of something else, do a counter on those
928 # TODO: throw except here?
929 metadata_actions[key] = Counter
931 elif all(kt in (bool, int, float) for kt in key_type):
932 # statcounter for numeric types
933 metadata_actions[key] = StatCounter
934 elif all(kt == str for kt in key_type):
935 # counter for string types
936 metadata_actions[key] = Counter
937 else:
938 # counter for everything else
939 # TODO: throw except here?
940 metadata_actions[key] = Counter
941 """