docs for maze-dataset v1.3.1
View Source on GitHub

maze_dataset.dataset.maze_dataset_config

implements MazeDatasetConfig which is used to generate or load a dataset


  1"implements `MazeDatasetConfig` which is used to generate or load a dataset"
  2
  3import hashlib
  4import importlib.metadata
  5import json
  6import typing
  7import warnings
  8from typing import Callable
  9
 10import numpy as np
 11from jaxtyping import Float
 12from muutils.json_serialize import (
 13	serializable_dataclass,
 14	serializable_field,
 15)
 16from muutils.json_serialize.util import (
 17	safe_getsource,
 18	string_as_lines,
 19)
 20from muutils.misc import sanitize_fname, shorten_numerical_to_str
 21
 22from maze_dataset.constants import Coord, CoordTup
 23from maze_dataset.dataset.dataset import (
 24	GPTDatasetConfig,
 25)
 26from maze_dataset.dataset.success_predict_math import cfg_success_predict_fn
 27from maze_dataset.generation.generators import _GENERATORS_PERCOLATED, GENERATORS_MAP
 28
 29SERIALIZE_MINIMAL_THRESHOLD: int | None = 100
 30"""If `n_mazes>=SERIALIZE_MINIMAL_THRESHOLD`, then the MazeDataset will use `serialize_minimal`.
 31Setting to None means that `serialize_minimal` will never be used.
 32Set to -1 to make calls to `read` use `MazeDataset._load_legacy`. Used for profiling only."""
 33
 34MAZEDATASETCONFIG_FNAME_HASH_LENGTH: int = 5
 35"length of the has, in characters, of the hash in the fname of a `MazeDatasetConfig`"
 36
 37_PercolationSuccessArray = Float[
 38	np.ndarray,
 39	"p/grid_n/deadends/endpoints_not_equal/generator_func=5",
 40]
 41
 42
 43class NoPercolationInConfigError(ValueError):
 44	"""raised when trying to predict the success fraction of a config that doesn't have percolation"""
 45
 46	pass
 47
 48
 49class SuccessChanceTooSmallError(ValueError):
 50	"""raised when the success fraction is below the threshold in `MazeDatasetConfig.success_fraction_compensate`"""
 51
 52	pass
 53
 54
 55def set_serialize_minimal_threshold(threshold: int | None) -> None:
 56	"get the global SERIALIZE_MINIMAL_THRESHOLD"
 57	global SERIALIZE_MINIMAL_THRESHOLD  # noqa: PLW0603
 58	SERIALIZE_MINIMAL_THRESHOLD = threshold
 59
 60
 61def _load_maze_ctor(maze_ctor_serialized: str | dict) -> Callable:
 62	"get the maze constructor from `GENERATORS_MAP`"
 63	if isinstance(maze_ctor_serialized, dict):
 64		# this is both the new and old version of the serialization
 65		return GENERATORS_MAP[maze_ctor_serialized["__name__"]]
 66	elif isinstance(maze_ctor_serialized, str):
 67		# this is a version I switched to for a while but now we are switching back
 68		warnings.warn(
 69			"you are loading an old model/config in `_load_maze_ctor()`!!! this should not be happening, please report: "
 70			"https://github.com/understanding-search/maze-dataset/issues/new",
 71		)
 72		return GENERATORS_MAP[maze_ctor_serialized]
 73	else:
 74		err_msg: str = f"maze_ctor_serialized is of type {type(maze_ctor_serialized) = }, expected str or dict\n{maze_ctor_serialized = }"
 75		raise TypeError(err_msg)
 76
 77
 78EndpointKwargsType = dict[
 79	typing.Literal[
 80		"allowed_start",
 81		"allowed_end",
 82		"deadend_start",
 83		"deadend_end",
 84		"endpoints_not_equal",
 85		"except_on_no_valid_endpoint",
 86	],
 87	bool | None | list[tuple[int, int]],
 88]
 89"type hint for `MazeDatasetConfig.endpoint_kwargs`"
 90
 91
 92def _load_endpoint_kwargs(data: dict) -> EndpointKwargsType:
 93	if data.get("endpoint_kwargs") is None:
 94		return dict()
 95
 96	else:
 97		return {
 98			k: (
 99				# bools and Nones are fine
100				v
101				if (isinstance(v, bool) or v is None)
102				# assume its a CoordList
103				else [tuple(x) for x in v]  # muutils/zanj saves tuples as lists
104			)
105			for k, v in data["endpoint_kwargs"].items()
106		}
107
108
109@serializable_dataclass(kw_only=True, properties_to_serialize=["grid_shape"])
110class _MazeDatasetConfig_base(GPTDatasetConfig):  # noqa: N801
111	"""base config -- we serialize, dump to json, and hash this to get the fname. all actual variables we want to be hashed are here"""
112
113	# NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only
114
115	grid_n: int = serializable_field()  # type: ignore[misc]
116
117	# not comparing n_mazes is done primarily to avoid conflicts which happen during `from_config` when we have applied filters
118	n_mazes: int = serializable_field(compare=False)  # type: ignore[misc]
119
120	maze_ctor: Callable = serializable_field(
121		default=GENERATORS_MAP["gen_dfs"],
122		serialization_fn=lambda gen_func: {
123			"__name__": gen_func.__name__,
124			"__module__": gen_func.__module__,
125			# NOTE: this was causing hashing issues on 3.13 vs older versions because somehow,
126			# the `__doc__` variable is different across versions??????? WHY???????? IT TREATS WHITESPACE DIFFERENTLY
127			# so we just uh. strip it all now.
128			# see:
129			# https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080746?pr=53
130			# https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080742?pr=53
131			# https://www.diffchecker.com/tqIMSevy/
132			# update: we also need to filter for empty lines. B)
133			"__doc__": [
134				line.strip()
135				for line in string_as_lines(gen_func.__doc__)
136				if line.strip()
137			],
138			"source_code": safe_getsource(gen_func),
139		},
140		loading_fn=lambda data: _load_maze_ctor(data["maze_ctor"]),
141		assert_type=False,  # TODO: check the type here once muutils supports checking Callable signatures
142	)
143
144	maze_ctor_kwargs: dict = serializable_field(
145		default_factory=dict,
146		serialization_fn=lambda kwargs: kwargs,
147		loading_fn=lambda data: (
148			dict()
149			if data.get("maze_ctor_kwargs", None)
150			is None  # this should handle the backwards compatibility
151			else data["maze_ctor_kwargs"]
152		),
153	)
154
155	endpoint_kwargs: EndpointKwargsType = serializable_field(
156		default_factory=dict,
157		serialization_fn=lambda kwargs: kwargs,
158		loading_fn=_load_endpoint_kwargs,
159		assert_type=False,
160	)
161
162	# NOTE: this part is very hacky. the way muutils works is that it iterates over the *keys in the serialized data*,
163	# and so we need to save an `None` here or this wont load the `fname` field on load
164	# this is a total mess, and very confusing, and entirely my fault
165	_fname_loaded: str | None = serializable_field(
166		default=None,
167		compare=False,
168		serialization_fn=lambda _: None,
169		loading_fn=lambda data: data.get("fname", None),
170	)
171
172	@property
173	def grid_shape(self) -> CoordTup:
174		"""return the shape of the grid as a tuple"""
175		return (self.grid_n, self.grid_n)
176
177	@property
178	def grid_shape_np(self) -> Coord:
179		"""return the shape of the grid as a numpy array"""
180		return np.array(self.grid_shape)
181
182	@property
183	def max_grid_n(self) -> int:
184		"""return the maximum of the grid shape"""
185		return max(self.grid_shape)
186
187	def _serialize_base(
188		self, applied_filters__skip__collect_generation_meta: bool = True
189	) -> dict:
190		"""serialize the base config for user in `stable_hash_cfg()` and `to_fname()`
191
192		- note that the _fname_loaded will always be `None` to avoid infinite recursion
193		- note that we **do not** by default include information about metadata collection here,
194		since otherwise loading a dataset that we minified by collecting the metadata would be impossible
195		but for comparing things, we do store it when serializing properly by setting
196		`applied_filters__skip__collect_generation_meta=False`
197		"""
198		serialized: dict = _MazeDatasetConfig_base.serialize(self)
199		if applied_filters__skip__collect_generation_meta:
200			serialized["applied_filters"] = [
201				x
202				for x in serialized["applied_filters"]
203				if x.get("name", None) != "collect_generation_meta"
204			]
205		return serialized
206
207	def _stable_str_dump(self) -> str:
208		return json.dumps(
209			self._serialize_base(),
210			sort_keys=True,
211			indent=None,
212		)
213
214	def stable_hash_cfg(self) -> int:
215		"""return a stable hash of the config"""
216		return int.from_bytes(
217			hashlib.md5(  # noqa: S324
218				bytes(self._stable_str_dump(), "ascii")
219			).digest(),
220			"big",
221		)
222
223	def to_fname(self) -> str:
224		"""return a unique identifier (valid as a filename) for this config"""
225		n_mazes_str: str = shorten_numerical_to_str(self.n_mazes)
226		maze_ctor_name: str = self.maze_ctor.__name__.removeprefix("gen_")
227		hash_id: int = self.stable_hash_cfg() % 10**MAZEDATASETCONFIG_FNAME_HASH_LENGTH
228		return sanitize_fname(
229			f"{self.name}-g{self.grid_n}-n{n_mazes_str}-a_{maze_ctor_name}-h{hash_id}",
230		)
231
232
233# NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only
234@serializable_dataclass(kw_only=True, methods_no_override=["serialize"])
235class MazeDatasetConfig(_MazeDatasetConfig_base):  # type: ignore[misc]
236	"""config object which is passed to `MazeDataset.from_config` to generate or load a dataset"""
237
238	@property
239	def config_version(self) -> str:
240		"""return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config"""
241		return "1.0"
242
243	@property
244	def versions(self) -> dict:
245		"""return the versions of the config and the maze_dataset"""
246		return dict(
247			config=self.config_version,
248			maze_dataset=importlib.metadata.version("maze_dataset"),
249		)
250
251	def serialize(self) -> dict:
252		"serialize the MazeDatasetConfig with all fields and fname"
253		return {
254			**self._serialize_base(
255				applied_filters__skip__collect_generation_meta=False
256			),
257			"fname": self.to_fname(),
258			"versions": self.versions,
259		}
260
261	def summary(self) -> dict:
262		"""return a summary of the config"""
263		# do we run this to make sure it doesn't error?
264		super_summary: dict = super().summary()
265		assert super_summary
266		self_ser: dict = self.serialize()
267		return dict(
268			name=self.name,
269			fname=self.to_fname(),
270			sdc_hash=self.stable_hash_cfg(),
271			seed=self.seed,
272			seq_len_min=self.seq_len_min,
273			seq_len_max=self.seq_len_max,
274			applied_filters=self.applied_filters,
275			grid_n=self_ser["grid_n"],
276			n_mazes=self_ser["n_mazes"],
277			maze_ctor_name=self_ser["maze_ctor"]["__name__"],
278			maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
279			endpoint_kwargs=self_ser["endpoint_kwargs"],
280		)
281
282	def _to_ps_array(self) -> _PercolationSuccessArray:
283		"""Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector.
284
285		used in predicting the success rate
286		"""
287		try:
288			assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, (
289				f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }"
290			)
291			assert "p" in self.maze_ctor_kwargs, (
292				f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }"
293			)
294			assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), (
295				f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }"
296			)
297		except AssertionError as e:
298			err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }"
299			raise NoPercolationInConfigError(
300				err_msg,
301			) from e
302
303		endpoints_unique_flag: int = int(
304			# we are pretty sure it will be an int or bool here
305			self.endpoint_kwargs.get("endpoints_not_equal", True),  # type: ignore[arg-type]
306		)
307
308		# adjustment for bknutson0
309		if not (
310			self.endpoint_kwargs.get("deadend_start", False)
311			and self.endpoint_kwargs.get("deadend_end", False)
312		):
313			# we didnt train on this, but if either endpoint is not required to be in a dead end
314			# then  requiring the endpoints to be unique does not really affect the success rate
315			# (except for very small percolation values, pure percolation generation)
316			endpoints_unique_flag = 0
317
318		return np.array(
319			[
320				float(self.maze_ctor_kwargs["p"]),
321				float(self.grid_n),
322				float(
323					int(
324						self.endpoint_kwargs.get("deadend_start", False)  # type: ignore[arg-type]
325						or self.endpoint_kwargs.get("deadend_end", False),
326					),
327				),
328				float(endpoints_unique_flag),
329				float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)),
330			],
331			dtype=np.float64,
332		)
333
334	@classmethod
335	def _from_ps_array(
336		cls,
337		arr: _PercolationSuccessArray,
338		name: str = "predict",
339		n_mazes: int = 100,
340		**kwargs,
341	) -> "MazeDatasetConfig":
342		"""Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters.
343
344		# Returns:
345		- `MazeDatasetConfig`
346			Config corresponding to `arr`
347		"""
348		return cls(
349			name=name,
350			grid_n=int(arr[1]),
351			n_mazes=n_mazes,
352			maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]],
353			maze_ctor_kwargs={"p": float(arr[0])},
354			endpoint_kwargs=dict(
355				deadend_start=bool(arr[2]),
356				deadend_end=bool(arr[2]),
357				endpoints_not_equal=bool(arr[3]),
358				except_on_no_valid_endpoint=False,
359			),
360			**kwargs,
361		)
362
363	def success_fraction_estimate(
364		self,
365		except_if_all_success_expected: bool = False,
366	) -> float:
367		"""Estimate the success fraction of this config.
368
369		only valid when the generator is a percolation generator,
370		and endpoints are enforced to be dead ends
371
372		this estimate comes from `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmarks.sweep_fit`
373
374		# Parameters:
375		- `except_if_all_success_expected : bool`
376			if `True`, don't raise an error if the success fraction is below the threshold.
377			will always return `1.0` if the config is not expected to fail
378
379		# Returns:
380		- `float`
381			estimated success fraction
382
383		# Raises:
384		- `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False`
385		"""
386		try:
387			return cfg_success_predict_fn(self)
388
389		except NoPercolationInConfigError as e:
390			if except_if_all_success_expected:
391				raise e  # noqa: TRY201
392			return 1.0
393
394	def success_fraction_compensate(
395		self,
396		safety_margin: float = 1.2,
397		except_if_all_success_expected: bool = False,
398		epsilon: float = 1e-2,
399	) -> "MazeDatasetConfig":
400		"""return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction
401
402		# Parameters:
403		- `safety_margin : float`
404			safety margin to apply to the success fraction estimate
405			(defaults to `1.2`, or 20% more mazes than estimated)
406		- `except_if_all_success_expected : bool`
407			if `True`, don't raise an error if the success fraction is below the threshold.
408			this is passed to `MazeDatasetConfig.success_fraction_estimate`.
409			if your config isn't expected to fail, passing this might mean you generate more mazes than needed
410			since `safety_margin` is still applied.
411			(defaults to `False`)
412		- `epsilon : float`
413			raise `SuccessChanceTooSmallError` if the success fraction is below this threshold
414			(defaults to `1e-2`)
415
416		# Returns:
417		- `MazeDatasetConfig`
418			new config with adjusted `n_mazes`
419
420		# Raises:
421		- `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon`
422		"""
423		# compute and check the success fraction
424		success_fraction: float = self.success_fraction_estimate(
425			except_if_all_success_expected=except_if_all_success_expected,
426		)
427		if success_fraction < epsilon:
428			err_msg: str = (
429				f"{success_fraction = } is below the threshold of {epsilon = }"
430			)
431			raise SuccessChanceTooSmallError(
432				err_msg,
433			)
434
435		# compute the new number of mazes
436		n_mazes: int = self.n_mazes
437		new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1
438
439		# put it in a new config and return
440		cfg_dict: dict = self.serialize()
441		cfg_dict["n_mazes"] = new_n_mazes
442		return MazeDatasetConfig.load(cfg_dict)

