Coverage for tests\unit\maze_dataset\tokenization\test_token_utils.py: 97%

172 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1import itertools 

2from typing import Callable 

3 

4import frozendict 

5import numpy as np 

6import pytest 

7from jaxtyping import Int 

8from pytest import mark, param 

9 

10from maze_dataset import LatticeMaze 

11from maze_dataset.constants import VOCAB, Connection, ConnectionArray 

12from maze_dataset.generation import numpy_rng 

13from maze_dataset.testing_utils import GRID_N, MAZE_DATASET 

14from maze_dataset.token_utils import ( 

15 _coord_to_strings_UT, 

16 coords_to_strings, 

17 equal_except_adj_list_sequence, 

18 get_adj_list_tokens, 

19 get_origin_tokens, 

20 get_path_tokens, 

21 get_relative_direction, 

22 get_target_tokens, 

23 is_connection, 

24 strings_to_coords, 

25 tokens_between, 

26) 

27from maze_dataset.tokenization import ( 

28 PathTokenizers, 

29 StepTokenizers, 

30 get_tokens_up_to_path_start, 

31) 

32from maze_dataset.utils import ( 

33 FiniteValued, 

34 all_instances, 

35 lattice_connection_array, 

36 manhattan_distance, 

37) 

38 

39MAZE_TOKENS: tuple[list[str], str] = ( 

40 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

41 "AOTP_UT", 

42) 

43MAZE_TOKENS_AOTP_CTT_indexed: tuple[list[str], str] = ( 

44 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

45 "AOTP_CTT_indexed", 

46) 

47TEST_TOKEN_LISTS: list[tuple[list[str], str]] = [ 

48 MAZE_TOKENS, 

49 MAZE_TOKENS_AOTP_CTT_indexed, 

50] 

51 

52 

53@mark.parametrize( 

54 "toks, tokenizer_name", 

55 [ 

56 param( 

57 token_list[0], 

58 token_list[1], 

59 id=f"{token_list[1]}", 

60 ) 

61 for token_list in TEST_TOKEN_LISTS 

62 ], 

63) 

64def test_tokens_between(toks: list[str], tokenizer_name: str): 

65 result = tokens_between(toks, "<PATH_START>", "<PATH_END>") 

66 match tokenizer_name: 

67 case "AOTP_UT": 

68 assert result == ["(1,0)", "(1,1)"] 

69 case "AOTP_CTT_indexed": 

70 assert result == ["(", "1", ",", "0", ")", "(", "1", ",", "1", ")"] 

71 

72 # Normal case 

73 tokens = ["the", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] 

74 start_value = "quick" 

75 end_value = "over" 

76 assert tokens_between(tokens, start_value, end_value) == ["brown", "fox", "jumps"] 

77 

78 # Including start and end values 

79 assert tokens_between(tokens, start_value, end_value, True, True) == [ 

80 "quick", 

81 "brown", 

82 "fox", 

83 "jumps", 

84 "over", 

85 ] 

86 

87 # When start_value or end_value is not unique and except_when_tokens_not_unique is True 

88 with pytest.raises(ValueError): 

89 tokens_between(tokens, "the", "dog", False, False, True) 

90 

91 # When start_value or end_value is not unique and except_when_tokens_not_unique is False 

92 assert tokens_between(tokens, "the", "dog", False, False, False) == [ 

93 "quick", 

94 "brown", 

95 "fox", 

96 "jumps", 

97 "over", 

98 "the", 

99 "lazy", 

100 ] 

101 

102 # Empty tokens list 

103 with pytest.raises(ValueError): 

104 tokens_between([], "start", "end") 

105 

106 # start_value and end_value are the same 

107 with pytest.raises(ValueError): 

108 tokens_between(tokens, "fox", "fox") 

109 

110 # start_value or end_value not in the tokens list 

111 with pytest.raises(ValueError): 

112 tokens_between(tokens, "start", "end") 

113 

114 # start_value comes after end_value in the tokens list 

115 with pytest.raises(AssertionError): 

116 tokens_between(tokens, "over", "quick") 

117 

118 # start_value and end_value are at the beginning and end of the tokens list, respectively 

119 assert tokens_between(tokens, "the", "dog", True, True) == tokens 

120 

121 # Single element in the tokens list, which is the same as start_value and end_value 

122 with pytest.raises(ValueError): 

123 tokens_between(["fox"], "fox", "fox", True, True) 

124 

125 

126@mark.parametrize( 

127 "toks, tokenizer_name", 

128 [ 

129 param( 

130 token_list[0], 

131 token_list[1], 

132 id=f"{token_list[1]}", 

133 ) 

134 for token_list in TEST_TOKEN_LISTS 

135 ], 

136) 

