Coverage for src/pydantic_typer/main.py: 95%

106 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-08-02 17:11 +0200

1from __future__ import annotations 

2 

3import inspect 

4from functools import wraps 

5from typing import Any 

6 

7import click 

8import pydantic 

9from typer import BadParameter, Option, Typer 

10from typer.main import CommandFunctionType, get_click_param, get_params_from_function, lenient_issubclass 

11from typer.models import OptionInfo, ParameterInfo 

12from typer.utils import _split_annotation_from_typer_annotations 

13from typing_extensions import Annotated 

14 

15from pydantic_typer.utils import copy_type, deep_update, inspect_signature 

16 

17PYDANTIC_FIELD_SEPARATOR = "." 

18 

19 

20def _flatten_pydantic_model( 

21 model: pydantic.BaseModel, ancestors: list[str], ancestor_typer_param=None 

22) -> dict[str, inspect.Parameter]: 

23 pydantic_parameters = {} 

24 for field_name, field in model.model_fields.items(): 

25 qualifier = [*ancestors, field_name] 

26 sub_name = f"_pydantic_{'_'.join(qualifier)}" 

27 if lenient_issubclass(field.annotation, pydantic.BaseModel): 

28 # TODO: pass ancestor_typer_param 

29 params = _flatten_pydantic_model(field.annotation, qualifier) # type: ignore 

30 pydantic_parameters.update(params) 

31 else: 

32 default = ( 

33 field.default if field.default is not pydantic.fields._Unset else ... # noqa: SLF001 

34 ) 

35 # Pydantic stores annotations in field.metadata. 

36 # If the field is already annotated with a typer.Option or typer.Argument, use that. 

37 existing_typer_params = [meta for meta in field.metadata if isinstance(meta, ParameterInfo)] 

38 if existing_typer_params: 

39 typer_param = existing_typer_params[0] 

40 if isinstance(typer_param, OptionInfo) and not typer_param.param_decls: 

41 # If the the option was not named manually, use the default naming scheme 

42 typer_param.param_decls = (f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}",) 

43 elif ancestor_typer_param: 

44 typer_param = ancestor_typer_param 

45 else: 

46 typer_param = Option(f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}") 

47 pydantic_parameters[sub_name] = inspect.Parameter( 

48 sub_name, 

49 inspect.Parameter.KEYWORD_ONLY, 

50 annotation=Annotated[field.annotation, typer_param, qualifier], 

51 default=default, 

52 ) 

53 return pydantic_parameters 

54 

55 

56def enable_pydantic(callback: CommandFunctionType) -> CommandFunctionType: 

57 original_signature = inspect_signature(callback) 

58 

59 pydantic_parameters = {} 

60 pydantic_roots = {} 

61 other_parameters = {} 

62 for name, parameter in original_signature.parameters.items(): 

63 base_annotation, typer_annotations = _split_annotation_from_typer_annotations(parameter.annotation) 

64 typer_param = typer_annotations[0] if typer_annotations else None 

65 if lenient_issubclass(base_annotation, pydantic.BaseModel): 

66 params = _flatten_pydantic_model(parameter.annotation, [name], typer_param) 

67 pydantic_parameters.update(params) 

68 pydantic_roots[name] = base_annotation 

69 else: 

70 other_parameters[name] = parameter 

71 

72 extended_signature = inspect.Signature( 

73 [*other_parameters.values(), *pydantic_parameters.values()], 

74 return_annotation=original_signature.return_annotation, 

75 ) 

76 

77 @copy_type(callback) 77 ↛ 79line 77 didn't jump to line 79

78 @wraps(callback) 78 ↛ 77line 78 didn't jump to line 77

79 def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] 79 ↛ 78line 79 didn't jump to line 78

80 converted_kwargs = kwargs.copy() 

81 raw_pydantic_objects: dict[str, Any] = {} 

82 for kwarg_name in kwargs: 

83 if kwarg_name in pydantic_parameters: 

