maze_dataset.dataset
MazeDatasetConfig
s are used to create a MazeDataset
via MazeDataset.from_config(cfg)
1"`MazeDatasetConfig`s are used to create a `MazeDataset` via `MazeDataset.from_config(cfg)`" 2 3from maze_dataset.dataset.collected_dataset import ( 4 MazeDatasetCollection, 5 MazeDatasetCollectionConfig, 6) 7from maze_dataset.dataset.maze_dataset import MazeDataset 8from maze_dataset.dataset.maze_dataset_config import MazeDatasetConfig 9 10__all__ = [ 11 # submodules 12 "collected_dataset", 13 "configs", 14 "dataset", 15 "filters", 16 "maze_dataset_config", 17 "maze_dataset", 18 "rasterized", 19 "success_predict_math", 20 # dataset classes 21 "MazeDataset", 22 "MazeDatasetConfig", 23 "MazeDatasetCollection", 24 "MazeDatasetCollectionConfig", 25]
113class MazeDataset(GPTDataset[MazeDatasetConfig]): 114 """a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`""" 115 116 def __init__( 117 self, 118 cfg: MazeDatasetConfig, 119 mazes: typing.Sequence[SolvedMaze], 120 generation_metadata_collected: dict | None = None, 121 ) -> None: 122 """initialize a maze dataset from a config and a list of solved mazes""" 123 super().__init__() 124 self.cfg: MazeDatasetConfig = cfg 125 self.mazes: list[SolvedMaze] = list(mazes) 126 self.generation_metadata_collected: dict | None = generation_metadata_collected 127 128 # TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset" [override] 129 @classmethod 130 def from_config( # type: ignore[override] 131 cls, 132 # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] 133 cfg: MazeDatasetConfig, # type: ignore[override] 134 do_generate: bool = True, 135 load_local: bool = True, 136 save_local: bool = True, 137 zanj: ZANJ | None = None, 138 do_download: bool = True, 139 local_base_path: Path = Path("data/maze_dataset"), 140 except_on_config_mismatch: bool = True, 141 allow_generation_metadata_filter_mismatch: bool = True, 142 verbose: bool = False, 143 **kwargs, 144 ) -> "MazeDataset": 145 """create a maze dataset from a config 146 147 priority of loading: 148 1. load from local 149 2. download 150 3. generate 151 152 """ 153 return cast( 154 MazeDataset, 155 super().from_config( 156 cfg=cfg, 157 do_generate=do_generate, 158 load_local=load_local, 159 save_local=save_local, 160 zanj=zanj, 161 do_download=do_download, 162 local_base_path=local_base_path, 163 except_on_config_mismatch=except_on_config_mismatch, 164 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 165 verbose=verbose, 166 **kwargs, 167 ), 168 ) 169 170 def data_hash(self) -> int: 171 """return a hash of the data""" 172 return stable_hash(str(tuple([x.serialize() for x in self.mazes]))) 173 174 def __getitem__(self, i: int) -> SolvedMaze: 175 """get a maze by index""" 176 return self.mazes[i] 177 178 def __iter__(self) -> typing.Iterator[SolvedMaze]: 179 """iterate over the mazes""" 180 return iter(self.mazes) 181 182 def __deepcopy__(self, memo) -> "MazeDataset": # noqa: ANN001 183 """deepcopy the dataset 184 185 FIX: this isnt actually a deepcopy I think? 186 """ 187 return MazeDataset.load(self._serialize_full()) 188 189 # TYPING: get type hints on the tokenizer here 190 @overload 191 def as_tokens( 192 self, 193 maze_tokenizer, # noqa: ANN001 194 limit: int | None = None, 195 join_tokens_individual_maze: Literal[False] = False, 196 ) -> list[list[str]]: ... 197 @overload 198 def as_tokens( 199 self, 200 maze_tokenizer, # noqa: ANN001 201 limit: int | None = None, 202 join_tokens_individual_maze: Literal[True] = True, 203 ) -> list[str]: ... 204 def as_tokens( 205 self, 206 maze_tokenizer, # TODO: MazeTokenizer 207 limit: int | None = None, 208 join_tokens_individual_maze: bool = False, 209 ) -> list[list[str]] | list[str]: 210 """return the dataset as tokens according to the passed `maze_tokenizer` 211 212 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 213 214 if `join_tokens_individual_maze` is True, then the tokens of each maze are 215 joined with a space, and the result is a list of strings. 216 i.e.: 217 218 >>> dataset.as_tokens(join_tokens_individual_maze=False) 219 [["a", "b", "c"], ["d", "e", "f"]] 220 >>> dataset.as_tokens(join_tokens_individual_maze=True) 221 ["a b c", "d e f"] 222 """ 223 output: list[list[str]] = [ 224 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 225 ] 226 if join_tokens_individual_maze: 227 return [" ".join(tokens) for tokens in output] 228 else: 229 return output 230 231 def __len__(self) -> int: 232 """return the number of mazes in the dataset""" 233 return len(self.mazes) 234 235 def __eq__(self, other: object) -> bool: 236 """compare two datasets""" 237 if not isinstance(other, MazeDataset): 238 raise NotImplementedError( 239 "can only compare with other MazeDataset objects", 240 ) 241 # TODO: compare hashes of data instead of the data itself? 242 return self.cfg == other.cfg and self.mazes == other.mazes 243 244 def assert_equal(self, other: "MazeDataset") -> None: 245 """assert that two datasets are equal""" 246 assert isinstance(other, MazeDataset) 247 assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" 248 assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }" 249 250 @classmethod 251 def generate( 252 cls, 253 cfg: MazeDatasetConfig, 254 gen_parallel: bool = False, 255 pool_kwargs: dict | None = None, 256 verbose: bool = False, 257 # TODO: what to do when unexpected kwargs are passed? 258 **kwargs, # noqa: ARG003 259 ) -> "MazeDataset": 260 """Generate a maze dataset given a config and some generation parameters""" 261 # Copy the config to avoid modifying the original 262 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 263 json.loads(json.dumps(cfg.serialize())), 264 ) 265 266 if pool_kwargs is None: 267 pool_kwargs = dict() 268 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 269 270 solved_mazes: list[SolvedMaze | None] 271 # Configure tqdm for progress bar 272 tqdm_kwargs: dict = dict( 273 total=cfg_cpy.n_mazes, 274 unit="maze", 275 desc="generating & solving mazes", 276 disable=not verbose, 277 ) 278 # TODO: don't use the global unless generating in parallel! 279 if gen_parallel: 280 with multiprocessing.Pool( 281 **pool_kwargs, 282 initializer=_maze_gen_init_worker, 283 initargs=(cfg_cpy,), 284 ) as pool: 285 solved_mazes = list( 286 tqdm.tqdm( 287 pool.imap(_generate_maze_helper, maze_indexes), 288 **tqdm_kwargs, 289 ), 290 ) 291 292 else: 293 _maze_gen_init_worker(cfg_cpy) 294 solved_mazes = list( 295 tqdm.tqdm( 296 map( 297 # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] 298 # why does it think tolist() returns a string? 299 _generate_maze_helper, # type: ignore[arg-type] 300 maze_indexes.tolist(), 301 ), 302 **tqdm_kwargs, 303 ), 304 ) 305 306 # Filter out None values explicitly after ensuring all results are collected 307 solved_mazes_: list[SolvedMaze] = [ 308 maze for maze in solved_mazes if maze is not None 309 ] 310 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 311 312 # Update the config with the actual number of mazes 313 cfg_cpy.n_mazes = len(solved_mazes_) 314 315 dataset: MazeDataset = cls( 316 cfg=cfg_cpy, 317 mazes=solved_mazes_, 318 ) 319 320 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 321 322 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 323 324 return dataset 325 326 @classmethod 327 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 328 "(not implemented yet!) download a maze dataset from the internet" 329 raise NotImplementedError("not implemented yet") 330 331 @classmethod 332 def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": 333 """load from zanj/json""" 334 if data[_FORMAT_KEY] == "MazeDataset:minimal": 335 return cls._load_minimal(data) 336 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 337 return cls._load_minimal_soln_cat(data) 338 elif data[_FORMAT_KEY] == "MazeDataset": 339 if ( 340 SERIALIZE_MINIMAL_THRESHOLD == -1 341 ): # Allow access to `_load_legacy` for profiling 342 return cls._load_legacy(data) 343 return cls._load_full(data) 344 else: 345 err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 346 raise KeyError( 347 err_msg, 348 ) 349 350 @classmethod 351 def _load_full(cls, data: JSONdict) -> "MazeDataset": 352 assert data[_FORMAT_KEY] == "MazeDataset" 353 return cls( 354 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 355 mazes=load_item_recursive(data["mazes"], tuple()), 356 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 357 ) 358 359 @classmethod 360 def _load_minimal(cls, data: JSONdict) -> "MazeDataset": 361 assert data[_FORMAT_KEY] == "MazeDataset:minimal" 362 return cls( 363 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 364 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 365 mazes=[ 366 SolvedMaze( 367 clist, 368 soln[:slen, ...], 369 ) 370 for clist, slen, soln in zip( 371 load_item_recursive(data["maze_connection_lists"], tuple()), 372 load_item_recursive(data["maze_solution_lengths"], tuple()), 373 load_item_recursive(data["maze_solutions"], tuple()), 374 strict=False, 375 # load_item_recursive(data["maze_endpoints"], tuple()), 376 ) 377 ], 378 ) 379 380 @classmethod 381 def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset": 382 assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat" 383 384 maze_solution_lengths = load_item_recursive( 385 data["maze_solution_lengths"], 386 tuple(), 387 ) 388 maze_solutions_concat = load_item_recursive( 389 data["maze_solutions_concat"], 390 tuple(), 391 ) 392 maze_solutions = np.split( 393 maze_solutions_concat, 394 np.cumsum(maze_solution_lengths)[:-1], 395 axis=0, 396 ) 397 398 return cls( 399 cfg=load_item_recursive(data["cfg"], tuple()), 400 generation_metadata_collected=load_item_recursive( 401 data["generation_metadata_collected"], 402 tuple(), 403 ), 404 mazes=[ 405 SolvedMaze( 406 connection_list=clist, 407 solution=soln, 408 ) 409 for clist, soln in zip( 410 load_item_recursive(data["maze_connection_lists"], tuple()), 411 # load_item_recursive(data["maze_endpoints"], tuple()), 412 maze_solutions, 413 strict=False, 414 ) 415 ], 416 ) 417 418 @classmethod 419 def _load_legacy(cls, data: JSONdict) -> "MazeDataset": 420 """Legacy `load` method from <0.5.2. Used exclusively for profiling comparison.""" 421 assert data[_FORMAT_KEY] == "MazeDataset" 422 return cls( 423 **{ 424 key: load_item_recursive(data[key], tuple()) 425 for key in ["cfg", "mazes", "generation_metadata_collected"] 426 }, 427 ) 428 429 def serialize(self) -> JSONdict: 430 """serialize to zanj/json""" 431 if ( 432 SERIALIZE_MINIMAL_THRESHOLD is not None 433 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 434 ): 435 return self._serialize_minimal() 436 return self._serialize_full() 437 438 def _serialize_full(self) -> JSONdict: 439 return { 440 _FORMAT_KEY: "MazeDataset", 441 "cfg": json_serialize(self.cfg), 442 "fname": self.cfg.to_fname(), 443 "mazes": json_serialize(self.mazes), 444 "generation_metadata_collected": json_serialize( 445 self.generation_metadata_collected, 446 ), 447 } 448 449 def _serialize_minimal(self) -> JSONdict: 450 "alternate serialization where metadata is collected and mazes are stored in concatenated form" 451 filtered_meta: MazeDataset 452 if self.generation_metadata_collected is None: 453 filtered_meta = self.filter_by.collect_generation_meta() 454 else: 455 filtered_meta = self 456 457 max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes) 458 n_mazes: int = len(filtered_meta.mazes) 459 grid_n: int = filtered_meta.cfg.grid_n 460 461 maze_connection_lists: np.ndarray = np.empty( 462 (n_mazes, 2, grid_n, grid_n), 463 dtype=np.bool_, 464 ) 465 # maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 466 maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32) 467 maze_solutions: np.ndarray = np.empty( 468 (n_mazes, max_solution_len, 2), 469 dtype=np.int8, 470 ) 471 472 for idx, maze in enumerate(filtered_meta.mazes): 473 maze_connection_lists[idx] = maze.connection_list 474 # maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 475 maze_solution_lengths[idx] = maze.solution.shape[0] 476 maze_solutions[idx, : maze.solution.shape[0]] = maze.solution 477 478 return { 479 _FORMAT_KEY: "MazeDataset:minimal", 480 "cfg": json_serialize(filtered_meta.cfg), 481 "fname": filtered_meta.cfg.to_fname(), 482 "generation_metadata_collected": json_serialize( 483 filtered_meta.generation_metadata_collected, 484 ), 485 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 486 # "maze_endpoints": maze_endpoints, 487 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 488 "maze_solutions": maze_solutions, # type: ignore[dict-item] 489 } 490 491 def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict: 492 "alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form" 493 filtered_meta: MazeDataset 494 if self.generation_metadata_collected is None: 495 filtered_meta = self.filter_by.collect_generation_meta() 496 else: 497 filtered_meta = self 498 499 maze_solution_lengths: np.ndarray = np.array( 500 [m.solution.shape[0] for m in filtered_meta.mazes], 501 dtype=np.int32, 502 ) 503 n_mazes: int = len(filtered_meta.mazes) 504 grid_n: int = filtered_meta.cfg.grid_n 505 total_solution_len: int = np.sum(maze_solution_lengths) 506 507 maze_connection_lists: np.ndarray = np.empty( 508 (n_mazes, 2, grid_n, grid_n), 509 dtype=np.bool_, 510 ) 511 maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 512 maze_solutions_concat: np.ndarray = np.empty( 513 (total_solution_len, 2), 514 dtype=np.int8, 515 ) 516 517 solutions_running_idx: int = 0 518 for idx, maze in enumerate(filtered_meta.mazes): 519 maze_connection_lists[idx] = maze.connection_list 520 maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 521 soln_len: int = maze.solution.shape[0] 522 maze_solution_lengths[idx] = soln_len 523 maze_solutions_concat[ 524 solutions_running_idx : solutions_running_idx + soln_len 525 ] = maze.solution 526 solutions_running_idx += soln_len 527 528 return { 529 _FORMAT_KEY: "MazeDataset:minimal_soln_cat", 530 "cfg": json_serialize(filtered_meta.cfg), 531 "fname": filtered_meta.cfg.to_fname(), 532 "generation_metadata_collected": json_serialize( 533 filtered_meta.generation_metadata_collected, 534 ), 535 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 536 "maze_endpoints": maze_endpoints, # type: ignore[dict-item] 537 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 538 "maze_solutions_concat": maze_solutions_concat, # type: ignore[dict-item] 539 } 540 541 def update_self_config(self) -> None: 542 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 543 if self.cfg.n_mazes != len(self.mazes): 544 warnings.warn( 545 f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", 546 ) 547 self.cfg.n_mazes = len(self.mazes) 548 549 def custom_maze_filter( 550 self, 551 method: typing.Callable[[SolvedMaze], bool], 552 **kwargs, 553 ) -> "MazeDataset": 554 """filter the dataset using a custom method""" 555 output: MazeDataset = MazeDataset( 556 cfg=copy.deepcopy(self.cfg), 557 mazes=[m for m in self.mazes if method(m, **kwargs)], 558 ) 559 output.cfg.applied_filters.append( 560 { 561 "name": f"__custom__:{method.__name__}", 562 "kwargs": kwargs, 563 }, 564 ) 565 output.update_self_config() 566 return output
a maze dataset class. This is a collection of solved mazes, and should be initialized via MazeDataset.from_config
116 def __init__( 117 self, 118 cfg: MazeDatasetConfig, 119 mazes: typing.Sequence[SolvedMaze], 120 generation_metadata_collected: dict | None = None, 121 ) -> None: 122 """initialize a maze dataset from a config and a list of solved mazes""" 123 super().__init__() 124 self.cfg: MazeDatasetConfig = cfg 125 self.mazes: list[SolvedMaze] = list(mazes) 126 self.generation_metadata_collected: dict | None = generation_metadata_collected
initialize a maze dataset from a config and a list of solved mazes
129 @classmethod 130 def from_config( # type: ignore[override] 131 cls, 132 # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] 133 cfg: MazeDatasetConfig, # type: ignore[override] 134 do_generate: bool = True, 135 load_local: bool = True, 136 save_local: bool = True, 137 zanj: ZANJ | None = None, 138 do_download: bool = True, 139 local_base_path: Path = Path("data/maze_dataset"), 140 except_on_config_mismatch: bool = True, 141 allow_generation_metadata_filter_mismatch: bool = True, 142 verbose: bool = False, 143 **kwargs, 144 ) -> "MazeDataset": 145 """create a maze dataset from a config 146 147 priority of loading: 148 1. load from local 149 2. download 150 3. generate 151 152 """ 153 return cast( 154 MazeDataset, 155 super().from_config( 156 cfg=cfg, 157 do_generate=do_generate, 158 load_local=load_local, 159 save_local=save_local, 160 zanj=zanj, 161 do_download=do_download, 162 local_base_path=local_base_path, 163 except_on_config_mismatch=except_on_config_mismatch, 164 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 165 verbose=verbose, 166 **kwargs, 167 ), 168 )
create a maze dataset from a config
priority of loading:
- load from local
- download
- generate
170 def data_hash(self) -> int: 171 """return a hash of the data""" 172 return stable_hash(str(tuple([x.serialize() for x in self.mazes])))
return a hash of the data
204 def as_tokens( 205 self, 206 maze_tokenizer, # TODO: MazeTokenizer 207 limit: int | None = None, 208 join_tokens_individual_maze: bool = False, 209 ) -> list[list[str]] | list[str]: 210 """return the dataset as tokens according to the passed `maze_tokenizer` 211 212 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 213 214 if `join_tokens_individual_maze` is True, then the tokens of each maze are 215 joined with a space, and the result is a list of strings. 216 i.e.: 217 218 >>> dataset.as_tokens(join_tokens_individual_maze=False) 219 [["a", "b", "c"], ["d", "e", "f"]] 220 >>> dataset.as_tokens(join_tokens_individual_maze=True) 221 ["a b c", "d e f"] 222 """ 223 output: list[list[str]] = [ 224 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 225 ] 226 if join_tokens_individual_maze: 227 return [" ".join(tokens) for tokens in output] 228 else: 229 return output
return the dataset as tokens according to the passed maze_tokenizer
the maze_tokenizer
should be either a MazeTokenizer
or a MazeTokenizerModular
if join_tokens_individual_maze
is True, then the tokens of each maze are
joined with a space, and the result is a list of strings.
i.e.:
>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
244 def assert_equal(self, other: "MazeDataset") -> None: 245 """assert that two datasets are equal""" 246 assert isinstance(other, MazeDataset) 247 assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" 248 assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"
assert that two datasets are equal
250 @classmethod 251 def generate( 252 cls, 253 cfg: MazeDatasetConfig, 254 gen_parallel: bool = False, 255 pool_kwargs: dict | None = None, 256 verbose: bool = False, 257 # TODO: what to do when unexpected kwargs are passed? 258 **kwargs, # noqa: ARG003 259 ) -> "MazeDataset": 260 """Generate a maze dataset given a config and some generation parameters""" 261 # Copy the config to avoid modifying the original 262 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 263 json.loads(json.dumps(cfg.serialize())), 264 ) 265 266 if pool_kwargs is None: 267 pool_kwargs = dict() 268 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 269 270 solved_mazes: list[SolvedMaze | None] 271 # Configure tqdm for progress bar 272 tqdm_kwargs: dict = dict( 273 total=cfg_cpy.n_mazes, 274 unit="maze", 275 desc="generating & solving mazes", 276 disable=not verbose, 277 ) 278 # TODO: don't use the global unless generating in parallel! 279 if gen_parallel: 280 with multiprocessing.Pool( 281 **pool_kwargs, 282 initializer=_maze_gen_init_worker, 283 initargs=(cfg_cpy,), 284 ) as pool: 285 solved_mazes = list( 286 tqdm.tqdm( 287 pool.imap(_generate_maze_helper, maze_indexes), 288 **tqdm_kwargs, 289 ), 290 ) 291 292 else: 293 _maze_gen_init_worker(cfg_cpy) 294 solved_mazes = list( 295 tqdm.tqdm( 296 map( 297 # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] 298 # why does it think tolist() returns a string? 299 _generate_maze_helper, # type: ignore[arg-type] 300 maze_indexes.tolist(), 301 ), 302 **tqdm_kwargs, 303 ), 304 ) 305 306 # Filter out None values explicitly after ensuring all results are collected 307 solved_mazes_: list[SolvedMaze] = [ 308 maze for maze in solved_mazes if maze is not None 309 ] 310 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 311 312 # Update the config with the actual number of mazes 313 cfg_cpy.n_mazes = len(solved_mazes_) 314 315 dataset: MazeDataset = cls( 316 cfg=cfg_cpy, 317 mazes=solved_mazes_, 318 ) 319 320 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 321 322 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 323 324 return dataset
Generate a maze dataset given a config and some generation parameters
326 @classmethod 327 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 328 "(not implemented yet!) download a maze dataset from the internet" 329 raise NotImplementedError("not implemented yet")
(not implemented yet!) download a maze dataset from the internet
331 @classmethod 332 def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": 333 """load from zanj/json""" 334 if data[_FORMAT_KEY] == "MazeDataset:minimal": 335 return cls._load_minimal(data) 336 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 337 return cls._load_minimal_soln_cat(data) 338 elif data[_FORMAT_KEY] == "MazeDataset": 339 if ( 340 SERIALIZE_MINIMAL_THRESHOLD == -1 341 ): # Allow access to `_load_legacy` for profiling 342 return cls._load_legacy(data) 343 return cls._load_full(data) 344 else: 345 err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 346 raise KeyError( 347 err_msg, 348 )
load from zanj/json
429 def serialize(self) -> JSONdict: 430 """serialize to zanj/json""" 431 if ( 432 SERIALIZE_MINIMAL_THRESHOLD is not None 433 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 434 ): 435 return self._serialize_minimal() 436 return self._serialize_full()
serialize to zanj/json
541 def update_self_config(self) -> None: 542 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 543 if self.cfg.n_mazes != len(self.mazes): 544 warnings.warn( 545 f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", 546 ) 547 self.cfg.n_mazes = len(self.mazes)
update the config to match the current state of the dataset (number of mazes, such as after filtering)
549 def custom_maze_filter( 550 self, 551 method: typing.Callable[[SolvedMaze], bool], 552 **kwargs, 553 ) -> "MazeDataset": 554 """filter the dataset using a custom method""" 555 output: MazeDataset = MazeDataset( 556 cfg=copy.deepcopy(self.cfg), 557 mazes=[m for m in self.mazes if method(m, **kwargs)], 558 ) 559 output.cfg.applied_filters.append( 560 { 561 "name": f"__custom__:{method.__name__}", 562 "kwargs": kwargs, 563 }, 564 ) 565 output.update_self_config() 566 return output
filter the dataset using a custom method
Inherited Members
235@serializable_dataclass(kw_only=True, methods_no_override=["serialize"]) 236class MazeDatasetConfig(_MazeDatasetConfig_base): # type: ignore[misc] 237 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset""" 238 239 @property 240 def config_version(self) -> str: 241 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 242 return "1.0" 243 244 @property 245 def versions(self) -> dict: 246 """return the versions of the config and the maze_dataset""" 247 return dict( 248 config=self.config_version, 249 maze_dataset=importlib.metadata.version("maze_dataset"), 250 ) 251 252 def serialize(self) -> dict: 253 "serialize the MazeDatasetConfig with all fields and fname" 254 return { 255 **self._serialize_base( 256 applied_filters__skip__collect_generation_meta=False 257 ), 258 "fname": self.to_fname(), 259 "versions": self.versions, 260 } 261 262 def summary(self) -> dict: 263 """return a summary of the config""" 264 # do we run this to make sure it doesn't error? 265 super_summary: dict = super().summary() 266 assert super_summary 267 self_ser: dict = self.serialize() 268 return dict( 269 name=self.name, 270 fname=self.to_fname(), 271 sdc_hash=self.stable_hash_cfg(), 272 seed=self.seed, 273 seq_len_min=self.seq_len_min, 274 seq_len_max=self.seq_len_max, 275 applied_filters=self.applied_filters, 276 grid_n=self_ser["grid_n"], 277 n_mazes=self_ser["n_mazes"], 278 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 279 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 280 endpoint_kwargs=self_ser["endpoint_kwargs"], 281 ) 282 283 def _to_ps_array(self) -> _PercolationSuccessArray: 284 """Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector. 285 286 used in predicting the success rate 287 """ 288 try: 289 assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, ( 290 f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }" 291 ) 292 assert "p" in self.maze_ctor_kwargs, ( 293 f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }" 294 ) 295 assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), ( 296 f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }" 297 ) 298 except AssertionError as e: 299 err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }" 300 raise NoPercolationInConfigError( 301 err_msg, 302 ) from e 303 304 endpoints_unique_flag: int = int( 305 # we are pretty sure it will be an int or bool here 306 self.endpoint_kwargs.get("endpoints_not_equal", True), # type: ignore[arg-type] 307 ) 308 309 # adjustment for bknutson0 310 if not ( 311 self.endpoint_kwargs.get("deadend_start", False) 312 and self.endpoint_kwargs.get("deadend_end", False) 313 ): 314 # we didnt train on this, but if either endpoint is not required to be in a dead end 315 # then requiring the endpoints to be unique does not really affect the success rate 316 # (except for very small percolation values, pure percolation generation) 317 endpoints_unique_flag = 0 318 319 return np.array( 320 [ 321 float(self.maze_ctor_kwargs["p"]), 322 float(self.grid_n), 323 float( 324 int( 325 self.endpoint_kwargs.get("deadend_start", False) # type: ignore[arg-type] 326 or self.endpoint_kwargs.get("deadend_end", False), 327 ), 328 ), 329 float(endpoints_unique_flag), 330 float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)), 331 ], 332 dtype=np.float64, 333 ) 334 335 @classmethod 336 def _from_ps_array( 337 cls, 338 arr: _PercolationSuccessArray, 339 name: str = "predict", 340 n_mazes: int = 100, 341 **kwargs, 342 ) -> "MazeDatasetConfig": 343 """Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters. 344 345 # Returns: 346 - `MazeDatasetConfig` 347 Config corresponding to `arr` 348 """ 349 return cls( 350 name=name, 351 grid_n=int(arr[1]), 352 n_mazes=n_mazes, 353 maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]], 354 maze_ctor_kwargs={"p": float(arr[0])}, 355 endpoint_kwargs=dict( 356 deadend_start=bool(arr[2]), 357 deadend_end=bool(arr[2]), 358 endpoints_not_equal=bool(arr[3]), 359 except_on_no_valid_endpoint=False, 360 ), 361 **kwargs, 362 ) 363 364 def success_fraction_estimate( 365 self, 366 except_if_all_success_expected: bool = False, 367 ) -> float: 368 """Estimate the success fraction of this config. 369 370 only valid when the generator is a percolation generator, 371 and endpoints are enforced to be dead ends 372 373 this estimate comes from `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmarks.sweep_fit` 374 375 # Parameters: 376 - `except_if_all_success_expected : bool` 377 if `True`, don't raise an error if the success fraction is below the threshold. 378 will always return `1.0` if the config is not expected to fail 379 380 # Returns: 381 - `float` 382 estimated success fraction 383 384 # Raises: 385 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 386 """ 387 try: 388 return cfg_success_predict_fn(self) 389 390 except NoPercolationInConfigError as e: 391 if except_if_all_success_expected: 392 raise e # noqa: TRY201 393 return 1.0 394 395 def success_fraction_compensate( 396 self, 397 safety_margin: float = 1.2, 398 except_if_all_success_expected: bool = False, 399 epsilon: float = 1e-2, 400 ) -> "MazeDatasetConfig": 401 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 402 403 # Parameters: 404 - `safety_margin : float` 405 safety margin to apply to the success fraction estimate 406 (defaults to `1.2`, or 20% more mazes than estimated) 407 - `except_if_all_success_expected : bool` 408 if `True`, don't raise an error if the success fraction is below the threshold. 409 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 410 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 411 since `safety_margin` is still applied. 412 (defaults to `False`) 413 - `epsilon : float` 414 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 415 (defaults to `1e-2`) 416 417 # Returns: 418 - `MazeDatasetConfig` 419 new config with adjusted `n_mazes` 420 421 # Raises: 422 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 423 """ 424 # compute and check the success fraction 425 success_fraction: float = self.success_fraction_estimate( 426 except_if_all_success_expected=except_if_all_success_expected, 427 ) 428 if success_fraction < epsilon: 429 err_msg: str = ( 430 f"{success_fraction = } is below the threshold of {epsilon = }" 431 ) 432 raise SuccessChanceTooSmallError( 433 err_msg, 434 ) 435 436 # compute the new number of mazes 437 n_mazes: int = self.n_mazes 438 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 439 440 # put it in a new config and return 441 cfg_dict: dict = self.serialize() 442 cfg_dict["n_mazes"] = new_n_mazes 443 return MazeDatasetConfig.load(cfg_dict)
config object which is passed to MazeDataset.from_config
to generate or load a dataset
239 @property 240 def config_version(self) -> str: 241 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 242 return "1.0"
return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config
244 @property 245 def versions(self) -> dict: 246 """return the versions of the config and the maze_dataset""" 247 return dict( 248 config=self.config_version, 249 maze_dataset=importlib.metadata.version("maze_dataset"), 250 )
return the versions of the config and the maze_dataset
252 def serialize(self) -> dict: 253 "serialize the MazeDatasetConfig with all fields and fname" 254 return { 255 **self._serialize_base( 256 applied_filters__skip__collect_generation_meta=False 257 ), 258 "fname": self.to_fname(), 259 "versions": self.versions, 260 }
serialize the MazeDatasetConfig with all fields and fname
262 def summary(self) -> dict: 263 """return a summary of the config""" 264 # do we run this to make sure it doesn't error? 265 super_summary: dict = super().summary() 266 assert super_summary 267 self_ser: dict = self.serialize() 268 return dict( 269 name=self.name, 270 fname=self.to_fname(), 271 sdc_hash=self.stable_hash_cfg(), 272 seed=self.seed, 273 seq_len_min=self.seq_len_min, 274 seq_len_max=self.seq_len_max, 275 applied_filters=self.applied_filters, 276 grid_n=self_ser["grid_n"], 277 n_mazes=self_ser["n_mazes"], 278 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 279 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 280 endpoint_kwargs=self_ser["endpoint_kwargs"], 281 )
return a summary of the config
364 def success_fraction_estimate( 365 self, 366 except_if_all_success_expected: bool = False, 367 ) -> float: 368 """Estimate the success fraction of this config. 369 370 only valid when the generator is a percolation generator, 371 and endpoints are enforced to be dead ends 372 373 this estimate comes from `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmarks.sweep_fit` 374 375 # Parameters: 376 - `except_if_all_success_expected : bool` 377 if `True`, don't raise an error if the success fraction is below the threshold. 378 will always return `1.0` if the config is not expected to fail 379 380 # Returns: 381 - `float` 382 estimated success fraction 383 384 # Raises: 385 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 386 """ 387 try: 388 return cfg_success_predict_fn(self) 389 390 except NoPercolationInConfigError as e: 391 if except_if_all_success_expected: 392 raise e # noqa: TRY201 393 return 1.0
Estimate the success fraction of this config.
only valid when the generator is a percolation generator, and endpoints are enforced to be dead ends
this estimate comes from estimate_dataset_fractions.ipynb
and maze_dataset.benchmarks.sweep_fit
Parameters:
except_if_all_success_expected : bool
ifTrue
, don't raise an error if the success fraction is below the threshold. will always return1.0
if the config is not expected to fail
Returns:
float
estimated success fraction
Raises:
NoPercolationInConfigError
: if the config is not expected to fail, andexcept_if_all_success_expected
isFalse
395 def success_fraction_compensate( 396 self, 397 safety_margin: float = 1.2, 398 except_if_all_success_expected: bool = False, 399 epsilon: float = 1e-2, 400 ) -> "MazeDatasetConfig": 401 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 402 403 # Parameters: 404 - `safety_margin : float` 405 safety margin to apply to the success fraction estimate 406 (defaults to `1.2`, or 20% more mazes than estimated) 407 - `except_if_all_success_expected : bool` 408 if `True`, don't raise an error if the success fraction is below the threshold. 409 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 410 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 411 since `safety_margin` is still applied. 412 (defaults to `False`) 413 - `epsilon : float` 414 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 415 (defaults to `1e-2`) 416 417 # Returns: 418 - `MazeDatasetConfig` 419 new config with adjusted `n_mazes` 420 421 # Raises: 422 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 423 """ 424 # compute and check the success fraction 425 success_fraction: float = self.success_fraction_estimate( 426 except_if_all_success_expected=except_if_all_success_expected, 427 ) 428 if success_fraction < epsilon: 429 err_msg: str = ( 430 f"{success_fraction = } is below the threshold of {epsilon = }" 431 ) 432 raise SuccessChanceTooSmallError( 433 err_msg, 434 ) 435 436 # compute the new number of mazes 437 n_mazes: int = self.n_mazes 438 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 439 440 # put it in a new config and return 441 cfg_dict: dict = self.serialize() 442 cfg_dict["n_mazes"] = new_n_mazes 443 return MazeDatasetConfig.load(cfg_dict)
return a new MazeDatasetConfig
like this one with n_mazes
adjusted to compensate for the success fraction
Parameters:
safety_margin : float
safety margin to apply to the success fraction estimate (defaults to1.2
, or 20% more mazes than estimated)except_if_all_success_expected : bool
ifTrue
, don't raise an error if the success fraction is below the threshold. this is passed toMazeDatasetConfig.success_fraction_estimate
. if your config isn't expected to fail, passing this might mean you generate more mazes than needed sincesafety_margin
is still applied. (defaults toFalse
)epsilon : float
raiseSuccessChanceTooSmallError
if the success fraction is below this threshold (defaults to1e-2
)
Returns:
MazeDatasetConfig
new config with adjustedn_mazes
Raises:
SuccessChanceTooSmallError
: if the computed success fraction is belowepsilon
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.dataset.maze_dataset_config._MazeDatasetConfig_base
- grid_n
- n_mazes
- maze_ctor
- maze_ctor_kwargs
- endpoint_kwargs
- grid_shape
- grid_shape_np
- max_grid_n
- stable_hash_cfg
- to_fname
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
84class MazeDatasetCollection(GPTDataset): 85 """a collection of maze datasets""" 86 87 def __init__( 88 self, 89 cfg: MazeDatasetCollectionConfig, 90 maze_datasets: list[MazeDataset], 91 generation_metadata_collected: dict | None = None, 92 ) -> None: 93 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" 94 super().__init__() 95 self.cfg: MazeDatasetCollectionConfig = cfg 96 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 97 for c, ds in zip( 98 self.cfg.maze_dataset_configs, 99 self.maze_datasets, 100 strict=False, 101 ): 102 assert c.name == ds.cfg.name 103 assert c == ds.cfg 104 105 self.generation_metadata_collected: dict | None = generation_metadata_collected 106 107 @property 108 def dataset_lengths(self) -> list[int]: 109 """return the lengths of each dataset in the collection""" 110 return [len(dataset) for dataset in self.maze_datasets] 111 112 @property 113 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 114 """return the cumulative lengths of each dataset in the collection""" 115 return np.array(list(itertools.accumulate(self.dataset_lengths))) 116 117 @cached_property 118 def mazes(self) -> list[LatticeMaze]: 119 "single list of all mazes in the collection" 120 return list( 121 itertools.chain.from_iterable( 122 dataset.mazes for dataset in self.maze_datasets 123 ), 124 ) 125 126 def __len__(self) -> int: 127 """return the total number of mazes in the collection""" 128 return sum(len(dataset) for dataset in self.maze_datasets) 129 130 def __getitem__(self, index: int) -> LatticeMaze: 131 "get a maze by index" 132 # find which dataset the index belongs to 133 # we add 1, since np.searchsorted returns the 134 # index of the last element that is strictly less than the target 135 # while we want the index of the last element less than or equal to the target 136 dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1)) 137 index_adjusted: int = index 138 if dataset_idx > 0: 139 # if the index is 0, `dataset_idx - 1` will be -1. 140 # We just want to use the base index 141 index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1] 142 return self.maze_datasets[dataset_idx][index_adjusted] 143 144 @classmethod 145 def generate( 146 cls, 147 cfg: MazeDatasetCollectionConfig, 148 **kwargs, 149 ) -> "MazeDatasetCollection": 150 """generate a dataset collection from a config""" 151 datasets = [ 152 MazeDataset.generate(config, **kwargs) 153 for config in cfg.maze_dataset_configs 154 ] 155 return cls(cfg, datasets) 156 157 @classmethod 158 def download( 159 cls, 160 cfg: MazeDatasetCollectionConfig, 161 **kwargs, 162 ) -> "MazeDatasetCollection": 163 "(not implemented!) download a dataset collection from a config" 164 datasets = [ 165 MazeDataset.download(config, **kwargs) 166 for config in cfg.maze_dataset_configs 167 ] 168 return cls(cfg, datasets) 169 170 def serialize(self) -> JSONdict: 171 """serialize the dataset collection""" 172 return { 173 _FORMAT_KEY: "MazeDatasetCollection", 174 "cfg": self.cfg.serialize(), 175 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 176 "generation_metadata_collected": json_serialize( 177 self.generation_metadata_collected, 178 ), 179 } 180 181 @classmethod 182 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 183 """load the dataset collection from the representation created by `serialize`""" 184 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 185 return cls( 186 **{ 187 key: load_item_recursive(data[key], tuple()) 188 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 189 }, 190 ) 191 192 # TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow? 193 def as_tokens( 194 self, 195 # TODO: MazeTokenizer 196 maze_tokenizer, # noqa: ANN001 197 limit: int | None = None, 198 join_tokens_individual_maze: bool = False, 199 ) -> list[list[str]] | list[str]: 200 """return the dataset as tokens 201 202 if join_tokens_individual_maze is True, then the tokens of each maze are 203 joined with a space, and the result is a list of strings. 204 i.e.: 205 >>> dataset.as_tokens(join_tokens_individual_maze=False) 206 [["a", "b", "c"], ["d", "e", "f"]] 207 >>> dataset.as_tokens(join_tokens_individual_maze=True) 208 ["a b c", "d e f"] 209 """ 210 output: list[list[str]] = [ 211 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 212 ] 213 if join_tokens_individual_maze: 214 return [" ".join(tokens) for tokens in output] 215 else: 216 return output 217 218 def update_self_config(self) -> None: 219 "update the config to match the number of mazes, and update the underlying configs of each dataset" 220 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 221 self.cfg.__dict__["n_mazes"] = len(self) 222 for dataset in self.maze_datasets: 223 dataset.update_self_config() 224 225 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]
a collection of maze datasets
87 def __init__( 88 self, 89 cfg: MazeDatasetCollectionConfig, 90 maze_datasets: list[MazeDataset], 91 generation_metadata_collected: dict | None = None, 92 ) -> None: 93 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" 94 super().__init__() 95 self.cfg: MazeDatasetCollectionConfig = cfg 96 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 97 for c, ds in zip( 98 self.cfg.maze_dataset_configs, 99 self.maze_datasets, 100 strict=False, 101 ): 102 assert c.name == ds.cfg.name 103 assert c == ds.cfg 104 105 self.generation_metadata_collected: dict | None = generation_metadata_collected
initialize the dataset collection from a MazeDatasetCollectionConfig
and a list of MazeDataset
s
107 @property 108 def dataset_lengths(self) -> list[int]: 109 """return the lengths of each dataset in the collection""" 110 return [len(dataset) for dataset in self.maze_datasets]
return the lengths of each dataset in the collection
112 @property 113 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 114 """return the cumulative lengths of each dataset in the collection""" 115 return np.array(list(itertools.accumulate(self.dataset_lengths)))
return the cumulative lengths of each dataset in the collection
117 @cached_property 118 def mazes(self) -> list[LatticeMaze]: 119 "single list of all mazes in the collection" 120 return list( 121 itertools.chain.from_iterable( 122 dataset.mazes for dataset in self.maze_datasets 123 ), 124 )
single list of all mazes in the collection
144 @classmethod 145 def generate( 146 cls, 147 cfg: MazeDatasetCollectionConfig, 148 **kwargs, 149 ) -> "MazeDatasetCollection": 150 """generate a dataset collection from a config""" 151 datasets = [ 152 MazeDataset.generate(config, **kwargs) 153 for config in cfg.maze_dataset_configs 154 ] 155 return cls(cfg, datasets)
generate a dataset collection from a config
157 @classmethod 158 def download( 159 cls, 160 cfg: MazeDatasetCollectionConfig, 161 **kwargs, 162 ) -> "MazeDatasetCollection": 163 "(not implemented!) download a dataset collection from a config" 164 datasets = [ 165 MazeDataset.download(config, **kwargs) 166 for config in cfg.maze_dataset_configs 167 ] 168 return cls(cfg, datasets)
(not implemented!) download a dataset collection from a config
170 def serialize(self) -> JSONdict: 171 """serialize the dataset collection""" 172 return { 173 _FORMAT_KEY: "MazeDatasetCollection", 174 "cfg": self.cfg.serialize(), 175 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 176 "generation_metadata_collected": json_serialize( 177 self.generation_metadata_collected, 178 ), 179 }
serialize the dataset collection
181 @classmethod 182 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 183 """load the dataset collection from the representation created by `serialize`""" 184 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 185 return cls( 186 **{ 187 key: load_item_recursive(data[key], tuple()) 188 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 189 }, 190 )
load the dataset collection from the representation created by serialize
193 def as_tokens( 194 self, 195 # TODO: MazeTokenizer 196 maze_tokenizer, # noqa: ANN001 197 limit: int | None = None, 198 join_tokens_individual_maze: bool = False, 199 ) -> list[list[str]] | list[str]: 200 """return the dataset as tokens 201 202 if join_tokens_individual_maze is True, then the tokens of each maze are 203 joined with a space, and the result is a list of strings. 204 i.e.: 205 >>> dataset.as_tokens(join_tokens_individual_maze=False) 206 [["a", "b", "c"], ["d", "e", "f"]] 207 >>> dataset.as_tokens(join_tokens_individual_maze=True) 208 ["a b c", "d e f"] 209 """ 210 output: list[list[str]] = [ 211 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 212 ] 213 if join_tokens_individual_maze: 214 return [" ".join(tokens) for tokens in output] 215 else: 216 return output
return the dataset as tokens
if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:
>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
218 def update_self_config(self) -> None: 219 "update the config to match the number of mazes, and update the underlying configs of each dataset" 220 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 221 self.cfg.__dict__["n_mazes"] = len(self) 222 for dataset in self.maze_datasets: 223 dataset.update_self_config() 224 225 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]
update the config to match the number of mazes, and update the underlying configs of each dataset
Inherited Members
31@serializable_dataclass(kw_only=True) 32class MazeDatasetCollectionConfig(GPTDatasetConfig): 33 """maze dataset collection configuration, including tokenizers and shuffle""" 34 35 # Attributes without a default cannot follow attributes with one [misc] 36 maze_dataset_configs: list[MazeDatasetConfig] = serializable_field( # type: ignore[misc] 37 serialization_fn=lambda configs: [config.serialize() for config in configs], 38 loading_fn=lambda data: [ 39 MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"] 40 ], 41 ) 42 43 def summary(self) -> dict: 44 """return a summary of the config""" 45 return dict( 46 n_mazes=self.n_mazes, 47 max_grid_n=self.max_grid_n, 48 max_grid_shape=self.max_grid_shape, 49 fname=self.to_fname(), 50 cfg_summaries=[c.summary() for c in self.maze_dataset_configs], 51 ) 52 53 @property 54 def n_mazes(self) -> int: 55 """return the total number of mazes in the collection across all dataset""" 56 return sum(config.n_mazes for config in self.maze_dataset_configs) 57 58 @property 59 def max_grid_n(self) -> int: 60 """return the maximum grid size of the mazes in the collection""" 61 return max(config.grid_n for config in self.maze_dataset_configs) 62 63 @property 64 def max_grid_shape(self) -> CoordTup: 65 """return the maximum grid shape of the mazes in the collection""" 66 return (self.max_grid_n, self.max_grid_n) 67 68 @property 69 def max_grid_shape_np(self) -> Coord: 70 """return the maximum grid shape of the mazes in the collection as a numpy array""" 71 return np.array(self.max_grid_shape, dtype=np.int32) 72 73 def stable_hash_cfg(self) -> int: 74 """return a stable hash of the config""" 75 return stable_hash(json.dumps(self.serialize())) 76 77 def to_fname(self) -> str: 78 """convert config to a filename""" 79 return sanitize_fname( 80 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", 81 )
maze dataset collection configuration, including tokenizers and shuffle
43 def summary(self) -> dict: 44 """return a summary of the config""" 45 return dict( 46 n_mazes=self.n_mazes, 47 max_grid_n=self.max_grid_n, 48 max_grid_shape=self.max_grid_shape, 49 fname=self.to_fname(), 50 cfg_summaries=[c.summary() for c in self.maze_dataset_configs], 51 )
return a summary of the config
53 @property 54 def n_mazes(self) -> int: 55 """return the total number of mazes in the collection across all dataset""" 56 return sum(config.n_mazes for config in self.maze_dataset_configs)
return the total number of mazes in the collection across all dataset
58 @property 59 def max_grid_n(self) -> int: 60 """return the maximum grid size of the mazes in the collection""" 61 return max(config.grid_n for config in self.maze_dataset_configs)
return the maximum grid size of the mazes in the collection
63 @property 64 def max_grid_shape(self) -> CoordTup: 65 """return the maximum grid shape of the mazes in the collection""" 66 return (self.max_grid_n, self.max_grid_n)
return the maximum grid shape of the mazes in the collection
68 @property 69 def max_grid_shape_np(self) -> Coord: 70 """return the maximum grid shape of the mazes in the collection as a numpy array""" 71 return np.array(self.max_grid_shape, dtype=np.int32)
return the maximum grid shape of the mazes in the collection as a numpy array
73 def stable_hash_cfg(self) -> int: 74 """return a stable hash of the config""" 75 return stable_hash(json.dumps(self.serialize()))
return a stable hash of the config
77 def to_fname(self) -> str: 78 """convert config to a filename""" 79 return sanitize_fname( 80 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", 81 )
convert config to a filename
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict