Coverage for tests/unit/maze_dataset/tokenization/test_token_utils.py: 97%

174 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-11 00:49 -0600

1import itertools 

2from typing import Callable 

3 

4import frozendict 

5import numpy as np 

6import pytest 

7from jaxtyping import Int 

8 

9from maze_dataset import LatticeMaze 

10from maze_dataset.constants import VOCAB, Connection, ConnectionArray 

11from maze_dataset.generation import numpy_rng 

12from maze_dataset.testing_utils import GRID_N, MAZE_DATASET 

13from maze_dataset.token_utils import ( 

14 _coord_to_strings_UT, 

15 coords_to_strings, 

16 equal_except_adj_list_sequence, 

17 get_adj_list_tokens, 

18 get_origin_tokens, 

19 get_path_tokens, 

20 get_relative_direction, 

21 get_target_tokens, 

22 is_connection, 

23 strings_to_coords, 

24 tokens_between, 

25) 

26from maze_dataset.tokenization import ( 

27 PathTokenizers, 

28 StepTokenizers, 

29 get_tokens_up_to_path_start, 

30) 

31from maze_dataset.utils import ( 

32 FiniteValued, 

33 all_instances, 

34 lattice_connection_array, 

35 manhattan_distance, 

36) 

37 

38MAZE_TOKENS: tuple[list[str], str] = ( 

39 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

40 "AOTP_UT", 

41) 

42MAZE_TOKENS_AOTP_CTT_indexed: tuple[list[str], str] = ( 

43 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

44 "AOTP_CTT_indexed", 

45) 

46TEST_TOKEN_LISTS: list[tuple[list[str], str]] = [ 

47 MAZE_TOKENS, 

48 MAZE_TOKENS_AOTP_CTT_indexed, 

49] 

50 

51 

52@pytest.mark.parametrize( 

53 ("toks", "tokenizer_name"), 

54 [ 

55 pytest.param( 

56 token_list[0], 

57 token_list[1], 

58 id=f"{token_list[1]}", 

59 ) 

60 for token_list in TEST_TOKEN_LISTS 

61 ], 

62) 

63def test_tokens_between(toks: list[str], tokenizer_name: str): 

64 result = tokens_between(toks, "<PATH_START>", "<PATH_END>") 

65 match tokenizer_name: 

66 case "AOTP_UT": 

67 assert result == ["(1,0)", "(1,1)"] 

68 case "AOTP_CTT_indexed": 

69 assert result == ["(", "1", ",", "0", ")", "(", "1", ",", "1", ")"] 

70 

71 # Normal case 

72 tokens = ["the", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] 

73 start_value = "quick" 

74 end_value = "over" 

75 assert tokens_between(tokens, start_value, end_value) == ["brown", "fox", "jumps"] 

76 

77 # Including start and end values 

78 assert tokens_between(tokens, start_value, end_value, True, True) == [ 

79 "quick", 

80 "brown", 

81 "fox", 

82 "jumps", 

83 "over", 

84 ] 

85 

86 # When start_value or end_value is not unique and except_when_tokens_not_unique is True 

87 with pytest.raises(ValueError): # noqa: PT011 

88 tokens_between(tokens, "the", "dog", False, False, True) 

89 

90 # When start_value or end_value is not unique and except_when_tokens_not_unique is False 

91 assert tokens_between(tokens, "the", "dog", False, False, False) == [ 

92 "quick", 

93 "brown", 

94 "fox", 

95 "jumps", 

96 "over", 

97 "the", 

98 "lazy", 

99 ] 

100 

101 # Empty tokens list 

102 with pytest.raises(ValueError): # noqa: PT011 

103 tokens_between([], "start", "end") 

104 

105 # start_value and end_value are the same 

106 with pytest.raises(ValueError): # noqa: PT011 

107 tokens_between(tokens, "fox", "fox") 

108 

109 # start_value or end_value not in the tokens list 

110 with pytest.raises(ValueError): # noqa: PT011 

