Coverage for tests\unit\maze_dataset\tokenization\test_maze_tokenization.py: 71%
14 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1from pytest import mark, param
3from maze_dataset import (
4 LatticeMazeGenerators,
5 MazeDataset,
6 MazeDatasetConfig,
7 SolvedMaze,
8)
9from maze_dataset.testing_utils import LEGACY_AND_EQUIVALENT_TOKENIZERS
10from maze_dataset.tokenization import MazeTokenizer, MazeTokenizerModular
13@mark.parametrize(
14 "tokenizer",
15 [
16 param(tokenizer, id=tokenizer.name)
17 for tokenizer in LEGACY_AND_EQUIVALENT_TOKENIZERS
18 ],
19)
20def test_tokenization_roundtrip(tokenizer: MazeTokenizer | MazeTokenizerModular):
21 dataset: MazeDataset = MazeDataset.from_config(
22 MazeDatasetConfig(
23 name="test",
24 grid_n=5,
25 n_mazes=5,
26 maze_ctor=LatticeMazeGenerators.gen_dfs,
27 ),
28 allow_generation_metadata_filter_mismatch=True,
29 )
31 dataset_tokenized: list[list[str]] = dataset.as_tokens(tokenizer)
32 # dataset_tokenized_joined: list[str] = dataset.as_tokens(
33 # tokenizer, join_tokens_individual_maze=True
34 # )
36 # TODO: can't test that these match because order in adjacency list is random
38 dataset_tokenized_individual: list[list[str]] = [
39 maze.as_tokens(tokenizer) for maze in dataset.mazes
40 ]
42 mazes_roundtrip: list[SolvedMaze] = [
43 SolvedMaze.from_tokens(
44 tokens=maze_tokens,
45 maze_tokenizer=tokenizer,
46 )
47 for maze_tokens in dataset_tokenized
48 ]
50 mazes_roundtrip_individual: list[SolvedMaze] = [
51 SolvedMaze.from_tokens(
52 tokens=maze_tokens,
53 maze_tokenizer=tokenizer,
54 )
55 for maze_tokens in dataset_tokenized_individual
56 ]
58 # NOTE: can't test the tokenization explicitly because order in adjacency list is random
59 # test both tokenized as a whole and tokenized individually
60 # for maze_tok, maze_tok_indiv in zip(dataset_tokenized, dataset_tokenized_individual):
61 # assert all(
62 # x == y
63 # for x, y in zip(maze_tok, maze_tok_indiv)
64 # ), f"maze_tok: {' '.join(maze_tok)}, maze_tok_indiv: {' '.join(maze_tok_indiv)}"
66 # test roundtrip
67 for maze, maze_rt, maze_rt_indiv in zip(
68 dataset.mazes, mazes_roundtrip, mazes_roundtrip_individual
69 ):
70 assert maze == maze_rt, f"maze: {maze}, maze_rt: {maze_rt}"
71 assert maze == maze_rt_indiv, f"maze: {maze}, maze_rt_indiv: {maze_rt_indiv}"