docs for maze-dataset v1.3.1
View Source on GitHub

maze_dataset.dataset

MazeDatasetConfigs are used to create a MazeDataset via MazeDataset.from_config(cfg)


 1"`MazeDatasetConfig`s are used to create a `MazeDataset` via `MazeDataset.from_config(cfg)`"
 2
 3from maze_dataset.dataset.collected_dataset import (
 4	MazeDatasetCollection,
 5	MazeDatasetCollectionConfig,
 6)
 7from maze_dataset.dataset.maze_dataset import MazeDataset
 8from maze_dataset.dataset.maze_dataset_config import MazeDatasetConfig
 9
10__all__ = [
11	# submodules
12	"collected_dataset",
13	"configs",
14	"dataset",
15	"filters",
16	"maze_dataset_config",
17	"maze_dataset",
18	"rasterized",
19	"success_predict_math",
20	# dataset classes
21	"MazeDataset",
22	"MazeDatasetConfig",
23	"MazeDatasetCollection",
24	"MazeDatasetCollectionConfig",
25]

113class MazeDataset(GPTDataset[MazeDatasetConfig]):
114	"""a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`"""
115
116	def __init__(
117		self,
118		cfg: MazeDatasetConfig,
119		mazes: typing.Sequence[SolvedMaze],
120		generation_metadata_collected: dict | None = None,
121	) -> None:
122		"""initialize a maze dataset from a config and a list of solved mazes"""
123		super().__init__()
124		self.cfg: MazeDatasetConfig = cfg
125		self.mazes: list[SolvedMaze] = list(mazes)
126		self.generation_metadata_collected: dict | None = generation_metadata_collected
127
128	# TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset"  [override]
129	@classmethod
130	def from_config(  # type: ignore[override]
131		cls,
132		# TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig"  [override]
133		cfg: MazeDatasetConfig,  # type: ignore[override]
134		do_generate: bool = True,
135		load_local: bool = True,
136		save_local: bool = True,
137		zanj: ZANJ | None = None,
138		do_download: bool = True,
139		local_base_path: Path = Path("data/maze_dataset"),
140		except_on_config_mismatch: bool = True,
141		allow_generation_metadata_filter_mismatch: bool = True,
142		verbose: bool = False,
143		**kwargs,
144	) -> "MazeDataset":
145		"""create a maze dataset from a config
146
147		priority of loading:
148		1. load from local
149		2. download
150		3. generate
151
152		"""
153		return cast(
154			MazeDataset,
155			super().from_config(
156				cfg=cfg,
157				do_generate=do_generate,
158				load_local=load_local,
159				save_local=save_local,
160				zanj=zanj,
161				do_download=do_download,
162				local_base_path=local_base_path,
163				except_on_config_mismatch=except_on_config_mismatch,
164				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
165				verbose=verbose,
166				**kwargs,
167			),
168		)
169
170	def data_hash(self) -> int:
171		"""return a hash of the data"""
172		return stable_hash(str(tuple([x.serialize() for x in self.mazes])))
173
174	def __getitem__(self, i: int) -> SolvedMaze:
175		"""get a maze by index"""
176		return self.mazes[i]
177
178	def __iter__(self) -> typing.Iterator[SolvedMaze]:
179		"""iterate over the mazes"""
180		return iter(self.mazes)
181
182	def __deepcopy__(self, memo) -> "MazeDataset":  # noqa: ANN001
183		"""deepcopy the dataset
184
185		FIX: this isnt actually a deepcopy I think?
186		"""
187		return MazeDataset.load(self._serialize_full())
188
189	# TYPING: get type hints on the tokenizer here
190	@overload
191	def as_tokens(
192		self,
193		maze_tokenizer,  # noqa: ANN001
194		limit: int | None = None,
195		join_tokens_individual_maze: Literal[False] = False,
196	) -> list[list[str]]: ...
197	@overload
198	def as_tokens(
199		self,
200		maze_tokenizer,  # noqa: ANN001
201		limit: int | None = None,
202		join_tokens_individual_maze: Literal[True] = True,
203	) -> list[str]: ...
204	def as_tokens(
205		self,
206		maze_tokenizer,  # TODO: MazeTokenizer
207		limit: int | None = None,
208		join_tokens_individual_maze: bool = False,
209	) -> list[list[str]] | list[str]:
210		"""return the dataset as tokens according to the passed `maze_tokenizer`
211
212		the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular`
213
214		if `join_tokens_individual_maze` is True, then the tokens of each maze are
215		joined with a space, and the result is a list of strings.
216		i.e.:
217
218			>>> dataset.as_tokens(join_tokens_individual_maze=False)
219			[["a", "b", "c"], ["d", "e", "f"]]
220			>>> dataset.as_tokens(join_tokens_individual_maze=True)
221			["a b c", "d e f"]
222		"""
223		output: list[list[str]] = [
224			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
225		]
226		if join_tokens_individual_maze:
227			return [" ".join(tokens) for tokens in output]
228		else:
229			return output
230
231	def __len__(self) -> int:
232		"""return the number of mazes in the dataset"""
233		return len(self.mazes)
234
235	def __eq__(self, other: object) -> bool:
236		"""compare two datasets"""
237		if not isinstance(other, MazeDataset):
238			raise NotImplementedError(
239				"can only compare with other MazeDataset objects",
240			)
241		# TODO: compare hashes of data instead of the data itself?
242		return self.cfg == other.cfg and self.mazes == other.mazes
243
244	def assert_equal(self, other: "MazeDataset") -> None:
245		"""assert that two datasets are equal"""
246		assert isinstance(other, MazeDataset)
247		assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }"
248		assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"
249
250	@classmethod
251	def generate(
252		cls,
253		cfg: MazeDatasetConfig,
254		gen_parallel: bool = False,
255		pool_kwargs: dict | None = None,
256		verbose: bool = False,
257		# TODO: what to do when unexpected kwargs are passed?
258		**kwargs,  # noqa: ARG003
259	) -> "MazeDataset":
260		"""Generate a maze dataset given a config and some generation parameters"""
261		# Copy the config to avoid modifying the original
262		cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load(
263			json.loads(json.dumps(cfg.serialize())),
264		)
265
266		if pool_kwargs is None:
267			pool_kwargs = dict()
268		maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes)  # type: ignore[assignment]
269
270		solved_mazes: list[SolvedMaze | None]
271		# Configure tqdm for progress bar
272		tqdm_kwargs: dict = dict(
273			total=cfg_cpy.n_mazes,
274			unit="maze",
275			desc="generating & solving mazes",
276			disable=not verbose,
277		)
278		# TODO: don't use the global unless generating in parallel!
279		if gen_parallel:
280			with multiprocessing.Pool(
281				**pool_kwargs,
282				initializer=_maze_gen_init_worker,
283				initargs=(cfg_cpy,),
284			) as pool:
285				solved_mazes = list(
286					tqdm.tqdm(
287						pool.imap(_generate_maze_helper, maze_indexes),
288						**tqdm_kwargs,
289					),
290				)
291
292		else:
293			_maze_gen_init_worker(cfg_cpy)
294			solved_mazes = list(
295				tqdm.tqdm(
296					map(
297						# TYPING:  error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]"  [arg-type]
298						# why does it think tolist() returns a string?
299						_generate_maze_helper,  # type: ignore[arg-type]
300						maze_indexes.tolist(),
301					),
302					**tqdm_kwargs,
303				),
304			)
305
306		# Filter out None values explicitly after ensuring all results are collected
307		solved_mazes_: list[SolvedMaze] = [
308			maze for maze in solved_mazes if maze is not None
309		]
310		# solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes))
311
312		# Update the config with the actual number of mazes
313		cfg_cpy.n_mazes = len(solved_mazes_)
314
315		dataset: MazeDataset = cls(
316			cfg=cfg_cpy,
317			mazes=solved_mazes_,
318		)
319
320		dataset.update_self_config()  # Call `update_self_config()` to ensure the dataset's config reflects changes
321
322		np.random.seed(cfg_cpy.seed)  # Reset the seed to the value in the config copy
323
324		return dataset
325
326	@classmethod
327	def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset":
328		"(not implemented yet!) download a maze dataset from the internet"
329		raise NotImplementedError("not implemented yet")
330
331	@classmethod
332	def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset":
333		"""load from zanj/json"""
334		if data[_FORMAT_KEY] == "MazeDataset:minimal":
335			return cls._load_minimal(data)
336		elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat":
337			return cls._load_minimal_soln_cat(data)
338		elif data[_FORMAT_KEY] == "MazeDataset":
339			if (
340				SERIALIZE_MINIMAL_THRESHOLD == -1
341			):  # Allow access to `_load_legacy` for profiling
342				return cls._load_legacy(data)
343			return cls._load_full(data)
344		else:
345			err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })"
346			raise KeyError(
347				err_msg,
348			)
349
350	@classmethod
351	def _load_full(cls, data: JSONdict) -> "MazeDataset":
352		assert data[_FORMAT_KEY] == "MazeDataset"
353		return cls(
354			cfg=MazeDatasetConfig.load(data["cfg"]),  # type: ignore[arg-type]
355			mazes=load_item_recursive(data["mazes"], tuple()),
356			generation_metadata_collected=data["generation_metadata_collected"],  # type: ignore[arg-type]
357		)
358
359	@classmethod
360	def _load_minimal(cls, data: JSONdict) -> "MazeDataset":
361		assert data[_FORMAT_KEY] == "MazeDataset:minimal"
362		return cls(
363			cfg=MazeDatasetConfig.load(data["cfg"]),  # type: ignore[arg-type]
364			generation_metadata_collected=data["generation_metadata_collected"],  # type: ignore[arg-type]
365			mazes=[
366				SolvedMaze(
367					clist,
368					soln[:slen, ...],
369				)
370				for clist, slen, soln in zip(
371					load_item_recursive(data["maze_connection_lists"], tuple()),
372					load_item_recursive(data["maze_solution_lengths"], tuple()),
373					load_item_recursive(data["maze_solutions"], tuple()),
374					strict=False,
375					# load_item_recursive(data["maze_endpoints"], tuple()),
376				)
377			],
378		)
379
380	@classmethod
381	def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset":
382		assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat"
383
384		maze_solution_lengths = load_item_recursive(
385			data["maze_solution_lengths"],
386			tuple(),
387		)
388		maze_solutions_concat = load_item_recursive(
389			data["maze_solutions_concat"],
390			tuple(),
391		)
392		maze_solutions = np.split(
393			maze_solutions_concat,
394			np.cumsum(maze_solution_lengths)[:-1],
395			axis=0,
396		)
397
398		return cls(
399			cfg=load_item_recursive(data["cfg"], tuple()),
400			generation_metadata_collected=load_item_recursive(
401				data["generation_metadata_collected"],
402				tuple(),
403			),
404			mazes=[
405				SolvedMaze(
406					connection_list=clist,
407					solution=soln,
408				)
409				for clist, soln in zip(
410					load_item_recursive(data["maze_connection_lists"], tuple()),
411					# load_item_recursive(data["maze_endpoints"], tuple()),
412					maze_solutions,
413					strict=False,
414				)
415			],
416		)
417
418	@classmethod
419	def _load_legacy(cls, data: JSONdict) -> "MazeDataset":
420		"""Legacy `load` method from <0.5.2. Used exclusively for profiling comparison."""
421		assert data[_FORMAT_KEY] == "MazeDataset"
422		return cls(
423			**{
424				key: load_item_recursive(data[key], tuple())
425				for key in ["cfg", "mazes", "generation_metadata_collected"]
426			},
427		)
428
429	def serialize(self) -> JSONdict:
430		"""serialize to zanj/json"""
431		if (
432			SERIALIZE_MINIMAL_THRESHOLD is not None
433			and len(self) >= SERIALIZE_MINIMAL_THRESHOLD
434		):
435			return self._serialize_minimal()
436		return self._serialize_full()
437
438	def _serialize_full(self) -> JSONdict:
439		return {
440			_FORMAT_KEY: "MazeDataset",
441			"cfg": json_serialize(self.cfg),
442			"fname": self.cfg.to_fname(),
443			"mazes": json_serialize(self.mazes),
444			"generation_metadata_collected": json_serialize(
445				self.generation_metadata_collected,
446			),
447		}
448
449	def _serialize_minimal(self) -> JSONdict:
450		"alternate serialization where metadata is collected and mazes are stored in concatenated form"
451		filtered_meta: MazeDataset
452		if self.generation_metadata_collected is None:
453			filtered_meta = self.filter_by.collect_generation_meta()
454		else:
455			filtered_meta = self
456
457		max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes)
458		n_mazes: int = len(filtered_meta.mazes)
459		grid_n: int = filtered_meta.cfg.grid_n
460
461		maze_connection_lists: np.ndarray = np.empty(
462			(n_mazes, 2, grid_n, grid_n),
463			dtype=np.bool_,
464		)
465		# maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
466		maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32)
467		maze_solutions: np.ndarray = np.empty(
468			(n_mazes, max_solution_len, 2),
469			dtype=np.int8,
470		)
471
472		for idx, maze in enumerate(filtered_meta.mazes):
473			maze_connection_lists[idx] = maze.connection_list
474			# maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
475			maze_solution_lengths[idx] = maze.solution.shape[0]
476			maze_solutions[idx, : maze.solution.shape[0]] = maze.solution
477
478		return {
479			_FORMAT_KEY: "MazeDataset:minimal",
480			"cfg": json_serialize(filtered_meta.cfg),
481			"fname": filtered_meta.cfg.to_fname(),
482			"generation_metadata_collected": json_serialize(
483				filtered_meta.generation_metadata_collected,
484			),
485			"maze_connection_lists": maze_connection_lists,  # type: ignore[dict-item]
486			# "maze_endpoints": maze_endpoints,
487			"maze_solution_lengths": maze_solution_lengths,  # type: ignore[dict-item]
488			"maze_solutions": maze_solutions,  # type: ignore[dict-item]
489		}
490
491	def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict:
492		"alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form"
493		filtered_meta: MazeDataset
494		if self.generation_metadata_collected is None:
495			filtered_meta = self.filter_by.collect_generation_meta()
496		else:
497			filtered_meta = self
498
499		maze_solution_lengths: np.ndarray = np.array(
500			[m.solution.shape[0] for m in filtered_meta.mazes],
501			dtype=np.int32,
502		)
503		n_mazes: int = len(filtered_meta.mazes)
504		grid_n: int = filtered_meta.cfg.grid_n
505		total_solution_len: int = np.sum(maze_solution_lengths)
506
507		maze_connection_lists: np.ndarray = np.empty(
508			(n_mazes, 2, grid_n, grid_n),
509			dtype=np.bool_,
510		)
511		maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
512		maze_solutions_concat: np.ndarray = np.empty(
513			(total_solution_len, 2),
514			dtype=np.int8,
515		)
516
517		solutions_running_idx: int = 0
518		for idx, maze in enumerate(filtered_meta.mazes):
519			maze_connection_lists[idx] = maze.connection_list
520			maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
521			soln_len: int = maze.solution.shape[0]
522			maze_solution_lengths[idx] = soln_len
523			maze_solutions_concat[
524				solutions_running_idx : solutions_running_idx + soln_len
525			] = maze.solution
526			solutions_running_idx += soln_len
527
528		return {
529			_FORMAT_KEY: "MazeDataset:minimal_soln_cat",
530			"cfg": json_serialize(filtered_meta.cfg),
531			"fname": filtered_meta.cfg.to_fname(),
532			"generation_metadata_collected": json_serialize(
533				filtered_meta.generation_metadata_collected,
534			),
535			"maze_connection_lists": maze_connection_lists,  # type: ignore[dict-item]
536			"maze_endpoints": maze_endpoints,  # type: ignore[dict-item]
537			"maze_solution_lengths": maze_solution_lengths,  # type: ignore[dict-item]
538			"maze_solutions_concat": maze_solutions_concat,  # type: ignore[dict-item]
539		}
540
541	def update_self_config(self) -> None:
542		"""update the config to match the current state of the dataset (number of mazes, such as after filtering)"""
543		if self.cfg.n_mazes != len(self.mazes):
544			warnings.warn(
545				f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}",
546			)
547			self.cfg.n_mazes = len(self.mazes)
548
549	def custom_maze_filter(
550		self,
551		method: typing.Callable[[SolvedMaze], bool],
552		**kwargs,
553	) -> "MazeDataset":
554		"""filter the dataset using a custom method"""
555		output: MazeDataset = MazeDataset(
556			cfg=copy.deepcopy(self.cfg),
557			mazes=[m for m in self.mazes if method(m, **kwargs)],
558		)
559		output.cfg.applied_filters.append(
560			{
561				"name": f"__custom__:{method.__name__}",
562				"kwargs": kwargs,
563			},
564		)
565		output.update_self_config()
566		return output