111 tokens_between(tokens, "start", "end") 

112 

113 # start_value comes after end_value in the tokens list 

114 with pytest.raises(AssertionError): 

115 tokens_between(tokens, "over", "quick") 

116 

117 # start_value and end_value are at the beginning and end of the tokens list, respectively 

118 assert tokens_between(tokens, "the", "dog", True, True) == tokens 

119 

120 # Single element in the tokens list, which is the same as start_value and end_value 

121 with pytest.raises(ValueError): # noqa: PT011 

122 tokens_between(["fox"], "fox", "fox", True, True) 

123 

124 

125@pytest.mark.parametrize( 

126 ("toks", "tokenizer_name"), 

127 [ 

128 pytest.param( 

129 token_list[0], 

130 token_list[1], 

131 id=f"{token_list[1]}", 

132 ) 

133 for token_list in TEST_TOKEN_LISTS 

134 ], 

135) 

136def test_tokens_between_out_of_order(toks: list[str], tokenizer_name: str): 

137 assert tokenizer_name 

138 with pytest.raises(AssertionError): 

139 tokens_between(toks, "<PATH_END>", "<PATH_START>") 

140 

141 

142@pytest.mark.parametrize( 

143 ("toks", "tokenizer_name"), 

144 [ 

145 pytest.param( 

146 token_list[0], 

147 token_list[1], 

148 id=f"{token_list[1]}", 

149 ) 

150 for token_list in TEST_TOKEN_LISTS 

151 ], 

152) 

153def test_get_adj_list_tokens(toks: list[str], tokenizer_name: str): 

154 result = get_adj_list_tokens(toks) 

155 match tokenizer_name: 

156 case "AOTP_UT": 

157 expected = ( 

158 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ;".split() 

159 ) 

160 case "AOTP_CTT_indexed": 

161 expected = "( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ;".split() 

162 assert result == expected 

163 

164 

165@pytest.mark.parametrize( 

166 ("toks", "tokenizer_name"), 

167 [ 

168 pytest.param( 

169 token_list[0], 

170 token_list[1], 

171 id=f"{token_list[1]}", 

172 ) 

173 for token_list in TEST_TOKEN_LISTS 

174 ], 

175) 

176def test_get_path_tokens(toks: list[str], tokenizer_name: str): 

177 result_notrim = get_path_tokens(toks) 

178 result_trim = get_path_tokens(toks, trim_end=True) 

179 match tokenizer_name: 

180 case "AOTP_UT": 

181 assert result_notrim == ["<PATH_START>", "(1,0)", "(1,1)", "<PATH_END>"] 

182 assert result_trim == ["(1,0)", "(1,1)"] 

183 case "AOTP_CTT_indexed": 

184 assert ( 

185 result_notrim == "<PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split() 

186 ) 

187 assert result_trim == "( 1 , 0 ) ( 1 , 1 )".split() 

188 

189 

190@pytest.mark.parametrize( 

191 ("toks", "tokenizer_name"), 

192 [ 

193 pytest.param( 

194 token_list[0], 

195 token_list[1], 

196 id=f"{token_list[1]}", 

197 ) 

198 for token_list in TEST_TOKEN_LISTS 

199 ], 

200) 

201def test_get_origin_tokens(toks: list[str], tokenizer_name: str): 

202 result = get_origin_tokens(toks) 

203 match tokenizer_name: 

204 case "AOTP_UT": 

205 assert result == ["(1,0)"] 

206 case "AOTP_CTT_indexed": 

207 assert result == "( 1 , 0 )".split() 

208 

209 

210@pytest.mark.parametrize( 

211 ("toks", "tokenizer_name"), 

212 [ 

213 pytest.param( 

214 token_list[0], 

215 token_list[1], 

216 id=f"{token_list[1]}", 

217 ) 

218 for token_list in TEST_TOKEN_LISTS 

219 ], 

220) 

221def test_get_target_tokens(toks: list[str], tokenizer_name: str): 

