Coverage for maze_dataset/benchmark/config_sweep.py: 0%

164 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-11 01:43 -0600

1"""Benchmarking of how successful maze generation is for various values of percolation""" 

2 

3import functools 

4import json 

5import warnings 

6from pathlib import Path 

7from typing import Any, Callable, Generic, Sequence, TypeVar 

8 

9import matplotlib.pyplot as plt 

10import numpy as np 

11from jaxtyping import Float 

12from muutils.dictmagic import dotlist_to_nested_dict, update_with_nested_dict 

13from muutils.json_serialize import ( 

14 JSONitem, 

15 SerializableDataclass, 

16 json_serialize, 

17 serializable_dataclass, 

18 serializable_field, 

19) 

20from muutils.parallel import run_maybe_parallel 

21from zanj import ZANJ 

22 

23from maze_dataset import MazeDataset, MazeDatasetConfig 

24from maze_dataset.generation import LatticeMazeGenerators 

25 

26SweepReturnType = TypeVar("SweepReturnType") 

27ParamType = TypeVar("ParamType") 

28AnalysisFunc = Callable[[MazeDatasetConfig], SweepReturnType] 

29 

30 

31def dataset_success_fraction(cfg: MazeDatasetConfig) -> float: 

32 """empirical success fraction of maze generation 

33 

34 for use as an `analyze_func` in `sweep()` 

35 """ 

36 dataset: MazeDataset = MazeDataset.from_config( 

37 cfg, 

38 do_download=False, 

39 load_local=False, 

40 save_local=False, 

41 verbose=False, 

42 ) 

43 

44 return len(dataset) / cfg.n_mazes 

45 

46 

47ANALYSIS_FUNCS: dict[str, AnalysisFunc] = dict( 

48 dataset_success_fraction=dataset_success_fraction, 

49) 

50 

51 

52def sweep( 

53 cfg_base: MazeDatasetConfig, 

54 param_values: list[ParamType], 

55 param_key: str, 

56 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], 

57) -> list[SweepReturnType]: 

58 """given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value 

59 

60 # Parameters: 

61 - `cfg_base : MazeDatasetConfig` 

62 base config on which we will modify the value at `param_key` with values from `param_values` 

63 - `param_values : list[ParamType]` 

64 list of values to try 

65 - `param_key : str` 

66 value to modify in `cfg_base` 

67 - `analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]` 

68 function which analyzes the resulting config. originally built for `dataset_success_fraction` 

69 

70 # Returns: 

71 - `list[SweepReturnType]` 

72 _description_ 

73 """ 

74 outputs: list[SweepReturnType] = [] 

75 

76 for p in param_values: 

77 # update the config 

78 cfg_dict: dict = cfg_base.serialize() 

79 update_with_nested_dict( 

80 cfg_dict, 

81 dotlist_to_nested_dict({param_key: p}), 

82 ) 

83 cfg_test: MazeDatasetConfig = MazeDatasetConfig.load(cfg_dict) 

84 

85 outputs.append(analyze_func(cfg_test)) 

86 

87 return outputs 

88 

89 

90@serializable_dataclass() 

91class SweepResult(SerializableDataclass, Generic[ParamType, SweepReturnType]): 

92 """result of a parameter sweep""" 

93 

94 configs: list[MazeDatasetConfig] = serializable_field( 

95 serialization_fn=lambda cfgs: [cfg.serialize() for cfg in cfgs], 

96 deserialize_fn=lambda cfgs: [MazeDatasetConfig.load(cfg) for cfg in cfgs], 

97 ) 

98 param_values: list[ParamType] = serializable_field( 

99 serialization_fn=lambda x: json_serialize(x), 

100 deserialize_fn=lambda x: x, 

101 assert_type=False, 

102 ) 

103 result_values: dict[str, Sequence[SweepReturnType]] = serializable_field( 

104 serialization_fn=lambda x: json_serialize(x), 

105 deserialize_fn=lambda x: x, 

106 assert_type=False, 

107 ) 