137def test_tokens_between_out_of_order(toks: list[str], tokenizer_name: str): 

138 with pytest.raises(AssertionError): 

139 tokens_between(toks, "<PATH_END>", "<PATH_START>") 

140 

141 

142@mark.parametrize( 

143 "toks, tokenizer_name", 

144 [ 

145 param( 

146 token_list[0], 

147 token_list[1], 

148 id=f"{token_list[1]}", 

149 ) 

150 for token_list in TEST_TOKEN_LISTS 

151 ], 

152) 

153def test_get_adj_list_tokens(toks: list[str], tokenizer_name: str): 

154 result = get_adj_list_tokens(toks) 

155 match tokenizer_name: 

156 case "AOTP_UT": 

157 expected = ( 

158 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ;".split() 

159 ) 

160 case "AOTP_CTT_indexed": 

161 expected = "( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ;".split() 

162 assert result == expected 

163 

164 

165@mark.parametrize( 

166 "toks, tokenizer_name", 

167 [ 

168 param( 

169 token_list[0], 

170 token_list[1], 

171 id=f"{token_list[1]}", 

172 ) 

173 for token_list in TEST_TOKEN_LISTS 

174 ], 

175) 

176def test_get_path_tokens(toks: list[str], tokenizer_name: str): 

177 result_notrim = get_path_tokens(toks) 

178 result_trim = get_path_tokens(toks, trim_end=True) 

179 match tokenizer_name: 

180 case "AOTP_UT": 

181 assert result_notrim == ["<PATH_START>", "(1,0)", "(1,1)", "<PATH_END>"] 

182 assert result_trim == ["(1,0)", "(1,1)"] 

183 case "AOTP_CTT_indexed": 

184 assert ( 

185 result_notrim == "<PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split() 

186 ) 

187 assert result_trim == "( 1 , 0 ) ( 1 , 1 )".split() 

188 

189 

190@mark.parametrize( 

191 "toks, tokenizer_name", 

192 [ 

193 param( 

194 token_list[0], 

195 token_list[1], 

196 id=f"{token_list[1]}", 

197 ) 

198 for token_list in TEST_TOKEN_LISTS 

199 ], 

200) 

201def test_get_origin_tokens(toks: list[str], tokenizer_name: str): 

202 result = get_origin_tokens(toks) 

203 match tokenizer_name: 

204 case "AOTP_UT": 

205 assert result == ["(1,0)"] 

206 case "AOTP_CTT_indexed": 

207 assert result == "( 1 , 0 )".split() 

208 

209 

210@mark.parametrize( 

211 "toks, tokenizer_name", 

212 [ 

213 param( 

214 token_list[0], 

215 token_list[1], 

216 id=f"{token_list[1]}", 

217 ) 

218 for token_list in TEST_TOKEN_LISTS 

219 ], 

220) 

221def test_get_target_tokens(toks: list[str], tokenizer_name: str): 

222 result = get_target_tokens(toks) 

223 match tokenizer_name: 

224 case "AOTP_UT": 

225 assert result == ["(1,1)"] 

226 case "AOTP_CTT_indexed": 

227 assert result == "( 1 , 1 )".split() 

228 

229 

230@mark.parametrize( 

231 "toks, tokenizer_name", 

232 [ 

233 param( 

234 token_list[0], 

235 token_list[1], 

236 id=f"{token_list[1]}", 

237 ) 

238 for token_list in [MAZE_TOKENS] 

239 ], 

240) 

241def test_get_tokens_up_to_path_start_including_start( 

242 toks: list[str], tokenizer_name: str 

243): 

244 # Dont test on `MAZE_TOKENS_AOTP_CTT_indexed` because this function doesn't support `AOTP_CTT_indexed` when `include_start_coord=True`. 

245 result = get_tokens_up_to_path_start(toks, include_start_coord=True) 

246 match tokenizer_name: 

247 case "AOTP_UT": 

248 expected = "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0)".split() 

249 case "AOTP_CTT_indexed": 

250 expected = "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 )".split() 

251 assert result == expected 

252 

253 

254@mark.parametrize( 

255 "toks, tokenizer_name", 

256 [ 

257 param( 

258 token_list[0], 

259 token_list[1], 

260 id=f"{token_list[1]}", 

261 ) 

262 for token_list in TEST_TOKEN_LISTS 

263 ], 

264) 

265def test_get_tokens_up_to_path_start_excluding_start( 

266 toks: list[str], tokenizer_name: str 

267): 

