Coverage for tests\unit\maze_dataset\tokenization\test_tokenizer.py: 82%

250 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1import itertools 

2import random 

3import re 

4from collections import Counter 

5from itertools import product 

6from typing import Iterable, Sequence 

7 

8import frozendict 

9import numpy as np 

10from jaxtyping import Int 

11from muutils.misc import flatten 

12from muutils.mlutils import GLOBAL_SEED 

13from pytest import mark, param 

14 

15from maze_dataset import ( 

16 VOCAB, 

17 ConnectionArray, 

18 Coord, 

19 CoordArray, 

20 CoordTup, 

21 LatticeMaze, 

22 MazeDataset, 

23 MazeDatasetConfig, 

24 SolvedMaze, 

25) 

26from maze_dataset.generation import LatticeMazeGenerators 

27from maze_dataset.plotting.print_tokens import color_maze_tokens_AOTP 

28from maze_dataset.testing_utils import ( 

29 ASCII_MAZES, 

30 LEGACY_AND_EQUIVALENT_TOKENIZERS, 

31 MANUAL_MAZE, 

32 MAZE_DATASET, 

33 MIXED_MAZES, 

34) 

35from maze_dataset.token_utils import ( 

36 connection_list_to_adj_list, 

37 equal_except_adj_list_sequence, 

38) 

39from maze_dataset.tokenization import ( 

40 AdjListTokenizers, 

41 CoordTokenizers, 

42 EdgeGroupings, 

43 EdgePermuters, 

44 EdgeSubsets, 

45 MazeTokenizer, 

46 MazeTokenizerModular, 

47 PathTokenizers, 

48 PromptSequencers, 

49 StepSizes, 

50 StepTokenizers, 

51 TargetTokenizers, 

52 TokenizationMode, 

53 _TokenizerElement, 

54) 

55from maze_dataset.utils import all_instances, lattice_max_degrees, manhattan_distance 

56 

57# Use for test fuzzing when there are too many possible tokenizers 

58NUM_TOKENIZERS_TO_TEST = 100 

59 

60 

61@mark.parametrize( 

62 "tok_mode, max_grid_size", 

63 list( 

64 product( 

65 [ 

66 TokenizationMode.AOTP_UT_rasterized, 

67 TokenizationMode.AOTP_UT_uniform, 

68 TokenizationMode.AOTP_CTT_indexed, 

69 ], 

70 [None, 3, 100], 

71 ) 

72 ), 

73) 

74def test_tokenizer_serialization(tok_mode: TokenizationMode, max_grid_size: int | None): 

75 tokenizer: MazeTokenizer = MazeTokenizer( 

76 tokenization_mode=tok_mode, max_grid_size=max_grid_size 

77 ) 

78 

79 serialized: dict = tokenizer.serialize() 

80 print(serialized) 

81 tokenizer_loaded: MazeTokenizer = MazeTokenizer.load(serialized) 

82 

83 assert tokenizer == tokenizer_loaded 

84 

85 

86def test_tokenizer(): 

87 cfg: MazeDatasetConfig = MazeDatasetConfig( 

88 name="test", 

89 grid_n=5, 

90 n_mazes=3, 

91 maze_ctor=LatticeMazeGenerators.gen_dfs, 

92 ) 

93 # to create a dataset, just call MazeDataset.from_config 

94 dataset: MazeDataset = MazeDataset.from_config( 

95 cfg, 

96 do_download=False, 

97 load_local=False, 

98 do_generate=True, 

99 save_local=False, 

100 verbose=True, 

101 gen_parallel=False, 

102 ) 

103 

104 for mode in ( 

105 TokenizationMode.AOTP_UT_rasterized, 

106 TokenizationMode.AOTP_UT_uniform, 

107 TokenizationMode.AOTP_CTT_indexed, 

108 ): 

109 tokenizer: MazeTokenizer = MazeTokenizer( 

110 tokenization_mode=mode, max_grid_size=100 

111 ) 

112 

113 assert tokenizer.name == f"maze_tokenizer-{mode.name}-g{100}" 

