docs for maze-dataset v1.2.0
View Source on GitHub

maze_dataset.tokenization.all_tokenizers

Contains get_all_tokenizers() and supporting limited-use functions.

get_all_tokenizers()

returns a comprehensive collection of all valid MazeTokenizerModular objects. This is an overwhelming majority subset of the set of all possible MazeTokenizerModular objects. Other tokenizers not contained in get_all_tokenizers() may be possible to construct, but they are untested and not guaranteed to work. This collection is in a separate module since it is expensive to compute and will grow more expensive as features are added to MazeTokenizerModular.

Use Cases

In general, uses for this module are limited to development of the library and specific research studying many tokenization behaviors.

  • Unit testing:
  • Large-scale tokenizer research:
    • Specific research training models on many tokenization behaviors can use get_all_tokenizers() as the maximally inclusive collection
    • get_all_tokenizers() may be subsequently filtered using MazeTokenizerModular.has_element For other uses, it's likely that the computational expense can be avoided by using
  • maze_tokenizer.get_all_tokenizer_hashes() for membership checks
  • utils.all_instances for generating smaller subsets of MazeTokenizerModular or _TokenizerElement objects

EVERY_TEST_TOKENIZERS

A collection of the tokenizers which should always be included in unit tests when test fuzzing is used. This collection should be expanded as specific tokenizers become canonical or popular.


  1"""Contains `get_all_tokenizers()` and supporting limited-use functions.
  2
  3# `get_all_tokenizers()`
  4returns a comprehensive collection of all valid `MazeTokenizerModular` objects.
  5This is an overwhelming majority subset of the set of all possible `MazeTokenizerModular` objects.
  6Other tokenizers not contained in `get_all_tokenizers()` may be possible to construct, but they are untested and not guaranteed to work.
  7This collection is in a separate module since it is expensive to compute and will grow more expensive as features are added to `MazeTokenizerModular`.
  8
  9## Use Cases
 10In general, uses for this module are limited to development of the library and specific research studying many tokenization behaviors.
 11- Unit testing:
 12  - Tokenizers to use in unit tests are sampled from `get_all_tokenizers()`
 13- Large-scale tokenizer research:
 14  - Specific research training models on many tokenization behaviors can use `get_all_tokenizers()` as the maximally inclusive collection
 15  - `get_all_tokenizers()` may be subsequently filtered using `MazeTokenizerModular.has_element`
 16For other uses, it's likely that the computational expense can be avoided by using
 17- `maze_tokenizer.get_all_tokenizer_hashes()` for membership checks
 18- `utils.all_instances` for generating smaller subsets of `MazeTokenizerModular` or `_TokenizerElement` objects
 19
 20# `EVERY_TEST_TOKENIZERS`
 21A collection of the tokenizers which should always be included in unit tests when test fuzzing is used.
 22This collection should be expanded as specific tokenizers become canonical or popular.
 23"""
 24
 25import functools
 26import multiprocessing
 27import random
 28from functools import cache
 29from pathlib import Path
 30from typing import Callable
 31
 32import frozendict
 33import numpy as np
 34from jaxtyping import Int64
 35from muutils.spinner import NoOpContextManager, SpinnerContext
 36from tqdm import tqdm
 37
 38from maze_dataset.tokenization import (
 39	CoordTokenizers,
 40	MazeTokenizerModular,
 41	PromptSequencers,
 42	StepTokenizers,
 43	_TokenizerElement,
 44)
 45from maze_dataset.utils import FiniteValued, all_instances
 46
 47# Always include this as the first item in the dict `validation_funcs` whenever using `all_instances` with `MazeTokenizerModular`
 48# TYPING: error: Type variable "maze_dataset.utils.FiniteValued" is unbound  [valid-type]
 49#   note: (Hint: Use "Generic[FiniteValued]" or "Protocol[FiniteValued]" base class to bind "FiniteValued" inside a class)
 50#   note: (Hint: Use "FiniteValued" in function signature to bind "FiniteValued" inside a function)
 51MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS: frozendict.frozendict[
 52	type[FiniteValued],
 53	Callable[[FiniteValued], bool],
 54] = frozendict.frozendict(
 55	{
 56		# TYPING: Item "bool" of the upper bound "bool | IsDataclass | Enum" of type variable "FiniteValued" has no attribute "is_valid"  [union-attr]
 57		_TokenizerElement: lambda x: x.is_valid(),
 58		# Currently no need for `MazeTokenizerModular.is_valid` since that method contains no special cases not already covered by `_TokenizerElement.is_valid`
 59		# MazeTokenizerModular: lambda x: x.is_valid(),
 60		# TYPING: error: No overload variant of "set" matches argument type "FiniteValued"  [call-overload]
 61		#   note: Possible overload variants:
 62		#   note:     def [_T] set(self) -> set[_T]
 63		#   note:     def [_T] set(self, Iterable[_T], /) -> set[_T]
 64		# TYPING: error: Argument 1 to "len" has incompatible type "FiniteValued"; expected "Sized"  [arg-type]
 65		StepTokenizers.StepTokenizerPermutation: lambda x: len(set(x)) == len(x)
 66		and x != (StepTokenizers.Distance(),),
 67	},
 68)
 69
 70
 71@cache
 72def get_all_tokenizers() -> list[MazeTokenizerModular]:
 73	"""Computes a complete list of all valid tokenizers.
 74
 75	Warning: This is an expensive function.
 76	"""
 77	return list(
 78		all_instances(
 79			MazeTokenizerModular,
 80			validation_funcs=MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS,
 81		),
 82	)
 83
 84
 85EVERY_TEST_TOKENIZERS: list[MazeTokenizerModular] = [
 86	MazeTokenizerModular(),
 87	MazeTokenizerModular(
 88		prompt_sequencer=PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.CTT()),
 89	),
 90	# TODO: add more here as specific tokenizers become canonical and frequently used
 91]
 92
 93
 94@cache
 95def all_tokenizers_set() -> set[MazeTokenizerModular]:
 96	"""Casts `get_all_tokenizers()` to a set."""
 97	return set(get_all_tokenizers())
 98
 99
