Coverage for src/sleazy/__init__.py: 100%
113 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-12 17:26 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-12 17:26 +0200
1# sleazy - cli+easy
2import argparse
3import types
4import typing as t
6from .__about__ import __version__
8class TypedDict(t.TypedDict): ...
10D = t.TypeVar("D", bound=TypedDict)
12# internal
13def parse_count_spec(spec: str) -> str | int:
14 """Parse a count specification into argparse nargs format.
16 Only the following are allowed (same as argparse):
17 - exactly 1 (default)
18 - exact integer N
19 - '+' for one or more
20 - '*' for zero or more
21 - '?' for zero or one
23 """
24 if spec in (None, ""):
25 return 1
27 # Exact numeric values
28 if isinstance(spec, int) or spec.isdigit():
29 return int(spec)
31 # Direct argparse-style symbols
32 if spec in ("+", "*", "?"):
33 return spec
35 # unsupported spec
36 raise SyntaxError(f"Unexpected '{spec}'. Please choose from [+, *, ?, n]")
39def strip_optional(tp: t.Type) -> t.Type:
40 """Remove Optional[...] or | None from a type."""
42 # Get the origin (e.g., Union) for both legacy and new union types (PEP 604)
43 origin = t.get_origin(tp)
45 # Handle Union types (both legacy Optional[...] and new | None)
46 if origin is t.Union or isinstance(tp, types.UnionType):
47 args = t.get_args(tp) # __args__ holds the union members
48 # Remove `NoneType` (type(None)) from the union args
49 args = tuple(a for a in args if a is not types.NoneType)
50 if len(args) == 1:
51 return args[0] # If only one type remains, return it directly
52 return t.Union[args] # Otherwise, return the filtered union
54 return tp # Return the type as-is if it's not a Union or Optional
57def parse(typeddict_cls: t.Type[D], args: t.Optional[list[str]] = None) -> D:
58 parser = argparse.ArgumentParser()
59 type_hints = t.get_type_hints(typeddict_cls, include_extras=True)
60 type_hints = {k: strip_optional(v) for k, v in type_hints.items()}
62 # First, add all positional arguments
63 positional_fields = {}
64 for field, hint in type_hints.items():
65 # Check if it's a positional argument
66 is_positional = False
67 arg_type = hint
68 nargs_value = 1 # Default is required single argument
69 is_list_type = False
71 if t.get_origin(hint) is t.Annotated:
72 arg_type, *annotations = t.get_args(hint)
74 # Check if the type is a list
75 is_list_type = t.get_origin(arg_type) is list
77 for anno in annotations:
78 # Support for positional counts - now directly parse the count spec
79 if isinstance(anno, str | int):
80 is_positional = True
81 nargs_value = parse_count_spec(anno)
83 if is_positional:
84 positional_fields[field] = (arg_type, nargs_value, is_list_type)
86 # Add positional arguments in their own group
87 for field, (arg_type, nargs_value, is_list_type) in positional_fields.items():
88 # Handle Literal types
89 if t.get_origin(arg_type) is t.Literal:
90 # Use first value's type as the parser type
91 if literal_values := t.get_args(arg_type):
92 first_value = literal_values[0]
93 parser_type = type(first_value)
95 if nargs_value == 1:
96 # convert to default (None) to prevent getting a list of 1 element
97 nargs_value = None
99 parser.add_argument(
100 field,
101 type=parser_type,
102 nargs=nargs_value,
103 default=None,
104 choices=literal_values,
105 )
106 else: # pragma: no cover
107 raise TypeError("Plain typing.Literal is not valid as type argument")
108 elif is_list_type:
109 # For list types, get the element type
110 elem_type = t.get_args(arg_type)[0] if t.get_args(arg_type) else str
111 parser.add_argument(field, type=elem_type, nargs=nargs_value, default=None)
112 else:
113 # For non-list types, ensure single values are not put in a list
114 # when nargs is a numeric value
115 if isinstance(nargs_value, int) and nargs_value == 1 and not is_list_type:
116 # For exactly 1 argument that's not a list type, don't use nargs
117 parser.add_argument(field, type=arg_type, default=None)
118 else:
119 parser.add_argument(
120 field, type=arg_type, nargs=nargs_value, default=None
121 )
123 # Then add all optional arguments
124 for field, hint in type_hints.items():
125 # Skip positional arguments as they've already been added
126 if field in positional_fields:
127 continue
129 arg_type = hint
131 if t.get_origin(hint) is t.Annotated:
132 arg_type, *_ = t.get_args(hint)
134 # Check if the type is a list
135 is_list_type = t.get_origin(arg_type) is list
137 # Handle Literal types in optional arguments
138 if t.get_origin(arg_type) is t.Literal:
139 if literal_values := t.get_args(arg_type):
140 first_value = literal_values[0]
141 parser_type = type(first_value)
142 parser.add_argument(
143 f"--{field.replace('_', '-')}",
144 type=parser_type,
145 choices=literal_values,
146 )
147 else: # pragma: no cover
148 raise TypeError("Plain typing.Literal is not valid as type argument")
149 elif arg_type is bool:
150 parser.add_argument(f"--{field.replace('_', '-')}", action="store_true")
151 elif is_list_type:
152 # For list types, use 'append' action to collect multiple instances
153 elem_type = t.get_args(arg_type)[0] if t.get_args(arg_type) else str
154 parser.add_argument(
155 f"--{field.replace('_', '-')}",
156 type=elem_type,
157 action="append",
158 )
159 else:
160 parser.add_argument(f"--{field.replace('_', '-')}", type=arg_type)
162 return vars(parser.parse_args(args))
165def stringify(data: D, typeddict_cls: t.Type[D] = None) -> list[str]:
166 """
167 Convert a TypedDict instance to a list of command-line arguments.
168 Positional arguments come first, followed by optional arguments.
169 """
170 args = []
171 typeddict_cls = typeddict_cls or data.__class__
172 type_hints = t.get_type_hints(typeddict_cls, include_extras=True)
174 # Process positional arguments first
175 positional_fields = []
176 for field, hint in type_hints.items():
177 is_positional = False
178 nargs_value = "?" # Default
180 if t.get_origin(hint) is t.Annotated:
181 _, *annotations = t.get_args(hint)
182 for anno in annotations:
183 # Support for positional counts with dynamic parsing
184 if isinstance(anno, str | int):
185 is_positional = True
186 nargs_value = parse_count_spec(anno)
188 if is_positional:
189 positional_fields.append((field, nargs_value))
191 # Add positional arguments
192 for field, nargs_value in positional_fields:
193 if field in data and data[field] is not None:
194 if isinstance(data[field], list) and nargs_value in ["*", "+"]:
195 for item in data[field]:
196 args.append(str(item))
197 else:
198 args.append(str(data[field]))
200 # Add optional arguments
201 for field, value in data.items():
202 # Skip positional arguments as they've already been added
203 if field in [f for f, _ in positional_fields]:
204 continue
207 # Skip None values
208 if value is None:
209 continue
211 if isinstance(value, bool):
212 if value: # Only add flag if True
213 args.append(f"--{field.replace('_', '-')}")
214 elif isinstance(value, list):
215 # For list types, add each item as a separate flag occurrence
216 for item in value:
217 args.append(f"--{field.replace('_', '-')}")
218 args.append(str(item))
219 else:
220 args.append(f"--{field.replace('_', '-')}")
221 args.append(str(value))
223 return args
225__all__ = [
226 "__version__",
227 "parse",
228 "stringify",
229 "TypedDict",
230]