Coverage for sleapyfaces/utils.py: 58%

102 statements  

« prev     ^ index     » next       coverage.py v7.0.2, created at 2023-01-03 12:07 -0800

1import h5py as h5 

2import numpy as np 

3import pandas as pd 

4from scipy.signal import savgol_filter 

5from scipy.interpolate import interp1d 

6import json 

7 

8from typing import Dict, List, Sequence, MutableSequence 

9 

10 

11def json_loads(json_str: str) -> Dict: 

12 """ 

13 Cite: 

14 From: https://sleap.ai/ 

15 By: Talmo Pereira 

16 

17 Summary: 

18 A simple wrapper around the JSON decoder we are using. 

19 

20 Args: 

21 json_str: JSON string to decode. 

22 

23 Returns: 

24 Result of decoding JSON string. 

25 """ 

26 return json.loads(json_str) 

27 

28 

29def json_dumps(d: Dict, filename: str = None): 

30 """ 

31 Cite: 

32 From: https://sleap.ai/ 

33 By: Talmo Pereira 

34 

35 Summary: 

36 A simple wrapper around the JSON encoder we are using. 

37 

38 Args: 

39 d: The dict to write. 

40 filename: The filename to write to. 

41 

42 Returns: 

43 None 

44 """ 

45 

46 encoder = json 

47 

48 if filename: 

49 with open(filename, "w") as f: 

50 encoder.dump(d, f, ensure_ascii=False) 

51 else: 

52 return encoder.dumps(d) 

53 

54 

55def save_dt_to_hdf5(hdfstore: pd.HDFStore, dt: pd.DataFrame, path: str): 

56 """ 

57 Summary: 

58 Saves a pandas DataFrame to an HDF5 file. 

59 

60 Args: 

61 hdfstore (pd.HDFStore Buffer): The HDF5 filename object to save the data to. Assume it is open. 

62 dt (pd.DataFrame): The DataFrame to save. 

63 path (str): The path to group save the DataFrame under. 

64 

65 Returns: 

66 None 

67 """ 

68 hdfstore.put(path, dt, format="table", data_columns=True) 

69 

70 

71def save_dict_to_hdf5(h5file: h5.File, path: str, dic: dict): 

72 """ 

73 Cite: 

74 From: https://sleap.ai/ 

75 By: Talmo Pereira 

76 

77 Summary: 

78 Saves dictionary to an HDF5 file. 

79 

80 Calls itself recursively if items in dictionary are not 

81 `np.ndarray`, `np.int64`, `np.float64`, `str`, or `bytes`. 

82 Objects must be iterable. 

83 

84 Args: 

85 h5file (h5.file): The HDF5 filename object to save the data to. 

86 Assume it is open. 

87 path (str): The path to group save the dict under. 

88 dic (dict): The dict to save. 

89 

90 Raises: 

91 ValueError: If type for item in dict cannot be saved. 

92 

93 Returns: 

94 None 

95 """ 

96 for key, item in list(dic.items()): 

97 print(f"Saving {key}:") 

98 if item is None: 

99 h5file[path + key] = "" 

100 elif isinstance(item, bool): 

101 h5file[path + key] = int(item) 

102 elif isinstance(item, list): 

103 items_encoded = [] 

104 for it in item: 

105 if isinstance(it, str): 

106 items_encoded.append(it.encode("utf8")) 

107 else: 

108 items_encoded.append(it) 

109 

110 h5file[path + key] = np.asarray(items_encoded) 

111 elif isinstance(item, (str)): 

112 h5file[path + key] = item.encode("utf8") 

113 elif isinstance(item, (np.ndarray, np.int64, np.float64, str, bytes, float)): 

114 h5file[path + key] = item 

115 elif isinstance(item, dict): 

116 save_dict_to_hdf5(h5file, path + key + "/", item) 

117 elif isinstance(item, int): 

118 h5file[path + key] = item 

119 else: 

120 raise ValueError("Cannot save %s type" % type(item)) 

121 

122 

123def fill_missing(Y, kind="linear") -> np.ndarray: 

