Coverage for maze_dataset\tokenization\maze_tokenizer.py: 78%

733 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1"""turning a maze into text: `MazeTokenizerModular` and the legacy `TokenizationMode` enum and `MazeTokenizer` class""" 

2 

3import abc 

4import base64 

5import hashlib 

6import random 

7import warnings 

8from enum import Enum 

9from functools import cached_property 

10from pathlib import Path 

11from typing import ( 

12 Any, 

13 Callable, 

14 Iterable, 

15 Literal, 

16 Mapping, 

17 Sequence, 

18 TypedDict, 

19 TypeVar, 

20 overload, 

21) 

22 

23import numpy as np 

24from jaxtyping import Bool, Int, Int64 

25from muutils.json_serialize import ( 

26 SerializableDataclass, 

27 serializable_dataclass, 

28 serializable_field, 

29) 

30from muutils.json_serialize.util import _FORMAT_KEY 

31from muutils.kappa import Kappa 

32from muutils.misc import empty_sequence_if_attr_false, flatten 

33from muutils.misc.sequence import WhenMissing 

34from zanj.loading import load_item_recursive 

35 

36# from maze_dataset import SolvedMaze 

37from maze_dataset.constants import ( 

38 SPECIAL_TOKENS, 

39 VOCAB, 

40 VOCAB_LIST, 

41 VOCAB_TOKEN_TO_INDEX, 

42 ConnectionArray, 

43 ConnectionList, 

44 Coord, 

45 CoordTup, 

46) 

47from maze_dataset.generation import numpy_rng 

48from maze_dataset.maze.lattice_maze import LatticeMaze, SolvedMaze 

49from maze_dataset.token_utils import ( 

50 TokenizerPendingDeprecationWarning, 

51 _coord_to_strings_indexed, 

52 _coord_to_strings_UT, 

53 connection_list_to_adj_list, 

54 coords_to_strings, 

55 get_cardinal_direction, 

56 get_relative_direction, 

57 is_connection, 

58 strings_to_coords, 

59 tokens_between, 

60) 

61from maze_dataset.utils import corner_first_ndindex, lattice_connection_array 

62 

63 

64class TokenError(ValueError): 

65 """error for tokenization""" 

66 

67 pass 

68 

69 

70class TokenizationMode(Enum): 

71 """legacy tokenization modes 

72 

73 > [!CAUTION] 

74 > Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use. 

75 > Use `MazeTokenizerModular` instead. 

76 

77 # Abbreviations: 

78 - `AOTP`: Ajacency list, Origin, Target, Path 

79 - `UT`: Unique Token (for each coordiate) 

80 - `CTT`: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers) 

81 

82 # Modes: 

83 - `AOTP_UT_rasterized`: the "classic" mode: assigning tokens to each coordinate is done via rasterization 

84 example: for a 3x3 maze, token order is `(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)` 

85 - `AOTP_UT_uniform`: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible 

86 uses `corner_first_ndindex` function to order the tokens 

87 - `AOTP_CTT_indexed`: each coordinate is a tuple of integers 

88 """ 

89 

90 AOTP_UT_rasterized = "AOTP_UT_rasterized" 

91 AOTP_UT_uniform = "AOTP_UT_uniform" 

92 AOTP_CTT_indexed = "AOTP_CTT_indexed" 

93 

94 def to_legacy_tokenizer(self, max_grid_size: int | None = None): 

95 return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size) 

96 

97 

98_NDINDEX_FUNC_MAP: dict[ 

99 TokenizationMode, Callable[[int], Iterable[tuple[int, ...]]] 

100] = { 

101 TokenizationMode.AOTP_UT_rasterized: lambda n: list(np.ndindex(n, n)), 

102 TokenizationMode.AOTP_UT_uniform: lambda n: corner_first_ndindex(n, 2), 

103} 

104 

105 

106def is_UT(tokenization_mode: TokenizationMode) -> bool: 

107 return tokenization_mode in ( 

108 TokenizationMode.AOTP_UT_rasterized, 

109 TokenizationMode.AOTP_UT_uniform, 

110 ) 

111 

112 

113def get_tokens_up_to_path_start( 

114 tokens: list[str], 

115 include_start_coord: bool = True, 

116 tokenization_mode: TokenizationMode = TokenizationMode.AOTP_UT_uniform, 

117) -> list[str]: 

118 warnings.warn( 

119 "`maze_tokenizer.get_tokens_up_to_path_start` will be deprecated for a `MazeTokenizerModular`-compatible function in a future release.", 

120 TokenizerPendingDeprecationWarning, 

121 ) 

122 path_start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + 1 

123 if include_start_coord: 

124 if is_UT(tokenization_mode): 

125 return tokens[: path_start_idx + 1] 

126 elif tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 

127 return tokens[: path_start_idx + 5] 

128 else: 

129 raise ValueError(f"Invalid tokenization mode: {tokenization_mode}") 

130 else: 

131 return tokens[:path_start_idx] 

132 

133 

134_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE: list[str] = [ 

135 "name", 

136 "max_grid_size", 

137 "token_arr", 

138 "tokenizer_map", 

139 "vocab_size", 

140 "padding_token_index", 

141] 

142 

143 

144@serializable_dataclass( 

145 properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE, kw_only=True 

146) 

147class MazeTokenizer(SerializableDataclass): 

148 """LEGACY Tokenizer for mazes 

149 

150 > [!CAUTION] 

151 > `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended 

152 > for use, but will remain for compatibility with existing code. 

153 

154 # Parameters: 

155 - `tokenization_mode: TokenizationMode` 

156 mode of tokenization. required. 

157 - `max_grid_size: int | None` 

158 maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text 

159 

160 # Properties 

161 - `name: str` 

162 auto-generated name of the tokenizer from mode and size 

163 

164 ## Conditional Properties 

165 

166 - `node_strings_map: Mapping[CoordTup, str]` 

167 map from node to string. This returns a `muutils.kappa.Kappa` object which you can use like a dictionary. returns `None` if not a `UT` mode 

168 

169 these all return `None` if `max_grid_size` is `None`. 

170 Prepend `_` to the name to get a guaranteed type, and cause an exception if `max_grid_size` is `None` 

171 

172 - `token_arr: list[str]` 

173 list of tokens, in order of their indices in the vocabulary 

174 - `tokenizer_map: Mapping[str, int]` 

175 map from token to index 

176 - `vocab_size: int` 

177 size of the vocabulary 

178 - `padding_token_index: int` 

179 index of the padding token 

180 

181 # Methods 

182 - `coords_to_strings(coords: list[CoordTup]) -> list[str]` 

183 convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates 

184 - `strings_to_coords(strings: list[str]) -> list[CoordTup]` 

185 convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates 

186 

187 """ 

188 

189 # parameters 

190 # ============================================================ 

191 

192 tokenization_mode: TokenizationMode = serializable_field( 

193 default=TokenizationMode.AOTP_UT_uniform, 

194 serialization_fn=lambda x: x.value, 

195 loading_fn=lambda x: TokenizationMode[x["tokenization_mode"]], 

196 ) 

197 

198 max_grid_size: int | None = serializable_field(default=None) 

199 

200 # properties 

201 # ============================================================ 

202 

203 @property 

204 def name(self) -> str: 

205 max_grid_size_str: str = ( 

206 f"-g{self.max_grid_size}" if self.max_grid_size is not None else "" 

207 ) 

208 return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}" 

209 

210 @cached_property 

211 def _node_strings_map(self) -> Mapping[CoordTup, list[str]]: 

212 """map a coordinate to a token""" 

213 if self.tokenization_mode in ( 

214 TokenizationMode.AOTP_UT_rasterized, 

215 TokenizationMode.AOTP_UT_uniform, 

216 ): 

217 return Kappa(_coord_to_strings_UT) 

218 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 

219 return Kappa(_coord_to_strings_indexed) 

220 else: 

221 raise ValueError( 

222 f"Invalid tokenization mode {self.tokenization_mode}", 

223 f"expected one of {TokenizationMode.__members__}", 

224 ) 

225 

226 @cached_property 

227 def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None: 

228 """map a coordinate to a token""" 

229 if self.tokenization_mode in ( 

230 TokenizationMode.AOTP_UT_rasterized, 

231 TokenizationMode.AOTP_UT_uniform, 

232 ): 

233 return None 

234 else: 

235 return self._node_strings_map 

236 

237 # conditional properties (on max_grid_size existing) 

238 # ------------------------------------------------------------ 

239 

240 @cached_property 

241 def _token_arr(self) -> list[str]: 

242 """map from index to token""" 

243 if self.max_grid_size is None: 

244 raise ValueError( 

245 f"max_grid_size must be specified to use token_arr property: {self.max_grid_size = }" 

246 ) 

247 

248 output: list[str] = list(SPECIAL_TOKENS.values()) 

249 

250 if self.tokenization_mode in ( 

251 TokenizationMode.AOTP_UT_rasterized, 

252 TokenizationMode.AOTP_UT_uniform, 

253 ): 

254 output.extend( 

255 [ 

256 self._node_strings_map[coord][0] 

257 for coord in _NDINDEX_FUNC_MAP[self.tokenization_mode]( 

258 self.max_grid_size 

259 ) 

260 ] 

261 ) 

262 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 

263 # TODO: this is hacky, but we don't want to modify the original SPECIAL_TOKENS since that will break old models 

264 output.extend( 

265 [ 

266 "(", 

267 ",", 

268 ")", # new special chars 

269 *map(str, range(self.max_grid_size)), # numbers 

270 ] 

271 ) 

272 else: 

273 raise ValueError( 

274 f"Invalid tokenization mode {self.tokenization_mode}", 

275 f"expected one of {TokenizationMode.__members__}", 

276 ) 

277 

278 return output 

279 

280 @cached_property 

281 def token_arr(self) -> list[str] | None: 

282 if self.max_grid_size is None: 

283 return None 

284 return self._token_arr 

285 

286 @cached_property 

287 def _tokenizer_map(self) -> dict[str, int]: 

288 """map from token to index""" 

289 return {token: i for i, token in enumerate(self._token_arr)} 

290 

291 @cached_property 

292 def tokenizer_map(self) -> dict[str, int] | None: 

293 if self.max_grid_size is None: 

294 return None 

295 return self._tokenizer_map 

296 

297 @property 

298 def _vocab_size(self) -> int: 

299 return len(self._token_arr) 

300 

301 @property 

302 def vocab_size(self) -> int | None: 

303 if self.max_grid_size is None: 

304 return None 

305 return self._vocab_size 

306 

307 @property 

308 def _n_tokens(self) -> int: 

309 # TODO: deprecate 

310 return self._vocab_size 