SERIALIZE_MINIMAL_THRESHOLD: int | None = 100

If n_mazes>=SERIALIZE_MINIMAL_THRESHOLD, then the MazeDataset will use serialize_minimal. Setting to None means that serialize_minimal will never be used. Set to -1 to make calls to read use MazeDataset._load_legacy. Used for profiling only.

MAZEDATASETCONFIG_FNAME_HASH_LENGTH: int = 5

length of the has, in characters, of the hash in the fname of a MazeDatasetConfig

class NoPercolationInConfigError(builtins.ValueError):
44class NoPercolationInConfigError(ValueError):
45	"""raised when trying to predict the success fraction of a config that doesn't have percolation"""
46
47	pass

raised when trying to predict the success fraction of a config that doesn't have percolation

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class SuccessChanceTooSmallError(builtins.ValueError):
50class SuccessChanceTooSmallError(ValueError):
51	"""raised when the success fraction is below the threshold in `MazeDatasetConfig.success_fraction_compensate`"""
52
53	pass

raised when the success fraction is below the threshold in MazeDatasetConfig.success_fraction_compensate

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
def set_serialize_minimal_threshold(threshold: int | None) -> None:
56def set_serialize_minimal_threshold(threshold: int | None) -> None:
57	"get the global SERIALIZE_MINIMAL_THRESHOLD"
58	global SERIALIZE_MINIMAL_THRESHOLD  # noqa: PLW0603
59	SERIALIZE_MINIMAL_THRESHOLD = threshold

