Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from __future__ import annotations 

2 

3import logging 

4 

5import warnings 

6from collections import namedtuple 

7from pathlib import PosixPath, Path 

8from typing import Dict, List, Optional, Union 

9 

10import numpy as np 

11import pandas as pd 

12from pandas.core.frame import DataFrame 

13 

14import torch 

15from torch.utils.data.dataloader import DataLoader 

16 

17import pytorch_lightning as pl 

18 

19from elfragmentador import constants, spectra 

20from argparse import _ArgumentGroup 

21from torch import Tensor 

22from tqdm.auto import tqdm 

23 

24TrainBatch = namedtuple( 

25 "TrainBatch", 

26 "encoded_sequence, encoded_mods, charge, nce, encoded_spectra, norm_irt", 

27) 

28 

29 

30def match_lengths( 

31 nested_list: Union[List[List[Union[int, float]]], List[List[int]]], 

32 max_len: int, 

33 name: str = "items", 

34 verbose: bool = True, 

35) -> Tensor: 

36 lengths = [len(x) for x in nested_list] 

37 unique_lengths = set(lengths) 

38 match_max = [1 for x in lengths if x == max_len] 

39 

40 out_message = ( 

41 f"{len(match_max)}/{len(nested_list)} " 

42 f"{name} actually match the max sequence length of" 

43 f" {max_len}," 

44 f" found {unique_lengths}" 

45 ) 

46 

47 if len(match_max) == len(nested_list): 

48 logging.info(out_message) 

49 else: 

50 logging.warning(out_message) 

51 

52 out = [ 

53 x + ([0] * (max_len - len(x))) if len(x) != max_len else x for x in nested_list 

54 ] 

55 out = torch.stack([torch.Tensor(x).T for x in out]) 

56 return out 

57 

58 

59def match_colnames(df: DataFrame) -> Dict[str, Optional[str]]: 

60 def match_col(string1, string2, colnames, match_mode="in", combine_mode=None): 

61 m = { 

62 "in": lambda q, t: q in t, 

63 "startswith": lambda q, t: q.startswith(t) or t.startswith(q), 

64 "equals": lambda q, t: q == t, 

65 } 

66 match_fun = m[match_mode] 

67 match_indices1 = [i for i, x in enumerate(colnames) if match_fun(string1, x)] 

68 

69 if string2 is None: 

70 match_indices = match_indices1 

71 else: 

72 match_indices2 = [ 

73 i for i, x in enumerate(colnames) if match_fun(string2, x) 

74 ] 

75 if combine_mode == "union": 

76 match_indices = set(match_indices1).union(set(match_indices2)) 

77 elif combine_mode == "intersect": 

78 match_indices = set(match_indices1).intersection(set(match_indices2)) 

79 else: 

80 raise NotImplementedError 

81 

82 try: 

83 out_index = list(match_indices)[0] 

84 except IndexError: 

85 out_index = None 

86 

87 return out_index 

88 

89 colnames = list(df) 

90 out = { 

91 "SeqE": match_col("Encoding", "Seq", colnames, combine_mode="intersect"), 

92 "ModE": match_col("Encoding", "Mod", colnames, combine_mode="intersect"), 

93 "SpecE": match_col("Encoding", "Spec", colnames, combine_mode="intersect"), 

94 "Ch": match_col("harg", None, colnames), 

95 "iRT": match_col("IRT", "iRT", colnames, combine_mode="union"), 

96 "NCE": match_col( 

97 "nce", "NCE", colnames, combine_mode="union", match_mode="startswith" 

98 ), 

99 } 

100 out = {k: (colnames[v] if v is not None else None) for k, v in out.items()} 

101 logging.info(f">>> Mapped column names to the provided dataset {out}") 

102 return out 

103 

104 

105class PeptideDataset(torch.utils.data.Dataset): 

106 def __init__( 

107 self, 

108 df: DataFrame, 

109 max_spec: int = 1e6, 

110 drop_missing_vals=False, 

111 ) -> None: 

112 super().__init__() 

113 logging.info("\n>>> Initalizing Dataset") 

114 if drop_missing_vals: 

115 former_len = len(df) 

116 df.dropna(inplace=True) 

117 logging.warning( 

118 f"\n>>> {former_len}/{len(df)} rows left after dropping missing values" 

119 ) 

120 

121 if max_spec < len(df): 

122 logging.warning( 

123 "\n>>> Filtering out to have " 

124 f"{max_spec}, change the 'max_spec' argument if you don't want" 

125 "this to happen" 

126 ) 