114 

115 if mode == TokenizationMode.AOTP_CTT_indexed: 

116 assert tokenizer.node_strings_map is not None 

117 assert 100 < tokenizer.vocab_size < 200 

118 elif mode in ( 

119 TokenizationMode.AOTP_UT_rasterized, 

120 TokenizationMode.AOTP_UT_uniform, 

121 ): 

122 assert tokenizer.node_strings_map is None 

123 assert tokenizer.vocab_size > 10000 

124 

125 assert isinstance(tokenizer.token_arr, Iterable) 

126 assert all(isinstance(token, str) for token in tokenizer.token_arr) 

127 assert len(tokenizer.token_arr) == tokenizer.vocab_size 

128 

129 print(tokenizer.summary()) 

130 

131 for maze in dataset: 

132 # clear the cache here so we test if it works fine on the next loop 

133 tokenizer.clear_cache() 

134 

135 maze_tok = maze.as_tokens(maze_tokenizer=tokenizer) 

136 

137 maze_encoded = tokenizer.encode(maze_tok) 

138 maze_decoded = tokenizer.decode(maze_encoded) 

139 

140 assert maze_tok == maze_decoded 

141 

142 # you can view the tokens directly 

143 print("\nRaw tokens:\n") 

144 print(" ".join(maze_tok)) 

145 

146 maze_recovered = SolvedMaze.from_tokens(maze_tok, maze_tokenizer=tokenizer) 

147 

148 assert (maze.connection_list == maze_recovered.connection_list).all() 

149 

150 # or color and print them in various formats 

151 print("\nColored tokens, raw html:\n") 

152 print(color_maze_tokens_AOTP(maze_tok, fmt="html")) 

153 print("\nColored tokens, raw latex:\n") 

154 print(color_maze_tokens_AOTP(maze_tok, fmt="latex")) 

155 print("\nColored tokens, terminal:\n") 

156 print(color_maze_tokens_AOTP(maze_tok, fmt="terminal")) 

157 

158 

159@mark.parametrize( 

160 "maze_ascii, tokenizer, tokens", 

161 [ 

162 param( 

163 ASCII_MAZES[maze_ascii_key][1], # maze_ascii 

164 tokenizer, # tok_mode 

165 ASCII_MAZES[maze_ascii_key][0], # tokens 

166 id=f"{tokenizer.name}_{maze_ascii_key}", 

167 ) 

168 for maze_ascii_key, tokenizer in product( 

169 ["small_3x3", "big_10x10"], 

170 LEGACY_AND_EQUIVALENT_TOKENIZERS, 

171 ) 

172 ], 

173) 

174def test_maze_to_tokens_roundtrip( 

175 maze_ascii: list[str], 

176 tokenizer: MazeTokenizer | MazeTokenizerModular, 

177 tokens: str, 

178): 

179 if not tokenizer.is_UT(): 

180 # The hardcoded `tokens` assumes a UT tokenizer. 

181 # Here we modify `tokens` to match what a `AOTP_CTT_indexed` tokenizer would produce. 

182 tokens = re.sub(r"\(([0-9]),([0-9])\)", r"(\1 , \2)", tokens) 

183 tokens = re.sub(r"\(([0-9]+ ,)", r"( \1", tokens) 

184 tokens = re.sub(r"(, [0-9]+)\)", r"\1 )", tokens) 

185 tokens_original_split: list[str] = tokens.split() 

186 

187 # join into a single string, and get a maze out 

188 ascii_str: str = "\n".join(maze_ascii) 

189 maze: SolvedMaze = SolvedMaze.from_ascii(ascii_str) 

190 

191 # maze as tokens 

192 tokens_from_maze: list[str] = maze.as_tokens(tokenizer) 

193 

194 # maze round trip 

195 maze_roundtrip: SolvedMaze = SolvedMaze.from_tokens(tokens_from_maze, tokenizer) 

196 tokens_roundtrip: list[str] = maze_roundtrip.as_tokens(tokenizer) 

197 