311 

312 @property 

313 def n_tokens(self) -> int | None: 

314 if self.max_grid_size is None: 

315 return None 

316 return self._n_tokens 

317 

318 @cached_property 

319 def _padding_token_index(self) -> int: 

320 return self.tokenizer_map[SPECIAL_TOKENS.PADDING] 

321 

322 @cached_property 

323 def padding_token_index(self) -> int | None: 

324 if self.max_grid_size is None: 

325 return None 

326 return self._padding_token_index 

327 

328 # conversion functions 

329 # ============================================================ 

330 

331 @overload 

332 def coords_to_strings( 

333 self, 

334 coords: list[str | CoordTup], 

335 when_noncoord: Literal["include", "skip"] = "skip", 

336 ) -> list[str]: ... 

337 @overload 

338 def coords_to_strings( 

339 self, 

340 coords: list[CoordTup], 

341 when_noncoord: Literal["error"] = "error", 

342 ) -> list[str]: ... 

343 def coords_to_strings( 

344 self, 

345 coords: list[CoordTup], 

346 when_noncoord: WhenMissing = "skip", 

347 ) -> list[str]: 

348 if self.tokenization_mode in ( 

349 TokenizationMode.AOTP_UT_rasterized, 

350 TokenizationMode.AOTP_UT_uniform, 

351 ): 

352 return coords_to_strings( 

353 coords=coords, 

354 coord_to_strings_func=_coord_to_strings_UT, 

355 when_noncoord=when_noncoord, 

356 ) 

357 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 

358 return coords_to_strings( 

359 coords=coords, 

360 coord_to_strings_func=_coord_to_strings_indexed, 

361 when_noncoord=when_noncoord, 

362 ) 

363 else: 

364 raise ValueError( 

365 f"Invalid tokenization mode {self.tokenization_mode}", 

366 f"expected one of {TokenizationMode.__members__}", 

367 ) 

368 

369 @overload 

370 def strings_to_coords( 

371 cls, 

372 text: str | list[str], 

373 when_noncoord: Literal["skip"] = "skip", 

374 ) -> list[CoordTup]: ... 

375 @overload 

376 def strings_to_coords( 

377 cls, 

378 text: str | list[str], 

379 when_noncoord: Literal["error"] = "error", 

380 ) -> list[CoordTup]: ... 

381 @overload 

382 def strings_to_coords( 

383 cls, 

384 text: str | list[str], 

385 when_noncoord: Literal["include"] = "include", 

386 ) -> list[str | CoordTup]: ... 

387 @classmethod 

388 def strings_to_coords( 

389 cls, 

390 text: str | list[str], 

391 when_noncoord: WhenMissing = "skip", 

392 ) -> list[str | CoordTup]: 

393 return strings_to_coords(text=text, when_noncoord=when_noncoord) 

394 

395 def encode(self, text: str | list[str]) -> list[int]: 

396 """encode a string or list of strings into a list of tokens""" 

397 try: 

398 if isinstance(text, str): 

399 text = text.split() 

400 return [self.tokenizer_map[token] for token in text] 

401 except KeyError as e: 

402 raise TokenError( 

403 f"Token {e} not found", 

404 f"in vocabulary of {self}:", 

405 f"{self.token_arr}", 

406 ) from e 

407 

408 def decode( 

409 self, tokens: Sequence[int], joined_tokens: bool = False 

410 ) -> list[str] | str: 

411 """decode a list of tokens into a string or list of strings""" 

412 try: 

413 output: list[str] = [self.token_arr[token] for token in tokens] 

414 except IndexError as e: 

415 raise TokenError( 

416 f"Token index '{e}' not found in vocabulary of length {self.vocab_size}" 

417 ) from e 

418 if joined_tokens: 

419 return " ".join(output) 

420 else: 

421 return output 

422 

423 # UT-only coordinate stuff 

424 # ============================================================ 

425 

426 @cached_property 

427 def coordinate_tokens_coords(self) -> dict[CoordTup, int]: 

428 print(f"{self.tokenization_mode = }") 

429 if not self.is_UT(): 

430 raise ValueError( 

431 f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }" 

432 ) 

433 if self.max_grid_size is None: 

434 raise ValueError( 

435 f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }" 

436 ) 

437 

438 raw_converted: list[CoordTup | str] = self.strings_to_coords( 

439 self.token_arr, when_noncoord="include" 

440 ) 

441 

442 # filter out non-coordinates 

443 return { 

444 coord: i 

445 for i, coord in enumerate(raw_converted) 

446 if not isinstance(coord, str) 

447 } 

448 

449 @cached_property 

450 def coordinate_tokens_ids(self) -> dict[str, int]: 

451 # checks performed in call 

452 output: dict[str, int] = dict() 

453 

454 for coord, index in self.coordinate_tokens_coords.items(): 

455 _for_key: list[str] = self.coords_to_strings([coord]) 

456 assert len(_for_key) == 1 

457 output[_for_key[0]] = index 

458 

459 return output 

460 

461 # other 

462 # ============================================================ 

463 

464 def summary(self) -> dict: 

465 """returns a summary of the tokenization mode""" 

466 return { 

467 "tokenization_mode": self.tokenization_mode.value, 

468 "max_grid_size": self.max_grid_size, 

469 "vocab_size": self.vocab_size, 

470 } 

471 

472 def is_AOTP(self) -> bool: 

473 """returns true if a tokenization mode is Adjacency list, Origin, Target, Path""" 

474 return self.tokenization_mode in ( 

475 TokenizationMode.AOTP_UT_rasterized, 

476 TokenizationMode.AOTP_UT_uniform, 

477 TokenizationMode.AOTP_CTT_indexed, 

478 ) 

479 

480 def is_UT(self) -> bool: 

481 return is_UT(self.tokenization_mode) 

482 

483 def clear_cache(self): 

484 """clears all cached properties""" 

485 # delete the properties only if they exist 

486 for name, prop in self.__class__.__dict__.items(): 

487 if isinstance(prop, cached_property): 

488 # if the property exists, delete it 

489 try: 

490 delattr(self, name) 

491 except AttributeError: 

492 pass 

493 

494 

495@serializable_dataclass(frozen=True, kw_only=True) 

496class _TokenizerElement(SerializableDataclass, abc.ABC): 

497 """Superclass for tokenizer elements. 

498 Subclasses contain modular functionality for maze tokenization. 

499 

500 # Development 

501 > [!TIP] 

502 > Due to the functionality of `get_all_tokenizers()`, `_TokenizerElement` subclasses 

503 > may only contain fields of type `utils.FiniteValued`. 

504 > Implementing a subclass with an `int` or `float`-typed field, for example, is not supported. 

505 > In the event that adding such fields is deemed necessary, `get_all_tokenizers()` must be updated. 

506 

507 """ 

508 

509 @staticmethod 

510 def _stringify(k: str, v: Any): 

511 if isinstance(v, bool): 

512 return f"{k}={str(v)[0]}" 

513 if isinstance(v, _TokenizerElement): 

514 return v.name 

515 if isinstance(v, tuple): 

516 return f"{k}={''.join(['(', *[str(x) + ', ' for x in v], ')'])}" 

517 else: 

518 return f"{k}={v}" 

519 

520 @property 

521 def name(self) -> str: 

522 members_str: str = ", ".join( 

523 [self._stringify(k, v) for k, v in self.__dict__.items() if k != "_type_"] 

524 ) 

525 output: str = f"{type(self).__name__}({members_str})" 

526 if "." in output and output.index("(") > output.index("."): 

527 return "".join(output.split(".")[1:]) 

528 else: 

529 return output 

530 

531 def __str__(self): 

532 return self.name 

533 

534 def __init_subclass__(cls, **kwargs): 

535 """ 

536 Hack: dataclass hashes don't include the class itself in the hash function inputs. 

537 This causes dataclasses with identical fields but different types to hash identically. 

538 This hack circumvents this by adding a slightly hidden field to every subclass with a value of `repr(cls)`. 

539 To maintain compatibility with `all_instances`, the static type of the new field can only have 1 possible value. 

540 So we type it as a singleton `Literal` type. 

541 muutils 0.6.1 doesn't support `Literal` type validation, so `assert_type=False`. 

542 Ignore Pylance complaining about the arg to `Literal` being an expression. 

543 """ 

544 super().__init_subclass__(**kwargs) 

545 cls._type_ = serializable_field( 

546 init=True, repr=False, default=repr(cls), assert_type=False 

547 ) 

548 cls.__annotations__["_type_"] = Literal[repr(cls)] # type: ignore 

549 

550 def __hash__(self): 

551 "Stable hash to identify unique `MazeTokenizerModular` instances. uses name" 

552 return int.from_bytes( 

553 hashlib.blake2b(self.name.encode("utf-8")).digest(), 

554 byteorder="big", 

555 ) 

556 

557 @classmethod 

558 def _level_one_subclass(cls) -> type["_TokenizerElement"]: 

559 """Returns the immediate subclass of `_TokenizerElement` of which `cls` is an instance.""" 

560 return ( 

561 set(cls.__mro__).intersection(set(_TokenizerElement.__subclasses__())).pop() 

562 ) 

563 

564 def tokenizer_elements(self, deep: bool = True) -> list["_TokenizerElement"]: 

565 """ 

566 Returns a list of all `_TokenizerElement` instances contained in the subtree. 

567 Currently only detects `_TokenizerElement` instances which are either direct attributes of another instance or 

568 which sit inside a `tuple` without further nesting. 

569 

570 # Parameters 

571 - `deep: bool`: Whether to return elements nested arbitrarily deeply or just a single layer. 

572 """ 

573 if not any(type(el) == tuple for el in self.__dict__.values()): # noqa: E721 

574 return list( 

575 flatten( 

576 [ 

577 [el] + el.tokenizer_elements() 

578 for el in self.__dict__.values() 

579 if isinstance(el, _TokenizerElement) 

580 ] 

581 ) 

582 if deep 

583 else filter( 

584 lambda x: isinstance(x, _TokenizerElement), self.__dict__.values() 

585 ) 

586 ) 

587 else: 

588 non_tuple_elems: list[_TokenizerElement] = list( 

589 flatten( 

590 [ 

591 [el] + el.tokenizer_elements() 

592 for el in self.__dict__.values() 

593 if isinstance(el, _TokenizerElement) 

594 ] 

595 if deep 

596 else filter( 

597 lambda x: isinstance(x, _TokenizerElement), 

598 self.__dict__.values(), 

599 ) 

600 ) 

601 ) 