268 result = get_tokens_up_to_path_start(toks, include_start_coord=False) 

269 match tokenizer_name: 

270 case "AOTP_UT": 

271 expected = "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START>".split() 

272 case "AOTP_CTT_indexed": 

273 expected = "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START>".split() 

274 assert result == expected 

275 

276 

277@mark.parametrize( 

278 "toks, tokenizer_name", 

279 [ 

280 param( 

281 token_list[0], 

282 token_list[1], 

283 id=f"{token_list[1]}", 

284 ) 

285 for token_list in TEST_TOKEN_LISTS 

286 ], 

287) 

288def test_strings_to_coords(toks: list[str], tokenizer_name: str): 

289 adj_list = get_adj_list_tokens(toks) 

290 skipped = strings_to_coords(adj_list, when_noncoord="skip") 

291 included = strings_to_coords(adj_list, when_noncoord="include") 

292 

293 assert skipped == [ 

294 (0, 1), 

295 (1, 1), 

296 (1, 0), 

297 (1, 1), 

298 (0, 1), 

299 (0, 0), 

300 ] 

301 

302 assert included == [ 

303 (0, 1), 

304 "<-->", 

305 (1, 1), 

306 ";", 

307 (1, 0), 

308 "<-->", 

309 (1, 1), 

310 ";", 

311 (0, 1), 

312 "<-->", 

313 (0, 0), 

314 ";", 

315 ] 

316 

317 with pytest.raises(ValueError): 

318 strings_to_coords(adj_list, when_noncoord="error") 

319 

320 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)") == [(1, 2), (5, 6)] 

321 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="skip") == [ 

322 (1, 2), 

323 (5, 6), 

324 ] 

325 assert strings_to_coords( 

326 "(1,2) <ADJLIST_START> (5,6)", when_noncoord="include" 

327 ) == [(1, 2), "<ADJLIST_START>", (5, 6)] 

328 with pytest.raises(ValueError): 

329 strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="error") 

330 

331 

332@mark.parametrize( 

333 "toks, tokenizer_name", 

334 [ 

335 param( 

336 token_list[0], 

337 token_list[1], 

338 id=f"{token_list[1]}", 

339 ) 

340 for token_list in TEST_TOKEN_LISTS 

341 ], 

342) 

343def test_coords_to_strings(toks: list[str], tokenizer_name: str): 

344 adj_list = get_adj_list_tokens(toks) 

345 # config = MazeDatasetConfig(name="test", grid_n=2, n_mazes=1) 

346 coords = strings_to_coords(adj_list, when_noncoord="include") 

347 

348 skipped = coords_to_strings( 

349 coords, coord_to_strings_func=_coord_to_strings_UT, when_noncoord="skip" 

350 ) 

351 included = coords_to_strings( 

352 coords, coord_to_strings_func=_coord_to_strings_UT, when_noncoord="include" 

353 ) 

354 

355 assert skipped == [ 

356 "(0,1)", 

357 "(1,1)", 

358 "(1,0)", 

359 "(1,1)", 

360 "(0,1)", 

361 "(0,0)", 

362 ] 

363 

364 assert included == [ 

365 "(0,1)", 

366 "<-->", 

367 "(1,1)", 

368 ";", 

369 "(1,0)", 

370 "<-->", 

371 "(1,1)", 

372 ";", 

373 "(0,1)", 

374 "<-->", 

375 "(0,0)", 

376 ";", 

377 ] 

378 

379 with pytest.raises(ValueError): 

380 coords_to_strings( 

381 coords, coord_to_strings_func=_coord_to_strings_UT, when_noncoord="error" 

382 ) 

383 

384 

385def test_equal_except_adj_list_sequence(): 

386 assert equal_except_adj_list_sequence(MAZE_TOKENS[0], MAZE_TOKENS[0]) 

387 assert not equal_except_adj_list_sequence( 

388 MAZE_TOKENS[0], MAZE_TOKENS_AOTP_CTT_indexed[0] 

389 ) 

390 assert equal_except_adj_list_sequence( 

391 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

392 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

393 ) 

394 assert equal_except_adj_list_sequence( 

395 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

396 "<ADJLIST_START> (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; (0,1) <--> (1,1) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

397 ) 

398 assert equal_except_adj_list_sequence( 

399 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

400 "<ADJLIST_START> (1,1) <--> (0,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

401 ) 

402 assert not equal_except_adj_list_sequence( 

403 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

404 "<ADJLIST_START> (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; (0,1) <--> (1,1) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,1) (1,0) <PATH_END>".split(), 

405 ) 