a maze dataset class. This is a collection of solved mazes, and should be initialized via MazeDataset.from_config

MazeDataset( cfg: MazeDatasetConfig, mazes: Sequence[maze_dataset.SolvedMaze], generation_metadata_collected: dict | None = None)
116	def __init__(
117		self,
118		cfg: MazeDatasetConfig,
119		mazes: typing.Sequence[SolvedMaze],
120		generation_metadata_collected: dict | None = None,
121	) -> None:
122		"""initialize a maze dataset from a config and a list of solved mazes"""
123		super().__init__()
124		self.cfg: MazeDatasetConfig = cfg
125		self.mazes: list[SolvedMaze] = list(mazes)
126		self.generation_metadata_collected: dict | None = generation_metadata_collected

initialize a maze dataset from a config and a list of solved mazes

generation_metadata_collected: dict | None
@classmethod
def from_config( cls, cfg: MazeDatasetConfig, do_generate: bool = True, load_local: bool = True, save_local: bool = True, zanj: zanj.zanj.ZANJ | None = None, do_download: bool = True, local_base_path: pathlib.Path = PosixPath('data/maze_dataset'), except_on_config_mismatch: bool = True, allow_generation_metadata_filter_mismatch: bool = True, verbose: bool = False, **kwargs) -> MazeDataset:
129	@classmethod
130	def from_config(  # type: ignore[override]
131		cls,
132		# TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig"  [override]
133		cfg: MazeDatasetConfig,  # type: ignore[override]
134		do_generate: bool = True,
135		load_local: bool = True,
136		save_local: bool = True,
137		zanj: ZANJ | None = None,
138		do_download: bool = True,
139		local_base_path: Path = Path("data/maze_dataset"),
140		except_on_config_mismatch: bool = True,
141		allow_generation_metadata_filter_mismatch: bool = True,
142		verbose: bool = False,
143		**kwargs,
144	) -> "MazeDataset":
145		"""create a maze dataset from a config
146
147		priority of loading:
148		1. load from local
149		2. download
150		3. generate
151
152		"""
153		return cast(
154			MazeDataset,
155			super().from_config(
156				cfg=cfg,
157				do_generate=do_generate,
158				load_local=load_local,
159				save_local=save_local,
160				zanj=zanj,
161				do_download=do_download,
162				local_base_path=local_base_path,
163				except_on_config_mismatch=except_on_config_mismatch,
164				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
165				verbose=verbose,
166				**kwargs,
167			),
168		)

