Coverage for src/pydantic_typer/main.py: 99%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-21 12:53 +0200

1from __future__ import annotations 

2 

3import inspect 

4from functools import wraps 

5from typing import Any, Callable 

6 

7import pydantic 

8from typer import Option 

9from typer.main import lenient_issubclass 

10from typer.models import OptionInfo, ParameterInfo 

11from typer.utils import _split_annotation_from_typer_annotations 

12from typing_extensions import Annotated 

13 

14from pydantic_typer.utils import deep_update, inspect_signature 

15 

16PYDANTIC_FIELD_SEPARATOR = "." 

17 

18 

19def _flatten_pydantic_model( 

20 model: pydantic.BaseModel, ancestors: list[str], ancestor_typer_param=None 

21) -> dict[str, inspect.Parameter]: 

22 pydantic_parameters = {} 

23 for field_name, field in model.model_fields.items(): 

24 qualifier = [*ancestors, field_name] 

25 sub_name = f"_pydantic_{'_'.join(qualifier)}" 

26 if lenient_issubclass(field.annotation, pydantic.BaseModel): 

27 # TODO pass ancestor_typer_param 

28 params = _flatten_pydantic_model(field.annotation, qualifier) # type: ignore 

29 pydantic_parameters.update(params) 

30 else: 

31 default = ( 

32 field.default if field.default is not pydantic.fields._Unset else ... # noqa: SLF001 

33 ) 

34 # Pydantic stores annotations in field.metadata. 

35 # If the field is already annotated with a typer.Option or typer.Argument, use that. 

36 existing_typer_params = [meta for meta in field.metadata if isinstance(meta, ParameterInfo)] 

37 if existing_typer_params: 

38 typer_param = existing_typer_params[0] 

39 if isinstance(typer_param, OptionInfo) and not typer_param.param_decls: 39 ↛ 46line 39 didn't jump to line 46 because the condition on line 39 was always true

40 # If the the option was not named manually, use the default naming scheme 

41 typer_param.param_decls = (f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}",) 

42 elif ancestor_typer_param: 

43 typer_param = ancestor_typer_param 

44 else: 

45 typer_param = Option(f"--{PYDANTIC_FIELD_SEPARATOR.join(qualifier)}") 

46 pydantic_parameters[sub_name] = inspect.Parameter( 

47 sub_name, 

48 inspect.Parameter.KEYWORD_ONLY, 

49 annotation=Annotated[field.annotation, typer_param, qualifier], 

50 default=default, 

51 ) 

52 return pydantic_parameters 

53 

54 

55def enable_pydantic(callback: Callable[..., Any]) -> Callable[..., Any]: 

56 original_signature = inspect_signature(callback) 

57 

58 pydantic_parameters = {} 

59 pydantic_roots = {} 

60 other_parameters = {} 

61 for name, parameter in original_signature.parameters.items(): 

62 base_annotation, typer_annotations = _split_annotation_from_typer_annotations(parameter.annotation) 

63 typer_param = typer_annotations[0] if typer_annotations else None 

64 if lenient_issubclass(base_annotation, pydantic.BaseModel): 

65 params = _flatten_pydantic_model(parameter.annotation, [name], typer_param) 

66 pydantic_parameters.update(params) 

67 pydantic_roots[name] = base_annotation 

68 else: 

69 other_parameters[name] = parameter 

70 

71 extended_signature = inspect.Signature( 

72 [*other_parameters.values(), *pydantic_parameters.values()], 

73 return_annotation=original_signature.return_annotation, 

74 ) 

75 

76 @wraps(callback) 

77 def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] 

78 converted_kwargs = kwargs.copy() 

79 raw_pydantic_objects: dict[str, Any] = {} 

80 for kwarg_name in kwargs: 

81 if kwarg_name in pydantic_parameters: 

82 kwarg_value = kwargs[kwarg_name] 

83 converted_kwargs.pop(kwarg_name) 

84 annotation = pydantic_parameters[kwarg_name].annotation 

85 _, qualifier = annotation.__metadata__ 

86 for part in reversed(qualifier): 

87 kwarg_value = {part: kwarg_value} 

88 raw_pydantic_objects = deep_update(raw_pydantic_objects, kwarg_value) 

89 for root_name, value in raw_pydantic_objects.items(): 

90 converted_kwargs[root_name] = pydantic_roots[root_name](**value) 

91 return callback(*args, **converted_kwargs) 

92 

93 wrapper.__signature__ = extended_signature # type: ignore 

94 # Copy annotations to make forward references work in Python <= 3.9 

95 wrapper.__annotations__ = {k: v.annotation for k, v in extended_signature.parameters.items()} 

96 return wrapper