406 assert not equal_except_adj_list_sequence( 

407 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

408 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END> <PATH_END>".split(), 

409 ) 

410 assert not equal_except_adj_list_sequence( 

411 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

412 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

413 ) 

414 assert not equal_except_adj_list_sequence( 

415 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

416 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

417 ) 

418 with pytest.raises(ValueError): 

419 equal_except_adj_list_sequence( 

420 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

421 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

422 ) 

423 with pytest.raises(ValueError): 

424 equal_except_adj_list_sequence( 

425 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

426 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

427 ) 

428 assert not equal_except_adj_list_sequence( 

429 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

430 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

431 ) 

432 

433 # CTT 

434 assert equal_except_adj_list_sequence( 

435 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

436 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

437 ) 

438 assert equal_except_adj_list_sequence( 

439 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

440 "<ADJLIST_START> ( 1 , 1 ) <--> ( 0 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

441 ) 

442 # This inactive test demonstrates the lack of robustness of the function for comparing source `LatticeMaze` objects. 

443 # See function documentation for details. 

444 # assert not equal_except_adj_list_sequence( 

445 # "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

446 # "<ADJLIST_START> ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split() 

447 # ) 

448 

449 

450# @mivanit: this was really difficult to understand 

451@mark.parametrize( 

452 "type_, validation_funcs, assertion", 

453 [ 

454 param( 

455 type_, 

456 vfs, 

457 assertion, 

458 id=f"{i}-{type_.__name__}", 

459 ) 

460 for i, (type_, vfs, assertion) in enumerate( 

461 [ 

462 ( 

463 # type 

464 PathTokenizers._PathTokenizer, 

465 # validation_funcs 

466 dict(), 

467 # assertion 

468 lambda x: PathTokenizers.StepSequence( 

469 step_tokenizers=(StepTokenizers.Distance(),) 

470 ) 

471 in x, 

472 ), 

473 ( 

474 # type 

475 PathTokenizers._PathTokenizer, 

476 # validation_funcs 

477 {PathTokenizers._PathTokenizer: lambda x: x.is_valid()}, 

478 # assertion 

479 lambda x: PathTokenizers.StepSequence( 

480 step_tokenizers=(StepTokenizers.Distance(),) 

481 ) 

482 not in x 

483 and PathTokenizers.StepSequence( 

484 step_tokenizers=( 

485 StepTokenizers.Coord(), 

486 StepTokenizers.Coord(), 

487 ) 

488 ) 

489 not in x, 

490 ), 

491 ] 

492 ) 

493 ], 

494) 

495def test_all_instances2( 

496 type_: FiniteValued, 

497 validation_funcs: frozendict.frozendict[ 

498 FiniteValued, Callable[[FiniteValued], bool] 

499 ], 

500 assertion: Callable[[list[FiniteValued]], bool], 

501): 

502 assert assertion(all_instances(type_, validation_funcs)) 

503 

504 

505@mark.parametrize( 

506 "coords, result", 

507 [ 

508 param( 

509 np.array(coords), 

510 res, 

511 id=f"{coords}", 

512 ) 

513 for coords, res in ( 

514 [ 

515 ([[0, 0], [0, 1], [1, 1]], VOCAB.PATH_RIGHT), 

516 ([[0, 0], [1, 0], [1, 1]], VOCAB.PATH_LEFT), 

517 ([[0, 0], [0, 1], [0, 2]], VOCAB.PATH_FORWARD), 

518 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD), 

519 ([[0, 0], [0, 1], [0, 1]], VOCAB.PATH_STAY), 

520 ([[1, 1], [0, 1], [0, 0]], VOCAB.PATH_LEFT), 

521 ([[1, 1], [1, 0], [0, 0]], VOCAB.PATH_RIGHT), 

522 ([[0, 2], [0, 1], [0, 0]], VOCAB.PATH_FORWARD), 

523 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD), 

524 ([[0, 1], [0, 1], [0, 0]], ValueError), 

525 ([[0, 1], [1, 1], [0, 0]], ValueError), 

526 ([[1, 0], [1, 1], [0, 0]], ValueError), 

527 ([[0, 1], [0, 2], [0, 0]], ValueError), 

528 ([[0, 1], [0, 0], [0, 0]], VOCAB.PATH_STAY), 

529 ([[1, 1], [0, 0], [0, 1]], ValueError), 

530 ([[1, 1], [0, 0], [1, 0]], ValueError), 

531 ([[0, 2], [0, 0], [0, 1]], ValueError), 

532 ([[0, 0], [0, 0], [0, 1]], ValueError), 

533 ([[0, 1], [0, 0], [0, 1]], VOCAB.PATH_BACKWARD), 

534 ([[-1, 0], [0, 0], [1, 0]], VOCAB.PATH_FORWARD), 

535 ([[-1, 0], [0, 0], [0, 1]], VOCAB.PATH_LEFT), 

536 ([[-1, 0], [0, 0], [-1, 0]], VOCAB.PATH_BACKWARD), 

537 ([[-1, 0], [0, 0], [0, -1]], VOCAB.PATH_RIGHT), 

538 ([[-1, 0], [0, 0], [1, 0], [2, 0]], ValueError), 

539 ([[-1, 0], [0, 0]], ValueError), 

540 ([[-1, 0, 0], [0, 0, 0]], ValueError), 

541 ] 

542 ) 