create a maze dataset from a config

priority of loading:

  1. load from local
  2. download
  3. generate
def data_hash(self) -> int:
170	def data_hash(self) -> int:
171		"""return a hash of the data"""
172		return stable_hash(str(tuple([x.serialize() for x in self.mazes])))

return a hash of the data

def as_tokens( self, maze_tokenizer, limit: int | None = None, join_tokens_individual_maze: bool = False) -> list[list[str]] | list[str]:
204	def as_tokens(
205		self,
206		maze_tokenizer,  # TODO: MazeTokenizer
207		limit: int | None = None,
208		join_tokens_individual_maze: bool = False,
209	) -> list[list[str]] | list[str]:
210		"""return the dataset as tokens according to the passed `maze_tokenizer`
211
212		the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular`
213
214		if `join_tokens_individual_maze` is True, then the tokens of each maze are
215		joined with a space, and the result is a list of strings.
216		i.e.:
217
218			>>> dataset.as_tokens(join_tokens_individual_maze=False)
219			[["a", "b", "c"], ["d", "e", "f"]]
220			>>> dataset.as_tokens(join_tokens_individual_maze=True)
221			["a b c", "d e f"]
222		"""
223		output: list[list[str]] = [
224			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
225		]
226		if join_tokens_individual_maze:
227			return [" ".join(tokens) for tokens in output]
228		else:
229			return output