222 result = get_target_tokens(toks) 

223 match tokenizer_name: 

224 case "AOTP_UT": 

225 assert result == ["(1,1)"] 

226 case "AOTP_CTT_indexed": 

227 assert result == "( 1 , 1 )".split() 

228 

229 

230@pytest.mark.parametrize( 

231 ("toks", "tokenizer_name"), 

232 [ 

233 pytest.param( 

234 token_list[0], 

235 token_list[1], 

236 id=f"{token_list[1]}", 

237 ) 

238 for token_list in [MAZE_TOKENS] 

239 ], 

240) 

241def test_get_tokens_up_to_path_start_including_start( 

242 toks: list[str], 

243 tokenizer_name: str, 

244): 

245 # Dont test on `MAZE_TOKENS_AOTP_CTT_indexed` because this function doesn't support `AOTP_CTT_indexed` when `include_start_coord=True`. 

246 result = get_tokens_up_to_path_start(toks, include_start_coord=True) 

247 match tokenizer_name: 

248 case "AOTP_UT": 

249 expected = "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0)".split() 

250 case "AOTP_CTT_indexed": 

251 expected = "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 )".split() 

252 assert result == expected 

253 

254 

255@pytest.mark.parametrize( 

256 ("toks", "tokenizer_name"), 

257 [ 

258 pytest.param( 

259 token_list[0], 

260 token_list[1], 

261 id=f"{token_list[1]}", 

262 ) 

263 for token_list in TEST_TOKEN_LISTS 

264 ], 

265) 

266def test_get_tokens_up_to_path_start_excluding_start( 

267 toks: list[str], 

268 tokenizer_name: str, 

269): 

270 result = get_tokens_up_to_path_start(toks, include_start_coord=False) 

271 match tokenizer_name: 

272 case "AOTP_UT": 

273 expected = "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START>".split() 

274 case "AOTP_CTT_indexed": 

275 expected = "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START>".split() 

276 assert result == expected 

277 

278 

279@pytest.mark.parametrize( 

280 ("toks", "tokenizer_name"), 

281 [ 

282 pytest.param( 

283 token_list[0], 

284 token_list[1], 

285 id=f"{token_list[1]}", 

286 ) 

287 for token_list in TEST_TOKEN_LISTS 

288 ], 

289) 

290def test_strings_to_coords(toks: list[str], tokenizer_name: str): 

291 assert tokenizer_name 

292 adj_list = get_adj_list_tokens(toks) 

293 skipped = strings_to_coords(adj_list, when_noncoord="skip") 

294 included = strings_to_coords(adj_list, when_noncoord="include") 

295 

296 assert skipped == [ 

297 (0, 1), 

298 (1, 1), 

299 (1, 0), 

300 (1, 1), 

301 (0, 1), 

302 (0, 0), 

303 ] 

304 

305 assert included == [ 

306 (0, 1), 

307 "<-->", 

308 (1, 1), 

309 ";", 

310 (1, 0), 

311 "<-->", 

312 (1, 1), 

313 ";", 

314 (0, 1), 

315 "<-->", 

316 (0, 0), 

317 ";", 

318 ] 

319 

320 with pytest.raises(ValueError): # noqa: PT011 

321 strings_to_coords(adj_list, when_noncoord="error") 

322 

323 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)") == [(1, 2), (5, 6)] 

324 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="skip") == [ 

325 (1, 2), 

326 (5, 6), 

327 ] 

328 assert strings_to_coords( 

329 "(1,2) <ADJLIST_START> (5,6)", 

330 when_noncoord="include", 

331 ) == [(1, 2), "<ADJLIST_START>", (5, 6)] 

332 with pytest.raises(ValueError): # noqa: PT011 

333 strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="error") 

334 

335 

336@pytest.mark.parametrize( 

337 ("toks", "tokenizer_name"), 

338 [ 

339 pytest.param( 

340 token_list[0], 

341 token_list[1], 

342 id=f"{token_list[1]}", 

343 ) 

344 for token_list in TEST_TOKEN_LISTS 

345 ], 

346) 

