Coverage for audoma/openapi.py: 89%
237 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-08 06:12 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-08 06:12 +0000
1import re
2import typing
3from copy import deepcopy
4from inspect import isclass
6from drf_spectacular.drainage import get_override
7from drf_spectacular.extensions import OpenApiSerializerExtension
8from drf_spectacular.openapi import AutoSchema
9from drf_spectacular.plumbing import (
10 ComponentRegistry,
11 build_array_type,
12 error,
13 force_instance,
14 sanitize_specification_extensions,
15)
16from drf_spectacular.types import OpenApiTypes
17from drf_spectacular.utils import OpenApiResponse
18from rest_framework.fields import Field
19from rest_framework.generics import GenericAPIView
20from rest_framework.permissions import (
21 AND,
22 OR,
23 BasePermission,
24 OperandHolder,
25 SingleOperandHolder,
26)
27from rest_framework.relations import RelatedField
28from rest_framework.serializers import BaseSerializer
29from rest_framework.views import APIView
31from django.views import View
33from audoma.drf.generics import GenericAPIView as AudomaGenericAPIView
34from audoma.drf.serializers import BulkSerializerMixin
35from audoma.drf.validators import ExclusiveFieldsValidator
36from audoma.links import (
37 ChoicesOptionsLink,
38 ChoicesOptionsLinkSchemaGenerator,
39)
40from audoma.plumbing import create_choices_enum_description
43class AudomaAutoSchema(AutoSchema):
44 choice_link_schema_generator = ChoicesOptionsLinkSchemaGenerator()
46 def _handle_permission(
47 self,
48 permission_class: typing.Union[
49 OperandHolder, SingleOperandHolder, BasePermission
50 ],
51 operations: list,
52 current_operation: typing.Type = AND,
53 ) -> dict:
54 permissions = {}
56 if isinstance(permission_class, OperandHolder):
57 if permission_class.operator_class == OR and current_operation != OR:
58 operations.append("(")
59 permissions.update(
60 self._handle_permission(
61 permission_class.op1_class,
62 operations,
63 permission_class.operator_class,
64 )
65 )
66 if permission_class.operator_class == OR:
67 operations.append("|")
68 elif permission_class.operator_class == AND:
69 operations.append(" & ")
70 permissions.update(
71 self._handle_permission(
72 permission_class.op2_class,
73 operations,
74 permission_class.operator_class,
75 )
76 )
77 if permission_class.operator_class == OR and current_operation != OR:
78 operations.append(" )")
79 elif isinstance(permission_class, SingleOperandHolder):
80 permissions.update(
81 self._handle_permission(
82 permission_class.op1_class,
83 operations,
84 permission_class.operator_class,
85 )
86 )
88 else:
89 try:
90 permissions[permission_class.__name__] = (
91 permission_class.get_description(self.view),
92 )
93 except AttributeError:
94 if permission_class.__doc__:
95 permissions[
96 permission_class.__name__
97 ] = permission_class.__doc__.replace("\n", " ").strip()
98 else:
99 permissions[
100 permission_class.__name__
101 ] = "(No description for this permission)"
102 operations.append(f"`{permission_class.__name__}`")
104 return permissions
106 def _get_permissions_description(self) -> str:
107 permissions = {}
108 operations = []
110 for permission_class in getattr(self.view, "permission_classes", []):
111 if operations:
112 operations.append("&")
113 permissions.update(self._handle_permission(permission_class, operations))
115 if permissions:
116 return (
117 "\n\n**Permissions:**\n"
118 + " ".join(operations)
119 + "\n"
120 + "\n".join(
121 f"+ `{name}`: *{doc_str}*" for name, doc_str in permissions.items()
122 )
123 )
124 else:
125 return ""
127 def get_description(self) -> str:
128 description = super().get_description() or ""
129 description += self._get_permissions_description()
130 return description
132 def _extract_action_function(self, view) -> typing.Callable:
133 action = getattr(view, "action", None)
134 if not action:
135 return
137 return getattr(view, action, None)
139 def _parse_action_serializers(self, action_serializers) -> dict:
140 if not action_serializers:
141 return action_serializers
143 if isinstance(action_serializers, str):
144 return {"default": OpenApiResponse(description=action_serializers)}
146 if not isinstance(action_serializers, dict):
147 return {"default": action_serializers}
149 parsed_action_serializers = deepcopy(action_serializers)
151 for method, method_serializers in action_serializers.items():
152 if isinstance(method_serializers, str):
153 parsed_action_serializers[method] = OpenApiResponse(
154 description=method_serializers
155 )
156 elif isinstance(method_serializers, dict):
157 for code, item in method_serializers.items():
158 if isinstance(item, str):
159 parsed_action_serializers[method][code] = OpenApiResponse(
160 description=item
161 )
163 return parsed_action_serializers
165 def _parse_action_errors(self, action_errors) -> dict:
166 if not action_errors:
167 return action_errors
169 parsed_errors = {}
170 for err in action_errors:
171 if isclass(err):
172 err = err()
174 # build properties
175 properties = {}
176 for key, value in vars(err).items():
177 properties[key] = {key: {"type": type(value).__name__}}
179 parsed_errors[err.status_code] = OpenApiResponse(
180 response={
181 "type": "object",
182 "properties": properties,
183 "example": vars(err),
184 }
185 )
186 return parsed_errors
188 def _extract_audoma_action_operations(
189 self, view: View, serializer_type: str
190 ) -> dict:
191 """
192 Extracts the audoma action operations from the view
193 """
194 action_function = self._extract_action_function(view)
195 _audoma = getattr(action_function, "_audoma", None)
196 if not _audoma:
197 return {}
199 if serializer_type == "collect":
200 action_serializers = getattr(_audoma, "collectors", None)
201 else:
202 results = self._parse_action_serializers(getattr(_audoma, "results", None))
203 errors = self._parse_action_errors(getattr(_audoma, "errors", {}))
204 if results:
205 action_serializers = results
206 action_serializers.update(errors)
207 else:
208 action_serializers = errors
210 return action_serializers
212 def _get_serializer( # noqa: C901
213 self, serializer_type="collect"
214 ) -> typing.Union[BaseSerializer, typing.Type[BaseSerializer]]:
215 view = self.view
216 method = view.request.method
218 action_serializers = self._extract_audoma_action_operations(
219 view, serializer_type
220 )
221 if action_serializers:
222 if (
223 isinstance(action_serializers, dict)
224 and method.lower() in action_serializers
225 ):
226 return action_serializers[method.lower()]
227 else:
228 return action_serializers
230 try:
231 if isinstance(view, AudomaGenericAPIView):
232 return view.get_serializer_class(type=serializer_type)()
233 elif isinstance(view, GenericAPIView):
234 # try to circumvent queryset issues with calling get_serializer. if view has NOT
235 # overridden get_serializer, its safe to use get_serializer_class.
236 if view.__class__.get_serializer == GenericAPIView.get_serializer:
237 return view.get_serializer_class()()
238 return view.get_serializer()
239 elif isinstance(view, APIView):
240 # APIView does not implement the required interface, but be lenient and make
241 # good guesses before giving up and emitting a warning.
242 if callable(getattr(view, "get_serializer", None)):
243 return view.get_serializer()
244 elif callable(getattr(view, "get_serializer_class", None)):
245 return view.get_serializer_class()()
246 elif hasattr(view, "serializer_class"):
247 return view.serializer_class
248 else:
249 error(
250 "unable to guess serializer. This is graceful "
251 "fallback handling for APIViews. Consider using GenericAPIView as view base "
252 "class, if view is under your control. Ignoring view for now. "
253 )
254 else:
255 error(
256 "Encountered unknown view base class. Please report this issue. Ignoring for now"
257 )
258 except Exception as exc:
259 error(
260 f"exception raised while getting serializer. Hint: "
261 f"Is get_serializer_class() returning None or is get_queryset() not working without "
262 f"a request? Ignoring the view for now. (Exception: {exc})"
263 )
265 def get_response_serializers(
266 self,
267 ) -> typing.Union[BaseSerializer, typing.Type[BaseSerializer]]:
268 """overrides this for custom behaviour"""
269 return self._get_serializer(serializer_type="result")
271 def _get_enum_choices_for_field(self, field):
272 if hasattr(field, "original_choices"):
273 choices = field.original_choices
274 else:
275 choices = field.choices
277 if hasattr(choices, "items"):
278 choices = choices.items()
280 return {"choices": {key: value for key, value in choices}}
282 def _get_link_choices_for_field(self, field, serializer):
283 link = serializer.choices_options_links.get(field.field_name, None)
284 if not link:
285 return
287 if isinstance(link, dict):
288 # presume that this dictionary are link kwargs
289 link.update(
290 {"field_name": field.field_name, "serializer_class": type(serializer)}
291 )
292 link = ChoicesOptionsLink(**link)
294 choices = self.choice_link_schema_generator.generate_schema(link)
295 if choices:
296 if link.field_name == field.field_name:
297 return choices
298 return
300 def _map_serializer_field(
301 self, field: Field, direction: str, bypass_extensions=False
302 ) -> dict:
303 """
304 Allows to use @extend_schema_field with `field` dict so that
305 it gets updated instead of being overriden
306 """
307 serializer_type = "collect" if direction == "request" else "result"
308 serializer = self._get_serializer(serializer_type=serializer_type)
309 serializer = force_instance(serializer)
311 if (
312 hasattr(serializer, "Meta")
313 and self.is_bulk
314 and field.field_name == getattr(serializer.Meta, "id_field", None)
315 ):
316 field.read_only = False
317 field.required = True
319 has_annotation = (
320 hasattr(field, "_spectacular_annotation")
321 and "field" in field._spectacular_annotation
322 and isinstance(field._spectacular_annotation["field"], dict)
323 )
324 if has_annotation:
325 annotation = field._spectacular_annotation
326 field._spectacular_annotation = {}
328 result = super()._map_serializer_field(
329 field, direction, bypass_extensions=bypass_extensions
330 )
332 if not isinstance(field, RelatedField) and hasattr(field, "choices"):
333 result["x-choices"] = self._get_enum_choices_for_field(field)
334 result["description"] = create_choices_enum_description(
335 result["x-choices"]["choices"], field.field_name
336 )
338 if hasattr(serializer, "choices_options_links"):
339 choices = self._get_link_choices_for_field(field, serializer)
340 if choices:
341 result["x-choices"] = choices
343 if has_annotation:
344 result.update(annotation["field"])
346 return result
348 def _get_request_for_media_type(self, serializer):
350 schema, request_body_required = super()._get_request_for_media_type(serializer)
351 if self.is_bulk:
352 schema = build_array_type(schema)
353 if isinstance(serializer, BulkSerializerMixin) and self.view.action == "create":
354 schema = {"oneOf": [build_array_type(schema), schema]}
355 return schema, request_body_required
357 def _get_response_for_code(self, serializer, status_code, media_types=None):
359 schema_resp = super()._get_response_for_code(
360 serializer, status_code, media_types
361 )
363 if self.is_bulk:
364 for media_type in schema_resp["content"]:
365 schema = schema_resp["content"][media_type]["schema"]
366 schema_resp["content"][media_type]["schema"] = build_array_type(schema)
368 if isinstance(serializer, BulkSerializerMixin) and self.view.action == "create":
369 for media_type in schema_resp["content"]:
370 schema = schema_resp["content"][media_type]["schema"]
371 schema_resp["content"][media_type]["schema"] = {
372 "oneOf": [build_array_type(schema), schema]
373 }
374 return schema_resp
376 def _build_exclusive_fields_schema(
377 self, schema: dict, exclusive_fields: typing.List[str]
378 ) -> typing.List[dict]:
379 modified_schemas = []
380 for field in exclusive_fields:
381 new_schema = deepcopy(schema)
382 new_schema["properties"].pop(field)
383 modified_schemas.append(new_schema)
384 return modified_schemas
386 def _map_serializer(
387 self,
388 serializer: typing.Union[
389 OpenApiTypes, BaseSerializer, typing.Type[BaseSerializer]
390 ],
391 direction: str,
392 bypass_extensions: bool = False,
393 ) -> dict:
394 serializer = force_instance(serializer)
395 serializer_extension = OpenApiSerializerExtension.get_match(serializer)
397 if serializer_extension and not bypass_extensions:
398 schema = serializer_extension.map_serializer(self, direction)
399 else:
400 schema = self._map_basic_serializer(serializer, direction)
402 if hasattr(serializer, "validators") and direction == "request":
403 subschemas = []
404 for validator in serializer.validators:
405 if isinstance(validator, ExclusiveFieldsValidator):
406 subschemas += self._build_exclusive_fields_schema(
407 schema, validator.fields
408 )
410 if subschemas:
411 schema = {"oneOf": subschemas}
413 extensions = get_override(serializer, "extensions", {})
414 if extensions:
415 schema.update(sanitize_specification_extensions(extensions))
417 return self._postprocess_serializer_schema(schema, serializer, direction)
419 def get_operation_id(self):
420 """override this for custom behaviour"""
421 tokenized_path = self._tokenize_path()
423 if self.is_bulk:
424 tokenized_path.append("bulk")
426 # replace dashes as they can be problematic later in code generation
427 tokenized_path = [t.replace("-", "_") for t in tokenized_path]
429 if self.method == "GET" and self._is_list_view():
430 action = "list"
431 else:
432 action = self.method_mapping[self.method.lower()]
434 if not tokenized_path:
435 tokenized_path.append("root")
437 if re.search(r"<drf_format_suffix\w*:\w+>", self.path_regex):
438 tokenized_path.append("formatted")
440 return "_".join(tokenized_path + [action])
442 def get_operation(
443 self, path, path_regex, path_prefix, method, registry: ComponentRegistry
444 ):
445 self.is_bulk = False
447 if "bulk" in getattr(self.view, "action", "") and method in [
448 "POST",
449 "PUT",
450 "PATCH",
451 ]:
452 self.is_bulk = True
454 return super().get_operation(path, path_regex, path_prefix, method, registry)