return the dataset as tokens according to the passed maze_tokenizer

the maze_tokenizer should be either a MazeTokenizer or a MazeTokenizerModular

if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:

    >>> dataset.as_tokens(join_tokens_individual_maze=False)
    [["a", "b", "c"], ["d", "e", "f"]]
    >>> dataset.as_tokens(join_tokens_individual_maze=True)
    ["a b c", "d e f"]
def assert_equal(self, other: MazeDataset) -> None:
244	def assert_equal(self, other: "MazeDataset") -> None:
245		"""assert that two datasets are equal"""
246		assert isinstance(other, MazeDataset)
247		assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }"
248		assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"

assert that two datasets are equal

@classmethod
def generate( cls, cfg: MazeDatasetConfig, gen_parallel: bool = False, pool_kwargs: dict | None = None, verbose: bool = False, **kwargs) -> MazeDataset:
250	@classmethod
251	def generate(
252		cls,
253		cfg: MazeDatasetConfig,
254		gen_parallel: bool = False,
255		pool_kwargs: dict | None = None,
256		verbose: bool = False,
257		# TODO: what to do when unexpected kwargs are passed?
258		**kwargs,  # noqa: ARG003
259	) -> "MazeDataset":
260		"""Generate a maze dataset given a config and some generation parameters"""
261		# Copy the config to avoid modifying the original
262		cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load(
263			json.loads(json.dumps(cfg.serialize())),
264		)
265
266		if pool_kwargs is None:
267			pool_kwargs = dict()
268		maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes)  # type: ignore[assignment]
269
270		solved_mazes: list[SolvedMaze | None]
271		# Configure tqdm for progress bar
272		tqdm_kwargs: dict = dict(
273			total=cfg_cpy.n_mazes,
274			unit="maze",
275			desc="generating & solving mazes",
276			disable=not verbose,
277		)
278		# TODO: don't use the global unless generating in parallel!
279		if gen_parallel:
280			with multiprocessing.Pool(
281				**pool_kwargs,
282				initializer=_maze_gen_init_worker,
283				initargs=(cfg_cpy,),
284			) as pool:
285				solved_mazes = list(
286					tqdm.tqdm(
287						pool.imap(_generate_maze_helper, maze_indexes),
288						**tqdm_kwargs,
289					),
290				)
291
292		else:
293			_maze_gen_init_worker(cfg_cpy)
294			solved_mazes = list(
295				tqdm.tqdm(
296					map(
297						# TYPING:  error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]"  [arg-type]
298						# why does it think tolist() returns a string?
299						_generate_maze_helper,  # type: ignore[arg-type]
300						maze_indexes.tolist(),
301					),
302					**tqdm_kwargs,
303				),
304			)
305
306		# Filter out None values explicitly after ensuring all results are collected
307		solved_mazes_: list[SolvedMaze] = [
308			maze for maze in solved_mazes if maze is not None
309		]
310		# solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes))
311
312		# Update the config with the actual number of mazes
313		cfg_cpy.n_mazes = len(solved_mazes_)
314
315		dataset: MazeDataset = cls(
316			cfg=cfg_cpy,
317			mazes=solved_mazes_,
318		)
319
320		dataset.update_self_config()  # Call `update_self_config()` to ensure the dataset's config reflects changes
321
322		np.random.seed(cfg_cpy.seed)  # Reset the seed to the value in the config copy
323
324		return dataset

Generate a maze dataset given a config and some generation parameters

@classmethod
def download( cls, cfg: MazeDatasetConfig, **kwargs) -> MazeDataset:
326	@classmethod
327	def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset":
328		"(not implemented yet!) download a maze dataset from the internet"
329		raise NotImplementedError("not implemented yet")

(not implemented yet!) download a maze dataset from the internet

@classmethod
def load( cls: type[MazeDataset], data: Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]) -> MazeDataset:
331	@classmethod
332	def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset":
333		"""load from zanj/json"""
334		if data[_FORMAT_KEY] == "MazeDataset:minimal":
335			return cls._load_minimal(data)
336		elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat":
337			return cls._load_minimal_soln_cat(data)
338		elif data[_FORMAT_KEY] == "MazeDataset":
339			if (
340				SERIALIZE_MINIMAL_THRESHOLD == -1
341			):  # Allow access to `_load_legacy` for profiling
342				return cls._load_legacy(data)
343			return cls._load_full(data)
344		else:
345			err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })"
346			raise KeyError(
347				err_msg,
348			)

load from zanj/json

def serialize( self) -> Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]:
429	def serialize(self) -> JSONdict:
430		"""serialize to zanj/json"""
431		if (
432			SERIALIZE_MINIMAL_THRESHOLD is not None
433			and len(self) >= SERIALIZE_MINIMAL_THRESHOLD
434		):
435			return self._serialize_minimal()
436		return self._serialize_full()

serialize to zanj/json

def update_self_config(self) -> None:
541	def update_self_config(self) -> None:
542		"""update the config to match the current state of the dataset (number of mazes, such as after filtering)"""
543		if self.cfg.n_mazes != len(self.mazes):
544			warnings.warn(
545				f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}",
546			)
547			self.cfg.n_mazes = len(self.mazes)

update the config to match the current state of the dataset (number of mazes, such as after filtering)

def custom_maze_filter( self, method: Callable[[maze_dataset.SolvedMaze], bool], **kwargs) -> MazeDataset:
549	def custom_maze_filter(
550		self,
551		method: typing.Callable[[SolvedMaze], bool],
552		**kwargs,
553	) -> "MazeDataset":
554		"""filter the dataset using a custom method"""
555		output: MazeDataset = MazeDataset(
556			cfg=copy.deepcopy(self.cfg),
557			mazes=[m for m in self.mazes if method(m, **kwargs)],
558		)
559		output.cfg.applied_filters.append(
560			{
561				"name": f"__custom__:{method.__name__}",
562				"kwargs": kwargs,
563			},
564		)
565		output.update_self_config()
566		return output

filter the dataset using a custom method

