Coverage for tests/unit/maze_dataset/tokenization/test_token_utils.py: 97%
174 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
1import itertools
2from typing import Callable
4import frozendict
5import numpy as np
6import pytest
7from jaxtyping import Int
9from maze_dataset import LatticeMaze
10from maze_dataset.constants import VOCAB, Connection, ConnectionArray
11from maze_dataset.generation import numpy_rng
12from maze_dataset.testing_utils import GRID_N, MAZE_DATASET
13from maze_dataset.token_utils import (
14 _coord_to_strings_UT,
15 coords_to_strings,
16 equal_except_adj_list_sequence,
17 get_adj_list_tokens,
18 get_origin_tokens,
19 get_path_tokens,
20 get_relative_direction,
21 get_target_tokens,
22 is_connection,
23 strings_to_coords,
24 tokens_between,
25)
26from maze_dataset.tokenization import (
27 PathTokenizers,
28 StepTokenizers,
29 get_tokens_up_to_path_start,
30)
31from maze_dataset.utils import (
32 FiniteValued,
33 all_instances,
34 lattice_connection_array,
35 manhattan_distance,
36)
38MAZE_TOKENS: tuple[list[str], str] = (
39 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
40 "AOTP_UT",
41)
42MAZE_TOKENS_AOTP_CTT_indexed: tuple[list[str], str] = (
43 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
44 "AOTP_CTT_indexed",
45)
46TEST_TOKEN_LISTS: list[tuple[list[str], str]] = [
47 MAZE_TOKENS,
48 MAZE_TOKENS_AOTP_CTT_indexed,
49]
52@pytest.mark.parametrize(
53 ("toks", "tokenizer_name"),
54 [
55 pytest.param(
56 token_list[0],
57 token_list[1],
58 id=f"{token_list[1]}",
59 )
60 for token_list in TEST_TOKEN_LISTS
61 ],
62)
63def test_tokens_between(toks: list[str], tokenizer_name: str):
64 result = tokens_between(toks, "<PATH_START>", "<PATH_END>")
65 match tokenizer_name:
66 case "AOTP_UT":
67 assert result == ["(1,0)", "(1,1)"]
68 case "AOTP_CTT_indexed":
69 assert result == ["(", "1", ",", "0", ")", "(", "1", ",", "1", ")"]
71 # Normal case
72 tokens = ["the", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"]
73 start_value = "quick"
74 end_value = "over"
75 assert tokens_between(tokens, start_value, end_value) == ["brown", "fox", "jumps"]
77 # Including start and end values
78 assert tokens_between(tokens, start_value, end_value, True, True) == [
79 "quick",
80 "brown",
81 "fox",
82 "jumps",
83 "over",
84 ]
86 # When start_value or end_value is not unique and except_when_tokens_not_unique is True
87 with pytest.raises(ValueError): # noqa: PT011
88 tokens_between(tokens, "the", "dog", False, False, True)
90 # When start_value or end_value is not unique and except_when_tokens_not_unique is False
91 assert tokens_between(tokens, "the", "dog", False, False, False) == [
92 "quick",
93 "brown",
94 "fox",
95 "jumps",
96 "over",
97 "the",
98 "lazy",
99 ]
101 # Empty tokens list
102 with pytest.raises(ValueError): # noqa: PT011
103 tokens_between([], "start", "end")
105 # start_value and end_value are the same
106 with pytest.raises(ValueError): # noqa: PT011
107 tokens_between(tokens, "fox", "fox")
109 # start_value or end_value not in the tokens list
110 with pytest.raises(ValueError): # noqa: PT011
111 tokens_between(tokens, "start", "end")
113 # start_value comes after end_value in the tokens list
114 with pytest.raises(AssertionError):
115 tokens_between(tokens, "over", "quick")
117 # start_value and end_value are at the beginning and end of the tokens list, respectively
118 assert tokens_between(tokens, "the", "dog", True, True) == tokens
120 # Single element in the tokens list, which is the same as start_value and end_value
121 with pytest.raises(ValueError): # noqa: PT011
122 tokens_between(["fox"], "fox", "fox", True, True)
125@pytest.mark.parametrize(
126 ("toks", "tokenizer_name"),
127 [
128 pytest.param(
129 token_list[0],
130 token_list[1],
131 id=f"{token_list[1]}",
132 )
133 for token_list in TEST_TOKEN_LISTS
134 ],
135)
136def test_tokens_between_out_of_order(toks: list[str], tokenizer_name: str):
137 assert tokenizer_name
138 with pytest.raises(AssertionError):
139 tokens_between(toks, "<PATH_END>", "<PATH_START>")
142@pytest.mark.parametrize(
143 ("toks", "tokenizer_name"),
144 [
145 pytest.param(
146 token_list[0],
147 token_list[1],
148 id=f"{token_list[1]}",
149 )
150 for token_list in TEST_TOKEN_LISTS
151 ],
152)
153def test_get_adj_list_tokens(toks: list[str], tokenizer_name: str):
154 result = get_adj_list_tokens(toks)
155 match tokenizer_name:
156 case "AOTP_UT":
157 expected = (
158 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ;".split()
159 )
160 case "AOTP_CTT_indexed":
161 expected = "( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ;".split()
162 assert result == expected
165@pytest.mark.parametrize(
166 ("toks", "tokenizer_name"),
167 [
168 pytest.param(
169 token_list[0],
170 token_list[1],
171 id=f"{token_list[1]}",
172 )
173 for token_list in TEST_TOKEN_LISTS
174 ],
175)
176def test_get_path_tokens(toks: list[str], tokenizer_name: str):
177 result_notrim = get_path_tokens(toks)
178 result_trim = get_path_tokens(toks, trim_end=True)
179 match tokenizer_name:
180 case "AOTP_UT":
181 assert result_notrim == ["<PATH_START>", "(1,0)", "(1,1)", "<PATH_END>"]
182 assert result_trim == ["(1,0)", "(1,1)"]
183 case "AOTP_CTT_indexed":
184 assert (
185 result_notrim == "<PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split()
186 )
187 assert result_trim == "( 1 , 0 ) ( 1 , 1 )".split()
190@pytest.mark.parametrize(
191 ("toks", "tokenizer_name"),
192 [
193 pytest.param(
194 token_list[0],
195 token_list[1],
196 id=f"{token_list[1]}",
197 )
198 for token_list in TEST_TOKEN_LISTS
199 ],
200)
201def test_get_origin_tokens(toks: list[str], tokenizer_name: str):
202 result = get_origin_tokens(toks)
203 match tokenizer_name:
204 case "AOTP_UT":
205 assert result == ["(1,0)"]
206 case "AOTP_CTT_indexed":
207 assert result == "( 1 , 0 )".split()
210@pytest.mark.parametrize(
211 ("toks", "tokenizer_name"),
212 [
213 pytest.param(
214 token_list[0],
215 token_list[1],
216 id=f"{token_list[1]}",
217 )
218 for token_list in TEST_TOKEN_LISTS
219 ],
220)
221def test_get_target_tokens(toks: list[str], tokenizer_name: str):
222 result = get_target_tokens(toks)
223 match tokenizer_name:
224 case "AOTP_UT":
225 assert result == ["(1,1)"]
226 case "AOTP_CTT_indexed":
227 assert result == "( 1 , 1 )".split()
230@pytest.mark.parametrize(
231 ("toks", "tokenizer_name"),
232 [
233 pytest.param(
234 token_list[0],
235 token_list[1],
236 id=f"{token_list[1]}",
237 )
238 for token_list in [MAZE_TOKENS]
239 ],
240)
241def test_get_tokens_up_to_path_start_including_start(
242 toks: list[str],
243 tokenizer_name: str,
244):
245 # Dont test on `MAZE_TOKENS_AOTP_CTT_indexed` because this function doesn't support `AOTP_CTT_indexed` when `include_start_coord=True`.
246 result = get_tokens_up_to_path_start(toks, include_start_coord=True)
247 match tokenizer_name:
248 case "AOTP_UT":
249 expected = "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0)".split()
250 case "AOTP_CTT_indexed":
251 expected = "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 )".split()
252 assert result == expected
255@pytest.mark.parametrize(
256 ("toks", "tokenizer_name"),
257 [
258 pytest.param(
259 token_list[0],
260 token_list[1],
261 id=f"{token_list[1]}",
262 )
263 for token_list in TEST_TOKEN_LISTS
264 ],
265)
266def test_get_tokens_up_to_path_start_excluding_start(
267 toks: list[str],
268 tokenizer_name: str,
269):
270 result = get_tokens_up_to_path_start(toks, include_start_coord=False)
271 match tokenizer_name:
272 case "AOTP_UT":
273 expected = "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START>".split()
274 case "AOTP_CTT_indexed":
275 expected = "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START>".split()
276 assert result == expected
279@pytest.mark.parametrize(
280 ("toks", "tokenizer_name"),
281 [
282 pytest.param(
283 token_list[0],
284 token_list[1],
285 id=f"{token_list[1]}",
286 )
287 for token_list in TEST_TOKEN_LISTS
288 ],
289)
290def test_strings_to_coords(toks: list[str], tokenizer_name: str):
291 assert tokenizer_name
292 adj_list = get_adj_list_tokens(toks)
293 skipped = strings_to_coords(adj_list, when_noncoord="skip")
294 included = strings_to_coords(adj_list, when_noncoord="include")
296 assert skipped == [
297 (0, 1),
298 (1, 1),
299 (1, 0),
300 (1, 1),
301 (0, 1),
302 (0, 0),
303 ]
305 assert included == [
306 (0, 1),
307 "<-->",
308 (1, 1),
309 ";",
310 (1, 0),
311 "<-->",
312 (1, 1),
313 ";",
314 (0, 1),
315 "<-->",
316 (0, 0),
317 ";",
318 ]
320 with pytest.raises(ValueError): # noqa: PT011
321 strings_to_coords(adj_list, when_noncoord="error")
323 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)") == [(1, 2), (5, 6)]
324 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="skip") == [
325 (1, 2),
326 (5, 6),
327 ]
328 assert strings_to_coords(
329 "(1,2) <ADJLIST_START> (5,6)",
330 when_noncoord="include",
331 ) == [(1, 2), "<ADJLIST_START>", (5, 6)]
332 with pytest.raises(ValueError): # noqa: PT011
333 strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="error")
336@pytest.mark.parametrize(
337 ("toks", "tokenizer_name"),
338 [
339 pytest.param(
340 token_list[0],
341 token_list[1],
342 id=f"{token_list[1]}",
343 )
344 for token_list in TEST_TOKEN_LISTS
345 ],
346)
347def test_coords_to_strings(toks: list[str], tokenizer_name: str):
348 assert tokenizer_name
349 adj_list = get_adj_list_tokens(toks)
350 # config = MazeDatasetConfig(name="test", grid_n=2, n_mazes=1)
351 coords = strings_to_coords(adj_list, when_noncoord="include")
353 skipped = coords_to_strings(
354 coords,
355 coord_to_strings_func=_coord_to_strings_UT,
356 when_noncoord="skip",
357 )
358 included = coords_to_strings(
359 coords,
360 coord_to_strings_func=_coord_to_strings_UT,
361 when_noncoord="include",
362 )
364 assert skipped == [
365 "(0,1)",
366 "(1,1)",
367 "(1,0)",
368 "(1,1)",
369 "(0,1)",
370 "(0,0)",
371 ]
373 assert included == [
374 "(0,1)",
375 "<-->",
376 "(1,1)",
377 ";",
378 "(1,0)",
379 "<-->",
380 "(1,1)",
381 ";",
382 "(0,1)",
383 "<-->",
384 "(0,0)",
385 ";",
386 ]
388 with pytest.raises(ValueError): # noqa: PT011
389 coords_to_strings(
390 coords,
391 coord_to_strings_func=_coord_to_strings_UT,
392 when_noncoord="error",
393 )
396def test_equal_except_adj_list_sequence():
397 assert equal_except_adj_list_sequence(MAZE_TOKENS[0], MAZE_TOKENS[0])
398 assert not equal_except_adj_list_sequence(
399 MAZE_TOKENS[0],
400 MAZE_TOKENS_AOTP_CTT_indexed[0],
401 )
402 assert equal_except_adj_list_sequence(
403 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
404 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
405 )
406 assert equal_except_adj_list_sequence(
407 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
408 "<ADJLIST_START> (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; (0,1) <--> (1,1) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
409 )
410 assert equal_except_adj_list_sequence(
411 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
412 "<ADJLIST_START> (1,1) <--> (0,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
413 )
414 assert not equal_except_adj_list_sequence(
415 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
416 "<ADJLIST_START> (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; (0,1) <--> (1,1) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,1) (1,0) <PATH_END>".split(),
417 )
418 assert not equal_except_adj_list_sequence(
419 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
420 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END> <PATH_END>".split(),
421 )
422 assert not equal_except_adj_list_sequence(
423 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
424 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
425 )
426 assert not equal_except_adj_list_sequence(
427 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
428 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
429 )
430 with pytest.raises(ValueError): # noqa: PT011
431 equal_except_adj_list_sequence(
432 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
433 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
434 )
435 with pytest.raises(ValueError): # noqa: PT011
436 equal_except_adj_list_sequence(
437 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
438 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
439 )
440 assert not equal_except_adj_list_sequence(
441 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
442 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
443 )
445 # CTT
446 assert equal_except_adj_list_sequence(
447 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
448 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
449 )
450 assert equal_except_adj_list_sequence(
451 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
452 "<ADJLIST_START> ( 1 , 1 ) <--> ( 0 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
453 )
454 # This inactive test demonstrates the lack of robustness of the function for comparing source `LatticeMaze` objects.
455 # See function documentation for details.
456 # assert not equal_except_adj_list_sequence(
457 # "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
458 # "<ADJLIST_START> ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split()
459 # )
462# @mivanit: this was really difficult to understand
463@pytest.mark.parametrize(
464 ("type_", "validation_funcs", "assertion"),
465 [
466 pytest.param(
467 type_,
468 vfs,
469 assertion,
470 id=f"{i}-{type_.__name__}",
471 )
472 for i, (type_, vfs, assertion) in enumerate(
473 [
474 (
475 # type
476 PathTokenizers._PathTokenizer,
477 # validation_funcs
478 dict(),
479 # assertion
480 lambda x: PathTokenizers.StepSequence(
481 step_tokenizers=(StepTokenizers.Distance(),),
482 )
483 in x,
484 ),
485 (
486 # type
487 PathTokenizers._PathTokenizer,
488 # validation_funcs
489 {PathTokenizers._PathTokenizer: lambda x: x.is_valid()},
490 # assertion
491 lambda x: PathTokenizers.StepSequence(
492 step_tokenizers=(StepTokenizers.Distance(),),
493 )
494 not in x
495 and PathTokenizers.StepSequence(
496 step_tokenizers=(
497 StepTokenizers.Coord(),
498 StepTokenizers.Coord(),
499 ),
500 )
501 not in x,
502 ),
503 ],
504 )
505 ],
506)
507def test_all_instances2(
508 type_: FiniteValued,
509 validation_funcs: frozendict.frozendict[
510 FiniteValued,
511 Callable[[FiniteValued], bool],
512 ],
513 assertion: Callable[[list[FiniteValued]], bool],
514):
515 assert assertion(all_instances(type_, validation_funcs))
518@pytest.mark.parametrize(
519 ("coords", "result"),
520 [
521 pytest.param(
522 np.array(coords),
523 res,
524 id=f"{coords}",
525 )
526 for coords, res in (
527 [
528 ([[0, 0], [0, 1], [1, 1]], VOCAB.PATH_RIGHT),
529 ([[0, 0], [1, 0], [1, 1]], VOCAB.PATH_LEFT),
530 ([[0, 0], [0, 1], [0, 2]], VOCAB.PATH_FORWARD),
531 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD),
532 ([[0, 0], [0, 1], [0, 1]], VOCAB.PATH_STAY),
533 ([[1, 1], [0, 1], [0, 0]], VOCAB.PATH_LEFT),
534 ([[1, 1], [1, 0], [0, 0]], VOCAB.PATH_RIGHT),
535 ([[0, 2], [0, 1], [0, 0]], VOCAB.PATH_FORWARD),
536 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD),
537 ([[0, 1], [0, 1], [0, 0]], ValueError),
538 ([[0, 1], [1, 1], [0, 0]], ValueError),
539 ([[1, 0], [1, 1], [0, 0]], ValueError),
540 ([[0, 1], [0, 2], [0, 0]], ValueError),
541 ([[0, 1], [0, 0], [0, 0]], VOCAB.PATH_STAY),
542 ([[1, 1], [0, 0], [0, 1]], ValueError),
543 ([[1, 1], [0, 0], [1, 0]], ValueError),
544 ([[0, 2], [0, 0], [0, 1]], ValueError),
545 ([[0, 0], [0, 0], [0, 1]], ValueError),
546 ([[0, 1], [0, 0], [0, 1]], VOCAB.PATH_BACKWARD),
547 ([[-1, 0], [0, 0], [1, 0]], VOCAB.PATH_FORWARD),
548 ([[-1, 0], [0, 0], [0, 1]], VOCAB.PATH_LEFT),
549 ([[-1, 0], [0, 0], [-1, 0]], VOCAB.PATH_BACKWARD),
550 ([[-1, 0], [0, 0], [0, -1]], VOCAB.PATH_RIGHT),
551 ([[-1, 0], [0, 0], [1, 0], [2, 0]], ValueError),
552 ([[-1, 0], [0, 0]], ValueError),
553 ([[-1, 0, 0], [0, 0, 0]], ValueError),
554 ]
555 )
556 ],
557)
558def test_get_relative_direction(
559 coords: Int[np.ndarray, "prev_cur_next=3 axis=2"],
560 result: str | type[Exception],
561):
562 if isinstance(result, type) and issubclass(result, Exception):
563 with pytest.raises(result):
564 get_relative_direction(coords)
565 return
566 assert get_relative_direction(coords) == result
569@pytest.mark.parametrize(
570 ("edges", "result"),
571 [
572 pytest.param(
573 edges,
574 res,
575 id=f"{edges}",
576 )
577 for edges, res in (
578 [
579 (np.array([[0, 0], [0, 1]]), 1),
580 (np.array([[1, 0], [0, 1]]), 2),
581 (np.array([[-1, 0], [0, 1]]), 2),
582 (np.array([[0, 0], [5, 3]]), 8),
583 (
584 np.array(
585 [
586 [[0, 0], [0, 1]],
587 [[1, 0], [0, 1]],
588 [[-1, 0], [0, 1]],
589 [[0, 0], [5, 3]],
590 ],
591 ),
592 [1, 2, 2, 8],
593 ),
594 (np.array([[[0, 0], [5, 3]]]), [8]),
595 ]
596 )
597 ],
598)
599def test_manhattan_distance(
600 edges: ConnectionArray | Connection,
601 result: Int[np.ndarray, " edges"] | Int[np.ndarray, ""] | type[Exception],
602):
603 if isinstance(result, type) and issubclass(result, Exception):
604 with pytest.raises(result):
605 manhattan_distance(edges)
606 return
607 assert np.array_equal(manhattan_distance(edges), np.array(result, dtype=np.int8))
610@pytest.mark.parametrize(
611 "n",
612 [pytest.param(n) for n in [2, 3, 5, 20]],
613)
614def test_lattice_connection_arrray(n):
615 edges = lattice_connection_array(n)
616 assert tuple(edges.shape) == (2 * n * (n - 1), 2, 2)
617 assert np.all(np.sum(edges[:, 1], axis=1) > np.sum(edges[:, 0], axis=1))
618 assert tuple(np.unique(edges, axis=0).shape) == (2 * n * (n - 1), 2, 2)
621@pytest.mark.parametrize(
622 ("edges", "maze"),
623 [
624 pytest.param(
625 edges(),
626 maze,
627 id=f"edges[{i}]; maze[{j}]",
628 )
629 for (i, edges), (j, maze) in itertools.product(
630 enumerate(
631 [
632 lambda: lattice_connection_array(GRID_N),
633 lambda: np.flip(lattice_connection_array(GRID_N), axis=1),
634 lambda: lattice_connection_array(GRID_N - 1),
635 lambda: numpy_rng.choice(
636 lattice_connection_array(GRID_N),
637 2 * GRID_N,
638 axis=0,
639 ),
640 lambda: numpy_rng.choice(
641 lattice_connection_array(GRID_N),
642 1,
643 axis=0,
644 ),
645 ],
646 ),
647 enumerate(MAZE_DATASET.mazes),
648 )
649 ],
650)
651def test_is_connection(edges: ConnectionArray, maze: LatticeMaze):
652 output = is_connection(edges, maze.connection_list)
653 sorted_edges = np.sort(edges, axis=1)
654 edge_direction = (
655 (sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0
656 ).astype(np.int8)
657 assert np.array_equal(
658 output,
659 maze.connection_list[
660 edge_direction,
661 sorted_edges[:, 0, 0],
662 sorted_edges[:, 0, 1],
663 ],
664 )