Coverage for maze_dataset\dataset\rasterized.py: 78%

96 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1"""a special `RasterizedMazeDataset` that returns 2 images, one for input and one for target, for each maze 

2 

3this lets you match the input and target format of the [`easy_2_hard`](https://github.com/aks2203/easy-to-hard) dataset 

4 

5 

6see their paper: 

7 

8```bibtex 

9@misc{schwarzschild2021learn, 

10 title={Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks}, 

11 author={Avi Schwarzschild and Eitan Borgnia and Arjun Gupta and Furong Huang and Uzi Vishkin and Micah Goldblum and Tom Goldstein}, 

12 year={2021}, 

13 eprint={2106.04537}, 

14 archivePrefix={arXiv}, 

15 primaryClass={cs.LG} 

16} 

17``` 

18""" 

19 

20from pathlib import Path 

21import typing 

22 

23import numpy as np 

24import torch 

25from jaxtyping import Float, Int 

26from muutils.json_serialize import serializable_dataclass, serializable_field 

27from torch.utils.data import Dataset 

28from zanj import ZANJ 

29 

30from maze_dataset import MazeDataset, MazeDatasetConfig 

31from maze_dataset.maze import PixelColors, SolvedMaze 

32from maze_dataset.maze.lattice_maze import PixelGrid, _remove_isolated_cells 

33 

34 

35def _extend_pixels( 

36 image: Int[np.ndarray, "x y rgb"], n_mult: int = 2, n_bdry: int = 1 

37) -> Int[np.ndarray, "n_mult*x+2*n_bdry n_mult*y+2*n_bdry rgb"]: 

38 wall_fill: int = PixelColors.WALL[0] 

39 assert all(x == wall_fill for x in PixelColors.WALL), ( 

40 "PixelColors.WALL must be a single value" 

41 ) 

42 

43 output: np.ndarray = np.repeat( 

44 np.repeat( 

45 image, 

46 n_mult, 

47 axis=0, 

48 ), 

49 n_mult, 

50 axis=1, 

51 ) 

52 

53 # pad on all sides by n_bdry 

54 output = np.pad( 

55 output, 

56 pad_width=((n_bdry, n_bdry), (n_bdry, n_bdry), (0, 0)), 

57 mode="constant", 

58 constant_values=wall_fill, 

59 ) 

60 

61 return output 

62 

63 

64_RASTERIZED_CFG_ADDED_PARAMS: list[str] = [ 

65 "remove_isolated_cells", 

66 "extend_pixels", 

67 "endpoints_as_open", 

68] 

69 

70 

71def process_maze_rasterized_input_target( 

72 maze: SolvedMaze, 

73 remove_isolated_cells: bool = True, 

74 extend_pixels: bool = True, 

75 endpoints_as_open: bool = False, 

76) -> Float[torch.Tensor, "in/tgt=2 x y rgb=3"]: 

77 # problem and solution mazes 

78 maze_pixels: PixelGrid = maze.as_pixels(show_endpoints=True, show_solution=True) 

79 problem_maze: PixelGrid = maze_pixels.copy() 

80 solution_maze: PixelGrid = maze_pixels.copy() 

81 

82 # in problem maze, set path to open 

83 problem_maze[(problem_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN 

84 

85 # wherever solution maze is PixelColors.OPEN, set it to PixelColors.WALL 

86 solution_maze[(solution_maze == PixelColors.OPEN).all(axis=-1)] = PixelColors.WALL 

87 # wherever it is solution, set it to PixelColors.OPEN 

88 solution_maze[(solution_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN 

89 if endpoints_as_open: 

90 for color in (PixelColors.START, PixelColors.END): 

91 solution_maze[(solution_maze == color).all(axis=-1)] = PixelColors.OPEN 

92 

93 # postprocess to match original easy_2_hard dataset 

94 if remove_isolated_cells: 

95 problem_maze = _remove_isolated_cells(problem_maze) 

96 solution_maze = _remove_isolated_cells(solution_maze) 

97 

98 if extend_pixels: 

99 problem_maze = _extend_pixels(problem_maze) 

100 solution_maze = _extend_pixels(solution_maze) 

101 

102 return torch.tensor(np.array([problem_maze, solution_maze])) 

103 

104 

105@serializable_dataclass 

106class RasterizedMazeDatasetConfig(MazeDatasetConfig): 

107 """ 

108 - `remove_isolated_cells: bool` whether to set isolated cells to walls 

109 - `extend_pixels: bool` whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze) 

110 - `endpoints_as_open: bool` whether to set endpoints to open 

111 """ 

112 

113 remove_isolated_cells: bool = serializable_field(default=True) 

114 extend_pixels: bool = serializable_field(default=True) 

115 endpoints_as_open: bool = serializable_field(default=False) 

116 

117 

118class RasterizedMazeDataset(MazeDataset): 

119 cfg: RasterizedMazeDatasetConfig 

120 

121 def __getitem__(self, idx: int) -> Float[torch.Tensor, "item in/tgt=2 x y rgb=3"]: 

122 # get the solved maze 

123 solved_maze: SolvedMaze = self.mazes[idx] 

124 

125 return process_maze_rasterized_input_target( 

126 maze=solved_maze, 

127 remove_isolated_cells=self.cfg.remove_isolated_cells, 

128 extend_pixels=self.cfg.extend_pixels, 

129 endpoints_as_open=self.cfg.endpoints_as_open, 

130 ) 

131 

132 def get_batch( 

133 self, idxs: list[int] | None 

134 ) -> Float[torch.Tensor, "in/tgt=2 item x y rgb=3"]: 

135 if idxs is None: 

136 idxs = list(range(len(self))) 

137 

138 inputs: list[Float[torch.Tensor, "x y rgb=3"]] 

139 targets: list[Float[torch.Tensor, "x y rgb=3"]] 

140 inputs, targets = zip(*[self[i] for i in idxs]) 

141 

142 return torch.stack([torch.stack(inputs), torch.stack(targets)]) 

143 

144 @classmethod 

145 def from_config( 

146 cls, 

147 cfg: RasterizedMazeDatasetConfig, 

148 do_generate: bool = True, 

149 load_local: bool = True, 

150 save_local: bool = True, 

151 zanj: ZANJ | None = None, 

152 do_download: bool = True, 

153 local_base_path: Path = Path("data/maze_dataset"), 

154 except_on_config_mismatch: bool = True, 

155 allow_generation_metadata_filter_mismatch: bool = True, 

156 verbose: bool = False, 

157 **kwargs, 

158 ) -> "RasterizedMazeDataset": 

159 """create a rasterized maze dataset from a config 

160 

161 priority of loading: 

162 1. load from local 

163 2. download 

164 3. generate 

165 

166 """ 

167 return typing.cast( 

168 RasterizedMazeDataset, 

169 super().from_config( 

170 cfg=cfg, 

171 do_generate=do_generate, 

172 load_local=load_local, 

173 save_local=save_local, 

174 zanj=zanj, 

175 do_download=do_download, 

176 local_base_path=local_base_path, 

177 except_on_config_mismatch=except_on_config_mismatch, 

178 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 

179 verbose=verbose, 

180 **kwargs, 

181 ), 

182 ) 

183 

184 @classmethod 

185 def from_config_augmented( 

186 cls, 

187 cfg: RasterizedMazeDatasetConfig, 

188 **kwargs, 

189 ) -> Dataset: 

190 """loads either a maze transformer dataset or an easy_2_hard dataset""" 

191 _cfg_temp: MazeDatasetConfig = MazeDatasetConfig.load(cfg.serialize()) 

192 return cls.from_base_MazeDataset( 

193 cls.from_config(cfg=_cfg_temp, **kwargs), 

194 added_params={ 

195 k: v 

196 for k, v in cfg.serialize().items() 

197 if k in _RASTERIZED_CFG_ADDED_PARAMS 

198 }, 

199 ) 

200 

201 @classmethod 

202 def from_base_MazeDataset( 

203 cls, 

204 base_dataset: MazeDataset, 

205 added_params: dict | None = None, 

206 ) -> Dataset: 

207 """loads either a maze transformer dataset or an easy_2_hard dataset""" 

208 if added_params is None: 

209 added_params = dict( 

210 remove_isolated_cells=True, 

211 extend_pixels=True, 

212 ) 

213 output: MazeDataset = cls( 

214 cfg=base_dataset.cfg, 

215 mazes=base_dataset.mazes, 

216 ) 

217 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load( 

218 { 

219 **base_dataset.cfg.serialize(), 

220 **added_params, 

221 } 

222 ) 

223 output.cfg = cfg 

224 return output 

225 

226 def plot(self, count: int | None = None, show: bool = True) -> tuple: 

227 import matplotlib.pyplot as plt 

228 

229 print(f"{self[0][0].shape = }, {self[0][1].shape = }") 

230 count = count or len(self) 

231 if count == 0: 

232 print("No mazes to plot for dataset") 

233 return 

234 fig, axes = plt.subplots(2, count, figsize=(15, 5)) 

235 if count == 1: 

236 axes = [axes] 

237 for i in range(count): 

238 axes[0, i].imshow(self[i][0]) 

239 axes[1, i].imshow(self[i][1]) 

240 # remove ticks 

241 axes[0, i].set_xticks([]) 

242 axes[0, i].set_yticks([]) 

243 axes[1, i].set_xticks([]) 

244 axes[1, i].set_yticks([]) 

245 

246 if show: 

247 plt.show() 

248 

249 return fig, axes 

250 

251 

252def make_numpy_collection( 

253 base_cfg: RasterizedMazeDatasetConfig, 

254 grid_sizes: list[int], 

255 from_config_kwargs: dict | None = None, 

256 verbose: bool = True, 

257 key_fmt: str = "{size}x{size}", 

258) -> dict[ 

259 typing.Literal["configs", "arrays"], 

260 dict[str, RasterizedMazeDatasetConfig | np.ndarray], 

261]: 

262 """create a collection of configs and arrays for different grid sizes, in plain tensor form 

263 

264 output is of structure: 

265 ``` 

266 { 

267 "configs": { 

268 "<n>x<n>": RasterizedMazeDatasetConfig, 

269 ... 

270 }, 

271 "arrays": { 

272 "<n>x<n>": np.ndarray, 

273 ... 

274 }, 

275 } 

276 ``` 

277 """ 

278 

279 if from_config_kwargs is None: 

280 from_config_kwargs = {} 

281 

282 datasets: dict[int, RasterizedMazeDataset] = {} 

283 

284 for size in grid_sizes: 

285 if verbose: 

286 print(f"Generating dataset for maze size {size}...") 

287 

288 cfg_temp: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load( 

289 base_cfg.serialize() 

290 ) 

291 cfg_temp.grid_n = size 

292 

293 datasets[size] = RasterizedMazeDataset.from_config_augmented( 

294 cfg=cfg_temp, 

295 **from_config_kwargs, 

296 ) 

297 

298 return dict( 

299 configs={ 

300 key_fmt.format(size=size): dataset.cfg for size, dataset in datasets.items() 

301 }, 

302 arrays={ 

303 # get_batch(None) returns a single tensor of shape (n, 2, x, y, 3) 

304 key_fmt.format(size=size): dataset.get_batch(None).cpu().numpy() 

305 for size, dataset in datasets.items() 

306 }, 

307 )