100@cache
101def _all_tokenizers_except_every_test_tokenizers() -> list[MazeTokenizerModular]:
102	"""Returns"""
103	return list(all_tokenizers_set().difference(EVERY_TEST_TOKENIZERS))
104
105
106def sample_all_tokenizers(n: int) -> list[MazeTokenizerModular]:
107	"""Samples `n` tokenizers from `get_all_tokenizers()`."""
108	return random.sample(get_all_tokenizers(), n)
109
110
111def sample_tokenizers_for_test(n: int | None) -> list[MazeTokenizerModular]:
112	"""Returns a sample of size `n` of unique elements from `get_all_tokenizers()`,
113
114	always including every element in `EVERY_TEST_TOKENIZERS`.
115	"""
116	if n is None:
117		return get_all_tokenizers()
118
119	if n < len(EVERY_TEST_TOKENIZERS):
120		err_msg: str = f"`n` must be at least {len(EVERY_TEST_TOKENIZERS) = } such that the sample can contain `EVERY_TEST_TOKENIZERS`."
121		raise ValueError(
122			err_msg,
123		)
124	sample: list[MazeTokenizerModular] = random.sample(
125		_all_tokenizers_except_every_test_tokenizers(),
126		n - len(EVERY_TEST_TOKENIZERS),
127	)
128	sample.extend(EVERY_TEST_TOKENIZERS)
129	return sample
130
131
132def save_hashes(
133	path: Path | None = None,
134	verbose: bool = False,
135	parallelize: bool | int = False,
136) -> Int64[np.ndarray, " tokenizers"]:
137	"""Computes, sorts, and saves the hashes of every member of `get_all_tokenizers()`."""
138	spinner = (
139		functools.partial(SpinnerContext, spinner_chars="square_dot")
140		if verbose
141		else NoOpContextManager
142	)
143
144	# get all tokenizers
145	with spinner(initial_value="getting all tokenizers...", update_interval=2.0):
146		all_tokenizers = get_all_tokenizers()
147
148	# compute hashes
149	hashes_array: Int64[np.ndarray, " tokenizers+dupes"]
150	if parallelize:
151		n_cpus: int = (
152			parallelize if int(parallelize) > 1 else multiprocessing.cpu_count()
153		)
154		with spinner(  # noqa: SIM117
155			initial_value=f"using {n_cpus} processes to compute {len(all_tokenizers)} tokenizer hashes...",
156			update_interval=2.0,
157		):
158			with multiprocessing.Pool(processes=n_cpus) as pool:
159				hashes_list: list[int] = list(pool.map(hash, all_tokenizers))
160
161		with spinner(initial_value="converting hashes to numpy array..."):
162			hashes_array = np.array(hashes_list, dtype=np.int64)
163	else:
164		with spinner(
165			initial_value=f"computing {len(all_tokenizers)} tokenizer hashes...",
166		):
167			hashes_array = np.array(
168				[
169					hash(obj)  # uses stable hash
170					for obj in tqdm(all_tokenizers, disable=not verbose)
171				],
172				dtype=np.int64,
173			)
174
175	# make sure there are no dupes
176	with spinner(initial_value="sorting and checking for hash collisions..."):
177		sorted_hashes, counts = np.unique(hashes_array, return_counts=True)
178		if sorted_hashes.shape[0] != hashes_array.shape[0]:
179			collisions = sorted_hashes[counts > 1]
180			err_msg: str = f"{hashes_array.shape[0] - sorted_hashes.shape[0]} tokenizer hash collisions: {collisions}\nReport error to the developer to increase the hash size or otherwise update the tokenizer hashing algorithm."
181			raise ValueError(
182				err_msg,
183			)
184
185	# save and return
186	with spinner(initial_value="saving hashes...", update_interval=0.5):
187		if path is None:
188			path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz"
189		np.savez_compressed(
190			path,
191			hashes=sorted_hashes,
192		)
193
194	return sorted_hashes

MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS: frozendict.frozendict[type[~FiniteValued], typing.Callable[[~FiniteValued], bool]] = frozendict.frozendict({<class 'maze_dataset.tokenization.maze_tokenizer._TokenizerElement'>: <function <lambda>>, tuple[maze_dataset.tokenization.maze_tokenizer.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.maze_tokenizer.StepTokenizers._StepTokenizer, maze_dataset.tokenization.maze_tokenizer.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.maze_tokenizer.StepTokenizers._StepTokenizer, maze_dataset.tokenization.maze_tokenizer.StepTokenizers._StepTokenizer, maze_dataset.tokenization.maze_tokenizer.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.maze_tokenizer.StepTokenizers._StepTokenizer, maze_dataset.tokenization.maze_tokenizer.StepTokenizers._StepTokenizer, maze_dataset.tokenization.maze_tokenizer.StepTokenizers._StepTokenizer, maze_dataset.tokenization.maze_tokenizer.StepTokenizers._StepTokenizer]: <function <lambda>>})
@cache
def get_all_tokenizers() -> list[maze_dataset.tokenization.MazeTokenizerModular]:
72@cache
73def get_all_tokenizers() -> list[MazeTokenizerModular]:
74	"""Computes a complete list of all valid tokenizers.
75
76	Warning: This is an expensive function.
77	"""
78	return list(
79		all_instances(
80			MazeTokenizerModular,
81			validation_funcs=MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS,
82		),
83	)

Computes a complete list of all valid tokenizers.

Warning: This is an expensive function.

EVERY_TEST_TOKENIZERS: list[maze_dataset.tokenization.MazeTokenizerModular] = [MazeTokenizerModular(prompt_sequencer=PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.UT(), adj_list_tokenizer=AdjListTokenizers.AdjListCoord(pre=False, post=True, shuffle_d0=True, edge_grouping=EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset=EdgeSubsets.ConnectionEdges(walls=False), edge_permuter=EdgePermuters.RandomCoords()), target_tokenizer=TargetTokenizers.Unlabeled(post=False), path_tokenizer=PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False))), MazeTokenizerModular(prompt_sequencer=PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.CTT(pre=True, intra=True, post=True), adj_list_tokenizer=AdjListTokenizers.AdjListCoord(pre=False, post=True, shuffle_d0=True, edge_grouping=EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset=EdgeSubsets.ConnectionEdges(walls=False), edge_permuter=EdgePermuters.RandomCoords()), target_tokenizer=TargetTokenizers.Unlabeled(post=False), path_tokenizer=PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False)))]
@cache
def all_tokenizers_set() -> set[maze_dataset.tokenization.MazeTokenizerModular]:
95@cache
96def all_tokenizers_set() -> set[MazeTokenizerModular]:
97	"""Casts `get_all_tokenizers()` to a set."""
98	return set(get_all_tokenizers())

Casts get_all_tokenizers() to a set.

