Coverage for tests/unit/maze_dataset/tokenization/test_tokenizer.py: 87%

253 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-20 17:51 -0600

1import itertools 

2import random 

3import re 

4from collections import Counter 

5from itertools import product 

6from typing import Iterable, Sequence 

7 

8import frozendict 

9import numpy as np 

10import pytest 

11from jaxtyping import Int 

12from muutils.misc import flatten 

13 

14from maze_dataset import ( 

15 VOCAB, 

16 ConnectionArray, 

17 Coord, 

18 CoordArray, 

19 CoordTup, 

20 LatticeMaze, 

21 MazeDataset, 

22 MazeDatasetConfig, 

23 SolvedMaze, 

24) 

25from maze_dataset.generation import LatticeMazeGenerators 

26from maze_dataset.generation.seed import GLOBAL_SEED 

27from maze_dataset.plotting.print_tokens import color_maze_tokens_AOTP 

28from maze_dataset.testing_utils import ( 

29 ASCII_MAZES, 

30 LEGACY_AND_EQUIVALENT_TOKENIZERS, 

31 MANUAL_MAZE, 

32 MAZE_DATASET, 

33 MIXED_MAZES, 

34) 

35from maze_dataset.token_utils import ( 

36 connection_list_to_adj_list, 

37 equal_except_adj_list_sequence, 

38) 

39from maze_dataset.tokenization import ( 

40 AdjListTokenizers, 

41 CoordTokenizers, 

42 EdgeGroupings, 

43 EdgePermuters, 

44 EdgeSubsets, 

45 MazeTokenizer, 

46 MazeTokenizerModular, 

47 PathTokenizers, 

48 PromptSequencers, 

49 StepSizes, 

50 StepTokenizers, 

51 TargetTokenizers, 

52 TokenizationMode, 

53 _TokenizerElement, 

54) 

55from maze_dataset.utils import all_instances, lattice_max_degrees, manhattan_distance 

56 

57# Use for test fuzzing when there are too many possible tokenizers 

58NUM_TOKENIZERS_TO_TEST = 100 

59 

60 

61@pytest.mark.parametrize( 

62 ("tok_mode", "max_grid_size"), 

63 list( 

64 product( 

65 [ 

66 TokenizationMode.AOTP_UT_rasterized, 

67 TokenizationMode.AOTP_UT_uniform, 

68 TokenizationMode.AOTP_CTT_indexed, 

69 ], 

70 [None, 3, 100], 

71 ), 

72 ), 

73) 

74def test_tokenizer_serialization(tok_mode: TokenizationMode, max_grid_size: int | None): 

75 tokenizer: MazeTokenizer = MazeTokenizer( 

76 tokenization_mode=tok_mode, 

77 max_grid_size=max_grid_size, 

78 ) 

79 

80 serialized: dict = tokenizer.serialize() 

81 print(serialized) 

82 tokenizer_loaded: MazeTokenizer = MazeTokenizer.load(serialized) 

83 

84 assert tokenizer == tokenizer_loaded 

85 

86 

87def test_tokenizer(): 

88 cfg: MazeDatasetConfig = MazeDatasetConfig( 

89 name="test", 

90 grid_n=5, 

91 n_mazes=3, 

92 maze_ctor=LatticeMazeGenerators.gen_dfs, 

93 ) 

94 # to create a dataset, just call MazeDataset.from_config 

95 dataset: MazeDataset = MazeDataset.from_config( 

96 cfg, 

97 do_download=False, 

98 load_local=False, 

99 do_generate=True, 

100 save_local=False, 

101 verbose=True, 

102 gen_parallel=False, 

103 ) 

104 

105 for mode in ( 

106 TokenizationMode.AOTP_UT_rasterized, 

107 TokenizationMode.AOTP_UT_uniform, 

108 TokenizationMode.AOTP_CTT_indexed, 

109 ): 

110 tokenizer: MazeTokenizer = MazeTokenizer( 

111 tokenization_mode=mode, 

112 max_grid_size=100, 

113 ) 

114 

115 assert tokenizer.name == f"maze_tokenizer-{mode.name}-g{100}" 

116 

