Coverage for src/pydantic_typer/main.py: 95%
106 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-08-02 17:11 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-08-02 17:11 +0200
1from __future__ import annotations
3import inspect
4from functools import wraps
5from typing import Any
7import click
8import pydantic
9from typer import BadParameter, Option, Typer
10from typer.main import CommandFunctionType, get_click_param, get_params_from_function, lenient_issubclass
11from typer.models import OptionInfo, ParameterInfo
12from typer.utils import _split_annotation_from_typer_annotations
13from typing_extensions import Annotated
15from pydantic_typer.utils import copy_type, deep_update, inspect_signature
17PYDANTIC_FIELD_SEPARATOR = "."
20def _flatten_pydantic_model(
21 model: pydantic.BaseModel, ancestors: list[str], ancestor_typer_param=None
22) -> dict[str, inspect.Parameter]:
23 pydantic_parameters = {}
24 for field_name, field in model.model_fields.items():
25 qualifier = [*ancestors, field_name]
26 sub_name = f"_pydantic_{'_'.join(qualifier)}"
27 if lenient_issubclass(field.annotation, pydantic.BaseModel):
28 # TODO: pass ancestor_typer_param
29 params = _flatten_pydantic_model(field.annotation, qualifier) # type: ignore
30 pydantic_parameters.update(params)
31 else:
32 default = (
33 field.default if field.default is not pydantic.fields._Unset else ... # noqa: SLF001
34 )
35 # Pydantic stores annotations in field.metadata.
36 # If the field is already annotated with a typer.Option or typer.Argument, use that.
37 existing_typer_params = [meta for meta in field.metadata if isinstance(meta, ParameterInfo)]
38 if existing_typer_params:
39 typer_param = existing_typer_params[0]
40 if isinstance(typer_param, OptionInfo) and not typer_param.param_decls:
41 # If the the option was not named manually, use the default naming scheme
42 typer_param.param_decls = (f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}",)
43 elif ancestor_typer_param:
44 typer_param = ancestor_typer_param
45 else:
46 typer_param = Option(f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}")
47 pydantic_parameters[sub_name] = inspect.Parameter(
48 sub_name,
49 inspect.Parameter.KEYWORD_ONLY,
50 annotation=Annotated[field.annotation, typer_param, qualifier],
51 default=default,
52 )
53 return pydantic_parameters
56def enable_pydantic(callback: CommandFunctionType) -> CommandFunctionType:
57 original_signature = inspect_signature(callback)
59 pydantic_parameters = {}
60 pydantic_roots = {}
61 other_parameters = {}
62 for name, parameter in original_signature.parameters.items():
63 base_annotation, typer_annotations = _split_annotation_from_typer_annotations(parameter.annotation)
64 typer_param = typer_annotations[0] if typer_annotations else None
65 if lenient_issubclass(base_annotation, pydantic.BaseModel):
66 params = _flatten_pydantic_model(parameter.annotation, [name], typer_param)
67 pydantic_parameters.update(params)
68 pydantic_roots[name] = base_annotation
69 else:
70 other_parameters[name] = parameter
72 extended_signature = inspect.Signature(
73 [*other_parameters.values(), *pydantic_parameters.values()],
74 return_annotation=original_signature.return_annotation,
75 )
77 @copy_type(callback) 77 ↛ 79line 77 didn't jump to line 79
78 @wraps(callback) 78 ↛ 77line 78 didn't jump to line 77
79 def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] 79 ↛ 78line 79 didn't jump to line 78
80 converted_kwargs = kwargs.copy()
81 raw_pydantic_objects: dict[str, Any] = {}
82 for kwarg_name in kwargs:
83 if kwarg_name in pydantic_parameters:
84 kwarg_value = kwargs[kwarg_name]
85 converted_kwargs.pop(kwarg_name)
86 annotation = pydantic_parameters[kwarg_name].annotation
87 _, qualifier = annotation.__metadata__
88 for part in reversed(qualifier):
89 kwarg_value = {part: kwarg_value}
90 raw_pydantic_objects = deep_update(raw_pydantic_objects, kwarg_value)
91 for root_name, value in raw_pydantic_objects.items():
92 converted_kwargs[root_name] = pydantic_roots[root_name](**value)
93 return callback(*args, **converted_kwargs)
95 wrapper.__signature__ = extended_signature # type: ignore
96 # Copy annotations to make forward references work in Python <= 3.9
97 wrapper.__annotations__ = {k: v.annotation for k, v in extended_signature.parameters.items()}
98 return wrapper
101def enable_pydantic_type_validation(callback: CommandFunctionType) -> CommandFunctionType:
102 original_signature = inspect_signature(callback)
104 # Change the annotation of unsupported types to str to be parsed by pydantic.
105 # Adapted from https://github.com/tiangolo/typer/blob/95b767e38a98ee287a7a0e28176284836e1188c2/typer/main.py#L543
106 # TODO: it's not ideal to call get_params_from_function and get_click_param here,
107 # because it will be called in typer again, but the annotations supported by typer are quite dynamic.
108 parameters = get_params_from_function(callback)
109 updated_parameters = dict(original_signature.parameters)
110 for param_name, param in parameters.items():
111 original_parameter = original_signature.parameters[param_name]
112 if lenient_issubclass(param.annotation, click.Context):
113 continue
114 # We don't know wheter to use pydantic or typer to parse a param without checking if typer supports it.
115 try:
116 get_click_param(param)
117 except RuntimeError as e:
118 # TODO: don't use raw str, but copy other annotations
119 updated_parameter = inspect.Parameter(
120 param_name, kind=original_parameter.kind, default=original_parameter.default, annotation=str
121 )
122 updated_parameters[param_name] = updated_parameter
124 new_signature = inspect.Signature(
125 parameters=list(updated_parameters.values()), return_annotation=original_signature.return_annotation
126 )
128 @copy_type(callback) 128 ↛ 130line 128 didn't jump to line 130
129 @wraps(callback) 129 ↛ 128line 129 didn't jump to line 128
130 def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] 130 ↛ 129line 130 didn't jump to line 129
131 bound_params = original_signature.bind(*args, **kwargs)
132 for name, value in bound_params.arguments.items():
133 try:
134 type_adapter = pydantic.TypeAdapter(original_signature.parameters[name].annotation)
135 except pydantic.PydanticSchemaGenerationError:
136 continue
137 try:
138 bound_params.arguments[name] = type_adapter.validate_python(value)
139 except pydantic.ValidationError as e:
140 raise BadParameter(message=e.errors()[0]["msg"], param_hint=name) from e
141 callback(*bound_params.args, **bound_params.kwargs)
143 wrapper.__signature__ = new_signature # type: ignore
144 # Copy annotations to make forward references work in Python <= 3.9
145 wrapper.__annotations__ = {k: v.annotation for k, v in new_signature.parameters.items()}
146 return wrapper
149class PydanticTyper(Typer):
150 @copy_type(Typer.command)
151 def command(self, *args, **kwargs): 151 ↛ 150line 151 didn't jump to line 150
152 original_decorator = super().command(*args, **kwargs)
154 def decorator_override(f: CommandFunctionType) -> CommandFunctionType:
155 wrapped_f = enable_pydantic_type_validation(f)
156 wrapped_f = enable_pydantic(f)
157 return original_decorator(wrapped_f)
159 return decorator_override