347def test_coords_to_strings(toks: list[str], tokenizer_name: str): 

348 assert tokenizer_name 

349 adj_list = get_adj_list_tokens(toks) 

350 # config = MazeDatasetConfig(name="test", grid_n=2, n_mazes=1) 

351 coords = strings_to_coords(adj_list, when_noncoord="include") 

352 

353 skipped = coords_to_strings( 

354 coords, 

355 coord_to_strings_func=_coord_to_strings_UT, 

356 when_noncoord="skip", 

357 ) 

358 included = coords_to_strings( 

359 coords, 

360 coord_to_strings_func=_coord_to_strings_UT, 

361 when_noncoord="include", 

362 ) 

363 

364 assert skipped == [ 

365 "(0,1)", 

366 "(1,1)", 

367 "(1,0)", 

368 "(1,1)", 

369 "(0,1)", 

370 "(0,0)", 

371 ] 

372 

373 assert included == [ 

374 "(0,1)", 

375 "<-->", 

376 "(1,1)", 

377 ";", 

378 "(1,0)", 

379 "<-->", 

380 "(1,1)", 

381 ";", 

382 "(0,1)", 

383 "<-->", 

384 "(0,0)", 

385 ";", 

386 ] 

387 

388 with pytest.raises(ValueError): # noqa: PT011 

389 coords_to_strings( 

390 coords, 

391 coord_to_strings_func=_coord_to_strings_UT, 

392 when_noncoord="error", 

393 ) 

394 

395 

396def test_equal_except_adj_list_sequence(): 

397 assert equal_except_adj_list_sequence(MAZE_TOKENS[0], MAZE_TOKENS[0]) 

398 assert not equal_except_adj_list_sequence( 

399 MAZE_TOKENS[0], 

400 MAZE_TOKENS_AOTP_CTT_indexed[0], 

401 ) 

402 assert equal_except_adj_list_sequence( 

403 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

404 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

405 ) 

406 assert equal_except_adj_list_sequence( 

407 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

408 "<ADJLIST_START> (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; (0,1) <--> (1,1) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

409 ) 

410 assert equal_except_adj_list_sequence( 

411 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

412 "<ADJLIST_START> (1,1) <--> (0,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

413 ) 

414 assert not equal_except_adj_list_sequence( 

415 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

416 "<ADJLIST_START> (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; (0,1) <--> (1,1) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,1) (1,0) <PATH_END>".split(), 

417 ) 

418 assert not equal_except_adj_list_sequence( 

419 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

420 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END> <PATH_END>".split(), 

421 ) 

422 assert not equal_except_adj_list_sequence( 

423 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

424 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

425 ) 

426 assert not equal_except_adj_list_sequence( 

427 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

428 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

429 ) 

430 with pytest.raises(ValueError): # noqa: PT011 

431 equal_except_adj_list_sequence( 

432 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

433 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

434 ) 

435 with pytest.raises(ValueError): # noqa: PT011 

436 equal_except_adj_list_sequence( 

437 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

438 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

439 ) 

440 assert not equal_except_adj_list_sequence( 

441 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

442 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

443 ) 

444 

445 # CTT 

446 assert equal_except_adj_list_sequence( 

447 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

448 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

449 ) 

450 assert equal_except_adj_list_sequence( 

451 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

452 "<ADJLIST_START> ( 1 , 1 ) <--> ( 0 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

453 ) 

454 # This inactive test demonstrates the lack of robustness of the function for comparing source `LatticeMaze` objects. 

455 # See function documentation for details. 

456 # assert not equal_except_adj_list_sequence( 

457 # "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

458 # "<ADJLIST_START> ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split() 

459 # ) 

460 

461 

462# @mivanit: this was really difficult to understand 

