Coverage for audoma/drf/mixins.py: 47%

107 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-08-08 06:13 +0000

1""" 

2This module overwrites basic mixins provided bu django rest framework. 

3Mixins defined here should be used instead of default drf's mixins. 

4Those mixins should be used to allow usage of extended `get_serializer` method> 

5 

6Example: 

7 

8 from audoma.drf import mixins 

9 from audoma.drf import viewsets 

10 

11 class ExampleModelViewSet( 

12 mixins.ActionModelMixin, 

13 mixins.CreateModelMixin, 

14 viewsets.GenericViewSet, 

15 ): 

16 serializer_class = ExampleModelSerializer 

17 queryset = ExampleModel.objects.all() 

18 

19""" 

20 

21from typing import ( 

22 Any, 

23 Dict, 

24 List, 

25) 

26 

27from rest_framework import ( 

28 mixins, 

29 serializers, 

30 status, 

31) 

32from rest_framework.request import Request 

33from rest_framework.response import Response 

34from rest_framework.serializers import BaseSerializer 

35from rest_framework.settings import api_settings 

36 

37from django.core.exceptions import ValidationError 

38 

39 

40class ActionModelMixin: 

41 def perform_action( 

42 self, 

43 request: Request, 

44 success_status: int = status.HTTP_200_OK, 

45 instance: Any = None, 

46 partial: bool = False, 

47 **kwargs 

48 ) -> Response: 

49 if instance: 

50 serializer = self.get_serializer( 

51 data=request.data, instance=instance, partial=partial 

52 ) 

53 else: 

54 serializer = self.get_serializer(data=request.data) 

55 serializer.is_valid(raise_exception=True) 

56 serializer.save() 

57 return_serializer = self.get_result_serializer(serializer.instance) 

58 headers = self.get_success_headers(return_serializer.data) 

59 

60 return Response(return_serializer.data, status=success_status, headers=headers) 

61 

62 def retrieve_instance( 

63 self, 

64 request: Request, 

65 instance: Any = None, 

66 success_status: int = status.HTTP_200_OK, 

67 **kwargs 

68 ) -> Response: 

69 if instance is None: 

70 instance = self.get_object() 

71 assert instance is not None 

72 serializer = self.get_result_serializer(instance) 

73 return Response(serializer.data, status=success_status) 

74 

75 def get_success_headers(self, data: dict) -> dict: 

76 try: 

77 return {"Location": str(data[api_settings.URL_FIELD_NAME])} 

78 except (TypeError, KeyError): 

79 return {} 

80 

81 

82class CreateModelMixin(mixins.CreateModelMixin): 

83 def create(self, request: Request, *args, **kwargs) -> Response: 

84 serializer = self.get_serializer(data=request.data) 

85 serializer.is_valid(raise_exception=True) 

86 self.perform_create(serializer) 

87 return_serializer = self.get_result_serializer(serializer.instance) 

88 headers = self.get_success_headers(return_serializer.data) 

89 return Response( 

90 return_serializer.data, status=status.HTTP_201_CREATED, headers=headers 

91 ) 

92 

93 

94class ListModelMixin(mixins.ListModelMixin): 

95 def list(self, request: Request, *args, **kwargs) -> Response: 

96 queryset = self.filter_queryset(self.get_queryset()) 

97 

98 page = self.paginate_queryset(queryset) 

99 if page is not None: 

100 serializer = self.get_result_serializer(page, many=True) 

101 return self.get_paginated_response(serializer.data) 

102 

103 serializer = self.get_result_serializer(queryset, many=True) 

104 return Response(serializer.data) 

105 

106 def get_paginated_response(self, data: List[Dict]) -> Response: 

107 ret = super().get_paginated_response(data) 

108 if hasattr(self, "get_list_message"): 

109 assert callable(self.get_list_message) 

110 ret.data["message"] = self.get_list_message() 

111 else: 

112 ret.data["message"] = None 

113 return ret 

114 

115 

116class RetrieveModelMixin(mixins.RetrieveModelMixin): 

117 def retrieve(self, request: Request, *args, **kwargs) -> Response: 

118 instance = self.get_object() 

119 serializer = self.get_result_serializer(instance) 

120 return Response(serializer.data) 

121 

122 

123class UpdateModelMixin(mixins.UpdateModelMixin): 

124 def update(self, request: Request, *args, **kwargs) -> Response: 

125 partial = kwargs.pop("partial", False) 

126 instance = self.get_object() 

127 serializer = self.get_serializer(instance, data=request.data, partial=partial) 

128 serializer.is_valid(raise_exception=True) 

129 self.perform_update(serializer) 

130 