198 # check that the mazes and tokens are all equivalent 

199 assert maze == maze_roundtrip 

200 assert equal_except_adj_list_sequence(tokens_original_split, tokens_from_maze) 

201 assert equal_except_adj_list_sequence(tokens_original_split, tokens_roundtrip) 

202 

203 

204@mark.parametrize( 

205 "tok_mode, max_grid_size, result", 

206 [ 

207 param( 

208 tok_mode, 

209 max_grid_size, 

210 MazeTokenizer(tokenization_mode=tok_mode, max_grid_size=max_grid_size), 

211 id=f"{tok_mode}-{max_grid_size}", 

212 ) 

213 for tok_mode, max_grid_size in [ 

214 (TokenizationMode.AOTP_CTT_indexed, None), 

215 (TokenizationMode.AOTP_UT_rasterized, None), 

216 (TokenizationMode.AOTP_UT_uniform, None), 

217 (TokenizationMode.AOTP_CTT_indexed, 5), 

218 ] 

219 ], 

220) 

221def test_to_legacy_tokenizer( 

222 tok_mode: TokenizationMode, max_grid_size: int | None, result: MazeTokenizer 

223): 

224 assert tok_mode.to_legacy_tokenizer(max_grid_size) == result 

225 

226 

227# MazeTokenizerModular tests 

228# ===================== 

229 

230# Backwards compatibility tests 

231# ============================= 

232 

233 

234@mark.parametrize( 

235 "maze,legacy_tokenizer", 

236 [ 

237 param(maze[0], tok_spec, id=f"{tok_spec.value}-maze{maze[1]}") 

238 for maze, tok_spec in itertools.product( 

239 [(maze, i) for i, maze in enumerate(MIXED_MAZES)], 

240 [tok_mode for tok_mode in TokenizationMode], 

241 ) 

242 ], 

243) 

244def test_to_tokens_backwards_compatible( 

245 maze: SolvedMaze, legacy_tokenizer: TokenizationMode 

246): 

247 tokenizer: MazeTokenizerModular = MazeTokenizerModular.from_legacy(legacy_tokenizer) 

248 toks: list[str] = maze.as_tokens(tokenizer) 

249 toks2: list[str] = tokenizer.to_tokens(maze) 

250 toks_legacy: list[str] = maze.as_tokens(legacy_tokenizer) 

251 

252 try: 

253 assert equal_except_adj_list_sequence(toks, toks_legacy) 

254 assert equal_except_adj_list_sequence(toks2, toks_legacy) 

255 except AssertionError as e: 

256 raise AssertionError( 

257 "Tokens from `as_tokens` and `to_tokens` should be equal to tokens from `as_tokens` with the legacy tokenizer.\n" 

258 f"{len(toks) = }, {len(toks2) = }, {len(toks_legacy) = }\n" 

259 f"{toks = }\n{toks2 = }\n{toks_legacy = }", 

260 ) from e 

261 

262 

263@mark.parametrize( 

264 "coords, legacy_tok_mode", 

265 [ 

266 param( 

267 coords, 

268 tok_mode, 

269 id=f"{tok_mode.value}-coords(type={type(coords[0])},len={len(coords)})", 

270 ) 

271 for tok_mode, coords in itertools.product( 

272 [tok_mode for tok_mode in TokenizationMode], 

273 [ 

274 *[[maze.start_pos] for maze in MAZE_DATASET.mazes[:2]], 

275 [maze.start_pos for maze in MAZE_DATASET.mazes], 

276 *[[tuple(maze.start_pos)] for maze in MAZE_DATASET.mazes[:2]], 

277 [tuple(maze.start_pos) for maze in MAZE_DATASET.mazes], 

278 ], 

279 ) 

280 ], 

281) 

282def test_coords_to_strings_backwards_compatible( 

283 coords: list[Coord, CoordTup], legacy_tok_mode: TokenizationMode 

284): 

285 tokenizer: MazeTokenizerModular = MazeTokenizerModular.from_legacy(legacy_tok_mode) 