84 kwarg_value = kwargs[kwarg_name] 

85 converted_kwargs.pop(kwarg_name) 

86 annotation = pydantic_parameters[kwarg_name].annotation 

87 _, qualifier = annotation.__metadata__ 

88 for part in reversed(qualifier): 

89 kwarg_value = {part: kwarg_value} 

90 raw_pydantic_objects = deep_update(raw_pydantic_objects, kwarg_value) 

91 for root_name, value in raw_pydantic_objects.items(): 

92 converted_kwargs[root_name] = pydantic_roots[root_name](**value) 

93 return callback(*args, **converted_kwargs) 

94 

95 wrapper.__signature__ = extended_signature # type: ignore 

96 # Copy annotations to make forward references work in Python <= 3.9 

97 wrapper.__annotations__ = {k: v.annotation for k, v in extended_signature.parameters.items()} 

98 return wrapper 

99 

100 

101def enable_pydantic_type_validation(callback: CommandFunctionType) -> CommandFunctionType: 

102 original_signature = inspect_signature(callback) 

103 

104 # Change the annotation of unsupported types to str to be parsed by pydantic. 

105 # Adapted from https://github.com/tiangolo/typer/blob/95b767e38a98ee287a7a0e28176284836e1188c2/typer/main.py#L543 

106 # TODO: it's not ideal to call get_params_from_function and get_click_param here, 

107 # because it will be called in typer again, but the annotations supported by typer are quite dynamic. 

108 parameters = get_params_from_function(callback) 

109 updated_parameters = dict(original_signature.parameters) 

110 for param_name, param in parameters.items(): 

111 original_parameter = original_signature.parameters[param_name] 

112 if lenient_issubclass(param.annotation, click.Context): 

113 continue 

114 # We don't know wheter to use pydantic or typer to parse a param without checking if typer supports it. 

115 try: 

116 get_click_param(param) 

117 except RuntimeError as e: 

118 # TODO: don't use raw str, but copy other annotations 

119 updated_parameter = inspect.Parameter( 

120 param_name, kind=original_parameter.kind, default=original_parameter.default, annotation=str 

121 ) 

122 updated_parameters[param_name] = updated_parameter 

123 

124 new_signature = inspect.Signature( 

125 parameters=list(updated_parameters.values()), return_annotation=original_signature.return_annotation 

126 ) 

127 

128 @copy_type(callback) 128 ↛ 130line 128 didn't jump to line 130

129 @wraps(callback) 129 ↛ 128line 129 didn't jump to line 128

130 def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] 130 ↛ 129line 130 didn't jump to line 129

131 bound_params = original_signature.bind(*args, **kwargs) 

132 for name, value in bound_params.arguments.items(): 

133 try: 

134 type_adapter = pydantic.TypeAdapter(original_signature.parameters[name].annotation) 

135 except pydantic.PydanticSchemaGenerationError: 

136 continue 

137 try: 

138 bound_params.arguments[name] = type_adapter.validate_python(value) 

139 except pydantic.ValidationError as e: 

140 raise BadParameter(message=e.errors()[0]["msg"], param_hint=name) from e 

141 callback(*bound_params.args, **bound_params.kwargs) 

142 

143 wrapper.__signature__ = new_signature # type: ignore 

144 # Copy annotations to make forward references work in Python <= 3.9 

145 wrapper.__annotations__ = {k: v.annotation for k, v in new_signature.parameters.items()} 

146 return wrapper 

147 

148 

149class PydanticTyper(Typer): 

150 @copy_type(Typer.command) 

151 def command(self, *args, **kwargs): 151 ↛ 150line 151 didn't jump to line 150

152 original_decorator = super().command(*args, **kwargs) 

153 

154 def decorator_override(f: CommandFunctionType) -> CommandFunctionType: 

155 wrapped_f = enable_pydantic_type_validation(f) 

156 wrapped_f = enable_pydantic(f) 

157 return original_decorator(wrapped_f) 

158 

159 return decorator_override