602 tuple_elems: list[_TokenizerElement] = list( 

603 flatten( 

604 [ 

605 ( 

606 [ 

607 [tup_el] + tup_el.tokenizer_elements() 

608 for tup_el in el 

609 if isinstance(tup_el, _TokenizerElement) 

610 ] 

611 if deep 

612 else filter(lambda x: isinstance(x, _TokenizerElement), el) 

613 ) 

614 for el in self.__dict__.values() 

615 if isinstance(el, tuple) 

616 ] 

617 ) 

618 ) 

619 non_tuple_elems.extend(tuple_elems) 

620 return non_tuple_elems 

621 

622 def tokenizer_element_tree(self, depth: int = 0, abstract: bool = False) -> str: 

623 """ 

624 Returns a string representation of the tree of tokenizer elements contained in `self`. 

625 

626 # Parameters 

627 - `depth: int`: Current depth in the tree. Used internally for recursion, no need to specify. 

628 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. 

629 """ 

630 name: str = "\t" * depth + ( 

631 type(self).__name__ 

632 if not abstract 

633 else type(self)._level_one_subclass().__name__ 

634 ) 

635 return ( 

636 name 

637 + "\n" 

638 + "".join( 

639 el.tokenizer_element_tree(depth + 1, abstract) 

640 for el in self.tokenizer_elements(deep=False) 

641 ) 

642 ) 

643 

644 def tokenizer_element_dict(self) -> dict: 

645 """ 

646 Returns a dictionary representation of the tree of tokenizer elements contained in `self`. 

647 """ 

648 return { 

649 type(self).__name__: { 

650 key: ( 

651 val.tokenizer_element_dict() 

652 if isinstance(val, _TokenizerElement) 

653 else ( 

654 val 

655 if not isinstance(val, tuple) 

656 else [ 

657 ( 

658 el.tokenizer_element_dict() 

659 if isinstance(el, _TokenizerElement) 

660 else el 

661 ) 

662 for el in val 

663 ] 

664 ) 

665 ) 

666 for key, val in self.__dict__.items() 

667 if key != "_type_" 

668 } 

669 } 

670 

671 @classmethod 

672 @abc.abstractmethod 

673 def attribute_key(cls) -> str: 

674 """Returns the binding used in `MazeTokenizerModular` for that type of `_TokenizerElement`.""" 

675 raise NotImplementedError 

676 

677 def to_tokens(self, *args, **kwargs) -> list[str]: 

678 """Converts a maze element into a list of tokens. 

679 Not all `_TokenizerElement` subclasses produce tokens, so this is not an abstract method. 

680 Those subclasses which do produce tokens should override this method. 

681 """ 

682 raise NotImplementedError 

683 

684 @abc.abstractmethod 

685 def is_valid(self) -> bool: 

686 """Returns if `self` contains data members capable of producing an overall valid `MazeTokenizerModular`. 

687 Some `_TokenizerElement` instances may be created which are not useful despite obeying data member type hints. 

688 `is_valid` allows for more precise detection of invalid `_TokenizerElement`s beyond type hinting alone. 

689 If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply `return True` for that subclass. 

690 

691 # Types of Invalidity 

692 In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. 

693 Invalidity types, in ascending order of invalidity: 

694 - Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study. 

695 E.g., `_TokenizerElement`s which are strictly worse than some alternative. 

696 - Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers. 

697 - Untrainable: Training functional models using these tokenizers would be (nearly) impossible. 

698 - Erroneous: These tokenizers might raise exceptions during use. 

699 

700 # Development 

701 `is_invalid` is implemented to always return `True` in some abstract classes where all currently possible subclass instances are valid. 

702 When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary. 

703 

704 ## Nesting 

705 In general, when implementing this method, there is no need to recursively call `is_valid` on nested `_TokenizerElement`s contained in the class. 

706 In other words, failures of `is_valid` need not bubble up to the top of the nested `_TokenizerElement` tree. 

707 `MazeTokenizerModular.is_valid` calls `is_valid` on each of its `_TokenizerElement`s individually, so failure at any level will be detected. 

708 

709 ## Types of Invalidity 

710 If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments. 

711 This could be used to create more or less stringent filters on the valid `_TokenizerElement` instances. 

712 """ 

713 raise NotImplementedError 

714 

715 

716T = TypeVar("T", bound=_TokenizerElement) 

717 

718 

719def mark_as_unsupported(is_valid: Callable[[T], bool], *args) -> T: 

720 """mark a _TokenizerElement as unsupported. 

721 

722 Classes marked with this decorator won't show up in `get_all_tokenizers()` and thus wont be tested. 

723 The classes marked in release 1.0.0 did work reliably before being marked, but they can't be instantiated since the decorator adds an abstract method. 

724 The decorator exists to prune the space of tokenizers returned by `all_instances` both for testing and usage. 

725 Previously, the space was too large, resulting in impractical runtimes. 

726 These decorators could be removed in future releases to expand the space of possible tokenizers. 

727 """ 

728 

729 def wrapper(cls): 

730 cls.is_valid = is_valid 

731 return cls 

732 

733 return wrapper 

734 

735 

736class __TokenizerElementNamespace(abc.ABC): 

737 """ABC for namespaces 

738 

739 # Properties 

740 - key: The binding used in `MazeTokenizerModular` for instances of the classes contained within that `__TokenizerElementNamespace`. 

741 """ 

742 

743 key: str = NotImplementedError 

744 

745 

746def _load_tokenizer_element( 

747 data: dict[str, Any], namespace: type[__TokenizerElementNamespace] 

748) -> _TokenizerElement: 

749 """Loads a `TokenizerElement` stored via zanj.""" 

750 key: str = namespace.key 

751 format: str = data[key][_FORMAT_KEY] 

752 cls_name: str = format.split("(")[0] 

753 cls: type[_TokenizerElement] = getattr(namespace, cls_name) 

754 kwargs: dict[str, Any] = { 

755 k: load_item_recursive(data[key][k], tuple()) for k, v in data[key].items() 

756 } 

757 if _FORMAT_KEY in kwargs: 

758 kwargs.pop(_FORMAT_KEY) 

759 return cls(**kwargs) 

760 

761 

762class CoordTokenizers(__TokenizerElementNamespace): 

763 """Namespace for `_CoordTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 

764 

765 key = "coord_tokenizer" 

766 

767 @serializable_dataclass(frozen=True, kw_only=True) 

768 class _CoordTokenizer(_TokenizerElement, abc.ABC): 

769 """ 

770 Superclass for classes which tokenize singular coords in a maze. 

771 """ 

772 

773 @abc.abstractmethod 

774 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: 

775 pass 

776 

777 @classmethod 

778 def attribute_key(cls) -> str: 

779 return CoordTokenizers.key 

780 

781 def is_valid(self) -> bool: 

782 # No invalid instances possible within data member type hint bounds 

783 return True 

784 

785 @serializable_dataclass(frozen=True, kw_only=True) 

786 class UT(_CoordTokenizer): 

787 """Unique token coordinate tokenizer.""" 

788 

789 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: 

790 return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])] 

791 

792 @serializable_dataclass(frozen=True, kw_only=True) 

793 class CTT(_CoordTokenizer): 

794 """Coordinate tuple tokenizer 

795 

796 # Parameters 

797 - `pre`: Whether all coords include an integral preceding delimiter token 

798 - `intra`: Whether all coords include a delimiter token between coordinates 

799 - `post`: Whether all coords include an integral following delimiter token 

800 """ 

801 

802 pre: bool = serializable_field(default=True) 

803 intra: bool = serializable_field(default=True) 

804 post: bool = serializable_field(default=True) 

805 # Implement methods 

806 

807 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: 

808 return [ 

809 *empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"), 

810 str(coord[0]), 

811 *empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"), 

812 str(coord[1]), 

813 *empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"), 

814 ] 

815 

816 

817class EdgeGroupings(__TokenizerElementNamespace): 

818 """Namespace for `_EdgeGrouping` subclass hierarchy used by `_AdjListTokenizer`.""" 

819 

820 key = "edge_grouping" 

821 

822 class _GroupingTokenParams(TypedDict): 

823 """A uniform private hyperparameter interface used by `AdjListTokenizer`.""" 

824 

825 connection_token_ordinal: Literal[0, 1, 2] 

826 intra: bool 

827 grouped: bool 

828 

829 @serializable_dataclass(frozen=True, kw_only=True) 

830 class _EdgeGrouping(_TokenizerElement, abc.ABC): 

831 """Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called a edge grouping.""" 

832 

833 @classmethod 

834 def attribute_key(cls) -> str: 

835 return EdgeGroupings.key 

836 

837 def is_valid(self) -> bool: 

838 return True 

839 

840 @abc.abstractmethod 

841 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: 

842 """Divides a ConnectionArray into groups of edges. 

843 Shuffles/sequences within each group if applicable. 

844 """ 

845 pass 

846 

847 @abc.abstractmethod 

848 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 

849 """Returns the tok.nization hyperparameters necessary for an `AdjListTokenizer` to tokenize. 

850 

851 These hyperparameters are not used by `_EdgeGrouping` internally. 

852 They are located in `_EdgeGrouping` rather than in `AdjListTokenizer` 

853 since the hyperparameter space is a function of the `_EdgeGrouping` subclass. 

854 This function resolves the `_EdgeGrouping` hyperparameter space which is non-uniform across subclasses 

855 into a uniform private interface used by `AdjListTokenizer`. 

856 """ 

857 pass 

858 

859 @serializable_dataclass(frozen=True, kw_only=True) 

860 class Ungrouped(_EdgeGrouping): 

861 """No grouping occurs, each edge is tokenized individually. 

862 

863 # Parameters 

864 - `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears. 

865 Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization. 

866 """ 

867 

868 connection_token_ordinal: Literal[0, 1, 2] = serializable_field( 

869 default=1, assert_type=False 

870 ) 

871 

872 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 

873 return EdgeGroupings._GroupingTokenParams( 

874 connection_token_ordinal=self.connection_token_ordinal, 

875 intra=False, 

876 grouped=False, 

877 ) 

878 

879 def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]: 

880 return np.expand_dims(edges, 1) 

881 

882 @serializable_dataclass(frozen=True, kw_only=True) 

883 @mark_as_unsupported(lambda self_: False) 

884 class ByLeadingCoord(_EdgeGrouping): 

885 """All edges with the same leading coord are grouped together. 

886 

887 # Parameters 

888 - `intra`: Whether all edge groupings include a delimiter token between individual edge representations. 

889 Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `) 

890 - `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order. 

891 If false, the fixed order is lexicographical by (row, col). 

892 In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord. 

893 - `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears. 

894 Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization. 