@serializable_dataclass(kw_only=True, methods_no_override=['serialize'])
class MazeDatasetConfig(maze_dataset.dataset.maze_dataset_config._MazeDatasetConfig_base):
235@serializable_dataclass(kw_only=True, methods_no_override=["serialize"])
236class MazeDatasetConfig(_MazeDatasetConfig_base):  # type: ignore[misc]
237	"""config object which is passed to `MazeDataset.from_config` to generate or load a dataset"""
238
239	@property
240	def config_version(self) -> str:
241		"""return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config"""
242		return "1.0"
243
244	@property
245	def versions(self) -> dict:
246		"""return the versions of the config and the maze_dataset"""
247		return dict(
248			config=self.config_version,
249			maze_dataset=importlib.metadata.version("maze_dataset"),
250		)
251
252	def serialize(self) -> dict:
253		"serialize the MazeDatasetConfig with all fields and fname"
254		return {
255			**self._serialize_base(
256				applied_filters__skip__collect_generation_meta=False
257			),
258			"fname": self.to_fname(),
259			"versions": self.versions,
260		}
261
262	def summary(self) -> dict:
263		"""return a summary of the config"""
264		# do we run this to make sure it doesn't error?
265		super_summary: dict = super().summary()
266		assert super_summary
267		self_ser: dict = self.serialize()
268		return dict(
269			name=self.name,
270			fname=self.to_fname(),
271			sdc_hash=self.stable_hash_cfg(),
272			seed=self.seed,
273			seq_len_min=self.seq_len_min,
274			seq_len_max=self.seq_len_max,
275			applied_filters=self.applied_filters,
276			grid_n=self_ser["grid_n"],
277			n_mazes=self_ser["n_mazes"],
278			maze_ctor_name=self_ser["maze_ctor"]["__name__"],
279			maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
280			endpoint_kwargs=self_ser["endpoint_kwargs"],
281		)
282
283	def _to_ps_array(self) -> _PercolationSuccessArray:
284		"""Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector.
285
286		used in predicting the success rate
287		"""
288		try:
289			assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, (
290				f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }"
291			)
292			assert "p" in self.maze_ctor_kwargs, (
293				f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }"
294			)
295			assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), (
296				f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }"
297			)
298		except AssertionError as e:
299			err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }"
300			raise NoPercolationInConfigError(
301				err_msg,
302			) from e
303
304		endpoints_unique_flag: int = int(
305			# we are pretty sure it will be an int or bool here
306			self.endpoint_kwargs.get("endpoints_not_equal", True),  # type: ignore[arg-type]
307		)
308
309		# adjustment for bknutson0
310		if not (
311			self.endpoint_kwargs.get("deadend_start", False)
312			and self.endpoint_kwargs.get("deadend_end", False)
313		):
314			# we didnt train on this, but if either endpoint is not required to be in a dead end
315			# then  requiring the endpoints to be unique does not really affect the success rate
316			# (except for very small percolation values, pure percolation generation)
317			endpoints_unique_flag = 0
318
319		return np.array(
320			[
321				float(self.maze_ctor_kwargs["p"]),
322				float(self.grid_n),
323				float(
324					int(
325						self.endpoint_kwargs.get("deadend_start", False)  # type: ignore[arg-type]
326						or self.endpoint_kwargs.get("deadend_end", False),
327					),
328				),
329				float(endpoints_unique_flag),
330				float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)),
331			],
332			dtype=np.float64,
333		)
334
335	@classmethod
336	def _from_ps_array(
337		cls,
338		arr: _PercolationSuccessArray,
339		name: str = "predict",
340		n_mazes: int = 100,
341		**kwargs,
342	) -> "MazeDatasetConfig":
343		"""Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters.
344
345		# Returns:
346		- `MazeDatasetConfig`
347			Config corresponding to `arr`
348		"""
349		return cls(
350			name=name,
351			grid_n=int(arr[1]),
352			n_mazes=n_mazes,
353			maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]],
354			maze_ctor_kwargs={"p": float(arr[0])},
355			endpoint_kwargs=dict(
356				deadend_start=bool(arr[2]),
357				deadend_end=bool(arr[2]),
358				endpoints_not_equal=bool(arr[3]),
359				except_on_no_valid_endpoint=False,
360			),
361			**kwargs,
362		)
363
364	def success_fraction_estimate(
365		self,
366		except_if_all_success_expected: bool = False,
367	) -> float:
368		"""Estimate the success fraction of this config.
369
370		only valid when the generator is a percolation generator,
371		and endpoints are enforced to be dead ends
372
373		this estimate comes from `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmarks.sweep_fit`
374
375		# Parameters:
376		- `except_if_all_success_expected : bool`
377			if `True`, don't raise an error if the success fraction is below the threshold.
378			will always return `1.0` if the config is not expected to fail
379
380		# Returns:
381		- `float`
382			estimated success fraction
383
384		# Raises:
385		- `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False`
386		"""
387		try:
388			return cfg_success_predict_fn(self)
389
390		except NoPercolationInConfigError as e:
391			if except_if_all_success_expected:
392				raise e  # noqa: TRY201
393			return 1.0
394
395	def success_fraction_compensate(
396		self,
397		safety_margin: float = 1.2,
398		except_if_all_success_expected: bool = False,
399		epsilon: float = 1e-2,
400	) -> "MazeDatasetConfig":
401		"""return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction
402
403		# Parameters:
404		- `safety_margin : float`
405			safety margin to apply to the success fraction estimate
406			(defaults to `1.2`, or 20% more mazes than estimated)
407		- `except_if_all_success_expected : bool`
408			if `True`, don't raise an error if the success fraction is below the threshold.
409			this is passed to `MazeDatasetConfig.success_fraction_estimate`.
410			if your config isn't expected to fail, passing this might mean you generate more mazes than needed
411			since `safety_margin` is still applied.
412			(defaults to `False`)
413		- `epsilon : float`
414			raise `SuccessChanceTooSmallError` if the success fraction is below this threshold
415			(defaults to `1e-2`)
416
417		# Returns:
418		- `MazeDatasetConfig`
419			new config with adjusted `n_mazes`
420
421		# Raises:
422		- `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon`
423		"""
424		# compute and check the success fraction
425		success_fraction: float = self.success_fraction_estimate(
426			except_if_all_success_expected=except_if_all_success_expected,
427		)
428		if success_fraction < epsilon:
429			err_msg: str = (
430				f"{success_fraction = } is below the threshold of {epsilon = }"
431			)
432			raise SuccessChanceTooSmallError(
433				err_msg,
434			)
435
436		# compute the new number of mazes
437		n_mazes: int = self.n_mazes
438		new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1
439
440		# put it in a new config and return
441		cfg_dict: dict = self.serialize()
442		cfg_dict["n_mazes"] = new_n_mazes
443		return MazeDatasetConfig.load(cfg_dict)

config object which is passed to MazeDataset.from_config to generate or load a dataset

MazeDatasetConfig( *, name: str, seq_len_min: int = 1, seq_len_max: int = 512, seed: int | None = 42, applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]] = <factory>, grid_n: int, n_mazes: int, maze_ctor: Callable = <function LatticeMazeGenerators.gen_dfs>, maze_ctor_kwargs: dict = <factory>, endpoint_kwargs: dict[typing.Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | None | list[tuple[int, int]]] = <factory>, _fname_loaded: str | None = None)
config_version: str
239	@property
240	def config_version(self) -> str:
241		"""return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config"""
242		return "1.0"

