Coverage for tests/unit/maze_dataset/dataset/test_collected_dataset.py: 100%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-11 00:49 -0600

1from functools import cached_property 

2 

3import numpy as np 

4 

5from maze_dataset.dataset.collected_dataset import ( 

6 MazeDatasetCollection, 

7 MazeDatasetCollectionConfig, 

8 MazeDatasetConfig, 

9) 

10 

11DATASET_LENGTHS: list[int] = [1, 0, 3, 2, 1] 

12DATASET_GRID_SIZES: list[int] = [5, 1, 3, 3, 4] 

13 

14 

15class TestMazeDatasetCollection: 

16 @cached_property 

17 def test_collection(self) -> MazeDatasetCollection: 

18 config = MazeDatasetCollectionConfig( 

19 name="test_collection", 

20 maze_dataset_configs=[ 

21 MazeDatasetConfig( 

22 n_mazes=n_mazes, 

23 grid_n=grid_n, 

24 name=f"test_dataset_{n_mazes}_{grid_n}", 

25 ) 

26 for n_mazes, grid_n in zip( 

27 DATASET_LENGTHS, 

28 DATASET_GRID_SIZES, 

29 strict=False, 

30 ) 

31 ], 

32 ) 

33 return MazeDatasetCollection.from_config( 

34 config, 

35 do_generate=True, 

36 load_local=False, 

37 do_download=False, 

38 save_local=True, 

39 local_base_path="data/", 

40 ) 

41 

42 def test_dataset_lengths(self): 

43 assert np.all( 

44 np.array(self.test_collection.dataset_lengths) == np.array(DATASET_LENGTHS), 

45 ) 

46 

47 def test_dataset_cum_lengths(self): 

48 assert ( 

49 self.test_collection.dataset_cum_lengths == np.array([1, 1, 4, 6, 7]) 

50 ).all() 

51 

52 def test_mazes(self): 

53 assert len(self.test_collection.mazes) == 7 

54 assert self.test_collection.mazes[0].connection_list.shape == (2, 5, 5) 

55 assert self.test_collection.mazes[-1].connection_list.shape == (2, 4, 4) 

56 

57 def test_len(self): 

58 assert len(self.test_collection) == 7 

59 

60 def test_getitem(self): 

61 # print(len(self.test_collection)) 

62 # print(self.test_collection.mazes) 

63 assert self.test_collection[0].connection_list.shape == (2, 5, 5) 

64 assert self.test_collection[1].connection_list.shape == (2, 3, 3) 

65 assert self.test_collection[2].connection_list.shape == (2, 3, 3) 

66 assert self.test_collection[3].connection_list.shape == (2, 3, 3) 

67 assert self.test_collection[4].connection_list.shape == (2, 3, 3) 

68 assert self.test_collection[5].connection_list.shape == (2, 3, 3) 

69 assert self.test_collection[6].connection_list.shape == (2, 4, 4) 

70 

71 for i in range(sum(DATASET_LENGTHS)): 

72 assert ( 

73 self.test_collection[i].connection_list.shape 

74 == self.test_collection.mazes[i].connection_list.shape 

75 ) 

76 assert ( 

77 self.test_collection[i].connection_list 

78 == self.test_collection.mazes[i].connection_list 

79 ).all() 

80 

81 def test_download(self): 

82 # TODO: test downloading after we implement downloading datasets 

83 pass 

84 

85 def test_serialize_and_load(self): 

86 serialized = self.test_collection.serialize() 

87 loaded = MazeDatasetCollection.load(serialized) 

88 assert loaded.mazes == self.test_collection.mazes 

89 assert loaded.cfg.diff(self.test_collection.cfg) == {} 

90 assert loaded.cfg == self.test_collection.cfg 

91 

92 def test_save_read(self): 

93 self.test_collection.save("tests/_temp/collected_dataset_test_save_read.zanj") 

94 loaded = MazeDatasetCollection.read( 

95 "tests/_temp/collected_dataset_test_save_read.zanj", 

96 ) 

97 assert loaded.mazes == self.test_collection.mazes 

98 assert loaded.cfg.diff(self.test_collection.cfg) == {} 

99 assert loaded.cfg == self.test_collection.cfg