895 """ 

896 

897 intra: bool = serializable_field(default=True) 

898 shuffle_group: bool = serializable_field(default=True) 

899 connection_token_ordinal: Literal[0, 1] = serializable_field( 

900 default=0, assert_type=False 

901 ) 

902 

903 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 

904 return EdgeGroupings._GroupingTokenParams( 

905 connection_token_ordinal=self.connection_token_ordinal, 

906 intra=self.intra, 

907 grouped=True, 

908 ) 

909 

910 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: 

911 # Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function 

912 index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort( 

913 (edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0]) 

914 ) 

915 sorted_edges: ConnectionArray = edges[index_array, ...] 

916 groups: list[ConnectionArray] = np.split( 

917 sorted_edges, 

918 np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:], 

919 ) 

920 if self.shuffle_group: 

921 [numpy_rng.shuffle(g, axis=0) for g in groups] 

922 return groups 

923 

924 

925class EdgePermuters(__TokenizerElementNamespace): 

926 """Namespace for `_EdgePermuter` subclass hierarchy used by `_AdjListTokenizer`.""" 

927 

928 key = "edge_permuter" 

929 

930 @serializable_dataclass(frozen=True, kw_only=True) 

931 class _EdgePermuter(_TokenizerElement, abc.ABC): 

932 """Specifies how to sequence the two coords that encode a lattice edge.""" 

933 

934 @classmethod 

935 def attribute_key(cls) -> str: 

936 return EdgePermuters.key 

937 

938 def is_valid(self) -> bool: 

939 # No invalid instances possible within data member type hint bounds 

940 return True 

941 

942 @staticmethod 

943 @abc.abstractmethod 

944 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 

945 """ 

946 Executes a permutation. 

947 Warning: Caller should be aware that `lattice_edges` may be modified in-place depending on the subclass's implementation. 

948 

949 # Parameters 

950 - `lattice_edges`: Array of lattice edges. 

951 The two coords in shape[1] must be adjacent in the lattice. 

952 

953 # Returns 

954 - Array of lattice edges with entries along shape[1] systematically permuted. 

955 - shape[0] of the returned array is NOT guaranteed to match `lattice_edges.shape[1]`. 

956 """ 

957 pass 

958 

959 @serializable_dataclass(frozen=True, kw_only=True) 

960 class SortedCoords(_EdgePermuter): 

961 """returns a sorted representation. useful for checking consistency""" 

962 

963 @staticmethod 

964 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 

965 return lattice_edges[ 

966 np.lexsort( 

967 ( 

968 lattice_edges[:, 1, 1], 

969 lattice_edges[:, 1, 0], 

970 lattice_edges[:, 0, 1], 

971 lattice_edges[:, 0, 0], 

972 ) 

973 ), 

974 ..., 

975 ] 

976 

977 @serializable_dataclass(frozen=True, kw_only=True) 

978 class RandomCoords(_EdgePermuter): 

979 """Permutes each edge randomly.""" 

980 

981 @staticmethod 

982 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 

983 numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges) 

984 return lattice_edges 

985 

986 @serializable_dataclass(frozen=True, kw_only=True) 

987 class BothCoords(_EdgePermuter): 

988 """Includes both possible permutations of every edge in the output. 

989 Since input ConnectionList has only 1 instance of each edge, 

990 a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`. 

991 """ 

992 

993 @staticmethod 

994 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 

995 return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0) 

996 

997 

998class EdgeSubsets(__TokenizerElementNamespace): 

999 """ 

1000 Namespace for `_EdgeSubset` subclass hierarchy used by `_AdjListTokenizer`. 

1001 """ 

1002 

1003 key = "edge_subset" 

1004 

1005 @serializable_dataclass(frozen=True, kw_only=True) 

1006 class _EdgeSubset(_TokenizerElement, abc.ABC): 

1007 """ 

1008 Component of an `AdjListTokenizers._AdjListTokenizer` which specifies the subset of lattice edges to be tokenized. 

1009 """ 

1010 

1011 @classmethod 

1012 def attribute_key(cls) -> str: 

1013 return EdgeSubsets.key 

1014 

1015 def is_valid(self) -> bool: 

1016 return True 

1017 

1018 @abc.abstractmethod 

1019 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 

1020 """ 

1021 Returns the set of lattice edges to be tokenized. 

1022 """ 

1023 pass 

1024 

1025 @serializable_dataclass(frozen=True, kw_only=True) 

1026 class AllLatticeEdges(_EdgeSubset): 

1027 """ 

1028 All 2n**2-2n edges of the lattice are tokenized. 

1029 If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`. 

1030 """ 

1031 

1032 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 

1033 return lattice_connection_array(maze.grid_n) 

1034 

1035 @serializable_dataclass(frozen=True, kw_only=True) 

1036 class ConnectionEdges(_EdgeSubset): 

1037 """ 

1038 Only edges which contain a connection are tokenized. 

1039 Alternatively, only edges which contain a wall are tokenized. 

1040 

1041 # Parameters 

1042 - `walls`: Whether wall edges or connection edges are tokenized. 

1043 If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`. 

1044 """ 

1045 

1046 walls: bool = serializable_field(default=False) 

1047 

1048 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 

1049 conn_list: ConnectionList = maze.connection_list 

1050 if self.walls: 

1051 conn_list = np.logical_not(conn_list) 

1052 conn_list[0, -1, :] = False 

1053 conn_list[1, :, -1] = False 

1054 return connection_list_to_adj_list( 

1055 conn_list, shuffle_d0=False, shuffle_d1=False 

1056 ) 

1057 

1058 

1059class AdjListTokenizers(__TokenizerElementNamespace): 

1060 """Namespace for `_AdjListTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 

1061 

1062 key = "adj_list_tokenizer" 

1063 

1064 @serializable_dataclass(frozen=True, kw_only=True) 

1065 @mark_as_unsupported(lambda self_: self_.pre is False) 

1066 class _AdjListTokenizer(_TokenizerElement, abc.ABC): 

1067 """ 

1068 Specifies how the adjacency list is tokenized. 

1069 Tokenization behavior is decomposed into specification of edge subsets, groupings, and permutations. 

1070 See documentation of `EdgeSubset` and `EdgeGrouping` classes for more details. 

1071 

1072 # Parameters 

1073 - `pre`: Whether all edge groupings include a preceding delimiter token 

1074 - `post`: Whether all edge groupings include a following delimiter token 

1075 - `shuffle_d0`: Specifies how to sequence the edge groupings. 

1076 If true, groupings are shuffled randomly. If false, groupings are sorted by the leading coord of each group. 

1077 - `edge_grouping`: Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called an edge grouping. 

1078 - `edge_subset`: Specifies the subset of lattice edges to be tokenized. 

1079 - `edge_permuter`: Specifies, in each edge tokenization, which coord either: 

1080 1. Appears first in the tokenization, for `AdjListCoord`. 

1081 2. Is tokenized directly as a coord, for `AdjListCardinal`. 

1082 - `shuffle`: For each edge, the leading coord is selected randomly. 

1083 - `all`: Each edge appears twice in the tokenization, appearing with both leading coords. 

1084 - `evens`, `odds`: The leading coord is the one belonging to that coord subset. See `EdgeSubsets.ChessboardSublattice` for details. 

1085 """ 

1086 

1087 pre: bool = serializable_field(default=False, assert_type=False) 

1088 post: bool = serializable_field(default=True) 

1089 shuffle_d0: bool = serializable_field(default=True) 

1090 edge_grouping: EdgeGroupings._EdgeGrouping = serializable_field( 

1091 default=EdgeGroupings.Ungrouped(), 

1092 loading_fn=lambda x: _load_tokenizer_element(x, EdgeGroupings), 

1093 ) 

1094 edge_subset: EdgeSubsets._EdgeSubset = serializable_field( 

1095 default=EdgeSubsets.ConnectionEdges(), 

1096 loading_fn=lambda x: _load_tokenizer_element(x, EdgeSubsets), 

1097 ) 

1098 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 

1099 default=EdgePermuters.RandomCoords(), 

1100 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 

1101 ) 

1102 

1103 @classmethod 

1104 def attribute_key(cls) -> str: 

1105 return AdjListTokenizers.key 

1106 

1107 def is_valid(self) -> bool: 

1108 # No invalid instances possible within data member type hint bounds 

1109 return True 

1110 

1111 @abc.abstractmethod 

1112 def _tokenization_callables( 

1113 self, 

1114 edges: ConnectionArray, 

1115 is_conn: Bool[np.ndarray, " edges"], 

1116 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

1117 *args, 

1118 **kwargs, 

1119 ): 

1120 """ 

1121 Returns a sequence of callables which take an index in `edges` and return parts of that edge tokenization. 

1122 

1123 # Returns 

1124 - `[0]`: leading coord tokens 

1125 - `[1]`: connector tokens 

1126 - `[2]`: trailing coord tokens 

1127 """ 

1128 pass 

1129 

1130 def _tokenize_edge_grouping( 

1131 self, 

1132 edges: ConnectionArray, 

1133 maze: LatticeMaze, 

1134 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

1135 group_params: EdgeGroupings._GroupingTokenParams, 

1136 ) -> Sequence[str]: 

1137 """ 

1138 Tokenizes a single edge grouping. 

1139 """ 

1140 cxn_ord: int = group_params["connection_token_ordinal"] 

1141 is_conn: Bool[np.ndarray, "edges"] = is_connection( 

1142 edges, maze.connection_list 

1143 ) 

1144 tokenize_callables = self._tokenization_callables( 

1145 edges, is_conn, coord_tokenizer 

1146 ) 

1147 

1148 if group_params["grouped"]: 

1149 # If grouped 

1150 callable_permutation: list[int] = [1, 2] if cxn_ord == 0 else [2, 1] 

1151 repeated_callables = [ 

1152 tokenize_callables[i] for i in callable_permutation 

1153 ] 

1154 return flatten( 

1155 [ 

1156 tokenize_callables[0](0), 

1157 [ 

1158 [ 

1159 *[ 

1160 tok_callable(i) 

1161 for tok_callable in repeated_callables 

1162 ], 

1163 *( 

1164 (VOCAB.ADJLIST_INTRA,) 

1165 if group_params["intra"] 

1166 else () 

1167 ), 

1168 ] 

1169 for i in range(edges.shape[0]) 

1170 ], 

1171 ] 

1172 ) 

1173 else: 

1174 # If ungrouped 

1175 callable_permutation = [0, 2] 

1176 callable_permutation.insert(cxn_ord, 1) 

1177 tokenize_callables = [ 

1178 tokenize_callables[i] for i in callable_permutation 

1179 ] 

1180 

