Coverage for src/configuraptor/helpers.py: 100%

77 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2025-01-09 20:20 +0100

1""" 

2Contains stand-alone helper functions. 

3""" 

4 

5import contextlib 

6import dataclasses as dc 

7import io 

8import math 

9import types 

10import typing 

11from collections import ChainMap 

12from pathlib import Path 

13 

14from typeguard import TypeCheckError 

15from typeguard import check_type as _check_type 

16 

17try: 

18 import annotationlib 

19except ImportError: # pragma: no cover 

20 annotationlib = None 

21 

22 

23def camel_to_snake(s: str) -> str: 

24 """ 

25 Convert CamelCase to snake_case. 

26 

27 Source: 

28 https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case 

29 """ 

30 return "".join([f"_{c.lower()}" if c.isupper() else c for c in s]).lstrip("_") 

31 

32 

33# def find_pyproject_toml() -> typing.Optional[str]: 

34# """ 

35# Find the project's config toml, looks up until it finds the project root (black's logic). 

36# """ 

37# return black.files.find_pyproject_toml((os.getcwd(),)) 

38 

39 

40def find_pyproject_toml(start_dir: typing.Optional[Path | str] = None) -> Path | None: 

41 """ 

42 Search for pyproject.toml starting from the current working directory \ 

43 and moving upwards in the directory tree. 

44 

45 Args: 

46 start_dir: Starting directory to begin the search. 

47 If not provided, uses the current working directory. 

48 

49 Returns: 

50 Path or None: Path object to the found pyproject.toml file, or None if not found. 

51 """ 

52 start_dir = Path.cwd() if start_dir is None else Path(start_dir).resolve() 

53 

54 current_dir = start_dir 

55 

56 while str(current_dir) != str(current_dir.root): 

57 pyproject_toml = current_dir / "pyproject.toml" 

58 if pyproject_toml.is_file(): 

59 return pyproject_toml 

60 current_dir = current_dir.parent 

61 

62 # If not found anywhere 

63 return None 

64 

65 

66Type = typing.Type[typing.Any] 

67 

68 

69def _cls_annotations(c: type) -> dict[str, type]: # pragma: no cover 

70 """ 

71 Functions to get the annotations of a class (excl inherited, use _all_annotations for that). 

72 

73 Uses `annotationlib` if available (since 3.14) and if so, resolves forward references immediately. 

74 """ 

75 if annotationlib: 

76 return typing.cast( 

77 dict[str, type], annotationlib.get_annotations(c, format=annotationlib.Format.VALUE, eval_str=True) 

78 ) 

79 else: 

80 # note: idk why but this is not equivalent (the first doesn't work well): 

81 # return getattr(c, "__annotations__", {}) 

82 return c.__dict__.get("__annotations__") or {} 

83 

84 

85def _all_annotations(cls: type) -> ChainMap[str, type]: 

86 """ 

87 Returns a dictionary-like ChainMap that includes annotations for all \ 

88 attributes defined in cls or inherited from superclasses. 

89 """ 

90 # chainmap reverses the iterable, so reverse again beforehand to keep order normally: 

91 

92 return ChainMap(*(_cls_annotations(c) for c in getattr(cls, "__mro__", []))) 

93 

94 

95def all_annotations(cls: Type, _except: typing.Iterable[str] = None) -> dict[str, type[object]]: 

96 """ 

97 Wrapper around `_all_annotations` that filters away any keys in _except. 

98 

99 It also flattens the ChainMap to a regular dict. 

100 """ 

101 if _except is None: 

102 _except = set() 

103 

104 _all = _all_annotations(cls) 

105 return {k: v for k, v in _all.items() if k not in _except} 

106 

107 

108T = typing.TypeVar("T") 

109 

110 

111def check_type(value: typing.Any, expected_type: typing.Type[T]) -> typing.TypeGuard[T]: 

112 """ 

113 Given a variable, check if it matches 'expected_type' (which can be a Union, parameterized generic etc.). 

114 

115 Based on typeguard but this returns a boolean instead of returning the value or throwing a TypeCheckError 

116 """ 

117 try: 

118 _check_type(value, expected_type) 

119 return True 

120 except TypeCheckError: 

121 return False 

122 

123 

