Coverage for audoma/schema.py: 65%
69 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-08 06:12 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-08 06:12 +0000
1from collections import namedtuple
3from drf_spectacular.contrib.django_filters import DjangoFilterExtension
4from drf_spectacular.extensions import OpenApiFilterExtension
5from rest_framework.filters import SearchFilter
7from audoma.plumbing import create_choices_enum_description
10class AudomaDjangoFilterExtension(DjangoFilterExtension):
12 priority = 5
14 def _get_x_choices(self, choices: namedtuple):
15 return {"choices": {key: value for key, value in choices}}
17 def resolve_filter_field(
18 self, auto_schema, model, filterset_class, field_name, filter_field
19 ):
20 choices = filter_field.extra.get("choices", None) or getattr(
21 filter_field, "choices", None
22 )
23 if choices:
24 # set proper help-text
25 filter_field.extra["help_text"] = filter_field.extra.get(
26 "help_text", "{choices}"
27 )
28 filter_field.extra["help_text"] = filter_field.extra["help_text"].format(
29 choices=create_choices_enum_description(choices, field_name)
30 )
31 schemas = super().resolve_filter_field(
32 auto_schema, model, filterset_class, field_name, filter_field
33 )
34 parsed_schemas = []
35 for schema in schemas:
36 if choices and "x-choices" not in schema:
37 schema["schema"]["x-choices"] = self._get_x_choices(choices)
38 parsed_schemas.append(schema)
40 return parsed_schemas
43class SearchFilterExtension(OpenApiFilterExtension):
44 target_class = SearchFilter
46 SEARCH_PARAMS = {
47 "^": "Starts-with search.",
48 "=": "Exact matches.",
49 "@": "Full-text search.",
50 "$": "Regex search.",
51 }
53 def _create_search_fields_description(self, fields):
54 transformed_fields = []
55 for field in fields:
56 try:
57 partials = field.split("__")
58 except IndexError:
59 transformed_fields.append(field)
60 continue
61 if len(partials) < 2:
62 continue
63 out = ""
64 for x, partial in enumerate(partials[1:]):
65 out += "\t " * (x + 1) + f"* `{partial}` \n"
66 transformed_fields.append(out)
67 return transformed_fields
69 def _preprocess_fields(self, fields):
70 for x, field in enumerate(fields):
71 for keyword, description in self.SEARCH_PARAMS.items():
72 if keyword in field:
73 fields[x] = field.replace(keyword, "")
74 fields[x] += f"({description})"
75 return fields
77 def _get_processed_search_fields(self, fields):
78 transformed_fields = {}
79 fields = self._preprocess_fields(fields)
81 while fields:
82 start_phrase = fields[0].split("__")[0]
83 processed_fields = [f for f in fields if f.startswith(start_phrase)]
84 transformed_fields[start_phrase] = self._create_search_fields_description(
85 processed_fields
86 )
87 fields = [f for f in fields if f not in processed_fields]
89 return transformed_fields
91 def _get_custom_serach_filter_description(self, view):
92 description = "Search by: \n"
93 processed_fields = self._get_processed_search_fields(view.search_fields)
94 for field_name, field_description in processed_fields.items():
95 description += f"* `{field_name}` \n" + "".join(field_description)
96 return description
98 def get_schema_operation_parameters(self, schema):
99 view = schema.view
100 result = self.target.get_schema_operation_parameters(view)[0]
101 result["description"] = result.get("description", "{search_fields}")
102 # overwrite default description
103 if result["description"] == self.target_class.search_description:
104 result["description"] = "{search_fields}"
106 result["description"] = result["description"].format(
107 search_fields=self._get_custom_serach_filter_description(view)
108 )
109 return [result]