get the global SERIALIZE_MINIMAL_THRESHOLD

EndpointKwargsType = dict[typing.Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | None | list[tuple[int, int]]]
@serializable_dataclass(kw_only=True, methods_no_override=['serialize'])
class MazeDatasetConfig(_MazeDatasetConfig_base):
235@serializable_dataclass(kw_only=True, methods_no_override=["serialize"])
236class MazeDatasetConfig(_MazeDatasetConfig_base):  # type: ignore[misc]
237	"""config object which is passed to `MazeDataset.from_config` to generate or load a dataset"""
238
239	@property
240	def config_version(self) -> str:
241		"""return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config"""
242		return "1.0"
243
244	@property
245	def versions(self) -> dict:
246		"""return the versions of the config and the maze_dataset"""
247		return dict(
248			config=self.config_version,
249			maze_dataset=importlib.metadata.version("maze_dataset"),
250		)
251
252	def serialize(self) -> dict:
253		"serialize the MazeDatasetConfig with all fields and fname"
254		return {
255			**self._serialize_base(
256				applied_filters__skip__collect_generation_meta=False
257			),
258			"fname": self.to_fname(),
259			"versions": self.versions,
260		}
261
262	def summary(self) -> dict:
263		"""return a summary of the config"""
264		# do we run this to make sure it doesn't error?
265		super_summary: dict = super().summary()
266		assert super_summary
267		self_ser: dict = self.serialize()
268		return dict(
269			name=self.name,
270			fname=self.to_fname(),
271			sdc_hash=self.stable_hash_cfg(),
272			seed=self.seed,
273			seq_len_min=self.seq_len_min,
274			seq_len_max=self.seq_len_max,
275			applied_filters=self.applied_filters,
276			grid_n=self_ser["grid_n"],
277			n_mazes=self_ser["n_mazes"],
278			maze_ctor_name=self_ser["maze_ctor"]["__name__"],
279			maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
280			endpoint_kwargs=self_ser["endpoint_kwargs"],
281		)
282
283	def _to_ps_array(self) -> _PercolationSuccessArray:
284		"""Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector.
285
286		used in predicting the success rate
287		"""
288		try:
289			assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, (
290				f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }"
291			)
292			assert "p" in self.maze_ctor_kwargs, (
293				f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }"
294			)
295			assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), (
296				f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }"
297			)
298		except AssertionError as e:
299			err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }"
300			raise NoPercolationInConfigError(
301				err_msg,
302			) from e
303
304		endpoints_unique_flag: int = int(
305			# we are pretty sure it will be an int or bool here
306			self.endpoint_kwargs.get("endpoints_not_equal", True),  # type: ignore[arg-type]
307		)
308
309		# adjustment for bknutson0
310		if not (
311			self.endpoint_kwargs.get("deadend_start", False)
312			and self.endpoint_kwargs.get("deadend_end", False)
313		):
314			# we didnt train on this, but if either endpoint is not required to be in a dead end
315			# then  requiring the endpoints to be unique does not really affect the success rate
316			# (except for very small percolation values, pure percolation generation)
317			endpoints_unique_flag = 0
318
319		return np.array(
320			[
321				float(self.maze_ctor_kwargs["p"]),
322				float(self.grid_n),
323				float(
324					int(
325						self.endpoint_kwargs.get("deadend_start", False)  # type: ignore[arg-type]
326						or self.endpoint_kwargs.get("deadend_end", False),
327					),
328				),
329				float(endpoints_unique_flag),
330				float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)),
331			],
332			dtype=np.float64,
333		)
334
335	@classmethod
336	def _from_ps_array(
337		cls,
338		arr: _PercolationSuccessArray,
339		name: str = "predict",
340		n_mazes: int = 100,
341		**kwargs,
342	) -> "MazeDatasetConfig":
343		"""Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters.
344
345		# Returns:
346		- `MazeDatasetConfig`
347			Config corresponding to `arr`
348		"""
349		return cls(
350			name=name,
351			grid_n=int(arr[1]),
352			n_mazes=n_mazes,
353			maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]],
354			maze_ctor_kwargs={"p": float(arr[0])},
355			endpoint_kwargs=dict(
356				deadend_start=bool(arr[2]),
357				deadend_end=bool(arr[2]),
358				endpoints_not_equal=bool(arr[3]),
359				except_on_no_valid_endpoint=False,
360			),
361			**kwargs,
362		)
363
364	def success_fraction_estimate(
365		self,
366		except_if_all_success_expected: bool = False,
367	) -> float:
368		"""Estimate the success fraction of this config.
369
370		only valid when the generator is a percolation generator,
371		and endpoints are enforced to be dead ends
372
373		this estimate comes from `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmarks.sweep_fit`
374
375		# Parameters:
376		- `except_if_all_success_expected : bool`
377			if `True`, don't raise an error if the success fraction is below the threshold.
378			will always return `1.0` if the config is not expected to fail
379
380		# Returns:
381		- `float`
382			estimated success fraction
383
384		# Raises:
385		- `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False`
386		"""
387		try:
388			return cfg_success_predict_fn(self)
389
390		except NoPercolationInConfigError as e:
391			if except_if_all_success_expected:
392				raise e  # noqa: TRY201
393			return 1.0
394
395	def success_fraction_compensate(
396		self,
397		safety_margin: float = 1.2,
398		except_if_all_success_expected: bool = False,
399		epsilon: float = 1e-2,
400	) -> "MazeDatasetConfig":
401		"""return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction
402
403		# Parameters:
404		- `safety_margin : float`
405			safety margin to apply to the success fraction estimate
406			(defaults to `1.2`, or 20% more mazes than estimated)
407		- `except_if_all_success_expected : bool`
408			if `True`, don't raise an error if the success fraction is below the threshold.
409			this is passed to `MazeDatasetConfig.success_fraction_estimate`.
410			if your config isn't expected to fail, passing this might mean you generate more mazes than needed
411			since `safety_margin` is still applied.
412			(defaults to `False`)
413		- `epsilon : float`
414			raise `SuccessChanceTooSmallError` if the success fraction is below this threshold
415			(defaults to `1e-2`)
416
417		# Returns:
418		- `MazeDatasetConfig`
419			new config with adjusted `n_mazes`
420
421		# Raises:
422		- `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon`
423		"""
424		# compute and check the success fraction
425		success_fraction: float = self.success_fraction_estimate(
426			except_if_all_success_expected=except_if_all_success_expected,
427		)
428		if success_fraction < epsilon:
429			err_msg: str = (
430				f"{success_fraction = } is below the threshold of {epsilon = }"
431			)
432			raise SuccessChanceTooSmallError(
433				err_msg,
434			)
435
436		# compute the new number of mazes
437		n_mazes: int = self.n_mazes
438		new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1
439
440		# put it in a new config and return
441		cfg_dict: dict = self.serialize()
442		cfg_dict["n_mazes"] = new_n_mazes
443		return MazeDatasetConfig.load(cfg_dict)

