Coverage for maze_dataset\dataset\dataset.py: 39%

193 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1"""`GPTDatasetConfig` and `GPTDataset` are base classes for datasets 

2they implement some basic functionality, saving/loading, the `from_config` pipeline, and filtering 

3 

4> [!NOTE] 

5> these should probably be moved into a different package, so don't rely on them being here 

6""" 

7 

8import functools 

9import json 

10import typing 

11import warnings 

12from pathlib import Path 

13from typing import Callable, Type 

14 

15import numpy as np 

16import torch 

17from muutils.json_serialize import ( 

18 JSONitem, 

19 SerializableDataclass, 

20 serializable_dataclass, 

21 serializable_field, 

22) 

23from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash 

24from muutils.mlutils import DEFAULT_SEED, GLOBAL_SEED, set_reproducibility 

25from muutils.tensor_utils import DTYPE_MAP 

26from torch.utils.data import Dataset 

27from zanj import ZANJ 

28 

29 

30class FilterInfoMismatchError(ValueError): 

31 """raised when the filter info in a dataset config does not match the filter info in the dataset""" 

32 

33 pass 

34 

35 

36def _dtype_serialization_fn(datatype: torch.dtype | np.dtype) -> str: 

37 """convert torch dtype to string, while checking that the conversion is reversible""" 

38 x_str: str = str(datatype) 

39 assert x_str in DTYPE_MAP, f"unknown dtype {datatype}" 

40 assert DTYPE_MAP[x_str] == datatype 

41 return x_str 

42 

43 

44def _load_applied_filters( 

45 filters: list[dict[typing.Literal["name", "args", "kwargs"], str | list | dict]], 

46) -> list[dict[typing.Literal["name", "args", "kwargs"], str | list | dict]]: 

47 try: 

48 return [ 

49 dict( 

50 name=filter_info["name"], 

51 args=tuple( 

52 filter_info["args"] 

53 ), # muutils/zanj save tuples as lists, and this causes problems 

54 kwargs=dict(filter_info["kwargs"]), 

55 ) 

56 for filter_info in filters 

57 ] 

58 except Exception as e: 

59 raise ValueError(f"failed to load applied filters:\n{filters}") from e 

60 

61 

62@serializable_dataclass(kw_only=True) 

63class GPTDatasetConfig(SerializableDataclass): 

64 """base GPTDatasetConfig class""" 

65 

66 name: str 

67 

68 # TODO: get rid of all these things as part of migration to tokenizer-free dataset config 

69 # -------------------------------------------------- 

70 seq_len_min: int = serializable_field(default=1) 

71 seq_len_max: int = serializable_field(default=512) 

72 # -------------------------------------------------- 

73 

74 seed: int | None = serializable_field(default=DEFAULT_SEED) 

75 applied_filters: list[ 

76 dict[typing.Literal["name", "args", "kwargs"], str | list | dict] 

77 ] = serializable_field( 

78 default_factory=list, 

79 deserialize_fn=_load_applied_filters, 

80 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures 

81 ) 

82 

83 def __post_init__(self): 

84 assert self.seq_len_min <= self.seq_len_max 

85 # if seed set to None, then generate a new random seed 

86 if self.seed is None: 

87 self.seed = torch.random.seed() % 2**31 

88 

89 if (DEFAULT_SEED != self.seed) and (GLOBAL_SEED != self.seed): 

90 warnings.warn( 

91 f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED=} which has already been changed elsewhere from {DEFAULT_SEED=}" 

92 ) 

93 

94 set_reproducibility(self.seed) 

95 

96 def summary(self) -> dict: 

97 """return a summary of the config""" 

98 # do we run this to make sure it doesn't error? 

99 self_ser: dict = self.serialize() 

100 assert self_ser 

101 return dict( 

102 name=self.name, 

103 seq_len_min=self.seq_len_min, 

104 seq_len_max=self.seq_len_max, 

105 seed=self.seed, 

106 applied_filters=self.applied_filters, 

107 ) 

108 

109 @classmethod 

110 @property 

111 def _dataset_class(cls) -> type: 

112 raise NotImplementedError("this should be implemented by subclasses!") 

113 

114 def to_fname(self) -> str: 

115 """convert config to a filename""" 

116 self_json_str: str = json.dumps(self.serialize()) 

117 self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10) 

118 warnings.warn( 

119 f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!" 

120 ) 

121 return sanitize_fname( 

122 f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}" 

123 ) 

124 

125 

126def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig": 