124 """ 

125 Cite: 

126 From: https://sleap.ai/notebooks/Analysis_examples.html 

127 By: Talmo Pereira 

128 

129 Summary: 

130 Fills missing values independently along each dimension after the first. 

131 

132 Args: 

133 Y (np.array): any dimensional array with missing values. 

134 kind (str): Interpolation kind. 

135 

136 Returns: 

137 Y (np.array): Original array with missing values filled. 

138 """ 

139 

140 # Store initial shape. 

141 initial_shape = Y.shape 

142 

143 # Flatten after first dim. 

144 Y = Y.reshape((initial_shape[0], -1)) 

145 

146 # Interpolate along each slice. 

147 for i in range(Y.shape[-1]): 

148 y = Y[:, i] 

149 

150 # Build interpolant. 

151 x = np.flatnonzero(~np.isnan(y)) 

152 f = interp1d(x, y[x], kind=kind, fill_value=np.nan, bounds_error=False) 

153 

154 # Fill missing 

155 xq = np.flatnonzero(np.isnan(y)) 

156 y[xq] = f(xq) 

157 

158 # Fill leading or trailing NaNs with the nearest non-NaN values 

159 mask = np.isnan(y) 

160 y[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), y[~mask]) 

161 

162 # Save slice 

163 Y[:, i] = y 

164 

165 # Restore to initial shape. 

166 Y = Y.reshape(initial_shape) 

167 

168 return Y 

169 

170 

171def smooth_diff(node_loc: np.ndarray, win=25, poly=3) -> np.ndarray: 

172 """ 

173 Cite: 

174 From: https://sleap.ai/notebooks/Analysis_examples.html 

175 By: Talmo Pereira 

176 

177 Summary: 

178 Computes the velocity of a node by taking the 

179 derivative of the smoothed position of the node and then taking the norm of the velocity vector at each frame. 

180 

181 Args: 

182 node_loc (np.array): is a [frames, 2] array 

183 

184 win (int): defines the window to smooth over 

185 

186 poly (int): defines the order of the polynomial 

187 to fit with 

188 

189 Returns: 

190 node_vel (np.array): is a [frames, 1] array 

191 

192 """ 

193 node_loc_vel = np.zeros_like(node_loc) 

194 

195 for c in range(node_loc.shape[-1]): 

196 node_loc_vel[:, c] = savgol_filter(node_loc[:, c], win, poly, deriv=1) 

197 

198 node_vel = np.linalg.norm(node_loc_vel, axis=1) 

199 

200 return node_vel 

201 

202 

203def corr_roll( 

204 datax: List | pd.Series | np.ndarray, datay: List | pd.Series | np.ndarray, win: int 

205) -> np.ndarray: 

206 """ 

207 Cite: 

208 From: https://sleap.ai/notebooks/Analysis_examples.html 

209 By: Talmo Pereira 

210 

211 Summary: 

212 Computes the rolling correlation between two timeseries 

213 

214 Args: 

215 datax (List or np.array or pd.Series): the x-dimensional timeseries 

216 datay (List np.array or pd.Series): the y-dimeansional timeseries 

217 

218 win (int): sets the number of frames over which the covariance is computed 

219 

220 Returns: 

221 np.array: returns a numpy array of the rolling correlation between the two timeseries over time 

222 

223 """ 

224 

225 s1 = pd.Series(datax) 

226 s2 = pd.Series(datay) 

227 

228 return np.array(s2.rolling(win).corr(s1)) 

229 

230 

231def into_trial_format( 

232 var: pd.DataFrame, 

233 trial_types: list, 

234 trial_start_idx: np.ndarray, 

235 trial_end_idx: np.ndarray, 

236) -> list[pd.DataFrame]: 

237 """ 

238 Summary: 

239 Splits an array or dataframe into individual trials. 

240 

241 Assumes that the index of the array or dataframe is the frame number. 

242 

243 Args: 

244 var (np.array or pd.DataFrame): array or dataframe to split into trials 

245 trial_start_idx (list[int]): the list of frame indecies where the trials start 

246 trial_end_idx (list[int]): the list of frame indecies where the trials end 

247 

248 Returns: 

249 pd.DataFrame: returns a DataFrame with a metaindex of trial number and frame number 

250 """ 