117 if mode == TokenizationMode.AOTP_CTT_indexed: 

118 assert tokenizer.node_strings_map is not None 

119 assert 100 < tokenizer.vocab_size < 200 

120 elif mode in ( 

121 TokenizationMode.AOTP_UT_rasterized, 

122 TokenizationMode.AOTP_UT_uniform, 

123 ): 

124 assert tokenizer.node_strings_map is None 

125 assert tokenizer.vocab_size > 10000 

126 

127 assert isinstance(tokenizer.token_arr, Iterable) 

128 assert all(isinstance(token, str) for token in tokenizer.token_arr) 

129 assert len(tokenizer.token_arr) == tokenizer.vocab_size 

130 

131 print(tokenizer.summary()) 

132 

133 for maze in dataset: 

134 # clear the cache here so we test if it works fine on the next loop 

135 tokenizer.clear_cache() 

136 

137 maze_tok = maze.as_tokens(maze_tokenizer=tokenizer) 

138 

139 maze_encoded = tokenizer.encode(maze_tok) 

140 maze_decoded = tokenizer.decode(maze_encoded) 

141 

142 assert maze_tok == maze_decoded 

143 

144 # you can view the tokens directly 

145 print("\nRaw tokens:\n") 

146 print(" ".join(maze_tok)) 

147 

148 maze_recovered = SolvedMaze.from_tokens(maze_tok, maze_tokenizer=tokenizer) 

149 

150 assert (maze.connection_list == maze_recovered.connection_list).all() 

151 

152 # or color and print them in various formats 

153 print("\nColored tokens, raw html:\n") 

154 print(color_maze_tokens_AOTP(maze_tok, fmt="html")) 

155 print("\nColored tokens, raw latex:\n") 

156 print(color_maze_tokens_AOTP(maze_tok, fmt="latex")) 

157 print("\nColored tokens, terminal:\n") 

158 print(color_maze_tokens_AOTP(maze_tok, fmt="terminal")) 

159 

160 

161@pytest.mark.parametrize( 

162 ("maze_ascii", "tokenizer", "tokens"), 

163 [ 

164 pytest.param( 

165 ASCII_MAZES[maze_ascii_key][1], # maze_ascii 

166 tokenizer, # tok_mode 

167 ASCII_MAZES[maze_ascii_key][0], # tokens 

168 id=f"{tokenizer.name}_{maze_ascii_key}", 

169 ) 

170 for maze_ascii_key, tokenizer in product( 

171 ["small_3x3", "big_10x10"], 

172 LEGACY_AND_EQUIVALENT_TOKENIZERS, 

173 ) 

174 ], 

175) 

176def test_maze_to_tokens_roundtrip( 

177 maze_ascii: list[str], 

178 tokenizer: MazeTokenizer | MazeTokenizerModular, 

179 tokens: str, 

180): 

181 if not tokenizer.is_UT(): 

182 # The hardcoded `tokens` assumes a UT tokenizer. 

183 # Here we modify `tokens` to match what a `AOTP_CTT_indexed` tokenizer would produce. 

184 tokens = re.sub(r"\(([0-9]),([0-9])\)", r"(\1 , \2)", tokens) 

185 tokens = re.sub(r"\(([0-9]+ ,)", r"( \1", tokens) 

186 tokens = re.sub(r"(, [0-9]+)\)", r"\1 )", tokens) 

187 tokens_original_split: list[str] = tokens.split() 

188 

189 # join into a single string, and get a maze out 

190 ascii_str: str = "\n".join(maze_ascii) 

191 maze: SolvedMaze = SolvedMaze.from_ascii(ascii_str) 

192 

193 # maze as tokens 

194 tokens_from_maze: list[str] = maze.as_tokens(tokenizer) 

195 

196 # maze round trip 

197 maze_roundtrip: SolvedMaze = SolvedMaze.from_tokens(tokens_from_maze, tokenizer) 

198 tokens_roundtrip: list[str] = maze_roundtrip.as_tokens(tokenizer) 

199 

200 # check that the mazes and tokens are all equivalent 

201 assert maze == maze_roundtrip 