127 df = df.sample(n=int(max_spec)) 

128 

129 self.df = df # TODO remove this for memory ... 

130 

131 name_match = match_colnames(df) 

132 

133 seq_encoding_iter = tqdm( 

134 self.df[name_match["SeqE"]], "Decoding sequence encodings" 

135 ) 

136 sequence_encodings = [eval(x) for x in seq_encoding_iter] 

137 sequence_encodings = match_lengths( 

138 sequence_encodings, constants.MAX_TENSOR_SEQUENCE, "Sequences" 

139 ) 

140 self.sequence_encodings = sequence_encodings.long() 

141 

142 if name_match["ModE"] is None: 

143 logging.warning( 

144 ( 

145 "Found missing Modification Encodings," 

146 " Assuming all peptides are unmodified." 

147 " Please fix the data for future use," 

148 " since this imputation will be removed in the future" 

149 ) 

150 ) 

151 mod_encodings = [ 

152 [0] * constants.MAX_TENSOR_SEQUENCE for _ in sequence_encodings 

153 ] 

154 else: 

155 mod_encodings_iter = tqdm( 

156 self.df[name_match["ModE"]], "Decoding Modification encoding" 

157 ) 

158 mod_encodings = [eval(x) for x in mod_encodings_iter] 

159 

160 mod_encodings = match_lengths( 

161 mod_encodings, constants.MAX_TENSOR_SEQUENCE, "Mods" 

162 ) 

163 self.mod_encodings = mod_encodings.long() 

164 

165 spec_encoding_iter = tqdm( 

166 self.df[name_match["SpecE"]], "Decoding Spec Encodings" 

167 ) 

168 spectra_encodings = [eval(x) for x in spec_encoding_iter] 

169 spectra_encodings = match_lengths( 

170 spectra_encodings, constants.NUM_FRAG_EMBEDINGS, "Spectra" 

171 ) 

172 self.spectra_encodings = spectra_encodings.float() 

173 avg_peaks = torch.sum(spectra_encodings > 0.01, axis=1).float().mean() 

174 

175 spectra_lengths = len(self.spectra_encodings[0]) 

176 sequence_lengths = len(self.sequence_encodings[0]) 

177 

178 try: 

179 irts = np.array(self.df[name_match["iRT"]]).astype("float") / 100 

180 self.norm_irts = torch.from_numpy(irts).float().unsqueeze(1) 

181 del irts 

182 except ValueError as e: 

183 logging.error(self.df[name_match["iRT"]]) 

184 logging.error(e) 

185 raise e 

186 

187 if name_match["NCE"] is None: 

188 nces = ( 

189 torch.Tensor([float("nan")] * len(self.norm_irts)).float().unsqueeze(1) 

190 ) 

191 else: 

192 try: 

193 nces = np.array(self.df[name_match["NCE"]]).astype("float") 

194 nces = torch.from_numpy(nces).float().unsqueeze(1) 

195 except ValueError as e: 

196 logging.error(self.df[name_match["NCE"]]) 

197 logging.error(e) 

198 raise e 

199 

200 self.nces = nces 

201 

202 if torch.any(self.nces.isnan()): 

203 # TODO decide if here should be the place to impute NCEs ... and warn ... 

204 warnings.warn( 

205 ( 

206 "Found missing values in NCEs, assuming 30." 

207 " Please fix the data for future use, " 

208 "since this imputation will be removed in the future" 

209 ), 

210 FutureWarning, 

211 ) 

212 self.nces = torch.where(self.nces.isnan(), torch.Tensor([30.0]), self.nces) 

213 

214 # This syntax is compatible in torch +1.8, will change when colab migrates to it 

215 # self.nces = torch.nan_to_num(self.nces, nan=30.0) 

216 

217 charges = np.array(self.df[name_match["Ch"]]).astype("long") 

218 self.charges = torch.Tensor(charges).long().unsqueeze(1) 

219 

220 logging.info( 

221 ( 

222 f"Dataset Initialized with {len(df)} entries." 

223 f" Sequence length: {sequence_lengths}" 

224 f" Spectra length: {spectra_lengths}" 

225 f"; Average Peaks/spec: {avg_peaks}" 

226 ) 

227 ) 

228 logging.info(">>> Done Initializing dataset\n") 

229 

230 @staticmethod 

