Coverage for audoma/openapi.py: 89%

237 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-08-08 06:12 +0000

1import re 

2import typing 

3from copy import deepcopy 

4from inspect import isclass 

5 

6from drf_spectacular.drainage import get_override 

7from drf_spectacular.extensions import OpenApiSerializerExtension 

8from drf_spectacular.openapi import AutoSchema 

9from drf_spectacular.plumbing import ( 

10 ComponentRegistry, 

11 build_array_type, 

12 error, 

13 force_instance, 

14 sanitize_specification_extensions, 

15) 

16from drf_spectacular.types import OpenApiTypes 

17from drf_spectacular.utils import OpenApiResponse 

18from rest_framework.fields import Field 

19from rest_framework.generics import GenericAPIView 

20from rest_framework.permissions import ( 

21 AND, 

22 OR, 

23 BasePermission, 

24 OperandHolder, 

25 SingleOperandHolder, 

26) 

27from rest_framework.relations import RelatedField 

28from rest_framework.serializers import BaseSerializer 

29from rest_framework.views import APIView 

30 

31from django.views import View 

32 

33from audoma.drf.generics import GenericAPIView as AudomaGenericAPIView 

34from audoma.drf.serializers import BulkSerializerMixin 

35from audoma.drf.validators import ExclusiveFieldsValidator 

36from audoma.links import ( 

37 ChoicesOptionsLink, 

38 ChoicesOptionsLinkSchemaGenerator, 

39) 

40from audoma.plumbing import create_choices_enum_description 

41 

42 

43class AudomaAutoSchema(AutoSchema): 

44 choice_link_schema_generator = ChoicesOptionsLinkSchemaGenerator() 

45 

46 def _handle_permission( 

47 self, 

48 permission_class: typing.Union[ 

49 OperandHolder, SingleOperandHolder, BasePermission 

50 ], 

51 operations: list, 

52 current_operation: typing.Type = AND, 

53 ) -> dict: 

54 permissions = {} 

55 

56 if isinstance(permission_class, OperandHolder): 

57 if permission_class.operator_class == OR and current_operation != OR: 

58 operations.append("(") 

59 permissions.update( 

60 self._handle_permission( 

61 permission_class.op1_class, 

62 operations, 

63 permission_class.operator_class, 

64 ) 

65 ) 

66 if permission_class.operator_class == OR: 

67 operations.append("|") 

68 elif permission_class.operator_class == AND: 

69 operations.append(" & ") 

70 permissions.update( 

71 self._handle_permission( 

72 permission_class.op2_class, 

73 operations, 

74 permission_class.operator_class, 

75 ) 

76 ) 

77 if permission_class.operator_class == OR and current_operation != OR: 

78 operations.append(" )") 

79 elif isinstance(permission_class, SingleOperandHolder): 

80 permissions.update( 

81 self._handle_permission( 

82 permission_class.op1_class, 

83 operations, 

84 permission_class.operator_class, 

85 ) 

86 ) 

87 

88 else: 

89 try: 

90 permissions[permission_class.__name__] = ( 

91 permission_class.get_description(self.view), 

92 ) 

93 except AttributeError: 

94 if permission_class.__doc__: 

95 permissions[ 

96 permission_class.__name__ 

97 ] = permission_class.__doc__.replace("\n", " ").strip() 

98 else: 

99 permissions[ 

100 permission_class.__name__ 

101 ] = "(No description for this permission)" 

102 operations.append(f"`{permission_class.__name__}`") 

103 

104 return permissions 

105 

106 def _get_permissions_description(self) -> str: 

107 permissions = {} 

108 operations = [] 

109 

110 for permission_class in getattr(self.view, "permission_classes", []): 

111 if operations: 

112 operations.append("&") 

113 permissions.update(self._handle_permission(permission_class, operations)) 

114 

115 if permissions: 

116 return ( 

117 "\n\n**Permissions:**\n" 

118 + " ".join(operations) 

119 + "\n" 

120 + "\n".join( 

121 f"+ `{name}`: *{doc_str}*" for name, doc_str in permissions.items() 

122 ) 

123 ) 

