maze_dataset.benchmark.config_sweep
Benchmarking of how successful maze generation is for various values of percolation
1"""Benchmarking of how successful maze generation is for various values of percolation""" 2 3import functools 4import json 5import warnings 6from pathlib import Path 7from typing import Any, Callable, Generic, Sequence, TypeVar 8 9import matplotlib.pyplot as plt 10import numpy as np 11from jaxtyping import Float 12from muutils.dictmagic import dotlist_to_nested_dict, update_with_nested_dict 13from muutils.json_serialize import ( 14 JSONitem, 15 SerializableDataclass, 16 json_serialize, 17 serializable_dataclass, 18 serializable_field, 19) 20from muutils.parallel import run_maybe_parallel 21from zanj import ZANJ 22 23from maze_dataset import MazeDataset, MazeDatasetConfig 24from maze_dataset.generation import LatticeMazeGenerators 25 26SweepReturnType = TypeVar("SweepReturnType") 27ParamType = TypeVar("ParamType") 28AnalysisFunc = Callable[[MazeDatasetConfig], SweepReturnType] 29 30 31def dataset_success_fraction(cfg: MazeDatasetConfig) -> float: 32 """empirical success fraction of maze generation 33 34 for use as an `analyze_func` in `sweep()` 35 """ 36 dataset: MazeDataset = MazeDataset.from_config( 37 cfg, 38 do_download=False, 39 load_local=False, 40 save_local=False, 41 verbose=False, 42 ) 43 44 return len(dataset) / cfg.n_mazes 45 46 47ANALYSIS_FUNCS: dict[str, AnalysisFunc] = dict( 48 dataset_success_fraction=dataset_success_fraction, 49) 50 51 52def sweep( 53 cfg_base: MazeDatasetConfig, 54 param_values: list[ParamType], 55 param_key: str, 56 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], 57) -> list[SweepReturnType]: 58 """given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value 59 60 # Parameters: 61 - `cfg_base : MazeDatasetConfig` 62 base config on which we will modify the value at `param_key` with values from `param_values` 63 - `param_values : list[ParamType]` 64 list of values to try 65 - `param_key : str` 66 value to modify in `cfg_base` 67 - `analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]` 68 function which analyzes the resulting config. originally built for `dataset_success_fraction` 69 70 # Returns: 71 - `list[SweepReturnType]` 72 _description_ 73 """ 74 outputs: list[SweepReturnType] = [] 75 76 for p in param_values: 77 # update the config 78 cfg_dict: dict = cfg_base.serialize() 79 update_with_nested_dict( 80 cfg_dict, 81 dotlist_to_nested_dict({param_key: p}), 82 ) 83 cfg_test: MazeDatasetConfig = MazeDatasetConfig.load(cfg_dict) 84 85 outputs.append(analyze_func(cfg_test)) 86 87 return outputs 88 89 90@serializable_dataclass() 91class SweepResult(SerializableDataclass, Generic[ParamType, SweepReturnType]): 92 """result of a parameter sweep""" 93 94 configs: list[MazeDatasetConfig] = serializable_field( 95 serialization_fn=lambda cfgs: [cfg.serialize() for cfg in cfgs], 96 deserialize_fn=lambda cfgs: [MazeDatasetConfig.load(cfg) for cfg in cfgs], 97 ) 98 param_values: list[ParamType] = serializable_field( 99 serialization_fn=lambda x: json_serialize(x), 100 deserialize_fn=lambda x: x, 101 assert_type=False, 102 ) 103 result_values: dict[str, Sequence[SweepReturnType]] = serializable_field( 104 serialization_fn=lambda x: json_serialize(x), 105 deserialize_fn=lambda x: x, 106 assert_type=False, 107 ) 108 param_key: str 109 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType] = serializable_field( 110 serialization_fn=lambda f: f.__name__, 111 deserialize_fn=ANALYSIS_FUNCS.get, 112 assert_type=False, 113 ) 114 115 def summary(self) -> JSONitem: 116 "human-readable and json-dumpable short summary of the result" 117 return { 118 "len(configs)": len(self.configs), 119 "len(param_values)": len(self.param_values), 120 "len(result_values)": len(self.result_values), 121 "param_key": self.param_key, 122 "analyze_func": self.analyze_func.__name__, 123 } 124 125 def save(self, path: str | Path, z: ZANJ | None = None) -> None: 126 "save to a file with zanj" 127 if z is None: 128 z = ZANJ() 129 130 z.save(self, path) 131 132 @classmethod 133 def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult": 134 "read from a file with zanj" 135 if z is None: 136 z = ZANJ() 137 138 return z.read(path) 139 140 def configs_by_name(self) -> dict[str, MazeDatasetConfig]: 141 "return configs by name" 142 return {cfg.name: cfg for cfg in self.configs} 143 144 def configs_by_key(self) -> dict[str, MazeDatasetConfig]: 145 "return configs by the key used in `result_values`, which is the filename of the config" 146 return {cfg.to_fname(): cfg for cfg in self.configs} 147 148 def configs_shared(self) -> dict[str, Any]: 149 "return key: value pairs that are shared across all configs" 150 # we know that the configs all have the same keys, 151 # so this way of doing it is fine 152 config_vals: dict[str, set[Any]] = dict() 153 for cfg in self.configs: 154 for k, v in cfg.serialize().items(): 155 if k not in config_vals: 156 config_vals[k] = set() 157 config_vals[k].add(json.dumps(v)) 158 159 shared_vals: dict[str, Any] = dict() 160 161 cfg_ser: dict = self.configs[0].serialize() 162 for k, v in config_vals.items(): 163 if len(v) == 1: 164 shared_vals[k] = cfg_ser[k] 165 166 return shared_vals 167 168 def configs_differing_keys(self) -> set[str]: 169 "return keys that differ across configs" 170 shared_vals: dict[str, Any] = self.configs_shared() 171 differing_keys: set[str] = set() 172 173 for k in MazeDatasetConfig.__dataclass_fields__: 174 if k not in shared_vals: 175 differing_keys.add(k) 176 177 return differing_keys 178 179 def configs_value_set(self, key: str) -> list[Any]: 180 "return a list of the unique values for a given key" 181 d: dict[str, Any] = { 182 json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key) 183 for cfg in self.configs 184 } 185 186 return list(d.values()) 187 188 def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult": 189 "get a subset of this `Result` where the configs has `key` satisfying `val_check`" 190 configs_list: list[MazeDatasetConfig] = [ 191 cfg for cfg in self.configs if val_check(getattr(cfg, key)) 192 ] 193 configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list} 194 result_values: dict[str, Sequence[SweepReturnType]] = { 195 k: self.result_values[k] for k in configs_keys 196 } 197 198 return SweepResult( 199 configs=configs_list, 200 param_values=self.param_values, 201 result_values=result_values, 202 param_key=self.param_key, 203 analyze_func=self.analyze_func, 204 ) 205 206 @classmethod 207 def analyze( 208 cls, 209 configs: list[MazeDatasetConfig], 210 param_values: list[ParamType], 211 param_key: str, 212 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], 213 parallel: bool | int = False, 214 **kwargs, 215 ) -> "SweepResult": 216 """Analyze success rate of maze generation for different percolation values 217 218 # Parameters: 219 - `configs : list[MazeDatasetConfig]` 220 configs to try 221 - `param_values : np.ndarray` 222 numpy array of values to try 223 224 # Returns: 225 - `SweepResult` 226 """ 227 n_pvals: int = len(param_values) 228 229 result_values_list: list[float] = run_maybe_parallel( 230 # TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]" [arg-type] 231 func=functools.partial( # type: ignore[arg-type] 232 sweep, 233 param_values=param_values, 234 param_key=param_key, 235 analyze_func=analyze_func, 236 ), 237 iterable=configs, 238 keep_ordered=True, 239 parallel=parallel, 240 pbar_kwargs=dict(total=len(configs)), 241 **kwargs, 242 ) 243 result_values: dict[str, Float[np.ndarray, n_pvals]] = { 244 cfg.to_fname(): np.array(res) 245 for cfg, res in zip(configs, result_values_list, strict=False) 246 } 247 return cls( 248 configs=configs, 249 param_values=param_values, 250 # TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]" [arg-type] 251 result_values=result_values, # type: ignore[arg-type] 252 param_key=param_key, 253 analyze_func=analyze_func, 254 ) 255 256 def plot( 257 self, 258 save_path: str | None = None, 259 cfg_keys: list[str] | None = None, 260 cmap_name: str | None = "viridis", 261 plot_only: bool = False, 262 show: bool = True, 263 ax: plt.Axes | None = None, 264 ) -> plt.Axes: 265 """Plot the results of percolation analysis""" 266 # set up figure 267 if not ax: 268 fig: plt.Figure 269 ax_: plt.Axes 270 fig, ax_ = plt.subplots(1, 1, figsize=(22, 10)) 271 else: 272 ax_ = ax 273 274 # plot 275 cmap = plt.get_cmap(cmap_name) 276 n_cfgs: int = len(self.result_values) 277 for i, (ep_cfg_name, result_values) in enumerate( 278 sorted( 279 self.result_values.items(), 280 # HACK: sort by grid size 281 # |--< name of config 282 # | |-----------< gets 'g{n}' 283 # | | |--< gets '{n}' 284 # | | | 285 key=lambda x: int(x[0].split("-")[0][1:]), 286 ), 287 ): 288 ax_.plot( 289 # TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 290 self.param_values, # type: ignore[arg-type] 291 # TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 292 result_values, # type: ignore[arg-type] 293 ".-", 294 label=self.configs_by_key()[ep_cfg_name].name, 295 color=cmap((i + 0.5) / (n_cfgs - 0.5)), 296 ) 297 298 # repr of config 299 cfg_shared: dict = self.configs_shared() 300 cfg_repr: str = ( 301 str(cfg_shared) 302 if cfg_keys is None 303 else ( 304 "MazeDatasetConfig(" 305 + ", ".join( 306 [ 307 f"{k}={cfg_shared[k].__name__}" 308 # TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo" [arg-type] 309 if isinstance(cfg_shared[k], Callable) # type: ignore[arg-type] 310 else f"{k}={cfg_shared[k]}" 311 for k in cfg_keys 312 ], 313 ) 314 + ")" 315 ) 316 ) 317 318 # add title and stuff 319 if not plot_only: 320 ax_.set_xlabel(self.param_key) 321 ax_.set_ylabel(self.analyze_func.__name__) 322 ax_.set_title( 323 f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}", 324 ) 325 ax_.grid(True) 326 ax_.legend(loc="center left") 327 328 # save and show 329 if save_path: 330 plt.savefig(save_path) 331 332 if show: 333 plt.show() 334 335 return ax_ 336 337 338DEFAULT_ENDPOINT_KWARGS: list[tuple[str, dict]] = [ 339 ( 340 "any", 341 dict(deadend_start=False, deadend_end=False, except_on_no_valid_endpoint=False), 342 ), 343 ( 344 "deadends", 345 dict( 346 deadend_start=True, 347 deadend_end=True, 348 endpoints_not_equal=False, 349 except_on_no_valid_endpoint=False, 350 ), 351 ), 352 ( 353 "deadends_unique", 354 dict( 355 deadend_start=True, 356 deadend_end=True, 357 endpoints_not_equal=True, 358 except_on_no_valid_endpoint=False, 359 ), 360 ), 361] 362 363 364def endpoint_kwargs_to_name(ep_kwargs: dict) -> str: 365 """convert endpoint kwargs options to a human-readable name""" 366 if ep_kwargs.get("deadend_start", False) or ep_kwargs.get("deadend_end", False): 367 if ep_kwargs.get("endpoints_not_equal", False): 368 return "deadends_unique" 369 else: 370 return "deadends" 371 else: 372 return "any" 373 374 375def full_percolation_analysis( 376 n_mazes: int, 377 p_val_count: int, 378 grid_sizes: list[int], 379 ep_kwargs: list[tuple[str, dict]] | None = None, 380 generators: Sequence[Callable] = ( 381 LatticeMazeGenerators.gen_percolation, 382 LatticeMazeGenerators.gen_dfs_percolation, 383 ), 384 save_dir: Path = Path("../docs/benchmarks/percolation_fractions"), 385 parallel: bool | int = False, 386 **analyze_kwargs, 387) -> SweepResult: 388 "run the full analysis of how percolation affects maze generation success" 389 if ep_kwargs is None: 390 ep_kwargs = DEFAULT_ENDPOINT_KWARGS 391 392 # configs 393 configs: list[MazeDatasetConfig] = list() 394 395 # TODO: B007 noqaed because we dont use `ep_kw_name` or `gf_idx` 396 for ep_kw_name, ep_kw in ep_kwargs: # noqa: B007 397 for gf_idx, gen_func in enumerate(generators): # noqa: B007 398 configs.extend( 399 [ 400 MazeDatasetConfig( 401 name=f"g{grid_n}-{gen_func.__name__.removeprefix('gen_').removesuffix('olation')}", 402 grid_n=grid_n, 403 n_mazes=n_mazes, 404 maze_ctor=gen_func, 405 maze_ctor_kwargs=dict(p=float("nan")), 406 endpoint_kwargs=ep_kw, 407 ) 408 for grid_n in grid_sizes 409 ], 410 ) 411 412 # get results 413 result: SweepResult = SweepResult.analyze( 414 configs=configs, # type: ignore[misc] 415 # TYPING: error: Argument "param_values" to "analyze" of "SweepResult" has incompatible type "float | list[float] | list[list[float]] | list[list[list[Any]]]"; expected "list[Any]" [arg-type] 416 param_values=np.linspace(0.0, 1.0, p_val_count).tolist(), # type: ignore[arg-type] 417 param_key="maze_ctor_kwargs.p", 418 analyze_func=dataset_success_fraction, 419 parallel=parallel, 420 **analyze_kwargs, 421 ) 422 423 # save the result 424 results_path: Path = ( 425 save_dir / f"result-n{n_mazes}-c{len(configs)}-p{p_val_count}.zanj" 426 ) 427 print(f"Saving results to {results_path.as_posix()}") 428 result.save(results_path) 429 430 return result 431 432 433def _is_eq(a, b) -> bool: # noqa: ANN001 434 """check if two objects are equal""" 435 return a == b 436 437 438def plot_grouped( # noqa: C901 439 results: SweepResult, 440 predict_fn: Callable[[MazeDatasetConfig], float] | None = None, 441 prediction_density: int = 50, 442 save_dir: Path | None = None, 443 show: bool = True, 444 logy: bool = False, 445) -> None: 446 """Plot grouped sweep percolation value results for each distinct `endpoint_kwargs` in the configs 447 448 with separate colormaps for each maze generator function 449 450 # Parameters: 451 - `results : SweepResult` 452 The sweep results to plot 453 - `predict_fn : Callable[[MazeDatasetConfig], float] | None` 454 Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines. 455 - `prediction_density : int` 456 Number of points to use for prediction curves (default: 50) 457 - `save_dir : Path | None` 458 Directory to save plots (defaults to `None`, meaning no saving) 459 - `show : bool` 460 Whether to display the plots (defaults to `True`) 461 462 # Usage: 463 ```python 464 >>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16]) 465 >>> plot_grouped(result, save_dir=Path("./plots"), show=False) 466 ``` 467 """ 468 # groups 469 endpoint_kwargs_set: list[dict] = results.configs_value_set("endpoint_kwargs") # type: ignore[assignment] 470 generator_funcs_names: list[str] = list( 471 {cfg.maze_ctor.__name__ for cfg in results.configs}, 472 ) 473 474 # if predicting, create denser p values 475 if predict_fn is not None: 476 p_dense = np.linspace(0.0, 1.0, prediction_density) 477 478 # separate plot for each set of endpoint kwargs 479 for ep_kw in endpoint_kwargs_set: 480 results_epkw: SweepResult = results.get_where( 481 "endpoint_kwargs", 482 functools.partial(_is_eq, b=ep_kw), 483 # lambda x: x == ep_kw, 484 ) 485 shared_keys: set[str] = set(results_epkw.configs_shared().keys()) 486 cfg_keys: set[str] = shared_keys.intersection({"n_mazes", "endpoint_kwargs"}) 487 fig, ax = plt.subplots(1, 1, figsize=(22, 10)) 488 for gf_idx, gen_func in enumerate(generator_funcs_names): 489 results_filtered: SweepResult = results_epkw.get_where( 490 "maze_ctor", 491 # HACK: big hassle to do this without a lambda, is it really that bad? 492 lambda x: x.__name__ == gen_func, # noqa: B023 493 ) 494 if len(results_filtered.configs) < 1: 495 warnings.warn( 496 f"No results for {gen_func} and {ep_kw}. Skipping.", 497 ) 498 continue 499 500 cmap_name = "Reds" if gf_idx == 0 else "Blues" 501 cmap = plt.get_cmap(cmap_name) 502 503 # Plot actual results 504 ax = results_filtered.plot( 505 cfg_keys=list(cfg_keys), 506 ax=ax, 507 show=False, 508 cmap_name=cmap_name, 509 ) 510 if logy: 511 ax.set_yscale("log") 512 513 # Plot predictions if function provided 514 if predict_fn is not None: 515 for cfg_idx, cfg in enumerate(results_filtered.configs): 516 predictions = [] 517 for p in p_dense: 518 cfg_temp = MazeDatasetConfig.load(cfg.serialize()) 519 cfg_temp.maze_ctor_kwargs["p"] = p 520 predictions.append(predict_fn(cfg_temp)) 521 522 # Get the same color as the actual data 523 n_cfgs: int = len(results_filtered.configs) 524 color = cmap((cfg_idx + 0.5) / (n_cfgs - 0.5)) 525 526 # Plot prediction as dashed line 527 ax.plot(p_dense, predictions, "--", color=color, alpha=0.8) 528 529 # save and show 530 if save_dir: 531 save_path: Path = save_dir / f"ep_{endpoint_kwargs_to_name(ep_kw)}.svg" 532 print(f"Saving plot to {save_path.as_posix()}") 533 save_path.parent.mkdir(exist_ok=True, parents=True) 534 plt.savefig(save_path) 535 536 if show: 537 plt.show()
32def dataset_success_fraction(cfg: MazeDatasetConfig) -> float: 33 """empirical success fraction of maze generation 34 35 for use as an `analyze_func` in `sweep()` 36 """ 37 dataset: MazeDataset = MazeDataset.from_config( 38 cfg, 39 do_download=False, 40 load_local=False, 41 save_local=False, 42 verbose=False, 43 ) 44 45 return len(dataset) / cfg.n_mazes
empirical success fraction of maze generation
for use as an analyze_func
in sweep()
53def sweep( 54 cfg_base: MazeDatasetConfig, 55 param_values: list[ParamType], 56 param_key: str, 57 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], 58) -> list[SweepReturnType]: 59 """given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value 60 61 # Parameters: 62 - `cfg_base : MazeDatasetConfig` 63 base config on which we will modify the value at `param_key` with values from `param_values` 64 - `param_values : list[ParamType]` 65 list of values to try 66 - `param_key : str` 67 value to modify in `cfg_base` 68 - `analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]` 69 function which analyzes the resulting config. originally built for `dataset_success_fraction` 70 71 # Returns: 72 - `list[SweepReturnType]` 73 _description_ 74 """ 75 outputs: list[SweepReturnType] = [] 76 77 for p in param_values: 78 # update the config 79 cfg_dict: dict = cfg_base.serialize() 80 update_with_nested_dict( 81 cfg_dict, 82 dotlist_to_nested_dict({param_key: p}), 83 ) 84 cfg_test: MazeDatasetConfig = MazeDatasetConfig.load(cfg_dict) 85 86 outputs.append(analyze_func(cfg_test)) 87 88 return outputs
given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value
Parameters:
cfg_base : MazeDatasetConfig
base config on which we will modify the value atparam_key
with values fromparam_values
param_values : list[ParamType]
list of values to tryparam_key : str
value to modify incfg_base
analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]
function which analyzes the resulting config. originally built fordataset_success_fraction
Returns:
list[SweepReturnType]
_description_
91@serializable_dataclass() 92class SweepResult(SerializableDataclass, Generic[ParamType, SweepReturnType]): 93 """result of a parameter sweep""" 94 95 configs: list[MazeDatasetConfig] = serializable_field( 96 serialization_fn=lambda cfgs: [cfg.serialize() for cfg in cfgs], 97 deserialize_fn=lambda cfgs: [MazeDatasetConfig.load(cfg) for cfg in cfgs], 98 ) 99 param_values: list[ParamType] = serializable_field( 100 serialization_fn=lambda x: json_serialize(x), 101 deserialize_fn=lambda x: x, 102 assert_type=False, 103 ) 104 result_values: dict[str, Sequence[SweepReturnType]] = serializable_field( 105 serialization_fn=lambda x: json_serialize(x), 106 deserialize_fn=lambda x: x, 107 assert_type=False, 108 ) 109 param_key: str 110 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType] = serializable_field( 111 serialization_fn=lambda f: f.__name__, 112 deserialize_fn=ANALYSIS_FUNCS.get, 113 assert_type=False, 114 ) 115 116 def summary(self) -> JSONitem: 117 "human-readable and json-dumpable short summary of the result" 118 return { 119 "len(configs)": len(self.configs), 120 "len(param_values)": len(self.param_values), 121 "len(result_values)": len(self.result_values), 122 "param_key": self.param_key, 123 "analyze_func": self.analyze_func.__name__, 124 } 125 126 def save(self, path: str | Path, z: ZANJ | None = None) -> None: 127 "save to a file with zanj" 128 if z is None: 129 z = ZANJ() 130 131 z.save(self, path) 132 133 @classmethod 134 def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult": 135 "read from a file with zanj" 136 if z is None: 137 z = ZANJ() 138 139 return z.read(path) 140 141 def configs_by_name(self) -> dict[str, MazeDatasetConfig]: 142 "return configs by name" 143 return {cfg.name: cfg for cfg in self.configs} 144 145 def configs_by_key(self) -> dict[str, MazeDatasetConfig]: 146 "return configs by the key used in `result_values`, which is the filename of the config" 147 return {cfg.to_fname(): cfg for cfg in self.configs} 148 149 def configs_shared(self) -> dict[str, Any]: 150 "return key: value pairs that are shared across all configs" 151 # we know that the configs all have the same keys, 152 # so this way of doing it is fine 153 config_vals: dict[str, set[Any]] = dict() 154 for cfg in self.configs: 155 for k, v in cfg.serialize().items(): 156 if k not in config_vals: 157 config_vals[k] = set() 158 config_vals[k].add(json.dumps(v)) 159 160 shared_vals: dict[str, Any] = dict() 161 162 cfg_ser: dict = self.configs[0].serialize() 163 for k, v in config_vals.items(): 164 if len(v) == 1: 165 shared_vals[k] = cfg_ser[k] 166 167 return shared_vals 168 169 def configs_differing_keys(self) -> set[str]: 170 "return keys that differ across configs" 171 shared_vals: dict[str, Any] = self.configs_shared() 172 differing_keys: set[str] = set() 173 174 for k in MazeDatasetConfig.__dataclass_fields__: 175 if k not in shared_vals: 176 differing_keys.add(k) 177 178 return differing_keys 179 180 def configs_value_set(self, key: str) -> list[Any]: 181 "return a list of the unique values for a given key" 182 d: dict[str, Any] = { 183 json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key) 184 for cfg in self.configs 185 } 186 187 return list(d.values()) 188 189 def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult": 190 "get a subset of this `Result` where the configs has `key` satisfying `val_check`" 191 configs_list: list[MazeDatasetConfig] = [ 192 cfg for cfg in self.configs if val_check(getattr(cfg, key)) 193 ] 194 configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list} 195 result_values: dict[str, Sequence[SweepReturnType]] = { 196 k: self.result_values[k] for k in configs_keys 197 } 198 199 return SweepResult( 200 configs=configs_list, 201 param_values=self.param_values, 202 result_values=result_values, 203 param_key=self.param_key, 204 analyze_func=self.analyze_func, 205 ) 206 207 @classmethod 208 def analyze( 209 cls, 210 configs: list[MazeDatasetConfig], 211 param_values: list[ParamType], 212 param_key: str, 213 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], 214 parallel: bool | int = False, 215 **kwargs, 216 ) -> "SweepResult": 217 """Analyze success rate of maze generation for different percolation values 218 219 # Parameters: 220 - `configs : list[MazeDatasetConfig]` 221 configs to try 222 - `param_values : np.ndarray` 223 numpy array of values to try 224 225 # Returns: 226 - `SweepResult` 227 """ 228 n_pvals: int = len(param_values) 229 230 result_values_list: list[float] = run_maybe_parallel( 231 # TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]" [arg-type] 232 func=functools.partial( # type: ignore[arg-type] 233 sweep, 234 param_values=param_values, 235 param_key=param_key, 236 analyze_func=analyze_func, 237 ), 238 iterable=configs, 239 keep_ordered=True, 240 parallel=parallel, 241 pbar_kwargs=dict(total=len(configs)), 242 **kwargs, 243 ) 244 result_values: dict[str, Float[np.ndarray, n_pvals]] = { 245 cfg.to_fname(): np.array(res) 246 for cfg, res in zip(configs, result_values_list, strict=False) 247 } 248 return cls( 249 configs=configs, 250 param_values=param_values, 251 # TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]" [arg-type] 252 result_values=result_values, # type: ignore[arg-type] 253 param_key=param_key, 254 analyze_func=analyze_func, 255 ) 256 257 def plot( 258 self, 259 save_path: str | None = None, 260 cfg_keys: list[str] | None = None, 261 cmap_name: str | None = "viridis", 262 plot_only: bool = False, 263 show: bool = True, 264 ax: plt.Axes | None = None, 265 ) -> plt.Axes: 266 """Plot the results of percolation analysis""" 267 # set up figure 268 if not ax: 269 fig: plt.Figure 270 ax_: plt.Axes 271 fig, ax_ = plt.subplots(1, 1, figsize=(22, 10)) 272 else: 273 ax_ = ax 274 275 # plot 276 cmap = plt.get_cmap(cmap_name) 277 n_cfgs: int = len(self.result_values) 278 for i, (ep_cfg_name, result_values) in enumerate( 279 sorted( 280 self.result_values.items(), 281 # HACK: sort by grid size 282 # |--< name of config 283 # | |-----------< gets 'g{n}' 284 # | | |--< gets '{n}' 285 # | | | 286 key=lambda x: int(x[0].split("-")[0][1:]), 287 ), 288 ): 289 ax_.plot( 290 # TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 291 self.param_values, # type: ignore[arg-type] 292 # TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 293 result_values, # type: ignore[arg-type] 294 ".-", 295 label=self.configs_by_key()[ep_cfg_name].name, 296 color=cmap((i + 0.5) / (n_cfgs - 0.5)), 297 ) 298 299 # repr of config 300 cfg_shared: dict = self.configs_shared() 301 cfg_repr: str = ( 302 str(cfg_shared) 303 if cfg_keys is None 304 else ( 305 "MazeDatasetConfig(" 306 + ", ".join( 307 [ 308 f"{k}={cfg_shared[k].__name__}" 309 # TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo" [arg-type] 310 if isinstance(cfg_shared[k], Callable) # type: ignore[arg-type] 311 else f"{k}={cfg_shared[k]}" 312 for k in cfg_keys 313 ], 314 ) 315 + ")" 316 ) 317 ) 318 319 # add title and stuff 320 if not plot_only: 321 ax_.set_xlabel(self.param_key) 322 ax_.set_ylabel(self.analyze_func.__name__) 323 ax_.set_title( 324 f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}", 325 ) 326 ax_.grid(True) 327 ax_.legend(loc="center left") 328 329 # save and show 330 if save_path: 331 plt.savefig(save_path) 332 333 if show: 334 plt.show() 335 336 return ax_
result of a parameter sweep
116 def summary(self) -> JSONitem: 117 "human-readable and json-dumpable short summary of the result" 118 return { 119 "len(configs)": len(self.configs), 120 "len(param_values)": len(self.param_values), 121 "len(result_values)": len(self.result_values), 122 "param_key": self.param_key, 123 "analyze_func": self.analyze_func.__name__, 124 }
human-readable and json-dumpable short summary of the result
126 def save(self, path: str | Path, z: ZANJ | None = None) -> None: 127 "save to a file with zanj" 128 if z is None: 129 z = ZANJ() 130 131 z.save(self, path)
save to a file with zanj
133 @classmethod 134 def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult": 135 "read from a file with zanj" 136 if z is None: 137 z = ZANJ() 138 139 return z.read(path)
read from a file with zanj
141 def configs_by_name(self) -> dict[str, MazeDatasetConfig]: 142 "return configs by name" 143 return {cfg.name: cfg for cfg in self.configs}
return configs by name
145 def configs_by_key(self) -> dict[str, MazeDatasetConfig]: 146 "return configs by the key used in `result_values`, which is the filename of the config" 147 return {cfg.to_fname(): cfg for cfg in self.configs}
return configs by the key used in result_values
, which is the filename of the config
169 def configs_differing_keys(self) -> set[str]: 170 "return keys that differ across configs" 171 shared_vals: dict[str, Any] = self.configs_shared() 172 differing_keys: set[str] = set() 173 174 for k in MazeDatasetConfig.__dataclass_fields__: 175 if k not in shared_vals: 176 differing_keys.add(k) 177 178 return differing_keys
return keys that differ across configs
180 def configs_value_set(self, key: str) -> list[Any]: 181 "return a list of the unique values for a given key" 182 d: dict[str, Any] = { 183 json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key) 184 for cfg in self.configs 185 } 186 187 return list(d.values())
return a list of the unique values for a given key
189 def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult": 190 "get a subset of this `Result` where the configs has `key` satisfying `val_check`" 191 configs_list: list[MazeDatasetConfig] = [ 192 cfg for cfg in self.configs if val_check(getattr(cfg, key)) 193 ] 194 configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list} 195 result_values: dict[str, Sequence[SweepReturnType]] = { 196 k: self.result_values[k] for k in configs_keys 197 } 198 199 return SweepResult( 200 configs=configs_list, 201 param_values=self.param_values, 202 result_values=result_values, 203 param_key=self.param_key, 204 analyze_func=self.analyze_func, 205 )
get a subset of this Result
where the configs has key
satisfying val_check
207 @classmethod 208 def analyze( 209 cls, 210 configs: list[MazeDatasetConfig], 211 param_values: list[ParamType], 212 param_key: str, 213 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], 214 parallel: bool | int = False, 215 **kwargs, 216 ) -> "SweepResult": 217 """Analyze success rate of maze generation for different percolation values 218 219 # Parameters: 220 - `configs : list[MazeDatasetConfig]` 221 configs to try 222 - `param_values : np.ndarray` 223 numpy array of values to try 224 225 # Returns: 226 - `SweepResult` 227 """ 228 n_pvals: int = len(param_values) 229 230 result_values_list: list[float] = run_maybe_parallel( 231 # TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]" [arg-type] 232 func=functools.partial( # type: ignore[arg-type] 233 sweep, 234 param_values=param_values, 235 param_key=param_key, 236 analyze_func=analyze_func, 237 ), 238 iterable=configs, 239 keep_ordered=True, 240 parallel=parallel, 241 pbar_kwargs=dict(total=len(configs)), 242 **kwargs, 243 ) 244 result_values: dict[str, Float[np.ndarray, n_pvals]] = { 245 cfg.to_fname(): np.array(res) 246 for cfg, res in zip(configs, result_values_list, strict=False) 247 } 248 return cls( 249 configs=configs, 250 param_values=param_values, 251 # TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]" [arg-type] 252 result_values=result_values, # type: ignore[arg-type] 253 param_key=param_key, 254 analyze_func=analyze_func, 255 )
Analyze success rate of maze generation for different percolation values
Parameters:
configs : list[MazeDatasetConfig]
configs to tryparam_values : np.ndarray
numpy array of values to try
Returns:
257 def plot( 258 self, 259 save_path: str | None = None, 260 cfg_keys: list[str] | None = None, 261 cmap_name: str | None = "viridis", 262 plot_only: bool = False, 263 show: bool = True, 264 ax: plt.Axes | None = None, 265 ) -> plt.Axes: 266 """Plot the results of percolation analysis""" 267 # set up figure 268 if not ax: 269 fig: plt.Figure 270 ax_: plt.Axes 271 fig, ax_ = plt.subplots(1, 1, figsize=(22, 10)) 272 else: 273 ax_ = ax 274 275 # plot 276 cmap = plt.get_cmap(cmap_name) 277 n_cfgs: int = len(self.result_values) 278 for i, (ep_cfg_name, result_values) in enumerate( 279 sorted( 280 self.result_values.items(), 281 # HACK: sort by grid size 282 # |--< name of config 283 # | |-----------< gets 'g{n}' 284 # | | |--< gets '{n}' 285 # | | | 286 key=lambda x: int(x[0].split("-")[0][1:]), 287 ), 288 ): 289 ax_.plot( 290 # TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 291 self.param_values, # type: ignore[arg-type] 292 # TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 293 result_values, # type: ignore[arg-type] 294 ".-", 295 label=self.configs_by_key()[ep_cfg_name].name, 296 color=cmap((i + 0.5) / (n_cfgs - 0.5)), 297 ) 298 299 # repr of config 300 cfg_shared: dict = self.configs_shared() 301 cfg_repr: str = ( 302 str(cfg_shared) 303 if cfg_keys is None 304 else ( 305 "MazeDatasetConfig(" 306 + ", ".join( 307 [ 308 f"{k}={cfg_shared[k].__name__}" 309 # TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo" [arg-type] 310 if isinstance(cfg_shared[k], Callable) # type: ignore[arg-type] 311 else f"{k}={cfg_shared[k]}" 312 for k in cfg_keys 313 ], 314 ) 315 + ")" 316 ) 317 ) 318 319 # add title and stuff 320 if not plot_only: 321 ax_.set_xlabel(self.param_key) 322 ax_.set_ylabel(self.analyze_func.__name__) 323 ax_.set_title( 324 f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}", 325 ) 326 ax_.grid(True) 327 ax_.legend(loc="center left") 328 329 # save and show 330 if save_path: 331 plt.savefig(save_path) 332 333 if show: 334 plt.show() 335 336 return ax_
Plot the results of percolation analysis
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- serialize
- load
- validate_fields_types
- validate_field_type
- diff
- update_from_nested_dict
365def endpoint_kwargs_to_name(ep_kwargs: dict) -> str: 366 """convert endpoint kwargs options to a human-readable name""" 367 if ep_kwargs.get("deadend_start", False) or ep_kwargs.get("deadend_end", False): 368 if ep_kwargs.get("endpoints_not_equal", False): 369 return "deadends_unique" 370 else: 371 return "deadends" 372 else: 373 return "any"
convert endpoint kwargs options to a human-readable name
376def full_percolation_analysis( 377 n_mazes: int, 378 p_val_count: int, 379 grid_sizes: list[int], 380 ep_kwargs: list[tuple[str, dict]] | None = None, 381 generators: Sequence[Callable] = ( 382 LatticeMazeGenerators.gen_percolation, 383 LatticeMazeGenerators.gen_dfs_percolation, 384 ), 385 save_dir: Path = Path("../docs/benchmarks/percolation_fractions"), 386 parallel: bool | int = False, 387 **analyze_kwargs, 388) -> SweepResult: 389 "run the full analysis of how percolation affects maze generation success" 390 if ep_kwargs is None: 391 ep_kwargs = DEFAULT_ENDPOINT_KWARGS 392 393 # configs 394 configs: list[MazeDatasetConfig] = list() 395 396 # TODO: B007 noqaed because we dont use `ep_kw_name` or `gf_idx` 397 for ep_kw_name, ep_kw in ep_kwargs: # noqa: B007 398 for gf_idx, gen_func in enumerate(generators): # noqa: B007 399 configs.extend( 400 [ 401 MazeDatasetConfig( 402 name=f"g{grid_n}-{gen_func.__name__.removeprefix('gen_').removesuffix('olation')}", 403 grid_n=grid_n, 404 n_mazes=n_mazes, 405 maze_ctor=gen_func, 406 maze_ctor_kwargs=dict(p=float("nan")), 407 endpoint_kwargs=ep_kw, 408 ) 409 for grid_n in grid_sizes 410 ], 411 ) 412 413 # get results 414 result: SweepResult = SweepResult.analyze( 415 configs=configs, # type: ignore[misc] 416 # TYPING: error: Argument "param_values" to "analyze" of "SweepResult" has incompatible type "float | list[float] | list[list[float]] | list[list[list[Any]]]"; expected "list[Any]" [arg-type] 417 param_values=np.linspace(0.0, 1.0, p_val_count).tolist(), # type: ignore[arg-type] 418 param_key="maze_ctor_kwargs.p", 419 analyze_func=dataset_success_fraction, 420 parallel=parallel, 421 **analyze_kwargs, 422 ) 423 424 # save the result 425 results_path: Path = ( 426 save_dir / f"result-n{n_mazes}-c{len(configs)}-p{p_val_count}.zanj" 427 ) 428 print(f"Saving results to {results_path.as_posix()}") 429 result.save(results_path) 430 431 return result
run the full analysis of how percolation affects maze generation success
439def plot_grouped( # noqa: C901 440 results: SweepResult, 441 predict_fn: Callable[[MazeDatasetConfig], float] | None = None, 442 prediction_density: int = 50, 443 save_dir: Path | None = None, 444 show: bool = True, 445 logy: bool = False, 446) -> None: 447 """Plot grouped sweep percolation value results for each distinct `endpoint_kwargs` in the configs 448 449 with separate colormaps for each maze generator function 450 451 # Parameters: 452 - `results : SweepResult` 453 The sweep results to plot 454 - `predict_fn : Callable[[MazeDatasetConfig], float] | None` 455 Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines. 456 - `prediction_density : int` 457 Number of points to use for prediction curves (default: 50) 458 - `save_dir : Path | None` 459 Directory to save plots (defaults to `None`, meaning no saving) 460 - `show : bool` 461 Whether to display the plots (defaults to `True`) 462 463 # Usage: 464 ```python 465 >>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16]) 466 >>> plot_grouped(result, save_dir=Path("./plots"), show=False) 467 ``` 468 """ 469 # groups 470 endpoint_kwargs_set: list[dict] = results.configs_value_set("endpoint_kwargs") # type: ignore[assignment] 471 generator_funcs_names: list[str] = list( 472 {cfg.maze_ctor.__name__ for cfg in results.configs}, 473 ) 474 475 # if predicting, create denser p values 476 if predict_fn is not None: 477 p_dense = np.linspace(0.0, 1.0, prediction_density) 478 479 # separate plot for each set of endpoint kwargs 480 for ep_kw in endpoint_kwargs_set: 481 results_epkw: SweepResult = results.get_where( 482 "endpoint_kwargs", 483 functools.partial(_is_eq, b=ep_kw), 484 # lambda x: x == ep_kw, 485 ) 486 shared_keys: set[str] = set(results_epkw.configs_shared().keys()) 487 cfg_keys: set[str] = shared_keys.intersection({"n_mazes", "endpoint_kwargs"}) 488 fig, ax = plt.subplots(1, 1, figsize=(22, 10)) 489 for gf_idx, gen_func in enumerate(generator_funcs_names): 490 results_filtered: SweepResult = results_epkw.get_where( 491 "maze_ctor", 492 # HACK: big hassle to do this without a lambda, is it really that bad? 493 lambda x: x.__name__ == gen_func, # noqa: B023 494 ) 495 if len(results_filtered.configs) < 1: 496 warnings.warn( 497 f"No results for {gen_func} and {ep_kw}. Skipping.", 498 ) 499 continue 500 501 cmap_name = "Reds" if gf_idx == 0 else "Blues" 502 cmap = plt.get_cmap(cmap_name) 503 504 # Plot actual results 505 ax = results_filtered.plot( 506 cfg_keys=list(cfg_keys), 507 ax=ax, 508 show=False, 509 cmap_name=cmap_name, 510 ) 511 if logy: 512 ax.set_yscale("log") 513 514 # Plot predictions if function provided 515 if predict_fn is not None: 516 for cfg_idx, cfg in enumerate(results_filtered.configs): 517 predictions = [] 518 for p in p_dense: 519 cfg_temp = MazeDatasetConfig.load(cfg.serialize()) 520 cfg_temp.maze_ctor_kwargs["p"] = p 521 predictions.append(predict_fn(cfg_temp)) 522 523 # Get the same color as the actual data 524 n_cfgs: int = len(results_filtered.configs) 525 color = cmap((cfg_idx + 0.5) / (n_cfgs - 0.5)) 526 527 # Plot prediction as dashed line 528 ax.plot(p_dense, predictions, "--", color=color, alpha=0.8) 529 530 # save and show 531 if save_dir: 532 save_path: Path = save_dir / f"ep_{endpoint_kwargs_to_name(ep_kw)}.svg" 533 print(f"Saving plot to {save_path.as_posix()}") 534 save_path.parent.mkdir(exist_ok=True, parents=True) 535 plt.savefig(save_path) 536 537 if show: 538 plt.show()
Plot grouped sweep percolation value results for each distinct endpoint_kwargs
in the configs
with separate colormaps for each maze generator function
Parameters:
results : SweepResult
The sweep results to plotpredict_fn : Callable[[MazeDatasetConfig], float] | None
Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines.prediction_density : int
Number of points to use for prediction curves (default: 50)save_dir : Path | None
Directory to save plots (defaults toNone
, meaning no saving)show : bool
Whether to display the plots (defaults toTrue
)
Usage:
>>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16])
>>> plot_grouped(result, save_dir=Path("./plots"), show=False)