def sample_all_tokenizers( n: int) -> list[maze_dataset.tokenization.MazeTokenizerModular]:
107def sample_all_tokenizers(n: int) -> list[MazeTokenizerModular]:
108	"""Samples `n` tokenizers from `get_all_tokenizers()`."""
109	return random.sample(get_all_tokenizers(), n)

Samples n tokenizers from get_all_tokenizers().

def sample_tokenizers_for_test( n: int | None) -> list[maze_dataset.tokenization.MazeTokenizerModular]:
112def sample_tokenizers_for_test(n: int | None) -> list[MazeTokenizerModular]:
113	"""Returns a sample of size `n` of unique elements from `get_all_tokenizers()`,
114
115	always including every element in `EVERY_TEST_TOKENIZERS`.
116	"""
117	if n is None:
118		return get_all_tokenizers()
119
120	if n < len(EVERY_TEST_TOKENIZERS):
121		err_msg: str = f"`n` must be at least {len(EVERY_TEST_TOKENIZERS) = } such that the sample can contain `EVERY_TEST_TOKENIZERS`."
122		raise ValueError(
123			err_msg,
124		)
125	sample: list[MazeTokenizerModular] = random.sample(
126		_all_tokenizers_except_every_test_tokenizers(),
127		n - len(EVERY_TEST_TOKENIZERS),
128	)
129	sample.extend(EVERY_TEST_TOKENIZERS)
130	return sample

Returns a sample of size n of unique elements from get_all_tokenizers(),

always including every element in EVERY_TEST_TOKENIZERS.

def save_hashes( path: pathlib.Path | None = None, verbose: bool = False, parallelize: bool | int = False) -> jaxtyping.Int64[ndarray, 'tokenizers']:
133def save_hashes(
134	path: Path | None = None,
135	verbose: bool = False,
136	parallelize: bool | int = False,
137) -> Int64[np.ndarray, " tokenizers"]:
138	"""Computes, sorts, and saves the hashes of every member of `get_all_tokenizers()`."""
139	spinner = (
140		functools.partial(SpinnerContext, spinner_chars="square_dot")
141		if verbose
142		else NoOpContextManager
143	)
144
145	# get all tokenizers
146	with spinner(initial_value="getting all tokenizers...", update_interval=2.0):
147		all_tokenizers = get_all_tokenizers()
148
149	# compute hashes
150	hashes_array: Int64[np.ndarray, " tokenizers+dupes"]
151	if parallelize:
152		n_cpus: int = (
153			parallelize if int(parallelize) > 1 else multiprocessing.cpu_count()
154		)
155		with spinner(  # noqa: SIM117
156			initial_value=f"using {n_cpus} processes to compute {len(all_tokenizers)} tokenizer hashes...",
157			update_interval=2.0,
158		):
159			with multiprocessing.Pool(processes=n_cpus) as pool:
160				hashes_list: list[int] = list(pool.map(hash, all_tokenizers))
161
162		with spinner(initial_value="converting hashes to numpy array..."):
163			hashes_array = np.array(hashes_list, dtype=np.int64)
164	else:
165		with spinner(
166			initial_value=f"computing {len(all_tokenizers)} tokenizer hashes...",
167		):
168			hashes_array = np.array(
169				[
170					hash(obj)  # uses stable hash
171					for obj in tqdm(all_tokenizers, disable=not verbose)
172				],
173				dtype=np.int64,
174			)
175
176	# make sure there are no dupes
177	with spinner(initial_value="sorting and checking for hash collisions..."):
178		sorted_hashes, counts = np.unique(hashes_array, return_counts=True)
179		if sorted_hashes.shape[0] != hashes_array.shape[0]:
180			collisions = sorted_hashes[counts > 1]
181			err_msg: str = f"{hashes_array.shape[0] - sorted_hashes.shape[0]} tokenizer hash collisions: {collisions}\nReport error to the developer to increase the hash size or otherwise update the tokenizer hashing algorithm."
182			raise ValueError(
183				err_msg,
184			)
185
186	# save and return
187	with spinner(initial_value="saving hashes...", update_interval=0.5):
188		if path is None:
189			path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz"
190		np.savez_compressed(
191			path,
192			hashes=sorted_hashes,
193		)
194
195	return sorted_hashes

Computes, sorts, and saves the hashes of every member of get_all_tokenizers().