Coverage for audoma/drf/mixins.py: 47%
107 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-08 06:13 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-08 06:13 +0000
1"""
2This module overwrites basic mixins provided bu django rest framework.
3Mixins defined here should be used instead of default drf's mixins.
4Those mixins should be used to allow usage of extended `get_serializer` method>
6Example:
8 from audoma.drf import mixins
9 from audoma.drf import viewsets
11 class ExampleModelViewSet(
12 mixins.ActionModelMixin,
13 mixins.CreateModelMixin,
14 viewsets.GenericViewSet,
15 ):
16 serializer_class = ExampleModelSerializer
17 queryset = ExampleModel.objects.all()
19"""
21from typing import (
22 Any,
23 Dict,
24 List,
25)
27from rest_framework import (
28 mixins,
29 serializers,
30 status,
31)
32from rest_framework.request import Request
33from rest_framework.response import Response
34from rest_framework.serializers import BaseSerializer
35from rest_framework.settings import api_settings
37from django.core.exceptions import ValidationError
40class ActionModelMixin:
41 def perform_action(
42 self,
43 request: Request,
44 success_status: int = status.HTTP_200_OK,
45 instance: Any = None,
46 partial: bool = False,
47 **kwargs
48 ) -> Response:
49 if instance:
50 serializer = self.get_serializer(
51 data=request.data, instance=instance, partial=partial
52 )
53 else:
54 serializer = self.get_serializer(data=request.data)
55 serializer.is_valid(raise_exception=True)
56 serializer.save()
57 return_serializer = self.get_result_serializer(serializer.instance)
58 headers = self.get_success_headers(return_serializer.data)
60 return Response(return_serializer.data, status=success_status, headers=headers)
62 def retrieve_instance(
63 self,
64 request: Request,
65 instance: Any = None,
66 success_status: int = status.HTTP_200_OK,
67 **kwargs
68 ) -> Response:
69 if instance is None:
70 instance = self.get_object()
71 assert instance is not None
72 serializer = self.get_result_serializer(instance)
73 return Response(serializer.data, status=success_status)
75 def get_success_headers(self, data: dict) -> dict:
76 try:
77 return {"Location": str(data[api_settings.URL_FIELD_NAME])}
78 except (TypeError, KeyError):
79 return {}
82class CreateModelMixin(mixins.CreateModelMixin):
83 def create(self, request: Request, *args, **kwargs) -> Response:
84 serializer = self.get_serializer(data=request.data)
85 serializer.is_valid(raise_exception=True)
86 self.perform_create(serializer)
87 return_serializer = self.get_result_serializer(serializer.instance)
88 headers = self.get_success_headers(return_serializer.data)
89 return Response(
90 return_serializer.data, status=status.HTTP_201_CREATED, headers=headers
91 )
94class ListModelMixin(mixins.ListModelMixin):
95 def list(self, request: Request, *args, **kwargs) -> Response:
96 queryset = self.filter_queryset(self.get_queryset())
98 page = self.paginate_queryset(queryset)
99 if page is not None:
100 serializer = self.get_result_serializer(page, many=True)
101 return self.get_paginated_response(serializer.data)
103 serializer = self.get_result_serializer(queryset, many=True)
104 return Response(serializer.data)
106 def get_paginated_response(self, data: List[Dict]) -> Response:
107 ret = super().get_paginated_response(data)
108 if hasattr(self, "get_list_message"):
109 assert callable(self.get_list_message)
110 ret.data["message"] = self.get_list_message()
111 else:
112 ret.data["message"] = None
113 return ret
116class RetrieveModelMixin(mixins.RetrieveModelMixin):
117 def retrieve(self, request: Request, *args, **kwargs) -> Response:
118 instance = self.get_object()
119 serializer = self.get_result_serializer(instance)
120 return Response(serializer.data)
123class UpdateModelMixin(mixins.UpdateModelMixin):
124 def update(self, request: Request, *args, **kwargs) -> Response:
125 partial = kwargs.pop("partial", False)
126 instance = self.get_object()
127 serializer = self.get_serializer(instance, data=request.data, partial=partial)
128 serializer.is_valid(raise_exception=True)
129 self.perform_update(serializer)
131 if getattr(instance, "_prefetched_objects_cache", None):
132 # If 'prefetch_related' has been applied to a queryset, we need to
133 # forcibly invalidate the prefetch cache on the instance.
134 instance._prefetched_objects_cache = {}
135 return_serializer = self.get_result_serializer(serializer.instance)
136 return Response(return_serializer.data)
139class DestroyModelMixin(mixins.DestroyModelMixin):
140 def destroy(self, request: Request, *args, **kwargs) -> Response:
142 instance = self.get_object()
143 try:
144 self.perform_destroy(instance)
145 except ValidationError as e:
146 raise serializers.ValidationError({"detail": e.message})
147 return Response(status=status.HTTP_204_NO_CONTENT)
150class BulkCreateModelMixin(CreateModelMixin):
151 """
152 Either create a single or many model instances in bulk by using the
153 Serializers ``many=True`` ability from Django REST >= 2.2.5.
154 .. note::
155 This mixin uses the same method to create model instances
156 as ``CreateModelMixin`` because both non-bulk and bulk
157 requests will use ``POST`` request method.
158 """
160 def create(self, request: Request, *args, **kwargs) -> Response:
161 bulk = isinstance(request.data, list)
162 if not bulk:
163 return super(BulkCreateModelMixin, self).create(request, *args, **kwargs)
164 else:
165 serializer = self.get_serializer(data=request.data, many=True)
166 serializer.is_valid(raise_exception=True)
167 self.perform_bulk_create(serializer)
168 serializer = self.get_result_serializer(serializer.instance, many=True)
169 return Response(serializer.data, status=status.HTTP_201_CREATED)
171 def perform_bulk_create(self, serializer: BaseSerializer) -> None:
172 self.perform_create(serializer)
175class BulkUpdateModelMixin(object):
176 """
177 Update model instances in bulk by using the Serializers
178 ``many=True`` ability from Django REST >= 2.2.5.
179 """
181 def get_object(self) -> Any:
182 lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
183 if lookup_url_kwarg in self.kwargs:
184 return super().get_object()
185 # If the lookup_url_kwarg is not present
186 # get_object() is most likely called as part of options()
187 # which by default simply checks for object permissions
188 # and raises permission denied if necessary.
189 # Here we don't need to check for general permissions
190 # and can simply return None since general permissions
191 # are checked in initial() which always gets executed
192 # before any of the API actions (e.g. create, update, etc)
193 return
195 def bulk_update(self, request: Request, *args, **kwargs) -> Response:
196 partial = kwargs.pop("partial", False)
197 # restrict the update to the filtered queryset
198 serializer = self.get_serializer(
199 self.filter_queryset(self.get_queryset()),
200 data=request.data,
201 many=True,
202 partial=partial,
203 )
204 serializer.is_valid(raise_exception=True)
205 self.perform_bulk_update(serializer)
206 return Response(serializer.data, status=status.HTTP_200_OK)
208 def partial_bulk_update(self, request: Request, *args, **kwargs) -> Response:
209 kwargs["partial"] = True
210 return self.bulk_update(request, *args, **kwargs)
212 def perform_update(self, serializer: BaseSerializer) -> None:
213 serializer.save()
215 def perform_bulk_update(self, serializer: BaseSerializer) -> None:
216 self.perform_update(serializer)
219# class BulkDestroyModelMixin(object):
220# """
221# Destroy model instances.
222# """
224# def allow_bulk_destroy(self, qs: Any, filtered: Any) -> bool:
225# """
226# Hook to ensure that the bulk destroy should be allowed.
227# By default this checks that the destroy is only applied to
228# filtered querysets.
229# """
230# return qs is not filtered
231# def bulk_destroy(self, request: Request, *args, **kwargs) -> Response:
232# qs = self.get_queryset()
233# filtered = self.filter_queryset(qs)
234# if not self.allow_bulk_destroy(qs, filtered):
235# return Response(status=status.HTTP_400_BAD_REQUEST)
237# self.perform_bulk_destroy(filtered)
238# return Response(status=status.HTTP_204_NO_CONTENT)
239# def perform_destroy(self, instance: object) -> None:
240# instance.delete()
242# def perform_bulk_destroy(self, objects: Any) -> None:
243# for obj in objects:
244# self.perform_destroy(obj)