124def is_builtin_type(_type: Type) -> bool: 

125 """ 

126 Returns whether _type is one of the builtin types. 

127 """ 

128 return _type.__module__ in ("__builtin__", "builtins") 

129 

130 

131# def is_builtin_class_instance(obj: typing.Any) -> bool: 

132# return is_builtin_type(obj.__class__) 

133 

134 

135def is_from_types_or_typing(_type: Type) -> bool: 

136 """ 

137 Returns whether _type is one of the stlib typing/types types. 

138 

139 e.g. types.UnionType or typing.Union 

140 """ 

141 return _type.__module__ in ("types", "typing") 

142 

143 

144def is_from_other_toml_supported_module(_type: Type) -> bool: 

145 """ 

146 Besides builtins, toml also supports 'datetime' and 'math' types, \ 

147 so this returns whether _type is a type from these stdlib modules. 

148 """ 

149 return _type.__module__ in ("datetime", "math") 

150 

151 

152def is_parameterized(_type: Type) -> bool: 

153 """ 

154 Returns whether _type is a parameterized type. 

155 

156 Examples: 

157 list[str] -> True 

158 str -> False 

159 """ 

160 return typing.get_origin(_type) is not None 

161 

162 

163def is_custom_class(_type: Type) -> bool: 

164 """ 

165 Tries to guess if _type is a builtin or a custom (user-defined) class. 

166 

167 Other logic in this module depends on knowing that. 

168 """ 

169 return ( 

170 type(_type) is type 

171 and not is_builtin_type(_type) 

172 and not is_from_other_toml_supported_module(_type) 

173 and not is_from_types_or_typing(_type) 

174 ) 

175 

176 

177def instance_of_custom_class(var: typing.Any) -> bool: 

178 """ 

179 Calls `is_custom_class` on an instance of a (possibly custom) class. 

180 """ 

181 return is_custom_class(var.__class__) 

182 

183 

184def is_optional(_type: Type | typing.Any) -> bool: 

185 """ 

186 Tries to guess if _type could be optional. 

187 

188 Examples: 

189 None -> True 

190 NoneType -> True 

191 typing.Union[str, None] -> True 

192 str | None -> True 

193 list[str | None] -> False 

194 list[str] -> False 

195 """ 

196 if _type and (is_parameterized(_type) and typing.get_origin(_type) in (dict, list)) or (_type is math.nan): 

197 # e.g. list[str] 

198 # will crash issubclass to test it first here 

199 return False 

200 

201 try: 

202 return ( 

203 _type is None 

204 or types.NoneType in typing.get_args(_type) # union with Nonetype 

205 or issubclass(types.NoneType, _type) 

206 or issubclass(types.NoneType, type(_type)) # no type # Nonetype 

207 ) 

208 except TypeError: 

209 # probably some weird input that's not a type 

210 return False 

211 

212 

213def dataclass_field(cls: Type, key: str) -> typing.Optional[dc.Field[typing.Any]]: 

214 """ 

215 Get Field info for a dataclass cls. 

216 """ 

217 fields = getattr(cls, "__dataclass_fields__", {}) 

218 return fields.get(key) 

219 

220 

221@contextlib.contextmanager 

222def uncloseable(fd: typing.BinaryIO) -> typing.Generator[typing.BinaryIO, typing.Any, None]: 

223 """ 

224 Context manager which turns the fd's close operation to no-op for the duration of the context. 

225 """ 

226 close = fd.close 

227 fd.close = lambda: None # type: ignore 

228 yield fd 

229 fd.close = close # type: ignore 

230 

231 

232def as_binaryio(file: str | Path | typing.BinaryIO | None, mode: typing.Literal["rb", "wb"] = "rb") -> typing.BinaryIO: 

233 """ 

234 Convert a number of possible 'file' descriptions into a single BinaryIO interface. 

235 """ 

236 if isinstance(file, str): 

237 file = Path(file) 

238 if isinstance(file, Path): 

239 file = file.open(mode) 

240 if file is None: 

241 file = io.BytesIO() 

242 if isinstance(file, io.BytesIO): 

243 # so .read() works after .write(): 

244 file.seek(0) 

245 # so the with-statement doesn't close the in-memory file: 

246 file = uncloseable(file) # type: ignore 

247 

248 return file