Coverage for maze_dataset\tokenization\maze_tokenizer.py: 78%
733 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1"""turning a maze into text: `MazeTokenizerModular` and the legacy `TokenizationMode` enum and `MazeTokenizer` class"""
3import abc
4import base64
5import hashlib
6import random
7import warnings
8from enum import Enum
9from functools import cached_property
10from pathlib import Path
11from typing import (
12 Any,
13 Callable,
14 Iterable,
15 Literal,
16 Mapping,
17 Sequence,
18 TypedDict,
19 TypeVar,
20 overload,
21)
23import numpy as np
24from jaxtyping import Bool, Int, Int64
25from muutils.json_serialize import (
26 SerializableDataclass,
27 serializable_dataclass,
28 serializable_field,
29)
30from muutils.json_serialize.util import _FORMAT_KEY
31from muutils.kappa import Kappa
32from muutils.misc import empty_sequence_if_attr_false, flatten
33from muutils.misc.sequence import WhenMissing
34from zanj.loading import load_item_recursive
36# from maze_dataset import SolvedMaze
37from maze_dataset.constants import (
38 SPECIAL_TOKENS,
39 VOCAB,
40 VOCAB_LIST,
41 VOCAB_TOKEN_TO_INDEX,
42 ConnectionArray,
43 ConnectionList,
44 Coord,
45 CoordTup,
46)
47from maze_dataset.generation import numpy_rng
48from maze_dataset.maze.lattice_maze import LatticeMaze, SolvedMaze
49from maze_dataset.token_utils import (
50 TokenizerPendingDeprecationWarning,
51 _coord_to_strings_indexed,
52 _coord_to_strings_UT,
53 connection_list_to_adj_list,
54 coords_to_strings,
55 get_cardinal_direction,
56 get_relative_direction,
57 is_connection,
58 strings_to_coords,
59 tokens_between,
60)
61from maze_dataset.utils import corner_first_ndindex, lattice_connection_array
64class TokenError(ValueError):
65 """error for tokenization"""
67 pass
70class TokenizationMode(Enum):
71 """legacy tokenization modes
73 > [!CAUTION]
74 > Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use.
75 > Use `MazeTokenizerModular` instead.
77 # Abbreviations:
78 - `AOTP`: Ajacency list, Origin, Target, Path
79 - `UT`: Unique Token (for each coordiate)
80 - `CTT`: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers)
82 # Modes:
83 - `AOTP_UT_rasterized`: the "classic" mode: assigning tokens to each coordinate is done via rasterization
84 example: for a 3x3 maze, token order is `(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)`
85 - `AOTP_UT_uniform`: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible
86 uses `corner_first_ndindex` function to order the tokens
87 - `AOTP_CTT_indexed`: each coordinate is a tuple of integers
88 """
90 AOTP_UT_rasterized = "AOTP_UT_rasterized"
91 AOTP_UT_uniform = "AOTP_UT_uniform"
92 AOTP_CTT_indexed = "AOTP_CTT_indexed"
94 def to_legacy_tokenizer(self, max_grid_size: int | None = None):
95 return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)
98_NDINDEX_FUNC_MAP: dict[
99 TokenizationMode, Callable[[int], Iterable[tuple[int, ...]]]
100] = {
101 TokenizationMode.AOTP_UT_rasterized: lambda n: list(np.ndindex(n, n)),
102 TokenizationMode.AOTP_UT_uniform: lambda n: corner_first_ndindex(n, 2),
103}
106def is_UT(tokenization_mode: TokenizationMode) -> bool:
107 return tokenization_mode in (
108 TokenizationMode.AOTP_UT_rasterized,
109 TokenizationMode.AOTP_UT_uniform,
110 )
113def get_tokens_up_to_path_start(
114 tokens: list[str],
115 include_start_coord: bool = True,
116 tokenization_mode: TokenizationMode = TokenizationMode.AOTP_UT_uniform,
117) -> list[str]:
118 warnings.warn(
119 "`maze_tokenizer.get_tokens_up_to_path_start` will be deprecated for a `MazeTokenizerModular`-compatible function in a future release.",
120 TokenizerPendingDeprecationWarning,
121 )
122 path_start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + 1
123 if include_start_coord:
124 if is_UT(tokenization_mode):
125 return tokens[: path_start_idx + 1]
126 elif tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
127 return tokens[: path_start_idx + 5]
128 else:
129 raise ValueError(f"Invalid tokenization mode: {tokenization_mode}")
130 else:
131 return tokens[:path_start_idx]
134_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE: list[str] = [
135 "name",
136 "max_grid_size",
137 "token_arr",
138 "tokenizer_map",
139 "vocab_size",
140 "padding_token_index",
141]
144@serializable_dataclass(
145 properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE, kw_only=True
146)
147class MazeTokenizer(SerializableDataclass):
148 """LEGACY Tokenizer for mazes
150 > [!CAUTION]
151 > `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended
152 > for use, but will remain for compatibility with existing code.
154 # Parameters:
155 - `tokenization_mode: TokenizationMode`
156 mode of tokenization. required.
157 - `max_grid_size: int | None`
158 maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text
160 # Properties
161 - `name: str`
162 auto-generated name of the tokenizer from mode and size
164 ## Conditional Properties
166 - `node_strings_map: Mapping[CoordTup, str]`
167 map from node to string. This returns a `muutils.kappa.Kappa` object which you can use like a dictionary. returns `None` if not a `UT` mode
169 these all return `None` if `max_grid_size` is `None`.
170 Prepend `_` to the name to get a guaranteed type, and cause an exception if `max_grid_size` is `None`
172 - `token_arr: list[str]`
173 list of tokens, in order of their indices in the vocabulary
174 - `tokenizer_map: Mapping[str, int]`
175 map from token to index
176 - `vocab_size: int`
177 size of the vocabulary
178 - `padding_token_index: int`
179 index of the padding token
181 # Methods
182 - `coords_to_strings(coords: list[CoordTup]) -> list[str]`
183 convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates
184 - `strings_to_coords(strings: list[str]) -> list[CoordTup]`
185 convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates
187 """
189 # parameters
190 # ============================================================
192 tokenization_mode: TokenizationMode = serializable_field(
193 default=TokenizationMode.AOTP_UT_uniform,
194 serialization_fn=lambda x: x.value,
195 loading_fn=lambda x: TokenizationMode[x["tokenization_mode"]],
196 )
198 max_grid_size: int | None = serializable_field(default=None)
200 # properties
201 # ============================================================
203 @property
204 def name(self) -> str:
205 max_grid_size_str: str = (
206 f"-g{self.max_grid_size}" if self.max_grid_size is not None else ""
207 )
208 return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}"
210 @cached_property
211 def _node_strings_map(self) -> Mapping[CoordTup, list[str]]:
212 """map a coordinate to a token"""
213 if self.tokenization_mode in (
214 TokenizationMode.AOTP_UT_rasterized,
215 TokenizationMode.AOTP_UT_uniform,
216 ):
217 return Kappa(_coord_to_strings_UT)
218 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
219 return Kappa(_coord_to_strings_indexed)
220 else:
221 raise ValueError(
222 f"Invalid tokenization mode {self.tokenization_mode}",
223 f"expected one of {TokenizationMode.__members__}",
224 )
226 @cached_property
227 def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None:
228 """map a coordinate to a token"""
229 if self.tokenization_mode in (
230 TokenizationMode.AOTP_UT_rasterized,
231 TokenizationMode.AOTP_UT_uniform,
232 ):
233 return None
234 else:
235 return self._node_strings_map
237 # conditional properties (on max_grid_size existing)
238 # ------------------------------------------------------------
240 @cached_property
241 def _token_arr(self) -> list[str]:
242 """map from index to token"""
243 if self.max_grid_size is None:
244 raise ValueError(
245 f"max_grid_size must be specified to use token_arr property: {self.max_grid_size = }"
246 )
248 output: list[str] = list(SPECIAL_TOKENS.values())
250 if self.tokenization_mode in (
251 TokenizationMode.AOTP_UT_rasterized,
252 TokenizationMode.AOTP_UT_uniform,
253 ):
254 output.extend(
255 [
256 self._node_strings_map[coord][0]
257 for coord in _NDINDEX_FUNC_MAP[self.tokenization_mode](
258 self.max_grid_size
259 )
260 ]
261 )
262 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
263 # TODO: this is hacky, but we don't want to modify the original SPECIAL_TOKENS since that will break old models
264 output.extend(
265 [
266 "(",
267 ",",
268 ")", # new special chars
269 *map(str, range(self.max_grid_size)), # numbers
270 ]
271 )
272 else:
273 raise ValueError(
274 f"Invalid tokenization mode {self.tokenization_mode}",
275 f"expected one of {TokenizationMode.__members__}",
276 )
278 return output
280 @cached_property
281 def token_arr(self) -> list[str] | None:
282 if self.max_grid_size is None:
283 return None
284 return self._token_arr
286 @cached_property
287 def _tokenizer_map(self) -> dict[str, int]:
288 """map from token to index"""
289 return {token: i for i, token in enumerate(self._token_arr)}
291 @cached_property
292 def tokenizer_map(self) -> dict[str, int] | None:
293 if self.max_grid_size is None:
294 return None
295 return self._tokenizer_map
297 @property
298 def _vocab_size(self) -> int:
299 return len(self._token_arr)
301 @property
302 def vocab_size(self) -> int | None:
303 if self.max_grid_size is None:
304 return None
305 return self._vocab_size
307 @property
308 def _n_tokens(self) -> int:
309 # TODO: deprecate
310 return self._vocab_size
312 @property
313 def n_tokens(self) -> int | None:
314 if self.max_grid_size is None:
315 return None
316 return self._n_tokens
318 @cached_property
319 def _padding_token_index(self) -> int:
320 return self.tokenizer_map[SPECIAL_TOKENS.PADDING]
322 @cached_property
323 def padding_token_index(self) -> int | None:
324 if self.max_grid_size is None:
325 return None
326 return self._padding_token_index
328 # conversion functions
329 # ============================================================
331 @overload
332 def coords_to_strings(
333 self,
334 coords: list[str | CoordTup],
335 when_noncoord: Literal["include", "skip"] = "skip",
336 ) -> list[str]: ...
337 @overload
338 def coords_to_strings(
339 self,
340 coords: list[CoordTup],
341 when_noncoord: Literal["error"] = "error",
342 ) -> list[str]: ...
343 def coords_to_strings(
344 self,
345 coords: list[CoordTup],
346 when_noncoord: WhenMissing = "skip",
347 ) -> list[str]:
348 if self.tokenization_mode in (
349 TokenizationMode.AOTP_UT_rasterized,
350 TokenizationMode.AOTP_UT_uniform,
351 ):
352 return coords_to_strings(
353 coords=coords,
354 coord_to_strings_func=_coord_to_strings_UT,
355 when_noncoord=when_noncoord,
356 )
357 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
358 return coords_to_strings(
359 coords=coords,
360 coord_to_strings_func=_coord_to_strings_indexed,
361 when_noncoord=when_noncoord,
362 )
363 else:
364 raise ValueError(
365 f"Invalid tokenization mode {self.tokenization_mode}",
366 f"expected one of {TokenizationMode.__members__}",
367 )
369 @overload
370 def strings_to_coords(
371 cls,
372 text: str | list[str],
373 when_noncoord: Literal["skip"] = "skip",
374 ) -> list[CoordTup]: ...
375 @overload
376 def strings_to_coords(
377 cls,
378 text: str | list[str],
379 when_noncoord: Literal["error"] = "error",
380 ) -> list[CoordTup]: ...
381 @overload
382 def strings_to_coords(
383 cls,
384 text: str | list[str],
385 when_noncoord: Literal["include"] = "include",
386 ) -> list[str | CoordTup]: ...
387 @classmethod
388 def strings_to_coords(
389 cls,
390 text: str | list[str],
391 when_noncoord: WhenMissing = "skip",
392 ) -> list[str | CoordTup]:
393 return strings_to_coords(text=text, when_noncoord=when_noncoord)
395 def encode(self, text: str | list[str]) -> list[int]:
396 """encode a string or list of strings into a list of tokens"""
397 try:
398 if isinstance(text, str):
399 text = text.split()
400 return [self.tokenizer_map[token] for token in text]
401 except KeyError as e:
402 raise TokenError(
403 f"Token {e} not found",
404 f"in vocabulary of {self}:",
405 f"{self.token_arr}",
406 ) from e
408 def decode(
409 self, tokens: Sequence[int], joined_tokens: bool = False
410 ) -> list[str] | str:
411 """decode a list of tokens into a string or list of strings"""
412 try:
413 output: list[str] = [self.token_arr[token] for token in tokens]
414 except IndexError as e:
415 raise TokenError(
416 f"Token index '{e}' not found in vocabulary of length {self.vocab_size}"
417 ) from e
418 if joined_tokens:
419 return " ".join(output)
420 else:
421 return output
423 # UT-only coordinate stuff
424 # ============================================================
426 @cached_property
427 def coordinate_tokens_coords(self) -> dict[CoordTup, int]:
428 print(f"{self.tokenization_mode = }")
429 if not self.is_UT():
430 raise ValueError(
431 f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }"
432 )
433 if self.max_grid_size is None:
434 raise ValueError(
435 f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }"
436 )
438 raw_converted: list[CoordTup | str] = self.strings_to_coords(
439 self.token_arr, when_noncoord="include"
440 )
442 # filter out non-coordinates
443 return {
444 coord: i
445 for i, coord in enumerate(raw_converted)
446 if not isinstance(coord, str)
447 }
449 @cached_property
450 def coordinate_tokens_ids(self) -> dict[str, int]:
451 # checks performed in call
452 output: dict[str, int] = dict()
454 for coord, index in self.coordinate_tokens_coords.items():
455 _for_key: list[str] = self.coords_to_strings([coord])
456 assert len(_for_key) == 1
457 output[_for_key[0]] = index
459 return output
461 # other
462 # ============================================================
464 def summary(self) -> dict:
465 """returns a summary of the tokenization mode"""
466 return {
467 "tokenization_mode": self.tokenization_mode.value,
468 "max_grid_size": self.max_grid_size,
469 "vocab_size": self.vocab_size,
470 }
472 def is_AOTP(self) -> bool:
473 """returns true if a tokenization mode is Adjacency list, Origin, Target, Path"""
474 return self.tokenization_mode in (
475 TokenizationMode.AOTP_UT_rasterized,
476 TokenizationMode.AOTP_UT_uniform,
477 TokenizationMode.AOTP_CTT_indexed,
478 )
480 def is_UT(self) -> bool:
481 return is_UT(self.tokenization_mode)
483 def clear_cache(self):
484 """clears all cached properties"""
485 # delete the properties only if they exist
486 for name, prop in self.__class__.__dict__.items():
487 if isinstance(prop, cached_property):
488 # if the property exists, delete it
489 try:
490 delattr(self, name)
491 except AttributeError:
492 pass
495@serializable_dataclass(frozen=True, kw_only=True)
496class _TokenizerElement(SerializableDataclass, abc.ABC):
497 """Superclass for tokenizer elements.
498 Subclasses contain modular functionality for maze tokenization.
500 # Development
501 > [!TIP]
502 > Due to the functionality of `get_all_tokenizers()`, `_TokenizerElement` subclasses
503 > may only contain fields of type `utils.FiniteValued`.
504 > Implementing a subclass with an `int` or `float`-typed field, for example, is not supported.
505 > In the event that adding such fields is deemed necessary, `get_all_tokenizers()` must be updated.
507 """
509 @staticmethod
510 def _stringify(k: str, v: Any):
511 if isinstance(v, bool):
512 return f"{k}={str(v)[0]}"
513 if isinstance(v, _TokenizerElement):
514 return v.name
515 if isinstance(v, tuple):
516 return f"{k}={''.join(['(', *[str(x) + ', ' for x in v], ')'])}"
517 else:
518 return f"{k}={v}"
520 @property
521 def name(self) -> str:
522 members_str: str = ", ".join(
523 [self._stringify(k, v) for k, v in self.__dict__.items() if k != "_type_"]
524 )
525 output: str = f"{type(self).__name__}({members_str})"
526 if "." in output and output.index("(") > output.index("."):
527 return "".join(output.split(".")[1:])
528 else:
529 return output
531 def __str__(self):
532 return self.name
534 def __init_subclass__(cls, **kwargs):
535 """
536 Hack: dataclass hashes don't include the class itself in the hash function inputs.
537 This causes dataclasses with identical fields but different types to hash identically.
538 This hack circumvents this by adding a slightly hidden field to every subclass with a value of `repr(cls)`.
539 To maintain compatibility with `all_instances`, the static type of the new field can only have 1 possible value.
540 So we type it as a singleton `Literal` type.
541 muutils 0.6.1 doesn't support `Literal` type validation, so `assert_type=False`.
542 Ignore Pylance complaining about the arg to `Literal` being an expression.
543 """
544 super().__init_subclass__(**kwargs)
545 cls._type_ = serializable_field(
546 init=True, repr=False, default=repr(cls), assert_type=False
547 )
548 cls.__annotations__["_type_"] = Literal[repr(cls)] # type: ignore
550 def __hash__(self):
551 "Stable hash to identify unique `MazeTokenizerModular` instances. uses name"
552 return int.from_bytes(
553 hashlib.blake2b(self.name.encode("utf-8")).digest(),
554 byteorder="big",
555 )
557 @classmethod
558 def _level_one_subclass(cls) -> type["_TokenizerElement"]:
559 """Returns the immediate subclass of `_TokenizerElement` of which `cls` is an instance."""
560 return (
561 set(cls.__mro__).intersection(set(_TokenizerElement.__subclasses__())).pop()
562 )
564 def tokenizer_elements(self, deep: bool = True) -> list["_TokenizerElement"]:
565 """
566 Returns a list of all `_TokenizerElement` instances contained in the subtree.
567 Currently only detects `_TokenizerElement` instances which are either direct attributes of another instance or
568 which sit inside a `tuple` without further nesting.
570 # Parameters
571 - `deep: bool`: Whether to return elements nested arbitrarily deeply or just a single layer.
572 """
573 if not any(type(el) == tuple for el in self.__dict__.values()): # noqa: E721
574 return list(
575 flatten(
576 [
577 [el] + el.tokenizer_elements()
578 for el in self.__dict__.values()
579 if isinstance(el, _TokenizerElement)
580 ]
581 )
582 if deep
583 else filter(
584 lambda x: isinstance(x, _TokenizerElement), self.__dict__.values()
585 )
586 )
587 else:
588 non_tuple_elems: list[_TokenizerElement] = list(
589 flatten(
590 [
591 [el] + el.tokenizer_elements()
592 for el in self.__dict__.values()
593 if isinstance(el, _TokenizerElement)
594 ]
595 if deep
596 else filter(
597 lambda x: isinstance(x, _TokenizerElement),
598 self.__dict__.values(),
599 )
600 )
601 )
602 tuple_elems: list[_TokenizerElement] = list(
603 flatten(
604 [
605 (
606 [
607 [tup_el] + tup_el.tokenizer_elements()
608 for tup_el in el
609 if isinstance(tup_el, _TokenizerElement)
610 ]
611 if deep
612 else filter(lambda x: isinstance(x, _TokenizerElement), el)
613 )
614 for el in self.__dict__.values()
615 if isinstance(el, tuple)
616 ]
617 )
618 )
619 non_tuple_elems.extend(tuple_elems)
620 return non_tuple_elems
622 def tokenizer_element_tree(self, depth: int = 0, abstract: bool = False) -> str:
623 """
624 Returns a string representation of the tree of tokenizer elements contained in `self`.
626 # Parameters
627 - `depth: int`: Current depth in the tree. Used internally for recursion, no need to specify.
628 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance.
629 """
630 name: str = "\t" * depth + (
631 type(self).__name__
632 if not abstract
633 else type(self)._level_one_subclass().__name__
634 )
635 return (
636 name
637 + "\n"
638 + "".join(
639 el.tokenizer_element_tree(depth + 1, abstract)
640 for el in self.tokenizer_elements(deep=False)
641 )
642 )
644 def tokenizer_element_dict(self) -> dict:
645 """
646 Returns a dictionary representation of the tree of tokenizer elements contained in `self`.
647 """
648 return {
649 type(self).__name__: {
650 key: (
651 val.tokenizer_element_dict()
652 if isinstance(val, _TokenizerElement)
653 else (
654 val
655 if not isinstance(val, tuple)
656 else [
657 (
658 el.tokenizer_element_dict()
659 if isinstance(el, _TokenizerElement)
660 else el
661 )
662 for el in val
663 ]
664 )
665 )
666 for key, val in self.__dict__.items()
667 if key != "_type_"
668 }
669 }
671 @classmethod
672 @abc.abstractmethod
673 def attribute_key(cls) -> str:
674 """Returns the binding used in `MazeTokenizerModular` for that type of `_TokenizerElement`."""
675 raise NotImplementedError
677 def to_tokens(self, *args, **kwargs) -> list[str]:
678 """Converts a maze element into a list of tokens.
679 Not all `_TokenizerElement` subclasses produce tokens, so this is not an abstract method.
680 Those subclasses which do produce tokens should override this method.
681 """
682 raise NotImplementedError
684 @abc.abstractmethod
685 def is_valid(self) -> bool:
686 """Returns if `self` contains data members capable of producing an overall valid `MazeTokenizerModular`.
687 Some `_TokenizerElement` instances may be created which are not useful despite obeying data member type hints.
688 `is_valid` allows for more precise detection of invalid `_TokenizerElement`s beyond type hinting alone.
689 If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply `return True` for that subclass.
691 # Types of Invalidity
692 In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below.
693 Invalidity types, in ascending order of invalidity:
694 - Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study.
695 E.g., `_TokenizerElement`s which are strictly worse than some alternative.
696 - Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
697 - Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
698 - Erroneous: These tokenizers might raise exceptions during use.
700 # Development
701 `is_invalid` is implemented to always return `True` in some abstract classes where all currently possible subclass instances are valid.
702 When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.
704 ## Nesting
705 In general, when implementing this method, there is no need to recursively call `is_valid` on nested `_TokenizerElement`s contained in the class.
706 In other words, failures of `is_valid` need not bubble up to the top of the nested `_TokenizerElement` tree.
707 `MazeTokenizerModular.is_valid` calls `is_valid` on each of its `_TokenizerElement`s individually, so failure at any level will be detected.
709 ## Types of Invalidity
710 If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments.
711 This could be used to create more or less stringent filters on the valid `_TokenizerElement` instances.
712 """
713 raise NotImplementedError
716T = TypeVar("T", bound=_TokenizerElement)
719def mark_as_unsupported(is_valid: Callable[[T], bool], *args) -> T:
720 """mark a _TokenizerElement as unsupported.
722 Classes marked with this decorator won't show up in `get_all_tokenizers()` and thus wont be tested.
723 The classes marked in release 1.0.0 did work reliably before being marked, but they can't be instantiated since the decorator adds an abstract method.
724 The decorator exists to prune the space of tokenizers returned by `all_instances` both for testing and usage.
725 Previously, the space was too large, resulting in impractical runtimes.
726 These decorators could be removed in future releases to expand the space of possible tokenizers.
727 """
729 def wrapper(cls):
730 cls.is_valid = is_valid
731 return cls
733 return wrapper
736class __TokenizerElementNamespace(abc.ABC):
737 """ABC for namespaces
739 # Properties
740 - key: The binding used in `MazeTokenizerModular` for instances of the classes contained within that `__TokenizerElementNamespace`.
741 """
743 key: str = NotImplementedError
746def _load_tokenizer_element(
747 data: dict[str, Any], namespace: type[__TokenizerElementNamespace]
748) -> _TokenizerElement:
749 """Loads a `TokenizerElement` stored via zanj."""
750 key: str = namespace.key
751 format: str = data[key][_FORMAT_KEY]
752 cls_name: str = format.split("(")[0]
753 cls: type[_TokenizerElement] = getattr(namespace, cls_name)
754 kwargs: dict[str, Any] = {
755 k: load_item_recursive(data[key][k], tuple()) for k, v in data[key].items()
756 }
757 if _FORMAT_KEY in kwargs:
758 kwargs.pop(_FORMAT_KEY)
759 return cls(**kwargs)
762class CoordTokenizers(__TokenizerElementNamespace):
763 """Namespace for `_CoordTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
765 key = "coord_tokenizer"
767 @serializable_dataclass(frozen=True, kw_only=True)
768 class _CoordTokenizer(_TokenizerElement, abc.ABC):
769 """
770 Superclass for classes which tokenize singular coords in a maze.
771 """
773 @abc.abstractmethod
774 def to_tokens(self, coord: Coord | CoordTup) -> list[str]:
775 pass
777 @classmethod
778 def attribute_key(cls) -> str:
779 return CoordTokenizers.key
781 def is_valid(self) -> bool:
782 # No invalid instances possible within data member type hint bounds
783 return True
785 @serializable_dataclass(frozen=True, kw_only=True)
786 class UT(_CoordTokenizer):
787 """Unique token coordinate tokenizer."""
789 def to_tokens(self, coord: Coord | CoordTup) -> list[str]:
790 return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])]
792 @serializable_dataclass(frozen=True, kw_only=True)
793 class CTT(_CoordTokenizer):
794 """Coordinate tuple tokenizer
796 # Parameters
797 - `pre`: Whether all coords include an integral preceding delimiter token
798 - `intra`: Whether all coords include a delimiter token between coordinates
799 - `post`: Whether all coords include an integral following delimiter token
800 """
802 pre: bool = serializable_field(default=True)
803 intra: bool = serializable_field(default=True)
804 post: bool = serializable_field(default=True)
805 # Implement methods
807 def to_tokens(self, coord: Coord | CoordTup) -> list[str]:
808 return [
809 *empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"),
810 str(coord[0]),
811 *empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"),
812 str(coord[1]),
813 *empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"),
814 ]
817class EdgeGroupings(__TokenizerElementNamespace):
818 """Namespace for `_EdgeGrouping` subclass hierarchy used by `_AdjListTokenizer`."""
820 key = "edge_grouping"
822 class _GroupingTokenParams(TypedDict):
823 """A uniform private hyperparameter interface used by `AdjListTokenizer`."""
825 connection_token_ordinal: Literal[0, 1, 2]
826 intra: bool
827 grouped: bool
829 @serializable_dataclass(frozen=True, kw_only=True)
830 class _EdgeGrouping(_TokenizerElement, abc.ABC):
831 """Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called a edge grouping."""
833 @classmethod
834 def attribute_key(cls) -> str:
835 return EdgeGroupings.key
837 def is_valid(self) -> bool:
838 return True
840 @abc.abstractmethod
841 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]:
842 """Divides a ConnectionArray into groups of edges.
843 Shuffles/sequences within each group if applicable.
844 """
845 pass
847 @abc.abstractmethod
848 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
849 """Returns the tok.nization hyperparameters necessary for an `AdjListTokenizer` to tokenize.
851 These hyperparameters are not used by `_EdgeGrouping` internally.
852 They are located in `_EdgeGrouping` rather than in `AdjListTokenizer`
853 since the hyperparameter space is a function of the `_EdgeGrouping` subclass.
854 This function resolves the `_EdgeGrouping` hyperparameter space which is non-uniform across subclasses
855 into a uniform private interface used by `AdjListTokenizer`.
856 """
857 pass
859 @serializable_dataclass(frozen=True, kw_only=True)
860 class Ungrouped(_EdgeGrouping):
861 """No grouping occurs, each edge is tokenized individually.
863 # Parameters
864 - `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears.
865 Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization.
866 """
868 connection_token_ordinal: Literal[0, 1, 2] = serializable_field(
869 default=1, assert_type=False
870 )
872 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
873 return EdgeGroupings._GroupingTokenParams(
874 connection_token_ordinal=self.connection_token_ordinal,
875 intra=False,
876 grouped=False,
877 )
879 def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]:
880 return np.expand_dims(edges, 1)
882 @serializable_dataclass(frozen=True, kw_only=True)
883 @mark_as_unsupported(lambda self_: False)
884 class ByLeadingCoord(_EdgeGrouping):
885 """All edges with the same leading coord are grouped together.
887 # Parameters
888 - `intra`: Whether all edge groupings include a delimiter token between individual edge representations.
889 Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `)
890 - `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order.
891 If false, the fixed order is lexicographical by (row, col).
892 In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord.
893 - `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears.
894 Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization.
895 """
897 intra: bool = serializable_field(default=True)
898 shuffle_group: bool = serializable_field(default=True)
899 connection_token_ordinal: Literal[0, 1] = serializable_field(
900 default=0, assert_type=False
901 )
903 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
904 return EdgeGroupings._GroupingTokenParams(
905 connection_token_ordinal=self.connection_token_ordinal,
906 intra=self.intra,
907 grouped=True,
908 )
910 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]:
911 # Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function
912 index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort(
913 (edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0])
914 )
915 sorted_edges: ConnectionArray = edges[index_array, ...]
916 groups: list[ConnectionArray] = np.split(
917 sorted_edges,
918 np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:],
919 )
920 if self.shuffle_group:
921 [numpy_rng.shuffle(g, axis=0) for g in groups]
922 return groups
925class EdgePermuters(__TokenizerElementNamespace):
926 """Namespace for `_EdgePermuter` subclass hierarchy used by `_AdjListTokenizer`."""
928 key = "edge_permuter"
930 @serializable_dataclass(frozen=True, kw_only=True)
931 class _EdgePermuter(_TokenizerElement, abc.ABC):
932 """Specifies how to sequence the two coords that encode a lattice edge."""
934 @classmethod
935 def attribute_key(cls) -> str:
936 return EdgePermuters.key
938 def is_valid(self) -> bool:
939 # No invalid instances possible within data member type hint bounds
940 return True
942 @staticmethod
943 @abc.abstractmethod
944 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
945 """
946 Executes a permutation.
947 Warning: Caller should be aware that `lattice_edges` may be modified in-place depending on the subclass's implementation.
949 # Parameters
950 - `lattice_edges`: Array of lattice edges.
951 The two coords in shape[1] must be adjacent in the lattice.
953 # Returns
954 - Array of lattice edges with entries along shape[1] systematically permuted.
955 - shape[0] of the returned array is NOT guaranteed to match `lattice_edges.shape[1]`.
956 """
957 pass
959 @serializable_dataclass(frozen=True, kw_only=True)
960 class SortedCoords(_EdgePermuter):
961 """returns a sorted representation. useful for checking consistency"""
963 @staticmethod
964 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
965 return lattice_edges[
966 np.lexsort(
967 (
968 lattice_edges[:, 1, 1],
969 lattice_edges[:, 1, 0],
970 lattice_edges[:, 0, 1],
971 lattice_edges[:, 0, 0],
972 )
973 ),
974 ...,
975 ]
977 @serializable_dataclass(frozen=True, kw_only=True)
978 class RandomCoords(_EdgePermuter):
979 """Permutes each edge randomly."""
981 @staticmethod
982 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
983 numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges)
984 return lattice_edges
986 @serializable_dataclass(frozen=True, kw_only=True)
987 class BothCoords(_EdgePermuter):
988 """Includes both possible permutations of every edge in the output.
989 Since input ConnectionList has only 1 instance of each edge,
990 a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`.
991 """
993 @staticmethod
994 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
995 return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0)
998class EdgeSubsets(__TokenizerElementNamespace):
999 """
1000 Namespace for `_EdgeSubset` subclass hierarchy used by `_AdjListTokenizer`.
1001 """
1003 key = "edge_subset"
1005 @serializable_dataclass(frozen=True, kw_only=True)
1006 class _EdgeSubset(_TokenizerElement, abc.ABC):
1007 """
1008 Component of an `AdjListTokenizers._AdjListTokenizer` which specifies the subset of lattice edges to be tokenized.
1009 """
1011 @classmethod
1012 def attribute_key(cls) -> str:
1013 return EdgeSubsets.key
1015 def is_valid(self) -> bool:
1016 return True
1018 @abc.abstractmethod
1019 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
1020 """
1021 Returns the set of lattice edges to be tokenized.
1022 """
1023 pass
1025 @serializable_dataclass(frozen=True, kw_only=True)
1026 class AllLatticeEdges(_EdgeSubset):
1027 """
1028 All 2n**2-2n edges of the lattice are tokenized.
1029 If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`.
1030 """
1032 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
1033 return lattice_connection_array(maze.grid_n)
1035 @serializable_dataclass(frozen=True, kw_only=True)
1036 class ConnectionEdges(_EdgeSubset):
1037 """
1038 Only edges which contain a connection are tokenized.
1039 Alternatively, only edges which contain a wall are tokenized.
1041 # Parameters
1042 - `walls`: Whether wall edges or connection edges are tokenized.
1043 If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`.
1044 """
1046 walls: bool = serializable_field(default=False)
1048 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
1049 conn_list: ConnectionList = maze.connection_list
1050 if self.walls:
1051 conn_list = np.logical_not(conn_list)
1052 conn_list[0, -1, :] = False
1053 conn_list[1, :, -1] = False
1054 return connection_list_to_adj_list(
1055 conn_list, shuffle_d0=False, shuffle_d1=False
1056 )
1059class AdjListTokenizers(__TokenizerElementNamespace):
1060 """Namespace for `_AdjListTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
1062 key = "adj_list_tokenizer"
1064 @serializable_dataclass(frozen=True, kw_only=True)
1065 @mark_as_unsupported(lambda self_: self_.pre is False)
1066 class _AdjListTokenizer(_TokenizerElement, abc.ABC):
1067 """
1068 Specifies how the adjacency list is tokenized.
1069 Tokenization behavior is decomposed into specification of edge subsets, groupings, and permutations.
1070 See documentation of `EdgeSubset` and `EdgeGrouping` classes for more details.
1072 # Parameters
1073 - `pre`: Whether all edge groupings include a preceding delimiter token
1074 - `post`: Whether all edge groupings include a following delimiter token
1075 - `shuffle_d0`: Specifies how to sequence the edge groupings.
1076 If true, groupings are shuffled randomly. If false, groupings are sorted by the leading coord of each group.
1077 - `edge_grouping`: Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called an edge grouping.
1078 - `edge_subset`: Specifies the subset of lattice edges to be tokenized.
1079 - `edge_permuter`: Specifies, in each edge tokenization, which coord either:
1080 1. Appears first in the tokenization, for `AdjListCoord`.
1081 2. Is tokenized directly as a coord, for `AdjListCardinal`.
1082 - `shuffle`: For each edge, the leading coord is selected randomly.
1083 - `all`: Each edge appears twice in the tokenization, appearing with both leading coords.
1084 - `evens`, `odds`: The leading coord is the one belonging to that coord subset. See `EdgeSubsets.ChessboardSublattice` for details.
1085 """
1087 pre: bool = serializable_field(default=False, assert_type=False)
1088 post: bool = serializable_field(default=True)
1089 shuffle_d0: bool = serializable_field(default=True)
1090 edge_grouping: EdgeGroupings._EdgeGrouping = serializable_field(
1091 default=EdgeGroupings.Ungrouped(),
1092 loading_fn=lambda x: _load_tokenizer_element(x, EdgeGroupings),
1093 )
1094 edge_subset: EdgeSubsets._EdgeSubset = serializable_field(
1095 default=EdgeSubsets.ConnectionEdges(),
1096 loading_fn=lambda x: _load_tokenizer_element(x, EdgeSubsets),
1097 )
1098 edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
1099 default=EdgePermuters.RandomCoords(),
1100 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
1101 )
1103 @classmethod
1104 def attribute_key(cls) -> str:
1105 return AdjListTokenizers.key
1107 def is_valid(self) -> bool:
1108 # No invalid instances possible within data member type hint bounds
1109 return True
1111 @abc.abstractmethod
1112 def _tokenization_callables(
1113 self,
1114 edges: ConnectionArray,
1115 is_conn: Bool[np.ndarray, " edges"],
1116 coord_tokenizer: CoordTokenizers._CoordTokenizer,
1117 *args,
1118 **kwargs,
1119 ):
1120 """
1121 Returns a sequence of callables which take an index in `edges` and return parts of that edge tokenization.
1123 # Returns
1124 - `[0]`: leading coord tokens
1125 - `[1]`: connector tokens
1126 - `[2]`: trailing coord tokens
1127 """
1128 pass
1130 def _tokenize_edge_grouping(
1131 self,
1132 edges: ConnectionArray,
1133 maze: LatticeMaze,
1134 coord_tokenizer: CoordTokenizers._CoordTokenizer,
1135 group_params: EdgeGroupings._GroupingTokenParams,
1136 ) -> Sequence[str]:
1137 """
1138 Tokenizes a single edge grouping.
1139 """
1140 cxn_ord: int = group_params["connection_token_ordinal"]
1141 is_conn: Bool[np.ndarray, "edges"] = is_connection(
1142 edges, maze.connection_list
1143 )
1144 tokenize_callables = self._tokenization_callables(
1145 edges, is_conn, coord_tokenizer
1146 )
1148 if group_params["grouped"]:
1149 # If grouped
1150 callable_permutation: list[int] = [1, 2] if cxn_ord == 0 else [2, 1]
1151 repeated_callables = [
1152 tokenize_callables[i] for i in callable_permutation
1153 ]
1154 return flatten(
1155 [
1156 tokenize_callables[0](0),
1157 [
1158 [
1159 *[
1160 tok_callable(i)
1161 for tok_callable in repeated_callables
1162 ],
1163 *(
1164 (VOCAB.ADJLIST_INTRA,)
1165 if group_params["intra"]
1166 else ()
1167 ),
1168 ]
1169 for i in range(edges.shape[0])
1170 ],
1171 ]
1172 )
1173 else:
1174 # If ungrouped
1175 callable_permutation = [0, 2]
1176 callable_permutation.insert(cxn_ord, 1)
1177 tokenize_callables = [
1178 tokenize_callables[i] for i in callable_permutation
1179 ]
1181 return flatten(
1182 [
1183 [
1184 [
1185 *[
1186 tok_callable(i)
1187 for tok_callable in tokenize_callables
1188 ],
1189 *empty_sequence_if_attr_false(
1190 (VOCAB.ADJLIST_INTRA,), group_params, "intra"
1191 ),
1192 ]
1193 for i in range(edges.shape[0])
1194 ]
1195 ]
1196 )
1198 def to_tokens(
1199 self, maze: LatticeMaze, coord_tokenizer: CoordTokenizers._CoordTokenizer
1200 ) -> list[str]:
1201 # Get the set of edges to be tokenized
1202 edges: ConnectionArray = self.edge_subset._get_edges(maze)
1203 # Systematically permute the leading coord of each edge
1204 edges: ConnectionArray = self.edge_permuter._permute(edges)
1205 group_params: EdgeGroupings._GroupingTokenParams = (
1206 self.edge_grouping._token_params()
1207 )
1208 # then, we need to group the edges
1209 groups: Sequence[ConnectionArray] = self.edge_grouping._group_edges(edges)
1210 # shuffle the groups if specified
1211 if self.shuffle_d0:
1212 if isinstance(groups, np.ndarray):
1213 numpy_rng.shuffle(groups, axis=0)
1214 elif isinstance(groups, list):
1215 random.shuffle(groups)
1216 else:
1217 raise TypeError(
1218 f"`groups` is an unexpected type {type(groups)}. Only types `list` and `np.ndarray` are currently supported."
1219 )
1220 # Tokenize each group with optional delimiters
1221 tokens: list[str] = list(
1222 flatten(
1223 [
1224 [
1225 *empty_sequence_if_attr_false(
1226 (VOCAB.ADJLIST_PRE,), self, "pre"
1227 ),
1228 *self._tokenize_edge_grouping(
1229 group, maze, coord_tokenizer, group_params
1230 ),
1231 *empty_sequence_if_attr_false(
1232 (VOCAB.ADJACENCY_ENDLINE,), self, "post"
1233 ),
1234 ]
1235 for group in groups
1236 ]
1237 )
1238 )
1239 return tokens
1241 @serializable_dataclass(frozen=True, kw_only=True)
1242 class AdjListCoord(_AdjListTokenizer):
1243 """Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members."""
1245 edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
1246 default=EdgePermuters.RandomCoords(),
1247 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
1248 )
1250 def _tokenization_callables(
1251 self,
1252 edges: ConnectionArray,
1253 is_conn: Bool[np.ndarray, " edges"],
1254 coord_tokenizer: CoordTokenizers._CoordTokenizer,
1255 *args,
1256 **kwargs,
1257 ):
1258 # Map from `is_conn` to the tokens which represent connections and walls
1259 conn_token_map: dict[bool, str] = {
1260 True: VOCAB.CONNECTOR,
1261 False: VOCAB.ADJLIST_WALL,
1262 }
1263 return [
1264 lambda i: coord_tokenizer.to_tokens(edges[i, 0]),
1265 lambda i: conn_token_map[is_conn[i]],
1266 lambda i: coord_tokenizer.to_tokens(edges[i, 1]),
1267 ]
1269 @serializable_dataclass(frozen=True, kw_only=True)
1270 class AdjListCardinal(_AdjListTokenizer):
1271 """Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members.
1273 # Parameters
1274 - `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens.
1275 """
1277 edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
1278 default=EdgePermuters.BothCoords(),
1279 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
1280 )
1282 def _tokenization_callables(
1283 self,
1284 edges: ConnectionArray,
1285 is_conn: Bool[np.ndarray, " edges"],
1286 coord_tokenizer: CoordTokenizers._CoordTokenizer,
1287 *args,
1288 **kwargs,
1289 ):
1290 # Map from `is_conn` to the tokens which represent connections and walls
1291 conn_token_map: dict[bool, str] = {
1292 True: VOCAB.CONNECTOR,
1293 False: VOCAB.ADJLIST_WALL,
1294 }
1295 return [
1296 lambda i: coord_tokenizer.to_tokens(edges[i, 0]),
1297 lambda i: conn_token_map[is_conn[i]],
1298 lambda i: get_cardinal_direction(edges[i]),
1299 ]
1302class TargetTokenizers(__TokenizerElementNamespace):
1303 """Namespace for `_TargetTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
1305 key = "target_tokenizer"
1307 @serializable_dataclass(frozen=True, kw_only=True)
1308 class _TargetTokenizer(_TokenizerElement, abc.ABC):
1309 """Superclass of tokenizers for maze targets."""
1311 @abc.abstractmethod
1312 def to_tokens(
1313 self,
1314 targets: Sequence[Coord],
1315 coord_tokenizer: CoordTokenizers._CoordTokenizer,
1316 ) -> list[str]:
1317 """Returns tokens representing the target."""
1318 pass
1320 @classmethod
1321 def attribute_key(cls) -> str:
1322 return TargetTokenizers.key
1324 @serializable_dataclass(frozen=True, kw_only=True)
1325 class Unlabeled(_TargetTokenizer):
1326 """Targets are simply listed as coord tokens.
1327 - `post`: Whether all coords include an integral following delimiter token
1328 """
1330 post: bool = serializable_field(default=False)
1332 def to_tokens(
1333 self,
1334 targets: Sequence[Coord],
1335 coord_tokenizer: CoordTokenizers._CoordTokenizer,
1336 ) -> list[str]:
1337 return list(
1338 flatten(
1339 [
1340 [
1341 *coord_tokenizer.to_tokens(target),
1342 *empty_sequence_if_attr_false(
1343 [VOCAB.TARGET_POST], self, "post"
1344 ),
1345 ]
1346 for target in targets
1347 ]
1348 )
1349 )
1351 def is_valid(self) -> bool:
1352 # No invalid instances possible within data member type hint bounds
1353 return True
1356class StepSizes(__TokenizerElementNamespace):
1357 """Namespace for `_StepSize` subclass hierarchy used by `MazeTokenizerModular`."""
1359 key = "step_size"
1361 @serializable_dataclass(frozen=True, kw_only=True)
1362 class _StepSize(_TokenizerElement, abc.ABC):
1363 """
1364 Specifies which coords in `maze.solution` are used to represent the path.
1365 """
1367 @classmethod
1368 def attribute_key(cls) -> str:
1369 return StepSizes.key
1371 @abc.abstractmethod # TODO: make this a static/class method, allowing ForksAndStraightaways to skip object construction at every call
1372 def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
1373 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
1374 raise NotImplementedError(
1375 "Subclasses must implement `StepSize.step_indices."
1376 )
1378 def step_start_end_indices(self, maze) -> list[tuple[int, int]]:
1379 """Returns steps as tuples of starting and ending positions for each step."""
1380 indices: list[int] = self._step_single_indices(maze)
1381 return [(start, end) for start, end in zip(indices[:-1], indices[1:])]
1383 def is_valid(self) -> bool:
1384 # No invalid instances possible within data member type hint bounds
1385 return True
1387 @serializable_dataclass(frozen=True, kw_only=True)
1388 class Singles(_StepSize):
1389 """
1390 Every coord in `maze.solution` is represented.
1391 Legacy tokenizers all use this behavior.
1392 """
1394 def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
1395 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
1396 return list(range(maze.solution.shape[0]))
1398 @serializable_dataclass(frozen=True, kw_only=True)
1399 @mark_as_unsupported(lambda self_: False)
1400 class Straightaways(_StepSize):
1401 """
1402 Only coords where the path turns are represented in the path.
1403 I.e., the path is represented as a sequence of straightaways,
1404 specified by the coords at the turns.
1405 """
1407 def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
1408 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
1409 last_turn_coord: Coord = maze.solution[0, ...]
1410 indices: list[int] = [0]
1411 for i, coord in enumerate(maze.solution):
1412 if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]:
1413 indices.append(i - 1)
1414 last_turn_coord = maze.solution[i - 1, ...]
1415 indices.append(i)
1416 return indices
1418 @serializable_dataclass(frozen=True, kw_only=True)
1419 class Forks(_StepSize):
1420 """
1421 Only coords at forks, where the path has >=2 options for the next step are included.
1422 Excludes the option of backtracking.
1423 The starting and ending coords are always included.
1424 """
1426 def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
1427 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
1428 return maze.get_solution_forking_points(always_include_endpoints=True)[0]
1430 @serializable_dataclass(frozen=True, kw_only=True)
1431 @mark_as_unsupported(lambda self_: False)
1432 class ForksAndStraightaways(_StepSize):
1433 """
1434 Includes the union of the coords included by `Forks` and `Straightaways`.
1435 See documentation for those classes for details.
1436 """
1438 def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
1439 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
1440 return list(
1441 np.unique(
1442 np.concatenate(
1443 (
1444 StepSizes.Straightaways()._step_single_indices(maze),
1445 StepSizes.Forks()._step_single_indices(maze),
1446 )
1447 )
1448 )
1449 )
1452class StepTokenizers(__TokenizerElementNamespace):
1453 """Namespace for `_StepTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
1455 key = "step_tokenizers"
1457 @serializable_dataclass(frozen=True, kw_only=True)
1458 class _StepTokenizer(_TokenizerElement, abc.ABC):
1459 """
1460 Specifies how a single step (as specified by an instance of `_StepSize`) is tokenized.
1461 """
1463 @classmethod
1464 def attribute_key(cls) -> str:
1465 return StepTokenizers.key
1467 @abc.abstractmethod
1468 def to_tokens(
1469 self,
1470 maze: SolvedMaze,
1471 start_index: int,
1472 end_index: int,
1473 **kwargs,
1474 ) -> list[str]:
1475 """Tokenizes a single step in the solution.
1477 # Parameters
1478 - `maze`: Maze to be tokenized
1479 - `start_index`: The index of the Coord in `maze.solution` at which the current step starts
1480 - `end_index`: The index of the Coord in `maze.solution` at which the current step ends
1481 """
1482 raise NotImplementedError(
1483 "Subclasses must implement `StepTokenizer.to_tokens."
1484 )
1486 def is_valid(self) -> bool:
1487 # No invalid instances possible within data member type hint bounds
1488 return True
1490 @serializable_dataclass(frozen=True, kw_only=True)
1491 class Coord(_StepTokenizer):
1492 """
1493 A direct tokenization of the end position coord represents the step.
1494 """
1496 def to_tokens(
1497 self,
1498 maze: SolvedMaze,
1499 start_index: int,
1500 end_index: int,
1501 coord_tokenizer: CoordTokenizers._CoordTokenizer,
1502 ) -> list[str]:
1503 return coord_tokenizer.to_tokens(maze.solution[end_index, ...])
1505 @serializable_dataclass(frozen=True, kw_only=True)
1506 class Cardinal(_StepTokenizer):
1507 """
1508 A step is tokenized with a cardinal direction token.
1509 It is the direction of the step from the starting position along the solution.
1510 """
1512 def to_tokens(
1513 self, maze: SolvedMaze, start_index: int, end_index: int, **kwargs
1514 ) -> list[str]:
1515 return [
1516 get_cardinal_direction(maze.solution[start_index : start_index + 2])
1517 ]
1519 @serializable_dataclass(frozen=True, kw_only=True)
1520 class Relative(_StepTokenizer):
1521 """Tokenizes a solution step using relative first-person directions (right, left, forward, etc.).
1522 To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH.
1523 Similarly to `Cardinal`, the direction is that of the step from the starting position.
1524 """
1526 def to_tokens(
1527 self, maze: SolvedMaze, start_index: int, end_index: int, **kwargs
1528 ) -> list[str]:
1529 if start_index == 0:
1530 start = maze.solution[0]
1531 previous = start + np.array([1, 0])
1532 return [
1533 get_relative_direction(
1534 np.concatenate(
1535 (
1536 np.expand_dims(previous, 0),
1537 maze.solution[start_index : start_index + 2],
1538 ),
1539 axis=0,
1540 )
1541 )
1542 ]
1543 return [
1544 get_relative_direction(maze.solution[start_index - 1 : start_index + 2])
1545 ]
1547 @serializable_dataclass(frozen=True, kw_only=True)
1548 class Distance(_StepTokenizer):
1549 """
1550 A count of the number of individual steps from the starting point to the end point.
1551 Contains no information about directionality, only the distance traveled in the step.
1552 `Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`.
1553 This constraint is enforced in `_PathTokenizer.is_valid`.
1554 """
1556 def to_tokens(
1557 self, maze: SolvedMaze, start_index: int, end_index: int, **kwargs
1558 ) -> list[str]:
1559 d: int = end_index - start_index
1560 return [getattr(VOCAB, f"I_{d:03}")]
1562 """
1563 `StepTokenizerPermutation`
1564 A sequence of unique `_StepTokenizer`s.
1565 This type exists mostly just for the clarity and convenience of `_PathTokenizer` code.
1566 """
1567 StepTokenizerPermutation: type = (
1568 tuple[_StepTokenizer]
1569 | tuple[_StepTokenizer, _StepTokenizer]
1570 | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer]
1571 | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer, _StepTokenizer]
1572 )
1575class PathTokenizers(__TokenizerElementNamespace):
1576 """Namespace for `_PathTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
1578 key = "path_tokenizer"
1580 @serializable_dataclass(frozen=True, kw_only=True)
1581 class _PathTokenizer(_TokenizerElement, abc.ABC):
1582 """Superclass of tokenizers for maze solution paths."""
1584 @abc.abstractmethod
1585 def to_tokens(
1586 self, maze: SolvedMaze, coord_tokenizer: CoordTokenizers._CoordTokenizer
1587 ) -> list[str]:
1588 """Returns tokens representing the solution path."""
1589 pass
1591 @classmethod
1592 def attribute_key(cls) -> str:
1593 return PathTokenizers.key
1595 @serializable_dataclass(frozen=True, kw_only=True)
1596 class StepSequence(_PathTokenizer, abc.ABC):
1597 """Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path.
1598 Allows for a sequence of leading and trailing tokens which don't fit the step pattern.
1600 # Parameters
1601 - `step_size`: Selects the size of a single step in the sequence
1602 - `step_tokenizers`: Selects the combination and permutation of tokens
1603 - `pre`: Whether all steps include an integral preceding delimiter token
1604 - `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization.
1605 - `post`: Whether all steps include an integral following delimiter token
1606 """
1608 step_size: StepSizes._StepSize = serializable_field(
1609 default=StepSizes.Singles(),
1610 loading_fn=lambda x: _load_tokenizer_element(x, StepSizes),
1611 )
1612 step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field(
1613 default=(StepTokenizers.Coord(),),
1614 serialization_fn=lambda x: [y.serialize() for y in x],
1615 loading_fn=lambda x: tuple(x[StepTokenizers.key]),
1616 )
1617 pre: bool = serializable_field(default=False)
1618 intra: bool = serializable_field(default=False)
1619 post: bool = serializable_field(default=False)
1621 def to_tokens(
1622 self, maze: SolvedMaze, coord_tokenizer: CoordTokenizers._CoordTokenizer
1623 ) -> list[str]:
1624 return [
1625 *self._leading_tokens(maze, coord_tokenizer),
1626 *flatten(
1627 [
1628 self._single_step_tokens(maze, start, end, coord_tokenizer)
1629 for start, end in self.step_size.step_start_end_indices(maze)
1630 ]
1631 ),
1632 *self._trailing_tokens(maze, coord_tokenizer),
1633 ]
1635 def _single_step_tokens(
1636 self,
1637 maze: SolvedMaze,
1638 i: int,
1639 j: int,
1640 coord_tokenizer: CoordTokenizers._CoordTokenizer,
1641 ) -> list[str]:
1642 """Returns the token sequence representing a single step along the path."""
1643 step_rep_tokens: list[list[str]] = [
1644 step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer)
1645 for step_tokenizer in self.step_tokenizers
1646 ]
1647 if self.intra:
1648 step_rep_tokens_and_intra: list[str] = [None] * (
1649 len(step_rep_tokens) * 2
1650 )
1651 step_rep_tokens_and_intra[::2] = step_rep_tokens
1652 step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len(
1653 step_rep_tokens
1654 )
1655 step_rep_tokens = list(flatten(step_rep_tokens_and_intra))
1656 all_tokens: list[str] = [
1657 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"),
1658 *flatten(step_rep_tokens),
1659 *empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"),
1660 ]
1661 return all_tokens
1663 def _leading_tokens(
1664 self, maze: SolvedMaze, coord_tokenizer: CoordTokenizers._CoordTokenizer
1665 ) -> list[str]:
1666 """Returns tokens preceding those from the sequence from `_single_step_tokens`.
1667 Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`.
1668 <PATH_START> should NOT be included.
1669 """
1670 if StepTokenizers.Coord() in self.step_tokenizers:
1671 return [
1672 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"),
1673 *coord_tokenizer.to_tokens(maze.solution[0, ...]),
1674 *empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"),
1675 ]
1676 return []
1678 def _trailing_tokens(
1679 self, c: Coord, coord_tokenizer: CoordTokenizers._CoordTokenizer
1680 ) -> list[str]:
1681 """Returns tokens following those from the sequence from `_single_step_tokens`.
1682 <PATH_END> should NOT be included.
1683 """
1684 return []
1686 def is_valid(self) -> bool:
1687 if len(set(self.step_tokenizers)) != len(self.step_tokenizers):
1688 # Uninteresting: repeated elements are not useful
1689 return False
1691 if len(self.step_tokenizers) == 1 and isinstance(
1692 self.step_tokenizers[0], StepTokenizers.Distance
1693 ):
1694 # Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required.
1695 return False
1696 else:
1697 return True
1700class PromptSequencers(__TokenizerElementNamespace):
1701 """Namespace for `_PromptSequencer` subclass hierarchy used by `MazeTokenizerModular`."""
1703 key = "prompt_sequencer"
1705 @serializable_dataclass(frozen=True, kw_only=True)
1706 class _PromptSequencer(_TokenizerElement, abc.ABC):
1707 """
1708 Sequences token regions into a complete maze tokenization.
1710 # Parameters
1711 - `coord_tokenizer`: Tokenizer element which tokenizes a single `Coord` aka maze position.
1712 - `adj_list_tokenizer`: Tokenizer element which tokenizes the adjacency list of a `LatticeMaze`.
1713 Uses `coord_tokenizer` to tokenize coords if needed in other `TokenizerElement`s.
1714 """
1716 coord_tokenizer: CoordTokenizers._CoordTokenizer = serializable_field(
1717 default=CoordTokenizers.UT(),
1718 loading_fn=lambda x: _load_tokenizer_element(x, CoordTokenizers),
1719 )
1720 adj_list_tokenizer: AdjListTokenizers._AdjListTokenizer = serializable_field(
1721 default=AdjListTokenizers.AdjListCoord(),
1722 loading_fn=lambda x: _load_tokenizer_element(x, AdjListTokenizers),
1723 )
1725 @classmethod
1726 def attribute_key(cls) -> str:
1727 return PromptSequencers.key
1729 @staticmethod
1730 def _trim_if_unsolved_maze(
1731 untrimmed: list[str], is_untargeted: bool = False, is_unsolved: bool = False
1732 ):
1733 """Trims a full `SolvedMaze` prompt if the maze data reflects an unsolved or untargeted maze.
1735 # Development
1736 This implementation should function for `AOTP`, `AOP`, and other concrete classes using any subsequence of AOTP.
1737 It is not located in `token_utils.py` because it may need to be overridden in more exotic `PromptSequencer` subclasses.
1738 """
1739 if is_untargeted:
1740 return tokens_between(
1741 untrimmed,
1742 VOCAB.ADJLIST_START,
1743 VOCAB.ADJLIST_END,
1744 include_start=True,
1745 include_end=True,
1746 )
1747 if is_unsolved:
1748 if VOCAB.TARGET_END in untrimmed:
1749 return tokens_between(
1750 untrimmed,
1751 VOCAB.ADJLIST_START,
1752 VOCAB.TARGET_END,
1753 include_start=True,
1754 include_end=True,
1755 )
1756 else:
1757 return tokens_between(
1758 untrimmed,
1759 VOCAB.ADJLIST_START,
1760 VOCAB.ORIGIN_END,
1761 include_start=True,
1762 include_end=True,
1763 )
1764 return untrimmed
1766 def to_tokens(
1767 self,
1768 maze: LatticeMaze,
1769 *args,
1770 **kwargs,
1771 ) -> list[str]:
1772 """Returns a complete list of tokens for a given set of maze elements."""
1773 untrimmed: list[str] = self._sequence_tokens(
1774 *self._get_prompt_regions(maze)
1775 )
1776 return self._trim_if_unsolved_maze(
1777 untrimmed, not hasattr(maze, "start_pos"), not hasattr(maze, "solution")
1778 )
1780 def _get_prompt_regions(
1781 self,
1782 maze: LatticeMaze,
1783 *args,
1784 **kwargs,
1785 ) -> list[list[str]]:
1786 """Gets the prompt regions of a maze in a fixed sequence.
1788 This method is NOT responsible for including/excluding any prompt regions.
1789 Always return according to the API described under Returns.
1790 This implementation is expected to be suitable for most `PromptSequencer` subclasses.
1791 Subclasses may override this method if needed for special behavior.
1793 # Returns
1794 - [0]: list[str] Adjacency list tokens
1795 - [1]: list[str] Origin tokens
1796 - [2]: list[str] Target tokens
1797 - [3]: list[str] Path tokens
1799 # `None`-valued Args
1800 If one or more of `origin`, `target`, or `path` are `None`, that indicates that an unsolved or untargeted maze is being tokenized.
1801 To ensure unpackability in `_sequence_tokens`, these `None` values are substituted for empty iterables.
1802 """
1803 origin: Coord | None = getattr(maze, "start_pos", None)
1804 target: list[Coord] | None = [
1805 getattr(maze, "end_pos", None)
1806 ] # TargetTokenizer requires target: Sequence[Coord]
1808 return [
1809 (
1810 self.adj_list_tokenizer.to_tokens(
1811 maze, coord_tokenizer=self.coord_tokenizer
1812 )
1813 if hasattr(self, "adj_list_tokenizer")
1814 else []
1815 ),
1816 self.coord_tokenizer.to_tokens(origin) if origin is not None else [],
1817 (
1818 self.target_tokenizer.to_tokens(
1819 target, coord_tokenizer=self.coord_tokenizer
1820 )
1821 if target[0] is not None and hasattr(self, "target_tokenizer")
1822 else []
1823 ),
1824 (
1825 self.path_tokenizer.to_tokens(
1826 maze, coord_tokenizer=self.coord_tokenizer
1827 )
1828 if hasattr(maze, "solution") and hasattr(self, "path_tokenizer")
1829 else []
1830 ),
1831 ]
1833 @abc.abstractmethod
1834 def _sequence_tokens(
1835 self,
1836 adj_list: list[str],
1837 origin: list[str] | None,
1838 target: list[str] | None,
1839 path: list[str] | None,
1840 ) -> list[str]:
1841 """Sequences token regions into a complete prompt.
1842 Includes any boundary tokens in `constants.SPECIAL_TOKENS` such as <ADJLIST_START>, <ORIGIN_END>, etc.
1843 # Parameters
1844 - `adj_list`: Tokens representing the adjacency list
1845 - `origin`: Tokens representing the origin
1846 - `target`: Tokens representing the target
1847 - `path`: Tokens representing the path
1848 """
1849 pass
1851 def is_valid(self) -> bool:
1852 # No invalid instances possible within data member type hint bounds
1853 return True
1855 @serializable_dataclass(frozen=True, kw_only=True)
1856 class AOTP(_PromptSequencer):
1857 """
1858 Sequences a prompt as [adjacency list, origin, target, path].
1860 # Parameters
1861 - `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`.
1862 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`.
1863 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`.
1864 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`.
1866 """
1868 target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field(
1869 default=TargetTokenizers.Unlabeled(),
1870 loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers),
1871 )
1872 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field(
1873 default=PathTokenizers.StepSequence(),
1874 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers),
1875 )
1877 def _sequence_tokens(
1878 self,
1879 adj_list: list[str],
1880 origin: list[str],
1881 target: list[str],
1882 path: list[str],
1883 ) -> list[str]:
1884 return [
1885 VOCAB.ADJLIST_START,
1886 *adj_list,
1887 VOCAB.ADJLIST_END,
1888 VOCAB.ORIGIN_START,
1889 *origin,
1890 VOCAB.ORIGIN_END,
1891 VOCAB.TARGET_START,
1892 *target,
1893 VOCAB.TARGET_END,
1894 VOCAB.PATH_START,
1895 *path,
1896 VOCAB.PATH_END,
1897 ]
1899 @serializable_dataclass(frozen=True, kw_only=True)
1900 class AOP(_PromptSequencer):
1901 """Sequences a prompt as [adjacency list, origin, path].
1902 Still includes "<TARGET_START>" and "<TARGET_END>" tokens, but no representation of the target itself.
1904 # Parameters
1905 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`.
1906 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`.
1907 """
1909 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field(
1910 default=PathTokenizers.StepSequence(),
1911 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers),
1912 )
1914 def _sequence_tokens(
1915 self,
1916 adj_list: list[str],
1917 origin: list[str],
1918 target: list[str],
1919 path: list[str],
1920 ) -> list[str]:
1921 return [
1922 VOCAB.ADJLIST_START,
1923 *adj_list,
1924 VOCAB.ADJLIST_END,
1925 VOCAB.ORIGIN_START,
1926 *origin,
1927 VOCAB.ORIGIN_END,
1928 VOCAB.TARGET_START,
1929 VOCAB.TARGET_END,
1930 VOCAB.PATH_START,
1931 *path,
1932 VOCAB.PATH_END,
1933 ]
1936@serializable_dataclass(
1937 frozen=True,
1938 kw_only=True,
1939 properties_to_serialize=["tokenizer_element_tree_concrete", "name"],
1940)
1941class MazeTokenizerModular(SerializableDataclass):
1942 """Tokenizer for mazes
1944 # Parameters
1945 - `prompt_sequencer`: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt.
1947 # Development
1948 - To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy `TokenizationMode.AOTP_UT_Uniform`.
1949 - Furthermore, the mapping reflected in `from_legacy` must also be maintained.
1950 - Updates to `MazeTokenizerModular` or the `_TokenizerElement` hierarchy must maintain that behavior.
1951 """
1953 prompt_sequencer: PromptSequencers._PromptSequencer = serializable_field(
1954 default=PromptSequencers.AOTP(),
1955 loading_fn=lambda x: _load_tokenizer_element(x, PromptSequencers),
1956 )
1958 def hash_int(self) -> int:
1959 return int.from_bytes(
1960 hashlib.blake2b(self.name.encode("utf-8")).digest(),
1961 byteorder="big",
1962 )
1964 def __hash__(self):
1965 "Stable hash to identify unique `MazeTokenizerModular` instances. uses name"
1966 return self.hash_int()
1968 def hash_b64(self, n_bytes: int = 8) -> str:
1969 """filename-safe base64 encoding of the hash"""
1970 # Use modulus to ensure the integer fits within n_bytes * 8 bits
1971 hash_mod: int = self.hash_int() % (1 << (n_bytes * 8))
1973 encoded = base64.b64encode(
1974 hash_mod.to_bytes(n_bytes, byteorder="big"), altchars=b"-_"
1975 ).decode()
1977 # Remove any padding equals signs
1978 return encoded.rstrip("=")
1980 # Information Querying Methods
1982 @cached_property
1983 def tokenizer_elements(self) -> list[_TokenizerElement]:
1984 return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()]
1986 def tokenizer_element_tree(self, abstract: bool = False) -> str:
1987 """
1988 Returns a string representation of the tree of tokenizer elements contained in `self`.
1990 # Parameters
1991 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance.
1992 """
1994 return "\n".join(
1995 [
1996 type(self).__name__,
1997 self.prompt_sequencer.tokenizer_element_tree(
1998 abstract=abstract, depth=1
1999 ),
2000 ]
2001 )
2003 @property
2004 def tokenizer_element_tree_concrete(self):
2005 """
2006 Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`.
2007 """
2008 return self.tokenizer_element_tree()
2010 def tokenizer_element_dict(self) -> dict:
2011 """
2012 Nested dictionary of the internal `TokenizerElement`s.
2013 """
2014 return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()}
2016 @property
2017 def name(self) -> str:
2018 """Serializes MazeTokenizer into a key for encoding in zanj"""
2019 return "-".join([type(self).__name__, self.prompt_sequencer.name])
2021 def summary(self) -> dict[str, str]:
2022 """
2023 Single-level dictionary of the internal `TokenizerElement`s.
2024 """
2025 return {
2026 # "prompt_sequencer": self.prompt_sequencer.name,
2027 **{elem.attribute_key(): elem.name for elem in self.tokenizer_elements}
2028 }
2030 @staticmethod
2031 def _type_check(obj: any) -> None:
2032 """Helper method for `has_element`"""
2033 if not (
2034 isinstance(obj, _TokenizerElement)
2035 or (isinstance(obj, type) and issubclass(obj, _TokenizerElement))
2036 ):
2037 raise TypeError(f"{obj} is not a `_TokenizerElement` instance or subclass.")
2039 def _has_element_singular(self, el: type[_TokenizerElement] | _TokenizerElement):
2040 """Helper method for `has_element`"""
2041 self._type_check(el)
2042 if isinstance(el, type):
2043 return any([isinstance(e, el) for e in self.tokenizer_elements])
2044 else:
2045 return el in self.tokenizer_elements
2047 def has_element(
2048 self,
2049 *elements: Sequence[type[_TokenizerElement] | _TokenizerElement],
2050 ) -> bool:
2051 """Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`.
2053 Querying with a partial subset of `_TokenizerElement` fields is not currently supported.
2054 To do such a query, assemble multiple calls to `has_elements`.
2056 # Parameters
2057 - `elements`: Singleton or iterable of `_TokenizerElement` instances or classes.
2058 If an instance is provided, then comparison is done via instance equality.
2059 If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted.
2060 """
2061 if len(elements) == 1 and isinstance(elements[0], Iterable):
2062 elements = elements[0]
2063 return all([self._has_element_singular(e) for e in elements])
2065 def is_valid(self):
2066 """
2067 Returns `True` if `self` is a valid tokenizer.
2068 Evaluates the validity of all of `self.tokenizer_elements` according to each one's method.
2069 """
2070 return all([el.is_valid() for el in self.tokenizer_elements])
2072 def is_legacy_equivalent(self) -> bool:
2073 """Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`."""
2074 return any(
2075 [
2076 self == MazeTokenizerModular.from_legacy(tok_mode)
2077 for tok_mode in TokenizationMode
2078 ]
2079 )
2081 def is_tested_tokenizer(self, do_assert: bool = False) -> bool:
2082 """Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers.
2084 Since evaluating `all_tokenizers.get_all_tokenizers` is expensive,
2085 instead checks for membership of `self`'s hash in `get_all_tokenizer_hashes()`.
2087 if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested.
2088 """
2089 all_tokenizer_hashes: Int64[np.ndarray, " n_tokenizers"] = (
2090 get_all_tokenizer_hashes()
2091 )
2092 hash_index: int = np.searchsorted(all_tokenizer_hashes, hash(self))
2094 in_range: bool = hash_index < len(all_tokenizer_hashes)
2095 hashes_match: bool = all_tokenizer_hashes[hash_index] == hash(self)
2096 is_valid: bool = self.is_valid()
2098 if do_assert:
2099 assert in_range, (
2100 f"{hash_index = } is invalid, must be at most {len(all_tokenizer_hashes) - 1}"
2101 )
2102 assert hashes_match, (
2103 f"{all_tokenizer_hashes[hash_index] = } != {hash(self) = }"
2104 )
2105 assert is_valid, "self.is_valid returns False"
2106 return True
2107 else:
2108 return in_range and hashes_match and is_valid
2110 def is_AOTP(self) -> bool:
2111 return self.has_element(PromptSequencers.AOTP)
2113 def is_UT(self) -> bool:
2114 return self.has_element(CoordTokenizers.UT)
2116 # Alternate Constructors
2117 # ======================
2119 @classmethod
2120 def from_legacy(
2121 cls, legacy_maze_tokenizer: MazeTokenizer | TokenizationMode
2122 ) -> "MazeTokenizerModular":
2123 """Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance."""
2124 if isinstance(legacy_maze_tokenizer, MazeTokenizer):
2125 legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode
2126 return {
2127 TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(),
2128 TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(),
2129 TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular(
2130 prompt_sequencer=PromptSequencers.AOTP(
2131 coord_tokenizer=CoordTokenizers.CTT()
2132 )
2133 ),
2134 }[legacy_maze_tokenizer]
2136 # Simple properties
2137 # =================
2138 @classmethod
2139 def from_tokens(
2140 cls,
2141 tokens: str | list[str],
2142 ) -> "MazeTokenizerModular":
2143 """
2144 Infers most `MazeTokenizerModular` parameters from a full sequence of tokens.
2145 """
2146 raise NotImplementedError(
2147 "Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported"
2148 )
2150 @property
2151 def token_arr(self) -> list[str] | None:
2152 """map from index to token"""
2153 return VOCAB_LIST
2155 @property
2156 def tokenizer_map(self) -> dict[str, int]:
2157 """map from token to index"""
2158 return VOCAB_TOKEN_TO_INDEX
2160 @property
2161 def vocab_size(self) -> int:
2162 """Number of tokens in the static vocab"""
2163 return len(VOCAB_LIST)
2165 @property
2166 def n_tokens(self) -> int:
2167 raise NameError(
2168 "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead."
2169 )
2171 @property
2172 def padding_token_index(self) -> int:
2173 return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING]
2175 # conversion functions
2176 # ============================================================
2178 def to_tokens(
2179 self,
2180 maze: LatticeMaze,
2181 ) -> list[str]:
2182 """Converts maze into a list of tokens."""
2183 return self.prompt_sequencer.to_tokens(maze)
2185 def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]:
2186 return list(
2187 flatten(
2188 [self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords]
2189 )
2190 )
2192 @overload
2193 def strings_to_coords(
2194 cls,
2195 text: str | list[str],
2196 when_noncoord: Literal["skip"] = "skip",
2197 ) -> list[CoordTup]: ...
2198 @overload
2199 def strings_to_coords(
2200 cls,
2201 text: str | list[str],
2202 when_noncoord: Literal["error"] = "error",
2203 ) -> list[CoordTup]: ...
2204 @overload
2205 def strings_to_coords(
2206 cls,
2207 text: str | list[str],
2208 when_noncoord: Literal["include"] = "include",
2209 ) -> list[str | CoordTup]: ...
2210 @classmethod
2211 def strings_to_coords(
2212 cls,
2213 text: str | list[str],
2214 when_noncoord: WhenMissing = "skip",
2215 ) -> list[str | CoordTup]:
2216 warnings.warn(
2217 "`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.",
2218 TokenizerPendingDeprecationWarning,
2219 )
2220 return strings_to_coords(text=text, when_noncoord=when_noncoord)
2222 @staticmethod
2223 def encode(text: str | list[str]) -> list[int]:
2224 """encode a string or list of strings into a list of tokens"""
2225 try:
2226 if isinstance(text, str):
2227 text = text.split()
2228 return [VOCAB_TOKEN_TO_INDEX[token] for token in text]
2229 except KeyError as e:
2230 raise TokenError(
2231 f"Token {e} not found",
2232 "in `VOCAB`.",
2233 ) from e
2235 @staticmethod
2236 def decode(
2237 token_ids: Sequence[int], joined_tokens: bool = False
2238 ) -> list[str] | str:
2239 """decode a list of tokens into a string or list of strings"""
2240 try:
2241 output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids]
2242 except IndexError as e:
2243 raise TokenError(f"Token index '{e}' not found in `VOCAB`.") from e
2244 if joined_tokens:
2245 return " ".join(output)
2246 else:
2247 return output
2250_ALL_TOKENIZER_HASHES: Int64[np.ndarray, " n_tokenizers"]
2251"private array of all tokenizer hashes"
2252_TOKENIZER_HASHES_PATH: Path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz"
2253"path to where we expect the hashes file -- in the same dir as this file, by default. change with `set_tokenizer_hashes_path`"
2256def set_tokenizer_hashes_path(path: Path):
2257 """set path to tokenizer hashes, and reload the hashes if needed
2259 the hashes are expected to be stored in and read from `_TOKENIZER_HASHES_PATH`,
2260 which by default is `Path(__file__).parent / "MazeTokenizerModular_hashes.npz"` or in this file's directory.
2262 However, this might not always work, so we provide a way to change this.
2263 """
2264 global _TOKENIZER_HASHES_PATH
2265 global _ALL_TOKENIZER_HASHES
2267 path = Path(path)
2268 if path.is_dir():
2269 path = path / "MazeTokenizerModular_hashes.npz"
2271 if not path.is_file():
2272 raise FileNotFoundError(f"could not find maze tokenizer hashes file at: {path}")
2274 if _TOKENIZER_HASHES_PATH.absolute() != path.absolute():
2275 # reload if they aren't equal
2276 _TOKENIZER_HASHES_PATH = path
2277 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes()
2278 else:
2279 # always set to new path
2280 _TOKENIZER_HASHES_PATH = path
2283def _load_tokenizer_hashes() -> Int64[np.ndarray, " n_tokenizers"]:
2284 """Loads the sorted list of `all_tokenizers.get_all_tokenizers()` hashes from disk."""
2285 global _TOKENIZER_HASHES_PATH
2286 try:
2287 path: Path = _TOKENIZER_HASHES_PATH
2288 return np.load(path)["hashes"]
2289 except FileNotFoundError as e:
2290 raise FileNotFoundError(
2291 "Tokenizers hashes cannot be loaded. To fix this, run",
2292 "\n`python -m maze-dataset.tokenization.save_hashes` which will save the hashes to",
2293 "\n`data/MazeTokenizerModular_hashes.npz`",
2294 "relative to the current working directory -- this is where the code looks for them.",
2295 ) from e
2298def get_all_tokenizer_hashes() -> Int64[np.ndarray, " n_tokenizers"]:
2299 global _ALL_TOKENIZER_HASHES
2300 try:
2301 got_tokenizers: bool = len(_ALL_TOKENIZER_HASHES) > 0
2302 if got_tokenizers:
2303 return _ALL_TOKENIZER_HASHES
2304 else:
2305 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes()
2306 except NameError:
2307 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes()
2309 return _ALL_TOKENIZER_HASHES