108 param_key: str 

109 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType] = serializable_field( 

110 serialization_fn=lambda f: f.__name__, 

111 deserialize_fn=ANALYSIS_FUNCS.get, 

112 assert_type=False, 

113 ) 

114 

115 def summary(self) -> JSONitem: 

116 "human-readable and json-dumpable short summary of the result" 

117 return { 

118 "len(configs)": len(self.configs), 

119 "len(param_values)": len(self.param_values), 

120 "len(result_values)": len(self.result_values), 

121 "param_key": self.param_key, 

122 "analyze_func": self.analyze_func.__name__, 

123 } 

124 

125 def save(self, path: str | Path, z: ZANJ | None = None) -> None: 

126 "save to a file with zanj" 

127 if z is None: 

128 z = ZANJ() 

129 

130 z.save(self, path) 

131 

132 @classmethod 

133 def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult": 

134 "read from a file with zanj" 

135 if z is None: 

136 z = ZANJ() 

137 

138 return z.read(path) 

139 

140 def configs_by_name(self) -> dict[str, MazeDatasetConfig]: 

141 "return configs by name" 

142 return {cfg.name: cfg for cfg in self.configs} 

143 

144 def configs_by_key(self) -> dict[str, MazeDatasetConfig]: 

145 "return configs by the key used in `result_values`, which is the filename of the config" 

146 return {cfg.to_fname(): cfg for cfg in self.configs} 

147 

148 def configs_shared(self) -> dict[str, Any]: 

149 "return key: value pairs that are shared across all configs" 

150 # we know that the configs all have the same keys, 

151 # so this way of doing it is fine 

152 config_vals: dict[str, set[Any]] = dict() 

153 for cfg in self.configs: 

154 for k, v in cfg.serialize().items(): 

155 if k not in config_vals: 

156 config_vals[k] = set() 

157 config_vals[k].add(json.dumps(v)) 

158 

159 shared_vals: dict[str, Any] = dict() 

160 

161 cfg_ser: dict = self.configs[0].serialize() 

162 for k, v in config_vals.items(): 

163 if len(v) == 1: 

164 shared_vals[k] = cfg_ser[k] 

165 

166 return shared_vals 

167 

168 def configs_differing_keys(self) -> set[str]: 

169 "return keys that differ across configs" 

170 shared_vals: dict[str, Any] = self.configs_shared() 

171 differing_keys: set[str] = set() 

172 

173 for k in MazeDatasetConfig.__dataclass_fields__: 

174 if k not in shared_vals: 

175 differing_keys.add(k) 

176 

177 return differing_keys 

178 

179 def configs_value_set(self, key: str) -> list[Any]: 

180 "return a list of the unique values for a given key" 

181 d: dict[str, Any] = { 

182 json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key) 

183 for cfg in self.configs 

184 } 

185 

186 return list(d.values()) 

187 

188 def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult": 

189 "get a subset of this `Result` where the configs has `key` satisfying `val_check`" 

190 configs_list: list[MazeDatasetConfig] = [ 

191 cfg for cfg in self.configs if val_check(getattr(cfg, key)) 

192 ] 

193 configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list} 

194 result_values: dict[str, Sequence[SweepReturnType]] = { 

195 k: self.result_values[k] for k in configs_keys 

196 } 

197 

198 return SweepResult( 

199 configs=configs_list, 

200 param_values=self.param_values, 

201 result_values=result_values, 

202 param_key=self.param_key, 

203 analyze_func=self.analyze_func, 

204 ) 

205 

206 @classmethod 

207 def analyze( 

208 cls, 

209 configs: list[MazeDatasetConfig], 

210 param_values: list[ParamType], 

211 param_key: str, 

212 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], 

213 parallel: bool | int = False, 

214 **kwargs, 

215 ) -> "SweepResult": 

216 """Analyze success rate of maze generation for different percolation values 

217 

218 # Parameters: 

219 - `configs : list[MazeDatasetConfig]` 

220 configs to try 

221 - `param_values : np.ndarray` 

222 numpy array of values to try 

223 

224 # Returns: 

225 - `SweepResult` 

226 """ 

