Coverage for tests/unit/maze_dataset/tokenization/test_tokenizer.py: 87%
253 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
1import itertools
2import random
3import re
4from collections import Counter
5from itertools import product
6from typing import Iterable, Sequence
8import frozendict
9import numpy as np
10import pytest
11from jaxtyping import Int
12from muutils.misc import flatten
14from maze_dataset import (
15 VOCAB,
16 ConnectionArray,
17 Coord,
18 CoordArray,
19 CoordTup,
20 LatticeMaze,
21 MazeDataset,
22 MazeDatasetConfig,
23 SolvedMaze,
24)
25from maze_dataset.generation import LatticeMazeGenerators
26from maze_dataset.generation.seed import GLOBAL_SEED
27from maze_dataset.plotting.print_tokens import color_maze_tokens_AOTP
28from maze_dataset.testing_utils import (
29 ASCII_MAZES,
30 LEGACY_AND_EQUIVALENT_TOKENIZERS,
31 MANUAL_MAZE,
32 MAZE_DATASET,
33 MIXED_MAZES,
34)
35from maze_dataset.token_utils import (
36 connection_list_to_adj_list,
37 equal_except_adj_list_sequence,
38)
39from maze_dataset.tokenization import (
40 AdjListTokenizers,
41 CoordTokenizers,
42 EdgeGroupings,
43 EdgePermuters,
44 EdgeSubsets,
45 MazeTokenizer,
46 MazeTokenizerModular,
47 PathTokenizers,
48 PromptSequencers,
49 StepSizes,
50 StepTokenizers,
51 TargetTokenizers,
52 TokenizationMode,
53 _TokenizerElement,
54)
55from maze_dataset.utils import all_instances, lattice_max_degrees, manhattan_distance
57# Use for test fuzzing when there are too many possible tokenizers
58NUM_TOKENIZERS_TO_TEST = 100
61@pytest.mark.parametrize(
62 ("tok_mode", "max_grid_size"),
63 list(
64 product(
65 [
66 TokenizationMode.AOTP_UT_rasterized,
67 TokenizationMode.AOTP_UT_uniform,
68 TokenizationMode.AOTP_CTT_indexed,
69 ],
70 [None, 3, 100],
71 ),
72 ),
73)
74def test_tokenizer_serialization(tok_mode: TokenizationMode, max_grid_size: int | None):
75 tokenizer: MazeTokenizer = MazeTokenizer(
76 tokenization_mode=tok_mode,
77 max_grid_size=max_grid_size,
78 )
80 serialized: dict = tokenizer.serialize()
81 print(serialized)
82 tokenizer_loaded: MazeTokenizer = MazeTokenizer.load(serialized)
84 assert tokenizer == tokenizer_loaded
87def test_tokenizer():
88 cfg: MazeDatasetConfig = MazeDatasetConfig(
89 name="test",
90 grid_n=5,
91 n_mazes=3,
92 maze_ctor=LatticeMazeGenerators.gen_dfs,
93 )
94 # to create a dataset, just call MazeDataset.from_config
95 dataset: MazeDataset = MazeDataset.from_config(
96 cfg,
97 do_download=False,
98 load_local=False,
99 do_generate=True,
100 save_local=False,
101 verbose=True,
102 gen_parallel=False,
103 )
105 for mode in (
106 TokenizationMode.AOTP_UT_rasterized,
107 TokenizationMode.AOTP_UT_uniform,
108 TokenizationMode.AOTP_CTT_indexed,
109 ):
110 tokenizer: MazeTokenizer = MazeTokenizer(
111 tokenization_mode=mode,
112 max_grid_size=100,
113 )
115 assert tokenizer.name == f"maze_tokenizer-{mode.name}-g{100}"
117 if mode == TokenizationMode.AOTP_CTT_indexed:
118 assert tokenizer.node_strings_map is not None
119 assert 100 < tokenizer.vocab_size < 200
120 elif mode in (
121 TokenizationMode.AOTP_UT_rasterized,
122 TokenizationMode.AOTP_UT_uniform,
123 ):
124 assert tokenizer.node_strings_map is None
125 assert tokenizer.vocab_size > 10000
127 assert isinstance(tokenizer.token_arr, Iterable)
128 assert all(isinstance(token, str) for token in tokenizer.token_arr)
129 assert len(tokenizer.token_arr) == tokenizer.vocab_size
131 print(tokenizer.summary())
133 for maze in dataset:
134 # clear the cache here so we test if it works fine on the next loop
135 tokenizer.clear_cache()
137 maze_tok = maze.as_tokens(maze_tokenizer=tokenizer)
139 maze_encoded = tokenizer.encode(maze_tok)
140 maze_decoded = tokenizer.decode(maze_encoded)
142 assert maze_tok == maze_decoded
144 # you can view the tokens directly
145 print("\nRaw tokens:\n")
146 print(" ".join(maze_tok))
148 maze_recovered = SolvedMaze.from_tokens(maze_tok, maze_tokenizer=tokenizer)
150 assert (maze.connection_list == maze_recovered.connection_list).all()
152 # or color and print them in various formats
153 print("\nColored tokens, raw html:\n")
154 print(color_maze_tokens_AOTP(maze_tok, fmt="html"))
155 print("\nColored tokens, raw latex:\n")
156 print(color_maze_tokens_AOTP(maze_tok, fmt="latex"))
157 print("\nColored tokens, terminal:\n")
158 print(color_maze_tokens_AOTP(maze_tok, fmt="terminal"))
161@pytest.mark.parametrize(
162 ("maze_ascii", "tokenizer", "tokens"),
163 [
164 pytest.param(
165 ASCII_MAZES[maze_ascii_key][1], # maze_ascii
166 tokenizer, # tok_mode
167 ASCII_MAZES[maze_ascii_key][0], # tokens
168 id=f"{tokenizer.name}_{maze_ascii_key}",
169 )
170 for maze_ascii_key, tokenizer in product(
171 ["small_3x3", "big_10x10"],
172 LEGACY_AND_EQUIVALENT_TOKENIZERS,
173 )
174 ],
175)
176def test_maze_to_tokens_roundtrip(
177 maze_ascii: list[str],
178 tokenizer: MazeTokenizer | MazeTokenizerModular,
179 tokens: str,
180):
181 if not tokenizer.is_UT():
182 # The hardcoded `tokens` assumes a UT tokenizer.
183 # Here we modify `tokens` to match what a `AOTP_CTT_indexed` tokenizer would produce.
184 tokens = re.sub(r"\(([0-9]),([0-9])\)", r"(\1 , \2)", tokens)
185 tokens = re.sub(r"\(([0-9]+ ,)", r"( \1", tokens)
186 tokens = re.sub(r"(, [0-9]+)\)", r"\1 )", tokens)
187 tokens_original_split: list[str] = tokens.split()
189 # join into a single string, and get a maze out
190 ascii_str: str = "\n".join(maze_ascii)
191 maze: SolvedMaze = SolvedMaze.from_ascii(ascii_str)
193 # maze as tokens
194 tokens_from_maze: list[str] = maze.as_tokens(tokenizer)
196 # maze round trip
197 maze_roundtrip: SolvedMaze = SolvedMaze.from_tokens(tokens_from_maze, tokenizer)
198 tokens_roundtrip: list[str] = maze_roundtrip.as_tokens(tokenizer)
200 # check that the mazes and tokens are all equivalent
201 assert maze == maze_roundtrip
202 assert equal_except_adj_list_sequence(tokens_original_split, tokens_from_maze)
203 assert equal_except_adj_list_sequence(tokens_original_split, tokens_roundtrip)
206@pytest.mark.parametrize(
207 ("tok_mode", "max_grid_size", "result"),
208 [
209 pytest.param(
210 tok_mode,
211 max_grid_size,
212 MazeTokenizer(tokenization_mode=tok_mode, max_grid_size=max_grid_size),
213 id=f"{tok_mode}-{max_grid_size}",
214 )
215 for tok_mode, max_grid_size in [
216 (TokenizationMode.AOTP_CTT_indexed, None),
217 (TokenizationMode.AOTP_UT_rasterized, None),
218 (TokenizationMode.AOTP_UT_uniform, None),
219 (TokenizationMode.AOTP_CTT_indexed, 5),
220 ]
221 ],
222)
223def test_to_legacy_tokenizer(
224 tok_mode: TokenizationMode,
225 max_grid_size: int | None,
226 result: MazeTokenizer,
227):
228 assert tok_mode.to_legacy_tokenizer(max_grid_size) == result
231# MazeTokenizerModular tests
232# =====================
234# Backwards compatibility tests
235# =============================
238@pytest.mark.parametrize(
239 ("maze", "legacy_tokenizer"),
240 [
241 pytest.param(maze[0], tok_spec, id=f"{tok_spec.value}-maze{maze[1]}")
242 for maze, tok_spec in itertools.product(
243 [(maze, i) for i, maze in enumerate(MIXED_MAZES)],
244 [tok_mode for tok_mode in TokenizationMode], # noqa: C416
245 )
246 ],
247)
248def test_to_tokens_backwards_compatible(
249 maze: SolvedMaze,
250 legacy_tokenizer: TokenizationMode,
251):
252 tokenizer: MazeTokenizerModular = MazeTokenizerModular.from_legacy(legacy_tokenizer)
253 toks: list[str] = maze.as_tokens(tokenizer)
254 toks2: list[str] = tokenizer.to_tokens(maze)
255 toks_legacy: list[str] = maze.as_tokens(legacy_tokenizer)
257 try:
258 assert equal_except_adj_list_sequence(toks, toks_legacy)
259 assert equal_except_adj_list_sequence(toks2, toks_legacy)
260 except AssertionError as e:
261 msg: str = (
262 "Tokens from `as_tokens` and `to_tokens` should be equal to tokens from `as_tokens` with the legacy tokenizer.\n"
263 f"{len(toks) = }, {len(toks2) = }, {len(toks_legacy) = }\n"
264 f"{toks = }\n{toks2 = }\n{toks_legacy = }"
265 )
266 raise AssertionError(msg) from e
269@pytest.mark.parametrize(
270 ("coords", "legacy_tok_mode"),
271 [
272 pytest.param(
273 coords,
274 tok_mode,
275 id=f"{tok_mode.value}-coords(type={type(coords[0])},len={len(coords)})",
276 )
277 for tok_mode, coords in itertools.product(
278 [tok_mode for tok_mode in TokenizationMode], # noqa: C416
279 [
280 *[[maze.start_pos] for maze in MAZE_DATASET.mazes[:2]],
281 [maze.start_pos for maze in MAZE_DATASET.mazes],
282 *[[tuple(maze.start_pos)] for maze in MAZE_DATASET.mazes[:2]],
283 [tuple(maze.start_pos) for maze in MAZE_DATASET.mazes],
284 ],
285 )
286 ],
287)
288def test_coords_to_strings_backwards_compatible(
289 coords: list[Coord, CoordTup],
290 legacy_tok_mode: TokenizationMode,
291):
292 tokenizer: MazeTokenizerModular = MazeTokenizerModular.from_legacy(legacy_tok_mode)
293 legacy_tokenizer = MazeTokenizer(tokenization_mode=legacy_tok_mode)
294 strings: list[str] = tokenizer.coords_to_strings(coords)
295 strings_legacy: list[str] = legacy_tokenizer.coords_to_strings(coords)
296 assert strings == strings_legacy
299@pytest.mark.parametrize(
300 ("maze", "tok_mode"),
301 [
302 pytest.param(maze[0], tok_spec, id=f"{tok_spec.value}-maze{maze[1]}")
303 for maze, tok_spec in itertools.product(
304 [(maze, i) for i, maze in enumerate(MIXED_MAZES)],
305 [tok_mode for tok_mode in TokenizationMode], # noqa: C416
306 )
307 ],
308)
309def test_from_tokens_backwards_compatible(
310 maze: LatticeMaze,
311 tok_mode: TokenizationMode,
312):
313 tokenizer = MazeTokenizerModular.from_legacy(tok_mode)
314 toks = maze.as_tokens(tok_mode)
315 # Equality test of `as_tokens` output done in a separate unit test
316 maze_legacy: LatticeMaze = LatticeMaze.from_tokens(toks, tok_mode)
317 maze: LatticeMaze = LatticeMaze.from_tokens(toks, tokenizer)
318 assert maze == maze_legacy
321# General functionality tests
322# ===========================
325@pytest.mark.parametrize(
326 ("el", "result"),
327 [
328 pytest.param(elem, result, id=elem.name)
329 for elem, result in [
330 (CoordTokenizers.CTT(), True),
331 (CoordTokenizers.CTT(intra=True), True),
332 (CoordTokenizers.UT(), True),
333 (AdjListTokenizers.AdjListCoord(), True),
334 (AdjListTokenizers.AdjListCoord(post=True), True),
335 (TargetTokenizers.Unlabeled(post=True), True),
336 (PathTokenizers.StepSequence(), True),
337 (
338 PathTokenizers.StepSequence(step_tokenizers=(StepTokenizers.Coord(),)),
339 True,
340 ),
341 (
342 PathTokenizers.StepSequence(
343 step_tokenizers=(
344 StepTokenizers.Coord(),
345 StepTokenizers.Coord(),
346 ),
347 ),
348 False,
349 ),
350 (PromptSequencers.AOP(), True),
351 (PromptSequencers.AOP(path_tokenizer=PathTokenizers.StepSequence()), True),
352 (
353 PromptSequencers.AOP(
354 path_tokenizer=PathTokenizers.StepSequence(
355 step_tokenizers=(StepTokenizers.Coord(),),
356 ),
357 ),
358 True,
359 ),
360 (
361 PromptSequencers.AOP(
362 path_tokenizer=PathTokenizers.StepSequence(
363 step_tokenizers=(
364 StepTokenizers.Coord(),
365 StepTokenizers.Coord(),
366 ),
367 ),
368 ),
369 True,
370 ),
371 ]
372 ],
373)
374def test_tokenizer_element_is_valid(el: _TokenizerElement, result: bool):
375 assert el.is_valid() == result
378@pytest.mark.parametrize(
379 ("tokenizer", "result"),
380 [
381 pytest.param(tokenizer, result, id=str(tokenizer))
382 for tokenizer, result in [
383 (MazeTokenizerModular(), True),
384 (MazeTokenizerModular.from_legacy(TokenizationMode.AOTP_CTT_indexed), True),
385 (MazeTokenizerModular(prompt_sequencer=PromptSequencers.AOP()), False),
386 ]
387 ],
388)
389def test_is_legacy_equivalent(tokenizer: MazeTokenizerModular, result: bool):
390 assert tokenizer.is_legacy_equivalent() == result
393def _helper_test_path_tokenizers(
394 pt: PathTokenizers._PathTokenizer,
395 maze: SolvedMaze,
396 footprint_inds: Sequence[int],
397):
398 ct: CoordTokenizers._CoordTokenizer = CoordTokenizers.UT()
399 path_toks: list[str] = pt.to_tokens(maze, ct)
400 path_toks_set: set[str] = set(path_toks)
401 footprint_inds: Int[np.ndarray, " footprint_index"] = np.array(footprint_inds)
402 footprints: Int[np.ndarray, "footprint_index row_col=2"] = maze.solution[
403 footprint_inds
404 ]
405 if StepTokenizers.Coord() in pt.step_tokenizers:
406 non_steps: set[CoordTup] = {tuple(c) for c in maze.solution} - {
407 tuple(c) for c in footprints
408 }
409 assert all(ct.to_tokens(coord)[0] in path_toks_set for coord in footprints)
410 assert all(ct.to_tokens(coord)[0] not in path_toks_set for coord in non_steps)
411 if StepTokenizers.Distance() in pt.step_tokenizers:
412 distances: list[int] = footprint_inds[1:] - footprint_inds[:-1]
413 assert (
414 len(
415 Counter(getattr(VOCAB, f"I_{d:03}") for d in distances)
416 - Counter(path_toks),
417 )
418 == 0
419 )
420 if StepTokenizers.Cardinal() in pt.step_tokenizers:
421 c = Counter(path_toks)
422 assert (
423 c[VOCAB.PATH_NORTH]
424 + c[VOCAB.PATH_SOUTH]
425 + c[VOCAB.PATH_EAST]
426 + c[VOCAB.PATH_WEST]
427 == len(footprint_inds) - 1
428 )
429 if StepTokenizers.Relative() in pt.step_tokenizers:
430 c = Counter(path_toks)
431 assert (
432 c[VOCAB.PATH_LEFT]
433 + c[VOCAB.PATH_RIGHT]
434 + c[VOCAB.PATH_FORWARD]
435 + c[VOCAB.PATH_BACKWARD]
436 == len(footprint_inds) - 1
437 )
440@pytest.mark.parametrize(
441 ("pt", "manual_maze"),
442 [
443 pytest.param(tokenizer, maze_kv[1], id=f"{tokenizer.name}-{maze_kv[0]}")
444 for maze_kv, tokenizer in itertools.product(
445 ASCII_MAZES.items(),
446 random.sample(
447 list(
448 all_instances(
449 PathTokenizers._PathTokenizer,
450 {_TokenizerElement: lambda x: x.is_valid()},
451 ),
452 ),
453 NUM_TOKENIZERS_TO_TEST,
454 ),
455 )
456 ],
457)
458def test_path_tokenizers(pt: PathTokenizers._PathTokenizer, manual_maze: MANUAL_MAZE):
459 solved_maze: SolvedMaze = SolvedMaze.from_ascii("\n".join(manual_maze.ascii))
460 match type(pt.step_size):
461 case StepSizes.Singles:
462 footprint_inds = range(solved_maze.solution.shape[0])
463 case StepSizes.Straightaways:
464 swy_coordtup_set: set[CoordTup] = {
465 tuple(c) for c in manual_maze.straightaway_footprints
466 }
467 footprint_inds: list[int] = [
468 i
469 for i, c in enumerate(solved_maze.solution)
470 if tuple(c) in swy_coordtup_set
471 ]
472 case StepSizes.Forks:
473 footprint_inds = solved_maze.get_solution_forking_points(
474 always_include_endpoints=True,
475 )[0]
476 case StepSizes.ForksAndStraightaways:
477 swy_step_inds: list[int] = StepSizes.Straightaways()._step_single_indices(
478 solved_maze,
479 )
480 footprint_inds: Int[np.ndarray, " footprint_index"] = np.concatenate(
481 (
482 solved_maze.get_solution_forking_points(
483 always_include_endpoints=True,
484 )[0],
485 swy_step_inds,
486 ),
487 )
488 footprint_inds, _ = np.unique(footprint_inds, axis=0, return_index=True)
489 _helper_test_path_tokenizers(
490 pt,
491 solved_maze,
492 footprint_inds,
493 )
496@pytest.mark.parametrize(
497 ("ep", "maze"),
498 [
499 pytest.param(tokenizer, maze, id=f"{tokenizer.name}-maze[{i}]")
500 for (i, maze), tokenizer in itertools.product(
501 enumerate(MIXED_MAZES[:6]),
502 all_instances(
503 EdgePermuters._EdgePermuter,
504 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}),
505 ),
506 )
507 ],
508)
509def test_edge_permuters(ep: EdgePermuters._EdgePermuter, maze: LatticeMaze):
510 edges: ConnectionArray = connection_list_to_adj_list(
511 maze.connection_list,
512 shuffle_d0=False,
513 shuffle_d1=False,
514 )
515 edges_copy: ConnectionArray = connection_list_to_adj_list(
516 maze.connection_list,
517 shuffle_d0=False,
518 shuffle_d1=False,
519 )
520 assert np.array_equal(edges, edges_copy)
521 old_shape = edges.shape
522 permuted: ConnectionArray = ep._permute(edges)
523 match ep:
524 case EdgePermuters.RandomCoords():
525 assert permuted.shape == old_shape
526 assert edges is permuted
527 i = 0
528 while np.array_equal(permuted, edges_copy) and i < 2:
529 # Permute again in case for small mazes the random selection happened to not change anything
530 permuted: ConnectionArray = ep._permute(permuted)
531 i += 1
532 assert not np.array_equal(permuted, edges_copy)
533 case EdgePermuters.BothCoords():
534 new_shape = old_shape[0] * 2, *old_shape[1:]
535 n = old_shape[0]
536 assert permuted.shape == new_shape
537 assert np.array_equal(permuted[:n, ...], edges_copy)
538 assert np.array_equal(permuted[:n, 0, :], permuted[n:, 1, :])
539 assert np.array_equal(permuted[:n, 1, :], permuted[n:, 0, :])
540 assert edges is not permuted
543@pytest.mark.parametrize(
544 ("es", "maze"),
545 [
546 pytest.param(tokenizer, maze, id=f"{tokenizer.name}-maze[{i}]")
547 for (i, maze), tokenizer in itertools.product(
548 enumerate(MIXED_MAZES[:6]),
549 all_instances(
550 EdgeSubsets._EdgeSubset,
551 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}),
552 ),
553 )
554 ],
555)
556def test_edge_subsets(es: EdgeSubsets._EdgeSubset, maze: LatticeMaze):
557 edges: ConnectionArray = es._get_edges(maze)
558 n: int = maze.grid_n
559 match type(es):
560 case EdgeSubsets.AllLatticeEdges:
561 assert_shape: tuple = (2 * n * (n - 1), 2, 2)
562 case EdgeSubsets.ConnectionEdges:
563 if not es.walls:
564 assert_shape: tuple = (np.count_nonzero(maze.connection_list), 2, 2)
565 else:
566 assert_shape: tuple = (
567 2 * n * (n - 1) - np.count_nonzero(maze.connection_list),
568 2,
569 2,
570 )
571 assert edges.dtype == np.int8
572 assert assert_shape == tuple(edges.shape)
573 assert assert_shape == tuple(
574 np.unique(edges, axis=0).shape,
575 ) # All edges are unique (swapping leading/trailing coords is considered different)
576 assert np.array_equal(
577 manhattan_distance(edges),
578 np.array([1] * assert_shape[0], dtype=np.int8),
579 )
582@pytest.mark.parametrize(
583 ("tok_elem", "es", "maze"),
584 [
585 # we do a little accessing private members here
586 pytest.param(tok_elem, es, maze, id=f"{tok_elem.name}-{es.name}-maze[{i}]")
587 for (i, maze), tok_elem, es in itertools.product(
588 enumerate(MIXED_MAZES[:6]),
589 all_instances(
590 EdgeGroupings._EdgeGrouping,
591 frozendict.frozendict(
592 {
593 _TokenizerElement: lambda x: x.is_valid(),
594 # Add a condition to prune the range space that doesn't affect functionality being tested
595 EdgeGroupings.ByLeadingCoord: lambda x: x.intra
596 and x.connection_token_ordinal == 1,
597 },
598 ),
599 ),
600 all_instances(
601 EdgeSubsets._EdgeSubset,
602 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}),
603 ),
604 )
605 ],
606)
607def test_edge_groupings(
608 tok_elem: EdgeGroupings._EdgeGrouping,
609 es: EdgeSubsets._EdgeSubset,
610 maze: LatticeMaze,
611):
612 # we do a little more accessing private members here
613 edges: ConnectionArray = es._get_edges(maze)
614 # n: int = maze.grid_n
615 groups: Sequence[ConnectionArray] = tok_elem._group_edges(edges)
617 assert all(
618 not np.any(np.diff(g[:, 0], axis=0)) for g in groups
619 ) # Asserts that the leading coord is the same for all edges within each group
620 match type(tok_elem):
621 case EdgeGroupings.Ungrouped:
622 assert_shape = edges.shape[0], 1, 2, 2
623 assert tuple(groups.shape) == assert_shape
624 case EdgeGroupings.ByLeadingCoord:
625 assert len(groups) == np.unique(edges[:, 0, :], axis=0).shape[0]
626 assert sum(g.shape[0] for g in groups) == edges.shape[0]
627 # trailing_coords: list[CoordArray] = [g[:, 1, :] for g in groups]
628 # vector_diffs is the position vector difference between the trailing coords of each group
629 # These are stacked into a single array since we don't care about maintaining group separation
630 vector_diffs: CoordArray = np.stack(
631 list(flatten([np.diff(g[:, 1, :], axis=0) for g in groups], 1)),
632 )
633 if tok_elem.shuffle_group:
634 allowed_diffs = {(1, -1), (1, 1), (0, 2), (2, 0)}
635 # The set of all 2D vectors between any 2 coords adjacent to a central coord
636 allowed_diffs = allowed_diffs.union(
637 {(-d[0], -d[1]) for d in allowed_diffs},
638 )
639 else:
640 # If vector_diffs are lexicographically sorted, these are the only possible values. Any other value indicates an error in sorting
641 allowed_diffs = {(1, -1), (1, 1), (0, 2), (2, 0)}
642 assert all(
643 tuple(diff) in allowed_diffs for diff in np.unique(vector_diffs, axis=0)
644 )
647random.seed(GLOBAL_SEED)
650@pytest.mark.parametrize(
651 ("tok_elem", "maze"),
652 [
653 pytest.param(tok_elem, maze, id=f"{tok_elem.name}-maze[{i}]")
654 for (i, maze), tok_elem in itertools.product(
655 enumerate(MAZE_DATASET),
656 random.sample(
657 list(
658 all_instances(
659 # yes we access a private member
660 AdjListTokenizers._AdjListTokenizer,
661 {
662 _TokenizerElement: lambda x: x.is_valid(),
663 },
664 ),
665 ),
666 100,
667 ),
668 )
669 ],
670)
671# too many branches and "too complex" but whatever
672def test_adjlist_tokenizers( # noqa: PLR0912,C901
673 tok_elem: AdjListTokenizers._AdjListTokenizer,
674 maze: LatticeMaze,
675):
676 toks: list[str] = tok_elem.to_tokens(maze, CoordTokenizers.UT())
677 tok_counter: Counter = Counter(toks)
678 n: int = maze.grid_n
679 edge_count: int = 1 # To be updated in match/case blocks
680 group_count: int = 1 # To be updated in match/case blocks
682 match tok_elem.edge_subset:
683 case EdgeSubsets.AllLatticeEdges():
684 edge_count *= n * (n - 1) * 2
685 case EdgeSubsets.ConnectionEdges(walls=False):
686 edge_count *= np.count_nonzero(maze.connection_list)
687 case EdgeSubsets.ConnectionEdges(walls=True):
688 edge_count *= n * (n - 1) * 2 - np.count_nonzero(maze.connection_list)
689 case _:
690 msg: str = f"`match` case missing for {tok_elem.edge_subset = }"
691 raise NotImplementedError(msg)
693 match tok_elem.edge_permuter:
694 case EdgePermuters.BothCoords():
695 edge_count *= 2
696 if tok_elem.edge_subset == EdgeSubsets.ConnectionEdges(walls=True):
697 group_count *= np.count_nonzero(
698 lattice_max_degrees(n) - maze.coord_degrees() > 0,
699 ) # All coords with 1 adjacent wall, not counting outer boundaries
700 else:
701 group_count *= np.count_nonzero(
702 maze.coord_degrees() > 0,
703 ) # All coords with >0 connections
704 case EdgePermuters.RandomCoords() | EdgePermuters.SortedCoords():
705 edge_count *= 1
706 group_count = None # Group count is stochastic
708 match type(tok_elem.edge_grouping):
709 case EdgeGroupings.Ungrouped:
710 group_count = edge_count # Override all above cases
711 case EdgeGroupings.ByLeadingCoord:
712 if group_count is not None:
713 group_count *= 1
714 if tok_elem.edge_grouping.intra:
715 assert tok_counter[VOCAB.ADJLIST_INTRA] == edge_count
716 case _:
717 msg: str = f"`match` case missing for {tok_elem.edge_grouping = }"
718 raise NotImplementedError(msg)
720 match type(tok_elem):
721 case AdjListTokenizers.AdjListCoord:
722 pass
723 case AdjListTokenizers.AdjListCardinal:
724 assert (
725 tok_counter[VOCAB.PATH_NORTH]
726 + tok_counter[VOCAB.PATH_SOUTH]
727 + tok_counter[VOCAB.PATH_EAST]
728 + tok_counter[VOCAB.PATH_WEST]
729 == edge_count
730 )
732 if group_count is not None:
733 if tok_elem.pre:
734 assert tok_counter[VOCAB.ADJLIST_PRE] == group_count
735 if tok_elem.post:
736 assert tok_counter[VOCAB.ADJACENCY_ENDLINE] == group_count
738 assert tok_counter[VOCAB.CONNECTOR] + tok_counter[VOCAB.ADJLIST_WALL] == edge_count
741@pytest.mark.parametrize(
742 ("tok_elem", "valid"),
743 [
744 pytest.param(
745 tok_elem,
746 valid,
747 id=f"{tok_elem!r}",
748 )
749 for tok_elem, valid in (
750 [
751 (StepSizes.ForksAndStraightaways(), False),
752 (StepSizes.Straightaways(), False),
753 (StepSizes.Forks(), True),
754 (AdjListTokenizers.AdjListCoord(), True),
755 (AdjListTokenizers.AdjListCoord(pre=True), False),
756 (AdjListTokenizers.AdjListCardinal(), True),
757 (AdjListTokenizers.AdjListCardinal(pre=True), False),
758 (EdgeGroupings.Ungrouped(), True),
759 (EdgeGroupings.ByLeadingCoord(), False),
760 (EdgeGroupings.ByLeadingCoord(connection_token_ordinal=0), False),
761 ]
762 )
763 ],
764)
765def test_unsupported_elements(tok_elem: _TokenizerElement, valid: bool):
766 assert tok_elem.is_valid() == valid