Coverage for src/extratools_cloud/aws/sqs.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-17 04:51 -0700

1import json 

2from collections.abc import Iterable 

3from os import getenv 

4from typing import Any, cast 

5from uuid import uuid4 

6 

7import boto3 

8from boto3.resources.base import ServiceResource 

9from extratools_core.crudl import CRUDLDict 

10from extratools_core.json import JsonDict 

11from toolz.itertoolz import partition_all 

12 

13from .helpers import ClientErrorHandler 

14 

15STAGE: str = getenv("STAGE", "local") 

16 

17 

18default_service_resource: ServiceResource = boto3.resource( 

19 "sqs", 

20 endpoint_url=( 

21 "http://localhost:4566" if STAGE == "local" 

22 else None 

23 ), 

24) 

25 

26type Queue = Any 

27 

28FIFO_QUEUE_NAME_SUFFIX = ".fifo" 

29 

30 

31def get_resource_dict( 

32 *, 

33 service_resource: ServiceResource | None = None, 

34 queue_name_prefix: str | None = None, 

35) -> CRUDLDict[str, Queue]: 

36 service_resource = service_resource or default_service_resource 

37 

38 # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sqs/service-resource/index.html 

39 

40 def check_queue_name(queue_name: str) -> None: 

41 if queue_name_prefix and not queue_name.startswith(queue_name_prefix): 

42 raise ValueError 

43 

44 def create_func(queue_name: str | None, attributes: dict[str, str]) -> None: 

45 if queue_name is None: 

46 raise ValueError 

47 

48 check_queue_name(queue_name) 

49 

50 service_resource.create_queue( 

51 QueueName=queue_name, 

52 Attributes={ 

53 "FifoQueue": str(queue_name.endswith(FIFO_QUEUE_NAME_SUFFIX)).lower(), 

54 **attributes, 

55 }, 

56 ) 

57 

58 @ClientErrorHandler( 

59 "QueueDoesNotExist", 

60 KeyError, 

61 ) 

62 def read_func(queue_name: str) -> Queue: 

63 check_queue_name(queue_name) 

64 

65 return service_resource.get_queue_by_name( 

66 QueueName=queue_name, 

67 ) 

68 

69 def update_func(queue_name: str, attributes: dict[str, str]) -> None: 

70 check_queue_name(queue_name) 

71 

72 service_resource.get_queue_by_name( 

73 QueueName=queue_name, 

74 ).set_attributes( 

75 Attributes={ 

76 **attributes, 

77 }, 

78 ) 

79 

80 def delete_func(queue_name: str) -> None: 

81 check_queue_name(queue_name) 

82 

83 service_resource.get_queue_by_name( 

84 QueueName=queue_name, 

85 ).delete() 

86 

87 def list_func(_: None) -> Iterable[tuple[str, Queue]]: 

88 for queue in ( 

89 service_resource.queues.filter( 

90 QueueNamePrefix=queue_name_prefix, 

91 ) 

92 if queue_name_prefix 

93 else service_resource.queues.all() 

94 ): 

95 queue_name = cast("str", queue.url).rsplit('/', maxsplit=1)[-1] 

96 yield queue_name, queue 

97 

98 return CRUDLDict[str, Queue]( 

99 create_func=create_func, 

100 read_func=read_func, 

101 update_func=update_func, 

102 delete_func=delete_func, 

103 list_func=list_func, 

104 ) 

105 

106 

107MESSAGE_BATCH_SIZE = 10 

108 

109 

110def send_messages( 

111 queue: Queue, 

112 messages: Iterable[JsonDict], 

113 group: str | None = None, 

114) -> Iterable[JsonDict]: 

115 batch_id = str(uuid4()) 

116 

117 fifo: bool = queue.url.endswith(FIFO_QUEUE_NAME_SUFFIX) 

118 if fifo and not group: 

119 raise ValueError 

120 

121 for message_batch in partition_all( 

122 MESSAGE_BATCH_SIZE, 

123 ( 

124 (f"{batch_id}_{i}", message_data) 

125 for i, message_data in enumerate(messages) 

126 ), 

127 ): 

128 response: JsonDict = queue.send_messages(Entries=[ 

129 dict( 

130 Id=message_id, 

131 MessageBody=json.dumps(message_data), 

132 ) | ( 

133 dict( 

134 MessageDeduplicationId=message_id, 

135 MessageGroupId=group, 

136 ) 

137 if fifo else {} 

138 ) 

139 for message_id, message_data in message_batch 

140 ]) 

141 

142 yield from response.get("Successful", []) 

143 yield from response.get("Failed", []) 

144 

145 

146class FifoRouter: 

147 def __init__( 

148 self, 

149 *, 

150 service_resource: ServiceResource | None = None, 

151 queue_name_prefix: str, 

152 default_queue_base_name: str, 

153 ) -> None: 

154 self.__resource_dict: CRUDLDict[str, Queue] = get_resource_dict( 

155 service_resource=service_resource, 

156 queue_name_prefix=queue_name_prefix, 

157 ) 

158 

159 self.__default_queue: Queue = self.__resource_dict[ 

160 queue_name_prefix + default_queue_base_name + FIFO_QUEUE_NAME_SUFFIX 

161 ] 

162 

163 self.__queue_name_prefix = queue_name_prefix 

164 

165 queue_name_prefix_len = len(queue_name_prefix) 

166 queue_name_suffix_len = len(FIFO_QUEUE_NAME_SUFFIX) 

167 self.__queues: dict[str, Queue] = { 

168 (queue_name[queue_name_prefix_len:])[:-queue_name_suffix_len]: queue 

169 for queue_name, queue in self.__resource_dict.items() 

170 } 

171 

172 def register_queue( 

173 self, 

174 queue_base_name: str, 

175 groups: Iterable[str] | None = None, 

176 *, 

177 create: bool = True, 

178 ) -> None: 

179 queue_name = self.__queue_name_prefix + queue_base_name + FIFO_QUEUE_NAME_SUFFIX 

180 

181 if queue_name not in self.__resource_dict: 

182 if create: 

183 self.__resource_dict[queue_name] = {} 

184 else: 

185 raise KeyError 

186 

187 queue = self.__resource_dict[queue_name] 

188 self.__queues.update(dict.fromkeys(groups or [queue_base_name], queue)) 

189 

190 def send_messages( 

191 self, 

192 messages: Iterable[JsonDict], 

193 group: str, 

194 ) -> Iterable[JsonDict]: 

195 yield from send_messages( 

196 self.__queues.get(group, self.__default_queue), 

197 messages, 

198 group, 

199 )