Coverage for src/sleazy/__init__.py: 100%

113 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-12 17:26 +0200

1# sleazy - cli+easy 

2import argparse 

3import types 

4import typing as t 

5 

6from .__about__ import __version__ 

7 

8class TypedDict(t.TypedDict): ... 

9 

10D = t.TypeVar("D", bound=TypedDict) 

11 

12# internal 

13def parse_count_spec(spec: str) -> str | int: 

14 """Parse a count specification into argparse nargs format. 

15 

16 Only the following are allowed (same as argparse): 

17 - exactly 1 (default) 

18 - exact integer N 

19 - '+' for one or more 

20 - '*' for zero or more 

21 - '?' for zero or one 

22 

23 """ 

24 if spec in (None, ""): 

25 return 1 

26 

27 # Exact numeric values 

28 if isinstance(spec, int) or spec.isdigit(): 

29 return int(spec) 

30 

31 # Direct argparse-style symbols 

32 if spec in ("+", "*", "?"): 

33 return spec 

34 

35 # unsupported spec 

36 raise SyntaxError(f"Unexpected '{spec}'. Please choose from [+, *, ?, n]") 

37 

38 

39def strip_optional(tp: t.Type) -> t.Type: 

40 """Remove Optional[...] or | None from a type.""" 

41 

42 # Get the origin (e.g., Union) for both legacy and new union types (PEP 604) 

43 origin = t.get_origin(tp) 

44 

45 # Handle Union types (both legacy Optional[...] and new | None) 

46 if origin is t.Union or isinstance(tp, types.UnionType): 

47 args = t.get_args(tp) # __args__ holds the union members 

48 # Remove `NoneType` (type(None)) from the union args 

49 args = tuple(a for a in args if a is not types.NoneType) 

50 if len(args) == 1: 

51 return args[0] # If only one type remains, return it directly 

52 return t.Union[args] # Otherwise, return the filtered union 

53 

54 return tp # Return the type as-is if it's not a Union or Optional 

55 

56 

57def parse(typeddict_cls: t.Type[D], args: t.Optional[list[str]] = None) -> D: 

58 parser = argparse.ArgumentParser() 

59 type_hints = t.get_type_hints(typeddict_cls, include_extras=True) 

60 type_hints = {k: strip_optional(v) for k, v in type_hints.items()} 

61 

62 # First, add all positional arguments 

63 positional_fields = {} 

64 for field, hint in type_hints.items(): 

65 # Check if it's a positional argument 

66 is_positional = False 

67 arg_type = hint 

68 nargs_value = 1 # Default is required single argument 

69 is_list_type = False 

70 

71 if t.get_origin(hint) is t.Annotated: 

72 arg_type, *annotations = t.get_args(hint) 

73 

74 # Check if the type is a list 

75 is_list_type = t.get_origin(arg_type) is list 

76 

77 for anno in annotations: 

78 # Support for positional counts - now directly parse the count spec 

79 if isinstance(anno, str | int): 

80 is_positional = True 

81 nargs_value = parse_count_spec(anno) 

82 

83 if is_positional: 

84 positional_fields[field] = (arg_type, nargs_value, is_list_type) 

85 

86 # Add positional arguments in their own group 

87 for field, (arg_type, nargs_value, is_list_type) in positional_fields.items(): 

88 # Handle Literal types 

89 if t.get_origin(arg_type) is t.Literal: 

90 # Use first value's type as the parser type 

91 if literal_values := t.get_args(arg_type): 

92 first_value = literal_values[0] 

93 parser_type = type(first_value) 

94 

95 if nargs_value == 1: 

96 # convert to default (None) to prevent getting a list of 1 element 

97 nargs_value = None 

98 

99 parser.add_argument( 

100 field, 

101 type=parser_type, 

102 nargs=nargs_value, 

103 default=None, 

104 choices=literal_values, 

105 ) 

106 else: # pragma: no cover 

107 raise TypeError("Plain typing.Literal is not valid as type argument") 

108 elif is_list_type: 

109 # For list types, get the element type 

110 elem_type = t.get_args(arg_type)[0] if t.get_args(arg_type) else str 

111 parser.add_argument(field, type=elem_type, nargs=nargs_value, default=None) 

