Coverage for denofo/utils/helpers.py: 59%

197 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-09 15:27 +0200

1import re 

2from pydantic import BaseModel 

3from enum import Enum 

4from types import UnionType, GenericAlias 

5from collections import Counter 

6from pathlib import Path 

7from ast import literal_eval 

8from denofo.choices import ThresholdChoices 

9from denofo.utils.constants import ENCODE_DICT, DECODE_DICT, REPLACEMENTS 

10 

11 

12def different_answers(answer: any, prev_answer: any) -> bool: 

13 """ 

14 Check if new answer differs from previously given answer. 

15 

16 :param answer: The answer. 

17 :type answer: any 

18 :param prev_answer: The previous answer. 

19 :type prev_answer: any 

20 :return: True if the answers are different, False otherwise. 

21 :rtype: bool 

22 """ 

23 if isinstance(answer, list) and isinstance(prev_answer, list): 

24 return Counter(answer) != Counter(prev_answer) 

25 

26 return answer != prev_answer 

27 

28 

29def get_model_from_qstack_dict( 

30 qstack_dict: dict, 

31 model: BaseModel, 

32) -> BaseModel: 

33 """ 

34 Get a BaseModel object from a qstack dictionary (which can contain field 

35 names and values from other models as well). 

36 

37 :param qstack_dict: The qstack dictionary. keys = model field names, values = model field values. 

38 :type qstack_dict: dict 

39 :param model: The BaseModel object. 

40 :type model: BaseModel 

41 :return: The BaseModel object with respective field values. 

42 :rtype: BaseModel 

43 """ 

44 model_fields = model.model_fields.keys() 

45 matching_fields = {k: v for k, v in qstack_dict.items() if k in model_fields and v} 

46 if len(matching_fields) == 0: 

47 return None 

48 

49 return model(**matching_fields) 

50 

51 

52def add_extension(file_path: Path, extension: str = "dngf") -> Path: 

53 """ 

54 Add an extension to a file path if it does not already have one. 

55 

56 :param file_path: The file path. 

57 :type file_path: Path 

58 :param extension: The extension to add. 

59 :type extension: str 

60 :return: The file path with the extension added. 

61 :rtype: Path 

62 """ 

63 if not file_path.suffix: 

64 file_path = file_path.with_suffix(f".{extension}") 

65 

66 return file_path 

67 

68 

69def infer_format_from_extension(file_path: Path) -> str: 

70 """ 

71 Infer the format of a file based on its extension. 

72 

73 :param file_path: The file path. 

74 :type file_path: Path 

75 :return: The inferred format. 

76 :rtype: str 

77 """ 

78 ext_to_format = { 

79 ".dngf": "dngf", 

80 ".json": "dngf", 

81 ".pickle": "pickle", 

82 ".fasta": "fasta", 

83 ".fa": "fasta", 

84 ".fna": "fasta", 

85 ".faa": "fasta", 

86 ".txt": "shortstr", 

87 ".gff": "gff", 

88 ".gtf": "gff", 

89 } 

90 

91 extension = file_path.suffix 

92 

93 return ext_to_format.get(extension, None) 

94 

95 

96def diff_two_lists( 

97 list1: list, 

98 list2: list, 

99 model_name: str, 

100 field_name: str, 

101 mode: str = "differences", 

102) -> list: 

103 """ 

104 Compare two lists. 

105 

106 :param list1: The first list. 

107 :type list1: list 

108 :param list2: The second list. 

109 :type list2: list 

110 :return: The comparison result. 

111 :rtype: list 

112 """ 

113 diff = [] 

114 

115 set1 = set(list1) 

116 set2 = set(list2) 

117 

118 if mode == "similarities": 

119 diff.extend( 

120 [("same", model_name, field_name, overlap) for overlap in set1 & set2] 

121 ) 

122 if mode == "differences": 

123 diff.extend([("1not2", model_name, field_name, only1) for only1 in set1 - set2]) 

124 diff.extend([("2not1", model_name, field_name, only2) for only2 in set2 - set1]) 

125 

126 return diff 

127 

128 

129def _merge_thresholds_with_vals(thres_lst: list, thresval_lst: list) -> list: 