231 def from_sptxt( 

232 filepath: str, 

233 max_spec: int = 1e6, 

234 filter_df: bool = True, 

235 *args, 

236 **kwargs, 

237 ) -> PeptideDataset: 

238 df = spectra.encode_sptxt(str(filepath), max_spec=max_spec, *args, **kwargs) 

239 if filter_df: 

240 df = filter_df_on_sequences(df) 

241 

242 return PeptideDataset(df) 

243 

244 @staticmethod 

245 def from_csv(filepath: Union[str, Path], max_spec: int = 1e6): 

246 df = filter_df_on_sequences(pd.read_csv(str(filepath))) 

247 return PeptideDataset(df, max_spec=max_spec) 

248 

249 def __len__(self) -> int: 

250 return len(self.df) 

251 

252 def __getitem__(self, index: int) -> TrainBatch: 

253 # encoded_pept = torch.Tensor(eval(self.df.iloc[index].Encoding)).long().T 

254 # norm_irt = torch.Tensor([self.df.iloc[index].mIRT / 100]).float() 

255 encoded_sequence = self.sequence_encodings[index] 

256 encoded_mods = self.mod_encodings[index] 

257 encoded_spectra = self.spectra_encodings[index] 

258 norm_irt = self.norm_irts[index] 

259 charge = self.charges[index] 

260 nce = self.nces[index] 

261 

262 out = TrainBatch( 

263 encoded_sequence=encoded_sequence, 

264 encoded_mods=encoded_mods, 

265 charge=charge, 

266 nce=nce, 

267 encoded_spectra=encoded_spectra, 

268 norm_irt=norm_irt, 

269 ) 

270 return out 

271 

272 

273def filter_df_on_sequences(df: DataFrame, name: str = "") -> DataFrame: 

274 name_match = match_colnames(df) 

275 logging.info(list(df)) 

276 logging.warning(f"Removing Large sequences, currently {name}: {len(df)}") 

277 

278 seq_iterable = tqdm(df[name_match["SeqE"]], desc="Decoding tensor seqs") 

279 

280 df = ( 

281 df[[len(eval(x)) <= constants.MAX_TENSOR_SEQUENCE for x in seq_iterable]] 

282 .copy() 

283 .reset_index(drop=True) 

284 ) 

285 

286 logging.warning(f"Left {name}: {len(df)}") 

287 return df 

288 

289 

290class PeptideDataModule(pl.LightningDataModule): 

291 def __init__( 

292 self, 

293 batch_size: int = 64, 

294 base_dir: Union[str, PosixPath] = ".", 

295 drop_missing_vals: bool = False, 

296 ) -> None: 

297 super().__init__() 

298 self.batch_size = batch_size 

299 self.drop_missing_vals = drop_missing_vals 

300 base_dir = Path(base_dir) 

301 

302 train_path = list(base_dir.glob("*train*.csv")) 

303 val_path = list(base_dir.glob("*val*.csv")) 

304 

305 assert ( 

306 len(train_path) > 0 

307 ), f"Train File '{train_path}' not found in '{base_dir}'" 

308 assert len(val_path) > 0, f"Val File '{val_path}' not found in '{base_dir}'" 

309 

310 train_df = pd.concat([pd.read_csv(str(x)) for x in train_path]) 

311 train_df = filter_df_on_sequences(train_df) 

312 val_df = pd.concat([pd.read_csv(str(x)) for x in val_path]) 

313 val_df = filter_df_on_sequences(val_df) 

314 

315 self.train_df = train_df 

316 self.val_df = val_df 

317 

318 @staticmethod 

319 def add_model_specific_args(parser: _ArgumentGroup) -> _ArgumentGroup: 

320 parser.add_argument("--batch_size", type=int, default=64) 

321 parser.add_argument("--data_dir", type=str, default=".") 

322 parser.add_argument("--drop_missing_vals", type=bool, default=False) 

323 return parser 

324 

325 def setup(self) -> None: 

326 self.train_dataset = PeptideDataset( 

327 self.train_df, drop_missing_vals=self.drop_missing_vals 

328 ) 

329 self.val_dataset = PeptideDataset( 

330 self.val_df, drop_missing_vals=self.drop_missing_vals 

331 ) 

332 

333 def train_dataloader(self) -> DataLoader: 

334 return DataLoader( 

335 self.train_dataset, num_workers=0, batch_size=self.batch_size, shuffle=True 

336 ) 

337 

338 def val_dataloader(self) -> DataLoader: 

339 return DataLoader(self.val_dataset, batch_size=self.batch_size)