127 raise NotImplementedError( 

128 f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}" 

129 ) 

130 

131 

132def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem: 

133 raise NotImplementedError( 

134 f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}" 

135 ) 

136 

137 

138GPTDatasetConfig.load = _dataset_config_load 

139GPTDatasetConfig.serialize = _dataset_config_serialize 

140 

141 

142class GPTDataset(Dataset): 

143 """wrapper for torch dataset with some extra functionality 

144 

145 (meaning the functionality should be inherited in downstream classes) 

146 

147 > [!NOTE] 

148 > `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config 

149 

150 # Requires: 

151 the following methods should be implemented in subclasses: 

152 - `__init__(self, cfg: GPTDatasetConfig, **kwargs)` 

153 initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably) 

154 - `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` 

155 generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation) 

156 - `serialize(self) -> JSONitem` 

157 serialize the dataset to a ZANJ-serializable object, including: 

158 - config 

159 - data in formats specified by `self.save_formats` 

160 - `load(cls, data: JSONitem) -> GPTDataset` 

161 load the dataset from a ZANJ-serializable object 

162 - `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` 

163 given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url) 

164 - `__len__(self) -> int` 

165 return the length of the dataset, required for `torch.utils.data.Dataset` 

166 - `__getitem__(self, i: int) -> list[str]` 

167 return the ith item in the dataset, required for `torch.utils.data.Dataset` 

168 return the ith item in the dataset, required for `torch.utils.data.Dataset` 

169 - `update_self_config(self) -> None` 

170 update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation 

171 - decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters 

172 

173 # Parameters: 

174 - `cfg : GPTDatasetConfig` 

175 config for the dataset, used to generate the dataset 

176 - `do_generate : bool` 

177 whether to generate the dataset if it isn't found 

178 (defaults to `True`) 

179 - `load_local : bool` 

180 whether to try finding the dataset locally 

181 (defaults to `True`) 

182 - `save_local : bool` 

183 whether to save the dataset locally if it is generated or downloaded 

184 (defaults to `True`) 

185 - `do_download : bool` 

186 whether to try downloading the dataset 

187 (defaults to `True`) 

188 - `local_base_path : Path` 

189 where to save the dataset 

190 (defaults to `Path("data/maze_dataset")`) 

191 

192 # Returns: 

193 - `GPTDataset` 

194 the dataset, as you wanted it 

195 

196 # Implements: 

197 - `save(self, file_path: str) -> None` 

198 save the dataset to a file, using ZANJ 

199 - `read(cls, file_path: str) -> GPTDataset` 

200 read the dataset from a file, using ZANJ 

201 get all items in the dataset, in the specified format 

202 - `filter_by(self)` 

203 returns a namespace class 

204 - `_filter_namespace(self) -> Class` 

205 returns a namespace class for filtering the dataset, checking that method 

206 - `_apply_filters_from_config(self) -> None` 

207 apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating 

208 

209 """ 

210 