130 """ 

131 Merge threshold types with their respective values. 

132 

133 :param thres_lst: The list of threshold types. 

134 :type thres_lst: list 

135 :param thresval_lst: The list of threshold values. 

136 :type thresval_lst: list 

137 :return: The merged list of threshold types and values. 

138 :rtype: list 

139 """ 

140 clist = thres_lst.copy() 

141 

142 if ThresholdChoices.CUSTOM in clist: 

143 clist.remove(ThresholdChoices.CUSTOM) 

144 clist = [(threstype, thresval) for threstype, thresval in zip(clist, thresval_lst)] 

145 

146 return clist 

147 

148 

149def compare_two_models( 

150 dngf1: BaseModel, 

151 dngf2: BaseModel, 

152 mode: str = "differences", 

153) -> list[tuple]: 

154 """ 

155 Compare two BaseModel objects. 

156 

157 :param dngf1: The first BaseModel object. 

158 :type dngf1: BaseModel 

159 :param dngf2: The second BaseModel object. 

160 :type dngf2: BaseModel 

161 :param mode: The mode of comparison. Options: 'differences' (default) or 'similarities'. 

162 :type mode: str 

163 :param name1: The name of the first BaseModel object in output. 

164 :type name1: str 

165 :param name2: The name of the second BaseModel object in output. 

166 :type name2: str 

167 :return: The comparison result. 

168 :rtype: list[tuple[str]] 

169 """ 

170 comparison = [] 

171 

172 for it1, it2 in zip(dngf1.model_dump().items(), dngf2.model_dump().items()): 

173 field1, value1 = it1 

174 field2, value2 = it2 

175 if isinstance(value1, dict) and isinstance(value2, dict): # nested models 

176 model1 = dngf1.__annotations__[field1] 

177 model2 = dngf2.__annotations__[field2] 

178 if isinstance(model1, UnionType): 

179 model1 = model1.__args__[0] 

180 model2 = model2.__args__[0] 

181 

182 model1 = model1.model_validate(value1) 

183 model2 = model2.model_validate(value2) 

184 comparison.extend(compare_two_models(model1, model2, mode=mode)) 

185 elif ( 

186 field1 == "threshold" 

187 ): # special case for mapping threshold metrics to threshold values 

188 full_list1 = value1 + ( 

189 dngf1.customThreshold if dngf1.customThreshold is not None else [] 

190 ) 

191 full_list2 = value2 + ( 

192 dngf2.customThreshold if dngf2.customThreshold is not None else [] 

193 ) 

194 full_list1 = _merge_thresholds_with_vals(full_list1, dngf1.thresholdValue) 

195 full_list2 = _merge_thresholds_with_vals(full_list2, dngf2.thresholdValue) 

196 

197 comp_new = diff_two_lists( 

198 full_list1, full_list2, dngf1.__class__.__name__, field1, mode=mode 

199 ) 

200 

201 if mode == "differences": 

202 to_skip = set() 

203 for elem in comp_new: 

204 threstype, thresval = elem[3] 

205 

206 if threstype in to_skip: 

207 continue 

208 threstypes = [ 

209 elem[3] for elem in comp_new if elem[3][0] == threstype 

210 ] 

211 if len(threstypes) > 1: 

212 to_skip.add(threstype) 

213 thresvals = ", ".join([str(elem[1]) for elem in threstypes]) 

214 comparison.append( 

215 ( 

216 "diffval", 

217 dngf1.__class__.__name__, 

218 field1, 

219 threstype, 

220 thresvals, 

221 ) 

222 ) 

223 else: 

224 comparison.append(elem) 

225 

226 elif field1 in ("customThreshold", "thresholdValue"): # see special case above 

227 continue 

228 elif value1 == value2 and mode == "similarities" and value1 is not None: 

229 comparison.append(("same", dngf1.__class__.__name__, field1, value1)) 

230 elif value1 != value2: 

231 if value1 and not value2: 

232 comparison.append(("1not2", dngf1.__class__.__name__, field1, value1)) 

233 elif not value1 and value2: 

234 comparison.append(("2not1", dngf2.__class__.__name__, field2, value2)) 

235 elif isinstance(value1, list) and isinstance(value2, list): 

236 comparison.extend( 

237 diff_two_lists( 

238 value1, value2, dngf1.__class__.__name__, field1, mode=mode 

239 ) 

240 ) 

