Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/common/grpc.py: 25%

67 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-06 10:20 +0200

1import asyncio 

2import base64 

3import os 

4import socket 

5import ssl 

6from contextlib import contextmanager 

7from typing import Any, Sequence, Tuple 

8from urllib.parse import urlparse 

9 

10import grpc 

11from anyio import fail_after 

12 

13from jumpstarter.common.exceptions import ConfigurationError, ConnectionError 

14 

15 

16async def ssl_channel_credentials(target: str, tls_config, timeout=5): 

17 configure_grpc_env() 

18 if tls_config.insecure or os.getenv("JUMPSTARTER_GRPC_INSECURE") == "1": 

19 try: 

20 parsed = urlparse(f"//{target}") 

21 port = parsed.port if parsed.port else 443 

22 except ValueError as e: 

23 raise ConfigurationError(f"Failed parsing {target}") from e 

24 

25 try: 

26 with fail_after(timeout): 

27 ssl_context = ssl.create_default_context() 

28 ssl_context.check_hostname = False 

29 ssl_context.verify_mode = ssl.CERT_NONE 

30 _, writer = await asyncio.open_connection(parsed.hostname, port, ssl=ssl_context) 

31 root_certificates = "" 

32 for cert in writer.get_extra_info("ssl_object")._sslobj.get_unverified_chain(): 

33 root_certificates += cert.public_bytes() 

34 return grpc.ssl_channel_credentials(root_certificates=root_certificates.encode()) 

35 except socket.gaierror as e: 

36 raise ConnectionError(f"Failed resolving {parsed.hostname}") from e 

37 except ConnectionRefusedError as e: 

38 raise ConnectionError(f"Failed connecting to {parsed.hostname}:{port}") from e 

39 except TimeoutError as e: 

40 raise ConnectionError(f"Timeout connecting to {parsed.hostname}:{port}") from e 

41 

42 elif tls_config.ca != "": 

43 ca_certificate = base64.b64decode(tls_config.ca) 

44 return grpc.ssl_channel_credentials(ca_certificate) 

45 else: 

46 return grpc.ssl_channel_credentials() 

47 

48 

49def aio_secure_channel(target: str, credentials: grpc.ChannelCredentials, grpc_options: dict[str, Any] | None): 

50 return grpc.aio.secure_channel( 

51 target, 

52 credentials, 

53 options=_override_default_grpc_options(grpc_options), 

54 ) 

55 

56 

57def _override_default_grpc_options(grpc_options: dict[str, str | int] | None) -> Sequence[Tuple[str, Any]]: 

58 defaults = ( 

59 ("grpc.lb_policy_name", "round_robin"), 

60 # we keep a low keepalive time to avoid idle timeouts on cloud load balancers 

61 ("grpc.keepalive_time_ms", 20000), 

62 ("grpc.keepalive_timeout_ms", 5000), 

63 ("grpc.http2.max_pings_without_data", 0), 

64 ("grpc.keepalive_permit_without_calls", 1), 

65 ) 

66 options = dict(defaults) 

67 options.update(grpc_options or {}) 

68 return tuple(options.items()) 

69 

70 

71def configure_grpc_env(): 

72 # disable informative logs by default, i.e.: 

73 # WARNING: All log messages before absl::InitializeLog() is called are written to STDERR 

74 # I0000 00:00:1739970744.889307 61962 ssl_transport_security.cc:1665] Handshake failed ... 

75 if os.environ.get("GRPC_VERBOSITY") is None: 

76 os.environ["GRPC_VERBOSITY"] = "ERROR" 

77 if os.environ.get("GLOG_minloglevel") is None: 

78 os.environ["GLOG_minloglevel"] = "2" 

79 

80 

81@contextmanager 

82def translate_grpc_exceptions(): 

83 """Translate grpc exceptions to JumpstarterExceptions.""" 

84 try: 

85 yield 

86 except grpc.aio.AioRpcError as e: 

87 if e.code().name == "UNAVAILABLE": 

88 # tls or other connection errors 

89 raise ConnectionError(f"grpc error: {e.details()}") from None 

90 if e.code().name == "UNKNOWN": 

91 # an error returned from our functions 

92 raise ConnectionError(f"grpc controller responded: {e.details()}") from None 

93 else: 

94 raise ConnectionError("grpc error") from e 

95 except grpc.RpcError as e: 

96 raise ConnectionError("grpc error") from e 

97 except ValueError as e: 

98 raise ConfigurationError("grpc error") from e 

99 except Exception as e: 

100 raise e