543 ], 

544) 

545def test_get_relative_direction( 

546 coords: Int[np.ndarray, "prev_cur_next=3 axis=2"], result: str | type[Exception] 

547): 

548 if isinstance(result, type) and issubclass(result, Exception): 

549 with pytest.raises(result): 

550 get_relative_direction(coords) 

551 return 

552 assert get_relative_direction(coords) == result 

553 

554 

555@mark.parametrize( 

556 "edges, result", 

557 [ 

558 param( 

559 edges, 

560 res, 

561 id=f"{edges}", 

562 ) 

563 for edges, res in ( 

564 [ 

565 (np.array([[0, 0], [0, 1]]), 1), 

566 (np.array([[1, 0], [0, 1]]), 2), 

567 (np.array([[-1, 0], [0, 1]]), 2), 

568 (np.array([[0, 0], [5, 3]]), 8), 

569 ( 

570 np.array( 

571 [ 

572 [[0, 0], [0, 1]], 

573 [[1, 0], [0, 1]], 

574 [[-1, 0], [0, 1]], 

575 [[0, 0], [5, 3]], 

576 ] 

577 ), 

578 [1, 2, 2, 8], 

579 ), 

580 (np.array([[[0, 0], [5, 3]]]), [8]), 

581 ] 

582 ) 

583 ], 

584) 

585def test_manhattan_distance( 

586 edges: ConnectionArray | Connection, 

587 result: Int[np.ndarray, " edges"] | Int[np.ndarray, ""] | type[Exception], 

588): 

589 if isinstance(result, type) and issubclass(result, Exception): 

590 with pytest.raises(result): 

591 manhattan_distance(edges) 

592 return 

593 assert np.array_equal(manhattan_distance(edges), np.array(result, dtype=np.int8)) 

594 

595 

596@mark.parametrize( 

597 "n", 

598 [param(n) for n in [2, 3, 5, 20]], 

599) 

600def test_lattice_connection_arrray(n): 

601 edges = lattice_connection_array(n) 

602 assert tuple(edges.shape) == (2 * n * (n - 1), 2, 2) 

603 assert np.all(np.sum(edges[:, 1], axis=1) > np.sum(edges[:, 0], axis=1)) 

604 assert tuple(np.unique(edges, axis=0).shape) == (2 * n * (n - 1), 2, 2) 

605 

606 

607@mark.parametrize( 

608 "edges, maze", 

609 [ 

610 param( 

611 edges(), 

612 maze, 

613 id=f"edges[{i}]; maze[{j}]", 

614 ) 

615 for (i, edges), (j, maze) in itertools.product( 

616 enumerate( 

617 [ 

618 lambda: lattice_connection_array(GRID_N), 

619 lambda: np.flip(lattice_connection_array(GRID_N), axis=1), 

620 lambda: lattice_connection_array(GRID_N - 1), 

621 lambda: numpy_rng.choice( 

622 lattice_connection_array(GRID_N), 2 * GRID_N, axis=0 

623 ), 

624 lambda: numpy_rng.choice( 

625 lattice_connection_array(GRID_N), 1, axis=0 

626 ), 

627 ] 

628 ), 

629 enumerate(MAZE_DATASET.mazes), 

630 ) 

631 ], 

632) 

633def test_is_connection(edges: ConnectionArray, maze: LatticeMaze): 

634 output = is_connection(edges, maze.connection_list) 

635 sorted_edges = np.sort(edges, axis=1) 

636 edge_direction = ( 

637 (sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0 

638 ).astype(np.int8) 

639 assert np.array_equal( 

640 output, 

641 maze.connection_list[ 

642 edge_direction, sorted_edges[:, 0, 0], sorted_edges[:, 0, 1] 

643 ], 

644 )