286 legacy_tokenizer = MazeTokenizer(tokenization_mode=legacy_tok_mode) 

287 strings: list[str] = tokenizer.coords_to_strings(coords) 

288 strings_legacy: list[str] = legacy_tokenizer.coords_to_strings(coords) 

289 assert strings == strings_legacy 

290 

291 

292@mark.parametrize( 

293 "maze,tok_mode", 

294 [ 

295 param(maze[0], tok_spec, id=f"{tok_spec.value}-maze{maze[1]}") 

296 for maze, tok_spec in itertools.product( 

297 [(maze, i) for i, maze in enumerate(MIXED_MAZES)], 

298 [tok_mode for tok_mode in TokenizationMode], 

299 ) 

300 ], 

301) 

302def test_from_tokens_backwards_compatible( 

303 maze: LatticeMaze, tok_mode: TokenizationMode 

304): 

305 tokenizer = MazeTokenizerModular.from_legacy(tok_mode) 

306 toks = maze.as_tokens(tok_mode) 

307 # Equality test of `as_tokens` output done in a separate unit test 

308 maze_legacy: LatticeMaze = LatticeMaze.from_tokens(toks, tok_mode) 

309 maze: LatticeMaze = LatticeMaze.from_tokens(toks, tokenizer) 

310 assert maze == maze_legacy 

311 

312 

313# General functionality tests 

314# =========================== 

315 

316 

317@mark.parametrize( 

318 "el, result", 

319 [ 

320 param(elem, result, id=elem.name) 

321 for elem, result in [ 

322 (CoordTokenizers.CTT(), True), 

323 (CoordTokenizers.CTT(intra=True), True), 

324 (CoordTokenizers.UT(), True), 

325 (AdjListTokenizers.AdjListCoord(), True), 

326 (AdjListTokenizers.AdjListCoord(post=True), True), 

327 (TargetTokenizers.Unlabeled(post=True), True), 

328 (PathTokenizers.StepSequence(), True), 

329 ( 

330 PathTokenizers.StepSequence(step_tokenizers=(StepTokenizers.Coord(),)), 

331 True, 

332 ), 

333 ( 

334 PathTokenizers.StepSequence( 

335 step_tokenizers=( 

336 StepTokenizers.Coord(), 

337 StepTokenizers.Coord(), 

338 ) 

339 ), 

340 False, 

341 ), 

342 (PromptSequencers.AOP(), True), 

343 (PromptSequencers.AOP(path_tokenizer=PathTokenizers.StepSequence()), True), 

344 ( 

345 PromptSequencers.AOP( 

346 path_tokenizer=PathTokenizers.StepSequence( 

347 step_tokenizers=(StepTokenizers.Coord(),) 

348 ) 

349 ), 

350 True, 

351 ), 

352 ( 

353 PromptSequencers.AOP( 

354 path_tokenizer=PathTokenizers.StepSequence( 

355 step_tokenizers=( 

356 StepTokenizers.Coord(), 

357 StepTokenizers.Coord(), 

358 ) 

359 ) 

360 ), 

361 True, 

362 ), 

363 ] 

364 ], 

365) 

366def test_tokenizer_element_is_valid(el: _TokenizerElement, result: bool): 

367 assert el.is_valid() == result 

368 

369 

370@mark.parametrize( 

371 "tokenizer, result", 

372 [ 

373 param(tokenizer, result, id=str(tokenizer)) 

374 for tokenizer, result in [ 

375 (MazeTokenizerModular(), True), 

376 (MazeTokenizerModular.from_legacy(TokenizationMode.AOTP_CTT_indexed), True), 

377 (MazeTokenizerModular(prompt_sequencer=PromptSequencers.AOP()), False), 

378 ] 

379 ], 

380) 

381def test_is_legacy_equivalent(tokenizer: MazeTokenizerModular, result: bool): 

382 assert tokenizer.is_legacy_equivalent() == result 

383 

384 

385def _helper_test_path_tokenizers( 

386 pt: PathTokenizers._PathTokenizer, 

387 maze: SolvedMaze, 

388 footprint_inds: Sequence[int], 

389): 

