Coverage for maze_dataset/benchmark/config_sweep.py: 0%
164 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 01:43 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 01:43 -0600
1"""Benchmarking of how successful maze generation is for various values of percolation"""
3import functools
4import json
5import warnings
6from pathlib import Path
7from typing import Any, Callable, Generic, Sequence, TypeVar
9import matplotlib.pyplot as plt
10import numpy as np
11from jaxtyping import Float
12from muutils.dictmagic import dotlist_to_nested_dict, update_with_nested_dict
13from muutils.json_serialize import (
14 JSONitem,
15 SerializableDataclass,
16 json_serialize,
17 serializable_dataclass,
18 serializable_field,
19)
20from muutils.parallel import run_maybe_parallel
21from zanj import ZANJ
23from maze_dataset import MazeDataset, MazeDatasetConfig
24from maze_dataset.generation import LatticeMazeGenerators
26SweepReturnType = TypeVar("SweepReturnType")
27ParamType = TypeVar("ParamType")
28AnalysisFunc = Callable[[MazeDatasetConfig], SweepReturnType]
31def dataset_success_fraction(cfg: MazeDatasetConfig) -> float:
32 """empirical success fraction of maze generation
34 for use as an `analyze_func` in `sweep()`
35 """
36 dataset: MazeDataset = MazeDataset.from_config(
37 cfg,
38 do_download=False,
39 load_local=False,
40 save_local=False,
41 verbose=False,
42 )
44 return len(dataset) / cfg.n_mazes
47ANALYSIS_FUNCS: dict[str, AnalysisFunc] = dict(
48 dataset_success_fraction=dataset_success_fraction,
49)
52def sweep(
53 cfg_base: MazeDatasetConfig,
54 param_values: list[ParamType],
55 param_key: str,
56 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType],
57) -> list[SweepReturnType]:
58 """given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value
60 # Parameters:
61 - `cfg_base : MazeDatasetConfig`
62 base config on which we will modify the value at `param_key` with values from `param_values`
63 - `param_values : list[ParamType]`
64 list of values to try
65 - `param_key : str`
66 value to modify in `cfg_base`
67 - `analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]`
68 function which analyzes the resulting config. originally built for `dataset_success_fraction`
70 # Returns:
71 - `list[SweepReturnType]`
72 _description_
73 """
74 outputs: list[SweepReturnType] = []
76 for p in param_values:
77 # update the config
78 cfg_dict: dict = cfg_base.serialize()
79 update_with_nested_dict(
80 cfg_dict,
81 dotlist_to_nested_dict({param_key: p}),
82 )
83 cfg_test: MazeDatasetConfig = MazeDatasetConfig.load(cfg_dict)
85 outputs.append(analyze_func(cfg_test))
87 return outputs
90@serializable_dataclass()
91class SweepResult(SerializableDataclass, Generic[ParamType, SweepReturnType]):
92 """result of a parameter sweep"""
94 configs: list[MazeDatasetConfig] = serializable_field(
95 serialization_fn=lambda cfgs: [cfg.serialize() for cfg in cfgs],
96 deserialize_fn=lambda cfgs: [MazeDatasetConfig.load(cfg) for cfg in cfgs],
97 )
98 param_values: list[ParamType] = serializable_field(
99 serialization_fn=lambda x: json_serialize(x),
100 deserialize_fn=lambda x: x,
101 assert_type=False,
102 )
103 result_values: dict[str, Sequence[SweepReturnType]] = serializable_field(
104 serialization_fn=lambda x: json_serialize(x),
105 deserialize_fn=lambda x: x,
106 assert_type=False,
107 )
108 param_key: str
109 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType] = serializable_field(
110 serialization_fn=lambda f: f.__name__,
111 deserialize_fn=ANALYSIS_FUNCS.get,
112 assert_type=False,
113 )
115 def summary(self) -> JSONitem:
116 "human-readable and json-dumpable short summary of the result"
117 return {
118 "len(configs)": len(self.configs),
119 "len(param_values)": len(self.param_values),
120 "len(result_values)": len(self.result_values),
121 "param_key": self.param_key,
122 "analyze_func": self.analyze_func.__name__,
123 }
125 def save(self, path: str | Path, z: ZANJ | None = None) -> None:
126 "save to a file with zanj"
127 if z is None:
128 z = ZANJ()
130 z.save(self, path)
132 @classmethod
133 def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult":
134 "read from a file with zanj"
135 if z is None:
136 z = ZANJ()
138 return z.read(path)
140 def configs_by_name(self) -> dict[str, MazeDatasetConfig]:
141 "return configs by name"
142 return {cfg.name: cfg for cfg in self.configs}
144 def configs_by_key(self) -> dict[str, MazeDatasetConfig]:
145 "return configs by the key used in `result_values`, which is the filename of the config"
146 return {cfg.to_fname(): cfg for cfg in self.configs}
148 def configs_shared(self) -> dict[str, Any]:
149 "return key: value pairs that are shared across all configs"
150 # we know that the configs all have the same keys,
151 # so this way of doing it is fine
152 config_vals: dict[str, set[Any]] = dict()
153 for cfg in self.configs:
154 for k, v in cfg.serialize().items():
155 if k not in config_vals:
156 config_vals[k] = set()
157 config_vals[k].add(json.dumps(v))
159 shared_vals: dict[str, Any] = dict()
161 cfg_ser: dict = self.configs[0].serialize()
162 for k, v in config_vals.items():
163 if len(v) == 1:
164 shared_vals[k] = cfg_ser[k]
166 return shared_vals
168 def configs_differing_keys(self) -> set[str]:
169 "return keys that differ across configs"
170 shared_vals: dict[str, Any] = self.configs_shared()
171 differing_keys: set[str] = set()
173 for k in MazeDatasetConfig.__dataclass_fields__:
174 if k not in shared_vals:
175 differing_keys.add(k)
177 return differing_keys
179 def configs_value_set(self, key: str) -> list[Any]:
180 "return a list of the unique values for a given key"
181 d: dict[str, Any] = {
182 json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key)
183 for cfg in self.configs
184 }
186 return list(d.values())
188 def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult":
189 "get a subset of this `Result` where the configs has `key` satisfying `val_check`"
190 configs_list: list[MazeDatasetConfig] = [
191 cfg for cfg in self.configs if val_check(getattr(cfg, key))
192 ]
193 configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list}
194 result_values: dict[str, Sequence[SweepReturnType]] = {
195 k: self.result_values[k] for k in configs_keys
196 }
198 return SweepResult(
199 configs=configs_list,
200 param_values=self.param_values,
201 result_values=result_values,
202 param_key=self.param_key,
203 analyze_func=self.analyze_func,
204 )
206 @classmethod
207 def analyze(
208 cls,
209 configs: list[MazeDatasetConfig],
210 param_values: list[ParamType],
211 param_key: str,
212 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType],
213 parallel: bool | int = False,
214 **kwargs,
215 ) -> "SweepResult":
216 """Analyze success rate of maze generation for different percolation values
218 # Parameters:
219 - `configs : list[MazeDatasetConfig]`
220 configs to try
221 - `param_values : np.ndarray`
222 numpy array of values to try
224 # Returns:
225 - `SweepResult`
226 """
227 n_pvals: int = len(param_values)
229 result_values_list: list[float] = run_maybe_parallel(
230 # TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]" [arg-type]
231 func=functools.partial( # type: ignore[arg-type]
232 sweep,
233 param_values=param_values,
234 param_key=param_key,
235 analyze_func=analyze_func,
236 ),
237 iterable=configs,
238 keep_ordered=True,
239 parallel=parallel,
240 pbar_kwargs=dict(total=len(configs)),
241 **kwargs,
242 )
243 result_values: dict[str, Float[np.ndarray, n_pvals]] = {
244 cfg.to_fname(): np.array(res)
245 for cfg, res in zip(configs, result_values_list, strict=False)
246 }
247 return cls(
248 configs=configs,
249 param_values=param_values,
250 # TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]" [arg-type]
251 result_values=result_values, # type: ignore[arg-type]
252 param_key=param_key,
253 analyze_func=analyze_func,
254 )
256 def plot(
257 self,
258 save_path: str | None = None,
259 cfg_keys: list[str] | None = None,
260 cmap_name: str | None = "viridis",
261 plot_only: bool = False,
262 show: bool = True,
263 ax: plt.Axes | None = None,
264 ) -> plt.Axes:
265 """Plot the results of percolation analysis"""
266 # set up figure
267 if not ax:
268 fig: plt.Figure
269 ax_: plt.Axes
270 fig, ax_ = plt.subplots(1, 1, figsize=(22, 10))
271 else:
272 ax_ = ax
274 # plot
275 cmap = plt.get_cmap(cmap_name)
276 n_cfgs: int = len(self.result_values)
277 for i, (ep_cfg_name, result_values) in enumerate(
278 sorted(
279 self.result_values.items(),
280 # HACK: sort by grid size
281 # |--< name of config
282 # | |-----------< gets 'g{n}'
283 # | | |--< gets '{n}'
284 # | | |
285 key=lambda x: int(x[0].split("-")[0][1:]),
286 ),
287 ):
288 ax_.plot(
289 # TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type]
290 self.param_values, # type: ignore[arg-type]
291 # TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type]
292 result_values, # type: ignore[arg-type]
293 ".-",
294 label=self.configs_by_key()[ep_cfg_name].name,
295 color=cmap((i + 0.5) / (n_cfgs - 0.5)),
296 )
298 # repr of config
299 cfg_shared: dict = self.configs_shared()
300 cfg_repr: str = (
301 str(cfg_shared)
302 if cfg_keys is None
303 else (
304 "MazeDatasetConfig("
305 + ", ".join(
306 [
307 f"{k}={cfg_shared[k].__name__}"
308 # TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo" [arg-type]
309 if isinstance(cfg_shared[k], Callable) # type: ignore[arg-type]
310 else f"{k}={cfg_shared[k]}"
311 for k in cfg_keys
312 ],
313 )
314 + ")"
315 )
316 )
318 # add title and stuff
319 if not plot_only:
320 ax_.set_xlabel(self.param_key)
321 ax_.set_ylabel(self.analyze_func.__name__)
322 ax_.set_title(
323 f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}",
324 )
325 ax_.grid(True)
326 ax_.legend(loc="center left")
328 # save and show
329 if save_path:
330 plt.savefig(save_path)
332 if show:
333 plt.show()
335 return ax_
338DEFAULT_ENDPOINT_KWARGS: list[tuple[str, dict]] = [
339 (
340 "any",
341 dict(deadend_start=False, deadend_end=False, except_on_no_valid_endpoint=False),
342 ),
343 (
344 "deadends",
345 dict(
346 deadend_start=True,
347 deadend_end=True,
348 endpoints_not_equal=False,
349 except_on_no_valid_endpoint=False,
350 ),
351 ),
352 (
353 "deadends_unique",
354 dict(
355 deadend_start=True,
356 deadend_end=True,
357 endpoints_not_equal=True,
358 except_on_no_valid_endpoint=False,
359 ),
360 ),
361]
364def endpoint_kwargs_to_name(ep_kwargs: dict) -> str:
365 """convert endpoint kwargs options to a human-readable name"""
366 if ep_kwargs.get("deadend_start", False) or ep_kwargs.get("deadend_end", False):
367 if ep_kwargs.get("endpoints_not_equal", False):
368 return "deadends_unique"
369 else:
370 return "deadends"
371 else:
372 return "any"
375def full_percolation_analysis(
376 n_mazes: int,
377 p_val_count: int,
378 grid_sizes: list[int],
379 ep_kwargs: list[tuple[str, dict]] | None = None,
380 generators: Sequence[Callable] = (
381 LatticeMazeGenerators.gen_percolation,
382 LatticeMazeGenerators.gen_dfs_percolation,
383 ),
384 save_dir: Path = Path("../docs/benchmarks/percolation_fractions"),
385 parallel: bool | int = False,
386 **analyze_kwargs,
387) -> SweepResult:
388 "run the full analysis of how percolation affects maze generation success"
389 if ep_kwargs is None:
390 ep_kwargs = DEFAULT_ENDPOINT_KWARGS
392 # configs
393 configs: list[MazeDatasetConfig] = list()
395 # TODO: B007 noqaed because we dont use `ep_kw_name` or `gf_idx`
396 for ep_kw_name, ep_kw in ep_kwargs: # noqa: B007
397 for gf_idx, gen_func in enumerate(generators): # noqa: B007
398 configs.extend(
399 [
400 MazeDatasetConfig(
401 name=f"g{grid_n}-{gen_func.__name__.removeprefix('gen_').removesuffix('olation')}",
402 grid_n=grid_n,
403 n_mazes=n_mazes,
404 maze_ctor=gen_func,
405 maze_ctor_kwargs=dict(p=float("nan")),
406 endpoint_kwargs=ep_kw,
407 )
408 for grid_n in grid_sizes
409 ],
410 )
412 # get results
413 result: SweepResult = SweepResult.analyze(
414 configs=configs, # type: ignore[misc]
415 # TYPING: error: Argument "param_values" to "analyze" of "SweepResult" has incompatible type "float | list[float] | list[list[float]] | list[list[list[Any]]]"; expected "list[Any]" [arg-type]
416 param_values=np.linspace(0.0, 1.0, p_val_count).tolist(), # type: ignore[arg-type]
417 param_key="maze_ctor_kwargs.p",
418 analyze_func=dataset_success_fraction,
419 parallel=parallel,
420 **analyze_kwargs,
421 )
423 # save the result
424 results_path: Path = (
425 save_dir / f"result-n{n_mazes}-c{len(configs)}-p{p_val_count}.zanj"
426 )
427 print(f"Saving results to {results_path.as_posix()}")
428 result.save(results_path)
430 return result
433def _is_eq(a, b) -> bool: # noqa: ANN001
434 """check if two objects are equal"""
435 return a == b
438def plot_grouped( # noqa: C901
439 results: SweepResult,
440 predict_fn: Callable[[MazeDatasetConfig], float] | None = None,
441 prediction_density: int = 50,
442 save_dir: Path | None = None,
443 show: bool = True,
444 logy: bool = False,
445) -> None:
446 """Plot grouped sweep percolation value results for each distinct `endpoint_kwargs` in the configs
448 with separate colormaps for each maze generator function
450 # Parameters:
451 - `results : SweepResult`
452 The sweep results to plot
453 - `predict_fn : Callable[[MazeDatasetConfig], float] | None`
454 Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines.
455 - `prediction_density : int`
456 Number of points to use for prediction curves (default: 50)
457 - `save_dir : Path | None`
458 Directory to save plots (defaults to `None`, meaning no saving)
459 - `show : bool`
460 Whether to display the plots (defaults to `True`)
462 # Usage:
463 ```python
464 >>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16])
465 >>> plot_grouped(result, save_dir=Path("./plots"), show=False)
466 ```
467 """
468 # groups
469 endpoint_kwargs_set: list[dict] = results.configs_value_set("endpoint_kwargs") # type: ignore[assignment]
470 generator_funcs_names: list[str] = list(
471 {cfg.maze_ctor.__name__ for cfg in results.configs},
472 )
474 # if predicting, create denser p values
475 if predict_fn is not None:
476 p_dense = np.linspace(0.0, 1.0, prediction_density)
478 # separate plot for each set of endpoint kwargs
479 for ep_kw in endpoint_kwargs_set:
480 results_epkw: SweepResult = results.get_where(
481 "endpoint_kwargs",
482 functools.partial(_is_eq, b=ep_kw),
483 # lambda x: x == ep_kw,
484 )
485 shared_keys: set[str] = set(results_epkw.configs_shared().keys())
486 cfg_keys: set[str] = shared_keys.intersection({"n_mazes", "endpoint_kwargs"})
487 fig, ax = plt.subplots(1, 1, figsize=(22, 10))
488 for gf_idx, gen_func in enumerate(generator_funcs_names):
489 results_filtered: SweepResult = results_epkw.get_where(
490 "maze_ctor",
491 # HACK: big hassle to do this without a lambda, is it really that bad?
492 lambda x: x.__name__ == gen_func, # noqa: B023
493 )
494 if len(results_filtered.configs) < 1:
495 warnings.warn(
496 f"No results for {gen_func} and {ep_kw}. Skipping.",
497 )
498 continue
500 cmap_name = "Reds" if gf_idx == 0 else "Blues"
501 cmap = plt.get_cmap(cmap_name)
503 # Plot actual results
504 ax = results_filtered.plot(
505 cfg_keys=list(cfg_keys),
506 ax=ax,
507 show=False,
508 cmap_name=cmap_name,
509 )
510 if logy:
511 ax.set_yscale("log")
513 # Plot predictions if function provided
514 if predict_fn is not None:
515 for cfg_idx, cfg in enumerate(results_filtered.configs):
516 predictions = []
517 for p in p_dense:
518 cfg_temp = MazeDatasetConfig.load(cfg.serialize())
519 cfg_temp.maze_ctor_kwargs["p"] = p
520 predictions.append(predict_fn(cfg_temp))
522 # Get the same color as the actual data
523 n_cfgs: int = len(results_filtered.configs)
524 color = cmap((cfg_idx + 0.5) / (n_cfgs - 0.5))
526 # Plot prediction as dashed line
527 ax.plot(p_dense, predictions, "--", color=color, alpha=0.8)
529 # save and show
530 if save_dir:
531 save_path: Path = save_dir / f"ep_{endpoint_kwargs_to_name(ep_kw)}.svg"
532 print(f"Saving plot to {save_path.as_posix()}")
533 save_path.parent.mkdir(exist_ok=True, parents=True)
534 plt.savefig(save_path)
536 if show:
537 plt.show()