Coverage for tests\unit\maze_dataset\dataset\test_rasterized.py: 100%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1from itertools import product 

2 

3import numpy as np 

4import pytest 

5 

6from maze_dataset import LatticeMazeGenerators, MazeDatasetConfig 

7from maze_dataset.dataset.maze_dataset import MazeDataset 

8from maze_dataset.dataset.rasterized import ( 

9 RasterizedMazeDataset, 

10 RasterizedMazeDatasetConfig, 

11 make_numpy_collection, 

12) 

13 

14_PARAMTETRIZATION = ( 

15 "remove_isolated_cells, extend_pixels, endpoints_as_open", 

16 list(product([True, False], repeat=3)), 

17) 

18 

19 

20@pytest.mark.parametrize(*_PARAMTETRIZATION) 

21def test_rasterized_new(remove_isolated_cells, extend_pixels, endpoints_as_open): 

22 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig( 

23 name="test", 

24 grid_n=5, 

25 n_mazes=2, 

26 maze_ctor=LatticeMazeGenerators.gen_percolation, # use percolation here to get some isolated cells 

27 maze_ctor_kwargs=dict(p=0.4), 

28 remove_isolated_cells=remove_isolated_cells, 

29 extend_pixels=extend_pixels, 

30 endpoints_as_open=endpoints_as_open, 

31 ) 

32 dataset: RasterizedMazeDataset = RasterizedMazeDataset.from_config_augmented( 

33 cfg, load_local=False 

34 ) 

35 

36 print(f"{dataset[0][0].shape = }, {dataset[0][1].shape = }") 

37 print(f"{dataset[0][1] = }\n{dataset[1][1] = }") 

38 

39 

40@pytest.mark.parametrize(*_PARAMTETRIZATION) 

41def test_rasterized_from_mazedataset( 

42 remove_isolated_cells, extend_pixels, endpoints_as_open 

43): 

44 cfg: MazeDatasetConfig = MazeDatasetConfig( 

45 name="test", 

46 grid_n=5, 

47 n_mazes=2, 

48 maze_ctor=LatticeMazeGenerators.gen_percolation, # use percolation here to get some isolated cells 

49 maze_ctor_kwargs=dict(p=0.4), 

50 ) 

51 dataset_m: MazeDataset = MazeDataset.from_config(cfg, load_local=False) 

52 dataset_r: RasterizedMazeDataset = RasterizedMazeDataset.from_base_MazeDataset( 

53 dataset_m, 

54 added_params=dict( 

55 remove_isolated_cells=remove_isolated_cells, 

56 extend_pixels=extend_pixels, 

57 endpoints_as_open=endpoints_as_open, 

58 ), 

59 ) 

60 

61 assert dataset_r 

62 

63 

64@pytest.mark.parametrize(*_PARAMTETRIZATION) 

65def test_make_numpy_collection(remove_isolated_cells, extend_pixels, endpoints_as_open): 

66 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig( 

67 name="test", 

68 grid_n=5, 

69 n_mazes=2, 

70 maze_ctor=LatticeMazeGenerators.gen_percolation, # use percolation here to get some isolated cells 

71 maze_ctor_kwargs=dict(p=0.4), 

72 remove_isolated_cells=remove_isolated_cells, 

73 extend_pixels=extend_pixels, 

74 endpoints_as_open=endpoints_as_open, 

75 ) 

76 

77 output = make_numpy_collection( 

78 base_cfg=cfg, 

79 grid_sizes=[2, 3], 

80 from_config_kwargs=dict(load_local=False), 

81 verbose=True, 

82 ) 

83 

84 assert isinstance(output, dict) 

85 assert isinstance(output["configs"], dict) 

86 assert isinstance(output["arrays"], dict) 

87 

88 assert len(output["configs"]) == 2 

89 assert len(output["arrays"]) == 2 

90 

91 for k, v in output["configs"].items(): 

92 assert isinstance(k, str) 

93 assert isinstance(v, RasterizedMazeDatasetConfig) 

94 

95 for k, v in output["arrays"].items(): 

96 assert isinstance(k, str) 

97 assert isinstance(v, np.ndarray)