1181 return flatten( 

1182 [ 

1183 [ 

1184 [ 

1185 *[ 

1186 tok_callable(i) 

1187 for tok_callable in tokenize_callables 

1188 ], 

1189 *empty_sequence_if_attr_false( 

1190 (VOCAB.ADJLIST_INTRA,), group_params, "intra" 

1191 ), 

1192 ] 

1193 for i in range(edges.shape[0]) 

1194 ] 

1195 ] 

1196 ) 

1197 

1198 def to_tokens( 

1199 self, maze: LatticeMaze, coord_tokenizer: CoordTokenizers._CoordTokenizer 

1200 ) -> list[str]: 

1201 # Get the set of edges to be tokenized 

1202 edges: ConnectionArray = self.edge_subset._get_edges(maze) 

1203 # Systematically permute the leading coord of each edge 

1204 edges: ConnectionArray = self.edge_permuter._permute(edges) 

1205 group_params: EdgeGroupings._GroupingTokenParams = ( 

1206 self.edge_grouping._token_params() 

1207 ) 

1208 # then, we need to group the edges 

1209 groups: Sequence[ConnectionArray] = self.edge_grouping._group_edges(edges) 

1210 # shuffle the groups if specified 

1211 if self.shuffle_d0: 

1212 if isinstance(groups, np.ndarray): 

1213 numpy_rng.shuffle(groups, axis=0) 

1214 elif isinstance(groups, list): 

1215 random.shuffle(groups) 

1216 else: 

1217 raise TypeError( 

1218 f"`groups` is an unexpected type {type(groups)}. Only types `list` and `np.ndarray` are currently supported." 

1219 ) 

1220 # Tokenize each group with optional delimiters 

1221 tokens: list[str] = list( 

1222 flatten( 

1223 [ 

1224 [ 

1225 *empty_sequence_if_attr_false( 

1226 (VOCAB.ADJLIST_PRE,), self, "pre" 

1227 ), 

1228 *self._tokenize_edge_grouping( 

1229 group, maze, coord_tokenizer, group_params 

1230 ), 

1231 *empty_sequence_if_attr_false( 

1232 (VOCAB.ADJACENCY_ENDLINE,), self, "post" 

1233 ), 

1234 ] 

1235 for group in groups 

1236 ] 

1237 ) 

1238 ) 

1239 return tokens 

1240 

1241 @serializable_dataclass(frozen=True, kw_only=True) 

1242 class AdjListCoord(_AdjListTokenizer): 

1243 """Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members.""" 

1244 

1245 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 

1246 default=EdgePermuters.RandomCoords(), 

1247 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 

1248 ) 

1249 

1250 def _tokenization_callables( 

1251 self, 

1252 edges: ConnectionArray, 

1253 is_conn: Bool[np.ndarray, " edges"], 

1254 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

1255 *args, 

1256 **kwargs, 

1257 ): 

1258 # Map from `is_conn` to the tokens which represent connections and walls 

1259 conn_token_map: dict[bool, str] = { 

1260 True: VOCAB.CONNECTOR, 

1261 False: VOCAB.ADJLIST_WALL, 

1262 } 

1263 return [ 

1264 lambda i: coord_tokenizer.to_tokens(edges[i, 0]), 

1265 lambda i: conn_token_map[is_conn[i]], 

1266 lambda i: coord_tokenizer.to_tokens(edges[i, 1]), 

1267 ] 

1268 

1269 @serializable_dataclass(frozen=True, kw_only=True) 

1270 class AdjListCardinal(_AdjListTokenizer): 

1271 """Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members. 

1272 

1273 # Parameters 

1274 - `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens. 

1275 """ 

1276 

1277 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 

1278 default=EdgePermuters.BothCoords(), 

1279 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 

1280 ) 

1281 

1282 def _tokenization_callables( 

1283 self, 

1284 edges: ConnectionArray, 

1285 is_conn: Bool[np.ndarray, " edges"], 

1286 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

1287 *args, 

1288 **kwargs, 

1289 ): 

1290 # Map from `is_conn` to the tokens which represent connections and walls 

1291 conn_token_map: dict[bool, str] = { 

1292 True: VOCAB.CONNECTOR, 

1293 False: VOCAB.ADJLIST_WALL, 

1294 } 

1295 return [ 

1296 lambda i: coord_tokenizer.to_tokens(edges[i, 0]), 

1297 lambda i: conn_token_map[is_conn[i]], 

1298 lambda i: get_cardinal_direction(edges[i]), 

1299 ] 

1300 

1301 

1302class TargetTokenizers(__TokenizerElementNamespace): 

1303 """Namespace for `_TargetTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 

1304 

1305 key = "target_tokenizer" 

1306 

1307 @serializable_dataclass(frozen=True, kw_only=True) 

1308 class _TargetTokenizer(_TokenizerElement, abc.ABC): 

1309 """Superclass of tokenizers for maze targets.""" 

1310 

1311 @abc.abstractmethod 

1312 def to_tokens( 

1313 self, 

1314 targets: Sequence[Coord], 

1315 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

1316 ) -> list[str]: 

1317 """Returns tokens representing the target.""" 

1318 pass 

1319 

1320 @classmethod 

1321 def attribute_key(cls) -> str: 

1322 return TargetTokenizers.key 

1323 

1324 @serializable_dataclass(frozen=True, kw_only=True) 

1325 class Unlabeled(_TargetTokenizer): 

1326 """Targets are simply listed as coord tokens. 

1327 - `post`: Whether all coords include an integral following delimiter token 

1328 """ 

1329 

1330 post: bool = serializable_field(default=False) 

1331 

1332 def to_tokens( 

1333 self, 

1334 targets: Sequence[Coord], 

1335 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

1336 ) -> list[str]: 

1337 return list( 

1338 flatten( 

1339 [ 

1340 [ 

1341 *coord_tokenizer.to_tokens(target), 

1342 *empty_sequence_if_attr_false( 

1343 [VOCAB.TARGET_POST], self, "post" 

1344 ), 

1345 ] 

1346 for target in targets 

1347 ] 

1348 ) 

1349 ) 

1350 

1351 def is_valid(self) -> bool: 

1352 # No invalid instances possible within data member type hint bounds 

1353 return True 

1354 

1355 

1356class StepSizes(__TokenizerElementNamespace): 

1357 """Namespace for `_StepSize` subclass hierarchy used by `MazeTokenizerModular`.""" 

1358 

1359 key = "step_size" 

1360 

1361 @serializable_dataclass(frozen=True, kw_only=True) 

1362 class _StepSize(_TokenizerElement, abc.ABC): 

1363 """ 

1364 Specifies which coords in `maze.solution` are used to represent the path. 

1365 """ 

1366 

1367 @classmethod 

1368 def attribute_key(cls) -> str: 

1369 return StepSizes.key 

1370 

1371 @abc.abstractmethod # TODO: make this a static/class method, allowing ForksAndStraightaways to skip object construction at every call 

1372 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 

1373 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 

1374 raise NotImplementedError( 

1375 "Subclasses must implement `StepSize.step_indices." 

1376 ) 

1377 

1378 def step_start_end_indices(self, maze) -> list[tuple[int, int]]: 

1379 """Returns steps as tuples of starting and ending positions for each step.""" 

1380 indices: list[int] = self._step_single_indices(maze) 

1381 return [(start, end) for start, end in zip(indices[:-1], indices[1:])] 

1382 

1383 def is_valid(self) -> bool: 

1384 # No invalid instances possible within data member type hint bounds 

1385 return True 

1386 

1387 @serializable_dataclass(frozen=True, kw_only=True) 

1388 class Singles(_StepSize): 

1389 """ 

1390 Every coord in `maze.solution` is represented. 

1391 Legacy tokenizers all use this behavior. 

1392 """ 

1393 

1394 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 

1395 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 

1396 return list(range(maze.solution.shape[0])) 

1397 

1398 @serializable_dataclass(frozen=True, kw_only=True) 

1399 @mark_as_unsupported(lambda self_: False) 

1400 class Straightaways(_StepSize): 

1401 """ 

1402 Only coords where the path turns are represented in the path. 

1403 I.e., the path is represented as a sequence of straightaways, 

1404 specified by the coords at the turns. 

1405 """ 

1406 

1407 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 

1408 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 

1409 last_turn_coord: Coord = maze.solution[0, ...] 

1410 indices: list[int] = [0] 

1411 for i, coord in enumerate(maze.solution): 

1412 if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]: 

1413 indices.append(i - 1) 

1414 last_turn_coord = maze.solution[i - 1, ...] 

1415 indices.append(i) 

1416 return indices 

1417 

1418 @serializable_dataclass(frozen=True, kw_only=True) 

1419 class Forks(_StepSize): 

1420 """ 

1421 Only coords at forks, where the path has >=2 options for the next step are included. 

1422 Excludes the option of backtracking. 

1423 The starting and ending coords are always included. 

1424 """ 

1425 

1426 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 

1427 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 

1428 return maze.get_solution_forking_points(always_include_endpoints=True)[0] 

1429 

1430 @serializable_dataclass(frozen=True, kw_only=True) 

1431 @mark_as_unsupported(lambda self_: False) 

1432 class ForksAndStraightaways(_StepSize): 

1433 """ 

1434 Includes the union of the coords included by `Forks` and `Straightaways`. 

1435 See documentation for those classes for details. 

1436 """ 

1437 

1438 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 

1439 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 

1440 return list( 

1441 np.unique( 

1442 np.concatenate( 

1443 ( 

1444 StepSizes.Straightaways()._step_single_indices(maze), 

1445 StepSizes.Forks()._step_single_indices(maze), 

1446 ) 

1447 ) 

1448 ) 

1449 ) 

1450 

1451 

1452class StepTokenizers(__TokenizerElementNamespace): 

1453 """Namespace for `_StepTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 

1454 

1455 key = "step_tokenizers" 

1456 

1457 @serializable_dataclass(frozen=True, kw_only=True) 

1458 class _StepTokenizer(_TokenizerElement, abc.ABC): 

1459 """ 

1460 Specifies how a single step (as specified by an instance of `_StepSize`) is tokenized. 

1461 """ 

1462 

1463 @classmethod 

1464 def attribute_key(cls) -> str: 

1465 return StepTokenizers.key 

1466 

1467 @abc.abstractmethod 

1468 def to_tokens( 

1469 self, 

1470 maze: SolvedMaze, 

1471 start_index: int, 

1472 end_index: int, 

1473 **kwargs, 

1474 ) -> list[str]: 

1475 """Tokenizes a single step in the solution. 

1476 

1477 # Parameters 

1478 - `maze`: Maze to be tokenized 

1479 - `start_index`: The index of the Coord in `maze.solution` at which the current step starts 

1480 - `end_index`: The index of the Coord in `maze.solution` at which the current step ends 

1481 """ 

1482 raise NotImplementedError( 

1483 "Subclasses must implement `StepTokenizer.to_tokens." 

