Coverage for tests\unit\maze_dataset\tokenization\test_maze_tokenization.py: 71%

14 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1from pytest import mark, param 

2 

3from maze_dataset import ( 

4 LatticeMazeGenerators, 

5 MazeDataset, 

6 MazeDatasetConfig, 

7 SolvedMaze, 

8) 

9from maze_dataset.testing_utils import LEGACY_AND_EQUIVALENT_TOKENIZERS 

10from maze_dataset.tokenization import MazeTokenizer, MazeTokenizerModular 

11 

12 

13@mark.parametrize( 

14 "tokenizer", 

15 [ 

16 param(tokenizer, id=tokenizer.name) 

17 for tokenizer in LEGACY_AND_EQUIVALENT_TOKENIZERS 

18 ], 

19) 

20def test_tokenization_roundtrip(tokenizer: MazeTokenizer | MazeTokenizerModular): 

21 dataset: MazeDataset = MazeDataset.from_config( 

22 MazeDatasetConfig( 

23 name="test", 

24 grid_n=5, 

25 n_mazes=5, 

26 maze_ctor=LatticeMazeGenerators.gen_dfs, 

27 ), 

28 allow_generation_metadata_filter_mismatch=True, 

29 ) 

30 

31 dataset_tokenized: list[list[str]] = dataset.as_tokens(tokenizer) 

32 # dataset_tokenized_joined: list[str] = dataset.as_tokens( 

33 # tokenizer, join_tokens_individual_maze=True 

34 # ) 

35 

36 # TODO: can't test that these match because order in adjacency list is random 

37 

38 dataset_tokenized_individual: list[list[str]] = [ 

39 maze.as_tokens(tokenizer) for maze in dataset.mazes 

40 ] 

41 

42 mazes_roundtrip: list[SolvedMaze] = [ 

43 SolvedMaze.from_tokens( 

44 tokens=maze_tokens, 

45 maze_tokenizer=tokenizer, 

46 ) 

47 for maze_tokens in dataset_tokenized 

48 ] 

49 

50 mazes_roundtrip_individual: list[SolvedMaze] = [ 

51 SolvedMaze.from_tokens( 

52 tokens=maze_tokens, 

53 maze_tokenizer=tokenizer, 

54 ) 

55 for maze_tokens in dataset_tokenized_individual 

56 ] 

57 

58 # NOTE: can't test the tokenization explicitly because order in adjacency list is random 

59 # test both tokenized as a whole and tokenized individually 

60 # for maze_tok, maze_tok_indiv in zip(dataset_tokenized, dataset_tokenized_individual): 

61 # assert all( 

62 # x == y 

63 # for x, y in zip(maze_tok, maze_tok_indiv) 

64 # ), f"maze_tok: {' '.join(maze_tok)}, maze_tok_indiv: {' '.join(maze_tok_indiv)}" 

65 

66 # test roundtrip 

67 for maze, maze_rt, maze_rt_indiv in zip( 

68 dataset.mazes, mazes_roundtrip, mazes_roundtrip_individual 

69 ): 

70 assert maze == maze_rt, f"maze: {maze}, maze_rt: {maze_rt}" 

71 assert maze == maze_rt_indiv, f"maze: {maze}, maze_rt_indiv: {maze_rt_indiv}"