124 else: 

125 return "" 

126 

127 def get_description(self) -> str: 

128 description = super().get_description() or "" 

129 description += self._get_permissions_description() 

130 return description 

131 

132 def _extract_action_function(self, view) -> typing.Callable: 

133 action = getattr(view, "action", None) 

134 if not action: 

135 return 

136 

137 return getattr(view, action, None) 

138 

139 def _parse_action_serializers(self, action_serializers) -> dict: 

140 if not action_serializers: 

141 return action_serializers 

142 

143 if isinstance(action_serializers, str): 

144 return {"default": OpenApiResponse(description=action_serializers)} 

145 

146 if not isinstance(action_serializers, dict): 

147 return {"default": action_serializers} 

148 

149 parsed_action_serializers = deepcopy(action_serializers) 

150 

151 for method, method_serializers in action_serializers.items(): 

152 if isinstance(method_serializers, str): 

153 parsed_action_serializers[method] = OpenApiResponse( 

154 description=method_serializers 

155 ) 

156 elif isinstance(method_serializers, dict): 

157 for code, item in method_serializers.items(): 

158 if isinstance(item, str): 

159 parsed_action_serializers[method][code] = OpenApiResponse( 

160 description=item 

161 ) 

162 

163 return parsed_action_serializers 

164 

165 def _parse_action_errors(self, action_errors) -> dict: 

166 if not action_errors: 

167 return action_errors 

168 

169 parsed_errors = {} 

170 for err in action_errors: 

171 if isclass(err): 

172 err = err() 

173 

174 # build properties 

175 properties = {} 

176 for key, value in vars(err).items(): 

177 properties[key] = {key: {"type": type(value).__name__}} 

178 

179 parsed_errors[err.status_code] = OpenApiResponse( 

180 response={ 

181 "type": "object", 

182 "properties": properties, 

183 "example": vars(err), 

184 } 

185 ) 

186 return parsed_errors 

187 

188 def _extract_audoma_action_operations( 

189 self, view: View, serializer_type: str 

190 ) -> dict: 

191 """ 

192 Extracts the audoma action operations from the view 

193 """ 

194 action_function = self._extract_action_function(view) 

195 _audoma = getattr(action_function, "_audoma", None) 

196 if not _audoma: 

197 return {} 

198 

199 if serializer_type == "collect": 

200 action_serializers = getattr(_audoma, "collectors", None) 

201 else: 

202 results = self._parse_action_serializers(getattr(_audoma, "results", None)) 

203 errors = self._parse_action_errors(getattr(_audoma, "errors", {})) 

204 if results: 

205 action_serializers = results 

206 action_serializers.update(errors) 

207 else: 

208 action_serializers = errors 

209 

210 return action_serializers 

211 

212 def _get_serializer( # noqa: C901 

213 self, serializer_type="collect" 

214 ) -> typing.Union[BaseSerializer, typing.Type[BaseSerializer]]: 

215 view = self.view 

216 method = view.request.method 

217 

218 action_serializers = self._extract_audoma_action_operations( 

219 view, serializer_type 

220 ) 

221 if action_serializers: 

222 if ( 

223 isinstance(action_serializers, dict) 

224 and method.lower() in action_serializers 

225 ): 

226 return action_serializers[method.lower()] 

227 else: 

228 return action_serializers 

229 

230 try: 

231 if isinstance(view, AudomaGenericAPIView): 

232 return view.get_serializer_class(type=serializer_type)() 

233 elif isinstance(view, GenericAPIView): 

234 # try to circumvent queryset issues with calling get_serializer. if view has NOT 

235 # overridden get_serializer, its safe to use get_serializer_class. 

236 if view.__class__.get_serializer == GenericAPIView.get_serializer: 

237 return view.get_serializer_class()() 

238 return view.get_serializer() 

239 elif isinstance(view, APIView): 

240 # APIView does not implement the required interface, but be lenient and make 