1484 ) 

1485 

1486 def is_valid(self) -> bool: 

1487 # No invalid instances possible within data member type hint bounds 

1488 return True 

1489 

1490 @serializable_dataclass(frozen=True, kw_only=True) 

1491 class Coord(_StepTokenizer): 

1492 """ 

1493 A direct tokenization of the end position coord represents the step. 

1494 """ 

1495 

1496 def to_tokens( 

1497 self, 

1498 maze: SolvedMaze, 

1499 start_index: int, 

1500 end_index: int, 

1501 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

1502 ) -> list[str]: 

1503 return coord_tokenizer.to_tokens(maze.solution[end_index, ...]) 

1504 

1505 @serializable_dataclass(frozen=True, kw_only=True) 

1506 class Cardinal(_StepTokenizer): 

1507 """ 

1508 A step is tokenized with a cardinal direction token. 

1509 It is the direction of the step from the starting position along the solution. 

1510 """ 

1511 

1512 def to_tokens( 

1513 self, maze: SolvedMaze, start_index: int, end_index: int, **kwargs 

1514 ) -> list[str]: 

1515 return [ 

1516 get_cardinal_direction(maze.solution[start_index : start_index + 2]) 

1517 ] 

1518 

1519 @serializable_dataclass(frozen=True, kw_only=True) 

1520 class Relative(_StepTokenizer): 

1521 """Tokenizes a solution step using relative first-person directions (right, left, forward, etc.). 

1522 To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH. 

1523 Similarly to `Cardinal`, the direction is that of the step from the starting position. 

1524 """ 

1525 

1526 def to_tokens( 

1527 self, maze: SolvedMaze, start_index: int, end_index: int, **kwargs 

1528 ) -> list[str]: 

1529 if start_index == 0: 

1530 start = maze.solution[0] 

1531 previous = start + np.array([1, 0]) 

1532 return [ 

1533 get_relative_direction( 

1534 np.concatenate( 

1535 ( 

1536 np.expand_dims(previous, 0), 

1537 maze.solution[start_index : start_index + 2], 

1538 ), 

1539 axis=0, 

1540 ) 

1541 ) 

1542 ] 

1543 return [ 

1544 get_relative_direction(maze.solution[start_index - 1 : start_index + 2]) 

1545 ] 

1546 

1547 @serializable_dataclass(frozen=True, kw_only=True) 

1548 class Distance(_StepTokenizer): 

1549 """ 

1550 A count of the number of individual steps from the starting point to the end point. 

1551 Contains no information about directionality, only the distance traveled in the step. 

1552 `Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`. 

1553 This constraint is enforced in `_PathTokenizer.is_valid`. 

1554 """ 

1555 

1556 def to_tokens( 

1557 self, maze: SolvedMaze, start_index: int, end_index: int, **kwargs 

1558 ) -> list[str]: 

1559 d: int = end_index - start_index 

1560 return [getattr(VOCAB, f"I_{d:03}")] 

1561 

1562 """ 

1563 `StepTokenizerPermutation` 

1564 A sequence of unique `_StepTokenizer`s. 

1565 This type exists mostly just for the clarity and convenience of `_PathTokenizer` code. 

1566 """ 

1567 StepTokenizerPermutation: type = ( 

1568 tuple[_StepTokenizer] 

1569 | tuple[_StepTokenizer, _StepTokenizer] 

1570 | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer] 

1571 | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer, _StepTokenizer] 

1572 ) 

1573 

1574 

1575class PathTokenizers(__TokenizerElementNamespace): 

1576 """Namespace for `_PathTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 

1577 

1578 key = "path_tokenizer" 

1579 

1580 @serializable_dataclass(frozen=True, kw_only=True) 

1581 class _PathTokenizer(_TokenizerElement, abc.ABC): 

1582 """Superclass of tokenizers for maze solution paths.""" 

1583 

1584 @abc.abstractmethod 

1585 def to_tokens( 

1586 self, maze: SolvedMaze, coord_tokenizer: CoordTokenizers._CoordTokenizer 

1587 ) -> list[str]: 

1588 """Returns tokens representing the solution path.""" 

1589 pass 

1590 

1591 @classmethod 

1592 def attribute_key(cls) -> str: 

1593 return PathTokenizers.key 

1594 

1595 @serializable_dataclass(frozen=True, kw_only=True) 

1596 class StepSequence(_PathTokenizer, abc.ABC): 

1597 """Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path. 

1598 Allows for a sequence of leading and trailing tokens which don't fit the step pattern. 

1599 

1600 # Parameters 

1601 - `step_size`: Selects the size of a single step in the sequence 

1602 - `step_tokenizers`: Selects the combination and permutation of tokens 

1603 - `pre`: Whether all steps include an integral preceding delimiter token 

1604 - `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization. 

1605 - `post`: Whether all steps include an integral following delimiter token 

1606 """ 

1607 

1608 step_size: StepSizes._StepSize = serializable_field( 

1609 default=StepSizes.Singles(), 

1610 loading_fn=lambda x: _load_tokenizer_element(x, StepSizes), 

1611 ) 

1612 step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field( 

1613 default=(StepTokenizers.Coord(),), 

1614 serialization_fn=lambda x: [y.serialize() for y in x], 

1615 loading_fn=lambda x: tuple(x[StepTokenizers.key]), 

1616 ) 

1617 pre: bool = serializable_field(default=False) 

1618 intra: bool = serializable_field(default=False) 

1619 post: bool = serializable_field(default=False) 

1620 

1621 def to_tokens( 

1622 self, maze: SolvedMaze, coord_tokenizer: CoordTokenizers._CoordTokenizer 

1623 ) -> list[str]: 

1624 return [ 

1625 *self._leading_tokens(maze, coord_tokenizer), 

1626 *flatten( 

1627 [ 

1628 self._single_step_tokens(maze, start, end, coord_tokenizer) 

1629 for start, end in self.step_size.step_start_end_indices(maze) 

1630 ] 

1631 ), 

1632 *self._trailing_tokens(maze, coord_tokenizer), 

1633 ] 

1634 

1635 def _single_step_tokens( 

1636 self, 

1637 maze: SolvedMaze, 

1638 i: int, 

1639 j: int, 

1640 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

1641 ) -> list[str]: 

1642 """Returns the token sequence representing a single step along the path.""" 

1643 step_rep_tokens: list[list[str]] = [ 

1644 step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer) 

1645 for step_tokenizer in self.step_tokenizers 

1646 ] 

1647 if self.intra: 

1648 step_rep_tokens_and_intra: list[str] = [None] * ( 

1649 len(step_rep_tokens) * 2 

1650 ) 

1651 step_rep_tokens_and_intra[::2] = step_rep_tokens 

1652 step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len( 

1653 step_rep_tokens 

1654 ) 

1655 step_rep_tokens = list(flatten(step_rep_tokens_and_intra)) 

1656 all_tokens: list[str] = [ 

1657 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), 

1658 *flatten(step_rep_tokens), 

1659 *empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"), 

1660 ] 

1661 return all_tokens 

1662 

1663 def _leading_tokens( 

1664 self, maze: SolvedMaze, coord_tokenizer: CoordTokenizers._CoordTokenizer 

1665 ) -> list[str]: 

1666 """Returns tokens preceding those from the sequence from `_single_step_tokens`. 

1667 Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`. 

1668 <PATH_START> should NOT be included. 

1669 """ 

1670 if StepTokenizers.Coord() in self.step_tokenizers: 

1671 return [ 

1672 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), 

1673 *coord_tokenizer.to_tokens(maze.solution[0, ...]), 

1674 *empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"), 

1675 ] 

1676 return [] 

1677 

1678 def _trailing_tokens( 

1679 self, c: Coord, coord_tokenizer: CoordTokenizers._CoordTokenizer 

1680 ) -> list[str]: 

1681 """Returns tokens following those from the sequence from `_single_step_tokens`. 

1682 <PATH_END> should NOT be included. 

1683 """ 

1684 return [] 

1685 

1686 def is_valid(self) -> bool: 

1687 if len(set(self.step_tokenizers)) != len(self.step_tokenizers): 

1688 # Uninteresting: repeated elements are not useful 

1689 return False 

1690 

1691 if len(self.step_tokenizers) == 1 and isinstance( 

1692 self.step_tokenizers[0], StepTokenizers.Distance 

1693 ): 

1694 # Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required. 

1695 return False 

1696 else: 

1697 return True 

1698 

1699 

1700class PromptSequencers(__TokenizerElementNamespace): 

1701 """Namespace for `_PromptSequencer` subclass hierarchy used by `MazeTokenizerModular`.""" 

1702 

1703 key = "prompt_sequencer" 

1704 

1705 @serializable_dataclass(frozen=True, kw_only=True) 

1706 class _PromptSequencer(_TokenizerElement, abc.ABC): 

1707 """ 

1708 Sequences token regions into a complete maze tokenization. 

1709 

1710 # Parameters 

1711 - `coord_tokenizer`: Tokenizer element which tokenizes a single `Coord` aka maze position. 

1712 - `adj_list_tokenizer`: Tokenizer element which tokenizes the adjacency list of a `LatticeMaze`. 

1713 Uses `coord_tokenizer` to tokenize coords if needed in other `TokenizerElement`s. 

1714 """ 

1715 

1716 coord_tokenizer: CoordTokenizers._CoordTokenizer = serializable_field( 

1717 default=CoordTokenizers.UT(), 

1718 loading_fn=lambda x: _load_tokenizer_element(x, CoordTokenizers), 

1719 ) 

1720 adj_list_tokenizer: AdjListTokenizers._AdjListTokenizer = serializable_field( 

1721 default=AdjListTokenizers.AdjListCoord(), 

1722 loading_fn=lambda x: _load_tokenizer_element(x, AdjListTokenizers), 

1723 ) 

1724 

1725 @classmethod 

1726 def attribute_key(cls) -> str: 

1727 return PromptSequencers.key 

1728 

1729 @staticmethod 

1730 def _trim_if_unsolved_maze( 

1731 untrimmed: list[str], is_untargeted: bool = False, is_unsolved: bool = False 

1732 ): 

1733 """Trims a full `SolvedMaze` prompt if the maze data reflects an unsolved or untargeted maze. 

1734 

1735 # Development 

1736 This implementation should function for `AOTP`, `AOP`, and other concrete classes using any subsequence of AOTP. 

1737 It is not located in `token_utils.py` because it may need to be overridden in more exotic `PromptSequencer` subclasses. 