390 ct: CoordTokenizers._CoordTokenizer = CoordTokenizers.UT() 

391 path_toks: list[str] = pt.to_tokens(maze, ct) 

392 path_toks_set: set[str] = set(path_toks) 

393 footprint_inds: Int[np.ndarray, " footprint_index"] = np.array(footprint_inds) 

394 footprints: Int[np.ndarray, "footprint_index row_col=2"] = maze.solution[ 

395 footprint_inds 

396 ] 

397 if StepTokenizers.Coord() in pt.step_tokenizers: 

398 non_steps: set[CoordTup] = set(tuple(c) for c in maze.solution) - set( 

399 tuple(c) for c in footprints 

400 ) 

401 assert all([ct.to_tokens(coord)[0] in path_toks_set for coord in footprints]) 

402 assert all([ct.to_tokens(coord)[0] not in path_toks_set for coord in non_steps]) 

403 if StepTokenizers.Distance() in pt.step_tokenizers: 

404 distances: list[int] = footprint_inds[1:] - footprint_inds[:-1] 

405 assert ( 

406 len( 

407 Counter(getattr(VOCAB, f"I_{d:03}") for d in distances) 

408 - Counter(path_toks) 

409 ) 

410 == 0 

411 ) 

412 if StepTokenizers.Cardinal() in pt.step_tokenizers: 

413 c = Counter(path_toks) 

414 assert ( 

415 c[VOCAB.PATH_NORTH] 

416 + c[VOCAB.PATH_SOUTH] 

417 + c[VOCAB.PATH_EAST] 

418 + c[VOCAB.PATH_WEST] 

419 == len(footprint_inds) - 1 

420 ) 

421 if StepTokenizers.Relative() in pt.step_tokenizers: 

422 c = Counter(path_toks) 

423 assert ( 

424 c[VOCAB.PATH_LEFT] 

425 + c[VOCAB.PATH_RIGHT] 

426 + c[VOCAB.PATH_FORWARD] 

427 + c[VOCAB.PATH_BACKWARD] 

428 == len(footprint_inds) - 1 

429 ) 

430 

431 

432@mark.parametrize( 

433 "pt,manual_maze", 

434 [ 

435 param(tokenizer, maze_kv[1], id=f"{tokenizer.name}-{maze_kv[0]}") 

436 for maze_kv, tokenizer in itertools.product( 

437 ASCII_MAZES.items(), 

438 random.sample( 

439 list( 

440 all_instances( 

441 PathTokenizers._PathTokenizer, 

442 {_TokenizerElement: lambda x: x.is_valid()}, 

443 ) 

444 ), 

445 NUM_TOKENIZERS_TO_TEST, 

446 ), 

447 ) 

448 ], 

449) 

450def test_path_tokenizers(pt: PathTokenizers._PathTokenizer, manual_maze: MANUAL_MAZE): 

451 solved_maze: SolvedMaze = SolvedMaze.from_ascii("\n".join(manual_maze.ascii)) 

452 match type(pt.step_size): 

453 case StepSizes.Singles: 

454 footprint_inds = range(solved_maze.solution.shape[0]) 

455 case StepSizes.Straightaways: 

456 swy_coordtup_set: set[CoordTup] = set( 

457 tuple(c) for c in manual_maze.straightaway_footprints 

458 ) 

459 footprint_inds: list[int] = [ 

460 i 

461 for i, c in enumerate(solved_maze.solution) 

462 if tuple(c) in swy_coordtup_set 

463 ] 

464 case StepSizes.Forks: 

465 footprint_inds = solved_maze.get_solution_forking_points( 

466 always_include_endpoints=True 

467 )[0] 

468 case StepSizes.ForksAndStraightaways: 

469 swy_step_inds: list[int] = StepSizes.Straightaways()._step_single_indices( 

470 solved_maze 

471 ) 

472 footprint_inds: Int[np.ndarray, " footprint_index"] = np.concatenate( 

473 ( 

474 solved_maze.get_solution_forking_points( 

475 always_include_endpoints=True 

476 )[0], 

477 swy_step_inds, 

478 ) 

479 ) 

