Coverage for src/pydantic_typer/main.py: 99%
64 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-21 12:53 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-21 12:53 +0200
1from __future__ import annotations
3import inspect
4from functools import wraps
5from typing import Any, Callable
7import pydantic
8from typer import Option
9from typer.main import lenient_issubclass
10from typer.models import OptionInfo, ParameterInfo
11from typer.utils import _split_annotation_from_typer_annotations
12from typing_extensions import Annotated
14from pydantic_typer.utils import deep_update, inspect_signature
16PYDANTIC_FIELD_SEPARATOR = "."
19def _flatten_pydantic_model(
20 model: pydantic.BaseModel, ancestors: list[str], ancestor_typer_param=None
21) -> dict[str, inspect.Parameter]:
22 pydantic_parameters = {}
23 for field_name, field in model.model_fields.items():
24 qualifier = [*ancestors, field_name]
25 sub_name = f"_pydantic_{'_'.join(qualifier)}"
26 if lenient_issubclass(field.annotation, pydantic.BaseModel):
27 # TODO pass ancestor_typer_param
28 params = _flatten_pydantic_model(field.annotation, qualifier) # type: ignore
29 pydantic_parameters.update(params)
30 else:
31 default = (
32 field.default if field.default is not pydantic.fields._Unset else ... # noqa: SLF001
33 )
34 # Pydantic stores annotations in field.metadata.
35 # If the field is already annotated with a typer.Option or typer.Argument, use that.
36 existing_typer_params = [meta for meta in field.metadata if isinstance(meta, ParameterInfo)]
37 if existing_typer_params:
38 typer_param = existing_typer_params[0]
39 if isinstance(typer_param, OptionInfo) and not typer_param.param_decls: 39 ↛ 46line 39 didn't jump to line 46 because the condition on line 39 was always true
40 # If the the option was not named manually, use the default naming scheme
41 typer_param.param_decls = (f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}",)
42 elif ancestor_typer_param:
43 typer_param = ancestor_typer_param
44 else:
45 typer_param = Option(f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}")
46 pydantic_parameters[sub_name] = inspect.Parameter(
47 sub_name,
48 inspect.Parameter.KEYWORD_ONLY,
49 annotation=Annotated[field.annotation, typer_param, qualifier],
50 default=default,
51 )
52 return pydantic_parameters
55def enable_pydantic(callback: Callable[..., Any]) -> Callable[..., Any]:
56 original_signature = inspect_signature(callback)
58 pydantic_parameters = {}
59 pydantic_roots = {}
60 other_parameters = {}
61 for name, parameter in original_signature.parameters.items():
62 base_annotation, typer_annotations = _split_annotation_from_typer_annotations(parameter.annotation)
63 typer_param = typer_annotations[0] if typer_annotations else None
64 if lenient_issubclass(base_annotation, pydantic.BaseModel):
65 params = _flatten_pydantic_model(parameter.annotation, [name], typer_param)
66 pydantic_parameters.update(params)
67 pydantic_roots[name] = base_annotation
68 else:
69 other_parameters[name] = parameter
71 extended_signature = inspect.Signature(
72 [*other_parameters.values(), *pydantic_parameters.values()],
73 return_annotation=original_signature.return_annotation,
74 )
76 @wraps(callback)
77 def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
78 converted_kwargs = kwargs.copy()
79 raw_pydantic_objects: dict[str, Any] = {}
80 for kwarg_name in kwargs:
81 if kwarg_name in pydantic_parameters:
82 kwarg_value = kwargs[kwarg_name]
83 converted_kwargs.pop(kwarg_name)
84 annotation = pydantic_parameters[kwarg_name].annotation
85 _, qualifier = annotation.__metadata__
86 for part in reversed(qualifier):
87 kwarg_value = {part: kwarg_value}
88 raw_pydantic_objects = deep_update(raw_pydantic_objects, kwarg_value)
89 for root_name, value in raw_pydantic_objects.items():
90 converted_kwargs[root_name] = pydantic_roots[root_name](**value)
91 return callback(*args, **converted_kwargs)
93 wrapper.__signature__ = extended_signature # type: ignore
94 # Copy annotations to make forward references work in Python <= 3.9
95 wrapper.__annotations__ = {k: v.annotation for k, v in extended_signature.parameters.items()}
96 return wrapper