config object which is passed to MazeDataset.from_config to generate or load a dataset

MazeDatasetConfig( *, name: str, seq_len_min: int = 1, seq_len_max: int = 512, seed: int | None = 42, applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]] = <factory>, grid_n: int, n_mazes: int, maze_ctor: Callable = <function LatticeMazeGenerators.gen_dfs>, maze_ctor_kwargs: dict = <factory>, endpoint_kwargs: dict[typing.Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | None | list[tuple[int, int]]] = <factory>, _fname_loaded: str | None = None)
config_version: str
239	@property
240	def config_version(self) -> str:
241		"""return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config"""
242		return "1.0"

return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config

versions: dict
244	@property
245	def versions(self) -> dict:
246		"""return the versions of the config and the maze_dataset"""
247		return dict(
248			config=self.config_version,
249			maze_dataset=importlib.metadata.version("maze_dataset"),
250		)

return the versions of the config and the maze_dataset

def serialize(self) -> dict:
252	def serialize(self) -> dict:
253		"serialize the MazeDatasetConfig with all fields and fname"
254		return {
255			**self._serialize_base(
256				applied_filters__skip__collect_generation_meta=False
257			),
258			"fname": self.to_fname(),
259			"versions": self.versions,
260		}

serialize the MazeDatasetConfig with all fields and fname