1738 """ 

1739 if is_untargeted: 

1740 return tokens_between( 

1741 untrimmed, 

1742 VOCAB.ADJLIST_START, 

1743 VOCAB.ADJLIST_END, 

1744 include_start=True, 

1745 include_end=True, 

1746 ) 

1747 if is_unsolved: 

1748 if VOCAB.TARGET_END in untrimmed: 

1749 return tokens_between( 

1750 untrimmed, 

1751 VOCAB.ADJLIST_START, 

1752 VOCAB.TARGET_END, 

1753 include_start=True, 

1754 include_end=True, 

1755 ) 

1756 else: 

1757 return tokens_between( 

1758 untrimmed, 

1759 VOCAB.ADJLIST_START, 

1760 VOCAB.ORIGIN_END, 

1761 include_start=True, 

1762 include_end=True, 

1763 ) 

1764 return untrimmed 

1765 

1766 def to_tokens( 

1767 self, 

1768 maze: LatticeMaze, 

1769 *args, 

1770 **kwargs, 

1771 ) -> list[str]: 

1772 """Returns a complete list of tokens for a given set of maze elements.""" 

1773 untrimmed: list[str] = self._sequence_tokens( 

1774 *self._get_prompt_regions(maze) 

1775 ) 

1776 return self._trim_if_unsolved_maze( 

1777 untrimmed, not hasattr(maze, "start_pos"), not hasattr(maze, "solution") 

1778 ) 

1779 

1780 def _get_prompt_regions( 

1781 self, 

1782 maze: LatticeMaze, 

1783 *args, 

1784 **kwargs, 

1785 ) -> list[list[str]]: 

1786 """Gets the prompt regions of a maze in a fixed sequence. 

1787 

1788 This method is NOT responsible for including/excluding any prompt regions. 

1789 Always return according to the API described under Returns. 

1790 This implementation is expected to be suitable for most `PromptSequencer` subclasses. 

1791 Subclasses may override this method if needed for special behavior. 

1792 

1793 # Returns 

1794 - [0]: list[str] Adjacency list tokens 

1795 - [1]: list[str] Origin tokens 

1796 - [2]: list[str] Target tokens 

1797 - [3]: list[str] Path tokens 

1798 

1799 # `None`-valued Args 

1800 If one or more of `origin`, `target`, or `path` are `None`, that indicates that an unsolved or untargeted maze is being tokenized. 

1801 To ensure unpackability in `_sequence_tokens`, these `None` values are substituted for empty iterables. 

1802 """ 

1803 origin: Coord | None = getattr(maze, "start_pos", None) 

1804 target: list[Coord] | None = [ 

1805 getattr(maze, "end_pos", None) 

1806 ] # TargetTokenizer requires target: Sequence[Coord] 

1807 

1808 return [ 

1809 ( 

1810 self.adj_list_tokenizer.to_tokens( 

1811 maze, coord_tokenizer=self.coord_tokenizer 

1812 ) 

1813 if hasattr(self, "adj_list_tokenizer") 

1814 else [] 

1815 ), 

1816 self.coord_tokenizer.to_tokens(origin) if origin is not None else [], 

1817 ( 

1818 self.target_tokenizer.to_tokens( 

1819 target, coord_tokenizer=self.coord_tokenizer 

1820 ) 

1821 if target[0] is not None and hasattr(self, "target_tokenizer") 

1822 else [] 

1823 ), 

1824 ( 

1825 self.path_tokenizer.to_tokens( 

1826 maze, coord_tokenizer=self.coord_tokenizer 

1827 ) 

1828 if hasattr(maze, "solution") and hasattr(self, "path_tokenizer") 

1829 else [] 

1830 ), 

1831 ] 

1832 

1833 @abc.abstractmethod 

1834 def _sequence_tokens( 

1835 self, 

1836 adj_list: list[str], 

1837 origin: list[str] | None, 

1838 target: list[str] | None, 

1839 path: list[str] | None, 

1840 ) -> list[str]: 

1841 """Sequences token regions into a complete prompt. 

1842 Includes any boundary tokens in `constants.SPECIAL_TOKENS` such as <ADJLIST_START>, <ORIGIN_END>, etc. 

1843 # Parameters 

1844 - `adj_list`: Tokens representing the adjacency list 

1845 - `origin`: Tokens representing the origin 

1846 - `target`: Tokens representing the target 

1847 - `path`: Tokens representing the path 

1848 """ 

1849 pass 

1850 

1851 def is_valid(self) -> bool: 

1852 # No invalid instances possible within data member type hint bounds 

1853 return True 

1854 

1855 @serializable_dataclass(frozen=True, kw_only=True) 

1856 class AOTP(_PromptSequencer): 

1857 """ 

1858 Sequences a prompt as [adjacency list, origin, target, path]. 

1859 

1860 # Parameters 

1861 - `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`. 

1862 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`. 

1863 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. 

1864 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. 

1865 

1866 """ 

1867 

1868 target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field( 

1869 default=TargetTokenizers.Unlabeled(), 

1870 loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers), 

1871 ) 

1872 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( 

1873 default=PathTokenizers.StepSequence(), 

1874 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), 

1875 ) 

1876 

1877 def _sequence_tokens( 

1878 self, 

1879 adj_list: list[str], 

1880 origin: list[str], 

1881 target: list[str], 

1882 path: list[str], 

1883 ) -> list[str]: 

1884 return [ 

1885 VOCAB.ADJLIST_START, 

1886 *adj_list, 

1887 VOCAB.ADJLIST_END, 

1888 VOCAB.ORIGIN_START, 

1889 *origin, 

1890 VOCAB.ORIGIN_END, 

1891 VOCAB.TARGET_START, 

1892 *target, 

1893 VOCAB.TARGET_END, 

1894 VOCAB.PATH_START, 

1895 *path, 

1896 VOCAB.PATH_END, 

1897 ] 

1898 

1899 @serializable_dataclass(frozen=True, kw_only=True) 

1900 class AOP(_PromptSequencer): 

1901 """Sequences a prompt as [adjacency list, origin, path]. 

1902 Still includes "<TARGET_START>" and "<TARGET_END>" tokens, but no representation of the target itself. 

1903 

1904 # Parameters 

1905 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. 

1906 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. 

1907 """ 

1908 

1909 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( 

1910 default=PathTokenizers.StepSequence(), 

1911 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), 

1912 ) 

1913 

1914 def _sequence_tokens( 

1915 self, 

1916 adj_list: list[str], 

1917 origin: list[str], 

1918 target: list[str], 

1919 path: list[str], 

1920 ) -> list[str]: 

1921 return [ 

1922 VOCAB.ADJLIST_START, 

1923 *adj_list, 

1924 VOCAB.ADJLIST_END, 

1925 VOCAB.ORIGIN_START, 

1926 *origin, 

1927 VOCAB.ORIGIN_END, 

1928 VOCAB.TARGET_START, 

1929 VOCAB.TARGET_END, 

1930 VOCAB.PATH_START, 

1931 *path, 

1932 VOCAB.PATH_END, 

1933 ] 

1934 

1935 

1936@serializable_dataclass( 

1937 frozen=True, 

1938 kw_only=True, 

1939 properties_to_serialize=["tokenizer_element_tree_concrete", "name"], 

1940) 

1941class MazeTokenizerModular(SerializableDataclass): 

1942 """Tokenizer for mazes 

1943 

1944 # Parameters 

1945 - `prompt_sequencer`: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt. 

1946 

1947 # Development 

1948 - To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy `TokenizationMode.AOTP_UT_Uniform`. 

1949 - Furthermore, the mapping reflected in `from_legacy` must also be maintained. 

1950 - Updates to `MazeTokenizerModular` or the `_TokenizerElement` hierarchy must maintain that behavior. 

1951 """ 

1952 

1953 prompt_sequencer: PromptSequencers._PromptSequencer = serializable_field( 

1954 default=PromptSequencers.AOTP(), 

1955 loading_fn=lambda x: _load_tokenizer_element(x, PromptSequencers), 

1956 ) 

1957 

1958 def hash_int(self) -> int: 

1959 return int.from_bytes( 

1960 hashlib.blake2b(self.name.encode("utf-8")).digest(), 

1961 byteorder="big", 

1962 ) 

1963 

1964 def __hash__(self): 

1965 "Stable hash to identify unique `MazeTokenizerModular` instances. uses name" 

1966 return self.hash_int() 

1967 

1968 def hash_b64(self, n_bytes: int = 8) -> str: 

1969 """filename-safe base64 encoding of the hash""" 

1970 # Use modulus to ensure the integer fits within n_bytes * 8 bits 

1971 hash_mod: int = self.hash_int() % (1 << (n_bytes * 8)) 

1972 

1973 encoded = base64.b64encode( 

1974 hash_mod.to_bytes(n_bytes, byteorder="big"), altchars=b"-_" 

1975 ).decode() 

1976 

1977 # Remove any padding equals signs 

1978 return encoded.rstrip("=") 

1979 

1980 # Information Querying Methods 

1981 

1982 @cached_property 

1983 def tokenizer_elements(self) -> list[_TokenizerElement]: 

1984 return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()] 

1985 

1986 def tokenizer_element_tree(self, abstract: bool = False) -> str: 

1987 """ 

1988 Returns a string representation of the tree of tokenizer elements contained in `self`. 

1989 

1990 # Parameters 

1991 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. 

1992 """ 

1993 

1994 return "\n".join( 

1995 [ 

1996 type(self).__name__, 

1997 self.prompt_sequencer.tokenizer_element_tree( 

1998 abstract=abstract, depth=1 

1999 ), 

2000 ] 

2001 ) 

2002 

2003 @property 

2004 def tokenizer_element_tree_concrete(self): 

2005 """ 

2006 Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`. 

2007 """ 

2008 return self.tokenizer_element_tree() 

2009 

2010 def tokenizer_element_dict(self) -> dict: 

2011 """ 

2012 Nested dictionary of the internal `TokenizerElement`s. 

2013 """ 

2014 return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()} 

2015 

2016 @property 

2017 def name(self) -> str: 

2018 """Serializes MazeTokenizer into a key for encoding in zanj""" 

2019 return "-".join([type(self).__name__, self.prompt_sequencer.name]) 

2020 

2021 def summary(self) -> dict[str, str]: 

2022 """ 

2023 Single-level dictionary of the internal `TokenizerElement`s. 