480 footprint_inds, _ = np.unique(footprint_inds, axis=0, return_index=True) 

481 _helper_test_path_tokenizers( 

482 pt, 

483 solved_maze, 

484 footprint_inds, 

485 ) 

486 

487 

488@mark.parametrize( 

489 "ep,maze", 

490 [ 

491 param(tokenizer, maze, id=f"{tokenizer.name}-maze[{i}]") 

492 for (i, maze), tokenizer in itertools.product( 

493 enumerate(MIXED_MAZES[:6]), 

494 all_instances( 

495 EdgePermuters._EdgePermuter, 

496 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}), 

497 ), 

498 ) 

499 ], 

500) 

501def test_edge_permuters(ep: EdgePermuters._EdgePermuter, maze: LatticeMaze): 

502 edges: ConnectionArray = connection_list_to_adj_list( 

503 maze.connection_list, shuffle_d0=False, shuffle_d1=False 

504 ) 

505 edges_copy: ConnectionArray = connection_list_to_adj_list( 

506 maze.connection_list, shuffle_d0=False, shuffle_d1=False 

507 ) 

508 assert np.array_equal(edges, edges_copy) 

509 old_shape = edges.shape 

510 permuted: ConnectionArray = ep._permute(edges) 

511 match ep: 

512 case EdgePermuters.RandomCoords(): 

513 assert permuted.shape == old_shape 

514 assert edges is permuted 

515 i = 0 

516 while np.array_equal(permuted, edges_copy) and i < 2: 

517 # Permute again in case for small mazes the random selection happened to not change anything 

518 permuted: ConnectionArray = ep._permute(permuted) 

519 i += 1 

520 assert not np.array_equal(permuted, edges_copy) 

521 case EdgePermuters.BothCoords(): 

522 new_shape = old_shape[0] * 2, *old_shape[1:] 

523 n = old_shape[0] 

524 assert permuted.shape == new_shape 

525 assert np.array_equal(permuted[:n, ...], edges_copy) 

526 assert np.array_equal(permuted[:n, 0, :], permuted[n:, 1, :]) 

527 assert np.array_equal(permuted[:n, 1, :], permuted[n:, 0, :]) 

528 assert edges is not permuted 

529 

530 

531@mark.parametrize( 

532 "es,maze", 

533 [ 

534 param(tokenizer, maze, id=f"{tokenizer.name}-maze[{i}]") 

535 for (i, maze), tokenizer in itertools.product( 

536 enumerate(MIXED_MAZES[:6]), 

537 all_instances( 

538 EdgeSubsets._EdgeSubset, 

539 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}), 

540 ), 

541 ) 

542 ], 

543) 

544def test_edge_subsets(es: EdgeSubsets._EdgeSubset, maze: LatticeMaze): 

545 edges: ConnectionArray = es._get_edges(maze) 

546 n: int = maze.grid_n 

547 match type(es): 

548 case EdgeSubsets.AllLatticeEdges: 

549 assert_shape: tuple = (2 * n * (n - 1), 2, 2) 

550 case EdgeSubsets.ConnectionEdges: 

551 if not es.walls: 

552 assert_shape: tuple = (np.count_nonzero(maze.connection_list), 2, 2) 

553 else: 

554 assert_shape: tuple = ( 

555 2 * n * (n - 1) - np.count_nonzero(maze.connection_list), 

556 2, 

557 2, 

558 ) 

559 assert edges.dtype == np.int8 

560 assert assert_shape == tuple(edges.shape) 

561 assert assert_shape == tuple( 

562 np.unique(edges, axis=0).shape 

563 ) # All edges are unique (swapping leading/trailing coords is considered different) 

564 assert np.array_equal( 

565 manhattan_distance(edges), np.array([1] * assert_shape[0], dtype=np.int8) 

566 ) 

567 

568 

569@mark.parametrize( 

570 "tok_elem,es,maze", 

571 [ 

572 param(tok_elem, es, maze, id=f"{tok_elem.name}-{es.name}-maze[{i}]") 

573 for (i, maze), tok_elem, es in itertools.product( 

574 enumerate(MIXED_MAZES[:6]), 

575 all_instances( 

576 EdgeGroupings._EdgeGrouping, 

577 frozendict.frozendict( 

578 { 

579 _TokenizerElement: lambda x: x.is_valid(), 

580 # Add a condition to prune the range space that doesn't affect functionality being tested 

581 EdgeGroupings.ByLeadingCoord: lambda x: x.intra 

582 and x.connection_token_ordinal == 1, 

583 } 

584 ), 

585 ), 

586 all_instances( 

587 EdgeSubsets._EdgeSubset, 

588 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}), 

589 ), 

590 ) 

591 ], 

592) 

593def test_edge_groupings( 

594 tok_elem: EdgeGroupings._EdgeGrouping, 

595 es: EdgeSubsets._EdgeSubset, 

596 maze: LatticeMaze, 

597): 

598 edges: ConnectionArray = es._get_edges(maze) 

599 # n: int = maze.grid_n 

600 groups: Sequence[ConnectionArray] = tok_elem._group_edges(edges) 

601 

602 assert all( 

603 not np.any(np.diff(g[:, 0], axis=0)) for g in groups 

604 ) # Asserts that the leading coord is the same for all edges within each group 

605 match type(tok_elem): 

606 case EdgeGroupings.Ungrouped: 

607 assert_shape = edges.shape[0], 1, 2, 2 

608 assert tuple(groups.shape) == assert_shape 

609 case EdgeGroupings.ByLeadingCoord: 

610 assert len(groups) == np.unique(edges[:, 0, :], axis=0).shape[0] 

611 assert sum(g.shape[0] for g in groups) == edges.shape[0] 

612 # trailing_coords: list[CoordArray] = [g[:, 1, :] for g in groups] 

613 # vector_diffs is the position vector difference between the trailing coords of each group 

614 # These are stacked into a single array since we don't care about maintaining group separation 

615 vector_diffs: CoordArray = np.stack( 

616 list(flatten([np.diff(g[:, 1, :], axis=0) for g in groups], 1)) 

617 ) 

618 if tok_elem.shuffle_group: 

619 allowed_diffs = {(1, -1), (1, 1), (0, 2), (2, 0)} 

620 # The set of all 2D vectors between any 2 coords adjacent to a central coord 

621 allowed_diffs = allowed_diffs.union( 

622 {(-d[0], -d[1]) for d in allowed_diffs} 

623 ) 

624 else: 

625 # If vector_diffs are lexicographically sorted, these are the only possible values. Any other value indicates an error in sorting 

626 allowed_diffs = {(1, -1), (1, 1), (0, 2), (2, 0)} 

627 assert all( 

628 tuple(diff) in allowed_diffs for diff in np.unique(vector_diffs, axis=0) 

629 ) 

630 

631 

632random.seed(GLOBAL_SEED) 

633 

634 

635@mark.parametrize( 

636 "tok_elem,maze", 

637 [ 

638 param(tok_elem, maze, id=f"{tok_elem.name}-maze[{i}]") 

639 for (i, maze), tok_elem in itertools.product( 

640 enumerate(MAZE_DATASET), 

641 random.sample( 

642 list( 

643 all_instances( 

644 AdjListTokenizers._AdjListTokenizer, 

645 { 

646 _TokenizerElement: lambda x: x.is_valid(), 

647 }, 

648 ) 

649 ), 

650 100, 

651 ), 

652 ) 

653 ], 

654) 

655def test_adjlist_tokenizers( 

656 tok_elem: AdjListTokenizers._AdjListTokenizer, maze: LatticeMaze 

657): 

658 toks: list[str] = tok_elem.to_tokens(maze, CoordTokenizers.UT()) 

659 tok_counter: Counter = Counter(toks) 

660 n: int = maze.grid_n 

661 edge_count: int = 1 # To be updated in match/case blocks 