def summary(self) -> dict:
262	def summary(self) -> dict:
263		"""return a summary of the config"""
264		# do we run this to make sure it doesn't error?
265		super_summary: dict = super().summary()
266		assert super_summary
267		self_ser: dict = self.serialize()
268		return dict(
269			name=self.name,
270			fname=self.to_fname(),
271			sdc_hash=self.stable_hash_cfg(),
272			seed=self.seed,
273			seq_len_min=self.seq_len_min,
274			seq_len_max=self.seq_len_max,
275			applied_filters=self.applied_filters,
276			grid_n=self_ser["grid_n"],
277			n_mazes=self_ser["n_mazes"],
278			maze_ctor_name=self_ser["maze_ctor"]["__name__"],
279			maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
280			endpoint_kwargs=self_ser["endpoint_kwargs"],
281		)

return a summary of the config

def success_fraction_estimate(self, except_if_all_success_expected: bool = False) -> float:
364	def success_fraction_estimate(
365		self,
366		except_if_all_success_expected: bool = False,
367	) -> float:
368		"""Estimate the success fraction of this config.
369
370		only valid when the generator is a percolation generator,
371		and endpoints are enforced to be dead ends
372
373		this estimate comes from `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmarks.sweep_fit`
374
375		# Parameters:
376		- `except_if_all_success_expected : bool`
377			if `True`, don't raise an error if the success fraction is below the threshold.
378			will always return `1.0` if the config is not expected to fail
379
380		# Returns:
381		- `float`
382			estimated success fraction
383
384		# Raises:
385		- `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False`
386		"""
387		try:
388			return cfg_success_predict_fn(self)
389
390		except NoPercolationInConfigError as e:
391			if except_if_all_success_expected:
392				raise e  # noqa: TRY201
393			return 1.0