131 if getattr(instance, "_prefetched_objects_cache", None): 

132 # If 'prefetch_related' has been applied to a queryset, we need to 

133 # forcibly invalidate the prefetch cache on the instance. 

134 instance._prefetched_objects_cache = {} 

135 return_serializer = self.get_result_serializer(serializer.instance) 

136 return Response(return_serializer.data) 

137 

138 

139class DestroyModelMixin(mixins.DestroyModelMixin): 

140 def destroy(self, request: Request, *args, **kwargs) -> Response: 

141 

142 instance = self.get_object() 

143 try: 

144 self.perform_destroy(instance) 

145 except ValidationError as e: 

146 raise serializers.ValidationError({"detail": e.message}) 

147 return Response(status=status.HTTP_204_NO_CONTENT) 

148 

149 

150class BulkCreateModelMixin(CreateModelMixin): 

151 """ 

152 Either create a single or many model instances in bulk by using the 

153 Serializers ``many=True`` ability from Django REST >= 2.2.5. 

154 .. note:: 

155 This mixin uses the same method to create model instances 

156 as ``CreateModelMixin`` because both non-bulk and bulk 

157 requests will use ``POST`` request method. 

158 """ 

159 

160 def create(self, request: Request, *args, **kwargs) -> Response: 

161 bulk = isinstance(request.data, list) 

162 if not bulk: 

163 return super(BulkCreateModelMixin, self).create(request, *args, **kwargs) 

164 else: 

165 serializer = self.get_serializer(data=request.data, many=True) 

166 serializer.is_valid(raise_exception=True) 

167 self.perform_bulk_create(serializer) 

168 serializer = self.get_result_serializer(serializer.instance, many=True) 

169 return Response(serializer.data, status=status.HTTP_201_CREATED) 

170 

171 def perform_bulk_create(self, serializer: BaseSerializer) -> None: 

172 self.perform_create(serializer) 

173 

174 

175class BulkUpdateModelMixin(object): 

176 """ 

177 Update model instances in bulk by using the Serializers 

178 ``many=True`` ability from Django REST >= 2.2.5. 

179 """ 

180 

181 def get_object(self) -> Any: 

182 lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field 

183 if lookup_url_kwarg in self.kwargs: 

184 return super().get_object() 

185 # If the lookup_url_kwarg is not present 

186 # get_object() is most likely called as part of options() 

187 # which by default simply checks for object permissions 

188 # and raises permission denied if necessary. 

189 # Here we don't need to check for general permissions 

190 # and can simply return None since general permissions 

191 # are checked in initial() which always gets executed 

192 # before any of the API actions (e.g. create, update, etc) 

193 return 

194 

195 def bulk_update(self, request: Request, *args, **kwargs) -> Response: 

196 partial = kwargs.pop("partial", False) 

197 # restrict the update to the filtered queryset 

198 serializer = self.get_serializer( 

199 self.filter_queryset(self.get_queryset()), 

200 data=request.data, 

201 many=True, 

202 partial=partial, 

203 ) 

204 serializer.is_valid(raise_exception=True) 

205 self.perform_bulk_update(serializer) 

206 return Response(serializer.data, status=status.HTTP_200_OK) 

207 

208 def partial_bulk_update(self, request: Request, *args, **kwargs) -> Response: 

209 kwargs["partial"] = True 

210 return self.bulk_update(request, *args, **kwargs) 

211 

212 def perform_update(self, serializer: BaseSerializer) -> None: 

213 serializer.save() 

214 

215 def perform_bulk_update(self, serializer: BaseSerializer) -> None: 

216 self.perform_update(serializer) 

217 

218 

219# class BulkDestroyModelMixin(object): 

220# """ 

221# Destroy model instances. 

222# """ 

223 

224# def allow_bulk_destroy(self, qs: Any, filtered: Any) -> bool: 

225# """ 

226# Hook to ensure that the bulk destroy should be allowed. 

227# By default this checks that the destroy is only applied to 

228# filtered querysets. 

229# """ 

230# return qs is not filtered 

231# def bulk_destroy(self, request: Request, *args, **kwargs) -> Response: 

232# qs = self.get_queryset() 

233# filtered = self.filter_queryset(qs) 

234# if not self.allow_bulk_destroy(qs, filtered): 

235# return Response(status=status.HTTP_400_BAD_REQUEST) 

236 

237# self.perform_bulk_destroy(filtered) 

238# return Response(status=status.HTTP_204_NO_CONTENT) 

239# def perform_destroy(self, instance: object) -> None: 

240# instance.delete() 

241 

242# def perform_bulk_destroy(self, objects: Any) -> None: 

243# for obj in objects: 

244# self.perform_destroy(obj)