Coverage for src/pydantic_typer/main.py: 99%

116 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-08-04 21:00 +0200

1from __future__ import annotations 

2 

3import inspect 

4from functools import wraps 

5from typing import Any, Callable 

6 

7import click 

8import pydantic 

9from typer import BadParameter, Option 

10from typer import Typer as TyperBase 

11from typer.main import CommandFunctionType, get_click_param, get_params_from_function, lenient_issubclass 

12from typer.models import OptionInfo, ParameterInfo 

13from typer.utils import ( 

14 AnnotatedParamWithDefaultValueError, 

15 DefaultFactoryAndDefaultValueError, 

16 MixedAnnotatedAndDefaultStyleError, 

17 MultipleTyperAnnotationsError, 

18 _split_annotation_from_typer_annotations, 

19) 

20from typing_extensions import Annotated 

21 

22from pydantic_typer.utils import copy_type, deep_update, inspect_signature 

23 

24PYDANTIC_FIELD_SEPARATOR = "." 

25 

26 

27def _flatten_pydantic_model( 

28 model: pydantic.BaseModel, ancestors: list[str], ancestor_typer_param=None 

29) -> dict[str, inspect.Parameter]: 

30 pydantic_parameters = {} 

31 for field_name, field in model.model_fields.items(): 

32 qualifier = [*ancestors, field_name] 

33 sub_name = f"_pydantic_{'_'.join(qualifier)}" 

34 if lenient_issubclass(field.annotation, pydantic.BaseModel): 

35 # TODO: pass ancestor_typer_param 

36 params = _flatten_pydantic_model(field.annotation, qualifier) # type: ignore 

37 pydantic_parameters.update(params) 

38 else: 

39 default = ( 

40 field.default if field.default is not pydantic.fields._Unset else ... # noqa: SLF001 

41 ) 

42 # Pydantic stores annotations in field.metadata. 

43 # If the field is already annotated with a typer.Option or typer.Argument, use that. 

44 existing_typer_params = [meta for meta in field.metadata if isinstance(meta, ParameterInfo)] 

45 if existing_typer_params: 

46 typer_param = existing_typer_params[0] 

47 if isinstance(typer_param, OptionInfo) and not typer_param.param_decls: 

48 # If the the option was not named manually, use the default naming scheme 

49 typer_param.param_decls = (f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}",) 

50 elif ancestor_typer_param: 

51 typer_param = ancestor_typer_param 

52 else: 

53 typer_param = Option(f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}") 

54 pydantic_parameters[sub_name] = inspect.Parameter( 

55 sub_name, 

56 inspect.Parameter.KEYWORD_ONLY, 

57 annotation=Annotated[field.annotation, typer_param, qualifier], 

58 default=default, 

59 ) 

60 return pydantic_parameters 

61 

62 

63def enable_pydantic(callback: CommandFunctionType) -> CommandFunctionType: 

64 original_signature = inspect_signature(callback) 

65 

66 pydantic_parameters = {} 

67 pydantic_roots = {} 

68 other_parameters = {} 

69 for name, parameter in original_signature.parameters.items(): 

70 base_annotation, typer_annotations = _split_annotation_from_typer_annotations(parameter.annotation) 

71 typer_param = typer_annotations[0] if typer_annotations else None 

72 if lenient_issubclass(base_annotation, pydantic.BaseModel): 

73 params = _flatten_pydantic_model(parameter.annotation, [name], typer_param) 

74 pydantic_parameters.update(params) 

75 pydantic_roots[name] = base_annotation 

76 else: 

77 other_parameters[name] = parameter 

78 

79 extended_signature = inspect.Signature( 

80 [*other_parameters.values(), *pydantic_parameters.values()], 

81 return_annotation=original_signature.return_annotation, 

82 ) 

83 

84 @copy_type(callback) 

85 @wraps(callback) 

86 def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] 

87 converted_kwargs = kwargs.copy() 

88 raw_pydantic_objects: dict[str, Any] = {} 

89 for kwarg_name in kwargs: 

90 if kwarg_name in pydantic_parameters: 90 ↛ 89line 90 didn't jump to line 89 because the condition on line 90 was always true

91 kwarg_value = kwargs[kwarg_name] 

92 converted_kwargs.pop(kwarg_name) 

