Coverage for sleapyfaces/project.py: 65%

63 statements  

« prev     ^ index     » next       coverage.py v7.0.2, created at 2023-01-03 12:07 -0800

1import os 

2from sleapyfaces.structs import CustomColumn, File, FileConstructor 

3from sleapyfaces.experiment import Experiment 

4from sleapyfaces.normalize import mean_center, z_score, pca 

5from dataclasses import dataclass 

6import pandas as pd 

7 

8 

9class Project: 

10 """Base class for project 

11 

12 Args: 

13 base (str): Base path of the project (e.g. "/specialk_cs/2p/raw/CSE009") 

14 iterator (dict[str, str]): Iterator for the project files, with keys as the label and values as the folder name (e.g. {"week 1": "20211105", "week 2": "20211112"}) 

15 DAQFile (str): The naming convention for the DAQ files (e.g. "*_events.csv" or "DAQOutput.csv") 

16 ExprMetaFile (str): The naming convention for the experimental structure files (e.g. "*_config.json" or "BehMetadata.json") 

17 SLEAPFile (str): The naming convention for the SLEAP files (e.g. "*_sleap.h5" or "SLEAP.h5") 

18 VideoFile (str): The naming convention for the video files (e.g. "*.mp4" or "video.avi") 

19 glob (bool): Whether to use glob to find the files (e.g. True or False) 

20 NOTE: if glob is True, make sure to include the file extension in the naming convention 

21 

22 """ 

23 

24 def __init__( 

25 self, 

26 DAQFile: str, 

27 BehFile: str, 

28 SLEAPFile: str, 

29 VideoFile: str, 

30 base: str, 

31 iterator: dict[str, str] = {}, 

32 get_glob: bool = False, 

33 ): 

34 self.base = base 

35 self.DAQFile = DAQFile 

36 self.BehFile = BehFile 

37 self.SLEAPFile = SLEAPFile 

38 self.VideoFile = VideoFile 

39 self.get_glob = get_glob 

40 if len(iterator.keys()) == 0: 

41 weeks = os.listdir(self.base) 

42 weeks = [ 

43 week for week in weeks if os.path.isdir(os.path.join(self.base, week)) 

44 ] 

45 weeks.sort() 

46 for i, week in enumerate(weeks): 

47 iterator[f"week {i+1}"] = week 

48 self.iterator = iterator 

49 self.exprs = [0] * len(self.iterator.keys()) 

50 self.files = [0] * len(self.iterator.keys()) 

51 for i, name in enumerate(list(self.iterator.keys())): 

52 daq_file = File( 

53 os.path.join(self.base, self.iterator[name]), 

54 self.DAQFile, 

55 self.get_glob, 

56 ) 

57 sleap_file = File( 

58 os.path.join(self.base, self.iterator[name]), 

59 self.SLEAPFile, 

60 self.get_glob, 

61 ) 

62 beh_file = File( 

63 os.path.join(self.base, self.iterator[name]), 

64 self.BehFile, 

65 self.get_glob, 

66 ) 

67 video_file = File( 

68 os.path.join(self.base, self.iterator[name]), 

69 self.VideoFile, 

70 self.get_glob, 

71 ) 

72 self.files[i] = FileConstructor(daq_file, sleap_file, beh_file, video_file) 

73 self.exprs[i] = Experiment(name, self.files[i]) 

74 

75 def buildColumns(self, columns: list, values: list): 

76 """Builds the custom columns for the project and builds the data for each experiment 

77 

78 Args: 

79 columns (list[str]): the column titles 

80 values (list[any]): the data for each column 

81 

82 Initializes attributes: 

83 custom_columns (list[CustomColumn]): list of custom columns 

84 all_data (pd.DataFrame): the data for all experiments concatenated together 

85 """ 

86 self.custom_columns = [0] * len(columns) 

87 for i in range(len(self.custom_columns)): 