241 # good guesses before giving up and emitting a warning. 

242 if callable(getattr(view, "get_serializer", None)): 

243 return view.get_serializer() 

244 elif callable(getattr(view, "get_serializer_class", None)): 

245 return view.get_serializer_class()() 

246 elif hasattr(view, "serializer_class"): 

247 return view.serializer_class 

248 else: 

249 error( 

250 "unable to guess serializer. This is graceful " 

251 "fallback handling for APIViews. Consider using GenericAPIView as view base " 

252 "class, if view is under your control. Ignoring view for now. " 

253 ) 

254 else: 

255 error( 

256 "Encountered unknown view base class. Please report this issue. Ignoring for now" 

257 ) 

258 except Exception as exc: 

259 error( 

260 f"exception raised while getting serializer. Hint: " 

261 f"Is get_serializer_class() returning None or is get_queryset() not working without " 

262 f"a request? Ignoring the view for now. (Exception: {exc})" 

263 ) 

264 

265 def get_response_serializers( 

266 self, 

267 ) -> typing.Union[BaseSerializer, typing.Type[BaseSerializer]]: 

268 """overrides this for custom behaviour""" 

269 return self._get_serializer(serializer_type="result") 

270 

271 def _get_enum_choices_for_field(self, field): 

272 if hasattr(field, "original_choices"): 

273 choices = field.original_choices 

274 else: 

275 choices = field.choices 

276 

277 if hasattr(choices, "items"): 

278 choices = choices.items() 

279 

280 return {"choices": {key: value for key, value in choices}} 

281 

282 def _get_link_choices_for_field(self, field, serializer): 

283 link = serializer.choices_options_links.get(field.field_name, None) 

284 if not link: 

285 return 

286 

287 if isinstance(link, dict): 

288 # presume that this dictionary are link kwargs 

289 link.update( 

290 {"field_name": field.field_name, "serializer_class": type(serializer)} 

291 ) 

292 link = ChoicesOptionsLink(**link) 

293 

294 choices = self.choice_link_schema_generator.generate_schema(link) 

295 if choices: 

296 if link.field_name == field.field_name: 

297 return choices 

298 return 

299 

300 def _map_serializer_field( 

301 self, field: Field, direction: str, bypass_extensions=False 

302 ) -> dict: 

303 """ 

304 Allows to use @extend_schema_field with `field` dict so that 

305 it gets updated instead of being overriden 

306 """ 

307 serializer_type = "collect" if direction == "request" else "result" 

308 serializer = self._get_serializer(serializer_type=serializer_type) 

309 serializer = force_instance(serializer) 

310 

311 if ( 

312 hasattr(serializer, "Meta") 

313 and self.is_bulk 

314 and field.field_name == getattr(serializer.Meta, "id_field", None) 

315 ): 

316 field.read_only = False 

317 field.required = True 

318 

319 has_annotation = ( 

320 hasattr(field, "_spectacular_annotation") 

321 and "field" in field._spectacular_annotation 

322 and isinstance(field._spectacular_annotation["field"], dict) 

323 ) 

324 if has_annotation: 

325 annotation = field._spectacular_annotation 

326 field._spectacular_annotation = {} 

327 

328 result = super()._map_serializer_field( 

329 field, direction, bypass_extensions=bypass_extensions 

330 ) 

331 

332 if not isinstance(field, RelatedField) and hasattr(field, "choices"): 

333 result["x-choices"] = self._get_enum_choices_for_field(field) 

334 result["description"] = create_choices_enum_description( 

335 result["x-choices"]["choices"], field.field_name 

336 ) 

337 

338 if hasattr(serializer, "choices_options_links"): 

339 choices = self._get_link_choices_for_field(field, serializer) 

340 if choices: 

341 result["x-choices"] = choices 

342 

343 if has_annotation: 

344 result.update(annotation["field"]) 

345 

346 return result 

347 

348 def _get_request_for_media_type(self, serializer): 

349 