93 annotation = pydantic_parameters[kwarg_name].annotation 

94 _, qualifier = annotation.__metadata__ 

95 for part in reversed(qualifier): 

96 kwarg_value = {part: kwarg_value} 

97 raw_pydantic_objects = deep_update(raw_pydantic_objects, kwarg_value) 

98 for root_name, value in raw_pydantic_objects.items(): 

99 converted_kwargs[root_name] = pydantic_roots[root_name](**value) 

100 return callback(*args, **converted_kwargs) 

101 

102 wrapper.__signature__ = extended_signature # type: ignore 

103 # Copy annotations to make forward references work in Python <= 3.9 

104 wrapper.__annotations__ = {k: v.annotation for k, v in extended_signature.parameters.items()} 

105 return wrapper 

106 

107 

108def enable_pydantic_type_validation(callback: CommandFunctionType) -> CommandFunctionType: 

109 original_signature = inspect_signature(callback) 

110 

111 # Change the annotation of unsupported types to str to be parsed by pydantic. 

112 # Adapted from https://github.com/tiangolo/typer/blob/95b767e38a98ee287a7a0e28176284836e1188c2/typer/main.py#L543 

113 # TODO: it's not ideal to call get_params_from_function and get_click_param here, 

114 # because it will be called in typer again, but the annotations supported by typer are quite dynamic. 

115 try: 

116 parameters = get_params_from_function(callback) 

117 except ( 

118 AnnotatedParamWithDefaultValueError, 

119 DefaultFactoryAndDefaultValueError, 

120 MixedAnnotatedAndDefaultStyleError, 

121 MultipleTyperAnnotationsError, 

122 ): 

123 # We can't raise now. Typer will raise in the right moment. 

124 parameters = {} 

125 

126 updated_parameters = dict(original_signature.parameters) 

127 for param_name, param in parameters.items(): 

128 original_parameter = original_signature.parameters[param_name] 

129 if lenient_issubclass(param.annotation, click.Context): 

130 # click.Context should not be modified 

131 continue 

132 # We don't know wheter to use pydantic or typer to parse a param without checking if typer supports it. 

133 try: 

134 get_click_param(param) 

135 except click.ClickException: 

136 # We can't raise now. Typer will raise in the right moment. 

137 pass 

138 except RuntimeError: 

139 # TODO: don't use raw str, but copy other annotations 

140 updated_parameter = inspect.Parameter( 

141 param_name, kind=original_parameter.kind, default=original_parameter.default, annotation=str 

142 ) 

143 updated_parameters[param_name] = updated_parameter 

144 

145 new_signature = inspect.Signature( 

146 parameters=list(updated_parameters.values()), return_annotation=original_signature.return_annotation 

147 ) 

148 

149 @copy_type(callback) 

150 @wraps(callback) 

151 def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] 

152 bound_params = original_signature.bind(*args, **kwargs) 

153 for name, value in bound_params.arguments.items(): 

154 try: 

155 type_adapter = pydantic.TypeAdapter(original_signature.parameters[name].annotation) 

156 except pydantic.PydanticSchemaGenerationError: 

157 continue 

158 try: 

159 bound_params.arguments[name] = type_adapter.validate_python(value) 

160 except pydantic.ValidationError as e: 

161 raise BadParameter(message=e.errors()[0]["msg"], param_hint=name) from e 

162 callback(*bound_params.args, **bound_params.kwargs) 

163 

164 wrapper.__signature__ = new_signature # type: ignore 

165 # Copy annotations to make forward references work in Python <= 3.9 

166 wrapper.__annotations__ = {k: v.annotation for k, v in new_signature.parameters.items()} 

167 return wrapper 

168 

169 

170class Typer(TyperBase): 

171 @copy_type(TyperBase.command) 

172 def command(self, *args, **kwargs): 

173 original_decorator = super().command(*args, **kwargs) 

174 

175 def decorator_override(f: CommandFunctionType) -> CommandFunctionType: 

176 f = enable_pydantic(f) 

177 f = enable_pydantic_type_validation(f) 

178 return original_decorator(f) 

179 

180 return decorator_override 

181 

182 

183def run(function: Callable[..., Any]) -> None: 

184 app = Typer(add_completion=False) 

185 app.command()(function) 

186 app()