2024 """ 

2025 return { 

2026 # "prompt_sequencer": self.prompt_sequencer.name, 

2027 **{elem.attribute_key(): elem.name for elem in self.tokenizer_elements} 

2028 } 

2029 

2030 @staticmethod 

2031 def _type_check(obj: any) -> None: 

2032 """Helper method for `has_element`""" 

2033 if not ( 

2034 isinstance(obj, _TokenizerElement) 

2035 or (isinstance(obj, type) and issubclass(obj, _TokenizerElement)) 

2036 ): 

2037 raise TypeError(f"{obj} is not a `_TokenizerElement` instance or subclass.") 

2038 

2039 def _has_element_singular(self, el: type[_TokenizerElement] | _TokenizerElement): 

2040 """Helper method for `has_element`""" 

2041 self._type_check(el) 

2042 if isinstance(el, type): 

2043 return any([isinstance(e, el) for e in self.tokenizer_elements]) 

2044 else: 

2045 return el in self.tokenizer_elements 

2046 

2047 def has_element( 

2048 self, 

2049 *elements: Sequence[type[_TokenizerElement] | _TokenizerElement], 

2050 ) -> bool: 

2051 """Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`. 

2052 

2053 Querying with a partial subset of `_TokenizerElement` fields is not currently supported. 

2054 To do such a query, assemble multiple calls to `has_elements`. 

2055 

2056 # Parameters 

2057 - `elements`: Singleton or iterable of `_TokenizerElement` instances or classes. 

2058 If an instance is provided, then comparison is done via instance equality. 

2059 If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted. 

2060 """ 

2061 if len(elements) == 1 and isinstance(elements[0], Iterable): 

2062 elements = elements[0] 

2063 return all([self._has_element_singular(e) for e in elements]) 

2064 

2065 def is_valid(self): 

2066 """ 

2067 Returns `True` if `self` is a valid tokenizer. 

2068 Evaluates the validity of all of `self.tokenizer_elements` according to each one's method. 

2069 """ 

2070 return all([el.is_valid() for el in self.tokenizer_elements]) 

2071 

2072 def is_legacy_equivalent(self) -> bool: 

2073 """Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`.""" 

2074 return any( 

2075 [ 

2076 self == MazeTokenizerModular.from_legacy(tok_mode) 

2077 for tok_mode in TokenizationMode 

2078 ] 

2079 ) 

2080 

2081 def is_tested_tokenizer(self, do_assert: bool = False) -> bool: 

2082 """Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers. 

2083 

2084 Since evaluating `all_tokenizers.get_all_tokenizers` is expensive, 

2085 instead checks for membership of `self`'s hash in `get_all_tokenizer_hashes()`. 

2086 

2087 if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested. 

2088 """ 

2089 all_tokenizer_hashes: Int64[np.ndarray, " n_tokenizers"] = ( 

2090 get_all_tokenizer_hashes() 

2091 ) 

2092 hash_index: int = np.searchsorted(all_tokenizer_hashes, hash(self)) 

2093 

2094 in_range: bool = hash_index < len(all_tokenizer_hashes) 

2095 hashes_match: bool = all_tokenizer_hashes[hash_index] == hash(self) 

2096 is_valid: bool = self.is_valid() 

2097 

2098 if do_assert: 

2099 assert in_range, ( 

2100 f"{hash_index = } is invalid, must be at most {len(all_tokenizer_hashes) - 1}" 

2101 ) 

2102 assert hashes_match, ( 

2103 f"{all_tokenizer_hashes[hash_index] = } != {hash(self) = }" 

2104 ) 

2105 assert is_valid, "self.is_valid returns False" 

2106 return True 

2107 else: 

2108 return in_range and hashes_match and is_valid 

2109 

2110 def is_AOTP(self) -> bool: 

2111 return self.has_element(PromptSequencers.AOTP) 

2112 

2113 def is_UT(self) -> bool: 

2114 return self.has_element(CoordTokenizers.UT) 

2115 

2116 # Alternate Constructors 

2117 # ====================== 

2118 

2119 @classmethod 

2120 def from_legacy( 

2121 cls, legacy_maze_tokenizer: MazeTokenizer | TokenizationMode 

2122 ) -> "MazeTokenizerModular": 

2123 """Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance.""" 

2124 if isinstance(legacy_maze_tokenizer, MazeTokenizer): 

2125 legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode 

2126 return { 

2127 TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(), 

2128 TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(), 

2129 TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular( 

2130 prompt_sequencer=PromptSequencers.AOTP( 

2131 coord_tokenizer=CoordTokenizers.CTT() 

2132 ) 

2133 ), 

2134 }[legacy_maze_tokenizer] 

2135 

2136 # Simple properties 

2137 # ================= 

2138 @classmethod 

2139 def from_tokens( 

2140 cls, 

2141 tokens: str | list[str], 

2142 ) -> "MazeTokenizerModular": 

2143 """ 

2144 Infers most `MazeTokenizerModular` parameters from a full sequence of tokens. 

2145 """ 

2146 raise NotImplementedError( 

2147 "Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported" 

2148 ) 

2149 

2150 @property 

2151 def token_arr(self) -> list[str] | None: 

2152 """map from index to token""" 

2153 return VOCAB_LIST 

2154 

2155 @property 

2156 def tokenizer_map(self) -> dict[str, int]: 

2157 """map from token to index""" 

2158 return VOCAB_TOKEN_TO_INDEX 

2159 

2160 @property 

2161 def vocab_size(self) -> int: 

2162 """Number of tokens in the static vocab""" 

2163 return len(VOCAB_LIST) 

2164 

2165 @property 

2166 def n_tokens(self) -> int: 

2167 raise NameError( 

2168 "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead." 

2169 ) 

2170 

2171 @property 

2172 def padding_token_index(self) -> int: 

2173 return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING] 

2174 

2175 # conversion functions 

2176 # ============================================================ 

2177 

2178 def to_tokens( 

2179 self, 

2180 maze: LatticeMaze, 

2181 ) -> list[str]: 

2182 """Converts maze into a list of tokens.""" 

2183 return self.prompt_sequencer.to_tokens(maze) 

2184 

2185 def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]: 

2186 return list( 

2187 flatten( 

2188 [self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords] 

2189 ) 

2190 ) 

2191 

2192 @overload 

2193 def strings_to_coords( 

2194 cls, 

2195 text: str | list[str], 

2196 when_noncoord: Literal["skip"] = "skip", 

2197 ) -> list[CoordTup]: ... 

2198 @overload 

2199 def strings_to_coords( 

2200 cls, 

2201 text: str | list[str], 

2202 when_noncoord: Literal["error"] = "error", 

2203 ) -> list[CoordTup]: ... 

2204 @overload 

2205 def strings_to_coords( 

2206 cls, 

2207 text: str | list[str], 

2208 when_noncoord: Literal["include"] = "include", 

2209 ) -> list[str | CoordTup]: ... 

2210 @classmethod 

2211 def strings_to_coords( 

2212 cls, 

2213 text: str | list[str], 

2214 when_noncoord: WhenMissing = "skip", 

2215 ) -> list[str | CoordTup]: 

2216 warnings.warn( 

2217 "`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.", 

2218 TokenizerPendingDeprecationWarning, 

2219 ) 

2220 return strings_to_coords(text=text, when_noncoord=when_noncoord) 

2221 

2222 @staticmethod 

2223 def encode(text: str | list[str]) -> list[int]: 

2224 """encode a string or list of strings into a list of tokens""" 

2225 try: 

2226 if isinstance(text, str): 

2227 text = text.split() 

2228 return [VOCAB_TOKEN_TO_INDEX[token] for token in text] 

2229 except KeyError as e: 

2230 raise TokenError( 

2231 f"Token {e} not found", 

2232 "in `VOCAB`.", 

2233 ) from e 

2234 

2235 @staticmethod 

2236 def decode( 

2237 token_ids: Sequence[int], joined_tokens: bool = False 

2238 ) -> list[str] | str: 

2239 """decode a list of tokens into a string or list of strings""" 

2240 try: 

2241 output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids] 

2242 except IndexError as e: 

2243 raise TokenError(f"Token index '{e}' not found in `VOCAB`.") from e 

2244 if joined_tokens: 

2245 return " ".join(output) 

2246 else: 

2247 return output 

2248 

2249 

2250_ALL_TOKENIZER_HASHES: Int64[np.ndarray, " n_tokenizers"] 

2251"private array of all tokenizer hashes" 

2252_TOKENIZER_HASHES_PATH: Path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz" 

2253"path to where we expect the hashes file -- in the same dir as this file, by default. change with `set_tokenizer_hashes_path`" 

2254 

2255 

2256def set_tokenizer_hashes_path(path: Path): 

2257 """set path to tokenizer hashes, and reload the hashes if needed 

2258 

2259 the hashes are expected to be stored in and read from `_TOKENIZER_HASHES_PATH`, 

2260 which by default is `Path(__file__).parent / "MazeTokenizerModular_hashes.npz"` or in this file's directory. 

2261 

2262 However, this might not always work, so we provide a way to change this. 

2263 """ 

2264 global _TOKENIZER_HASHES_PATH 

2265 global _ALL_TOKENIZER_HASHES 

2266 

2267 path = Path(path) 

2268 if path.is_dir(): 

2269 path = path / "MazeTokenizerModular_hashes.npz" 

2270 

2271 if not path.is_file(): 

2272 raise FileNotFoundError(f"could not find maze tokenizer hashes file at: {path}") 

2273 

2274 if _TOKENIZER_HASHES_PATH.absolute() != path.absolute(): 

2275 # reload if they aren't equal 

2276 _TOKENIZER_HASHES_PATH = path 

2277 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() 

2278 else: 

2279 # always set to new path 

2280 _TOKENIZER_HASHES_PATH = path 

2281 

2282 

2283def _load_tokenizer_hashes() -> Int64[np.ndarray, " n_tokenizers"]: 

2284 """Loads the sorted list of `all_tokenizers.get_all_tokenizers()` hashes from disk.""" 

2285 global _TOKENIZER_HASHES_PATH 

2286 try: 

2287 path: Path = _TOKENIZER_HASHES_PATH 

2288 return np.load(path)["hashes"] 

2289 except FileNotFoundError as e: 

2290 raise FileNotFoundError( 

2291 "Tokenizers hashes cannot be loaded. To fix this, run", 

2292 "\n`python -m maze-dataset.tokenization.save_hashes` which will save the hashes to", 

2293 "\n`data/MazeTokenizerModular_hashes.npz`", 

2294 "relative to the current working directory -- this is where the code looks for them.", 

2295 ) from e 

2296 

2297 

2298def get_all_tokenizer_hashes() -> Int64[np.ndarray, " n_tokenizers"]: 

2299 global _ALL_TOKENIZER_HASHES 

2300 try: 

2301 got_tokenizers: bool = len(_ALL_TOKENIZER_HASHES) > 0 

2302 if got_tokenizers: 

2303 return _ALL_TOKENIZER_HASHES 

2304 else: 

2305 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() 

2306 except NameError: 

2307 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() 

2308 

2309 return _ALL_TOKENIZER_HASHES