Coverage for tests\unit\maze_dataset\dataset\test_rasterized.py: 100%
34 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-23 12:49 -0700
1from itertools import product
3import numpy as np
4import pytest
6from maze_dataset import LatticeMazeGenerators, MazeDatasetConfig
7from maze_dataset.dataset.maze_dataset import MazeDataset
8from maze_dataset.dataset.rasterized import (
9 RasterizedMazeDataset,
10 RasterizedMazeDatasetConfig,
11 make_numpy_collection,
12)
14_PARAMTETRIZATION = (
15 "remove_isolated_cells, extend_pixels, endpoints_as_open",
16 list(product([True, False], repeat=3)),
17)
20@pytest.mark.parametrize(*_PARAMTETRIZATION)
21def test_rasterized_new(remove_isolated_cells, extend_pixels, endpoints_as_open):
22 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig(
23 name="test",
24 grid_n=5,
25 n_mazes=2,
26 maze_ctor=LatticeMazeGenerators.gen_percolation, # use percolation here to get some isolated cells
27 maze_ctor_kwargs=dict(p=0.4),
28 remove_isolated_cells=remove_isolated_cells,
29 extend_pixels=extend_pixels,
30 endpoints_as_open=endpoints_as_open,
31 )
32 dataset: RasterizedMazeDataset = RasterizedMazeDataset.from_config_augmented(
33 cfg, load_local=False
34 )
36 print(f"{dataset[0][0].shape = }, {dataset[0][1].shape = }")
37 print(f"{dataset[0][1] = }\n{dataset[1][1] = }")
40@pytest.mark.parametrize(*_PARAMTETRIZATION)
41def test_rasterized_from_mazedataset(
42 remove_isolated_cells, extend_pixels, endpoints_as_open
43):
44 cfg: MazeDatasetConfig = MazeDatasetConfig(
45 name="test",
46 grid_n=5,
47 n_mazes=2,
48 maze_ctor=LatticeMazeGenerators.gen_percolation, # use percolation here to get some isolated cells
49 maze_ctor_kwargs=dict(p=0.4),
50 )
51 dataset_m: MazeDataset = MazeDataset.from_config(cfg, load_local=False)
52 dataset_r: RasterizedMazeDataset = RasterizedMazeDataset.from_base_MazeDataset(
53 dataset_m,
54 added_params=dict(
55 remove_isolated_cells=remove_isolated_cells,
56 extend_pixels=extend_pixels,
57 endpoints_as_open=endpoints_as_open,
58 ),
59 )
61 assert dataset_r
64@pytest.mark.parametrize(*_PARAMTETRIZATION)
65def test_make_numpy_collection(remove_isolated_cells, extend_pixels, endpoints_as_open):
66 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig(
67 name="test",
68 grid_n=5,
69 n_mazes=2,
70 maze_ctor=LatticeMazeGenerators.gen_percolation, # use percolation here to get some isolated cells
71 maze_ctor_kwargs=dict(p=0.4),
72 remove_isolated_cells=remove_isolated_cells,
73 extend_pixels=extend_pixels,
74 endpoints_as_open=endpoints_as_open,
75 )
77 output = make_numpy_collection(
78 base_cfg=cfg,
79 grid_sizes=[2, 3],
80 from_config_kwargs=dict(load_local=False),
81 verbose=True,
82 )
84 assert isinstance(output, dict)
85 assert isinstance(output["configs"], dict)
86 assert isinstance(output["arrays"], dict)
88 assert len(output["configs"]) == 2
89 assert len(output["arrays"]) == 2
91 for k, v in output["configs"].items():
92 assert isinstance(k, str)
93 assert isinstance(v, RasterizedMazeDatasetConfig)
95 for k, v in output["arrays"].items():
96 assert isinstance(k, str)
97 assert isinstance(v, np.ndarray)