88 self.custom_columns[i] = CustomColumn(columns[i], values[i]) 

89 exprs_list = [0] * len(self.exprs) 

90 names_list = [0] * len(self.exprs) 

91 for i in range(len(self.exprs)): 

92 self.exprs[i].buildData(self.custom_columns) 

93 exprs_list[i] = self.exprs[i].sleap.tracks 

94 names_list[i] = self.exprs[i].name 

95 self.all_data = pd.concat(exprs_list, keys=names_list) 

96 

97 def buildTrials( 

98 self, 

99 TrackedData: list[str], 

100 Reduced: list[bool], 

101 start_buffer: int = 10000, 

102 end_buffer: int = 13000, 

103 ): 

104 """Parses the data from each experiment into its individual trials 

105 

106 Args: 

107 TrackedData (list[str]): The title of the columns from the DAQ data to be tracked 

108 Reduced (list[bool]): The corresponding boolean for whether the DAQ data is to be reduced (`True`) or not (`False`) 

109 start_buffer (int, optional): The time in milliseconds before the trial start to capture. Defaults to 10000. 

110 end_buffer (int, optional): The time in milliseconds after the trial start to capture. Defaults to 13000. 

111 

112 Initializes attributes: 

113 exprs[i].trials (pd.DataFrame): the data frame containing the concatenated trial data for each experiment 

114 exprs[i].trialData (list[pd.DataFrame]): the list of data frames containing the trial data for each trial for each experiment 

115 """ 

116 for i in range(len(self.exprs)): 

117 self.exprs[i].buildTrials(TrackedData, Reduced, start_buffer, end_buffer) 

118 

119 def meanCenter(self): 

120 """Recursively mean centers the data for each trial for each experiment 

121 

122 Initializes attributes: 

123 all_data (pd.DataFrame): the mean centered data for all trials and experiments concatenated together 

124 """ 

125 mean_all = [0] * len(self.exprs) 

126 for i in range(len(self.exprs)): 

127 mean_all[i] = [0] * len(self.exprs[i].trialData) 

128 for j in range(len(self.exprs[i].trialData)): 

129 mean_all[i][j] = mean_center( 

130 self.exprs[i].trialData[i], self.exprs[i].sleap.track_names 

131 ) 

132 mean_all[i] = pd.concat( 

133 mean_all[i], 

134 axis=0, 

135 keys=range(len(mean_all[i])), 

136 ) 

137 mean_all[i] = mean_center(mean_all[i], self.exprs[i].sleap.track_names) 

138 self.all_data = pd.concat(mean_all, keys=list(self.iterator.keys())) 

139 

140 def zScore(self): 

141 """Z scores the mean centered data for each experiment 

142 

143 Updates attributes: 

144 all_data (pd.DataFrame): the z-scored data for all experiments concatenated together 

145 """ 

146 self.all_data = z_score(self.all_data, self.exprs[0].sleap.track_names) 

147 

148 def normalize(self): 

149 """Runs the mean centering and z scoring functions 

150 

151 Updates attributes: 

152 all_data (pd.DataFrame): the fully normalized data for all experiments concatenated together 

153 """ 

154 analyze_all = [0] * len(self.exprs) 

155 for iterator in range(len(self.exprs)): 

156 analyze_all[iterator] = self.exprs[iterator].normalizeTrials() 

157 analyze_all = pd.concat(analyze_all, keys=list(self.iterator.keys())) 

158 self.all_data = z_score(analyze_all, self.exprs[0].sleap.track_names) 

159 

160 def visualize(self): 

161 """Reduces `all_data` to 2 and 3 dimensions using PCA 

162 

163 Initializes attributes: 

164 pcas (dict[str, pd.DataFrame]): a dictionary containing the 2 and 3 dimensional PCA data for each experiment (the keys are 'pca2d', 'pca3d') 

165 """ 

166 self.pcas = pca(self.all_data, self.exprs[0].sleap.track_names)