return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config

versions: dict
244	@property
245	def versions(self) -> dict:
246		"""return the versions of the config and the maze_dataset"""
247		return dict(
248			config=self.config_version,
249			maze_dataset=importlib.metadata.version("maze_dataset"),
250		)

return the versions of the config and the maze_dataset

def serialize(self) -> dict:
252	def serialize(self) -> dict:
253		"serialize the MazeDatasetConfig with all fields and fname"
254		return {
255			**self._serialize_base(
256				applied_filters__skip__collect_generation_meta=False
257			),
258			"fname": self.to_fname(),
259			"versions": self.versions,
260		}

serialize the MazeDatasetConfig with all fields and fname

def summary(self) -> dict:
262	def summary(self) -> dict:
263		"""return a summary of the config"""
264		# do we run this to make sure it doesn't error?
265		super_summary: dict = super().summary()
266		assert super_summary
267		self_ser: dict = self.serialize()
268		return dict(
269			name=self.name,
270			fname=self.to_fname(),
271			sdc_hash=self.stable_hash_cfg(),
272			seed=self.seed,
273			seq_len_min=self.seq_len_min,
274			seq_len_max=self.seq_len_max,
275			applied_filters=self.applied_filters,
276			grid_n=self_ser["grid_n"],
277			n_mazes=self_ser["n_mazes"],
278			maze_ctor_name=self_ser["maze_ctor"]["__name__"],
279			maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
280			endpoint_kwargs=self_ser["endpoint_kwargs"],
281		)

return a summary of the config

def success_fraction_estimate(self, except_if_all_success_expected: bool = False) -> float:
364	def success_fraction_estimate(
365		self,
366		except_if_all_success_expected: bool = False,
367	) -> float:
368		"""Estimate the success fraction of this config.
369
370		only valid when the generator is a percolation generator,
371		and endpoints are enforced to be dead ends
372
373		this estimate comes from `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmarks.sweep_fit`
374
375		# Parameters:
376		- `except_if_all_success_expected : bool`
377			if `True`, don't raise an error if the success fraction is below the threshold.
378			will always return `1.0` if the config is not expected to fail
379
380		# Returns:
381		- `float`
382			estimated success fraction
383
384		# Raises:
385		- `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False`
386		"""
387		try:
388			return cfg_success_predict_fn(self)
389
390		except NoPercolationInConfigError as e:
391			if except_if_all_success_expected:
392				raise e  # noqa: TRY201
393			return 1.0

Estimate the success fraction of this config.

only valid when the generator is a percolation generator, and endpoints are enforced to be dead ends

this estimate comes from estimate_dataset_fractions.ipynb and maze_dataset.benchmarks.sweep_fit

Parameters:

  • except_if_all_success_expected : bool if True, don't raise an error if the success fraction is below the threshold. will always return 1.0 if the config is not expected to fail

Returns:

  • float estimated success fraction

Raises:

  • NoPercolationInConfigError : if the config is not expected to fail, and except_if_all_success_expected is False
def success_fraction_compensate( self, safety_margin: float = 1.2, except_if_all_success_expected: bool = False, epsilon: float = 0.01) -> MazeDatasetConfig:
395	def success_fraction_compensate(
396		self,
397		safety_margin: float = 1.2,
398		except_if_all_success_expected: bool = False,
399		epsilon: float = 1e-2,
400	) -> "MazeDatasetConfig":
401		"""return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction
402
403		# Parameters:
404		- `safety_margin : float`
405			safety margin to apply to the success fraction estimate
406			(defaults to `1.2`, or 20% more mazes than estimated)
407		- `except_if_all_success_expected : bool`
408			if `True`, don't raise an error if the success fraction is below the threshold.
409			this is passed to `MazeDatasetConfig.success_fraction_estimate`.
410			if your config isn't expected to fail, passing this might mean you generate more mazes than needed
411			since `safety_margin` is still applied.
412			(defaults to `False`)
413		- `epsilon : float`
414			raise `SuccessChanceTooSmallError` if the success fraction is below this threshold
415			(defaults to `1e-2`)
416
417		# Returns:
418		- `MazeDatasetConfig`
419			new config with adjusted `n_mazes`
420
421		# Raises:
422		- `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon`
423		"""
424		# compute and check the success fraction
425		success_fraction: float = self.success_fraction_estimate(
426			except_if_all_success_expected=except_if_all_success_expected,
427		)
428		if success_fraction < epsilon:
429			err_msg: str = (
430				f"{success_fraction = } is below the threshold of {epsilon = }"
431			)
432			raise SuccessChanceTooSmallError(
433				err_msg,
434			)
435
436		# compute the new number of mazes
437		n_mazes: int = self.n_mazes
438		new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1
439
440		# put it in a new config and return
441		cfg_dict: dict = self.serialize()
442		cfg_dict["n_mazes"] = new_n_mazes
443		return MazeDatasetConfig.load(cfg_dict)

return a new MazeDatasetConfig like this one with n_mazes adjusted to compensate for the success fraction

Parameters:

  • safety_margin : float safety margin to apply to the success fraction estimate (defaults to 1.2, or 20% more mazes than estimated)
  • except_if_all_success_expected : bool if True, don't raise an error if the success fraction is below the threshold. this is passed to MazeDatasetConfig.success_fraction_estimate. if your config isn't expected to fail, passing this might mean you generate more mazes than needed since safety_margin is still applied. (defaults to False)
  • epsilon : float raise SuccessChanceTooSmallError if the success fraction is below this threshold (defaults to 1e-2)

Returns:

Raises:

  • SuccessChanceTooSmallError : if the computed success fraction is below epsilon
@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