112 else: 

113 # For non-list types, ensure single values are not put in a list 

114 # when nargs is a numeric value 

115 if isinstance(nargs_value, int) and nargs_value == 1 and not is_list_type: 

116 # For exactly 1 argument that's not a list type, don't use nargs 

117 parser.add_argument(field, type=arg_type, default=None) 

118 else: 

119 parser.add_argument( 

120 field, type=arg_type, nargs=nargs_value, default=None 

121 ) 

122 

123 # Then add all optional arguments 

124 for field, hint in type_hints.items(): 

125 # Skip positional arguments as they've already been added 

126 if field in positional_fields: 

127 continue 

128 

129 arg_type = hint 

130 

131 if t.get_origin(hint) is t.Annotated: 

132 arg_type, *_ = t.get_args(hint) 

133 

134 # Check if the type is a list 

135 is_list_type = t.get_origin(arg_type) is list 

136 

137 # Handle Literal types in optional arguments 

138 if t.get_origin(arg_type) is t.Literal: 

139 if literal_values := t.get_args(arg_type): 

140 first_value = literal_values[0] 

141 parser_type = type(first_value) 

142 parser.add_argument( 

143 f"--{field.replace('_', '-')}", 

144 type=parser_type, 

145 choices=literal_values, 

146 ) 

147 else: # pragma: no cover 

148 raise TypeError("Plain typing.Literal is not valid as type argument") 

149 elif arg_type is bool: 

150 parser.add_argument(f"--{field.replace('_', '-')}", action="store_true") 

151 elif is_list_type: 

152 # For list types, use 'append' action to collect multiple instances 

153 elem_type = t.get_args(arg_type)[0] if t.get_args(arg_type) else str 

154 parser.add_argument( 

155 f"--{field.replace('_', '-')}", 

156 type=elem_type, 

157 action="append", 

158 ) 

159 else: 

160 parser.add_argument(f"--{field.replace('_', '-')}", type=arg_type) 

161 

162 return vars(parser.parse_args(args)) 

163 

164 

165def stringify(data: D, typeddict_cls: t.Type[D] = None) -> list[str]: 

166 """ 

167 Convert a TypedDict instance to a list of command-line arguments. 

168 Positional arguments come first, followed by optional arguments. 

169 """ 

170 args = [] 

171 typeddict_cls = typeddict_cls or data.__class__ 

172 type_hints = t.get_type_hints(typeddict_cls, include_extras=True) 

173 

174 # Process positional arguments first 

175 positional_fields = [] 

176 for field, hint in type_hints.items(): 

177 is_positional = False 

178 nargs_value = "?" # Default 

179 

180 if t.get_origin(hint) is t.Annotated: 

181 _, *annotations = t.get_args(hint) 

182 for anno in annotations: 

183 # Support for positional counts with dynamic parsing 

184 if isinstance(anno, str | int): 

185 is_positional = True 

186 nargs_value = parse_count_spec(anno) 

187 

188 if is_positional: 

189 positional_fields.append((field, nargs_value)) 

190 

191 # Add positional arguments 

192 for field, nargs_value in positional_fields: 

193 if field in data and data[field] is not None: 

194 if isinstance(data[field], list) and nargs_value in ["*", "+"]: 

195 for item in data[field]: 

196 args.append(str(item)) 

197 else: 

198 args.append(str(data[field])) 

199 

200 # Add optional arguments 

201 for field, value in data.items(): 

202 # Skip positional arguments as they've already been added 

203 if field in [f for f, _ in positional_fields]: 

204 continue 

205 

206 

207 # Skip None values 

208 if value is None: 

209 continue 

210 

211 if isinstance(value, bool): 

212 if value: # Only add flag if True 

213 args.append(f"--{field.replace('_', '-')}") 

214 elif isinstance(value, list): 

215 # For list types, add each item as a separate flag occurrence 

216 for item in value: 

217 args.append(f"--{field.replace('_', '-')}") 

218 args.append(str(item)) 

219 else: 

220 args.append(f"--{field.replace('_', '-')}") 

221 args.append(str(value)) 

222 

223 return args 

224 

225__all__ = [ 

226 "__version__", 

227 "parse", 

228 "stringify", 

229 "TypedDict", 

230]