Coverage for maze_dataset/dataset/maze_dataset_config.py: 24%
118 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 14:35 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 14:35 -0600
1"implements `MazeDatasetConfig` which is used to generate or load a dataset"
3import hashlib
4import importlib.metadata
5import json
6import typing
7import warnings
8from typing import Callable
10import numpy as np
11from jaxtyping import Float
12from muutils.json_serialize import (
13 serializable_dataclass,
14 serializable_field,
15)
16from muutils.json_serialize.util import (
17 safe_getsource,
18 string_as_lines,
19)
20from muutils.misc import sanitize_fname, shorten_numerical_to_str
22from maze_dataset.constants import Coord, CoordTup
23from maze_dataset.dataset.dataset import (
24 GPTDatasetConfig,
25)
26from maze_dataset.dataset.success_predict_math import cfg_success_predict_fn
27from maze_dataset.generation.generators import _GENERATORS_PERCOLATED, GENERATORS_MAP
29SERIALIZE_MINIMAL_THRESHOLD: int | None = 100
30"""If `n_mazes>=SERIALIZE_MINIMAL_THRESHOLD`, then the MazeDataset will use `serialize_minimal`.
31Setting to None means that `serialize_minimal` will never be used.
32Set to -1 to make calls to `read` use `MazeDataset._load_legacy`. Used for profiling only."""
34MAZEDATASETCONFIG_FNAME_HASH_LENGTH: int = 5
35"length of the has, in characters, of the hash in the fname of a `MazeDatasetConfig`"
37_PercolationSuccessArray = Float[
38 np.ndarray,
39 "p/grid_n/deadends/endpoints_not_equal/generator_func=5",
40]
43class NoPercolationInConfigError(ValueError):
44 """raised when trying to predict the success fraction of a config that doesn't have percolation"""
46 pass
49class SuccessChanceTooSmallError(ValueError):
50 """raised when the success fraction is below the threshold in `MazeDatasetConfig.success_fraction_compensate`"""
52 pass
55def set_serialize_minimal_threshold(threshold: int | None) -> None:
56 "get the global SERIALIZE_MINIMAL_THRESHOLD"
57 global SERIALIZE_MINIMAL_THRESHOLD # noqa: PLW0603
58 SERIALIZE_MINIMAL_THRESHOLD = threshold
61def _load_maze_ctor(maze_ctor_serialized: str | dict) -> Callable:
62 "get the maze constructor from `GENERATORS_MAP`"
63 if isinstance(maze_ctor_serialized, dict):
64 # this is both the new and old version of the serialization
65 return GENERATORS_MAP[maze_ctor_serialized["__name__"]]
66 elif isinstance(maze_ctor_serialized, str):
67 # this is a version I switched to for a while but now we are switching back
68 warnings.warn(
69 "you are loading an old model/config in `_load_maze_ctor()`!!! this should not be happening, please report: "
70 "https://github.com/understanding-search/maze-dataset/issues/new",
71 )
72 return GENERATORS_MAP[maze_ctor_serialized]
73 else:
74 err_msg: str = f"maze_ctor_serialized is of type {type(maze_ctor_serialized) = }, expected str or dict\n{maze_ctor_serialized = }"
75 raise TypeError(err_msg)
78EndpointKwargsType = dict[
79 typing.Literal[
80 "allowed_start",
81 "allowed_end",
82 "deadend_start",
83 "deadend_end",
84 "endpoints_not_equal",
85 "except_on_no_valid_endpoint",
86 ],
87 bool | None | list[tuple[int, int]],
88]
89"type hint for `MazeDatasetConfig.endpoint_kwargs`"
92def _load_endpoint_kwargs(data: dict) -> EndpointKwargsType:
93 if data.get("endpoint_kwargs") is None:
94 return dict()
96 else:
97 return {
98 k: (
99 # bools and Nones are fine
100 v
101 if (isinstance(v, bool) or v is None)
102 # assume its a CoordList
103 else [tuple(x) for x in v] # muutils/zanj saves tuples as lists
104 )
105 for k, v in data["endpoint_kwargs"].items()
106 }
109@serializable_dataclass(kw_only=True, properties_to_serialize=["grid_shape"])
110class _MazeDatasetConfig_base(GPTDatasetConfig): # noqa: N801
111 """base config -- we serialize, dump to json, and hash this to get the fname. all actual variables we want to be hashed are here"""
113 # NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only
115 grid_n: int = serializable_field() # type: ignore[misc]
117 # not comparing n_mazes is done primarily to avoid conflicts which happen during `from_config` when we have applied filters
118 n_mazes: int = serializable_field(compare=False) # type: ignore[misc]
120 maze_ctor: Callable = serializable_field(
121 default=GENERATORS_MAP["gen_dfs"],
122 serialization_fn=lambda gen_func: {
123 "__name__": gen_func.__name__,
124 "__module__": gen_func.__module__,
125 # NOTE: this was causing hashing issues on 3.13 vs older versions because somehow,
126 # the `__doc__` variable is different across versions??????? WHY???????? IT TREATS WHITESPACE DIFFERENTLY
127 # so we just uh. strip it all now.
128 # see:
129 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080746?pr=53
130 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080742?pr=53
131 # https://www.diffchecker.com/tqIMSevy/
132 # update: we also need to filter for empty lines. B)
133 "__doc__": [
134 line.strip()
135 for line in string_as_lines(gen_func.__doc__)
136 if line.strip()
137 ],
138 "source_code": safe_getsource(gen_func),
139 },
140 loading_fn=lambda data: _load_maze_ctor(data["maze_ctor"]),
141 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures
142 )
144 maze_ctor_kwargs: dict = serializable_field(
145 default_factory=dict,
146 serialization_fn=lambda kwargs: kwargs,
147 loading_fn=lambda data: (
148 dict()
149 if data.get("maze_ctor_kwargs", None)
150 is None # this should handle the backwards compatibility
151 else data["maze_ctor_kwargs"]
152 ),
153 )
155 endpoint_kwargs: EndpointKwargsType = serializable_field(
156 default_factory=dict,
157 serialization_fn=lambda kwargs: kwargs,
158 loading_fn=_load_endpoint_kwargs,
159 assert_type=False,
160 )
162 # NOTE: this part is very hacky. the way muutils works is that it iterates over the *keys in the serialized data*,
163 # and so we need to save an `None` here or this wont load the `fname` field on load
164 # this is a total mess, and very confusing, and entirely my fault
165 _fname_loaded: str | None = serializable_field(
166 default=None,
167 compare=False,
168 serialization_fn=lambda _: None,
169 loading_fn=lambda data: data.get("fname", None),
170 )
172 @property
173 def grid_shape(self) -> CoordTup:
174 """return the shape of the grid as a tuple"""
175 return (self.grid_n, self.grid_n)
177 @property
178 def grid_shape_np(self) -> Coord:
179 """return the shape of the grid as a numpy array"""
180 return np.array(self.grid_shape)
182 @property
183 def max_grid_n(self) -> int:
184 """return the maximum of the grid shape"""
185 return max(self.grid_shape)
187 def _serialize_base(
188 self, applied_filters__skip__collect_generation_meta: bool = True
189 ) -> dict:
190 """serialize the base config for user in `stable_hash_cfg()` and `to_fname()`
192 - note that the _fname_loaded will always be `None` to avoid infinite recursion
193 - note that we **do not** by default include information about metadata collection here,
194 since otherwise loading a dataset that we minified by collecting the metadata would be impossible
195 but for comparing things, we do store it when serializing properly by setting
196 `applied_filters__skip__collect_generation_meta=False`
197 """
198 serialized: dict = _MazeDatasetConfig_base.serialize(self)
199 if applied_filters__skip__collect_generation_meta:
200 serialized["applied_filters"] = [
201 x
202 for x in serialized["applied_filters"]
203 if x.get("name", None) != "collect_generation_meta"
204 ]
205 return serialized
207 def _stable_str_dump(self) -> str:
208 return json.dumps(
209 self._serialize_base(),
210 sort_keys=True,
211 indent=None,
212 )
214 def stable_hash_cfg(self) -> int:
215 """return a stable hash of the config"""
216 return int.from_bytes(
217 hashlib.md5( # noqa: S324
218 bytes(self._stable_str_dump(), "ascii")
219 ).digest(),
220 "big",
221 )
223 def to_fname(self) -> str:
224 """return a unique identifier (valid as a filename) for this config"""
225 n_mazes_str: str = shorten_numerical_to_str(self.n_mazes)
226 maze_ctor_name: str = self.maze_ctor.__name__.removeprefix("gen_")
227 hash_id: int = self.stable_hash_cfg() % 10**MAZEDATASETCONFIG_FNAME_HASH_LENGTH
228 return sanitize_fname(
229 f"{self.name}-g{self.grid_n}-n{n_mazes_str}-a_{maze_ctor_name}-h{hash_id}",
230 )
233# NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only
234@serializable_dataclass(kw_only=True, methods_no_override=["serialize"])
235class MazeDatasetConfig(_MazeDatasetConfig_base): # type: ignore[misc]
236 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset"""
238 @property
239 def config_version(self) -> str:
240 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config"""
241 return "1.0"
243 @property
244 def versions(self) -> dict:
245 """return the versions of the config and the maze_dataset"""
246 return dict(
247 config=self.config_version,
248 maze_dataset=importlib.metadata.version("maze_dataset"),
249 )
251 def serialize(self) -> dict:
252 "serialize the MazeDatasetConfig with all fields and fname"
253 return {
254 **self._serialize_base(
255 applied_filters__skip__collect_generation_meta=False
256 ),
257 "fname": self.to_fname(),
258 "versions": self.versions,
259 }
261 def summary(self) -> dict:
262 """return a summary of the config"""
263 # do we run this to make sure it doesn't error?
264 super_summary: dict = super().summary()
265 assert super_summary
266 self_ser: dict = self.serialize()
267 return dict(
268 name=self.name,
269 fname=self.to_fname(),
270 sdc_hash=self.stable_hash_cfg(),
271 seed=self.seed,
272 seq_len_min=self.seq_len_min,
273 seq_len_max=self.seq_len_max,
274 applied_filters=self.applied_filters,
275 grid_n=self_ser["grid_n"],
276 n_mazes=self_ser["n_mazes"],
277 maze_ctor_name=self_ser["maze_ctor"]["__name__"],
278 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
279 endpoint_kwargs=self_ser["endpoint_kwargs"],
280 )
282 def _to_ps_array(self) -> _PercolationSuccessArray:
283 """Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector.
285 used in predicting the success rate
286 """
287 try:
288 assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, (
289 f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }"
290 )
291 assert "p" in self.maze_ctor_kwargs, (
292 f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }"
293 )
294 assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), (
295 f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }"
296 )
297 except AssertionError as e:
298 err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }"
299 raise NoPercolationInConfigError(
300 err_msg,
301 ) from e
303 endpoints_unique_flag: int = int(
304 # we are pretty sure it will be an int or bool here
305 self.endpoint_kwargs.get("endpoints_not_equal", True), # type: ignore[arg-type]
306 )
308 # adjustment for bknutson0
309 if not (
310 self.endpoint_kwargs.get("deadend_start", False)
311 and self.endpoint_kwargs.get("deadend_end", False)
312 ):
313 # we didnt train on this, but if either endpoint is not required to be in a dead end
314 # then requiring the endpoints to be unique does not really affect the success rate
315 # (except for very small percolation values, pure percolation generation)
316 endpoints_unique_flag = 0
318 return np.array(
319 [
320 float(self.maze_ctor_kwargs["p"]),
321 float(self.grid_n),
322 float(
323 int(
324 self.endpoint_kwargs.get("deadend_start", False) # type: ignore[arg-type]
325 or self.endpoint_kwargs.get("deadend_end", False),
326 ),
327 ),
328 float(endpoints_unique_flag),
329 float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)),
330 ],
331 dtype=np.float64,
332 )
334 @classmethod
335 def _from_ps_array(
336 cls,
337 arr: _PercolationSuccessArray,
338 name: str = "predict",
339 n_mazes: int = 100,
340 **kwargs,
341 ) -> "MazeDatasetConfig":
342 """Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters.
344 # Returns:
345 - `MazeDatasetConfig`
346 Config corresponding to `arr`
347 """
348 return cls(
349 name=name,
350 grid_n=int(arr[1]),
351 n_mazes=n_mazes,
352 maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]],
353 maze_ctor_kwargs={"p": float(arr[0])},
354 endpoint_kwargs=dict(
355 deadend_start=bool(arr[2]),
356 deadend_end=bool(arr[2]),
357 endpoints_not_equal=bool(arr[3]),
358 except_on_no_valid_endpoint=False,
359 ),
360 **kwargs,
361 )
363 def success_fraction_estimate(
364 self,
365 except_if_all_success_expected: bool = False,
366 ) -> float:
367 """Estimate the success fraction of this config.
369 only valid when the generator is a percolation generator,
370 and endpoints are enforced to be dead ends
372 this estimate comes from `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmarks.sweep_fit`
374 # Parameters:
375 - `except_if_all_success_expected : bool`
376 if `True`, don't raise an error if the success fraction is below the threshold.
377 will always return `1.0` if the config is not expected to fail
379 # Returns:
380 - `float`
381 estimated success fraction
383 # Raises:
384 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False`
385 """
386 try:
387 return cfg_success_predict_fn(self)
389 except NoPercolationInConfigError as e:
390 if except_if_all_success_expected:
391 raise e # noqa: TRY201
392 return 1.0
394 def success_fraction_compensate(
395 self,
396 safety_margin: float = 1.2,
397 except_if_all_success_expected: bool = False,
398 epsilon: float = 1e-2,
399 ) -> "MazeDatasetConfig":
400 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction
402 # Parameters:
403 - `safety_margin : float`
404 safety margin to apply to the success fraction estimate
405 (defaults to `1.2`, or 20% more mazes than estimated)
406 - `except_if_all_success_expected : bool`
407 if `True`, don't raise an error if the success fraction is below the threshold.
408 this is passed to `MazeDatasetConfig.success_fraction_estimate`.
409 if your config isn't expected to fail, passing this might mean you generate more mazes than needed
410 since `safety_margin` is still applied.
411 (defaults to `False`)
412 - `epsilon : float`
413 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold
414 (defaults to `1e-2`)
416 # Returns:
417 - `MazeDatasetConfig`
418 new config with adjusted `n_mazes`
420 # Raises:
421 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon`
422 """
423 # compute and check the success fraction
424 success_fraction: float = self.success_fraction_estimate(
425 except_if_all_success_expected=except_if_all_success_expected,
426 )
427 if success_fraction < epsilon:
428 err_msg: str = (
429 f"{success_fraction = } is below the threshold of {epsilon = }"
430 )
431 raise SuccessChanceTooSmallError(
432 err_msg,
433 )
435 # compute the new number of mazes
436 n_mazes: int = self.n_mazes
437 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1
439 # put it in a new config and return
440 cfg_dict: dict = self.serialize()
441 cfg_dict["n_mazes"] = new_n_mazes
442 return MazeDatasetConfig.load(cfg_dict)