Coverage for audoma/schema.py: 65%

69 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-08-08 06:12 +0000

1from collections import namedtuple 

2 

3from drf_spectacular.contrib.django_filters import DjangoFilterExtension 

4from drf_spectacular.extensions import OpenApiFilterExtension 

5from rest_framework.filters import SearchFilter 

6 

7from audoma.plumbing import create_choices_enum_description 

8 

9 

10class AudomaDjangoFilterExtension(DjangoFilterExtension): 

11 

12 priority = 5 

13 

14 def _get_x_choices(self, choices: namedtuple): 

15 return {"choices": {key: value for key, value in choices}} 

16 

17 def resolve_filter_field( 

18 self, auto_schema, model, filterset_class, field_name, filter_field 

19 ): 

20 choices = filter_field.extra.get("choices", None) or getattr( 

21 filter_field, "choices", None 

22 ) 

23 if choices: 

24 # set proper help-text 

25 filter_field.extra["help_text"] = filter_field.extra.get( 

26 "help_text", "{choices}" 

27 ) 

28 filter_field.extra["help_text"] = filter_field.extra["help_text"].format( 

29 choices=create_choices_enum_description(choices, field_name) 

30 ) 

31 schemas = super().resolve_filter_field( 

32 auto_schema, model, filterset_class, field_name, filter_field 

33 ) 

34 parsed_schemas = [] 

35 for schema in schemas: 

36 if choices and "x-choices" not in schema: 

37 schema["schema"]["x-choices"] = self._get_x_choices(choices) 

38 parsed_schemas.append(schema) 

39 

40 return parsed_schemas 

41 

42 

43class SearchFilterExtension(OpenApiFilterExtension): 

44 target_class = SearchFilter 

45 

46 SEARCH_PARAMS = { 

47 "^": "Starts-with search.", 

48 "=": "Exact matches.", 

49 "@": "Full-text search.", 

50 "$": "Regex search.", 

51 } 

52 

53 def _create_search_fields_description(self, fields): 

54 transformed_fields = [] 

55 for field in fields: 

56 try: 

57 partials = field.split("__") 

58 except IndexError: 

59 transformed_fields.append(field) 

60 continue 

61 if len(partials) < 2: 

62 continue 

63 out = "" 

64 for x, partial in enumerate(partials[1:]): 

65 out += "\t " * (x + 1) + f"* `{partial}` \n" 

66 transformed_fields.append(out) 

67 return transformed_fields 

68 

69 def _preprocess_fields(self, fields): 

70 for x, field in enumerate(fields): 

71 for keyword, description in self.SEARCH_PARAMS.items(): 

72 if keyword in field: 

73 fields[x] = field.replace(keyword, "") 

74 fields[x] += f"({description})" 

75 return fields 

76 

77 def _get_processed_search_fields(self, fields): 

78 transformed_fields = {} 

79 fields = self._preprocess_fields(fields) 

80 

81 while fields: 

82 start_phrase = fields[0].split("__")[0] 

83 processed_fields = [f for f in fields if f.startswith(start_phrase)] 

84 transformed_fields[start_phrase] = self._create_search_fields_description( 

85 processed_fields 

86 ) 

87 fields = [f for f in fields if f not in processed_fields] 

88 

89 return transformed_fields 

90 

91 def _get_custom_serach_filter_description(self, view): 

92 description = "Search by: \n" 

93 processed_fields = self._get_processed_search_fields(view.search_fields) 

94 for field_name, field_description in processed_fields.items(): 

95 description += f"* `{field_name}` \n" + "".join(field_description) 

96 return description 

97 

98 def get_schema_operation_parameters(self, schema): 

99 view = schema.view 

100 result = self.target.get_schema_operation_parameters(view)[0] 

101 result["description"] = result.get("description", "{search_fields}") 

102 # overwrite default description 

103 if result["description"] == self.target_class.search_description: 

104 result["description"] = "{search_fields}" 

105 

106 result["description"] = result["description"].format( 

107 search_fields=self._get_custom_serach_filter_description(view) 

108 ) 

109 return [result]