227 n_pvals: int = len(param_values) 

228 

229 result_values_list: list[float] = run_maybe_parallel( 

230 # TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]" [arg-type] 

231 func=functools.partial( # type: ignore[arg-type] 

232 sweep, 

233 param_values=param_values, 

234 param_key=param_key, 

235 analyze_func=analyze_func, 

236 ), 

237 iterable=configs, 

238 keep_ordered=True, 

239 parallel=parallel, 

240 pbar_kwargs=dict(total=len(configs)), 

241 **kwargs, 

242 ) 

243 result_values: dict[str, Float[np.ndarray, n_pvals]] = { 

244 cfg.to_fname(): np.array(res) 

245 for cfg, res in zip(configs, result_values_list, strict=False) 

246 } 

247 return cls( 

248 configs=configs, 

249 param_values=param_values, 

250 # TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]" [arg-type] 

251 result_values=result_values, # type: ignore[arg-type] 

252 param_key=param_key, 

253 analyze_func=analyze_func, 

254 ) 

255 

256 def plot( 

257 self, 

258 save_path: str | None = None, 

259 cfg_keys: list[str] | None = None, 

260 cmap_name: str | None = "viridis", 

261 plot_only: bool = False, 

262 show: bool = True, 

263 ax: plt.Axes | None = None, 

264 ) -> plt.Axes: 

265 """Plot the results of percolation analysis""" 

266 # set up figure 

267 if not ax: 

268 fig: plt.Figure 

269 ax_: plt.Axes 

270 fig, ax_ = plt.subplots(1, 1, figsize=(22, 10)) 

271 else: 

272 ax_ = ax 

273 

274 # plot 

275 cmap = plt.get_cmap(cmap_name) 

276 n_cfgs: int = len(self.result_values) 

277 for i, (ep_cfg_name, result_values) in enumerate( 

278 sorted( 

279 self.result_values.items(), 

280 # HACK: sort by grid size 

281 # |--< name of config 

282 # | |-----------< gets 'g{n}' 

283 # | | |--< gets '{n}' 

284 # | | | 

285 key=lambda x: int(x[0].split("-")[0][1:]), 

286 ), 

287 ): 

288 ax_.plot( 

289 # TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 

290 self.param_values, # type: ignore[arg-type] 

291 # TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 

292 result_values, # type: ignore[arg-type] 

293 ".-", 

294 label=self.configs_by_key()[ep_cfg_name].name, 

295 color=cmap((i + 0.5) / (n_cfgs - 0.5)), 

296 ) 

297 

298 # repr of config 

299 cfg_shared: dict = self.configs_shared() 

300 cfg_repr: str = ( 

301 str(cfg_shared) 

302 if cfg_keys is None 

303 else ( 

304 "MazeDatasetConfig(" 

305 + ", ".join( 

306 [ 

307 f"{k}={cfg_shared[k].__name__}" 

308 # TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo" [arg-type] 

309 if isinstance(cfg_shared[k], Callable) # type: ignore[arg-type] 

310 else f"{k}={cfg_shared[k]}" 

311 for k in cfg_keys 

312 ], 

313 ) 

314 + ")" 

315 ) 

316 ) 

317 

318 # add title and stuff 

319 if not plot_only: 

320 ax_.set_xlabel(self.param_key) 

321 ax_.set_ylabel(self.analyze_func.__name__) 

322 ax_.set_title( 

323 f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}", 

324 ) 

325 ax_.grid(True) 

326 ax_.legend(loc="center left") 

327 

328 # save and show 

329 if save_path: 

330 plt.savefig(save_path) 

331 

332 if show: 

333 plt.show() 

334 

335 return ax_ 

336 

337 

338DEFAULT_ENDPOINT_KWARGS: list[tuple[str, dict]] = [ 

339 ( 

340 "any", 

341 dict(deadend_start=False, deadend_end=False, except_on_no_valid_endpoint=False), 

342 ), 

343 ( 

344 "deadends", 

345 dict( 

346 deadend_start=True, 

347 deadend_end=True, 

348 endpoints_not_equal=False, 

349 except_on_no_valid_endpoint=False, 

350 ), 

351 ), 

352 ( 

353 "deadends_unique", 

354 dict( 

355 deadend_start=True, 

356 deadend_end=True, 

357 endpoints_not_equal=True, 

358 except_on_no_valid_endpoint=False, 

359 ), 

360 ), 

361] 

362 

363 

364def endpoint_kwargs_to_name(ep_kwargs: dict) -> str: 

365 """convert endpoint kwargs options to a human-readable name""" 

366 if ep_kwargs.get("deadend_start", False) or ep_kwargs.get("deadend_end", False): 

367 if ep_kwargs.get("endpoints_not_equal", False): 

368 return "deadends_unique" 

369 else: 

370 return "deadends" 

371 else: 

372 return "any" 

373 

374 

375def full_percolation_analysis( 

376 n_mazes: int, 

377 p_val_count: int, 

378 grid_sizes: list[int], 

379 ep_kwargs: list[tuple[str, dict]] | None = None, 

380 generators: Sequence[Callable] = ( 

381 LatticeMazeGenerators.gen_percolation, 

382 LatticeMazeGenerators.gen_dfs_percolation, 

383 ), 

384 save_dir: Path = Path("../docs/benchmarks/percolation_fractions"), 

385 parallel: bool | int = False, 

386 **analyze_kwargs, 

387) -> SweepResult: 

388 "run the full analysis of how percolation affects maze generation success" 

389 if ep_kwargs is None: 

390 ep_kwargs = DEFAULT_ENDPOINT_KWARGS 

391 

392 # configs 

393 configs: list[MazeDatasetConfig] = list() 

394 

395 # TODO: B007 noqaed because we dont use `ep_kw_name` or `gf_idx` 

396 for ep_kw_name, ep_kw in ep_kwargs: # noqa: B007 

397 for gf_idx, gen_func in enumerate(generators): # noqa: B007 

398 configs.extend( 

399 [ 

400 MazeDatasetConfig( 

401 name=f"g{grid_n}-{gen_func.__name__.removeprefix('gen_').removesuffix('olation')}", 

402 grid_n=grid_n, 

403 n_mazes=n_mazes, 

404 maze_ctor=gen_func, 

405 maze_ctor_kwargs=dict(p=float("nan")), 

406 endpoint_kwargs=ep_kw, 

407 ) 

408 for grid_n in grid_sizes 

409 ], 

410 ) 

411 

412 # get results 

413 result: SweepResult = SweepResult.analyze( 

414 configs=configs, # type: ignore[misc] 

415 # TYPING: error: Argument "param_values" to "analyze" of "SweepResult" has incompatible type "float | list[float] | list[list[float]] | list[list[list[Any]]]"; expected "list[Any]" [arg-type] 

416 param_values=np.linspace(0.0, 1.0, p_val_count).tolist(), # type: ignore[arg-type] 

417 param_key="maze_ctor_kwargs.p", 

418 analyze_func=dataset_success_fraction, 

419 parallel=parallel, 

420 **analyze_kwargs, 

421 ) 

422 

423 # save the result 

424 results_path: Path = ( 

425 save_dir / f"result-n{n_mazes}-c{len(configs)}-p{p_val_count}.zanj" 

426 ) 

427 print(f"Saving results to {results_path.as_posix()}") 

428 result.save(results_path) 

429 

430 return result 

431 

432 

433def _is_eq(a, b) -> bool: # noqa: ANN001 

434 """check if two objects are equal""" 

435 return a == b 

436 

437 

438def plot_grouped( # noqa: C901 

439 results: SweepResult, 

440 predict_fn: Callable[[MazeDatasetConfig], float] | None = None, 

441 prediction_density: int = 50, 

442 save_dir: Path | None = None, 

443 show: bool = True, 

444 logy: bool = False, 

445) -> None: 

446 """Plot grouped sweep percolation value results for each distinct `endpoint_kwargs` in the configs 

447 

448 with separate colormaps for each maze generator function 

449 

450 # Parameters: 

451 - `results : SweepResult` 

452 The sweep results to plot 

453 - `predict_fn : Callable[[MazeDatasetConfig], float] | None` 

454 Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines. 

455 - `prediction_density : int` 

456 Number of points to use for prediction curves (default: 50) 

457 - `save_dir : Path | None` 

458 Directory to save plots (defaults to `None`, meaning no saving) 

459 - `show : bool` 

460 Whether to display the plots (defaults to `True`) 

461 

462 # Usage: 

463 ```python 

464 >>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16]) 

465 >>> plot_grouped(result, save_dir=Path("./plots"), show=False) 

466 ``` 

467 """ 

468 # groups 

469 endpoint_kwargs_set: list[dict] = results.configs_value_set("endpoint_kwargs") # type: ignore[assignment] 

470 generator_funcs_names: list[str] = list( 

471 {cfg.maze_ctor.__name__ for cfg in results.configs}, 

472 ) 

473 

474 # if predicting, create denser p values 

475 if predict_fn is not None: 

476 p_dense = np.linspace(0.0, 1.0, prediction_density) 

477 

478 # separate plot for each set of endpoint kwargs 

479 for ep_kw in endpoint_kwargs_set: 

480 results_epkw: SweepResult = results.get_where( 

481 "endpoint_kwargs", 

482 functools.partial(_is_eq, b=ep_kw), 

483 # lambda x: x == ep_kw, 

484 ) 

485 shared_keys: set[str] = set(results_epkw.configs_shared().keys()) 

486 cfg_keys: set[str] = shared_keys.intersection({"n_mazes", "endpoint_kwargs"}) 

487 fig, ax = plt.subplots(1, 1, figsize=(22, 10)) 

488 for gf_idx, gen_func in enumerate(generator_funcs_names): 

489 results_filtered: SweepResult = results_epkw.get_where( 

490 "maze_ctor", 

491 # HACK: big hassle to do this without a lambda, is it really that bad? 

492 lambda x: x.__name__ == gen_func, # noqa: B023 

493 ) 

494 if len(results_filtered.configs) < 1: 

495 warnings.warn( 

496 f"No results for {gen_func} and {ep_kw}. Skipping.", 

497 ) 

498 continue 

499 

500 cmap_name = "Reds" if gf_idx == 0 else "Blues" 

501 cmap = plt.get_cmap(cmap_name) 

502 

503 # Plot actual results 

504 ax = results_filtered.plot( 

505 cfg_keys=list(cfg_keys), 

506 ax=ax, 

507 show=False, 

508 cmap_name=cmap_name, 

509 ) 

510 if logy: 

511 ax.set_yscale("log") 

512 

513 # Plot predictions if function provided 

514 if predict_fn is not None: 

515 for cfg_idx, cfg in enumerate(results_filtered.configs): 

516 predictions = [] 

517 for p in p_dense: 

518 cfg_temp = MazeDatasetConfig.load(cfg.serialize()) 

519 cfg_temp.maze_ctor_kwargs["p"] = p 

520 predictions.append(predict_fn(cfg_temp)) 

521 

522 # Get the same color as the actual data 

523 n_cfgs: int = len(results_filtered.configs) 

524 color = cmap((cfg_idx + 0.5) / (n_cfgs - 0.5)) 

525 

526 # Plot prediction as dashed line 

527 ax.plot(p_dense, predictions, "--", color=color, alpha=0.8) 

528 

529 # save and show 

530 if save_dir: 

531 save_path: Path = save_dir / f"ep_{endpoint_kwargs_to_name(ep_kw)}.svg" 

532 print(f"Saving plot to {save_path.as_posix()}") 

533 save_path.parent.mkdir(exist_ok=True, parents=True) 

534 plt.savefig(save_path) 

535 

536 if show: 

537 plt.show()