Coverage for tests\unit\maze_dataset\tokenization\test_tokenizer.py: 82%
250 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1import itertools
2import random
3import re
4from collections import Counter
5from itertools import product
6from typing import Iterable, Sequence
8import frozendict
9import numpy as np
10from jaxtyping import Int
11from muutils.misc import flatten
12from muutils.mlutils import GLOBAL_SEED
13from pytest import mark, param
15from maze_dataset import (
16 VOCAB,
17 ConnectionArray,
18 Coord,
19 CoordArray,
20 CoordTup,
21 LatticeMaze,
22 MazeDataset,
23 MazeDatasetConfig,
24 SolvedMaze,
25)
26from maze_dataset.generation import LatticeMazeGenerators
27from maze_dataset.plotting.print_tokens import color_maze_tokens_AOTP
28from maze_dataset.testing_utils import (
29 ASCII_MAZES,
30 LEGACY_AND_EQUIVALENT_TOKENIZERS,
31 MANUAL_MAZE,
32 MAZE_DATASET,
33 MIXED_MAZES,
34)
35from maze_dataset.token_utils import (
36 connection_list_to_adj_list,
37 equal_except_adj_list_sequence,
38)
39from maze_dataset.tokenization import (
40 AdjListTokenizers,
41 CoordTokenizers,
42 EdgeGroupings,
43 EdgePermuters,
44 EdgeSubsets,
45 MazeTokenizer,
46 MazeTokenizerModular,
47 PathTokenizers,
48 PromptSequencers,
49 StepSizes,
50 StepTokenizers,
51 TargetTokenizers,
52 TokenizationMode,
53 _TokenizerElement,
54)
55from maze_dataset.utils import all_instances, lattice_max_degrees, manhattan_distance
57# Use for test fuzzing when there are too many possible tokenizers
58NUM_TOKENIZERS_TO_TEST = 100
61@mark.parametrize(
62 "tok_mode, max_grid_size",
63 list(
64 product(
65 [
66 TokenizationMode.AOTP_UT_rasterized,
67 TokenizationMode.AOTP_UT_uniform,
68 TokenizationMode.AOTP_CTT_indexed,
69 ],
70 [None, 3, 100],
71 )
72 ),
73)
74def test_tokenizer_serialization(tok_mode: TokenizationMode, max_grid_size: int | None):
75 tokenizer: MazeTokenizer = MazeTokenizer(
76 tokenization_mode=tok_mode, max_grid_size=max_grid_size
77 )
79 serialized: dict = tokenizer.serialize()
80 print(serialized)
81 tokenizer_loaded: MazeTokenizer = MazeTokenizer.load(serialized)
83 assert tokenizer == tokenizer_loaded
86def test_tokenizer():
87 cfg: MazeDatasetConfig = MazeDatasetConfig(
88 name="test",
89 grid_n=5,
90 n_mazes=3,
91 maze_ctor=LatticeMazeGenerators.gen_dfs,
92 )
93 # to create a dataset, just call MazeDataset.from_config
94 dataset: MazeDataset = MazeDataset.from_config(
95 cfg,
96 do_download=False,
97 load_local=False,
98 do_generate=True,
99 save_local=False,
100 verbose=True,
101 gen_parallel=False,
102 )
104 for mode in (
105 TokenizationMode.AOTP_UT_rasterized,
106 TokenizationMode.AOTP_UT_uniform,
107 TokenizationMode.AOTP_CTT_indexed,
108 ):
109 tokenizer: MazeTokenizer = MazeTokenizer(
110 tokenization_mode=mode, max_grid_size=100
111 )
113 assert tokenizer.name == f"maze_tokenizer-{mode.name}-g{100}"
115 if mode == TokenizationMode.AOTP_CTT_indexed:
116 assert tokenizer.node_strings_map is not None
117 assert 100 < tokenizer.vocab_size < 200
118 elif mode in (
119 TokenizationMode.AOTP_UT_rasterized,
120 TokenizationMode.AOTP_UT_uniform,
121 ):
122 assert tokenizer.node_strings_map is None
123 assert tokenizer.vocab_size > 10000
125 assert isinstance(tokenizer.token_arr, Iterable)
126 assert all(isinstance(token, str) for token in tokenizer.token_arr)
127 assert len(tokenizer.token_arr) == tokenizer.vocab_size
129 print(tokenizer.summary())
131 for maze in dataset:
132 # clear the cache here so we test if it works fine on the next loop
133 tokenizer.clear_cache()
135 maze_tok = maze.as_tokens(maze_tokenizer=tokenizer)
137 maze_encoded = tokenizer.encode(maze_tok)
138 maze_decoded = tokenizer.decode(maze_encoded)
140 assert maze_tok == maze_decoded
142 # you can view the tokens directly
143 print("\nRaw tokens:\n")
144 print(" ".join(maze_tok))
146 maze_recovered = SolvedMaze.from_tokens(maze_tok, maze_tokenizer=tokenizer)
148 assert (maze.connection_list == maze_recovered.connection_list).all()
150 # or color and print them in various formats
151 print("\nColored tokens, raw html:\n")
152 print(color_maze_tokens_AOTP(maze_tok, fmt="html"))
153 print("\nColored tokens, raw latex:\n")
154 print(color_maze_tokens_AOTP(maze_tok, fmt="latex"))
155 print("\nColored tokens, terminal:\n")
156 print(color_maze_tokens_AOTP(maze_tok, fmt="terminal"))
159@mark.parametrize(
160 "maze_ascii, tokenizer, tokens",
161 [
162 param(
163 ASCII_MAZES[maze_ascii_key][1], # maze_ascii
164 tokenizer, # tok_mode
165 ASCII_MAZES[maze_ascii_key][0], # tokens
166 id=f"{tokenizer.name}_{maze_ascii_key}",
167 )
168 for maze_ascii_key, tokenizer in product(
169 ["small_3x3", "big_10x10"],
170 LEGACY_AND_EQUIVALENT_TOKENIZERS,
171 )
172 ],
173)
174def test_maze_to_tokens_roundtrip(
175 maze_ascii: list[str],
176 tokenizer: MazeTokenizer | MazeTokenizerModular,
177 tokens: str,
178):
179 if not tokenizer.is_UT():
180 # The hardcoded `tokens` assumes a UT tokenizer.
181 # Here we modify `tokens` to match what a `AOTP_CTT_indexed` tokenizer would produce.
182 tokens = re.sub(r"\(([0-9]),([0-9])\)", r"(\1 , \2)", tokens)
183 tokens = re.sub(r"\(([0-9]+ ,)", r"( \1", tokens)
184 tokens = re.sub(r"(, [0-9]+)\)", r"\1 )", tokens)
185 tokens_original_split: list[str] = tokens.split()
187 # join into a single string, and get a maze out
188 ascii_str: str = "\n".join(maze_ascii)
189 maze: SolvedMaze = SolvedMaze.from_ascii(ascii_str)
191 # maze as tokens
192 tokens_from_maze: list[str] = maze.as_tokens(tokenizer)
194 # maze round trip
195 maze_roundtrip: SolvedMaze = SolvedMaze.from_tokens(tokens_from_maze, tokenizer)
196 tokens_roundtrip: list[str] = maze_roundtrip.as_tokens(tokenizer)
198 # check that the mazes and tokens are all equivalent
199 assert maze == maze_roundtrip
200 assert equal_except_adj_list_sequence(tokens_original_split, tokens_from_maze)
201 assert equal_except_adj_list_sequence(tokens_original_split, tokens_roundtrip)
204@mark.parametrize(
205 "tok_mode, max_grid_size, result",
206 [
207 param(
208 tok_mode,
209 max_grid_size,
210 MazeTokenizer(tokenization_mode=tok_mode, max_grid_size=max_grid_size),
211 id=f"{tok_mode}-{max_grid_size}",
212 )
213 for tok_mode, max_grid_size in [
214 (TokenizationMode.AOTP_CTT_indexed, None),
215 (TokenizationMode.AOTP_UT_rasterized, None),
216 (TokenizationMode.AOTP_UT_uniform, None),
217 (TokenizationMode.AOTP_CTT_indexed, 5),
218 ]
219 ],
220)
221def test_to_legacy_tokenizer(
222 tok_mode: TokenizationMode, max_grid_size: int | None, result: MazeTokenizer
223):
224 assert tok_mode.to_legacy_tokenizer(max_grid_size) == result
227# MazeTokenizerModular tests
228# =====================
230# Backwards compatibility tests
231# =============================
234@mark.parametrize(
235 "maze,legacy_tokenizer",
236 [
237 param(maze[0], tok_spec, id=f"{tok_spec.value}-maze{maze[1]}")
238 for maze, tok_spec in itertools.product(
239 [(maze, i) for i, maze in enumerate(MIXED_MAZES)],
240 [tok_mode for tok_mode in TokenizationMode],
241 )
242 ],
243)
244def test_to_tokens_backwards_compatible(
245 maze: SolvedMaze, legacy_tokenizer: TokenizationMode
246):
247 tokenizer: MazeTokenizerModular = MazeTokenizerModular.from_legacy(legacy_tokenizer)
248 toks: list[str] = maze.as_tokens(tokenizer)
249 toks2: list[str] = tokenizer.to_tokens(maze)
250 toks_legacy: list[str] = maze.as_tokens(legacy_tokenizer)
252 try:
253 assert equal_except_adj_list_sequence(toks, toks_legacy)
254 assert equal_except_adj_list_sequence(toks2, toks_legacy)
255 except AssertionError as e:
256 raise AssertionError(
257 "Tokens from `as_tokens` and `to_tokens` should be equal to tokens from `as_tokens` with the legacy tokenizer.\n"
258 f"{len(toks) = }, {len(toks2) = }, {len(toks_legacy) = }\n"
259 f"{toks = }\n{toks2 = }\n{toks_legacy = }",
260 ) from e
263@mark.parametrize(
264 "coords, legacy_tok_mode",
265 [
266 param(
267 coords,
268 tok_mode,
269 id=f"{tok_mode.value}-coords(type={type(coords[0])},len={len(coords)})",
270 )
271 for tok_mode, coords in itertools.product(
272 [tok_mode for tok_mode in TokenizationMode],
273 [
274 *[[maze.start_pos] for maze in MAZE_DATASET.mazes[:2]],
275 [maze.start_pos for maze in MAZE_DATASET.mazes],
276 *[[tuple(maze.start_pos)] for maze in MAZE_DATASET.mazes[:2]],
277 [tuple(maze.start_pos) for maze in MAZE_DATASET.mazes],
278 ],
279 )
280 ],
281)
282def test_coords_to_strings_backwards_compatible(
283 coords: list[Coord, CoordTup], legacy_tok_mode: TokenizationMode
284):
285 tokenizer: MazeTokenizerModular = MazeTokenizerModular.from_legacy(legacy_tok_mode)
286 legacy_tokenizer = MazeTokenizer(tokenization_mode=legacy_tok_mode)
287 strings: list[str] = tokenizer.coords_to_strings(coords)
288 strings_legacy: list[str] = legacy_tokenizer.coords_to_strings(coords)
289 assert strings == strings_legacy
292@mark.parametrize(
293 "maze,tok_mode",
294 [
295 param(maze[0], tok_spec, id=f"{tok_spec.value}-maze{maze[1]}")
296 for maze, tok_spec in itertools.product(
297 [(maze, i) for i, maze in enumerate(MIXED_MAZES)],
298 [tok_mode for tok_mode in TokenizationMode],
299 )
300 ],
301)
302def test_from_tokens_backwards_compatible(
303 maze: LatticeMaze, tok_mode: TokenizationMode
304):
305 tokenizer = MazeTokenizerModular.from_legacy(tok_mode)
306 toks = maze.as_tokens(tok_mode)
307 # Equality test of `as_tokens` output done in a separate unit test
308 maze_legacy: LatticeMaze = LatticeMaze.from_tokens(toks, tok_mode)
309 maze: LatticeMaze = LatticeMaze.from_tokens(toks, tokenizer)
310 assert maze == maze_legacy
313# General functionality tests
314# ===========================
317@mark.parametrize(
318 "el, result",
319 [
320 param(elem, result, id=elem.name)
321 for elem, result in [
322 (CoordTokenizers.CTT(), True),
323 (CoordTokenizers.CTT(intra=True), True),
324 (CoordTokenizers.UT(), True),
325 (AdjListTokenizers.AdjListCoord(), True),
326 (AdjListTokenizers.AdjListCoord(post=True), True),
327 (TargetTokenizers.Unlabeled(post=True), True),
328 (PathTokenizers.StepSequence(), True),
329 (
330 PathTokenizers.StepSequence(step_tokenizers=(StepTokenizers.Coord(),)),
331 True,
332 ),
333 (
334 PathTokenizers.StepSequence(
335 step_tokenizers=(
336 StepTokenizers.Coord(),
337 StepTokenizers.Coord(),
338 )
339 ),
340 False,
341 ),
342 (PromptSequencers.AOP(), True),
343 (PromptSequencers.AOP(path_tokenizer=PathTokenizers.StepSequence()), True),
344 (
345 PromptSequencers.AOP(
346 path_tokenizer=PathTokenizers.StepSequence(
347 step_tokenizers=(StepTokenizers.Coord(),)
348 )
349 ),
350 True,
351 ),
352 (
353 PromptSequencers.AOP(
354 path_tokenizer=PathTokenizers.StepSequence(
355 step_tokenizers=(
356 StepTokenizers.Coord(),
357 StepTokenizers.Coord(),
358 )
359 )
360 ),
361 True,
362 ),
363 ]
364 ],
365)
366def test_tokenizer_element_is_valid(el: _TokenizerElement, result: bool):
367 assert el.is_valid() == result
370@mark.parametrize(
371 "tokenizer, result",
372 [
373 param(tokenizer, result, id=str(tokenizer))
374 for tokenizer, result in [
375 (MazeTokenizerModular(), True),
376 (MazeTokenizerModular.from_legacy(TokenizationMode.AOTP_CTT_indexed), True),
377 (MazeTokenizerModular(prompt_sequencer=PromptSequencers.AOP()), False),
378 ]
379 ],
380)
381def test_is_legacy_equivalent(tokenizer: MazeTokenizerModular, result: bool):
382 assert tokenizer.is_legacy_equivalent() == result
385def _helper_test_path_tokenizers(
386 pt: PathTokenizers._PathTokenizer,
387 maze: SolvedMaze,
388 footprint_inds: Sequence[int],
389):
390 ct: CoordTokenizers._CoordTokenizer = CoordTokenizers.UT()
391 path_toks: list[str] = pt.to_tokens(maze, ct)
392 path_toks_set: set[str] = set(path_toks)
393 footprint_inds: Int[np.ndarray, " footprint_index"] = np.array(footprint_inds)
394 footprints: Int[np.ndarray, "footprint_index row_col=2"] = maze.solution[
395 footprint_inds
396 ]
397 if StepTokenizers.Coord() in pt.step_tokenizers:
398 non_steps: set[CoordTup] = set(tuple(c) for c in maze.solution) - set(
399 tuple(c) for c in footprints
400 )
401 assert all([ct.to_tokens(coord)[0] in path_toks_set for coord in footprints])
402 assert all([ct.to_tokens(coord)[0] not in path_toks_set for coord in non_steps])
403 if StepTokenizers.Distance() in pt.step_tokenizers:
404 distances: list[int] = footprint_inds[1:] - footprint_inds[:-1]
405 assert (
406 len(
407 Counter(getattr(VOCAB, f"I_{d:03}") for d in distances)
408 - Counter(path_toks)
409 )
410 == 0
411 )
412 if StepTokenizers.Cardinal() in pt.step_tokenizers:
413 c = Counter(path_toks)
414 assert (
415 c[VOCAB.PATH_NORTH]
416 + c[VOCAB.PATH_SOUTH]
417 + c[VOCAB.PATH_EAST]
418 + c[VOCAB.PATH_WEST]
419 == len(footprint_inds) - 1
420 )
421 if StepTokenizers.Relative() in pt.step_tokenizers:
422 c = Counter(path_toks)
423 assert (
424 c[VOCAB.PATH_LEFT]
425 + c[VOCAB.PATH_RIGHT]
426 + c[VOCAB.PATH_FORWARD]
427 + c[VOCAB.PATH_BACKWARD]
428 == len(footprint_inds) - 1
429 )
432@mark.parametrize(
433 "pt,manual_maze",
434 [
435 param(tokenizer, maze_kv[1], id=f"{tokenizer.name}-{maze_kv[0]}")
436 for maze_kv, tokenizer in itertools.product(
437 ASCII_MAZES.items(),
438 random.sample(
439 list(
440 all_instances(
441 PathTokenizers._PathTokenizer,
442 {_TokenizerElement: lambda x: x.is_valid()},
443 )
444 ),
445 NUM_TOKENIZERS_TO_TEST,
446 ),
447 )
448 ],
449)
450def test_path_tokenizers(pt: PathTokenizers._PathTokenizer, manual_maze: MANUAL_MAZE):
451 solved_maze: SolvedMaze = SolvedMaze.from_ascii("\n".join(manual_maze.ascii))
452 match type(pt.step_size):
453 case StepSizes.Singles:
454 footprint_inds = range(solved_maze.solution.shape[0])
455 case StepSizes.Straightaways:
456 swy_coordtup_set: set[CoordTup] = set(
457 tuple(c) for c in manual_maze.straightaway_footprints
458 )
459 footprint_inds: list[int] = [
460 i
461 for i, c in enumerate(solved_maze.solution)
462 if tuple(c) in swy_coordtup_set
463 ]
464 case StepSizes.Forks:
465 footprint_inds = solved_maze.get_solution_forking_points(
466 always_include_endpoints=True
467 )[0]
468 case StepSizes.ForksAndStraightaways:
469 swy_step_inds: list[int] = StepSizes.Straightaways()._step_single_indices(
470 solved_maze
471 )
472 footprint_inds: Int[np.ndarray, " footprint_index"] = np.concatenate(
473 (
474 solved_maze.get_solution_forking_points(
475 always_include_endpoints=True
476 )[0],
477 swy_step_inds,
478 )
479 )
480 footprint_inds, _ = np.unique(footprint_inds, axis=0, return_index=True)
481 _helper_test_path_tokenizers(
482 pt,
483 solved_maze,
484 footprint_inds,
485 )
488@mark.parametrize(
489 "ep,maze",
490 [
491 param(tokenizer, maze, id=f"{tokenizer.name}-maze[{i}]")
492 for (i, maze), tokenizer in itertools.product(
493 enumerate(MIXED_MAZES[:6]),
494 all_instances(
495 EdgePermuters._EdgePermuter,
496 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}),
497 ),
498 )
499 ],
500)
501def test_edge_permuters(ep: EdgePermuters._EdgePermuter, maze: LatticeMaze):
502 edges: ConnectionArray = connection_list_to_adj_list(
503 maze.connection_list, shuffle_d0=False, shuffle_d1=False
504 )
505 edges_copy: ConnectionArray = connection_list_to_adj_list(
506 maze.connection_list, shuffle_d0=False, shuffle_d1=False
507 )
508 assert np.array_equal(edges, edges_copy)
509 old_shape = edges.shape
510 permuted: ConnectionArray = ep._permute(edges)
511 match ep:
512 case EdgePermuters.RandomCoords():
513 assert permuted.shape == old_shape
514 assert edges is permuted
515 i = 0
516 while np.array_equal(permuted, edges_copy) and i < 2:
517 # Permute again in case for small mazes the random selection happened to not change anything
518 permuted: ConnectionArray = ep._permute(permuted)
519 i += 1
520 assert not np.array_equal(permuted, edges_copy)
521 case EdgePermuters.BothCoords():
522 new_shape = old_shape[0] * 2, *old_shape[1:]
523 n = old_shape[0]
524 assert permuted.shape == new_shape
525 assert np.array_equal(permuted[:n, ...], edges_copy)
526 assert np.array_equal(permuted[:n, 0, :], permuted[n:, 1, :])
527 assert np.array_equal(permuted[:n, 1, :], permuted[n:, 0, :])
528 assert edges is not permuted
531@mark.parametrize(
532 "es,maze",
533 [
534 param(tokenizer, maze, id=f"{tokenizer.name}-maze[{i}]")
535 for (i, maze), tokenizer in itertools.product(
536 enumerate(MIXED_MAZES[:6]),
537 all_instances(
538 EdgeSubsets._EdgeSubset,
539 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}),
540 ),
541 )
542 ],
543)
544def test_edge_subsets(es: EdgeSubsets._EdgeSubset, maze: LatticeMaze):
545 edges: ConnectionArray = es._get_edges(maze)
546 n: int = maze.grid_n
547 match type(es):
548 case EdgeSubsets.AllLatticeEdges:
549 assert_shape: tuple = (2 * n * (n - 1), 2, 2)
550 case EdgeSubsets.ConnectionEdges:
551 if not es.walls:
552 assert_shape: tuple = (np.count_nonzero(maze.connection_list), 2, 2)
553 else:
554 assert_shape: tuple = (
555 2 * n * (n - 1) - np.count_nonzero(maze.connection_list),
556 2,
557 2,
558 )
559 assert edges.dtype == np.int8
560 assert assert_shape == tuple(edges.shape)
561 assert assert_shape == tuple(
562 np.unique(edges, axis=0).shape
563 ) # All edges are unique (swapping leading/trailing coords is considered different)
564 assert np.array_equal(
565 manhattan_distance(edges), np.array([1] * assert_shape[0], dtype=np.int8)
566 )
569@mark.parametrize(
570 "tok_elem,es,maze",
571 [
572 param(tok_elem, es, maze, id=f"{tok_elem.name}-{es.name}-maze[{i}]")
573 for (i, maze), tok_elem, es in itertools.product(
574 enumerate(MIXED_MAZES[:6]),
575 all_instances(
576 EdgeGroupings._EdgeGrouping,
577 frozendict.frozendict(
578 {
579 _TokenizerElement: lambda x: x.is_valid(),
580 # Add a condition to prune the range space that doesn't affect functionality being tested
581 EdgeGroupings.ByLeadingCoord: lambda x: x.intra
582 and x.connection_token_ordinal == 1,
583 }
584 ),
585 ),
586 all_instances(
587 EdgeSubsets._EdgeSubset,
588 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}),
589 ),
590 )
591 ],
592)
593def test_edge_groupings(
594 tok_elem: EdgeGroupings._EdgeGrouping,
595 es: EdgeSubsets._EdgeSubset,
596 maze: LatticeMaze,
597):
598 edges: ConnectionArray = es._get_edges(maze)
599 # n: int = maze.grid_n
600 groups: Sequence[ConnectionArray] = tok_elem._group_edges(edges)
602 assert all(
603 not np.any(np.diff(g[:, 0], axis=0)) for g in groups
604 ) # Asserts that the leading coord is the same for all edges within each group
605 match type(tok_elem):
606 case EdgeGroupings.Ungrouped:
607 assert_shape = edges.shape[0], 1, 2, 2
608 assert tuple(groups.shape) == assert_shape
609 case EdgeGroupings.ByLeadingCoord:
610 assert len(groups) == np.unique(edges[:, 0, :], axis=0).shape[0]
611 assert sum(g.shape[0] for g in groups) == edges.shape[0]
612 # trailing_coords: list[CoordArray] = [g[:, 1, :] for g in groups]
613 # vector_diffs is the position vector difference between the trailing coords of each group
614 # These are stacked into a single array since we don't care about maintaining group separation
615 vector_diffs: CoordArray = np.stack(
616 list(flatten([np.diff(g[:, 1, :], axis=0) for g in groups], 1))
617 )
618 if tok_elem.shuffle_group:
619 allowed_diffs = {(1, -1), (1, 1), (0, 2), (2, 0)}
620 # The set of all 2D vectors between any 2 coords adjacent to a central coord
621 allowed_diffs = allowed_diffs.union(
622 {(-d[0], -d[1]) for d in allowed_diffs}
623 )
624 else:
625 # If vector_diffs are lexicographically sorted, these are the only possible values. Any other value indicates an error in sorting
626 allowed_diffs = {(1, -1), (1, 1), (0, 2), (2, 0)}
627 assert all(
628 tuple(diff) in allowed_diffs for diff in np.unique(vector_diffs, axis=0)
629 )
632random.seed(GLOBAL_SEED)
635@mark.parametrize(
636 "tok_elem,maze",
637 [
638 param(tok_elem, maze, id=f"{tok_elem.name}-maze[{i}]")
639 for (i, maze), tok_elem in itertools.product(
640 enumerate(MAZE_DATASET),
641 random.sample(
642 list(
643 all_instances(
644 AdjListTokenizers._AdjListTokenizer,
645 {
646 _TokenizerElement: lambda x: x.is_valid(),
647 },
648 )
649 ),
650 100,
651 ),
652 )
653 ],
654)
655def test_adjlist_tokenizers(
656 tok_elem: AdjListTokenizers._AdjListTokenizer, maze: LatticeMaze
657):
658 toks: list[str] = tok_elem.to_tokens(maze, CoordTokenizers.UT())
659 tok_counter: Counter = Counter(toks)
660 n: int = maze.grid_n
661 edge_count: int = 1 # To be updated in match/case blocks
662 group_count: int = 1 # To be updated in match/case blocks
664 match tok_elem.edge_subset:
665 case EdgeSubsets.AllLatticeEdges():
666 edge_count *= n * (n - 1) * 2
667 case EdgeSubsets.ConnectionEdges(walls=False):
668 edge_count *= np.count_nonzero(maze.connection_list)
669 case EdgeSubsets.ConnectionEdges(walls=True):
670 edge_count *= n * (n - 1) * 2 - np.count_nonzero(maze.connection_list)
671 case _:
672 raise NotImplementedError(
673 f"`match` case missing for {tok_elem.edge_subset=}"
674 )
676 match tok_elem.edge_permuter:
677 case EdgePermuters.BothCoords():
678 edge_count *= 2
679 if tok_elem.edge_subset == EdgeSubsets.ConnectionEdges(walls=True):
680 group_count *= np.count_nonzero(
681 lattice_max_degrees(n) - maze.coord_degrees() > 0
682 ) # All coords with 1 adjacent wall, not counting outer boundaries
683 else:
684 group_count *= np.count_nonzero(
685 maze.coord_degrees() > 0
686 ) # All coords with >0 connections
687 case EdgePermuters.RandomCoords() | EdgePermuters.SortedCoords():
688 edge_count *= 1
689 group_count = None # Group count is stochastic
691 match type(tok_elem.edge_grouping):
692 case EdgeGroupings.Ungrouped:
693 group_count = edge_count # Override all above cases
694 case EdgeGroupings.ByLeadingCoord:
695 if group_count is not None:
696 group_count *= 1
697 if tok_elem.edge_grouping.intra:
698 assert tok_counter[VOCAB.ADJLIST_INTRA] == edge_count
699 case _:
700 raise NotImplementedError(
701 f"`match` case missing for {tok_elem.edge_grouping=}"
702 )
704 match type(tok_elem):
705 case AdjListTokenizers.AdjListCoord:
706 pass
707 case AdjListTokenizers.AdjListCardinal:
708 assert (
709 tok_counter[VOCAB.PATH_NORTH]
710 + tok_counter[VOCAB.PATH_SOUTH]
711 + tok_counter[VOCAB.PATH_EAST]
712 + tok_counter[VOCAB.PATH_WEST]
713 == edge_count
714 )
716 if group_count is not None:
717 if tok_elem.pre:
718 assert tok_counter[VOCAB.ADJLIST_PRE] == group_count
719 if tok_elem.post:
720 assert tok_counter[VOCAB.ADJACENCY_ENDLINE] == group_count
722 assert tok_counter[VOCAB.CONNECTOR] + tok_counter[VOCAB.ADJLIST_WALL] == edge_count
725@mark.parametrize(
726 "tok_elem, valid",
727 [
728 param(
729 tok_elem,
730 valid,
731 id=f"{repr(tok_elem)}",
732 )
733 for tok_elem, valid in (
734 [
735 (StepSizes.ForksAndStraightaways(), False),
736 (StepSizes.Straightaways(), False),
737 (StepSizes.Forks(), True),
738 (AdjListTokenizers.AdjListCoord(), True),
739 (AdjListTokenizers.AdjListCoord(pre=True), False),
740 (AdjListTokenizers.AdjListCardinal(), True),
741 (AdjListTokenizers.AdjListCardinal(pre=True), False),
742 (EdgeGroupings.Ungrouped(), True),
743 (EdgeGroupings.ByLeadingCoord(), False),
744 (EdgeGroupings.ByLeadingCoord(connection_token_ordinal=0), False),
745 ]
746 )
747 ],
748)
749def test_unsupported_elements(tok_elem: _TokenizerElement, valid: bool):
750 assert tok_elem.is_valid() == valid