Coverage for maze_dataset\dataset\dataset.py: 39%
193 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1"""`GPTDatasetConfig` and `GPTDataset` are base classes for datasets
2they implement some basic functionality, saving/loading, the `from_config` pipeline, and filtering
4> [!NOTE]
5> these should probably be moved into a different package, so don't rely on them being here
6"""
8import functools
9import json
10import typing
11import warnings
12from pathlib import Path
13from typing import Callable, Type
15import numpy as np
16import torch
17from muutils.json_serialize import (
18 JSONitem,
19 SerializableDataclass,
20 serializable_dataclass,
21 serializable_field,
22)
23from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash
24from muutils.mlutils import DEFAULT_SEED, GLOBAL_SEED, set_reproducibility
25from muutils.tensor_utils import DTYPE_MAP
26from torch.utils.data import Dataset
27from zanj import ZANJ
30class FilterInfoMismatchError(ValueError):
31 """raised when the filter info in a dataset config does not match the filter info in the dataset"""
33 pass
36def _dtype_serialization_fn(datatype: torch.dtype | np.dtype) -> str:
37 """convert torch dtype to string, while checking that the conversion is reversible"""
38 x_str: str = str(datatype)
39 assert x_str in DTYPE_MAP, f"unknown dtype {datatype}"
40 assert DTYPE_MAP[x_str] == datatype
41 return x_str
44def _load_applied_filters(
45 filters: list[dict[typing.Literal["name", "args", "kwargs"], str | list | dict]],
46) -> list[dict[typing.Literal["name", "args", "kwargs"], str | list | dict]]:
47 try:
48 return [
49 dict(
50 name=filter_info["name"],
51 args=tuple(
52 filter_info["args"]
53 ), # muutils/zanj save tuples as lists, and this causes problems
54 kwargs=dict(filter_info["kwargs"]),
55 )
56 for filter_info in filters
57 ]
58 except Exception as e:
59 raise ValueError(f"failed to load applied filters:\n{filters}") from e
62@serializable_dataclass(kw_only=True)
63class GPTDatasetConfig(SerializableDataclass):
64 """base GPTDatasetConfig class"""
66 name: str
68 # TODO: get rid of all these things as part of migration to tokenizer-free dataset config
69 # --------------------------------------------------
70 seq_len_min: int = serializable_field(default=1)
71 seq_len_max: int = serializable_field(default=512)
72 # --------------------------------------------------
74 seed: int | None = serializable_field(default=DEFAULT_SEED)
75 applied_filters: list[
76 dict[typing.Literal["name", "args", "kwargs"], str | list | dict]
77 ] = serializable_field(
78 default_factory=list,
79 deserialize_fn=_load_applied_filters,
80 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures
81 )
83 def __post_init__(self):
84 assert self.seq_len_min <= self.seq_len_max
85 # if seed set to None, then generate a new random seed
86 if self.seed is None:
87 self.seed = torch.random.seed() % 2**31
89 if (DEFAULT_SEED != self.seed) and (GLOBAL_SEED != self.seed):
90 warnings.warn(
91 f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED=} which has already been changed elsewhere from {DEFAULT_SEED=}"
92 )
94 set_reproducibility(self.seed)
96 def summary(self) -> dict:
97 """return a summary of the config"""
98 # do we run this to make sure it doesn't error?
99 self_ser: dict = self.serialize()
100 assert self_ser
101 return dict(
102 name=self.name,
103 seq_len_min=self.seq_len_min,
104 seq_len_max=self.seq_len_max,
105 seed=self.seed,
106 applied_filters=self.applied_filters,
107 )
109 @classmethod
110 @property
111 def _dataset_class(cls) -> type:
112 raise NotImplementedError("this should be implemented by subclasses!")
114 def to_fname(self) -> str:
115 """convert config to a filename"""
116 self_json_str: str = json.dumps(self.serialize())
117 self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10)
118 warnings.warn(
119 f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!"
120 )
121 return sanitize_fname(
122 f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}"
123 )
126def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig":
127 raise NotImplementedError(
128 f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}"
129 )
132def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem:
133 raise NotImplementedError(
134 f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}"
135 )
138GPTDatasetConfig.load = _dataset_config_load
139GPTDatasetConfig.serialize = _dataset_config_serialize
142class GPTDataset(Dataset):
143 """wrapper for torch dataset with some extra functionality
145 (meaning the functionality should be inherited in downstream classes)
147 > [!NOTE]
148 > `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config
150 # Requires:
151 the following methods should be implemented in subclasses:
152 - `__init__(self, cfg: GPTDatasetConfig, **kwargs)`
153 initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably)
154 - `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset`
155 generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation)
156 - `serialize(self) -> JSONitem`
157 serialize the dataset to a ZANJ-serializable object, including:
158 - config
159 - data in formats specified by `self.save_formats`
160 - `load(cls, data: JSONitem) -> GPTDataset`
161 load the dataset from a ZANJ-serializable object
162 - `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset`
163 given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url)
164 - `__len__(self) -> int`
165 return the length of the dataset, required for `torch.utils.data.Dataset`
166 - `__getitem__(self, i: int) -> list[str]`
167 return the ith item in the dataset, required for `torch.utils.data.Dataset`
168 return the ith item in the dataset, required for `torch.utils.data.Dataset`
169 - `update_self_config(self) -> None`
170 update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation
171 - decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters
173 # Parameters:
174 - `cfg : GPTDatasetConfig`
175 config for the dataset, used to generate the dataset
176 - `do_generate : bool`
177 whether to generate the dataset if it isn't found
178 (defaults to `True`)
179 - `load_local : bool`
180 whether to try finding the dataset locally
181 (defaults to `True`)
182 - `save_local : bool`
183 whether to save the dataset locally if it is generated or downloaded
184 (defaults to `True`)
185 - `do_download : bool`
186 whether to try downloading the dataset
187 (defaults to `True`)
188 - `local_base_path : Path`
189 where to save the dataset
190 (defaults to `Path("data/maze_dataset")`)
192 # Returns:
193 - `GPTDataset`
194 the dataset, as you wanted it
196 # Implements:
197 - `save(self, file_path: str) -> None`
198 save the dataset to a file, using ZANJ
199 - `read(cls, file_path: str) -> GPTDataset`
200 read the dataset from a file, using ZANJ
201 get all items in the dataset, in the specified format
202 - `filter_by(self)`
203 returns a namespace class
204 - `_filter_namespace(self) -> Class`
205 returns a namespace class for filtering the dataset, checking that method
206 - `_apply_filters_from_config(self) -> None`
207 apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating
209 """
211 _FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`" # type: ignore
213 @classmethod
214 def from_config(
215 cls,
216 cfg: GPTDatasetConfig,
217 do_generate: bool = True,
218 load_local: bool = True,
219 save_local: bool = True,
220 zanj: ZANJ | None = None,
221 do_download: bool = True,
222 local_base_path: Path = Path("data/maze_dataset"),
223 except_on_config_mismatch: bool = True,
224 allow_generation_metadata_filter_mismatch: bool = True,
225 verbose: bool = False,
226 **kwargs,
227 ) -> "GPTDataset":
228 """base class for gpt datasets
230 priority of loading:
231 1. load from local
232 2. download
233 3. generate
235 """
237 print_log: Callable = print if verbose else lambda *_a, **_kw: None
239 local_base_path = Path(local_base_path)
240 fname: Path = Path(f"{cfg.to_fname()}.zanj")
241 output: GPTDataset | None = None
242 did_load_local: bool = False
243 if zanj is None:
244 zanj = ZANJ()
246 print_log(f"trying to get the dataset '{cfg.to_fname()}'")
248 if not (load_local or do_download or do_generate):
249 raise ValueError(
250 "no way to load dataset! you said not to load local, not to download, and not to generate"
251 )
253 dataset_path: Path = local_base_path / fname
255 # try loading
256 if load_local:
257 if dataset_path.exists():
258 print_log(f"loading dataset from {dataset_path.as_posix()}")
259 try:
260 output = cls.read(dataset_path, zanj=zanj)
261 did_load_local = True
262 print_log("load successful!")
263 except Exception as e:
264 print_log(f"failed to load dataset: {e}")
266 if do_download and output is None:
267 print_log("seeing if we can download the dataset...")
268 try:
269 output = cls.download(cfg, **kwargs)
270 print_log("download successful!")
271 except NotImplementedError:
272 print_log("no download found, or download failed")
274 if do_generate and output is None:
275 print_log("generating dataset...")
276 output = cls.generate(cfg, verbose=verbose, **kwargs)
277 # only if we generated it, apply filters
278 output = output._apply_filters_from_config()
280 # check and save
281 if output is None:
282 raise ValueError("failed to load dataset!")
284 cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True)
285 if cfg_diff:
286 if except_on_config_mismatch:
287 if allow_generation_metadata_filter_mismatch and (
288 cfg_diff
289 == {
290 "applied_filters": {
291 "self": [],
292 "other": [
293 {
294 "name": "collect_generation_meta",
295 "args": (),
296 "kwargs": {},
297 }
298 ],
299 }
300 }
301 ):
302 pass
303 else:
304 raise ValueError(f"config mismatch: {cfg_diff = }")
305 else:
306 warnings.warn(f"config mismatch: {cfg_diff = }")
308 if save_local and not did_load_local:
309 print_log(f"saving dataset to {dataset_path}")
310 output.save(dataset_path, zanj=zanj)
312 print_log(
313 f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }"
314 )
315 return output
317 def save(self, file_path: Path | str, zanj: ZANJ | None = None):
318 if zanj is None:
319 zanj = ZANJ()
320 zanj.save(self.serialize(), file_path)
322 # serialization & loading
323 @classmethod
324 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "GPTDataset":
325 if zanj is None:
326 zanj = ZANJ()
327 return zanj.read(file_path)
329 def serialize(self) -> JSONitem:
330 raise NotImplementedError()
332 def data_hash(self) -> int:
333 raise NotImplementedError()
335 @classmethod
336 def load(cls, data: JSONitem) -> "GPTDataset":
337 raise NotImplementedError()
339 # generating & downloading
340 @classmethod
341 def generate(cls, cfg: GPTDatasetConfig, **kwargs) -> "GPTDataset":
342 raise NotImplementedError()
344 @classmethod
345 def download(cls, cfg: GPTDatasetConfig, **kwargs) -> "GPTDataset":
346 raise NotImplementedError()
348 # filtering
349 def update_self_config(self):
350 """update the config of the dataset to match the actual data, if needed
352 for example, adjust number of mazes after filtering
353 """
354 pass
356 class FilterBy:
357 """thanks GPT-4"""
359 def __init__(self, dataset: "GPTDataset"):
360 self.dataset: "GPTDataset" = dataset
362 def __getattr__(self, name: str) -> typing.Callable[..., "GPTDataset"]:
363 filter_func: DatasetFilterProtocol = getattr(
364 self.dataset._FILTER_NAMESPACE, name
365 )
367 def wrapped_filter_func(*args, **kwargs):
368 return filter_func(self.dataset, *args, **kwargs)
370 return wrapped_filter_func
372 @property
373 def filter_by(self) -> "FilterBy":
374 return self.FilterBy(self)
376 def _apply_filters_from_config(self) -> "GPTDataset":
377 """apply filters to the dataset, as specified in the config. used in `from_config()`"""
378 output: GPTDataset = self
379 # copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters
380 applied_filters_old: list[
381 dict[typing.Literal["name", "args", "kwargs"], typing.Any]
382 ] = self.cfg.applied_filters
383 output.cfg.applied_filters = list()
384 # apply the filters
385 for filter_info in applied_filters_old:
386 filter_name: str = filter_info["name"]
387 if filter_name not in output._FILTER_NAMESPACE.__dict__:
388 if filter_name.startswith("__custom__:"):
389 raise ValueError(
390 f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!"
391 )
392 else:
393 raise ValueError(
394 f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'"
395 )
396 filter_args: list = filter_info["args"] if "args" in filter_info else list()
397 filter_kwargs: dict = (
398 filter_info["kwargs"] if "kwargs" in filter_info else dict()
399 )
400 output = getattr(output.filter_by, filter_name)(
401 *filter_args, **filter_kwargs
402 )
404 # update the config, perform checks
405 # TODO: some funny business with manually specified filters here?
406 output.update_self_config()
407 _check_filter_equality(applied_filters_old, output.cfg.applied_filters)
408 return output
411def _check_filter_equality(
412 filters_old: list[
413 dict[typing.Literal["name", "args", "kwargs"], str | list | dict]
414 ],
415 filters_new: list[
416 dict[typing.Literal["name", "args", "kwargs"], str | list | dict]
417 ],
418) -> None:
419 try:
420 assert len(filters_old) == len(filters_new)
422 for filterinfo_new, filterinfo_old in zip(filters_old, filters_new):
423 # basic checks
424 assert isinstance(filterinfo_new, dict), "filterinfo_new is not a dict"
425 assert isinstance(filterinfo_old, dict), "filterinfo_old is not a dict"
426 assert all(key in filterinfo_new for key in ["name", "args", "kwargs"]), (
427 "missing keys in filterinfo_new"
428 )
429 assert all(key in filterinfo_old for key in ["name", "args", "kwargs"]), (
430 "missing keys in filterinfo_old"
431 )
433 # name
434 assert filterinfo_new["name"] == filterinfo_old["name"], (
435 "filter names don't match"
436 )
438 # args
439 assert len(filterinfo_new["args"]) == len(filterinfo_old["args"]), (
440 "filter args of different lengths"
441 )
442 for arg_new, arg_old in zip(filterinfo_new["args"], filterinfo_old["args"]):
443 assert arg_new == arg_old, "filter args don't match"
445 # kwargs
446 assert len(filterinfo_new["kwargs"]) == len(filterinfo_old["kwargs"]), (
447 "filter kwargs of different lengths"
448 )
449 for key in filterinfo_old["kwargs"]:
450 assert key in filterinfo_new["kwargs"], (
451 f"filter kwargs don't match: missing key '{key}'"
452 )
453 assert filterinfo_new["kwargs"][key] == filterinfo_old["kwargs"][key], (
454 f"filter kwargs don't match: values for key '{key}' don't match"
455 )
457 except AssertionError as e:
458 raise FilterInfoMismatchError(
459 f"config mismatch in applied filters: {filters_new} != {filters_old}"
460 ) from e
463def register_filter_namespace_for_dataset(
464 dataset_cls: Type[GPTDataset],
465) -> Callable[[Type], Type]:
466 """register the namespace class with the given dataset class"""
468 def decorator(filter_namespace_cls: Type) -> Type:
469 dataset_cls._FILTER_NAMESPACE = filter_namespace_cls
470 filter_namespace_cls._BASE_DATASET = dataset_cls
472 return filter_namespace_cls
474 return decorator
477class DatasetFilterProtocol(typing.Protocol):
478 def __call__(
479 self,
480 dataset: GPTDataset,
481 **kwargs,
482 ) -> GPTDataset: ...
485def register_dataset_filter(
486 method: DatasetFilterProtocol,
487) -> DatasetFilterProtocol:
488 """register a dataset filter, copying the underlying dataset and updating the config
490 be sure to return a COPY, not the original?
492 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset`
493 """
495 @functools.wraps(method)
496 def wrapper(dataset: GPTDataset, *args, **kwargs):
497 new_dataset = method(dataset, *args, **kwargs)
498 # update the config
499 new_dataset.cfg.applied_filters.append(
500 dict(name=method.__name__, args=args, kwargs=kwargs)
501 )
502 new_dataset.update_self_config()
503 return new_dataset
505 return wrapper