241 else: 

242 comparison.append( 

243 ("diffval", dngf1.__class__.__name__, field1, value1, value2) 

244 ) 

245 

246 return comparison 

247 

248 

249def _get_index_from_enum_choice(enum_choice: Enum) -> int: 

250 """ 

251 Get the index of an Enum choice. 

252 

253 :param enum_choice: The Enum choice. 

254 :type enum_choice: Enum 

255 :return: The index of the Enum choice. 

256 :rtype: int 

257 """ 

258 enum_order = list(enum_choice.__class__) 

259 if ( 

260 hasattr(enum_choice.__class__, "CUSTOM") 

261 and enum_choice == enum_choice.__class__.CUSTOM 

262 ): 

263 value_idx = 0 

264 else: 

265 value_idx = enum_order.index(enum_choice) + 1 

266 

267 return value_idx 

268 

269 

270def get_short_repr(orig_model: BaseModel) -> str: 

271 """ 

272 Get a short representation string of a BaseModel object. 

273 

274 :param orig_model: The BaseModel object. 

275 :type orig_model: BaseModel 

276 :return: The short representation of the BaseModel object. 

277 :rtype: str 

278 """ 

279 model_dict = orig_model.model_dump() 

280 short_repr = "" 

281 for field, value in model_dict.items(): 

282 if not value: 

283 continue 

284 if field in ENCODE_DICT: 

285 value_short_repr = "" 

286 

287 if isinstance(value, dict): # nested model 

288 submodel = orig_model.__annotations__[field] 

289 if isinstance(submodel, UnionType): 

290 submodel = submodel.__args__[0] 

291 value_short_repr += get_short_repr(submodel.validate(value)) 

292 elif isinstance(value, Enum): 

293 value_idx = _get_index_from_enum_choice(value) 

294 value_short_repr += str(value_idx) 

295 elif isinstance(value, list): 

296 if isinstance(value[0], Enum): 

297 for val in value: 

298 val_idx = _get_index_from_enum_choice(val) 

299 value_short_repr += str(val_idx) 

300 elif isinstance(value[0], str) or isinstance(value[0], float): 

301 vals = ",".join(map(str, value)) 

302 value_short_repr += f"'[{vals}]'" 

303 elif ( 

304 isinstance(value, str) 

305 or isinstance(value, int) 

306 or isinstance(value, float) 

307 ): 

308 value_short_repr += f"'{value}'" 

309 elif isinstance(value, bool): 

310 value_short_repr += str(int(value)) 

311 

312 if value_short_repr: 

313 short_repr += ENCODE_DICT[field] 

314 short_repr += value_short_repr 

315 

316 return short_repr 

317 

318 

319def _process_short_matches(short_matches: list) -> list: 

320 """ 

321 Process the short matches to translate back to original values. 

322 

323 :param short_matches: The short matches. 

324 :type short_matches: list 

325 :return: The processed short matches. 

326 :rtype: list 

327 """ 

328 shrt_mtchs_trns = [] 

329 for match in short_matches: 

330 if match[0]: # list values 

331 pattern = re.compile(r"\[|\]|,") 

332 ematch = pattern.sub(lambda m: REPLACEMENTS[m.group()], match[0]) 

333 mlst = literal_eval(ematch) 

334 if isinstance(mlst, list): 

335 shrt_mtchs_trns.append(mlst) 

336 else: 

337 raise ValueError(f"Could not convert to list: {match[0]}") 

338 elif match[1]: # single str/int/float values 

339 if match[1].isdigit(): 

340 shrt_mtchs_trns.append(int(match[1])) 

341 try: 

342 shrt_mtchs_trns.append(float(match[1])) 

343 except ValueError: 

344 shrt_mtchs_trns.append(match[1]) 

345 elif match[2]: # enum choices encoded 

346 field_name = DECODE_DICT[match[2][0]] 

347 enum_nums = list(map(int, list(match[2][1:]))) 

348 shrt_mtchs_trns.append((field_name, enum_nums)) 

349 elif match[3]: # single field names encoded 

350 shrt_mtchs_trns.append(DECODE_DICT[match[3]]) 

351 

352 return shrt_mtchs_trns 

353 

354 