662 group_count: int = 1 # To be updated in match/case blocks 

663 

664 match tok_elem.edge_subset: 

665 case EdgeSubsets.AllLatticeEdges(): 

666 edge_count *= n * (n - 1) * 2 

667 case EdgeSubsets.ConnectionEdges(walls=False): 

668 edge_count *= np.count_nonzero(maze.connection_list) 

669 case EdgeSubsets.ConnectionEdges(walls=True): 

670 edge_count *= n * (n - 1) * 2 - np.count_nonzero(maze.connection_list) 

671 case _: 

672 raise NotImplementedError( 

673 f"`match` case missing for {tok_elem.edge_subset=}" 

674 ) 

675 

676 match tok_elem.edge_permuter: 

677 case EdgePermuters.BothCoords(): 

678 edge_count *= 2 

679 if tok_elem.edge_subset == EdgeSubsets.ConnectionEdges(walls=True): 

680 group_count *= np.count_nonzero( 

681 lattice_max_degrees(n) - maze.coord_degrees() > 0 

682 ) # All coords with 1 adjacent wall, not counting outer boundaries 

683 else: 

684 group_count *= np.count_nonzero( 

685 maze.coord_degrees() > 0 

686 ) # All coords with >0 connections 

687 case EdgePermuters.RandomCoords() | EdgePermuters.SortedCoords(): 

688 edge_count *= 1 

689 group_count = None # Group count is stochastic 

690 

691 match type(tok_elem.edge_grouping): 

692 case EdgeGroupings.Ungrouped: 

693 group_count = edge_count # Override all above cases 

694 case EdgeGroupings.ByLeadingCoord: 

695 if group_count is not None: 

696 group_count *= 1 

697 if tok_elem.edge_grouping.intra: 

698 assert tok_counter[VOCAB.ADJLIST_INTRA] == edge_count 

699 case _: 

700 raise NotImplementedError( 

701 f"`match` case missing for {tok_elem.edge_grouping=}" 

702 ) 

703 

704 match type(tok_elem): 

705 case AdjListTokenizers.AdjListCoord: 

706 pass 

707 case AdjListTokenizers.AdjListCardinal: 

708 assert ( 

709 tok_counter[VOCAB.PATH_NORTH] 

710 + tok_counter[VOCAB.PATH_SOUTH] 

711 + tok_counter[VOCAB.PATH_EAST] 

712 + tok_counter[VOCAB.PATH_WEST] 

713 == edge_count 

714 ) 

715 

716 if group_count is not None: 

717 if tok_elem.pre: 

718 assert tok_counter[VOCAB.ADJLIST_PRE] == group_count 

719 if tok_elem.post: 

720 assert tok_counter[VOCAB.ADJACENCY_ENDLINE] == group_count 

721 

722 assert tok_counter[VOCAB.CONNECTOR] + tok_counter[VOCAB.ADJLIST_WALL] == edge_count 

723 

724 

725@mark.parametrize( 

726 "tok_elem, valid", 

727 [ 

728 param( 

729 tok_elem, 

730 valid, 

731 id=f"{repr(tok_elem)}", 

732 ) 

733 for tok_elem, valid in ( 

734 [ 

735 (StepSizes.ForksAndStraightaways(), False), 

736 (StepSizes.Straightaways(), False), 

737 (StepSizes.Forks(), True), 

738 (AdjListTokenizers.AdjListCoord(), True), 

739 (AdjListTokenizers.AdjListCoord(pre=True), False), 

740 (AdjListTokenizers.AdjListCardinal(), True), 

741 (AdjListTokenizers.AdjListCardinal(pre=True), False), 

742 (EdgeGroupings.Ungrouped(), True), 

743 (EdgeGroupings.ByLeadingCoord(), False), 

744 (EdgeGroupings.ByLeadingCoord(connection_token_ordinal=0), False), 

745 ] 

746 ) 

747 ], 

748) 

749def test_unsupported_elements(tok_elem: _TokenizerElement, valid: bool): 

750 assert tok_elem.is_valid() == valid