Coverage for /Users/eugene/Development/robotnikmq/robotnikmq/core.py: 73%

142 statements  

« prev     ^ index     » next       coverage.py v7.3.4, created at 2023-12-26 19:13 -0500

1from contextlib import contextmanager 

2from dataclasses import dataclass 

3from datetime import datetime 

4from json import loads as _from_json 

5from json.decoder import JSONDecodeError 

6from pathlib import Path 

7from random import sample 

8from threading import current_thread 

9from typing import Optional, Callable, Any, Dict, Union, Generator, List, TypedDict 

10from uuid import uuid4 as uuid, UUID 

11 

12from arrow import Arrow, get as to_arrow, now 

13from funcy import first 

14from pika import BlockingConnection 

15from pika.adapters.blocking_connection import BlockingChannel 

16from pika.exceptions import AMQPError, AMQPConnectionError 

17from pydantic import BaseModel # pylint: disable=E0611 

18from tenacity import ( 

19 retry_if_exception_type, 

20 stop_after_attempt, 

21 wait_random, 

22 Retrying, 

23) 

24from typeguard import typechecked 

25 

26from robotnikmq.config import RobotnikConfig, config_of 

27from robotnikmq.error import UnableToConnect, MalformedMessage 

28from robotnikmq.log import log 

29from robotnikmq.utils import to_json as _to_json 

30 

31AMQPErrorCallback = Optional[Callable[[AMQPError], None]] 

32ConnErrorCallback = Optional[Callable[[AMQPConnectionError], None]] 

33 

34 

35@contextmanager 

36def thread_name(name: Union[str, UUID]): 

37 thread = current_thread() 

38 original = thread.name 

39 thread.name = str(name) 

40 yield 

41 thread.name = original 

42 

43 

44@typechecked 

45def jsonable(content: Any) -> bool: 

46 try: 

47 _to_json(content) 

48 return True 

49 except (TypeError, OverflowError): 

50 return False 

51 

52 

53@typechecked 

54def valid_json(string: str) -> bool: 

55 try: 

56 _from_json(string) 

57 return True 

58 except JSONDecodeError: 

59 return False 

60 

61@typechecked 

62class MessageTypedDict(TypedDict): 

63 contents: Dict[str, Any] 

64 routing_key: Optional[str] 

65 timestamp: Union[int, float] 

66 msg_id: str 

67 

68@typechecked 

69@dataclass(frozen=True) 

70class Message: 

71 contents: Union[BaseModel, Dict[str, Any]] 

72 routing_key: str 

73 timestamp: Arrow 

74 msg_id: Union[str, UUID] 

75 

76 @staticmethod 

77 def of( 

78 contents: Union[BaseModel, Dict], 

79 routing_key: Optional[str] = None, 

80 timestamp: Union[int, float, datetime, Arrow, None] = None, 

81 msg_id: Union[str, UUID, None] = None, 

82 ) -> 'Message': 

83 msg_id = msg_id or uuid() 

84 if not jsonable(contents): 84 ↛ 86line 84 didn't jump to line 86, because the condition on line 84 was never false

85 raise ValueError("Contents of message have to be JSON-serializeable") 

86 contents = contents.dict() if isinstance(contents, BaseModel) else contents 

87 routing_key: str = routing_key or "" 

88 timestamp: Arrow = to_arrow(timestamp) if timestamp is not None else now() 

89 return Message(contents, routing_key, timestamp, msg_id) 

90 

91 def with_routing_key(self, routing_key: Optional[str]) -> 'Message': 

92 return Message.of(self.contents, routing_key, self.timestamp, self.msg_id) 

93 

94 def to_dict(self) -> MessageTypedDict: 

95 return { 

96 "routing_key": self.routing_key, 

97 "contents": self.contents, 

98 "msg_id": str(self.msg_id), 

99 "timestamp": self.timestamp.int_timestamp, 

100 } 

101 