211 _FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`" # type: ignore 

212 

213 @classmethod 

214 def from_config( 

215 cls, 

216 cfg: GPTDatasetConfig, 

217 do_generate: bool = True, 

218 load_local: bool = True, 

219 save_local: bool = True, 

220 zanj: ZANJ | None = None, 

221 do_download: bool = True, 

222 local_base_path: Path = Path("data/maze_dataset"), 

223 except_on_config_mismatch: bool = True, 

224 allow_generation_metadata_filter_mismatch: bool = True, 

225 verbose: bool = False, 

226 **kwargs, 

227 ) -> "GPTDataset": 

228 """base class for gpt datasets 

229 

230 priority of loading: 

231 1. load from local 

232 2. download 

233 3. generate 

234 

235 """ 

236 

237 print_log: Callable = print if verbose else lambda *_a, **_kw: None 

238 

239 local_base_path = Path(local_base_path) 

240 fname: Path = Path(f"{cfg.to_fname()}.zanj") 

241 output: GPTDataset | None = None 

242 did_load_local: bool = False 

243 if zanj is None: 

244 zanj = ZANJ() 

245 

246 print_log(f"trying to get the dataset '{cfg.to_fname()}'") 

247 

248 if not (load_local or do_download or do_generate): 

249 raise ValueError( 

250 "no way to load dataset! you said not to load local, not to download, and not to generate" 

251 ) 

252 

253 dataset_path: Path = local_base_path / fname 

254 

255 # try loading 

256 if load_local: 

257 if dataset_path.exists(): 

258 print_log(f"loading dataset from {dataset_path.as_posix()}") 

259 try: 

260 output = cls.read(dataset_path, zanj=zanj) 

261 did_load_local = True 

262 print_log("load successful!") 

263 except Exception as e: 

264 print_log(f"failed to load dataset: {e}") 

265 

266 if do_download and output is None: 

267 print_log("seeing if we can download the dataset...") 

268 try: 

269 output = cls.download(cfg, **kwargs) 

270 print_log("download successful!") 

271 except NotImplementedError: 

272 print_log("no download found, or download failed") 

273 

274 if do_generate and output is None: 

275 print_log("generating dataset...") 

276 output = cls.generate(cfg, verbose=verbose, **kwargs) 

277 # only if we generated it, apply filters 

278 output = output._apply_filters_from_config() 

279 

280 # check and save 

281 if output is None: 

282 raise ValueError("failed to load dataset!") 

283 

284 cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True) 

285 if cfg_diff: 

286 if except_on_config_mismatch: 

287 if allow_generation_metadata_filter_mismatch and ( 

288 cfg_diff 

289 == { 

290 "applied_filters": { 

291 "self": [], 

292 "other": [ 

293 { 

294 "name": "collect_generation_meta", 

295 "args": (), 

296 "kwargs": {}, 

297 } 

298 ], 

299 } 

300 } 

301 ): 

302 pass 

303 else: 

304 raise ValueError(f"config mismatch: {cfg_diff = }") 

305 else: 

306 warnings.warn(f"config mismatch: {cfg_diff = }") 

307 

308 if save_local and not did_load_local: 

309 print_log(f"saving dataset to {dataset_path}") 

310 output.save(dataset_path, zanj=zanj) 

311 

312 print_log( 

313 f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }" 

314 ) 

315 return output 

316 

317 def save(self, file_path: Path | str, zanj: ZANJ | None = None): 

318 if zanj is None: 

319 zanj = ZANJ() 

320 zanj.save(self.serialize(), file_path) 

321 

322 # serialization & loading 

323 @classmethod 

324 def read(cls, file_path: str, zanj: ZANJ | None = None) -> "GPTDataset": 

325 if zanj is None: 

326 zanj = ZANJ() 

327 return zanj.read(file_path) 

328 

329 def serialize(self) -> JSONitem: 

330 raise NotImplementedError() 

331 

332 def data_hash(self) -> int: 

333 raise NotImplementedError() 

334 

335 @classmethod 

336 def load(cls, data: JSONitem) -> "GPTDataset": 

337 raise NotImplementedError() 

338 

339 # generating & downloading 

340 @classmethod 

341 def generate(cls, cfg: GPTDatasetConfig, **kwargs) -> "GPTDataset": 

342 raise NotImplementedError() 

343 

344 @classmethod 

345 def download(cls, cfg: GPTDatasetConfig, **kwargs) -> "GPTDataset": 

346 raise NotImplementedError() 

347 

348 # filtering 

349 def update_self_config(self): 

350 """update the config of the dataset to match the actual data, if needed 

351 

352 for example, adjust number of mazes after filtering 

353 """ 

354 pass 

355 

356 class FilterBy: 

357 """thanks GPT-4""" 

358 

359 def __init__(self, dataset: "GPTDataset"): 

360 self.dataset: "GPTDataset" = dataset 

361 

362 def __getattr__(self, name: str) -> typing.Callable[..., "GPTDataset"]: 

363 filter_func: DatasetFilterProtocol = getattr( 

364 self.dataset._FILTER_NAMESPACE, name 

365 ) 

366 

367 def wrapped_filter_func(*args, **kwargs): 

368 return filter_func(self.dataset, *args, **kwargs) 

369 

370 return wrapped_filter_func 

371 

372 @property 

373 def filter_by(self) -> "FilterBy": 

374 return self.FilterBy(self) 

375 

376 def _apply_filters_from_config(self) -> "GPTDataset": 

377 """apply filters to the dataset, as specified in the config. used in `from_config()`""" 

378 output: GPTDataset = self 

379 # copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters 

380 applied_filters_old: list[ 

381 dict[typing.Literal["name", "args", "kwargs"], typing.Any] 

382 ] = self.cfg.applied_filters 

383 output.cfg.applied_filters = list() 

384 # apply the filters 

385 for filter_info in applied_filters_old: 

386 filter_name: str = filter_info["name"] 

387 if filter_name not in output._FILTER_NAMESPACE.__dict__: 

388 if filter_name.startswith("__custom__:"): 

389 raise ValueError( 

390 f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!" 

391 ) 

392 else: 

393 raise ValueError( 

394 f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'" 

395 ) 

396 filter_args: list = filter_info["args"] if "args" in filter_info else list() 

397 filter_kwargs: dict = ( 

398 filter_info["kwargs"] if "kwargs" in filter_info else dict() 

399 ) 

400 output = getattr(output.filter_by, filter_name)( 

401 *filter_args, **filter_kwargs 

402 ) 

403 

404 # update the config, perform checks 

405 # TODO: some funny business with manually specified filters here? 

406 output.update_self_config() 

407 _check_filter_equality(applied_filters_old, output.cfg.applied_filters) 

408 return output 

409 

410 

411def _check_filter_equality( 

412 filters_old: list[ 

413 dict[typing.Literal["name", "args", "kwargs"], str | list | dict] 

414 ], 

415 filters_new: list[ 

416 dict[typing.Literal["name", "args", "kwargs"], str | list | dict] 

417 ], 

418) -> None: 

419 try: 

420 assert len(filters_old) == len(filters_new) 

421 

422 for filterinfo_new, filterinfo_old in zip(filters_old, filters_new): 

423 # basic checks 

424 assert isinstance(filterinfo_new, dict), "filterinfo_new is not a dict" 

425 assert isinstance(filterinfo_old, dict), "filterinfo_old is not a dict" 

426 assert all(key in filterinfo_new for key in ["name", "args", "kwargs"]), ( 

427 "missing keys in filterinfo_new" 

428 ) 

429 assert all(key in filterinfo_old for key in ["name", "args", "kwargs"]), ( 

430 "missing keys in filterinfo_old" 

431 ) 

432 

433 # name 

434 assert filterinfo_new["name"] == filterinfo_old["name"], ( 

435 "filter names don't match" 

436 ) 

437 

438 # args 

439 assert len(filterinfo_new["args"]) == len(filterinfo_old["args"]), ( 

440 "filter args of different lengths" 

441 ) 

442 for arg_new, arg_old in zip(filterinfo_new["args"], filterinfo_old["args"]): 

443 assert arg_new == arg_old, "filter args don't match" 

444 

445 # kwargs 

446 assert len(filterinfo_new["kwargs"]) == len(filterinfo_old["kwargs"]), ( 

447 "filter kwargs of different lengths" 

448 ) 

449 for key in filterinfo_old["kwargs"]: 

450 assert key in filterinfo_new["kwargs"], ( 

451 f"filter kwargs don't match: missing key '{key}'" 

452 ) 

453 assert filterinfo_new["kwargs"][key] == filterinfo_old["kwargs"][key], ( 

454 f"filter kwargs don't match: values for key '{key}' don't match" 

455 ) 

456 

457 except AssertionError as e: 

458 raise FilterInfoMismatchError( 

459 f"config mismatch in applied filters: {filters_new} != {filters_old}" 

460 ) from e 

461 

462 

463def register_filter_namespace_for_dataset( 

464 dataset_cls: Type[GPTDataset], 

465) -> Callable[[Type], Type]: 

466 """register the namespace class with the given dataset class""" 

467 

468 def decorator(filter_namespace_cls: Type) -> Type: 

469 dataset_cls._FILTER_NAMESPACE = filter_namespace_cls 

470 filter_namespace_cls._BASE_DATASET = dataset_cls 

471 

472 return filter_namespace_cls 

473 

474 return decorator 

475 

476 

477class DatasetFilterProtocol(typing.Protocol): 

478 def __call__( 

479 self, 

480 dataset: GPTDataset, 

481 **kwargs, 

482 ) -> GPTDataset: ... 

483 

484 

485def register_dataset_filter( 

486 method: DatasetFilterProtocol, 

487) -> DatasetFilterProtocol: 

488 """register a dataset filter, copying the underlying dataset and updating the config 

489 

490 be sure to return a COPY, not the original? 

491 

492 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` 

493 """ 

494 

495 @functools.wraps(method) 

496 def wrapper(dataset: GPTDataset, *args, **kwargs): 

497 new_dataset = method(dataset, *args, **kwargs) 

498 # update the config 

499 new_dataset.cfg.applied_filters.append( 

500 dict(name=method.__name__, args=args, kwargs=kwargs) 

501 ) 

502 new_dataset.update_self_config() 

503 return new_dataset 

504 

505 return wrapper