Coverage for maze_dataset\dataset\collected_dataset.py: 33%
87 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1"""collecting different maze datasets into a single dataset, for greater variety in a training or validation set
3> [!CAUTION]
4> `MazeDatasetCollection` is not thoroughly tested and is not guaranteed to work.
6"""
8import itertools
9import json
10import typing
11from functools import cached_property
13import numpy as np
14from jaxtyping import Int
15from muutils.json_serialize import (
16 json_serialize,
17 serializable_dataclass,
18 serializable_field,
19)
20from muutils.json_serialize.util import JSONdict
21from muutils.json_serialize.util import _FORMAT_KEY
22from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash
23from zanj.loading import LoaderHandler, load_item_recursive, register_loader_handler
25from maze_dataset.constants import Coord, CoordTup
26from maze_dataset.dataset.dataset import GPTDataset, GPTDatasetConfig
27from maze_dataset.dataset.maze_dataset import MazeDataset, MazeDatasetConfig
28from maze_dataset.maze import LatticeMaze
31@serializable_dataclass(kw_only=True)
32class MazeDatasetCollectionConfig(GPTDatasetConfig):
33 """maze dataset collection configuration, including tokenizers and shuffle"""
35 maze_dataset_configs: list[MazeDatasetConfig] = serializable_field(
36 serialization_fn=lambda configs: [config.serialize() for config in configs],
37 loading_fn=lambda data: [
38 MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"]
39 ],
40 )
42 def summary(self) -> dict:
43 """return a summary of the config"""
44 return [c.summary() for c in self.maze_dataset_configs]
46 @property
47 def n_mazes(self) -> int:
48 return sum(config.n_mazes for config in self.maze_dataset_configs)
50 @property
51 def max_grid_n(self) -> int:
52 return max(config.grid_n for config in self.maze_dataset_configs)
54 @property
55 def max_grid_shape(self) -> CoordTup:
56 return (self.max_grid_n, self.max_grid_n)
58 @property
59 def max_grid_shape_np(self) -> Coord:
60 return np.array(self.max_grid_shape, dtype=np.int32)
62 def stable_hash_cfg(self) -> int:
63 return stable_hash(json.dumps(self.serialize()))
65 def to_fname(self) -> str:
66 """convert config to a filename"""
67 return sanitize_fname(
68 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}"
69 )
72class MazeDatasetCollection(GPTDataset):
73 """a collection of maze datasets"""
75 def __init__(
76 self,
77 cfg: MazeDatasetCollectionConfig,
78 maze_datasets: list[MazeDataset],
79 generation_metadata_collected: dict | None = None,
80 ) -> None:
81 super().__init__()
82 self.cfg: MazeDatasetCollectionConfig = cfg
83 self.maze_datasets: list[MazeDataset] = list(maze_datasets)
84 for c, ds in zip(self.cfg.maze_dataset_configs, self.maze_datasets):
85 assert c.name == ds.cfg.name
86 assert c == ds.cfg
88 self.generation_metadata_collected: dict | None = generation_metadata_collected
90 @property
91 def dataset_lengths(self) -> list[int]:
92 return [len(dataset) for dataset in self.maze_datasets]
94 @property
95 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]:
96 return np.array(list(itertools.accumulate(self.dataset_lengths)))
98 @cached_property
99 def mazes(self) -> list[LatticeMaze]:
100 return list(
101 itertools.chain.from_iterable(
102 dataset.mazes for dataset in self.maze_datasets
103 )
104 )
106 def __len__(self) -> int:
107 return sum(len(dataset) for dataset in self.maze_datasets)
109 def __getitem__(self, index: int):
110 # find which dataset the index belongs to
111 # we add 1, since np.searchsorted returns the
112 # index of the last element that is strictly less than the target
113 # while we want the index of the last element less than or equal to the target
114 dataset_idx: int = np.searchsorted(self.dataset_cum_lengths, index + 1)
115 index_adjusted: int = index
116 if dataset_idx > 0:
117 # if the index is 0, `dataset_idx - 1` will be -1.
118 # We just want to use the base index
119 index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1]
120 return self.maze_datasets[dataset_idx][index_adjusted]
122 @classmethod
123 def generate(
124 cls, cfg: MazeDatasetCollectionConfig, **kwargs
125 ) -> "MazeDatasetCollection":
126 datasets = [
127 MazeDataset.generate(config, **kwargs)
128 for config in cfg.maze_dataset_configs
129 ]
130 return cls(cfg, datasets)
132 @classmethod
133 def download(
134 cls, cfg: MazeDatasetCollectionConfig, **kwargs
135 ) -> "MazeDatasetCollection":
136 datasets = [
137 MazeDataset.download(config, **kwargs)
138 for config in cfg.maze_dataset_configs
139 ]
140 return cls(cfg, datasets)
142 def serialize(self) -> JSONdict:
143 return {
144 _FORMAT_KEY: "MazeDatasetCollection",
145 "cfg": self.cfg.serialize(),
146 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets],
147 "generation_metadata_collected": json_serialize(
148 self.generation_metadata_collected
149 ),
150 }
152 @classmethod
153 def load(cls, data: JSONdict) -> "MazeDatasetCollection":
154 assert data[_FORMAT_KEY] == "MazeDatasetCollection"
155 return cls(
156 **{
157 key: load_item_recursive(data[key], tuple())
158 for key in ["cfg", "maze_datasets", "generation_metadata_collected"]
159 }
160 )
162 # TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow?
163 def as_tokens(
164 self,
165 maze_tokenizer, # TODO: MazeTokenizer
166 limit: int | None = None,
167 join_tokens_individual_maze: bool = False,
168 ) -> list[list[str]] | list[str]:
169 """return the dataset as tokens
171 if join_tokens_individual_maze is True, then the tokens of each maze are
172 joined with a space, and the result is a list of strings.
173 i.e.:
174 >>> dataset.as_tokens(join_tokens_individual_maze=False)
175 [["a", "b", "c"], ["d", "e", "f"]]
176 >>> dataset.as_tokens(join_tokens_individual_maze=True)
177 ["a b c", "d e f"]
178 """
179 output: list[list[str]] = [
180 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
181 ]
182 if join_tokens_individual_maze:
183 return [" ".join(tokens) for tokens in output]
184 else:
185 return output
187 def update_self_config(self) -> None:
188 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset
189 self.cfg.__dict__["n_mazes"] = len(self)
190 for dataset in self.maze_datasets:
191 dataset.update_self_config()
194MazeDatasetCollectionConfig._dataset_class = MazeDatasetCollection
195register_loader_handler(
196 LoaderHandler(
197 check=lambda json_item, path=None, z=None: (
198 isinstance(json_item, typing.Mapping)
199 and _FORMAT_KEY in json_item
200 and json_item[_FORMAT_KEY].startswith("MazeDatasetCollection")
201 ),
202 load=lambda json_item, path=None, z=None: MazeDatasetCollection.load(json_item),
203 uid="MazeDatasetCollection",
204 source_pckg="maze_dataset.generation.maze_dataset_collection",
205 desc="MazeDatasetCollection",
206 )
207)