Coverage for maze_dataset\dataset\collected_dataset.py: 33%

87 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1"""collecting different maze datasets into a single dataset, for greater variety in a training or validation set 

2 

3> [!CAUTION] 

4> `MazeDatasetCollection` is not thoroughly tested and is not guaranteed to work. 

5 

6""" 

7 

8import itertools 

9import json 

10import typing 

11from functools import cached_property 

12 

13import numpy as np 

14from jaxtyping import Int 

15from muutils.json_serialize import ( 

16 json_serialize, 

17 serializable_dataclass, 

18 serializable_field, 

19) 

20from muutils.json_serialize.util import JSONdict 

21from muutils.json_serialize.util import _FORMAT_KEY 

22from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash 

23from zanj.loading import LoaderHandler, load_item_recursive, register_loader_handler 

24 

25from maze_dataset.constants import Coord, CoordTup 

26from maze_dataset.dataset.dataset import GPTDataset, GPTDatasetConfig 

27from maze_dataset.dataset.maze_dataset import MazeDataset, MazeDatasetConfig 

28from maze_dataset.maze import LatticeMaze 

29 

30 

31@serializable_dataclass(kw_only=True) 

32class MazeDatasetCollectionConfig(GPTDatasetConfig): 

33 """maze dataset collection configuration, including tokenizers and shuffle""" 

34 

35 maze_dataset_configs: list[MazeDatasetConfig] = serializable_field( 

36 serialization_fn=lambda configs: [config.serialize() for config in configs], 

37 loading_fn=lambda data: [ 

38 MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"] 

39 ], 

40 ) 

41 

42 def summary(self) -> dict: 

43 """return a summary of the config""" 

44 return [c.summary() for c in self.maze_dataset_configs] 

45 

46 @property 

47 def n_mazes(self) -> int: 

48 return sum(config.n_mazes for config in self.maze_dataset_configs) 

49 

50 @property 

51 def max_grid_n(self) -> int: 

52 return max(config.grid_n for config in self.maze_dataset_configs) 

53 

54 @property 

55 def max_grid_shape(self) -> CoordTup: 

56 return (self.max_grid_n, self.max_grid_n) 

57 

58 @property 

59 def max_grid_shape_np(self) -> Coord: 

60 return np.array(self.max_grid_shape, dtype=np.int32) 

61 

62 def stable_hash_cfg(self) -> int: 

63 return stable_hash(json.dumps(self.serialize())) 

64 

65 def to_fname(self) -> str: 

66 """convert config to a filename""" 

67 return sanitize_fname( 

68 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}" 

69 ) 

70 

71 

72class MazeDatasetCollection(GPTDataset): 

73 """a collection of maze datasets""" 

74 

75 def __init__( 

76 self, 

77 cfg: MazeDatasetCollectionConfig, 

78 maze_datasets: list[MazeDataset], 

79 generation_metadata_collected: dict | None = None, 

80 ) -> None: 

81 super().__init__() 

82 self.cfg: MazeDatasetCollectionConfig = cfg 

83 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 

84 for c, ds in zip(self.cfg.maze_dataset_configs, self.maze_datasets): 

85 assert c.name == ds.cfg.name 

86 assert c == ds.cfg 

87 

88 self.generation_metadata_collected: dict | None = generation_metadata_collected 

89 

90 @property 

91 def dataset_lengths(self) -> list[int]: 

92 return [len(dataset) for dataset in self.maze_datasets] 

93 

94 @property 

95 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 

96 return np.array(list(itertools.accumulate(self.dataset_lengths))) 

97 

98 @cached_property 

99 def mazes(self) -> list[LatticeMaze]: 

100 return list( 

101 itertools.chain.from_iterable( 

102 dataset.mazes for dataset in self.maze_datasets 

103 ) 

104 ) 

105 

106 def __len__(self) -> int: 

107 return sum(len(dataset) for dataset in self.maze_datasets) 

108 

109 def __getitem__(self, index: int): 

110 # find which dataset the index belongs to 

111 # we add 1, since np.searchsorted returns the 

112 # index of the last element that is strictly less than the target 

113 # while we want the index of the last element less than or equal to the target 

114 dataset_idx: int = np.searchsorted(self.dataset_cum_lengths, index + 1) 

115 index_adjusted: int = index 

116 if dataset_idx > 0: 

117 # if the index is 0, `dataset_idx - 1` will be -1. 

118 # We just want to use the base index 

119 index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1] 

120 return self.maze_datasets[dataset_idx][index_adjusted] 

121 

122 @classmethod 

123 def generate( 

124 cls, cfg: MazeDatasetCollectionConfig, **kwargs 

125 ) -> "MazeDatasetCollection": 

126 datasets = [ 

127 MazeDataset.generate(config, **kwargs) 

128 for config in cfg.maze_dataset_configs 

129 ] 

130 return cls(cfg, datasets) 

131 

132 @classmethod 

133 def download( 

134 cls, cfg: MazeDatasetCollectionConfig, **kwargs 

135 ) -> "MazeDatasetCollection": 

136 datasets = [ 

137 MazeDataset.download(config, **kwargs) 

138 for config in cfg.maze_dataset_configs 

139 ] 

140 return cls(cfg, datasets) 

141 

142 def serialize(self) -> JSONdict: 

143 return { 

144 _FORMAT_KEY: "MazeDatasetCollection", 

145 "cfg": self.cfg.serialize(), 

146 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 

147 "generation_metadata_collected": json_serialize( 

148 self.generation_metadata_collected 

149 ), 

150 } 

151 

152 @classmethod 

153 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 

154 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 

155 return cls( 

156 **{ 

157 key: load_item_recursive(data[key], tuple()) 

158 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 

159 } 

160 ) 

161 

162 # TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow? 

163 def as_tokens( 

164 self, 

165 maze_tokenizer, # TODO: MazeTokenizer 

166 limit: int | None = None, 

167 join_tokens_individual_maze: bool = False, 

168 ) -> list[list[str]] | list[str]: 

169 """return the dataset as tokens 

170 

171 if join_tokens_individual_maze is True, then the tokens of each maze are 

172 joined with a space, and the result is a list of strings. 

173 i.e.: 

174 >>> dataset.as_tokens(join_tokens_individual_maze=False) 

175 [["a", "b", "c"], ["d", "e", "f"]] 

176 >>> dataset.as_tokens(join_tokens_individual_maze=True) 

177 ["a b c", "d e f"] 

178 """ 

179 output: list[list[str]] = [ 

180 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 

181 ] 

182 if join_tokens_individual_maze: 

183 return [" ".join(tokens) for tokens in output] 

184 else: 

185 return output 

186 

187 def update_self_config(self) -> None: 

188 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 

189 self.cfg.__dict__["n_mazes"] = len(self) 

190 for dataset in self.maze_datasets: 

191 dataset.update_self_config() 

192 

193 

194MazeDatasetCollectionConfig._dataset_class = MazeDatasetCollection 

195register_loader_handler( 

196 LoaderHandler( 

197 check=lambda json_item, path=None, z=None: ( 

198 isinstance(json_item, typing.Mapping) 

199 and _FORMAT_KEY in json_item 

200 and json_item[_FORMAT_KEY].startswith("MazeDatasetCollection") 

201 ), 

202 load=lambda json_item, path=None, z=None: MazeDatasetCollection.load(json_item), 

203 uid="MazeDatasetCollection", 

204 source_pckg="maze_dataset.generation.maze_dataset_collection", 

205 desc="MazeDatasetCollection", 

206 ) 

207)