Estimate the success fraction of this config.

only valid when the generator is a percolation generator, and endpoints are enforced to be dead ends

this estimate comes from estimate_dataset_fractions.ipynb and maze_dataset.benchmarks.sweep_fit

Parameters:

  • except_if_all_success_expected : bool if True, don't raise an error if the success fraction is below the threshold. will always return 1.0 if the config is not expected to fail

Returns:

  • float estimated success fraction

Raises:

def success_fraction_compensate( self, safety_margin: float = 1.2, except_if_all_success_expected: bool = False, epsilon: float = 0.01) -> MazeDatasetConfig:
395	def success_fraction_compensate(
396		self,
397		safety_margin: float = 1.2,
398		except_if_all_success_expected: bool = False,
399		epsilon: float = 1e-2,
400	) -> "MazeDatasetConfig":
401		"""return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction
402
403		# Parameters:
404		- `safety_margin : float`
405			safety margin to apply to the success fraction estimate
406			(defaults to `1.2`, or 20% more mazes than estimated)
407		- `except_if_all_success_expected : bool`
408			if `True`, don't raise an error if the success fraction is below the threshold.
409			this is passed to `MazeDatasetConfig.success_fraction_estimate`.
410			if your config isn't expected to fail, passing this might mean you generate more mazes than needed
411			since `safety_margin` is still applied.
412			(defaults to `False`)
413		- `epsilon : float`
414			raise `SuccessChanceTooSmallError` if the success fraction is below this threshold
415			(defaults to `1e-2`)
416
417		# Returns:
418		- `MazeDatasetConfig`
419			new config with adjusted `n_mazes`
420
421		# Raises:
422		- `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon`
423		"""
424		# compute and check the success fraction
425		success_fraction: float = self.success_fraction_estimate(
426			except_if_all_success_expected=except_if_all_success_expected,
427		)
428		if success_fraction < epsilon:
429			err_msg: str = (
430				f"{success_fraction = } is below the threshold of {epsilon = }"
431			)
432			raise SuccessChanceTooSmallError(
433				err_msg,
434			)
435
436		# compute the new number of mazes
437		n_mazes: int = self.n_mazes
438		new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1
439
440		# put it in a new config and return
441		cfg_dict: dict = self.serialize()
442		cfg_dict["n_mazes"] = new_n_mazes
443		return MazeDatasetConfig.load(cfg_dict)

return a new MazeDatasetConfig like this one with n_mazes adjusted to compensate for the success fraction

Parameters:

  • safety_margin : float safety margin to apply to the success fraction estimate (defaults to 1.2, or 20% more mazes than estimated)
  • except_if_all_success_expected : bool if True, don't raise an error if the success fraction is below the threshold. this is passed to MazeDatasetConfig.success_fraction_estimate. if your config isn't expected to fail, passing this might mean you generate more mazes than needed since safety_margin is still applied. (defaults to False)
  • epsilon : float raise SuccessChanceTooSmallError if the success fraction is below this threshold (defaults to 1e-2)

Returns:

Raises:

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
_MazeDatasetConfig_base
grid_n
n_mazes
maze_ctor
maze_ctor_kwargs
endpoint_kwargs
grid_shape
grid_shape_np
max_grid_n
stable_hash_cfg
to_fname
maze_dataset.dataset.dataset.GPTDatasetConfig
name
seq_len_min
seq_len_max
seed
applied_filters
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict