Coverage for maze_dataset\dataset\rasterized.py: 78%
96 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1"""a special `RasterizedMazeDataset` that returns 2 images, one for input and one for target, for each maze
3this lets you match the input and target format of the [`easy_2_hard`](https://github.com/aks2203/easy-to-hard) dataset
6see their paper:
8```bibtex
9@misc{schwarzschild2021learn,
10 title={Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks},
11 author={Avi Schwarzschild and Eitan Borgnia and Arjun Gupta and Furong Huang and Uzi Vishkin and Micah Goldblum and Tom Goldstein},
12 year={2021},
13 eprint={2106.04537},
14 archivePrefix={arXiv},
15 primaryClass={cs.LG}
16}
17```
18"""
20from pathlib import Path
21import typing
23import numpy as np
24import torch
25from jaxtyping import Float, Int
26from muutils.json_serialize import serializable_dataclass, serializable_field
27from torch.utils.data import Dataset
28from zanj import ZANJ
30from maze_dataset import MazeDataset, MazeDatasetConfig
31from maze_dataset.maze import PixelColors, SolvedMaze
32from maze_dataset.maze.lattice_maze import PixelGrid, _remove_isolated_cells
35def _extend_pixels(
36 image: Int[np.ndarray, "x y rgb"], n_mult: int = 2, n_bdry: int = 1
37) -> Int[np.ndarray, "n_mult*x+2*n_bdry n_mult*y+2*n_bdry rgb"]:
38 wall_fill: int = PixelColors.WALL[0]
39 assert all(x == wall_fill for x in PixelColors.WALL), (
40 "PixelColors.WALL must be a single value"
41 )
43 output: np.ndarray = np.repeat(
44 np.repeat(
45 image,
46 n_mult,
47 axis=0,
48 ),
49 n_mult,
50 axis=1,
51 )
53 # pad on all sides by n_bdry
54 output = np.pad(
55 output,
56 pad_width=((n_bdry, n_bdry), (n_bdry, n_bdry), (0, 0)),
57 mode="constant",
58 constant_values=wall_fill,
59 )
61 return output
64_RASTERIZED_CFG_ADDED_PARAMS: list[str] = [
65 "remove_isolated_cells",
66 "extend_pixels",
67 "endpoints_as_open",
68]
71def process_maze_rasterized_input_target(
72 maze: SolvedMaze,
73 remove_isolated_cells: bool = True,
74 extend_pixels: bool = True,
75 endpoints_as_open: bool = False,
76) -> Float[torch.Tensor, "in/tgt=2 x y rgb=3"]:
77 # problem and solution mazes
78 maze_pixels: PixelGrid = maze.as_pixels(show_endpoints=True, show_solution=True)
79 problem_maze: PixelGrid = maze_pixels.copy()
80 solution_maze: PixelGrid = maze_pixels.copy()
82 # in problem maze, set path to open
83 problem_maze[(problem_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN
85 # wherever solution maze is PixelColors.OPEN, set it to PixelColors.WALL
86 solution_maze[(solution_maze == PixelColors.OPEN).all(axis=-1)] = PixelColors.WALL
87 # wherever it is solution, set it to PixelColors.OPEN
88 solution_maze[(solution_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN
89 if endpoints_as_open:
90 for color in (PixelColors.START, PixelColors.END):
91 solution_maze[(solution_maze == color).all(axis=-1)] = PixelColors.OPEN
93 # postprocess to match original easy_2_hard dataset
94 if remove_isolated_cells:
95 problem_maze = _remove_isolated_cells(problem_maze)
96 solution_maze = _remove_isolated_cells(solution_maze)
98 if extend_pixels:
99 problem_maze = _extend_pixels(problem_maze)
100 solution_maze = _extend_pixels(solution_maze)
102 return torch.tensor(np.array([problem_maze, solution_maze]))
105@serializable_dataclass
106class RasterizedMazeDatasetConfig(MazeDatasetConfig):
107 """
108 - `remove_isolated_cells: bool` whether to set isolated cells to walls
109 - `extend_pixels: bool` whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze)
110 - `endpoints_as_open: bool` whether to set endpoints to open
111 """
113 remove_isolated_cells: bool = serializable_field(default=True)
114 extend_pixels: bool = serializable_field(default=True)
115 endpoints_as_open: bool = serializable_field(default=False)
118class RasterizedMazeDataset(MazeDataset):
119 cfg: RasterizedMazeDatasetConfig
121 def __getitem__(self, idx: int) -> Float[torch.Tensor, "item in/tgt=2 x y rgb=3"]:
122 # get the solved maze
123 solved_maze: SolvedMaze = self.mazes[idx]
125 return process_maze_rasterized_input_target(
126 maze=solved_maze,
127 remove_isolated_cells=self.cfg.remove_isolated_cells,
128 extend_pixels=self.cfg.extend_pixels,
129 endpoints_as_open=self.cfg.endpoints_as_open,
130 )
132 def get_batch(
133 self, idxs: list[int] | None
134 ) -> Float[torch.Tensor, "in/tgt=2 item x y rgb=3"]:
135 if idxs is None:
136 idxs = list(range(len(self)))
138 inputs: list[Float[torch.Tensor, "x y rgb=3"]]
139 targets: list[Float[torch.Tensor, "x y rgb=3"]]
140 inputs, targets = zip(*[self[i] for i in idxs])
142 return torch.stack([torch.stack(inputs), torch.stack(targets)])
144 @classmethod
145 def from_config(
146 cls,
147 cfg: RasterizedMazeDatasetConfig,
148 do_generate: bool = True,
149 load_local: bool = True,
150 save_local: bool = True,
151 zanj: ZANJ | None = None,
152 do_download: bool = True,
153 local_base_path: Path = Path("data/maze_dataset"),
154 except_on_config_mismatch: bool = True,
155 allow_generation_metadata_filter_mismatch: bool = True,
156 verbose: bool = False,
157 **kwargs,
158 ) -> "RasterizedMazeDataset":
159 """create a rasterized maze dataset from a config
161 priority of loading:
162 1. load from local
163 2. download
164 3. generate
166 """
167 return typing.cast(
168 RasterizedMazeDataset,
169 super().from_config(
170 cfg=cfg,
171 do_generate=do_generate,
172 load_local=load_local,
173 save_local=save_local,
174 zanj=zanj,
175 do_download=do_download,
176 local_base_path=local_base_path,
177 except_on_config_mismatch=except_on_config_mismatch,
178 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
179 verbose=verbose,
180 **kwargs,
181 ),
182 )
184 @classmethod
185 def from_config_augmented(
186 cls,
187 cfg: RasterizedMazeDatasetConfig,
188 **kwargs,
189 ) -> Dataset:
190 """loads either a maze transformer dataset or an easy_2_hard dataset"""
191 _cfg_temp: MazeDatasetConfig = MazeDatasetConfig.load(cfg.serialize())
192 return cls.from_base_MazeDataset(
193 cls.from_config(cfg=_cfg_temp, **kwargs),
194 added_params={
195 k: v
196 for k, v in cfg.serialize().items()
197 if k in _RASTERIZED_CFG_ADDED_PARAMS
198 },
199 )
201 @classmethod
202 def from_base_MazeDataset(
203 cls,
204 base_dataset: MazeDataset,
205 added_params: dict | None = None,
206 ) -> Dataset:
207 """loads either a maze transformer dataset or an easy_2_hard dataset"""
208 if added_params is None:
209 added_params = dict(
210 remove_isolated_cells=True,
211 extend_pixels=True,
212 )
213 output: MazeDataset = cls(
214 cfg=base_dataset.cfg,
215 mazes=base_dataset.mazes,
216 )
217 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load(
218 {
219 **base_dataset.cfg.serialize(),
220 **added_params,
221 }
222 )
223 output.cfg = cfg
224 return output
226 def plot(self, count: int | None = None, show: bool = True) -> tuple:
227 import matplotlib.pyplot as plt
229 print(f"{self[0][0].shape = }, {self[0][1].shape = }")
230 count = count or len(self)
231 if count == 0:
232 print("No mazes to plot for dataset")
233 return
234 fig, axes = plt.subplots(2, count, figsize=(15, 5))
235 if count == 1:
236 axes = [axes]
237 for i in range(count):
238 axes[0, i].imshow(self[i][0])
239 axes[1, i].imshow(self[i][1])
240 # remove ticks
241 axes[0, i].set_xticks([])
242 axes[0, i].set_yticks([])
243 axes[1, i].set_xticks([])
244 axes[1, i].set_yticks([])
246 if show:
247 plt.show()
249 return fig, axes
252def make_numpy_collection(
253 base_cfg: RasterizedMazeDatasetConfig,
254 grid_sizes: list[int],
255 from_config_kwargs: dict | None = None,
256 verbose: bool = True,
257 key_fmt: str = "{size}x{size}",
258) -> dict[
259 typing.Literal["configs", "arrays"],
260 dict[str, RasterizedMazeDatasetConfig | np.ndarray],
261]:
262 """create a collection of configs and arrays for different grid sizes, in plain tensor form
264 output is of structure:
265 ```
266 {
267 "configs": {
268 "<n>x<n>": RasterizedMazeDatasetConfig,
269 ...
270 },
271 "arrays": {
272 "<n>x<n>": np.ndarray,
273 ...
274 },
275 }
276 ```
277 """
279 if from_config_kwargs is None:
280 from_config_kwargs = {}
282 datasets: dict[int, RasterizedMazeDataset] = {}
284 for size in grid_sizes:
285 if verbose:
286 print(f"Generating dataset for maze size {size}...")
288 cfg_temp: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load(
289 base_cfg.serialize()
290 )
291 cfg_temp.grid_n = size
293 datasets[size] = RasterizedMazeDataset.from_config_augmented(
294 cfg=cfg_temp,
295 **from_config_kwargs,
296 )
298 return dict(
299 configs={
300 key_fmt.format(size=size): dataset.cfg for size, dataset in datasets.items()
301 },
302 arrays={
303 # get_batch(None) returns a single tensor of shape (n, 2, x, y, 3)
304 key_fmt.format(size=size): dataset.get_batch(None).cpu().numpy()
305 for size, dataset in datasets.items()
306 },
307 )