Coverage for src/configuraptor/helpers.py: 95%
77 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-09 20:12 +0100
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-09 20:12 +0100
1"""
2Contains stand-alone helper functions.
3"""
5import contextlib
6import dataclasses as dc
7import io
8import math
9import types
10import typing
11from collections import ChainMap
12from pathlib import Path
14from typeguard import TypeCheckError
15from typeguard import check_type as _check_type
17try:
18 import annotationlib
19except ImportError: # pragma: no cover
20 annotationlib = None
23def camel_to_snake(s: str) -> str:
24 """
25 Convert CamelCase to snake_case.
27 Source:
28 https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case
29 """
30 return "".join([f"_{c.lower()}" if c.isupper() else c for c in s]).lstrip("_")
33# def find_pyproject_toml() -> typing.Optional[str]:
34# """
35# Find the project's config toml, looks up until it finds the project root (black's logic).
36# """
37# return black.files.find_pyproject_toml((os.getcwd(),))
40def find_pyproject_toml(start_dir: typing.Optional[Path | str] = None) -> Path | None:
41 """
42 Search for pyproject.toml starting from the current working directory \
43 and moving upwards in the directory tree.
45 Args:
46 start_dir: Starting directory to begin the search.
47 If not provided, uses the current working directory.
49 Returns:
50 Path or None: Path object to the found pyproject.toml file, or None if not found.
51 """
52 start_dir = Path.cwd() if start_dir is None else Path(start_dir).resolve()
54 current_dir = start_dir
56 while str(current_dir) != str(current_dir.root):
57 pyproject_toml = current_dir / "pyproject.toml"
58 if pyproject_toml.is_file():
59 return pyproject_toml
60 current_dir = current_dir.parent
62 # If not found anywhere
63 return None
66Type = typing.Type[typing.Any]
69def _cls_annotations(c: type) -> dict[str, type]: # pragma: no cover
70 """
71 Functions to get the annotations of a class (excl inherited, use _all_annotations for that).
73 Uses `annotationlib` if available (since 3.14) and if so, resolves forward references immediately.
74 """
75 if annotationlib:
76 return typing.cast(
77 dict[str, type], annotationlib.get_annotations(c, format=annotationlib.Format.VALUE, eval_str=True)
78 )
79 else:
80 return getattr(c, "__annotations__", {})
83def _all_annotations(cls: type) -> ChainMap[str, type]:
84 """
85 Returns a dictionary-like ChainMap that includes annotations for all \
86 attributes defined in cls or inherited from superclasses.
87 """
88 # chainmap reverses the iterable, so reverse again beforehand to keep order normally:
90 return ChainMap(*(_cls_annotations(c) for c in getattr(cls, "__mro__", [])))
93def all_annotations(cls: Type, _except: typing.Iterable[str] = None) -> dict[str, type[object]]:
94 """
95 Wrapper around `_all_annotations` that filters away any keys in _except.
97 It also flattens the ChainMap to a regular dict.
98 """
99 if _except is None:
100 _except = set()
102 _all = _all_annotations(cls)
103 return {k: v for k, v in _all.items() if k not in _except}
106T = typing.TypeVar("T")
109def check_type(value: typing.Any, expected_type: typing.Type[T]) -> typing.TypeGuard[T]:
110 """
111 Given a variable, check if it matches 'expected_type' (which can be a Union, parameterized generic etc.).
113 Based on typeguard but this returns a boolean instead of returning the value or throwing a TypeCheckError
114 """
115 try:
116 _check_type(value, expected_type)
117 return True
118 except TypeCheckError:
119 return False
122def is_builtin_type(_type: Type) -> bool:
123 """
124 Returns whether _type is one of the builtin types.
125 """
126 return _type.__module__ in ("__builtin__", "builtins")
129# def is_builtin_class_instance(obj: typing.Any) -> bool:
130# return is_builtin_type(obj.__class__)
133def is_from_types_or_typing(_type: Type) -> bool:
134 """
135 Returns whether _type is one of the stlib typing/types types.
137 e.g. types.UnionType or typing.Union
138 """
139 return _type.__module__ in ("types", "typing")
142def is_from_other_toml_supported_module(_type: Type) -> bool:
143 """
144 Besides builtins, toml also supports 'datetime' and 'math' types, \
145 so this returns whether _type is a type from these stdlib modules.
146 """
147 return _type.__module__ in ("datetime", "math")
150def is_parameterized(_type: Type) -> bool:
151 """
152 Returns whether _type is a parameterized type.
154 Examples:
155 list[str] -> True
156 str -> False
157 """
158 return typing.get_origin(_type) is not None
161def is_custom_class(_type: Type) -> bool:
162 """
163 Tries to guess if _type is a builtin or a custom (user-defined) class.
165 Other logic in this module depends on knowing that.
166 """
167 return (
168 type(_type) is type
169 and not is_builtin_type(_type)
170 and not is_from_other_toml_supported_module(_type)
171 and not is_from_types_or_typing(_type)
172 )
175def instance_of_custom_class(var: typing.Any) -> bool:
176 """
177 Calls `is_custom_class` on an instance of a (possibly custom) class.
178 """
179 return is_custom_class(var.__class__)
182def is_optional(_type: Type | typing.Any) -> bool:
183 """
184 Tries to guess if _type could be optional.
186 Examples:
187 None -> True
188 NoneType -> True
189 typing.Union[str, None] -> True
190 str | None -> True
191 list[str | None] -> False
192 list[str] -> False
193 """
194 if _type and (is_parameterized(_type) and typing.get_origin(_type) in (dict, list)) or (_type is math.nan):
195 # e.g. list[str]
196 # will crash issubclass to test it first here
197 return False
199 try:
200 return (
201 _type is None
202 or types.NoneType in typing.get_args(_type) # union with Nonetype
203 or issubclass(types.NoneType, _type)
204 or issubclass(types.NoneType, type(_type)) # no type # Nonetype
205 )
206 except TypeError:
207 # probably some weird input that's not a type
208 return False
211def dataclass_field(cls: Type, key: str) -> typing.Optional[dc.Field[typing.Any]]:
212 """
213 Get Field info for a dataclass cls.
214 """
215 fields = getattr(cls, "__dataclass_fields__", {})
216 return fields.get(key)
219@contextlib.contextmanager
220def uncloseable(fd: typing.BinaryIO) -> typing.Generator[typing.BinaryIO, typing.Any, None]:
221 """
222 Context manager which turns the fd's close operation to no-op for the duration of the context.
223 """
224 close = fd.close
225 fd.close = lambda: None # type: ignore
226 yield fd
227 fd.close = close # type: ignore
230def as_binaryio(file: str | Path | typing.BinaryIO | None, mode: typing.Literal["rb", "wb"] = "rb") -> typing.BinaryIO:
231 """
232 Convert a number of possible 'file' descriptions into a single BinaryIO interface.
233 """
234 if isinstance(file, str):
235 file = Path(file)
236 if isinstance(file, Path):
237 file = file.open(mode)
238 if file is None:
239 file = io.BytesIO()
240 if isinstance(file, io.BytesIO):
241 # so .read() works after .write():
242 file.seek(0)
243 # so the with-statement doesn't close the in-memory file:
244 file = uncloseable(file) # type: ignore
246 return file