Coverage for denofo/utils/helpers.py: 59%
197 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 15:27 +0200
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 15:27 +0200
1import re
2from pydantic import BaseModel
3from enum import Enum
4from types import UnionType, GenericAlias
5from collections import Counter
6from pathlib import Path
7from ast import literal_eval
8from denofo.choices import ThresholdChoices
9from denofo.utils.constants import ENCODE_DICT, DECODE_DICT, REPLACEMENTS
12def different_answers(answer: any, prev_answer: any) -> bool:
13 """
14 Check if new answer differs from previously given answer.
16 :param answer: The answer.
17 :type answer: any
18 :param prev_answer: The previous answer.
19 :type prev_answer: any
20 :return: True if the answers are different, False otherwise.
21 :rtype: bool
22 """
23 if isinstance(answer, list) and isinstance(prev_answer, list):
24 return Counter(answer) != Counter(prev_answer)
26 return answer != prev_answer
29def get_model_from_qstack_dict(
30 qstack_dict: dict,
31 model: BaseModel,
32) -> BaseModel:
33 """
34 Get a BaseModel object from a qstack dictionary (which can contain field
35 names and values from other models as well).
37 :param qstack_dict: The qstack dictionary. keys = model field names, values = model field values.
38 :type qstack_dict: dict
39 :param model: The BaseModel object.
40 :type model: BaseModel
41 :return: The BaseModel object with respective field values.
42 :rtype: BaseModel
43 """
44 model_fields = model.model_fields.keys()
45 matching_fields = {k: v for k, v in qstack_dict.items() if k in model_fields and v}
46 if len(matching_fields) == 0:
47 return None
49 return model(**matching_fields)
52def add_extension(file_path: Path, extension: str = "dngf") -> Path:
53 """
54 Add an extension to a file path if it does not already have one.
56 :param file_path: The file path.
57 :type file_path: Path
58 :param extension: The extension to add.
59 :type extension: str
60 :return: The file path with the extension added.
61 :rtype: Path
62 """
63 if not file_path.suffix:
64 file_path = file_path.with_suffix(f".{extension}")
66 return file_path
69def infer_format_from_extension(file_path: Path) -> str:
70 """
71 Infer the format of a file based on its extension.
73 :param file_path: The file path.
74 :type file_path: Path
75 :return: The inferred format.
76 :rtype: str
77 """
78 ext_to_format = {
79 ".dngf": "dngf",
80 ".json": "dngf",
81 ".pickle": "pickle",
82 ".fasta": "fasta",
83 ".fa": "fasta",
84 ".fna": "fasta",
85 ".faa": "fasta",
86 ".txt": "shortstr",
87 ".gff": "gff",
88 ".gtf": "gff",
89 }
91 extension = file_path.suffix
93 return ext_to_format.get(extension, None)
96def diff_two_lists(
97 list1: list,
98 list2: list,
99 model_name: str,
100 field_name: str,
101 mode: str = "differences",
102) -> list:
103 """
104 Compare two lists.
106 :param list1: The first list.
107 :type list1: list
108 :param list2: The second list.
109 :type list2: list
110 :return: The comparison result.
111 :rtype: list
112 """
113 diff = []
115 set1 = set(list1)
116 set2 = set(list2)
118 if mode == "similarities":
119 diff.extend(
120 [("same", model_name, field_name, overlap) for overlap in set1 & set2]
121 )
122 if mode == "differences":
123 diff.extend([("1not2", model_name, field_name, only1) for only1 in set1 - set2])
124 diff.extend([("2not1", model_name, field_name, only2) for only2 in set2 - set1])
126 return diff
129def _merge_thresholds_with_vals(thres_lst: list, thresval_lst: list) -> list:
130 """
131 Merge threshold types with their respective values.
133 :param thres_lst: The list of threshold types.
134 :type thres_lst: list
135 :param thresval_lst: The list of threshold values.
136 :type thresval_lst: list
137 :return: The merged list of threshold types and values.
138 :rtype: list
139 """
140 clist = thres_lst.copy()
142 if ThresholdChoices.CUSTOM in clist:
143 clist.remove(ThresholdChoices.CUSTOM)
144 clist = [(threstype, thresval) for threstype, thresval in zip(clist, thresval_lst)]
146 return clist
149def compare_two_models(
150 dngf1: BaseModel,
151 dngf2: BaseModel,
152 mode: str = "differences",
153) -> list[tuple]:
154 """
155 Compare two BaseModel objects.
157 :param dngf1: The first BaseModel object.
158 :type dngf1: BaseModel
159 :param dngf2: The second BaseModel object.
160 :type dngf2: BaseModel
161 :param mode: The mode of comparison. Options: 'differences' (default) or 'similarities'.
162 :type mode: str
163 :param name1: The name of the first BaseModel object in output.
164 :type name1: str
165 :param name2: The name of the second BaseModel object in output.
166 :type name2: str
167 :return: The comparison result.
168 :rtype: list[tuple[str]]
169 """
170 comparison = []
172 for it1, it2 in zip(dngf1.model_dump().items(), dngf2.model_dump().items()):
173 field1, value1 = it1
174 field2, value2 = it2
175 if isinstance(value1, dict) and isinstance(value2, dict): # nested models
176 model1 = dngf1.__annotations__[field1]
177 model2 = dngf2.__annotations__[field2]
178 if isinstance(model1, UnionType):
179 model1 = model1.__args__[0]
180 model2 = model2.__args__[0]
182 model1 = model1.model_validate(value1)
183 model2 = model2.model_validate(value2)
184 comparison.extend(compare_two_models(model1, model2, mode=mode))
185 elif (
186 field1 == "threshold"
187 ): # special case for mapping threshold metrics to threshold values
188 full_list1 = value1 + (
189 dngf1.customThreshold if dngf1.customThreshold is not None else []
190 )
191 full_list2 = value2 + (
192 dngf2.customThreshold if dngf2.customThreshold is not None else []
193 )
194 full_list1 = _merge_thresholds_with_vals(full_list1, dngf1.thresholdValue)
195 full_list2 = _merge_thresholds_with_vals(full_list2, dngf2.thresholdValue)
197 comp_new = diff_two_lists(
198 full_list1, full_list2, dngf1.__class__.__name__, field1, mode=mode
199 )
201 if mode == "differences":
202 to_skip = set()
203 for elem in comp_new:
204 threstype, thresval = elem[3]
206 if threstype in to_skip:
207 continue
208 threstypes = [
209 elem[3] for elem in comp_new if elem[3][0] == threstype
210 ]
211 if len(threstypes) > 1:
212 to_skip.add(threstype)
213 thresvals = ", ".join([str(elem[1]) for elem in threstypes])
214 comparison.append(
215 (
216 "diffval",
217 dngf1.__class__.__name__,
218 field1,
219 threstype,
220 thresvals,
221 )
222 )
223 else:
224 comparison.append(elem)
226 elif field1 in ("customThreshold", "thresholdValue"): # see special case above
227 continue
228 elif value1 == value2 and mode == "similarities" and value1 is not None:
229 comparison.append(("same", dngf1.__class__.__name__, field1, value1))
230 elif value1 != value2:
231 if value1 and not value2:
232 comparison.append(("1not2", dngf1.__class__.__name__, field1, value1))
233 elif not value1 and value2:
234 comparison.append(("2not1", dngf2.__class__.__name__, field2, value2))
235 elif isinstance(value1, list) and isinstance(value2, list):
236 comparison.extend(
237 diff_two_lists(
238 value1, value2, dngf1.__class__.__name__, field1, mode=mode
239 )
240 )
241 else:
242 comparison.append(
243 ("diffval", dngf1.__class__.__name__, field1, value1, value2)
244 )
246 return comparison
249def _get_index_from_enum_choice(enum_choice: Enum) -> int:
250 """
251 Get the index of an Enum choice.
253 :param enum_choice: The Enum choice.
254 :type enum_choice: Enum
255 :return: The index of the Enum choice.
256 :rtype: int
257 """
258 enum_order = list(enum_choice.__class__)
259 if (
260 hasattr(enum_choice.__class__, "CUSTOM")
261 and enum_choice == enum_choice.__class__.CUSTOM
262 ):
263 value_idx = 0
264 else:
265 value_idx = enum_order.index(enum_choice) + 1
267 return value_idx
270def get_short_repr(orig_model: BaseModel) -> str:
271 """
272 Get a short representation string of a BaseModel object.
274 :param orig_model: The BaseModel object.
275 :type orig_model: BaseModel
276 :return: The short representation of the BaseModel object.
277 :rtype: str
278 """
279 model_dict = orig_model.model_dump()
280 short_repr = ""
281 for field, value in model_dict.items():
282 if not value:
283 continue
284 if field in ENCODE_DICT:
285 value_short_repr = ""
287 if isinstance(value, dict): # nested model
288 submodel = orig_model.__annotations__[field]
289 if isinstance(submodel, UnionType):
290 submodel = submodel.__args__[0]
291 value_short_repr += get_short_repr(submodel.validate(value))
292 elif isinstance(value, Enum):
293 value_idx = _get_index_from_enum_choice(value)
294 value_short_repr += str(value_idx)
295 elif isinstance(value, list):
296 if isinstance(value[0], Enum):
297 for val in value:
298 val_idx = _get_index_from_enum_choice(val)
299 value_short_repr += str(val_idx)
300 elif isinstance(value[0], str) or isinstance(value[0], float):
301 vals = ",".join(map(str, value))
302 value_short_repr += f"'[{vals}]'"
303 elif (
304 isinstance(value, str)
305 or isinstance(value, int)
306 or isinstance(value, float)
307 ):
308 value_short_repr += f"'{value}'"
309 elif isinstance(value, bool):
310 value_short_repr += str(int(value))
312 if value_short_repr:
313 short_repr += ENCODE_DICT[field]
314 short_repr += value_short_repr
316 return short_repr
319def _process_short_matches(short_matches: list) -> list:
320 """
321 Process the short matches to translate back to original values.
323 :param short_matches: The short matches.
324 :type short_matches: list
325 :return: The processed short matches.
326 :rtype: list
327 """
328 shrt_mtchs_trns = []
329 for match in short_matches:
330 if match[0]: # list values
331 pattern = re.compile(r"\[|\]|,")
332 ematch = pattern.sub(lambda m: REPLACEMENTS[m.group()], match[0])
333 mlst = literal_eval(ematch)
334 if isinstance(mlst, list):
335 shrt_mtchs_trns.append(mlst)
336 else:
337 raise ValueError(f"Could not convert to list: {match[0]}")
338 elif match[1]: # single str/int/float values
339 if match[1].isdigit():
340 shrt_mtchs_trns.append(int(match[1]))
341 try:
342 shrt_mtchs_trns.append(float(match[1]))
343 except ValueError:
344 shrt_mtchs_trns.append(match[1])
345 elif match[2]: # enum choices encoded
346 field_name = DECODE_DICT[match[2][0]]
347 enum_nums = list(map(int, list(match[2][1:])))
348 shrt_mtchs_trns.append((field_name, enum_nums))
349 elif match[3]: # single field names encoded
350 shrt_mtchs_trns.append(DECODE_DICT[match[3]])
352 return shrt_mtchs_trns
355def _get_enum_choices_from_num(enum_nums: list, enum_class: Enum) -> Enum:
356 """
357 Get the Enum choices from their respective numbers.
359 :param enum_nums: The list of Enum numbers.
360 :type enum_nums: list
361 :param enum_class: The Enum class.
362 :type enum_class: Enum
363 :return: The Enum choices.
364 :rtype: Enum
365 """
366 enum_choices = []
367 enum_order = list(enum_class)
369 for enum_num in enum_nums:
370 if enum_num == 0:
371 enum_choices.append(enum_class.CUSTOM)
372 else:
373 enum_choices.append(enum_order[enum_num - 1])
375 return enum_choices
378def _fill_model_dict(
379 shrt_mtchs_trns: list, model_dict: dict, model: BaseModel
380) -> tuple[dict, list]:
381 """
382 Fill the model dictionary with values from the short representation.
384 :param shrt_mtchs_trns: The short matches.
385 :type shrt_mtchs_trns: list
386 :param model_dict: The model dictionary.
387 :type model_dict: dict
388 :param model: The BaseModel object.
389 :type model: BaseModel
390 :return: The model dictionary and the remaining short matches.
391 :rtype: tuple[dict, list]
392 """
393 while shrt_mtchs_trns:
394 elem = shrt_mtchs_trns.pop(0)
396 if isinstance(elem, tuple):
397 enum_choices = None
398 field_name = elem[0]
400 if field_name not in model.__annotations__:
401 return model_dict, [elem] + shrt_mtchs_trns
402 enum_class = model.__annotations__[field_name]
404 if not isinstance(enum_class, UnionType) and issubclass(enum_class, bool):
405 enum_choices = elem[1][0]
406 elif not isinstance(enum_class, UnionType) and issubclass(enum_class, Enum):
407 enum_choices = _get_enum_choices_from_num(elem[1], enum_class)
408 enum_choices = enum_choices[0]
409 elif isinstance(enum_class, GenericAlias) or isinstance(
410 enum_class, UnionType
411 ):
412 enum_class = enum_class.__args__[0]
413 if isinstance(enum_class, GenericAlias):
414 enum_class = enum_class.__args__[0]
415 if not enum_choices:
416 enum_choices = _get_enum_choices_from_num(elem[1], enum_class)
418 model_dict[field_name] = enum_choices
420 elif isinstance(elem, str):
421 if elem in model.__annotations__:
422 val_type = model.__annotations__[elem]
423 if isinstance(val_type, UnionType):
424 val_type = val_type.__args__[0]
426 if not isinstance(val_type, GenericAlias) and issubclass(
427 val_type, BaseModel
428 ): # val is BaseModel
429 vdict, shrt_mtchs_trns = _fill_model_dict(
430 shrt_mtchs_trns, {}, val_type
431 )
432 model_dict[elem] = vdict
433 else: # val is not a BaseModel
434 model_dict[elem] = shrt_mtchs_trns.pop(0)
435 else: # field name not in model
436 return model_dict, [elem] + shrt_mtchs_trns
438 return model_dict, shrt_mtchs_trns
441def get_model_from_short_repr(short_repr: str, model: BaseModel) -> BaseModel:
442 """
443 Get a BaseModel object from a short representation string.
445 :param short_repr: The short representation string.
446 :type short_repr: str
447 :param model: The BaseModel object.
448 :type model: BaseModel
449 :return: The BaseModel object.
450 :rtype: BaseModel
451 """
452 model_dict = {}
454 short_matches = re.findall(
455 r"'(\[.*?\])'|'(.*?)'|([A-Za-z$]\d+)|([A-Za-z$])", short_repr
456 )
457 shrt_mtchs_trns = _process_short_matches(short_matches)
459 model_dict, _ = _fill_model_dict(shrt_mtchs_trns, model_dict, model())
461 out_model = model(**model_dict)
463 return out_model