350 schema, request_body_required = super()._get_request_for_media_type(serializer) 

351 if self.is_bulk: 

352 schema = build_array_type(schema) 

353 if isinstance(serializer, BulkSerializerMixin) and self.view.action == "create": 

354 schema = {"oneOf": [build_array_type(schema), schema]} 

355 return schema, request_body_required 

356 

357 def _get_response_for_code(self, serializer, status_code, media_types=None): 

358 

359 schema_resp = super()._get_response_for_code( 

360 serializer, status_code, media_types 

361 ) 

362 

363 if self.is_bulk: 

364 for media_type in schema_resp["content"]: 

365 schema = schema_resp["content"][media_type]["schema"] 

366 schema_resp["content"][media_type]["schema"] = build_array_type(schema) 

367 

368 if isinstance(serializer, BulkSerializerMixin) and self.view.action == "create": 

369 for media_type in schema_resp["content"]: 

370 schema = schema_resp["content"][media_type]["schema"] 

371 schema_resp["content"][media_type]["schema"] = { 

372 "oneOf": [build_array_type(schema), schema] 

373 } 

374 return schema_resp 

375 

376 def _build_exclusive_fields_schema( 

377 self, schema: dict, exclusive_fields: typing.List[str] 

378 ) -> typing.List[dict]: 

379 modified_schemas = [] 

380 for field in exclusive_fields: 

381 new_schema = deepcopy(schema) 

382 new_schema["properties"].pop(field) 

383 modified_schemas.append(new_schema) 

384 return modified_schemas 

385 

386 def _map_serializer( 

387 self, 

388 serializer: typing.Union[ 

389 OpenApiTypes, BaseSerializer, typing.Type[BaseSerializer] 

390 ], 

391 direction: str, 

392 bypass_extensions: bool = False, 

393 ) -> dict: 

394 serializer = force_instance(serializer) 

395 serializer_extension = OpenApiSerializerExtension.get_match(serializer) 

396 

397 if serializer_extension and not bypass_extensions: 

398 schema = serializer_extension.map_serializer(self, direction) 

399 else: 

400 schema = self._map_basic_serializer(serializer, direction) 

401 

402 if hasattr(serializer, "validators") and direction == "request": 

403 subschemas = [] 

404 for validator in serializer.validators: 

405 if isinstance(validator, ExclusiveFieldsValidator): 

406 subschemas += self._build_exclusive_fields_schema( 

407 schema, validator.fields 

408 ) 

409 

410 if subschemas: 

411 schema = {"oneOf": subschemas} 

412 

413 extensions = get_override(serializer, "extensions", {}) 

414 if extensions: 

415 schema.update(sanitize_specification_extensions(extensions)) 

416 

417 return self._postprocess_serializer_schema(schema, serializer, direction) 

418 

419 def get_operation_id(self): 

420 """override this for custom behaviour""" 

421 tokenized_path = self._tokenize_path() 

422 

423 if self.is_bulk: 

424 tokenized_path.append("bulk") 

425 

426 # replace dashes as they can be problematic later in code generation 

427 tokenized_path = [t.replace("-", "_") for t in tokenized_path] 

428 

429 if self.method == "GET" and self._is_list_view(): 

430 action = "list" 

431 else: 

432 action = self.method_mapping[self.method.lower()] 

433 

434 if not tokenized_path: 

435 tokenized_path.append("root") 

436 

437 if re.search(r"<drf_format_suffix\w*:\w+>", self.path_regex): 

438 tokenized_path.append("formatted") 

439 

440 return "_".join(tokenized_path + [action]) 

441 

442 def get_operation( 

443 self, path, path_regex, path_prefix, method, registry: ComponentRegistry 

444 ): 

445 self.is_bulk = False 

446 

447 if "bulk" in getattr(self.view, "action", "") and method in [ 

448 "POST", 

449 "PUT", 

450 "PATCH", 

451 ]: 

452 self.is_bulk = True 

453 

454 return super().get_operation(path, path_regex, path_prefix, method, registry)