Coverage for src/configuraptor/helpers.py: 100%
77 statements
« prev ^ index » next coverage.py v7.2.7, created at 2025-01-09 20:20 +0100
« prev ^ index » next coverage.py v7.2.7, created at 2025-01-09 20:20 +0100
1"""
2Contains stand-alone helper functions.
3"""
5import contextlib
6import dataclasses as dc
7import io
8import math
9import types
10import typing
11from collections import ChainMap
12from pathlib import Path
14from typeguard import TypeCheckError
15from typeguard import check_type as _check_type
17try:
18 import annotationlib
19except ImportError: # pragma: no cover
20 annotationlib = None
23def camel_to_snake(s: str) -> str:
24 """
25 Convert CamelCase to snake_case.
27 Source:
28 https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case
29 """
30 return "".join([f"_{c.lower()}" if c.isupper() else c for c in s]).lstrip("_")
33# def find_pyproject_toml() -> typing.Optional[str]:
34# """
35# Find the project's config toml, looks up until it finds the project root (black's logic).
36# """
37# return black.files.find_pyproject_toml((os.getcwd(),))
40def find_pyproject_toml(start_dir: typing.Optional[Path | str] = None) -> Path | None:
41 """
42 Search for pyproject.toml starting from the current working directory \
43 and moving upwards in the directory tree.
45 Args:
46 start_dir: Starting directory to begin the search.
47 If not provided, uses the current working directory.
49 Returns:
50 Path or None: Path object to the found pyproject.toml file, or None if not found.
51 """
52 start_dir = Path.cwd() if start_dir is None else Path(start_dir).resolve()
54 current_dir = start_dir
56 while str(current_dir) != str(current_dir.root):
57 pyproject_toml = current_dir / "pyproject.toml"
58 if pyproject_toml.is_file():
59 return pyproject_toml
60 current_dir = current_dir.parent
62 # If not found anywhere
63 return None
66Type = typing.Type[typing.Any]
69def _cls_annotations(c: type) -> dict[str, type]: # pragma: no cover
70 """
71 Functions to get the annotations of a class (excl inherited, use _all_annotations for that).
73 Uses `annotationlib` if available (since 3.14) and if so, resolves forward references immediately.
74 """
75 if annotationlib:
76 return typing.cast(
77 dict[str, type], annotationlib.get_annotations(c, format=annotationlib.Format.VALUE, eval_str=True)
78 )
79 else:
80 # note: idk why but this is not equivalent (the first doesn't work well):
81 # return getattr(c, "__annotations__", {})
82 return c.__dict__.get("__annotations__") or {}
85def _all_annotations(cls: type) -> ChainMap[str, type]:
86 """
87 Returns a dictionary-like ChainMap that includes annotations for all \
88 attributes defined in cls or inherited from superclasses.
89 """
90 # chainmap reverses the iterable, so reverse again beforehand to keep order normally:
92 return ChainMap(*(_cls_annotations(c) for c in getattr(cls, "__mro__", [])))
95def all_annotations(cls: Type, _except: typing.Iterable[str] = None) -> dict[str, type[object]]:
96 """
97 Wrapper around `_all_annotations` that filters away any keys in _except.
99 It also flattens the ChainMap to a regular dict.
100 """
101 if _except is None:
102 _except = set()
104 _all = _all_annotations(cls)
105 return {k: v for k, v in _all.items() if k not in _except}
108T = typing.TypeVar("T")
111def check_type(value: typing.Any, expected_type: typing.Type[T]) -> typing.TypeGuard[T]:
112 """
113 Given a variable, check if it matches 'expected_type' (which can be a Union, parameterized generic etc.).
115 Based on typeguard but this returns a boolean instead of returning the value or throwing a TypeCheckError
116 """
117 try:
118 _check_type(value, expected_type)
119 return True
120 except TypeCheckError:
121 return False
124def is_builtin_type(_type: Type) -> bool:
125 """
126 Returns whether _type is one of the builtin types.
127 """
128 return _type.__module__ in ("__builtin__", "builtins")
131# def is_builtin_class_instance(obj: typing.Any) -> bool:
132# return is_builtin_type(obj.__class__)
135def is_from_types_or_typing(_type: Type) -> bool:
136 """
137 Returns whether _type is one of the stlib typing/types types.
139 e.g. types.UnionType or typing.Union
140 """
141 return _type.__module__ in ("types", "typing")
144def is_from_other_toml_supported_module(_type: Type) -> bool:
145 """
146 Besides builtins, toml also supports 'datetime' and 'math' types, \
147 so this returns whether _type is a type from these stdlib modules.
148 """
149 return _type.__module__ in ("datetime", "math")
152def is_parameterized(_type: Type) -> bool:
153 """
154 Returns whether _type is a parameterized type.
156 Examples:
157 list[str] -> True
158 str -> False
159 """
160 return typing.get_origin(_type) is not None
163def is_custom_class(_type: Type) -> bool:
164 """
165 Tries to guess if _type is a builtin or a custom (user-defined) class.
167 Other logic in this module depends on knowing that.
168 """
169 return (
170 type(_type) is type
171 and not is_builtin_type(_type)
172 and not is_from_other_toml_supported_module(_type)
173 and not is_from_types_or_typing(_type)
174 )
177def instance_of_custom_class(var: typing.Any) -> bool:
178 """
179 Calls `is_custom_class` on an instance of a (possibly custom) class.
180 """
181 return is_custom_class(var.__class__)
184def is_optional(_type: Type | typing.Any) -> bool:
185 """
186 Tries to guess if _type could be optional.
188 Examples:
189 None -> True
190 NoneType -> True
191 typing.Union[str, None] -> True
192 str | None -> True
193 list[str | None] -> False
194 list[str] -> False
195 """
196 if _type and (is_parameterized(_type) and typing.get_origin(_type) in (dict, list)) or (_type is math.nan):
197 # e.g. list[str]
198 # will crash issubclass to test it first here
199 return False
201 try:
202 return (
203 _type is None
204 or types.NoneType in typing.get_args(_type) # union with Nonetype
205 or issubclass(types.NoneType, _type)
206 or issubclass(types.NoneType, type(_type)) # no type # Nonetype
207 )
208 except TypeError:
209 # probably some weird input that's not a type
210 return False
213def dataclass_field(cls: Type, key: str) -> typing.Optional[dc.Field[typing.Any]]:
214 """
215 Get Field info for a dataclass cls.
216 """
217 fields = getattr(cls, "__dataclass_fields__", {})
218 return fields.get(key)
221@contextlib.contextmanager
222def uncloseable(fd: typing.BinaryIO) -> typing.Generator[typing.BinaryIO, typing.Any, None]:
223 """
224 Context manager which turns the fd's close operation to no-op for the duration of the context.
225 """
226 close = fd.close
227 fd.close = lambda: None # type: ignore
228 yield fd
229 fd.close = close # type: ignore
232def as_binaryio(file: str | Path | typing.BinaryIO | None, mode: typing.Literal["rb", "wb"] = "rb") -> typing.BinaryIO:
233 """
234 Convert a number of possible 'file' descriptions into a single BinaryIO interface.
235 """
236 if isinstance(file, str):
237 file = Path(file)
238 if isinstance(file, Path):
239 file = file.open(mode)
240 if file is None:
241 file = io.BytesIO()
242 if isinstance(file, io.BytesIO):
243 # so .read() works after .write():
244 file.seek(0)
245 # so the with-statement doesn't close the in-memory file:
246 file = uncloseable(file) # type: ignore
248 return file