Coverage for maze_dataset\tokenization\all_tokenizers.py: 0%
57 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1"""Contains `get_all_tokenizers()` and supporting limited-use functions.
3# `get_all_tokenizers()`
4returns a comprehensive collection of all valid `MazeTokenizerModular` objects.
5This is an overwhelming majority subset of the set of all possible `MazeTokenizerModular` objects.
6Other tokenizers not contained in `get_all_tokenizers()` may be possible to construct, but they are untested and not guaranteed to work.
7This collection is in a separate module since it is expensive to compute and will grow more expensive as features are added to `MazeTokenizerModular`.
9## Use Cases
10In general, uses for this module are limited to development of the library and specific research studying many tokenization behaviors.
11- Unit testing:
12 - Tokenizers to use in unit tests are sampled from `get_all_tokenizers()`
13- Large-scale tokenizer research:
14 - Specific research training models on many tokenization behaviors can use `get_all_tokenizers()` as the maximally inclusive collection
15 - `get_all_tokenizers()` may be subsequently filtered using `MazeTokenizerModular.has_element`
16For other uses, it's likely that the computational expense can be avoided by using
17- `maze_tokenizer.get_all_tokenizer_hashes()` for membership checks
18- `utils.all_instances` for generating smaller subsets of `MazeTokenizerModular` or `_TokenizerElement` objects
20# `EVERY_TEST_TOKENIZERS`
21A collection of the tokenizers which should always be included in unit tests when test fuzzing is used.
22This collection should be expanded as specific tokenizers become canonical or popular.
23"""
25import functools
26import multiprocessing
27import random
28from functools import cache
29from pathlib import Path
30from typing import Callable
32import frozendict
33import numpy as np
34from jaxtyping import Int64
35from muutils.spinner import NoOpContextManager, SpinnerContext
36from tqdm import tqdm
38from maze_dataset.tokenization import (
39 CoordTokenizers,
40 MazeTokenizerModular,
41 PromptSequencers,
42 StepTokenizers,
43 _TokenizerElement,
44)
45from maze_dataset.utils import FiniteValued, all_instances
47# Always include this as the first item in the dict `validation_funcs` whenever using `all_instances` with `MazeTokenizerModular`
48MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS: frozendict.frozendict[
49 type[FiniteValued], Callable[[FiniteValued], bool]
50] = frozendict.frozendict(
51 {
52 _TokenizerElement: lambda x: x.is_valid(),
53 # Currently no need for `MazeTokenizerModular.is_valid` since that method contains no special cases not already covered by `_TokenizerElement.is_valid`
54 # MazeTokenizerModular: lambda x: x.is_valid(),
55 StepTokenizers.StepTokenizerPermutation: lambda x: len(set(x)) == len(x)
56 and x != (StepTokenizers.Distance(),),
57 }
58)
61@cache
62def get_all_tokenizers() -> list[MazeTokenizerModular]:
63 """
64 Computes a complete list of all valid tokenizers.
65 Warning: This is an expensive function.
66 """
67 return list(
68 all_instances(
69 MazeTokenizerModular,
70 validation_funcs=MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS,
71 )
72 )
75EVERY_TEST_TOKENIZERS: list[MazeTokenizerModular] = [
76 MazeTokenizerModular(),
77 MazeTokenizerModular(
78 prompt_sequencer=PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.CTT())
79 ),
80 # TODO: add more here as specific tokenizers become canonical and frequently used
81]
84@cache
85def all_tokenizers_set() -> set[MazeTokenizerModular]:
86 """Casts `get_all_tokenizers()` to a set."""
87 return set(get_all_tokenizers())
90@cache
91def _all_tokenizers_except_every_test_tokenizers() -> list[MazeTokenizerModular]:
92 """Returns"""
93 return list(all_tokenizers_set().difference(EVERY_TEST_TOKENIZERS))
96def sample_all_tokenizers(n: int) -> list[MazeTokenizerModular]:
97 """Samples `n` tokenizers from `get_all_tokenizers()`."""
98 return random.sample(get_all_tokenizers(), n)
101def sample_tokenizers_for_test(n: int | None) -> list[MazeTokenizerModular]:
102 """Returns a sample of size `n` of unique elements from `get_all_tokenizers()`,
103 always including every element in `EVERY_TEST_TOKENIZERS`.
104 """
105 if n is None:
106 return get_all_tokenizers()
108 if n < len(EVERY_TEST_TOKENIZERS):
109 raise ValueError(
110 f"`n` must be at least {len(EVERY_TEST_TOKENIZERS) = } such that the sample can contain `EVERY_TEST_TOKENIZERS`."
111 )
112 sample: list[MazeTokenizerModular] = random.sample(
113 _all_tokenizers_except_every_test_tokenizers(), n - len(EVERY_TEST_TOKENIZERS)
114 )
115 sample.extend(EVERY_TEST_TOKENIZERS)
116 return sample
119def save_hashes(
120 path: Path | None = None,
121 verbose: bool = False,
122 parallelize: bool | int = False,
123) -> Int64[np.ndarray, " tokenizers"]:
124 """Computes, sorts, and saves the hashes of every member of `get_all_tokenizers()`."""
125 spinner = (
126 functools.partial(SpinnerContext, spinner_chars="square_dot")
127 if verbose
128 else NoOpContextManager
129 )
131 # get all tokenizers
132 with spinner(initial_value="getting all tokenizers...", update_interval=2.0):
133 all_tokenizers = get_all_tokenizers()
135 # compute hashes
136 if parallelize:
137 n_cpus: int = (
138 parallelize if int(parallelize) > 1 else multiprocessing.cpu_count()
139 )
140 with spinner(
141 initial_value=f"using {n_cpus} processes to compute {len(all_tokenizers)} tokenizer hashes...",
142 update_interval=2.0,
143 ):
144 with multiprocessing.Pool(processes=n_cpus) as pool:
145 hashes_list: list[int] = list(pool.map(hash, all_tokenizers))
147 with spinner(initial_value="converting hashes to numpy array..."):
148 hashes_array: "Int64[np.ndarray, ' tokenizers+dupes']" = np.array(
149 hashes_list, dtype=np.int64
150 )
151 else:
152 with spinner(
153 initial_value=f"computing {len(all_tokenizers)} tokenizer hashes..."
154 ):
155 hashes_array: "Int64[np.ndarray, ' tokenizers+dupes']" = np.array(
156 [
157 hash(obj) # uses stable hash
158 for obj in tqdm(all_tokenizers, disable=not verbose)
159 ],
160 dtype=np.int64,
161 )
163 # make sure there are no dupes
164 with spinner(initial_value="sorting and checking for hash collisions..."):
165 sorted_hashes, counts = np.unique(hashes_array, return_counts=True)
166 if sorted_hashes.shape[0] != hashes_array.shape[0]:
167 collisions = sorted_hashes[counts > 1]
168 raise ValueError(
169 f"{hashes_array.shape[0] - sorted_hashes.shape[0]} tokenizer hash collisions: {collisions}\nReport error to the developer to increase the hash size or otherwise update the tokenizer hashing algorithm."
170 )
172 # save and return
173 with spinner(initial_value="saving hashes...", update_interval=0.5):
174 if path is None:
175 path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz"
176 np.savez_compressed(
177 path,
178 hashes=sorted_hashes,
179 )
181 return sorted_hashes