Coverage for tests\unit\maze_dataset\tokenization\test_token_utils.py: 97%
172 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1import itertools
2from typing import Callable
4import frozendict
5import numpy as np
6import pytest
7from jaxtyping import Int
8from pytest import mark, param
10from maze_dataset import LatticeMaze
11from maze_dataset.constants import VOCAB, Connection, ConnectionArray
12from maze_dataset.generation import numpy_rng
13from maze_dataset.testing_utils import GRID_N, MAZE_DATASET
14from maze_dataset.token_utils import (
15 _coord_to_strings_UT,
16 coords_to_strings,
17 equal_except_adj_list_sequence,
18 get_adj_list_tokens,
19 get_origin_tokens,
20 get_path_tokens,
21 get_relative_direction,
22 get_target_tokens,
23 is_connection,
24 strings_to_coords,
25 tokens_between,
26)
27from maze_dataset.tokenization import (
28 PathTokenizers,
29 StepTokenizers,
30 get_tokens_up_to_path_start,
31)
32from maze_dataset.utils import (
33 FiniteValued,
34 all_instances,
35 lattice_connection_array,
36 manhattan_distance,
37)
39MAZE_TOKENS: tuple[list[str], str] = (
40 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
41 "AOTP_UT",
42)
43MAZE_TOKENS_AOTP_CTT_indexed: tuple[list[str], str] = (
44 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
45 "AOTP_CTT_indexed",
46)
47TEST_TOKEN_LISTS: list[tuple[list[str], str]] = [
48 MAZE_TOKENS,
49 MAZE_TOKENS_AOTP_CTT_indexed,
50]
53@mark.parametrize(
54 "toks, tokenizer_name",
55 [
56 param(
57 token_list[0],
58 token_list[1],
59 id=f"{token_list[1]}",
60 )
61 for token_list in TEST_TOKEN_LISTS
62 ],
63)
64def test_tokens_between(toks: list[str], tokenizer_name: str):
65 result = tokens_between(toks, "<PATH_START>", "<PATH_END>")
66 match tokenizer_name:
67 case "AOTP_UT":
68 assert result == ["(1,0)", "(1,1)"]
69 case "AOTP_CTT_indexed":
70 assert result == ["(", "1", ",", "0", ")", "(", "1", ",", "1", ")"]
72 # Normal case
73 tokens = ["the", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"]
74 start_value = "quick"
75 end_value = "over"
76 assert tokens_between(tokens, start_value, end_value) == ["brown", "fox", "jumps"]
78 # Including start and end values
79 assert tokens_between(tokens, start_value, end_value, True, True) == [
80 "quick",
81 "brown",
82 "fox",
83 "jumps",
84 "over",
85 ]
87 # When start_value or end_value is not unique and except_when_tokens_not_unique is True
88 with pytest.raises(ValueError):
89 tokens_between(tokens, "the", "dog", False, False, True)
91 # When start_value or end_value is not unique and except_when_tokens_not_unique is False
92 assert tokens_between(tokens, "the", "dog", False, False, False) == [
93 "quick",
94 "brown",
95 "fox",
96 "jumps",
97 "over",
98 "the",
99 "lazy",
100 ]
102 # Empty tokens list
103 with pytest.raises(ValueError):
104 tokens_between([], "start", "end")
106 # start_value and end_value are the same
107 with pytest.raises(ValueError):
108 tokens_between(tokens, "fox", "fox")
110 # start_value or end_value not in the tokens list
111 with pytest.raises(ValueError):
112 tokens_between(tokens, "start", "end")
114 # start_value comes after end_value in the tokens list
115 with pytest.raises(AssertionError):
116 tokens_between(tokens, "over", "quick")
118 # start_value and end_value are at the beginning and end of the tokens list, respectively
119 assert tokens_between(tokens, "the", "dog", True, True) == tokens
121 # Single element in the tokens list, which is the same as start_value and end_value
122 with pytest.raises(ValueError):
123 tokens_between(["fox"], "fox", "fox", True, True)
126@mark.parametrize(
127 "toks, tokenizer_name",
128 [
129 param(
130 token_list[0],
131 token_list[1],
132 id=f"{token_list[1]}",
133 )
134 for token_list in TEST_TOKEN_LISTS
135 ],
136)
137def test_tokens_between_out_of_order(toks: list[str], tokenizer_name: str):
138 with pytest.raises(AssertionError):
139 tokens_between(toks, "<PATH_END>", "<PATH_START>")
142@mark.parametrize(
143 "toks, tokenizer_name",
144 [
145 param(
146 token_list[0],
147 token_list[1],
148 id=f"{token_list[1]}",
149 )
150 for token_list in TEST_TOKEN_LISTS
151 ],
152)
153def test_get_adj_list_tokens(toks: list[str], tokenizer_name: str):
154 result = get_adj_list_tokens(toks)
155 match tokenizer_name:
156 case "AOTP_UT":
157 expected = (
158 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ;".split()
159 )
160 case "AOTP_CTT_indexed":
161 expected = "( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ;".split()
162 assert result == expected
165@mark.parametrize(
166 "toks, tokenizer_name",
167 [
168 param(
169 token_list[0],
170 token_list[1],
171 id=f"{token_list[1]}",
172 )
173 for token_list in TEST_TOKEN_LISTS
174 ],
175)
176def test_get_path_tokens(toks: list[str], tokenizer_name: str):
177 result_notrim = get_path_tokens(toks)
178 result_trim = get_path_tokens(toks, trim_end=True)
179 match tokenizer_name:
180 case "AOTP_UT":
181 assert result_notrim == ["<PATH_START>", "(1,0)", "(1,1)", "<PATH_END>"]
182 assert result_trim == ["(1,0)", "(1,1)"]
183 case "AOTP_CTT_indexed":
184 assert (
185 result_notrim == "<PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split()
186 )
187 assert result_trim == "( 1 , 0 ) ( 1 , 1 )".split()
190@mark.parametrize(
191 "toks, tokenizer_name",
192 [
193 param(
194 token_list[0],
195 token_list[1],
196 id=f"{token_list[1]}",
197 )
198 for token_list in TEST_TOKEN_LISTS
199 ],
200)
201def test_get_origin_tokens(toks: list[str], tokenizer_name: str):
202 result = get_origin_tokens(toks)
203 match tokenizer_name:
204 case "AOTP_UT":
205 assert result == ["(1,0)"]
206 case "AOTP_CTT_indexed":
207 assert result == "( 1 , 0 )".split()
210@mark.parametrize(
211 "toks, tokenizer_name",
212 [
213 param(
214 token_list[0],
215 token_list[1],
216 id=f"{token_list[1]}",
217 )
218 for token_list in TEST_TOKEN_LISTS
219 ],
220)
221def test_get_target_tokens(toks: list[str], tokenizer_name: str):
222 result = get_target_tokens(toks)
223 match tokenizer_name:
224 case "AOTP_UT":
225 assert result == ["(1,1)"]
226 case "AOTP_CTT_indexed":
227 assert result == "( 1 , 1 )".split()
230@mark.parametrize(
231 "toks, tokenizer_name",
232 [
233 param(
234 token_list[0],
235 token_list[1],
236 id=f"{token_list[1]}",
237 )
238 for token_list in [MAZE_TOKENS]
239 ],
240)
241def test_get_tokens_up_to_path_start_including_start(
242 toks: list[str], tokenizer_name: str
243):
244 # Dont test on `MAZE_TOKENS_AOTP_CTT_indexed` because this function doesn't support `AOTP_CTT_indexed` when `include_start_coord=True`.
245 result = get_tokens_up_to_path_start(toks, include_start_coord=True)
246 match tokenizer_name:
247 case "AOTP_UT":
248 expected = "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0)".split()
249 case "AOTP_CTT_indexed":
250 expected = "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 )".split()
251 assert result == expected
254@mark.parametrize(
255 "toks, tokenizer_name",
256 [
257 param(
258 token_list[0],
259 token_list[1],
260 id=f"{token_list[1]}",
261 )
262 for token_list in TEST_TOKEN_LISTS
263 ],
264)
265def test_get_tokens_up_to_path_start_excluding_start(
266 toks: list[str], tokenizer_name: str
267):
268 result = get_tokens_up_to_path_start(toks, include_start_coord=False)
269 match tokenizer_name:
270 case "AOTP_UT":
271 expected = "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START>".split()
272 case "AOTP_CTT_indexed":
273 expected = "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START>".split()
274 assert result == expected
277@mark.parametrize(
278 "toks, tokenizer_name",
279 [
280 param(
281 token_list[0],
282 token_list[1],
283 id=f"{token_list[1]}",
284 )
285 for token_list in TEST_TOKEN_LISTS
286 ],
287)
288def test_strings_to_coords(toks: list[str], tokenizer_name: str):
289 adj_list = get_adj_list_tokens(toks)
290 skipped = strings_to_coords(adj_list, when_noncoord="skip")
291 included = strings_to_coords(adj_list, when_noncoord="include")
293 assert skipped == [
294 (0, 1),
295 (1, 1),
296 (1, 0),
297 (1, 1),
298 (0, 1),
299 (0, 0),
300 ]
302 assert included == [
303 (0, 1),
304 "<-->",
305 (1, 1),
306 ";",
307 (1, 0),
308 "<-->",
309 (1, 1),
310 ";",
311 (0, 1),
312 "<-->",
313 (0, 0),
314 ";",
315 ]
317 with pytest.raises(ValueError):
318 strings_to_coords(adj_list, when_noncoord="error")
320 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)") == [(1, 2), (5, 6)]
321 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="skip") == [
322 (1, 2),
323 (5, 6),
324 ]
325 assert strings_to_coords(
326 "(1,2) <ADJLIST_START> (5,6)", when_noncoord="include"
327 ) == [(1, 2), "<ADJLIST_START>", (5, 6)]
328 with pytest.raises(ValueError):
329 strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="error")
332@mark.parametrize(
333 "toks, tokenizer_name",
334 [
335 param(
336 token_list[0],
337 token_list[1],
338 id=f"{token_list[1]}",
339 )
340 for token_list in TEST_TOKEN_LISTS
341 ],
342)
343def test_coords_to_strings(toks: list[str], tokenizer_name: str):
344 adj_list = get_adj_list_tokens(toks)
345 # config = MazeDatasetConfig(name="test", grid_n=2, n_mazes=1)
346 coords = strings_to_coords(adj_list, when_noncoord="include")
348 skipped = coords_to_strings(
349 coords, coord_to_strings_func=_coord_to_strings_UT, when_noncoord="skip"
350 )
351 included = coords_to_strings(
352 coords, coord_to_strings_func=_coord_to_strings_UT, when_noncoord="include"
353 )
355 assert skipped == [
356 "(0,1)",
357 "(1,1)",
358 "(1,0)",
359 "(1,1)",
360 "(0,1)",
361 "(0,0)",
362 ]
364 assert included == [
365 "(0,1)",
366 "<-->",
367 "(1,1)",
368 ";",
369 "(1,0)",
370 "<-->",
371 "(1,1)",
372 ";",
373 "(0,1)",
374 "<-->",
375 "(0,0)",
376 ";",
377 ]
379 with pytest.raises(ValueError):
380 coords_to_strings(
381 coords, coord_to_strings_func=_coord_to_strings_UT, when_noncoord="error"
382 )
385def test_equal_except_adj_list_sequence():
386 assert equal_except_adj_list_sequence(MAZE_TOKENS[0], MAZE_TOKENS[0])
387 assert not equal_except_adj_list_sequence(
388 MAZE_TOKENS[0], MAZE_TOKENS_AOTP_CTT_indexed[0]
389 )
390 assert equal_except_adj_list_sequence(
391 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
392 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
393 )
394 assert equal_except_adj_list_sequence(
395 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
396 "<ADJLIST_START> (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; (0,1) <--> (1,1) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
397 )
398 assert equal_except_adj_list_sequence(
399 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
400 "<ADJLIST_START> (1,1) <--> (0,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
401 )
402 assert not equal_except_adj_list_sequence(
403 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
404 "<ADJLIST_START> (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; (0,1) <--> (1,1) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,1) (1,0) <PATH_END>".split(),
405 )
406 assert not equal_except_adj_list_sequence(
407 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
408 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END> <PATH_END>".split(),
409 )
410 assert not equal_except_adj_list_sequence(
411 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
412 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
413 )
414 assert not equal_except_adj_list_sequence(
415 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
416 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
417 )
418 with pytest.raises(ValueError):
419 equal_except_adj_list_sequence(
420 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
421 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
422 )
423 with pytest.raises(ValueError):
424 equal_except_adj_list_sequence(
425 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
426 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
427 )
428 assert not equal_except_adj_list_sequence(
429 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
430 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
431 )
433 # CTT
434 assert equal_except_adj_list_sequence(
435 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
436 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
437 )
438 assert equal_except_adj_list_sequence(
439 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
440 "<ADJLIST_START> ( 1 , 1 ) <--> ( 0 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
441 )
442 # This inactive test demonstrates the lack of robustness of the function for comparing source `LatticeMaze` objects.
443 # See function documentation for details.
444 # assert not equal_except_adj_list_sequence(
445 # "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
446 # "<ADJLIST_START> ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split()
447 # )
450# @mivanit: this was really difficult to understand
451@mark.parametrize(
452 "type_, validation_funcs, assertion",
453 [
454 param(
455 type_,
456 vfs,
457 assertion,
458 id=f"{i}-{type_.__name__}",
459 )
460 for i, (type_, vfs, assertion) in enumerate(
461 [
462 (
463 # type
464 PathTokenizers._PathTokenizer,
465 # validation_funcs
466 dict(),
467 # assertion
468 lambda x: PathTokenizers.StepSequence(
469 step_tokenizers=(StepTokenizers.Distance(),)
470 )
471 in x,
472 ),
473 (
474 # type
475 PathTokenizers._PathTokenizer,
476 # validation_funcs
477 {PathTokenizers._PathTokenizer: lambda x: x.is_valid()},
478 # assertion
479 lambda x: PathTokenizers.StepSequence(
480 step_tokenizers=(StepTokenizers.Distance(),)
481 )
482 not in x
483 and PathTokenizers.StepSequence(
484 step_tokenizers=(
485 StepTokenizers.Coord(),
486 StepTokenizers.Coord(),
487 )
488 )
489 not in x,
490 ),
491 ]
492 )
493 ],
494)
495def test_all_instances2(
496 type_: FiniteValued,
497 validation_funcs: frozendict.frozendict[
498 FiniteValued, Callable[[FiniteValued], bool]
499 ],
500 assertion: Callable[[list[FiniteValued]], bool],
501):
502 assert assertion(all_instances(type_, validation_funcs))
505@mark.parametrize(
506 "coords, result",
507 [
508 param(
509 np.array(coords),
510 res,
511 id=f"{coords}",
512 )
513 for coords, res in (
514 [
515 ([[0, 0], [0, 1], [1, 1]], VOCAB.PATH_RIGHT),
516 ([[0, 0], [1, 0], [1, 1]], VOCAB.PATH_LEFT),
517 ([[0, 0], [0, 1], [0, 2]], VOCAB.PATH_FORWARD),
518 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD),
519 ([[0, 0], [0, 1], [0, 1]], VOCAB.PATH_STAY),
520 ([[1, 1], [0, 1], [0, 0]], VOCAB.PATH_LEFT),
521 ([[1, 1], [1, 0], [0, 0]], VOCAB.PATH_RIGHT),
522 ([[0, 2], [0, 1], [0, 0]], VOCAB.PATH_FORWARD),
523 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD),
524 ([[0, 1], [0, 1], [0, 0]], ValueError),
525 ([[0, 1], [1, 1], [0, 0]], ValueError),
526 ([[1, 0], [1, 1], [0, 0]], ValueError),
527 ([[0, 1], [0, 2], [0, 0]], ValueError),
528 ([[0, 1], [0, 0], [0, 0]], VOCAB.PATH_STAY),
529 ([[1, 1], [0, 0], [0, 1]], ValueError),
530 ([[1, 1], [0, 0], [1, 0]], ValueError),
531 ([[0, 2], [0, 0], [0, 1]], ValueError),
532 ([[0, 0], [0, 0], [0, 1]], ValueError),
533 ([[0, 1], [0, 0], [0, 1]], VOCAB.PATH_BACKWARD),
534 ([[-1, 0], [0, 0], [1, 0]], VOCAB.PATH_FORWARD),
535 ([[-1, 0], [0, 0], [0, 1]], VOCAB.PATH_LEFT),
536 ([[-1, 0], [0, 0], [-1, 0]], VOCAB.PATH_BACKWARD),
537 ([[-1, 0], [0, 0], [0, -1]], VOCAB.PATH_RIGHT),
538 ([[-1, 0], [0, 0], [1, 0], [2, 0]], ValueError),
539 ([[-1, 0], [0, 0]], ValueError),
540 ([[-1, 0, 0], [0, 0, 0]], ValueError),
541 ]
542 )
543 ],
544)
545def test_get_relative_direction(
546 coords: Int[np.ndarray, "prev_cur_next=3 axis=2"], result: str | type[Exception]
547):
548 if isinstance(result, type) and issubclass(result, Exception):
549 with pytest.raises(result):
550 get_relative_direction(coords)
551 return
552 assert get_relative_direction(coords) == result
555@mark.parametrize(
556 "edges, result",
557 [
558 param(
559 edges,
560 res,
561 id=f"{edges}",
562 )
563 for edges, res in (
564 [
565 (np.array([[0, 0], [0, 1]]), 1),
566 (np.array([[1, 0], [0, 1]]), 2),
567 (np.array([[-1, 0], [0, 1]]), 2),
568 (np.array([[0, 0], [5, 3]]), 8),
569 (
570 np.array(
571 [
572 [[0, 0], [0, 1]],
573 [[1, 0], [0, 1]],
574 [[-1, 0], [0, 1]],
575 [[0, 0], [5, 3]],
576 ]
577 ),
578 [1, 2, 2, 8],
579 ),
580 (np.array([[[0, 0], [5, 3]]]), [8]),
581 ]
582 )
583 ],
584)
585def test_manhattan_distance(
586 edges: ConnectionArray | Connection,
587 result: Int[np.ndarray, " edges"] | Int[np.ndarray, ""] | type[Exception],
588):
589 if isinstance(result, type) and issubclass(result, Exception):
590 with pytest.raises(result):
591 manhattan_distance(edges)
592 return
593 assert np.array_equal(manhattan_distance(edges), np.array(result, dtype=np.int8))
596@mark.parametrize(
597 "n",
598 [param(n) for n in [2, 3, 5, 20]],
599)
600def test_lattice_connection_arrray(n):
601 edges = lattice_connection_array(n)
602 assert tuple(edges.shape) == (2 * n * (n - 1), 2, 2)
603 assert np.all(np.sum(edges[:, 1], axis=1) > np.sum(edges[:, 0], axis=1))
604 assert tuple(np.unique(edges, axis=0).shape) == (2 * n * (n - 1), 2, 2)
607@mark.parametrize(
608 "edges, maze",
609 [
610 param(
611 edges(),
612 maze,
613 id=f"edges[{i}]; maze[{j}]",
614 )
615 for (i, edges), (j, maze) in itertools.product(
616 enumerate(
617 [
618 lambda: lattice_connection_array(GRID_N),
619 lambda: np.flip(lattice_connection_array(GRID_N), axis=1),
620 lambda: lattice_connection_array(GRID_N - 1),
621 lambda: numpy_rng.choice(
622 lattice_connection_array(GRID_N), 2 * GRID_N, axis=0
623 ),
624 lambda: numpy_rng.choice(
625 lattice_connection_array(GRID_N), 1, axis=0
626 ),
627 ]
628 ),
629 enumerate(MAZE_DATASET.mazes),
630 )
631 ],
632)
633def test_is_connection(edges: ConnectionArray, maze: LatticeMaze):
634 output = is_connection(edges, maze.connection_list)
635 sorted_edges = np.sort(edges, axis=1)
636 edge_direction = (
637 (sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0
638 ).astype(np.int8)
639 assert np.array_equal(
640 output,
641 maze.connection_list[
642 edge_direction, sorted_edges[:, 0, 0], sorted_edges[:, 0, 1]
643 ],
644 )