Coverage for /Users/eugene/Development/robotnikmq/robotnikmq/rpc_server.py: 44%
180 statements
« prev ^ index » next coverage.py v7.3.4, created at 2023-12-26 19:16 -0500
« prev ^ index » next coverage.py v7.3.4, created at 2023-12-26 19:16 -0500
1from dataclasses import dataclass
2from inspect import signature, Parameter
3from json import loads as _from_json
4from traceback import format_exc
5from socket import gethostname
6from typing import Optional, Callable, Union, Any, Dict, Tuple, List, TypedDict, Type
7from typing import get_type_hints, get_origin, get_args
8from uuid import uuid4 as uuid, UUID
10from pika import BasicProperties
11from pika.exceptions import AMQPError, ChannelError, AMQPConnectionError
12from tenacity import retry, wait_exponential, retry_if_exception_type
13from typeguard import typechecked
15from robotnikmq.config import RobotnikConfig
16from robotnikmq.core import Robotnik, thread_name, valid_json
17from robotnikmq.utils import to_json as _to_json
18from robotnikmq.log import log
21@typechecked
22def _type_hint_str(typ: Any) -> str:
23 if get_origin(typ) is Union:
24 return f"Union[{','.join([_type_hint_str(t) for t in get_args(typ)])}]"
25 return str(typ.__name__)
27@typechecked
28class RpcErrorTypedDict(TypedDict):
29 request_id: str
30 type: str
31 details: Union[None, str, Dict[str, Any]]
33@typechecked
34@dataclass(frozen=True)
35class RpcError:
36 request_id: Union[str, UUID]
37 details: Union[None, str, Dict[str, Any]]
39 @staticmethod
40 def of(
41 request_id: Union[str, UUID, None] = None,
42 details: Union[None, str, Dict[str, Any]] = None,
43 ) -> 'RpcError':
44 return RpcError(request_id or uuid(), details)
46 def to_json(self) -> str:
47 return _to_json(self.to_dict())
49 def to_dict(self) -> RpcErrorTypedDict:
50 return {
51 "request_id": str(self.request_id),
52 "type": "error",
53 "details": self.details,
54 }
56 @staticmethod
57 def from_json(json_str: Union[str, bytes]) -> Optional['RpcError']:
58 json_str = json_str if isinstance(json_str, str) else json_str.decode()
59 log.debug(json_str)
60 if valid_json(json_str):
61 data = _from_json(json_str)
62 if all(k in data for k in {"request_id", "type", "details"}):
63 return RpcError.of(request_id=data["request_id"], details=data["details"])
64 return None
67@typechecked
68class RpcResponseTypedDict(TypedDict):
69 request_id: str
70 type: str
71 data: Union[None, str, int, float, Dict[str, Any], List[Dict[str, Any]]]
74@typechecked
75@dataclass(frozen=True)
76class RpcResponse:
77 request_id: Union[str, UUID]
78 data: Union[None, str, int, float, Dict[str, Any], List[Dict[str, Any]]]
80 @staticmethod
81 def of(
82 request_id: Union[str, UUID, None] = None,
83 data: Union[None, str, int, float, Dict[str, Any], List[Dict[str, Any]]] = None,
84 ) -> 'RpcResponse':
85 return RpcResponse(request_id or uuid(), data)
87 def to_dict(self) -> RpcResponseTypedDict:
88 return {
89 "request_id": str(self.request_id),
90 "type": "response",
91 "data": self.data,
92 }
94 def to_json(self) -> str:
95 return _to_json(self.to_dict())
97 @staticmethod
98 def from_json(json_str: Union[str, bytes]) -> Optional["RpcResponse"]:
99 json_str = json_str if isinstance(json_str, str) else json_str.decode()
100 if valid_json(json_str):
101 data = _from_json(json_str)
102 if all(k in data for k in ("request_id", "type", "data")):
103 return RpcResponse.of(request_id=data["request_id"], data=data["data"])
104 return None
107class RpcServer(Robotnik):
108 @typechecked
109 def __init__(
110 self,
111 config: Optional[RobotnikConfig] = None,
112 meta_queue_prefix: Optional[str] = None,
113 docs_queue_suffix: Optional[str] = None,
114 only_once: bool = False,
115 ):
116 super().__init__(config=config)
117 self._callbacks: Dict[str, Callable] = {}
118 self.meta_queue_prefix = meta_queue_prefix or gethostname()
119 self.docs_queue_suffix = docs_queue_suffix or ".__doc__"
120 # Typically used for testing, implies server should stop after 1 response
121 self.only_once = only_once
123 @typechecked
124 def _register_docs(self, queue: str, callback: Callable) -> None:
125 self.channel.queue_declare(
126 queue=queue + self.docs_queue_suffix, exclusive=False
127 )
129 @typechecked
130 def docs_callback(_, method, props: BasicProperties, __) -> None:
131 req_id = props.correlation_id or uuid()
132 response = RpcResponse.of(
133 req_id,
134 data={
135 "rpc_queue": queue,
136 "inputs": self._get_input_type_strings(queue),
137 "returns": self._get_return_type_str(queue),
138 "description": callback.__doc__,
139 },
140 )
141 self.channel.basic_publish(
142 exchange="",
143 routing_key=props.reply_to or "",
144 properties=BasicProperties(correlation_id=props.correlation_id),
145 body=response.to_json(),
146 )
147 self.channel.basic_ack(delivery_tag=method.delivery_tag)
149 self.channel.basic_consume(
150 queue=queue + self.docs_queue_suffix,
151 on_message_callback=docs_callback,
152 auto_ack=False,
153 )
155 @typechecked
156 def _get_defaults(self, queue: str) -> Dict:
157 params = signature(self._callbacks[queue]).parameters
158 return {
159 p: params[p].default
160 for p in params
161 if params[p].default is not Parameter.empty
162 }
164 @typechecked
165 def _get_input_types(self, queue: str) -> Dict:
166 return {
167 k: v
168 for k, v in get_type_hints(self._callbacks[queue]).items()
169 if k != "return"
170 }
172 @typechecked
173 def _get_input_type_strings(self, queue: str) -> Dict:
174 return {
175 k: _type_hint_str(v)
176 for k, v in get_type_hints(self._callbacks[queue]).items()
177 if k != "return"
178 }
180 @typechecked
181 def _get_return_type_str(self, queue: str) -> Any:
182 return _type_hint_str(get_type_hints(self._callbacks[queue])["return"])
184 @typechecked
185 @staticmethod
186 def _is_optional(arg_type: Any) -> bool:
187 return get_origin(arg_type) is Union and type(None) in get_args(arg_type)
189 @typechecked
190 @staticmethod
191 def _valid_arg(arg_value: Any, arg_type: Any) -> bool:
192 if arg_type is Any:
193 return True
194 if get_origin(arg_type) is Union:
195 if (type(None) in get_args(arg_type)) and (
196 arg_value is None or arg_value == {}
197 ): # Optional
198 return True
199 return any(
200 RpcServer._valid_arg(arg_value, typ) for typ in get_args(arg_type)
201 )
202 if get_origin(arg_type) is dict:
203 key_type, val_type = get_args(arg_type)
204 return all(
205 RpcServer._valid_arg(key, key_type) for key in arg_value.keys()
206 ) and all(RpcServer._valid_arg(val, val_type) for val in arg_value.values())
207 return isinstance(arg_value, arg_type)
209 def _valid_inputs(self, queue: str, inputs: Dict[str, Any]) -> Tuple[bool, Optional[str]]:
210 inputs_with_defaults = {**self._get_defaults(queue), **inputs}
211 for arg_name, arg_type in self._get_input_types(queue).items():
212 if arg_name not in inputs_with_defaults and not self._is_optional(arg_type):
213 return False, f"Missing required argument {arg_name}"
214 if arg_name in inputs_with_defaults and not self._valid_arg(
215 inputs_with_defaults[arg_name], arg_type
216 ):
217 return False, f"Invalid type for {arg_name}"
218 return True, None
220 @typechecked
221 def register_rpc(
222 self, queue: str, callback: Callable, register_docs: bool = True
223 ) -> None:
224 self.channel.queue_declare(queue=queue, exclusive=False)
225 self._callbacks[queue] = callback
226 if register_docs:
227 self._register_docs(queue, callback)
228 # TODO: servers should have an exclusive Queue for information about themselves
230 @typechecked
231 def meta_callback(_, method, props: BasicProperties, body: bytes):
232 req_id = props.correlation_id or uuid()
233 with thread_name(req_id):
234 self.log.debug("Request received")
235 try:
236 try:
237 if valid_json(body.decode()):
238 input_args: Dict = _from_json(body.decode())
239 self.log.debug(f"Input JSON is valid: {input_args}")
240 valid_inputs, msg = self._valid_inputs(queue, input_args)
241 if not valid_inputs:
242 self.log.debug("Invalid input")
243 response = RpcError.of(req_id, msg).to_json()
244 elif not input_args:
245 self.log.debug(f"Executing: {callback}")
246 response = RpcResponse.of(req_id, callback()).to_json()
247 else:
248 self.log.debug(
249 f"Executing: {callback} with inputs: {input_args}"
250 )
251 response = RpcResponse.of(
252 req_id, callback(**input_args)
253 ).to_json()
254 else:
255 response = RpcError.of(
256 req_id, "Input could not be decoded as JSON"
257 ).to_json()
258 except (AMQPError, ChannelError):
259 raise # we want this kind of exception to be caught further down
260 except Exception: # pylint: disable=W0703
261 self.log.error(
262 "An error has occurred during the execution of the RPC method"
263 )
264 for line in format_exc().split("\n"):
265 self.log.error(line)
266 response = RpcError.of(
267 request_id=req_id,
268 details=f"There was an error "
269 f"while processing the "
270 f"request, please refer "
271 f"to server log with "
272 f"request ID: "
273 f"{req_id}",
274 ).to_json()
275 self.log.debug(f"Response: {response}")
276 self.channel.basic_publish(
277 exchange="",
278 routing_key=props.reply_to or "",
279 properties=BasicProperties(correlation_id=props.correlation_id),
280 body=response,
281 )
282 self.channel.basic_ack(delivery_tag=method.delivery_tag)
283 self.log.debug("Response sent and ack-ed")
284 except (AMQPError, ChannelError):
285 self.log.error(
286 f"A RabbitMQ communication error has occurred while processing "
287 f"Request ID: {req_id}"
288 )
289 for line in format_exc().split("\n"):
290 self.log.error(line)
291 if self.only_once:
292 self.channel.stop_consuming()
294 self.channel.basic_consume(
295 queue=queue, on_message_callback=meta_callback, auto_ack=False
296 )
298 @retry(
299 retry=retry_if_exception_type((AMQPConnectionError, OSError)),
300 wait=wait_exponential(multiplier=1, min=3, max=30),
301 )
302 @typechecked
303 def run(self) -> None:
304 try:
305 self.channel.start_consuming()
306 except KeyboardInterrupt:
307 self.channel.stop_consuming()
308 self.log.info("Shutting down server")