251 

252 if trial_start_idx.shape != trial_end_idx.shape: 

253 raise ValueError("trial_start_idx and trial_end_idx must be the same length") 

254 var_trials = [0] * trial_start_idx.shape[0] 

255 for trial, (start, end, trial_type) in enumerate( 

256 zip(trial_start_idx, trial_end_idx, trial_types) 

257 ): 

258 var_trials[trial] = pd.DataFrame(var.iloc[start:end, :]) 

259 var_trials[trial] = var_trials[trial].reset_index() 

260 trial_type = [trial_type] * len(var_trials[trial].index) 

261 trial_type = pd.DataFrame(trial_type, columns=["trial_type"]) 

262 var_trials[trial] = pd.concat( 

263 [var_trials[trial], trial_type], 

264 axis=1, 

265 ) 

266 return var_trials 

267 

268 

269# create gaussian kernel for smoothing 

270def gaussian_kernel(window_size: int, sigma=1) -> np.ndarray: 

271 """ 

272 Summary: 

273 this function creates a gaussian kernel for back smoothing 

274 

275 Args: 

276 window_size (int): how many frames to smooth over 

277 sigma (int, optional): relative standard deviation. Defaults to 1. 

278 

279 Returns: 

280 np.array: returns a kernel to smooth over with shape (window_size,) 

281 """ 

282 x_vals = np.arange(window_size) 

283 to_ret = np.exp(-((x_vals - window_size // 2) * 2) / (2 * sigma * 2)) 

284 to_ret[: window_size // 2] = 0 

285 return to_ret 

286 

287 

288def reduce_daq(iterable: list, ms=4000) -> list[float]: 

289 """ 

290 Summary: 

291 

292 Reduces rapid succession TTL pulses to a single pulse. 

293 

294 Args: 

295 iterable (list): the list of TTL pulse times (preferably in ms) 

296 ms (int, optional): the minimum time between pulses. Defaults to 4000ms (or 5 seconds) between pulses. 

297 

298 Returns: 

299 list[float]: a reduced list of TTL pulse times 

300 """ 

301 list: list[float] = [] 

302 j: int = 0 

303 list.append(iterable[j]) 

304 for i in range(0, len(iterable)): 

305 if iterable[j] < (iterable[i] - ms): 

306 j = i 

307 list.append(iterable[j]) 

308 return list 

309 

310 

311def tracks_deconstructor( 

312 tracks: np.ndarray | pd.DataFrame | List | Sequence | MutableSequence, 

313 nodes: np.ndarray | pd.DataFrame | List | Sequence | MutableSequence, 

314) -> pd.DataFrame: 

315 """takes the tracks array from a SLEAP analysis file and converts it into a pandas DataFrame 

316 

317 Args: 

318 tracks (np.ndarray | pd.DataFrame | List | Sequence | MutableSequence): the 4D array of tracks from a SLEAP analysis file 

319 nodes (np.ndarray | pd.DataFrame | List | Sequence | MutableSequence): the list of nodes from a SLEAP analysis file 

320 

321 Returns: 

322 pd.DataFrame: the tracks DataFrame 

323 """ 

324 new_tracks = [0] * len(nodes) 

325 for n, node in enumerate(nodes): 

326 new_tracks[n] = pd.DataFrame( 

327 { 

328 f"{node.replace(' ', '_')}_x": tracks[:, n, 0, 0], 

329 f"{node.replace(' ', '_')}_y": tracks[:, n, 1, 0], 

330 }, 

331 columns=[f"{node.replace(' ', '_')}_x", f"{node.replace(' ', '_')}_y"], 

332 ) 

333 return pd.concat(new_tracks, axis=1) 

334 

335 

336def flatten_list(_2d_list): 

337 flat_list = [] 

338 # Iterate through the outer list 

339 for element in _2d_list: 

340 if type(element) is list: 

341 # If the element is of type list, iterate through the sublist 

342 for item in element: 

343 flat_list.append(item) 

344 else: 

345 flat_list.append(element) 

346 return flat_list