355def _get_enum_choices_from_num(enum_nums: list, enum_class: Enum) -> Enum: 

356 """ 

357 Get the Enum choices from their respective numbers. 

358 

359 :param enum_nums: The list of Enum numbers. 

360 :type enum_nums: list 

361 :param enum_class: The Enum class. 

362 :type enum_class: Enum 

363 :return: The Enum choices. 

364 :rtype: Enum 

365 """ 

366 enum_choices = [] 

367 enum_order = list(enum_class) 

368 

369 for enum_num in enum_nums: 

370 if enum_num == 0: 

371 enum_choices.append(enum_class.CUSTOM) 

372 else: 

373 enum_choices.append(enum_order[enum_num - 1]) 

374 

375 return enum_choices 

376 

377 

378def _fill_model_dict( 

379 shrt_mtchs_trns: list, model_dict: dict, model: BaseModel 

380) -> tuple[dict, list]: 

381 """ 

382 Fill the model dictionary with values from the short representation. 

383 

384 :param shrt_mtchs_trns: The short matches. 

385 :type shrt_mtchs_trns: list 

386 :param model_dict: The model dictionary. 

387 :type model_dict: dict 

388 :param model: The BaseModel object. 

389 :type model: BaseModel 

390 :return: The model dictionary and the remaining short matches. 

391 :rtype: tuple[dict, list] 

392 """ 

393 while shrt_mtchs_trns: 

394 elem = shrt_mtchs_trns.pop(0) 

395 

396 if isinstance(elem, tuple): 

397 enum_choices = None 

398 field_name = elem[0] 

399 

400 if field_name not in model.__annotations__: 

401 return model_dict, [elem] + shrt_mtchs_trns 

402 enum_class = model.__annotations__[field_name] 

403 

404 if not isinstance(enum_class, UnionType) and issubclass(enum_class, bool): 

405 enum_choices = elem[1][0] 

406 elif not isinstance(enum_class, UnionType) and issubclass(enum_class, Enum): 

407 enum_choices = _get_enum_choices_from_num(elem[1], enum_class) 

408 enum_choices = enum_choices[0] 

409 elif isinstance(enum_class, GenericAlias) or isinstance( 

410 enum_class, UnionType 

411 ): 

412 enum_class = enum_class.__args__[0] 

413 if isinstance(enum_class, GenericAlias): 

414 enum_class = enum_class.__args__[0] 

415 if not enum_choices: 

416 enum_choices = _get_enum_choices_from_num(elem[1], enum_class) 

417 

418 model_dict[field_name] = enum_choices 

419 

420 elif isinstance(elem, str): 

421 if elem in model.__annotations__: 

422 val_type = model.__annotations__[elem] 

423 if isinstance(val_type, UnionType): 

424 val_type = val_type.__args__[0] 

425 

426 if not isinstance(val_type, GenericAlias) and issubclass( 

427 val_type, BaseModel 

428 ): # val is BaseModel 

429 vdict, shrt_mtchs_trns = _fill_model_dict( 

430 shrt_mtchs_trns, {}, val_type 

431 ) 

432 model_dict[elem] = vdict 

433 else: # val is not a BaseModel 

434 model_dict[elem] = shrt_mtchs_trns.pop(0) 

435 else: # field name not in model 

436 return model_dict, [elem] + shrt_mtchs_trns 

437 

438 return model_dict, shrt_mtchs_trns 

439 

440 

441def get_model_from_short_repr(short_repr: str, model: BaseModel) -> BaseModel: 

442 """ 

443 Get a BaseModel object from a short representation string. 

444 

445 :param short_repr: The short representation string. 

446 :type short_repr: str 

447 :param model: The BaseModel object. 

448 :type model: BaseModel 

449 :return: The BaseModel object. 

450 :rtype: BaseModel 

451 """ 

452 model_dict = {} 

453 

454 short_matches = re.findall( 

455 r"'(\[.*?\])'|'(.*?)'|([A-Za-z$]\d+)|([A-Za-z$])", short_repr 

456 ) 

457 shrt_mtchs_trns = _process_short_matches(short_matches) 

458 

459 model_dict, _ = _fill_model_dict(shrt_mtchs_trns, model_dict, model()) 

460 

461 out_model = model(**model_dict) 

462 

463 return out_model