Coverage for tests\unit\maze_dataset\dataset\test_collected_dataset.py: 100%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-23 12:49 -0700

1from functools import cached_property 

2 

3import numpy as np 

4 

5from maze_dataset.dataset.collected_dataset import ( 

6 MazeDatasetCollection, 

7 MazeDatasetCollectionConfig, 

8 MazeDatasetConfig, 

9) 

10 

11DATASET_LENGTHS: list[int] = [1, 0, 3, 2, 1] 

12DATASET_GRID_SIZES: list[int] = [5, 1, 3, 3, 4] 

13 

14 

15class TestMazeDatasetCollection: 

16 @cached_property 

17 def test_collection(self) -> MazeDatasetCollection: 

18 config = MazeDatasetCollectionConfig( 

19 name="test_collection", 

20 maze_dataset_configs=[ 

21 MazeDatasetConfig( 

22 n_mazes=n_mazes, 

23 grid_n=grid_n, 

24 name=f"test_dataset_{n_mazes}_{grid_n}", 

25 ) 

26 for n_mazes, grid_n in zip(DATASET_LENGTHS, DATASET_GRID_SIZES) 

27 ], 

28 ) 

29 return MazeDatasetCollection.from_config( 

30 config, 

31 do_generate=True, 

32 load_local=False, 

33 do_download=False, 

34 save_local=True, 

35 local_base_path="data/", 

36 ) 

37 

38 def test_dataset_lengths(self): 

39 assert np.all( 

40 np.array(self.test_collection.dataset_lengths) == np.array(DATASET_LENGTHS) 

41 ) 

42 

43 def test_dataset_cum_lengths(self): 

44 assert ( 

45 self.test_collection.dataset_cum_lengths == np.array([1, 1, 4, 6, 7]) 

46 ).all() 

47 

48 def test_mazes(self): 

49 assert len(self.test_collection.mazes) == 7 

50 assert self.test_collection.mazes[0].connection_list.shape == (2, 5, 5) 

51 assert self.test_collection.mazes[-1].connection_list.shape == (2, 4, 4) 

52 

53 def test_len(self): 

54 assert len(self.test_collection) == 7 

55 

56 def test_getitem(self): 

57 # print(len(self.test_collection)) 

58 # print(self.test_collection.mazes) 

59 assert self.test_collection[0].connection_list.shape == (2, 5, 5) 

60 assert self.test_collection[1].connection_list.shape == (2, 3, 3) 

61 assert self.test_collection[2].connection_list.shape == (2, 3, 3) 

62 assert self.test_collection[3].connection_list.shape == (2, 3, 3) 

63 assert self.test_collection[4].connection_list.shape == (2, 3, 3) 

64 assert self.test_collection[5].connection_list.shape == (2, 3, 3) 

65 assert self.test_collection[6].connection_list.shape == (2, 4, 4) 

66 

67 for i in range(sum(DATASET_LENGTHS)): 

68 assert ( 

69 self.test_collection[i].connection_list.shape 

70 == self.test_collection.mazes[i].connection_list.shape 

71 ) 

72 assert ( 

73 self.test_collection[i].connection_list 

74 == self.test_collection.mazes[i].connection_list 

75 ).all() 

76 

77 def test_download(self): 

78 # TODO 

79 pass 

80 

81 def test_serialize_and_load(self): 

82 serialized = self.test_collection.serialize() 

83 loaded = MazeDatasetCollection.load(serialized) 

84 assert loaded.mazes == self.test_collection.mazes 

85 assert loaded.cfg.diff(self.test_collection.cfg) == {} 

86 assert loaded.cfg == self.test_collection.cfg 

87 

88 def test_save_read(self): 

89 self.test_collection.save("tests/_temp/collected_dataset_test_save_read.zanj") 

90 loaded = MazeDatasetCollection.read( 

91 "tests/_temp/collected_dataset_test_save_read.zanj" 

92 ) 

93 assert loaded.mazes == self.test_collection.mazes 

94 assert loaded.cfg.diff(self.test_collection.cfg) == {} 

95 assert loaded.cfg == self.test_collection.cfg