463@pytest.mark.parametrize( 

464 ("type_", "validation_funcs", "assertion"), 

465 [ 

466 pytest.param( 

467 type_, 

468 vfs, 

469 assertion, 

470 id=f"{i}-{type_.__name__}", 

471 ) 

472 for i, (type_, vfs, assertion) in enumerate( 

473 [ 

474 ( 

475 # type 

476 PathTokenizers._PathTokenizer, 

477 # validation_funcs 

478 dict(), 

479 # assertion 

480 lambda x: PathTokenizers.StepSequence( 

481 step_tokenizers=(StepTokenizers.Distance(),), 

482 ) 

483 in x, 

484 ), 

485 ( 

486 # type 

487 PathTokenizers._PathTokenizer, 

488 # validation_funcs 

489 {PathTokenizers._PathTokenizer: lambda x: x.is_valid()}, 

490 # assertion 

491 lambda x: PathTokenizers.StepSequence( 

492 step_tokenizers=(StepTokenizers.Distance(),), 

493 ) 

494 not in x 

495 and PathTokenizers.StepSequence( 

496 step_tokenizers=( 

497 StepTokenizers.Coord(), 

498 StepTokenizers.Coord(), 

499 ), 

500 ) 

501 not in x, 

502 ), 

503 ], 

504 ) 

505 ], 

506) 

507def test_all_instances2( 

508 type_: FiniteValued, 

509 validation_funcs: frozendict.frozendict[ 

510 FiniteValued, 

511 Callable[[FiniteValued], bool], 

512 ], 

513 assertion: Callable[[list[FiniteValued]], bool], 

514): 

515 assert assertion(all_instances(type_, validation_funcs)) 

516 

517 

518@pytest.mark.parametrize( 

519 ("coords", "result"), 

520 [ 

521 pytest.param( 

522 np.array(coords), 

523 res, 

524 id=f"{coords}", 

525 ) 

526 for coords, res in ( 

527 [ 

528 ([[0, 0], [0, 1], [1, 1]], VOCAB.PATH_RIGHT), 

529 ([[0, 0], [1, 0], [1, 1]], VOCAB.PATH_LEFT), 

530 ([[0, 0], [0, 1], [0, 2]], VOCAB.PATH_FORWARD), 

531 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD), 

532 ([[0, 0], [0, 1], [0, 1]], VOCAB.PATH_STAY), 

533 ([[1, 1], [0, 1], [0, 0]], VOCAB.PATH_LEFT), 

534 ([[1, 1], [1, 0], [0, 0]], VOCAB.PATH_RIGHT), 

535 ([[0, 2], [0, 1], [0, 0]], VOCAB.PATH_FORWARD), 

536 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD), 

537 ([[0, 1], [0, 1], [0, 0]], ValueError), 

538 ([[0, 1], [1, 1], [0, 0]], ValueError), 

539 ([[1, 0], [1, 1], [0, 0]], ValueError), 

540 ([[0, 1], [0, 2], [0, 0]], ValueError), 

541 ([[0, 1], [0, 0], [0, 0]], VOCAB.PATH_STAY), 

542 ([[1, 1], [0, 0], [0, 1]], ValueError), 

543 ([[1, 1], [0, 0], [1, 0]], ValueError), 

544 ([[0, 2], [0, 0], [0, 1]], ValueError), 

545 ([[0, 0], [0, 0], [0, 1]], ValueError), 

546 ([[0, 1], [0, 0], [0, 1]], VOCAB.PATH_BACKWARD), 

547 ([[-1, 0], [0, 0], [1, 0]], VOCAB.PATH_FORWARD), 

548 ([[-1, 0], [0, 0], [0, 1]], VOCAB.PATH_LEFT), 

549 ([[-1, 0], [0, 0], [-1, 0]], VOCAB.PATH_BACKWARD), 

550 ([[-1, 0], [0, 0], [0, -1]], VOCAB.PATH_RIGHT), 

551 ([[-1, 0], [0, 0], [1, 0], [2, 0]], ValueError), 

552 ([[-1, 0], [0, 0]], ValueError), 

553 ([[-1, 0, 0], [0, 0, 0]], ValueError), 

554 ] 

555 ) 