class MazeDatasetCollection(typing.Generic[~T_DatasetConfig]):
 84class MazeDatasetCollection(GPTDataset):
 85	"""a collection of maze datasets"""
 86
 87	def __init__(
 88		self,
 89		cfg: MazeDatasetCollectionConfig,
 90		maze_datasets: list[MazeDataset],
 91		generation_metadata_collected: dict | None = None,
 92	) -> None:
 93		"initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s"
 94		super().__init__()
 95		self.cfg: MazeDatasetCollectionConfig = cfg
 96		self.maze_datasets: list[MazeDataset] = list(maze_datasets)
 97		for c, ds in zip(
 98			self.cfg.maze_dataset_configs,
 99			self.maze_datasets,
100			strict=False,
101		):
102			assert c.name == ds.cfg.name
103			assert c == ds.cfg
104
105		self.generation_metadata_collected: dict | None = generation_metadata_collected
106
107	@property
108	def dataset_lengths(self) -> list[int]:
109		"""return the lengths of each dataset in the collection"""
110		return [len(dataset) for dataset in self.maze_datasets]
111
112	@property
113	def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]:
114		"""return the cumulative lengths of each dataset in the collection"""
115		return np.array(list(itertools.accumulate(self.dataset_lengths)))
116
117	@cached_property
118	def mazes(self) -> list[LatticeMaze]:
119		"single list of all mazes in the collection"
120		return list(
121			itertools.chain.from_iterable(
122				dataset.mazes for dataset in self.maze_datasets
123			),
124		)
125
126	def __len__(self) -> int:
127		"""return the total number of mazes in the collection"""
128		return sum(len(dataset) for dataset in self.maze_datasets)
129
130	def __getitem__(self, index: int) -> LatticeMaze:
131		"get a maze by index"
132		# find which dataset the index belongs to
133		# we add 1, since np.searchsorted returns the
134		# index of the last element that is strictly less than the target
135		# while we want the index of the last element less than or equal to the target
136		dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1))
137		index_adjusted: int = index
138		if dataset_idx > 0:
139			# if the index is 0, `dataset_idx - 1` will be -1.
140			# We just want to use the base index
141			index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1]
142		return self.maze_datasets[dataset_idx][index_adjusted]
143
144	@classmethod
145	def generate(
146		cls,
147		cfg: MazeDatasetCollectionConfig,
148		**kwargs,
149	) -> "MazeDatasetCollection":
150		"""generate a dataset collection from a config"""
151		datasets = [
152			MazeDataset.generate(config, **kwargs)
153			for config in cfg.maze_dataset_configs
154		]
155		return cls(cfg, datasets)
156
157	@classmethod
158	def download(
159		cls,
160		cfg: MazeDatasetCollectionConfig,
161		**kwargs,
162	) -> "MazeDatasetCollection":
163		"(not implemented!) download a dataset collection from a config"
164		datasets = [
165			MazeDataset.download(config, **kwargs)
166			for config in cfg.maze_dataset_configs
167		]
168		return cls(cfg, datasets)
169
170	def serialize(self) -> JSONdict:
171		"""serialize the dataset collection"""
172		return {
173			_FORMAT_KEY: "MazeDatasetCollection",
174			"cfg": self.cfg.serialize(),
175			"maze_datasets": [dataset.serialize() for dataset in self.maze_datasets],
176			"generation_metadata_collected": json_serialize(
177				self.generation_metadata_collected,
178			),
179		}
180
181	@classmethod
182	def load(cls, data: JSONdict) -> "MazeDatasetCollection":
183		"""load the dataset collection from the representation created by `serialize`"""
184		assert data[_FORMAT_KEY] == "MazeDatasetCollection"
185		return cls(
186			**{
187				key: load_item_recursive(data[key], tuple())
188				for key in ["cfg", "maze_datasets", "generation_metadata_collected"]
189			},
190		)
191
192	# TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow?
193	def as_tokens(
194		self,
195		# TODO: MazeTokenizer
196		maze_tokenizer,  # noqa: ANN001
197		limit: int | None = None,
198		join_tokens_individual_maze: bool = False,
199	) -> list[list[str]] | list[str]:
200		"""return the dataset as tokens
201
202		if join_tokens_individual_maze is True, then the tokens of each maze are
203		joined with a space, and the result is a list of strings.
204		i.e.:
205		>>> dataset.as_tokens(join_tokens_individual_maze=False)
206		[["a", "b", "c"], ["d", "e", "f"]]
207		>>> dataset.as_tokens(join_tokens_individual_maze=True)
208		["a b c", "d e f"]
209		"""
210		output: list[list[str]] = [
211			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
212		]
213		if join_tokens_individual_maze:
214			return [" ".join(tokens) for tokens in output]
215		else:
216			return output
217
218	def update_self_config(self) -> None:
219		"update the config to match the number of mazes, and update the underlying configs of each dataset"
220		# TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset
221		self.cfg.__dict__["n_mazes"] = len(self)
222		for dataset in self.maze_datasets:
223			dataset.update_self_config()
224
225		self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]

a collection of maze datasets

MazeDatasetCollection( cfg: MazeDatasetCollectionConfig, maze_datasets: list[MazeDataset], generation_metadata_collected: dict | None = None)
 87	def __init__(
 88		self,
 89		cfg: MazeDatasetCollectionConfig,
 90		maze_datasets: list[MazeDataset],
 91		generation_metadata_collected: dict | None = None,
 92	) -> None:
 93		"initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s"
 94		super().__init__()
 95		self.cfg: MazeDatasetCollectionConfig = cfg
 96		self.maze_datasets: list[MazeDataset] = list(maze_datasets)
 97		for c, ds in zip(
 98			self.cfg.maze_dataset_configs,
 99			self.maze_datasets,
100			strict=False,
101		):
102			assert c.name == ds.cfg.name
103			assert c == ds.cfg
104
105		self.generation_metadata_collected: dict | None = generation_metadata_collected

initialize the dataset collection from a MazeDatasetCollectionConfig and a list of MazeDatasets

maze_datasets: list[MazeDataset]
generation_metadata_collected: dict | None
dataset_lengths: list[int]
107	@property
108	def dataset_lengths(self) -> list[int]:
109		"""return the lengths of each dataset in the collection"""
110		return [len(dataset) for dataset in self.maze_datasets]

return the lengths of each dataset in the collection

dataset_cum_lengths: jaxtyping.Int[ndarray, 'indices']
112	@property
113	def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]:
114		"""return the cumulative lengths of each dataset in the collection"""
115		return np.array(list(itertools.accumulate(self.dataset_lengths)))

return the cumulative lengths of each dataset in the collection

117	@cached_property
118	def mazes(self) -> list[LatticeMaze]:
119		"single list of all mazes in the collection"
120		return list(
121			itertools.chain.from_iterable(
122				dataset.mazes for dataset in self.maze_datasets
123			),
124		)

single list of all mazes in the collection

@classmethod
def generate( cls, cfg: MazeDatasetCollectionConfig, **kwargs) -> MazeDatasetCollection:
144	@classmethod
145	def generate(
146		cls,
147		cfg: MazeDatasetCollectionConfig,
148		**kwargs,
149	) -> "MazeDatasetCollection":
150		"""generate a dataset collection from a config"""
151		datasets = [
152			MazeDataset.generate(config, **kwargs)
153			for config in cfg.maze_dataset_configs
154		]
155		return cls(cfg, datasets)

generate a dataset collection from a config

@classmethod
def download( cls, cfg: MazeDatasetCollectionConfig, **kwargs) -> MazeDatasetCollection:
157	@classmethod
158	def download(
159		cls,
160		cfg: MazeDatasetCollectionConfig,
161		**kwargs,
162	) -> "MazeDatasetCollection":
163		"(not implemented!) download a dataset collection from a config"
164		datasets = [
165			MazeDataset.download(config, **kwargs)
166			for config in cfg.maze_dataset_configs
167		]
168		return cls(cfg, datasets)

(not implemented!) download a dataset collection from a config