202 assert equal_except_adj_list_sequence(tokens_original_split, tokens_from_maze) 

203 assert equal_except_adj_list_sequence(tokens_original_split, tokens_roundtrip) 

204 

205 

206@pytest.mark.parametrize( 

207 ("tok_mode", "max_grid_size", "result"), 

208 [ 

209 pytest.param( 

210 tok_mode, 

211 max_grid_size, 

212 MazeTokenizer(tokenization_mode=tok_mode, max_grid_size=max_grid_size), 

213 id=f"{tok_mode}-{max_grid_size}", 

214 ) 

215 for tok_mode, max_grid_size in [ 

216 (TokenizationMode.AOTP_CTT_indexed, None), 

217 (TokenizationMode.AOTP_UT_rasterized, None), 

218 (TokenizationMode.AOTP_UT_uniform, None), 

219 (TokenizationMode.AOTP_CTT_indexed, 5), 

220 ] 

221 ], 

222) 

223def test_to_legacy_tokenizer( 

224 tok_mode: TokenizationMode, 

225 max_grid_size: int | None, 

226 result: MazeTokenizer, 

227): 

228 assert tok_mode.to_legacy_tokenizer(max_grid_size) == result 

229 

230 

231# MazeTokenizerModular tests 

232# ===================== 

233 

234# Backwards compatibility tests 

235# ============================= 

236 

237 

238@pytest.mark.parametrize( 

239 ("maze", "legacy_tokenizer"), 

240 [ 

241 pytest.param(maze[0], tok_spec, id=f"{tok_spec.value}-maze{maze[1]}") 

242 for maze, tok_spec in itertools.product( 

243 [(maze, i) for i, maze in enumerate(MIXED_MAZES)], 

244 [tok_mode for tok_mode in TokenizationMode], # noqa: C416 

245 ) 

246 ], 

247) 

248def test_to_tokens_backwards_compatible( 

249 maze: SolvedMaze, 

250 legacy_tokenizer: TokenizationMode, 

251): 

252 tokenizer: MazeTokenizerModular = MazeTokenizerModular.from_legacy(legacy_tokenizer) 

253 toks: list[str] = maze.as_tokens(tokenizer) 

254 toks2: list[str] = tokenizer.to_tokens(maze) 

255 toks_legacy: list[str] = maze.as_tokens(legacy_tokenizer) 

256 

257 try: 

258 assert equal_except_adj_list_sequence(toks, toks_legacy) 

259 assert equal_except_adj_list_sequence(toks2, toks_legacy) 

260 except AssertionError as e: 

261 msg: str = ( 

262 "Tokens from `as_tokens` and `to_tokens` should be equal to tokens from `as_tokens` with the legacy tokenizer.\n" 

263 f"{len(toks) = }, {len(toks2) = }, {len(toks_legacy) = }\n" 

264 f"{toks = }\n{toks2 = }\n{toks_legacy = }" 

265 ) 

266 raise AssertionError(msg) from e 

267 

268 

269@pytest.mark.parametrize( 

270 ("coords", "legacy_tok_mode"), 

271 [ 

272 pytest.param( 

273 coords, 

274 tok_mode, 

275 id=f"{tok_mode.value}-coords(type={type(coords[0])},len={len(coords)})", 

276 ) 

277 for tok_mode, coords in itertools.product( 

278 [tok_mode for tok_mode in TokenizationMode], # noqa: C416 

279 [ 

280 *[[maze.start_pos] for maze in MAZE_DATASET.mazes[:2]], 

281 [maze.start_pos for maze in MAZE_DATASET.mazes], 

282 *[[tuple(maze.start_pos)] for maze in MAZE_DATASET.mazes[:2]], 

283 [tuple(maze.start_pos) for maze in MAZE_DATASET.mazes], 

284 ], 

285 ) 

286 ], 

287) 

288def test_coords_to_strings_backwards_compatible( 

289 coords: list[Coord, CoordTup], 

290 legacy_tok_mode: TokenizationMode, 

291): 

292 tokenizer: MazeTokenizerModular = MazeTokenizerModular.from_legacy(legacy_tok_mode) 

293 legacy_tokenizer = MazeTokenizer(tokenization_mode=legacy_tok_mode) 