556 ], 

557) 

558def test_get_relative_direction( 

559 coords: Int[np.ndarray, "prev_cur_next=3 axis=2"], 

560 result: str | type[Exception], 

561): 

562 if isinstance(result, type) and issubclass(result, Exception): 

563 with pytest.raises(result): 

564 get_relative_direction(coords) 

565 return 

566 assert get_relative_direction(coords) == result 

567 

568 

569@pytest.mark.parametrize( 

570 ("edges", "result"), 

571 [ 

572 pytest.param( 

573 edges, 

574 res, 

575 id=f"{edges}", 

576 ) 

577 for edges, res in ( 

578 [ 

579 (np.array([[0, 0], [0, 1]]), 1), 

580 (np.array([[1, 0], [0, 1]]), 2), 

581 (np.array([[-1, 0], [0, 1]]), 2), 

582 (np.array([[0, 0], [5, 3]]), 8), 

583 ( 

584 np.array( 

585 [ 

586 [[0, 0], [0, 1]], 

587 [[1, 0], [0, 1]], 

588 [[-1, 0], [0, 1]], 

589 [[0, 0], [5, 3]], 

590 ], 

591 ), 

592 [1, 2, 2, 8], 

593 ), 

594 (np.array([[[0, 0], [5, 3]]]), [8]), 

595 ] 

596 ) 

597 ], 

598) 

599def test_manhattan_distance( 

600 edges: ConnectionArray | Connection, 

601 result: Int[np.ndarray, " edges"] | Int[np.ndarray, ""] | type[Exception], 

602): 

603 if isinstance(result, type) and issubclass(result, Exception): 

604 with pytest.raises(result): 

605 manhattan_distance(edges) 

606 return 

607 assert np.array_equal(manhattan_distance(edges), np.array(result, dtype=np.int8)) 

608 

609 

610@pytest.mark.parametrize( 

611 "n", 

612 [pytest.param(n) for n in [2, 3, 5, 20]], 

613) 

614def test_lattice_connection_arrray(n): 

615 edges = lattice_connection_array(n) 

616 assert tuple(edges.shape) == (2 * n * (n - 1), 2, 2) 

617 assert np.all(np.sum(edges[:, 1], axis=1) > np.sum(edges[:, 0], axis=1)) 

618 assert tuple(np.unique(edges, axis=0).shape) == (2 * n * (n - 1), 2, 2) 

619 

620 

621@pytest.mark.parametrize( 

622 ("edges", "maze"), 

623 [ 

624 pytest.param( 

625 edges(), 

626 maze, 

627 id=f"edges[{i}]; maze[{j}]", 

628 ) 

629 for (i, edges), (j, maze) in itertools.product( 

630 enumerate( 

631 [ 

632 lambda: lattice_connection_array(GRID_N), 

633 lambda: np.flip(lattice_connection_array(GRID_N), axis=1), 

634 lambda: lattice_connection_array(GRID_N - 1), 

635 lambda: numpy_rng.choice( 

636 lattice_connection_array(GRID_N), 

637 2 * GRID_N, 

638 axis=0, 

639 ), 

640 lambda: numpy_rng.choice( 

641 lattice_connection_array(GRID_N), 

642 1, 

643 axis=0, 

644 ), 

645 ], 

646 ), 

647 enumerate(MAZE_DATASET.mazes), 

648 ) 

649 ], 

650) 

651def test_is_connection(edges: ConnectionArray, maze: LatticeMaze): 

652 output = is_connection(edges, maze.connection_list) 

653 sorted_edges = np.sort(edges, axis=1) 

654 edge_direction = ( 

655 (sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0 

656 ).astype(np.int8) 

657 assert np.array_equal( 

658 output, 

659 maze.connection_list[ 

660 edge_direction, 

661 sorted_edges[:, 0, 0], 

662 sorted_edges[:, 0, 1], 

663 ], 

664 )