102 def to_json(self) -> str: 

103 return _to_json(self.to_dict()) 

104 

105 @staticmethod 

106 def of_json(body: str) -> "Message": # pylint: disable=C0103 

107 try: 

108 msg = _from_json(body) 

109 return Message.of( 

110 msg["contents"], msg["routing_key"], msg["timestamp"], msg["msg_id"] 

111 ) 

112 except (JSONDecodeError, KeyError) as exc: 

113 raise MalformedMessage(body) from exc 

114 

115 def __getitem__(self, key: str) -> Any: 

116 return self.contents[key] 

117 

118 def keys(self): 

119 return self.contents.keys() 

120 

121 def values(self): 

122 return self.contents.values() 

123 

124 def __contains__(self, item: str) -> bool: 

125 return item in self.contents 

126 

127 def __iter__(self): 

128 return iter(self.contents) 

129 

130 @property 

131 def route(self) -> str: 

132 return self.routing_key 

133 

134 

135class Robotnik: 

136 @typechecked 

137 def __init__( 

138 self, 

139 config: Optional[RobotnikConfig] = None, 

140 config_paths: Optional[List[Path]] = None, 

141 ): 

142 config_paths = config_paths or [ 

143 Path.cwd() / "robotnikmq.yaml", 

144 Path.home() / ".config" / "robotnikmq" / "robotnikmq.yaml", 

145 Path("/etc/robotnikmq/robotnikmq.yaml"), 

146 ] 

147 self.config = config or config_of( 

148 first(path for path in config_paths if path.exists()) 

149 ) 

150 self._connection = None 

151 self._channel: Optional[BlockingChannel] = None 

152 self.log = log.bind(rmq_server="") 

153 

154 @typechecked 

155 def _make_connection(self) -> BlockingConnection: 

156 self.log = log.bind(rmq_server="") 

157 for tier in self.config.tiers: 

158 for config in sample(tier, len(tier)): 

159 try: 

160 connection = BlockingConnection(config.conn_params()) 

161 self.log = log.bind(rmq_server=f"{config.host}:{config.port}{config.vhost}") 

162 self.log.success("Connection to {}:{}{} is successful", 

163 config.host, config.port, config.vhost) 

164 return connection 

165 except AMQPConnectionError: 

166 log.exception("Unable to connect to {}:{}{}", config.host, 

167 config.port, config.vhost) 

168 raise UnableToConnect("Cannot connect to any of the configured servers") 

169 

170 @property 

171 def connection(self) -> BlockingConnection: 

172 if self._connection is None or not self._connection.is_open: 

173 for attempt in Retrying( 

174 retry=retry_if_exception_type((UnableToConnect, AMQPError, OSError)), 

175 stop=stop_after_attempt(self.config.connection.attempts), 

176 wait=wait_random( 

177 min=self.config.connection.wait_random_min_seconds, 

178 max=self.config.connection.wait_random_max_seconds, 

179 ), 

180 reraise=True, 

181 ): 

182 with attempt: 

183 self._connection = self._make_connection() 

184 return self._connection 

185 

186 @typechecked 

187 def _open_channel(self) -> BlockingChannel: 

188 _channel = self.connection.channel() 

189 _channel.basic_qos(prefetch_count=1) 

190 return _channel 

191 

192 @property 

193 def channel(self) -> BlockingChannel: 

194 if self._channel is None or not self._channel.is_open: 

195 self._channel = self._open_channel() 

196 return self._channel 

197 

198 @contextmanager 

199 def open_channel(self) -> Generator[BlockingChannel, None, None]: 

200 _ch = self.channel 

201 yield _ch 

202 self.close_channel(_ch) 

203 

204 @typechecked 

205 def close_channel(self, channel: Optional[BlockingChannel] = None) -> None: 

206 channel = channel or self.channel 

207 if channel is not None and channel.is_open: 

208 channel.stop_consuming() 

209 channel.close()