294 strings: list[str] = tokenizer.coords_to_strings(coords) 

295 strings_legacy: list[str] = legacy_tokenizer.coords_to_strings(coords) 

296 assert strings == strings_legacy 

297 

298 

299@pytest.mark.parametrize( 

300 ("maze", "tok_mode"), 

301 [ 

302 pytest.param(maze[0], tok_spec, id=f"{tok_spec.value}-maze{maze[1]}") 

303 for maze, tok_spec in itertools.product( 

304 [(maze, i) for i, maze in enumerate(MIXED_MAZES)], 

305 [tok_mode for tok_mode in TokenizationMode], # noqa: C416 

306 ) 

307 ], 

308) 

309def test_from_tokens_backwards_compatible( 

310 maze: LatticeMaze, 

311 tok_mode: TokenizationMode, 

312): 

313 tokenizer = MazeTokenizerModular.from_legacy(tok_mode) 

314 toks = maze.as_tokens(tok_mode) 

315 # Equality test of `as_tokens` output done in a separate unit test 

316 maze_legacy: LatticeMaze = LatticeMaze.from_tokens(toks, tok_mode) 

317 maze: LatticeMaze = LatticeMaze.from_tokens(toks, tokenizer) 

318 assert maze == maze_legacy 

319 

320 

321# General functionality tests 

322# =========================== 

323 

324 

325@pytest.mark.parametrize( 

326 ("el", "result"), 

327 [ 

328 pytest.param(elem, result, id=elem.name) 

329 for elem, result in [ 

330 (CoordTokenizers.CTT(), True), 

331 (CoordTokenizers.CTT(intra=True), True), 

332 (CoordTokenizers.UT(), True), 

333 (AdjListTokenizers.AdjListCoord(), True), 

334 (AdjListTokenizers.AdjListCoord(post=True), True), 

335 (TargetTokenizers.Unlabeled(post=True), True), 

336 (PathTokenizers.StepSequence(), True), 

337 ( 

338 PathTokenizers.StepSequence(step_tokenizers=(StepTokenizers.Coord(),)), 

339 True, 

340 ), 

341 ( 

342 PathTokenizers.StepSequence( 

343 step_tokenizers=( 

344 StepTokenizers.Coord(), 

345 StepTokenizers.Coord(), 

346 ), 

347 ), 

348 False, 

349 ), 

350 (PromptSequencers.AOP(), True), 

351 (PromptSequencers.AOP(path_tokenizer=PathTokenizers.StepSequence()), True), 

352 ( 

353 PromptSequencers.AOP( 

354 path_tokenizer=PathTokenizers.StepSequence( 

355 step_tokenizers=(StepTokenizers.Coord(),), 

356 ), 

357 ), 

358 True, 

359 ), 

360 ( 

361 PromptSequencers.AOP( 

362 path_tokenizer=PathTokenizers.StepSequence( 

363 step_tokenizers=( 

364 StepTokenizers.Coord(), 

365 StepTokenizers.Coord(), 

366 ), 

367 ), 

368 ), 

369 True, 

370 ), 

371 ] 

372 ], 

373) 

374def test_tokenizer_element_is_valid(el: _TokenizerElement, result: bool): 

375 assert el.is_valid() == result 

376 

377 

378@pytest.mark.parametrize( 

379 ("tokenizer", "result"), 

380 [ 

381 pytest.param(tokenizer, result, id=str(tokenizer)) 

382 for tokenizer, result in [ 

383 (MazeTokenizerModular(), True), 

384 (MazeTokenizerModular.from_legacy(TokenizationMode.AOTP_CTT_indexed), True), 

385 (MazeTokenizerModular(prompt_sequencer=PromptSequencers.AOP()), False), 

386 ] 

387 ], 

388) 

389def test_is_legacy_equivalent(tokenizer: MazeTokenizerModular, result: bool): 

390 assert tokenizer.is_legacy_equivalent() == result 

391 

392 

393def _helper_test_path_tokenizers( 

394 pt: PathTokenizers._PathTokenizer, 

395 maze: SolvedMaze, 

396 footprint_inds: Sequence[int], 

397): 

398 ct: CoordTokenizers._CoordTokenizer = CoordTokenizers.UT() 

399 path_toks: list[str] = pt.to_tokens(maze, ct) 

400 path_toks_set: set[str] = set(path_toks) 

401 footprint_inds: Int[np.ndarray, " footprint_index"] = np.array(footprint_inds) 

402 footprints: Int[np.ndarray, "footprint_index row_col=2"] = maze.solution[ 

403 footprint_inds 

404 ] 

405 if StepTokenizers.Coord() in pt.step_tokenizers: 

406 non_steps: set[CoordTup] = {tuple(c) for c in maze.solution} - { 

407 tuple(c) for c in footprints 

408 } 

409 assert all(ct.to_tokens(coord)[0] in path_toks_set for coord in footprints) 

410 assert all(ct.to_tokens(coord)[0] not in path_toks_set for coord in non_steps) 

411 if StepTokenizers.Distance() in pt.step_tokenizers: 

412 distances: list[int] = footprint_inds[1:] - footprint_inds[:-1] 

413 assert ( 

414 len( 

415 Counter(getattr(VOCAB, f"I_{d:03}") for d in distances) 

416 - Counter(path_toks), 

417 ) 

418 == 0 

419 ) 

420 if StepTokenizers.Cardinal() in pt.step_tokenizers: 

421 c = Counter(path_toks) 

422 assert ( 

423 c[VOCAB.PATH_NORTH] 

424 + c[VOCAB.PATH_SOUTH] 

425 + c[VOCAB.PATH_EAST] 

426 + c[VOCAB.PATH_WEST] 

427 == len(footprint_inds) - 1 

428 ) 

429 if StepTokenizers.Relative() in pt.step_tokenizers: 

430 c = Counter(path_toks) 

431 assert ( 

432 c[VOCAB.PATH_LEFT] 

433 + c[VOCAB.PATH_RIGHT] 

434 + c[VOCAB.PATH_FORWARD] 

435 + c[VOCAB.PATH_BACKWARD] 

436 == len(footprint_inds) - 1 

437 ) 

438 

439 

440@pytest.mark.parametrize( 

441 ("pt", "manual_maze"), 

442 [ 

443 pytest.param(tokenizer, maze_kv[1], id=f"{tokenizer.name}-{maze_kv[0]}") 

444 for maze_kv, tokenizer in itertools.product( 

445 ASCII_MAZES.items(), 

446 random.sample( 

447 list( 

448 all_instances( 

449 PathTokenizers._PathTokenizer, 

450 {_TokenizerElement: lambda x: x.is_valid()}, 

451 ), 

452 ), 

453 NUM_TOKENIZERS_TO_TEST, 

454 ), 

455 ) 

456 ], 

457) 

458def test_path_tokenizers(pt: PathTokenizers._PathTokenizer, manual_maze: MANUAL_MAZE): 

459 solved_maze: SolvedMaze = SolvedMaze.from_ascii("\n".join(manual_maze.ascii)) 

460 match type(pt.step_size): 

461 case StepSizes.Singles: 

462 footprint_inds = range(solved_maze.solution.shape[0]) 

463 case StepSizes.Straightaways: 

464 swy_coordtup_set: set[CoordTup] = { 

465 tuple(c) for c in manual_maze.straightaway_footprints 

466 } 

467 footprint_inds: list[int] = [ 

468 i 

469 for i, c in enumerate(solved_maze.solution) 

470 if tuple(c) in swy_coordtup_set 

471 ] 

472 case StepSizes.Forks: 

473 footprint_inds = solved_maze.get_solution_forking_points( 

474 always_include_endpoints=True, 

475 )[0] 

476 case StepSizes.ForksAndStraightaways: 

477 swy_step_inds: list[int] = StepSizes.Straightaways()._step_single_indices( 

478 solved_maze, 

479 ) 

480 footprint_inds: Int[np.ndarray, " footprint_index"] = np.concatenate( 

481 ( 

482 solved_maze.get_solution_forking_points( 

483 always_include_endpoints=True, 

484 )[0], 

485 swy_step_inds, 

486 ), 

487 ) 

488 footprint_inds, _ = np.unique(footprint_inds, axis=0, return_index=True) 

489 _helper_test_path_tokenizers( 

490 pt, 

491 solved_maze, 

492 footprint_inds, 

493 ) 

494 

495 

496@pytest.mark.parametrize( 

497 ("ep", "maze"), 

498 [ 

499 pytest.param(tokenizer, maze, id=f"{tokenizer.name}-maze[{i}]") 

500 for (i, maze), tokenizer in itertools.product( 

501 enumerate(MIXED_MAZES[:6]), 

502 all_instances( 

503 EdgePermuters._EdgePermuter, 

504 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}), 

505 ), 

506 ) 

507 ], 

508) 

509def test_edge_permuters(ep: EdgePermuters._EdgePermuter, maze: LatticeMaze): 

510 edges: ConnectionArray = connection_list_to_adj_list( 

511 maze.connection_list, 

512 shuffle_d0=False, 

513 shuffle_d1=False, 

514 ) 

515 edges_copy: ConnectionArray = connection_list_to_adj_list( 

516 maze.connection_list, 

517 shuffle_d0=False, 

518 shuffle_d1=False, 

519 ) 

520 assert np.array_equal(edges, edges_copy) 

521 old_shape = edges.shape 

522 permuted: ConnectionArray = ep._permute(edges) 

523 match ep: 

524 case EdgePermuters.RandomCoords(): 

525 assert permuted.shape == old_shape 

526 assert edges is permuted 

527 i = 0 

528 while np.array_equal(permuted, edges_copy) and i < 2: 

529 # Permute again in case for small mazes the random selection happened to not change anything 

530 permuted: ConnectionArray = ep._permute(permuted) 

531 i += 1 

532 assert not np.array_equal(permuted, edges_copy) 

533 case EdgePermuters.BothCoords(): 

534 new_shape = old_shape[0] * 2, *old_shape[1:] 

535 n = old_shape[0] 

536 assert permuted.shape == new_shape 

537 assert np.array_equal(permuted[:n, ...], edges_copy) 

538 assert np.array_equal(permuted[:n, 0, :], permuted[n:, 1, :]) 

539 assert np.array_equal(permuted[:n, 1, :], permuted[n:, 0, :]) 

540 assert edges is not permuted 

541 

542 

543@pytest.mark.parametrize( 

544 ("es", "maze"), 

545 [ 

546 pytest.param(tokenizer, maze, id=f"{tokenizer.name}-maze[{i}]") 

547 for (i, maze), tokenizer in itertools.product( 

548 enumerate(MIXED_MAZES[:6]), 

549 all_instances( 

550 EdgeSubsets._EdgeSubset, 

551 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}), 

552 ), 

553 ) 

554 ], 

555) 

556def test_edge_subsets(es: EdgeSubsets._EdgeSubset, maze: LatticeMaze): 

557 edges: ConnectionArray = es._get_edges(maze) 

558 n: int = maze.grid_n 

559 match type(es): 

560 case EdgeSubsets.AllLatticeEdges: 

561 assert_shape: tuple = (2 * n * (n - 1), 2, 2) 

562 case EdgeSubsets.ConnectionEdges: 

563 if not es.walls: 

564 assert_shape: tuple = (np.count_nonzero(maze.connection_list), 2, 2) 

565 else: 

566 assert_shape: tuple = ( 

567 2 * n * (n - 1) - np.count_nonzero(maze.connection_list), 

568 2, 

569 2, 

570 ) 

571 assert edges.dtype == np.int8 

572 assert assert_shape == tuple(edges.shape) 

573 assert assert_shape == tuple( 

574 np.unique(edges, axis=0).shape, 

575 ) # All edges are unique (swapping leading/trailing coords is considered different) 

576 assert np.array_equal( 

577 manhattan_distance(edges), 

578 np.array([1] * assert_shape[0], dtype=np.int8), 

579 ) 

580 

581 

582@pytest.mark.parametrize( 

583 ("tok_elem", "es", "maze"), 

584 [ 

585 # we do a little accessing private members here 

586 pytest.param(tok_elem, es, maze, id=f"{tok_elem.name}-{es.name}-maze[{i}]") 

587 for (i, maze), tok_elem, es in itertools.product( 

588 enumerate(MIXED_MAZES[:6]), 

589 all_instances( 

590 EdgeGroupings._EdgeGrouping, 

591 frozendict.frozendict( 

592 { 

593 _TokenizerElement: lambda x: x.is_valid(), 

594 # Add a condition to prune the range space that doesn't affect functionality being tested 

595 EdgeGroupings.ByLeadingCoord: lambda x: x.intra 

596 and x.connection_token_ordinal == 1, 

597 }, 

598 ), 

599 ), 

600 all_instances( 

601 EdgeSubsets._EdgeSubset, 

602 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}), 

603 ), 

604 ) 

605 ], 

606) 

607def test_edge_groupings( 

608 tok_elem: EdgeGroupings._EdgeGrouping, 

609 es: EdgeSubsets._EdgeSubset, 

610 maze: LatticeMaze, 

611): 

612 # we do a little more accessing private members here 

613 edges: ConnectionArray = es._get_edges(maze) 

614 # n: int = maze.grid_n 

615 groups: Sequence[ConnectionArray] = tok_elem._group_edges(edges) 

616 

617 assert all( 

618 not np.any(np.diff(g[:, 0], axis=0)) for g in groups 

619 ) # Asserts that the leading coord is the same for all edges within each group 

620 match type(tok_elem): 

621 case EdgeGroupings.Ungrouped: 

622 assert_shape = edges.shape[0], 1, 2, 2 

623 assert tuple(groups.shape) == assert_shape 

624 case EdgeGroupings.ByLeadingCoord: 

625 assert len(groups) == np.unique(edges[:, 0, :], axis=0).shape[0] 

626 assert sum(g.shape[0] for g in groups) == edges.shape[0] 

627 # trailing_coords: list[CoordArray] = [g[:, 1, :] for g in groups] 

628 # vector_diffs is the position vector difference between the trailing coords of each group 

629 # These are stacked into a single array since we don't care about maintaining group separation 

630 vector_diffs: CoordArray = np.stack( 

631 list(flatten([np.diff(g[:, 1, :], axis=0) for g in groups], 1)), 

632 ) 

633 if tok_elem.shuffle_group: 

634 allowed_diffs = {(1, -1), (1, 1), (0, 2), (2, 0)} 

635 # The set of all 2D vectors between any 2 coords adjacent to a central coord 

636 allowed_diffs = allowed_diffs.union( 

637 {(-d[0], -d[1]) for d in allowed_diffs}, 

638 ) 

639 else: 

640 # If vector_diffs are lexicographically sorted, these are the only possible values. Any other value indicates an error in sorting 

641 allowed_diffs = {(1, -1), (1, 1), (0, 2), (2, 0)} 

642 assert all( 

643 tuple(diff) in allowed_diffs for diff in np.unique(vector_diffs, axis=0) 

644 ) 

645 

646 

647random.seed(GLOBAL_SEED) 

648 

649 

650@pytest.mark.parametrize( 

651 ("tok_elem", "maze"), 

652 [ 

653 pytest.param(tok_elem, maze, id=f"{tok_elem.name}-maze[{i}]") 

654 for (i, maze), tok_elem in itertools.product( 

655 enumerate(MAZE_DATASET), 

656 random.sample( 

657 list( 

658 all_instances( 

659 # yes we access a private member 

660 AdjListTokenizers._AdjListTokenizer, 

661 { 

662 _TokenizerElement: lambda x: x.is_valid(), 

663 }, 

664 ), 

665 ), 

666 100, 

667 ), 

668 ) 

669 ], 

670) 

671# too many branches and "too complex" but whatever 

672def test_adjlist_tokenizers( # noqa: PLR0912,C901 

673 tok_elem: AdjListTokenizers._AdjListTokenizer, 

674 maze: LatticeMaze, 

675): 

676 toks: list[str] = tok_elem.to_tokens(maze, CoordTokenizers.UT()) 

677 tok_counter: Counter = Counter(toks) 

678 n: int = maze.grid_n 

679 edge_count: int = 1 # To be updated in match/case blocks 

680 group_count: int = 1 # To be updated in match/case blocks 

681 

682 match tok_elem.edge_subset: 

683 case EdgeSubsets.AllLatticeEdges(): 

684 edge_count *= n * (n - 1) * 2 

685 case EdgeSubsets.ConnectionEdges(walls=False): 

686 edge_count *= np.count_nonzero(maze.connection_list) 

687 case EdgeSubsets.ConnectionEdges(walls=True): 

688 edge_count *= n * (n - 1) * 2 - np.count_nonzero(maze.connection_list) 

689 case _: 

690 msg: str = f"`match` case missing for {tok_elem.edge_subset = }" 

691 raise NotImplementedError(msg) 

692 

693 match tok_elem.edge_permuter: 

694 case EdgePermuters.BothCoords(): 

695 edge_count *= 2 

696 if tok_elem.edge_subset == EdgeSubsets.ConnectionEdges(walls=True): 

697 group_count *= np.count_nonzero( 

698 lattice_max_degrees(n) - maze.coord_degrees() > 0, 

699 ) # All coords with 1 adjacent wall, not counting outer boundaries 

700 else: 

701 group_count *= np.count_nonzero( 

702 maze.coord_degrees() > 0, 

703 ) # All coords with >0 connections 

704 case EdgePermuters.RandomCoords() | EdgePermuters.SortedCoords(): 

705 edge_count *= 1 

706 group_count = None # Group count is stochastic 

707 

708 match type(tok_elem.edge_grouping): 

709 case EdgeGroupings.Ungrouped: 

710 group_count = edge_count # Override all above cases 

711 case EdgeGroupings.ByLeadingCoord: 

712 if group_count is not None: 

713 group_count *= 1 

714 if tok_elem.edge_grouping.intra: 

715 assert tok_counter[VOCAB.ADJLIST_INTRA] == edge_count 

716 case _: 

717 msg: str = f"`match` case missing for {tok_elem.edge_grouping = }" 

718 raise NotImplementedError(msg) 

719 

720 match type(tok_elem): 

721 case AdjListTokenizers.AdjListCoord: 

722 pass 

723 case AdjListTokenizers.AdjListCardinal: 

724 assert ( 

725 tok_counter[VOCAB.PATH_NORTH] 

726 + tok_counter[VOCAB.PATH_SOUTH] 

727 + tok_counter[VOCAB.PATH_EAST] 

728 + tok_counter[VOCAB.PATH_WEST] 

729 == edge_count 

730 ) 

731 

732 if group_count is not None: 

733 if tok_elem.pre: 

734 assert tok_counter[VOCAB.ADJLIST_PRE] == group_count 

735 if tok_elem.post: 

736 assert tok_counter[VOCAB.ADJACENCY_ENDLINE] == group_count 

737 

738 assert tok_counter[VOCAB.CONNECTOR] + tok_counter[VOCAB.ADJLIST_WALL] == edge_count 

739 

740 

741@pytest.mark.parametrize( 

742 ("tok_elem", "valid"), 

743 [ 

744 pytest.param( 

745 tok_elem, 

746 valid, 

747 id=f"{tok_elem!r}", 

748 ) 

749 for tok_elem, valid in ( 

750 [ 

751 (StepSizes.ForksAndStraightaways(), False), 

752 (StepSizes.Straightaways(), False), 

753 (StepSizes.Forks(), True), 

754 (AdjListTokenizers.AdjListCoord(), True), 

755 (AdjListTokenizers.AdjListCoord(pre=True), False), 

756 (AdjListTokenizers.AdjListCardinal(), True), 

757 (AdjListTokenizers.AdjListCardinal(pre=True), False), 

758 (EdgeGroupings.Ungrouped(), True), 

759 (EdgeGroupings.ByLeadingCoord(), False), 

760 (EdgeGroupings.ByLeadingCoord(connection_token_ordinal=0), False), 

761 ] 

762 ) 

763 ], 

764) 

765def test_unsupported_elements(tok_elem: _TokenizerElement, valid: bool): 

766 assert tok_elem.is_valid() == valid