maze_dataset.dataset.maze_dataset_config
implements MazeDatasetConfig
which is used to generate or load a dataset
1"implements `MazeDatasetConfig` which is used to generate or load a dataset" 2 3import hashlib 4import importlib.metadata 5import json 6import typing 7import warnings 8from typing import Callable 9 10import numpy as np 11from jaxtyping import Float 12from muutils.json_serialize import ( 13 serializable_dataclass, 14 serializable_field, 15) 16from muutils.json_serialize.util import ( 17 safe_getsource, 18 string_as_lines, 19) 20from muutils.misc import sanitize_fname, shorten_numerical_to_str 21 22from maze_dataset.constants import Coord, CoordTup 23from maze_dataset.dataset.dataset import ( 24 GPTDatasetConfig, 25) 26from maze_dataset.dataset.success_predict_math import cfg_success_predict_fn 27from maze_dataset.generation.generators import _GENERATORS_PERCOLATED, GENERATORS_MAP 28 29SERIALIZE_MINIMAL_THRESHOLD: int | None = 100 30"""If `n_mazes>=SERIALIZE_MINIMAL_THRESHOLD`, then the MazeDataset will use `serialize_minimal`. 31Setting to None means that `serialize_minimal` will never be used. 32Set to -1 to make calls to `read` use `MazeDataset._load_legacy`. Used for profiling only.""" 33 34MAZEDATASETCONFIG_FNAME_HASH_LENGTH: int = 5 35"length of the has, in characters, of the hash in the fname of a `MazeDatasetConfig`" 36 37_PercolationSuccessArray = Float[ 38 np.ndarray, 39 "p/grid_n/deadends/endpoints_not_equal/generator_func=5", 40] 41 42 43class NoPercolationInConfigError(ValueError): 44 """raised when trying to predict the success fraction of a config that doesn't have percolation""" 45 46 pass 47 48 49class SuccessChanceTooSmallError(ValueError): 50 """raised when the success fraction is below the threshold in `MazeDatasetConfig.success_fraction_compensate`""" 51 52 pass 53 54 55def set_serialize_minimal_threshold(threshold: int | None) -> None: 56 "get the global SERIALIZE_MINIMAL_THRESHOLD" 57 global SERIALIZE_MINIMAL_THRESHOLD # noqa: PLW0603 58 SERIALIZE_MINIMAL_THRESHOLD = threshold 59 60 61def _load_maze_ctor(maze_ctor_serialized: str | dict) -> Callable: 62 "get the maze constructor from `GENERATORS_MAP`" 63 if isinstance(maze_ctor_serialized, dict): 64 # this is both the new and old version of the serialization 65 return GENERATORS_MAP[maze_ctor_serialized["__name__"]] 66 elif isinstance(maze_ctor_serialized, str): 67 # this is a version I switched to for a while but now we are switching back 68 warnings.warn( 69 "you are loading an old model/config in `_load_maze_ctor()`!!! this should not be happening, please report: " 70 "https://github.com/understanding-search/maze-dataset/issues/new", 71 ) 72 return GENERATORS_MAP[maze_ctor_serialized] 73 else: 74 err_msg: str = f"maze_ctor_serialized is of type {type(maze_ctor_serialized) = }, expected str or dict\n{maze_ctor_serialized = }" 75 raise TypeError(err_msg) 76 77 78EndpointKwargsType = dict[ 79 typing.Literal[ 80 "allowed_start", 81 "allowed_end", 82 "deadend_start", 83 "deadend_end", 84 "endpoints_not_equal", 85 "except_on_no_valid_endpoint", 86 ], 87 bool | None | list[tuple[int, int]], 88] 89"type hint for `MazeDatasetConfig.endpoint_kwargs`" 90 91 92def _load_endpoint_kwargs(data: dict) -> EndpointKwargsType: 93 if data.get("endpoint_kwargs") is None: 94 return dict() 95 96 else: 97 return { 98 k: ( 99 # bools and Nones are fine 100 v 101 if (isinstance(v, bool) or v is None) 102 # assume its a CoordList 103 else [tuple(x) for x in v] # muutils/zanj saves tuples as lists 104 ) 105 for k, v in data["endpoint_kwargs"].items() 106 } 107 108 109@serializable_dataclass(kw_only=True, properties_to_serialize=["grid_shape"]) 110class _MazeDatasetConfig_base(GPTDatasetConfig): # noqa: N801 111 """base config -- we serialize, dump to json, and hash this to get the fname. all actual variables we want to be hashed are here""" 112 113 # NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only 114 115 grid_n: int = serializable_field() # type: ignore[misc] 116 117 # not comparing n_mazes is done primarily to avoid conflicts which happen during `from_config` when we have applied filters 118 n_mazes: int = serializable_field(compare=False) # type: ignore[misc] 119 120 maze_ctor: Callable = serializable_field( 121 default=GENERATORS_MAP["gen_dfs"], 122 serialization_fn=lambda gen_func: { 123 "__name__": gen_func.__name__, 124 "__module__": gen_func.__module__, 125 # NOTE: this was causing hashing issues on 3.13 vs older versions because somehow, 126 # the `__doc__` variable is different across versions??????? WHY???????? IT TREATS WHITESPACE DIFFERENTLY 127 # so we just uh. strip it all now. 128 # see: 129 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080746?pr=53 130 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080742?pr=53 131 # https://www.diffchecker.com/tqIMSevy/ 132 # update: we also need to filter for empty lines. B) 133 "__doc__": [ 134 line.strip() 135 for line in string_as_lines(gen_func.__doc__) 136 if line.strip() 137 ], 138 "source_code": safe_getsource(gen_func), 139 }, 140 loading_fn=lambda data: _load_maze_ctor(data["maze_ctor"]), 141 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures 142 ) 143 144 maze_ctor_kwargs: dict = serializable_field( 145 default_factory=dict, 146 serialization_fn=lambda kwargs: kwargs, 147 loading_fn=lambda data: ( 148 dict() 149 if data.get("maze_ctor_kwargs", None) 150 is None # this should handle the backwards compatibility 151 else data["maze_ctor_kwargs"] 152 ), 153 ) 154 155 endpoint_kwargs: EndpointKwargsType = serializable_field( 156 default_factory=dict, 157 serialization_fn=lambda kwargs: kwargs, 158 loading_fn=_load_endpoint_kwargs, 159 assert_type=False, 160 ) 161 162 # NOTE: this part is very hacky. the way muutils works is that it iterates over the *keys in the serialized data*, 163 # and so we need to save an `None` here or this wont load the `fname` field on load 164 # this is a total mess, and very confusing, and entirely my fault 165 _fname_loaded: str | None = serializable_field( 166 default=None, 167 compare=False, 168 serialization_fn=lambda _: None, 169 loading_fn=lambda data: data.get("fname", None), 170 ) 171 172 @property 173 def grid_shape(self) -> CoordTup: 174 """return the shape of the grid as a tuple""" 175 return (self.grid_n, self.grid_n) 176 177 @property 178 def grid_shape_np(self) -> Coord: 179 """return the shape of the grid as a numpy array""" 180 return np.array(self.grid_shape) 181 182 @property 183 def max_grid_n(self) -> int: 184 """return the maximum of the grid shape""" 185 return max(self.grid_shape) 186 187 def _serialize_base( 188 self, applied_filters__skip__collect_generation_meta: bool = True 189 ) -> dict: 190 """serialize the base config for user in `stable_hash_cfg()` and `to_fname()` 191 192 - note that the _fname_loaded will always be `None` to avoid infinite recursion 193 - note that we **do not** by default include information about metadata collection here, 194 since otherwise loading a dataset that we minified by collecting the metadata would be impossible 195 but for comparing things, we do store it when serializing properly by setting 196 `applied_filters__skip__collect_generation_meta=False` 197 """ 198 serialized: dict = _MazeDatasetConfig_base.serialize(self) 199 if applied_filters__skip__collect_generation_meta: 200 serialized["applied_filters"] = [ 201 x 202 for x in serialized["applied_filters"] 203 if x.get("name", None) != "collect_generation_meta" 204 ] 205 return serialized 206 207 def _stable_str_dump(self) -> str: 208 return json.dumps( 209 self._serialize_base(), 210 sort_keys=True, 211 indent=None, 212 ) 213 214 def stable_hash_cfg(self) -> int: 215 """return a stable hash of the config""" 216 return int.from_bytes( 217 hashlib.md5( # noqa: S324 218 bytes(self._stable_str_dump(), "ascii") 219 ).digest(), 220 "big", 221 ) 222 223 def to_fname(self) -> str: 224 """return a unique identifier (valid as a filename) for this config""" 225 n_mazes_str: str = shorten_numerical_to_str(self.n_mazes) 226 maze_ctor_name: str = self.maze_ctor.__name__.removeprefix("gen_") 227 hash_id: int = self.stable_hash_cfg() % 10**MAZEDATASETCONFIG_FNAME_HASH_LENGTH 228 return sanitize_fname( 229 f"{self.name}-g{self.grid_n}-n{n_mazes_str}-a_{maze_ctor_name}-h{hash_id}", 230 ) 231 232 233# NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only 234@serializable_dataclass(kw_only=True, methods_no_override=["serialize"]) 235class MazeDatasetConfig(_MazeDatasetConfig_base): # type: ignore[misc] 236 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset""" 237 238 @property 239 def config_version(self) -> str: 240 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 241 return "1.0" 242 243 @property 244 def versions(self) -> dict: 245 """return the versions of the config and the maze_dataset""" 246 return dict( 247 config=self.config_version, 248 maze_dataset=importlib.metadata.version("maze_dataset"), 249 ) 250 251 def serialize(self) -> dict: 252 "serialize the MazeDatasetConfig with all fields and fname" 253 return { 254 **self._serialize_base( 255 applied_filters__skip__collect_generation_meta=False 256 ), 257 "fname": self.to_fname(), 258 "versions": self.versions, 259 } 260 261 def summary(self) -> dict: 262 """return a summary of the config""" 263 # do we run this to make sure it doesn't error? 264 super_summary: dict = super().summary() 265 assert super_summary 266 self_ser: dict = self.serialize() 267 return dict( 268 name=self.name, 269 fname=self.to_fname(), 270 sdc_hash=self.stable_hash_cfg(), 271 seed=self.seed, 272 seq_len_min=self.seq_len_min, 273 seq_len_max=self.seq_len_max, 274 applied_filters=self.applied_filters, 275 grid_n=self_ser["grid_n"], 276 n_mazes=self_ser["n_mazes"], 277 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 278 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 279 endpoint_kwargs=self_ser["endpoint_kwargs"], 280 ) 281 282 def _to_ps_array(self) -> _PercolationSuccessArray: 283 """Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector. 284 285 used in predicting the success rate 286 """ 287 try: 288 assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, ( 289 f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }" 290 ) 291 assert "p" in self.maze_ctor_kwargs, ( 292 f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }" 293 ) 294 assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), ( 295 f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }" 296 ) 297 except AssertionError as e: 298 err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }" 299 raise NoPercolationInConfigError( 300 err_msg, 301 ) from e 302 303 endpoints_unique_flag: int = int( 304 # we are pretty sure it will be an int or bool here 305 self.endpoint_kwargs.get("endpoints_not_equal", True), # type: ignore[arg-type] 306 ) 307 308 # adjustment for bknutson0 309 if not ( 310 self.endpoint_kwargs.get("deadend_start", False) 311 and self.endpoint_kwargs.get("deadend_end", False) 312 ): 313 # we didnt train on this, but if either endpoint is not required to be in a dead end 314 # then requiring the endpoints to be unique does not really affect the success rate 315 # (except for very small percolation values, pure percolation generation) 316 endpoints_unique_flag = 0 317 318 return np.array( 319 [ 320 float(self.maze_ctor_kwargs["p"]), 321 float(self.grid_n), 322 float( 323 int( 324 self.endpoint_kwargs.get("deadend_start", False) # type: ignore[arg-type] 325 or self.endpoint_kwargs.get("deadend_end", False), 326 ), 327 ), 328 float(endpoints_unique_flag), 329 float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)), 330 ], 331 dtype=np.float64, 332 ) 333 334 @classmethod 335 def _from_ps_array( 336 cls, 337 arr: _PercolationSuccessArray, 338 name: str = "predict", 339 n_mazes: int = 100, 340 **kwargs, 341 ) -> "MazeDatasetConfig": 342 """Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters. 343 344 # Returns: 345 - `MazeDatasetConfig` 346 Config corresponding to `arr` 347 """ 348 return cls( 349 name=name, 350 grid_n=int(arr[1]), 351 n_mazes=n_mazes, 352 maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]], 353 maze_ctor_kwargs={"p": float(arr[0])}, 354 endpoint_kwargs=dict( 355 deadend_start=bool(arr[2]), 356 deadend_end=bool(arr[2]), 357 endpoints_not_equal=bool(arr[3]), 358 except_on_no_valid_endpoint=False, 359 ), 360 **kwargs, 361 ) 362 363 def success_fraction_estimate( 364 self, 365 except_if_all_success_expected: bool = False, 366 ) -> float: 367 """Estimate the success fraction of this config. 368 369 only valid when the generator is a percolation generator, 370 and endpoints are enforced to be dead ends 371 372 this estimate comes from `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmarks.sweep_fit` 373 374 # Parameters: 375 - `except_if_all_success_expected : bool` 376 if `True`, don't raise an error if the success fraction is below the threshold. 377 will always return `1.0` if the config is not expected to fail 378 379 # Returns: 380 - `float` 381 estimated success fraction 382 383 # Raises: 384 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 385 """ 386 try: 387 return cfg_success_predict_fn(self) 388 389 except NoPercolationInConfigError as e: 390 if except_if_all_success_expected: 391 raise e # noqa: TRY201 392 return 1.0 393 394 def success_fraction_compensate( 395 self, 396 safety_margin: float = 1.2, 397 except_if_all_success_expected: bool = False, 398 epsilon: float = 1e-2, 399 ) -> "MazeDatasetConfig": 400 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 401 402 # Parameters: 403 - `safety_margin : float` 404 safety margin to apply to the success fraction estimate 405 (defaults to `1.2`, or 20% more mazes than estimated) 406 - `except_if_all_success_expected : bool` 407 if `True`, don't raise an error if the success fraction is below the threshold. 408 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 409 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 410 since `safety_margin` is still applied. 411 (defaults to `False`) 412 - `epsilon : float` 413 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 414 (defaults to `1e-2`) 415 416 # Returns: 417 - `MazeDatasetConfig` 418 new config with adjusted `n_mazes` 419 420 # Raises: 421 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 422 """ 423 # compute and check the success fraction 424 success_fraction: float = self.success_fraction_estimate( 425 except_if_all_success_expected=except_if_all_success_expected, 426 ) 427 if success_fraction < epsilon: 428 err_msg: str = ( 429 f"{success_fraction = } is below the threshold of {epsilon = }" 430 ) 431 raise SuccessChanceTooSmallError( 432 err_msg, 433 ) 434 435 # compute the new number of mazes 436 n_mazes: int = self.n_mazes 437 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 438 439 # put it in a new config and return 440 cfg_dict: dict = self.serialize() 441 cfg_dict["n_mazes"] = new_n_mazes 442 return MazeDatasetConfig.load(cfg_dict)
If n_mazes>=SERIALIZE_MINIMAL_THRESHOLD
, then the MazeDataset will use serialize_minimal
.
Setting to None means that serialize_minimal
will never be used.
Set to -1 to make calls to read
use MazeDataset._load_legacy
. Used for profiling only.
length of the has, in characters, of the hash in the fname of a MazeDatasetConfig
44class NoPercolationInConfigError(ValueError): 45 """raised when trying to predict the success fraction of a config that doesn't have percolation""" 46 47 pass
raised when trying to predict the success fraction of a config that doesn't have percolation
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
50class SuccessChanceTooSmallError(ValueError): 51 """raised when the success fraction is below the threshold in `MazeDatasetConfig.success_fraction_compensate`""" 52 53 pass
raised when the success fraction is below the threshold in MazeDatasetConfig.success_fraction_compensate
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
56def set_serialize_minimal_threshold(threshold: int | None) -> None: 57 "get the global SERIALIZE_MINIMAL_THRESHOLD" 58 global SERIALIZE_MINIMAL_THRESHOLD # noqa: PLW0603 59 SERIALIZE_MINIMAL_THRESHOLD = threshold
get the global SERIALIZE_MINIMAL_THRESHOLD
type hint for MazeDatasetConfig.endpoint_kwargs
235@serializable_dataclass(kw_only=True, methods_no_override=["serialize"]) 236class MazeDatasetConfig(_MazeDatasetConfig_base): # type: ignore[misc] 237 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset""" 238 239 @property 240 def config_version(self) -> str: 241 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 242 return "1.0" 243 244 @property 245 def versions(self) -> dict: 246 """return the versions of the config and the maze_dataset""" 247 return dict( 248 config=self.config_version, 249 maze_dataset=importlib.metadata.version("maze_dataset"), 250 ) 251 252 def serialize(self) -> dict: 253 "serialize the MazeDatasetConfig with all fields and fname" 254 return { 255 **self._serialize_base( 256 applied_filters__skip__collect_generation_meta=False 257 ), 258 "fname": self.to_fname(), 259 "versions": self.versions, 260 } 261 262 def summary(self) -> dict: 263 """return a summary of the config""" 264 # do we run this to make sure it doesn't error? 265 super_summary: dict = super().summary() 266 assert super_summary 267 self_ser: dict = self.serialize() 268 return dict( 269 name=self.name, 270 fname=self.to_fname(), 271 sdc_hash=self.stable_hash_cfg(), 272 seed=self.seed, 273 seq_len_min=self.seq_len_min, 274 seq_len_max=self.seq_len_max, 275 applied_filters=self.applied_filters, 276 grid_n=self_ser["grid_n"], 277 n_mazes=self_ser["n_mazes"], 278 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 279 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 280 endpoint_kwargs=self_ser["endpoint_kwargs"], 281 ) 282 283 def _to_ps_array(self) -> _PercolationSuccessArray: 284 """Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector. 285 286 used in predicting the success rate 287 """ 288 try: 289 assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, ( 290 f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }" 291 ) 292 assert "p" in self.maze_ctor_kwargs, ( 293 f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }" 294 ) 295 assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), ( 296 f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }" 297 ) 298 except AssertionError as e: 299 err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }" 300 raise NoPercolationInConfigError( 301 err_msg, 302 ) from e 303 304 endpoints_unique_flag: int = int( 305 # we are pretty sure it will be an int or bool here 306 self.endpoint_kwargs.get("endpoints_not_equal", True), # type: ignore[arg-type] 307 ) 308 309 # adjustment for bknutson0 310 if not ( 311 self.endpoint_kwargs.get("deadend_start", False) 312 and self.endpoint_kwargs.get("deadend_end", False) 313 ): 314 # we didnt train on this, but if either endpoint is not required to be in a dead end 315 # then requiring the endpoints to be unique does not really affect the success rate 316 # (except for very small percolation values, pure percolation generation) 317 endpoints_unique_flag = 0 318 319 return np.array( 320 [ 321 float(self.maze_ctor_kwargs["p"]), 322 float(self.grid_n), 323 float( 324 int( 325 self.endpoint_kwargs.get("deadend_start", False) # type: ignore[arg-type] 326 or self.endpoint_kwargs.get("deadend_end", False), 327 ), 328 ), 329 float(endpoints_unique_flag), 330 float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)), 331 ], 332 dtype=np.float64, 333 ) 334 335 @classmethod 336 def _from_ps_array( 337 cls, 338 arr: _PercolationSuccessArray, 339 name: str = "predict", 340 n_mazes: int = 100, 341 **kwargs, 342 ) -> "MazeDatasetConfig": 343 """Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters. 344 345 # Returns: 346 - `MazeDatasetConfig` 347 Config corresponding to `arr` 348 """ 349 return cls( 350 name=name, 351 grid_n=int(arr[1]), 352 n_mazes=n_mazes, 353 maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]], 354 maze_ctor_kwargs={"p": float(arr[0])}, 355 endpoint_kwargs=dict( 356 deadend_start=bool(arr[2]), 357 deadend_end=bool(arr[2]), 358 endpoints_not_equal=bool(arr[3]), 359 except_on_no_valid_endpoint=False, 360 ), 361 **kwargs, 362 ) 363 364 def success_fraction_estimate( 365 self, 366 except_if_all_success_expected: bool = False, 367 ) -> float: 368 """Estimate the success fraction of this config. 369 370 only valid when the generator is a percolation generator, 371 and endpoints are enforced to be dead ends 372 373 this estimate comes from `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmarks.sweep_fit` 374 375 # Parameters: 376 - `except_if_all_success_expected : bool` 377 if `True`, don't raise an error if the success fraction is below the threshold. 378 will always return `1.0` if the config is not expected to fail 379 380 # Returns: 381 - `float` 382 estimated success fraction 383 384 # Raises: 385 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 386 """ 387 try: 388 return cfg_success_predict_fn(self) 389 390 except NoPercolationInConfigError as e: 391 if except_if_all_success_expected: 392 raise e # noqa: TRY201 393 return 1.0 394 395 def success_fraction_compensate( 396 self, 397 safety_margin: float = 1.2, 398 except_if_all_success_expected: bool = False, 399 epsilon: float = 1e-2, 400 ) -> "MazeDatasetConfig": 401 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 402 403 # Parameters: 404 - `safety_margin : float` 405 safety margin to apply to the success fraction estimate 406 (defaults to `1.2`, or 20% more mazes than estimated) 407 - `except_if_all_success_expected : bool` 408 if `True`, don't raise an error if the success fraction is below the threshold. 409 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 410 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 411 since `safety_margin` is still applied. 412 (defaults to `False`) 413 - `epsilon : float` 414 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 415 (defaults to `1e-2`) 416 417 # Returns: 418 - `MazeDatasetConfig` 419 new config with adjusted `n_mazes` 420 421 # Raises: 422 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 423 """ 424 # compute and check the success fraction 425 success_fraction: float = self.success_fraction_estimate( 426 except_if_all_success_expected=except_if_all_success_expected, 427 ) 428 if success_fraction < epsilon: 429 err_msg: str = ( 430 f"{success_fraction = } is below the threshold of {epsilon = }" 431 ) 432 raise SuccessChanceTooSmallError( 433 err_msg, 434 ) 435 436 # compute the new number of mazes 437 n_mazes: int = self.n_mazes 438 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 439 440 # put it in a new config and return 441 cfg_dict: dict = self.serialize() 442 cfg_dict["n_mazes"] = new_n_mazes 443 return MazeDatasetConfig.load(cfg_dict)
config object which is passed to MazeDataset.from_config
to generate or load a dataset
239 @property 240 def config_version(self) -> str: 241 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 242 return "1.0"
return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config
244 @property 245 def versions(self) -> dict: 246 """return the versions of the config and the maze_dataset""" 247 return dict( 248 config=self.config_version, 249 maze_dataset=importlib.metadata.version("maze_dataset"), 250 )
return the versions of the config and the maze_dataset
252 def serialize(self) -> dict: 253 "serialize the MazeDatasetConfig with all fields and fname" 254 return { 255 **self._serialize_base( 256 applied_filters__skip__collect_generation_meta=False 257 ), 258 "fname": self.to_fname(), 259 "versions": self.versions, 260 }
serialize the MazeDatasetConfig with all fields and fname
262 def summary(self) -> dict: 263 """return a summary of the config""" 264 # do we run this to make sure it doesn't error? 265 super_summary: dict = super().summary() 266 assert super_summary 267 self_ser: dict = self.serialize() 268 return dict( 269 name=self.name, 270 fname=self.to_fname(), 271 sdc_hash=self.stable_hash_cfg(), 272 seed=self.seed, 273 seq_len_min=self.seq_len_min, 274 seq_len_max=self.seq_len_max, 275 applied_filters=self.applied_filters, 276 grid_n=self_ser["grid_n"], 277 n_mazes=self_ser["n_mazes"], 278 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 279 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 280 endpoint_kwargs=self_ser["endpoint_kwargs"], 281 )
return a summary of the config
364 def success_fraction_estimate( 365 self, 366 except_if_all_success_expected: bool = False, 367 ) -> float: 368 """Estimate the success fraction of this config. 369 370 only valid when the generator is a percolation generator, 371 and endpoints are enforced to be dead ends 372 373 this estimate comes from `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmarks.sweep_fit` 374 375 # Parameters: 376 - `except_if_all_success_expected : bool` 377 if `True`, don't raise an error if the success fraction is below the threshold. 378 will always return `1.0` if the config is not expected to fail 379 380 # Returns: 381 - `float` 382 estimated success fraction 383 384 # Raises: 385 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 386 """ 387 try: 388 return cfg_success_predict_fn(self) 389 390 except NoPercolationInConfigError as e: 391 if except_if_all_success_expected: 392 raise e # noqa: TRY201 393 return 1.0
Estimate the success fraction of this config.
only valid when the generator is a percolation generator, and endpoints are enforced to be dead ends
this estimate comes from estimate_dataset_fractions.ipynb
and maze_dataset.benchmarks.sweep_fit
Parameters:
except_if_all_success_expected : bool
ifTrue
, don't raise an error if the success fraction is below the threshold. will always return1.0
if the config is not expected to fail
Returns:
float
estimated success fraction
Raises:
NoPercolationInConfigError
: if the config is not expected to fail, andexcept_if_all_success_expected
isFalse
395 def success_fraction_compensate( 396 self, 397 safety_margin: float = 1.2, 398 except_if_all_success_expected: bool = False, 399 epsilon: float = 1e-2, 400 ) -> "MazeDatasetConfig": 401 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 402 403 # Parameters: 404 - `safety_margin : float` 405 safety margin to apply to the success fraction estimate 406 (defaults to `1.2`, or 20% more mazes than estimated) 407 - `except_if_all_success_expected : bool` 408 if `True`, don't raise an error if the success fraction is below the threshold. 409 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 410 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 411 since `safety_margin` is still applied. 412 (defaults to `False`) 413 - `epsilon : float` 414 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 415 (defaults to `1e-2`) 416 417 # Returns: 418 - `MazeDatasetConfig` 419 new config with adjusted `n_mazes` 420 421 # Raises: 422 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 423 """ 424 # compute and check the success fraction 425 success_fraction: float = self.success_fraction_estimate( 426 except_if_all_success_expected=except_if_all_success_expected, 427 ) 428 if success_fraction < epsilon: 429 err_msg: str = ( 430 f"{success_fraction = } is below the threshold of {epsilon = }" 431 ) 432 raise SuccessChanceTooSmallError( 433 err_msg, 434 ) 435 436 # compute the new number of mazes 437 n_mazes: int = self.n_mazes 438 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 439 440 # put it in a new config and return 441 cfg_dict: dict = self.serialize() 442 cfg_dict["n_mazes"] = new_n_mazes 443 return MazeDatasetConfig.load(cfg_dict)
return a new MazeDatasetConfig
like this one with n_mazes
adjusted to compensate for the success fraction
Parameters:
safety_margin : float
safety margin to apply to the success fraction estimate (defaults to1.2
, or 20% more mazes than estimated)except_if_all_success_expected : bool
ifTrue
, don't raise an error if the success fraction is below the threshold. this is passed toMazeDatasetConfig.success_fraction_estimate
. if your config isn't expected to fail, passing this might mean you generate more mazes than needed sincesafety_margin
is still applied. (defaults toFalse
)epsilon : float
raiseSuccessChanceTooSmallError
if the success fraction is below this threshold (defaults to1e-2
)
Returns:
MazeDatasetConfig
new config with adjustedn_mazes
Raises:
SuccessChanceTooSmallError
: if the computed success fraction is belowepsilon
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- _MazeDatasetConfig_base
- grid_n
- n_mazes
- maze_ctor
- maze_ctor_kwargs
- endpoint_kwargs
- grid_shape
- grid_shape_np
- max_grid_n
- stable_hash_cfg
- to_fname
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict