Coverage for maze_dataset/tokenization/all_tokenizers.py: 0%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-11 01:43 -0600

1"""Contains `get_all_tokenizers()` and supporting limited-use functions. 

2 

3# `get_all_tokenizers()` 

4returns a comprehensive collection of all valid `MazeTokenizerModular` objects. 

5This is an overwhelming majority subset of the set of all possible `MazeTokenizerModular` objects. 

6Other tokenizers not contained in `get_all_tokenizers()` may be possible to construct, but they are untested and not guaranteed to work. 

7This collection is in a separate module since it is expensive to compute and will grow more expensive as features are added to `MazeTokenizerModular`. 

8 

9## Use Cases 

10In general, uses for this module are limited to development of the library and specific research studying many tokenization behaviors. 

11- Unit testing: 

12 - Tokenizers to use in unit tests are sampled from `get_all_tokenizers()` 

13- Large-scale tokenizer research: 

14 - Specific research training models on many tokenization behaviors can use `get_all_tokenizers()` as the maximally inclusive collection 

15 - `get_all_tokenizers()` may be subsequently filtered using `MazeTokenizerModular.has_element` 

16For other uses, it's likely that the computational expense can be avoided by using 

17- `maze_tokenizer.get_all_tokenizer_hashes()` for membership checks 

18- `utils.all_instances` for generating smaller subsets of `MazeTokenizerModular` or `_TokenizerElement` objects 

19 

20# `EVERY_TEST_TOKENIZERS` 

21A collection of the tokenizers which should always be included in unit tests when test fuzzing is used. 

22This collection should be expanded as specific tokenizers become canonical or popular. 

23""" 

24 

25import functools 

26import multiprocessing 

27import random 

28from functools import cache 

29from pathlib import Path 

30from typing import Callable 

31 

32import frozendict 

33import numpy as np 

34from jaxtyping import Int64 

35from muutils.spinner import NoOpContextManager, SpinnerContext 

36from tqdm import tqdm 

37 

38from maze_dataset.tokenization import ( 

39 CoordTokenizers, 

40 MazeTokenizerModular, 

41 PromptSequencers, 

42 StepTokenizers, 

43 _TokenizerElement, 

44) 

45from maze_dataset.utils import FiniteValued, all_instances 

46 

47# Always include this as the first item in the dict `validation_funcs` whenever using `all_instances` with `MazeTokenizerModular` 

48# TYPING: error: Type variable "maze_dataset.utils.FiniteValued" is unbound [valid-type] 

49# note: (Hint: Use "Generic[FiniteValued]" or "Protocol[FiniteValued]" base class to bind "FiniteValued" inside a class) 

50# note: (Hint: Use "FiniteValued" in function signature to bind "FiniteValued" inside a function) 

51MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS: frozendict.frozendict[ 

52 type[FiniteValued], 

53 Callable[[FiniteValued], bool], 

54] = frozendict.frozendict( 

55 { 

56 # TYPING: Item "bool" of the upper bound "bool | IsDataclass | Enum" of type variable "FiniteValued" has no attribute "is_valid" [union-attr] 

57 _TokenizerElement: lambda x: x.is_valid(), 

58 # Currently no need for `MazeTokenizerModular.is_valid` since that method contains no special cases not already covered by `_TokenizerElement.is_valid` 

59 # MazeTokenizerModular: lambda x: x.is_valid(), 

60 # TYPING: error: No overload variant of "set" matches argument type "FiniteValued" [call-overload] 

61 # note: Possible overload variants: 

62 # note: def [_T] set(self) -> set[_T] 

63 # note: def [_T] set(self, Iterable[_T], /) -> set[_T] 

64 # TYPING: error: Argument 1 to "len" has incompatible type "FiniteValued"; expected "Sized" [arg-type] 

65 StepTokenizers.StepTokenizerPermutation: lambda x: len(set(x)) == len(x) 

66 and x != (StepTokenizers.Distance(),), 

67 }, 

68) 

69 

70 

71@cache 

72def get_all_tokenizers() -> list[MazeTokenizerModular]: 

73 """Computes a complete list of all valid tokenizers. 

74 

75 Warning: This is an expensive function. 

76 """ 

77 return list( 

78 all_instances( 

79 MazeTokenizerModular, 

80 validation_funcs=MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS, 

81 ), 

82 ) 

83 

84 

85EVERY_TEST_TOKENIZERS: list[MazeTokenizerModular] = [ 

86 MazeTokenizerModular(), 

87 MazeTokenizerModular( 

88 prompt_sequencer=PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.CTT()), 

89 ), 

90 # TODO: add more here as specific tokenizers become canonical and frequently used 

91] 

92 

93 

94@cache 

95def all_tokenizers_set() -> set[MazeTokenizerModular]: 

96 """Casts `get_all_tokenizers()` to a set.""" 

97 return set(get_all_tokenizers()) 

98 

99 

100@cache 

101def _all_tokenizers_except_every_test_tokenizers() -> list[MazeTokenizerModular]: 

102 """Returns""" 

103 return list(all_tokenizers_set().difference(EVERY_TEST_TOKENIZERS)) 

104 

105 

106def sample_all_tokenizers(n: int) -> list[MazeTokenizerModular]: 

107 """Samples `n` tokenizers from `get_all_tokenizers()`.""" 

108 return random.sample(get_all_tokenizers(), n) 

109 

110 

111def sample_tokenizers_for_test(n: int | None) -> list[MazeTokenizerModular]: 

112 """Returns a sample of size `n` of unique elements from `get_all_tokenizers()`, 

113 

114 always including every element in `EVERY_TEST_TOKENIZERS`. 

115 """ 

116 if n is None: 

117 return get_all_tokenizers() 

118 

119 if n < len(EVERY_TEST_TOKENIZERS): 

120 err_msg: str = f"`n` must be at least {len(EVERY_TEST_TOKENIZERS) = } such that the sample can contain `EVERY_TEST_TOKENIZERS`." 

121 raise ValueError( 

122 err_msg, 

123 ) 

124 sample: list[MazeTokenizerModular] = random.sample( 

125 _all_tokenizers_except_every_test_tokenizers(), 

126 n - len(EVERY_TEST_TOKENIZERS), 

127 ) 

128 sample.extend(EVERY_TEST_TOKENIZERS) 

129 return sample 

130 

131 

132def save_hashes( 

133 path: Path | None = None, 

134 verbose: bool = False, 

135 parallelize: bool | int = False, 

136) -> Int64[np.ndarray, " tokenizers"]: 

137 """Computes, sorts, and saves the hashes of every member of `get_all_tokenizers()`.""" 

138 spinner = ( 

139 functools.partial(SpinnerContext, spinner_chars="square_dot") 

140 if verbose 

141 else NoOpContextManager 

142 ) 

143 

144 # get all tokenizers 

145 with spinner(initial_value="getting all tokenizers...", update_interval=2.0): 

146 all_tokenizers = get_all_tokenizers() 

147 

148 # compute hashes 

149 hashes_array: Int64[np.ndarray, " tokenizers+dupes"] 

150 if parallelize: 

151 n_cpus: int = ( 

152 parallelize if int(parallelize) > 1 else multiprocessing.cpu_count() 

153 ) 

154 with spinner( # noqa: SIM117 

155 initial_value=f"using {n_cpus} processes to compute {len(all_tokenizers)} tokenizer hashes...", 

156 update_interval=2.0, 

157 ): 

158 with multiprocessing.Pool(processes=n_cpus) as pool: 

159 hashes_list: list[int] = list(pool.map(hash, all_tokenizers)) 

160 

161 with spinner(initial_value="converting hashes to numpy array..."): 

162 hashes_array = np.array(hashes_list, dtype=np.int64) 

163 else: 

164 with spinner( 

165 initial_value=f"computing {len(all_tokenizers)} tokenizer hashes...", 

166 ): 

167 hashes_array = np.array( 

168 [ 

169 hash(obj) # uses stable hash 

170 for obj in tqdm(all_tokenizers, disable=not verbose) 

171 ], 

172 dtype=np.int64, 

173 ) 

174 

175 # make sure there are no dupes 

176 with spinner(initial_value="sorting and checking for hash collisions..."): 

177 sorted_hashes, counts = np.unique(hashes_array, return_counts=True) 

178 if sorted_hashes.shape[0] != hashes_array.shape[0]: 

179 collisions = sorted_hashes[counts > 1] 

180 err_msg: str = f"{hashes_array.shape[0] - sorted_hashes.shape[0]} tokenizer hash collisions: {collisions}\nReport error to the developer to increase the hash size or otherwise update the tokenizer hashing algorithm." 

181 raise ValueError( 

182 err_msg, 

183 ) 

184 

185 # save and return 

186 with spinner(initial_value="saving hashes...", update_interval=0.5): 

187 if path is None: 

188 path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz" 

189 np.savez_compressed( 

190 path, 

191 hashes=sorted_hashes, 

192 ) 

193 

194 return sorted_hashes