def serialize( self) -> Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]:
170	def serialize(self) -> JSONdict:
171		"""serialize the dataset collection"""
172		return {
173			_FORMAT_KEY: "MazeDatasetCollection",
174			"cfg": self.cfg.serialize(),
175			"maze_datasets": [dataset.serialize() for dataset in self.maze_datasets],
176			"generation_metadata_collected": json_serialize(
177				self.generation_metadata_collected,
178			),
179		}

serialize the dataset collection

@classmethod
def load( cls, data: Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]) -> MazeDatasetCollection:
181	@classmethod
182	def load(cls, data: JSONdict) -> "MazeDatasetCollection":
183		"""load the dataset collection from the representation created by `serialize`"""
184		assert data[_FORMAT_KEY] == "MazeDatasetCollection"
185		return cls(
186			**{
187				key: load_item_recursive(data[key], tuple())
188				for key in ["cfg", "maze_datasets", "generation_metadata_collected"]
189			},
190		)

load the dataset collection from the representation created by serialize

def as_tokens( self, maze_tokenizer, limit: int | None = None, join_tokens_individual_maze: bool = False) -> list[list[str]] | list[str]:
193	def as_tokens(
194		self,
195		# TODO: MazeTokenizer
196		maze_tokenizer,  # noqa: ANN001
197		limit: int | None = None,
198		join_tokens_individual_maze: bool = False,
199	) -> list[list[str]] | list[str]:
200		"""return the dataset as tokens
201
202		if join_tokens_individual_maze is True, then the tokens of each maze are
203		joined with a space, and the result is a list of strings.
204		i.e.:
205		>>> dataset.as_tokens(join_tokens_individual_maze=False)
206		[["a", "b", "c"], ["d", "e", "f"]]
207		>>> dataset.as_tokens(join_tokens_individual_maze=True)
208		["a b c", "d e f"]
209		"""
210		output: list[list[str]] = [
211			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
212		]
213		if join_tokens_individual_maze:
214			return [" ".join(tokens) for tokens in output]
215		else:
216			return output

return the dataset as tokens

if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:

>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
def update_self_config(self) -> None:
218	def update_self_config(self) -> None:
219		"update the config to match the number of mazes, and update the underlying configs of each dataset"
220		# TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset
221		self.cfg.__dict__["n_mazes"] = len(self)
222		for dataset in self.maze_datasets:
223			dataset.update_self_config()
224
225		self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]

update the config to match the number of mazes, and update the underlying configs of each dataset

@serializable_dataclass(kw_only=True)
class MazeDatasetCollectionConfig(maze_dataset.dataset.dataset.GPTDatasetConfig):
31@serializable_dataclass(kw_only=True)
32class MazeDatasetCollectionConfig(GPTDatasetConfig):
33	"""maze dataset collection configuration, including tokenizers and shuffle"""
34
35	# Attributes without a default cannot follow attributes with one  [misc]
36	maze_dataset_configs: list[MazeDatasetConfig] = serializable_field(  # type: ignore[misc]
37		serialization_fn=lambda configs: [config.serialize() for config in configs],
38		loading_fn=lambda data: [
39			MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"]
40		],
41	)
42
43	def summary(self) -> dict:
44		"""return a summary of the config"""
45		return dict(
46			n_mazes=self.n_mazes,
47			max_grid_n=self.max_grid_n,
48			max_grid_shape=self.max_grid_shape,
49			fname=self.to_fname(),
50			cfg_summaries=[c.summary() for c in self.maze_dataset_configs],
51		)
52
53	@property
54	def n_mazes(self) -> int:
55		"""return the total number of mazes in the collection across all dataset"""
56		return sum(config.n_mazes for config in self.maze_dataset_configs)
57
58	@property
59	def max_grid_n(self) -> int:
60		"""return the maximum grid size of the mazes in the collection"""
61		return max(config.grid_n for config in self.maze_dataset_configs)
62
63	@property
64	def max_grid_shape(self) -> CoordTup:
65		"""return the maximum grid shape of the mazes in the collection"""
66		return (self.max_grid_n, self.max_grid_n)
67
68	@property
69	def max_grid_shape_np(self) -> Coord:
70		"""return the maximum grid shape of the mazes in the collection as a numpy array"""
71		return np.array(self.max_grid_shape, dtype=np.int32)
72
73	def stable_hash_cfg(self) -> int:
74		"""return a stable hash of the config"""
75		return stable_hash(json.dumps(self.serialize()))
76
77	def to_fname(self) -> str:
78		"""convert config to a filename"""
79		return sanitize_fname(
80			f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}",
81		)

maze dataset collection configuration, including tokenizers and shuffle

MazeDatasetCollectionConfig( *, name: str, seq_len_min: int = 1, seq_len_max: int = 512, seed: int | None = 42, applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]] = <factory>, maze_dataset_configs: list[MazeDatasetConfig])
maze_dataset_configs: list[MazeDatasetConfig]
def summary(self) -> dict:
43	def summary(self) -> dict:
44		"""return a summary of the config"""
45		return dict(
46			n_mazes=self.n_mazes,
47			max_grid_n=self.max_grid_n,
48			max_grid_shape=self.max_grid_shape,
49			fname=self.to_fname(),
50			cfg_summaries=[c.summary() for c in self.maze_dataset_configs],
51		)

return a summary of the config

n_mazes: int
53	@property
54	def n_mazes(self) -> int:
55		"""return the total number of mazes in the collection across all dataset"""
56		return sum(config.n_mazes for config in self.maze_dataset_configs)

return the total number of mazes in the collection across all dataset

max_grid_n: int
58	@property
59	def max_grid_n(self) -> int:
60		"""return the maximum grid size of the mazes in the collection"""
61		return max(config.grid_n for config in self.maze_dataset_configs)

return the maximum grid size of the mazes in the collection

max_grid_shape: tuple[int, int]
63	@property
64	def max_grid_shape(self) -> CoordTup:
65		"""return the maximum grid shape of the mazes in the collection"""
66		return (self.max_grid_n, self.max_grid_n)

return the maximum grid shape of the mazes in the collection

max_grid_shape_np: jaxtyping.Int8[ndarray, 'row_col=2']
68	@property
69	def max_grid_shape_np(self) -> Coord:
70		"""return the maximum grid shape of the mazes in the collection as a numpy array"""
71		return np.array(self.max_grid_shape, dtype=np.int32)

return the maximum grid shape of the mazes in the collection as a numpy array

def stable_hash_cfg(self) -> int:
73	def stable_hash_cfg(self) -> int:
74		"""return a stable hash of the config"""
75		return stable_hash(json.dumps(self.serialize()))

return a stable hash of the config

def to_fname(self) -> str:
77	def to_fname(self) -> str:
78		"""convert config to a filename"""
79		return sanitize_fname(
80			f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}",
81		)

convert config to a filename

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.dataset.dataset.GPTDatasetConfig
name
seq_len_min
seq_len_max
seed
applied_filters
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict