Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/exporter/exporter.py: 28%

98 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-06 10:20 +0200

1import logging 

2from collections.abc import Callable 

3from contextlib import AbstractAsyncContextManager, asynccontextmanager 

4from dataclasses import dataclass, field 

5 

6import grpc 

7from anyio import connect_unix, create_memory_object_stream, create_task_group, sleep 

8from google.protobuf import empty_pb2 

9from jumpstarter_protocol import ( 

10 jumpstarter_pb2, 

11 jumpstarter_pb2_grpc, 

12) 

13 

14from jumpstarter.common import Metadata 

15from jumpstarter.common.streams import connect_router_stream 

16from jumpstarter.config.tls import TLSConfigV1Alpha1 

17from jumpstarter.driver import Driver 

18from jumpstarter.exporter.session import Session 

19 

20logger = logging.getLogger(__name__) 

21 

22 

23@dataclass(kw_only=True) 

24class Exporter(AbstractAsyncContextManager, Metadata): 

25 channel_factory: Callable[[], grpc.aio.Channel] 

26 device_factory: Callable[[], Driver] 

27 lease_name: str = field(init=False, default="") 

28 tls: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) 

29 grpc_options: dict[str, str] = field(default_factory=dict) 

30 

31 async def __aexit__(self, exc_type, exc_value, traceback): 

32 controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory()) 

33 logger.info("Unregistering exporter with controller") 

34 await controller.Unregister( 

35 jumpstarter_pb2.UnregisterRequest( 

36 reason="TODO", 

37 ) 

38 ) 

39 

40 async def __handle(self, path, endpoint, token, tls_config, grpc_options): 

41 try: 

42 async with await connect_unix(path) as stream: 

43 async with connect_router_stream(endpoint, token, stream, tls_config, grpc_options): 

44 pass 

45 except Exception as e: 

46 logger.info("failed to handle connection: {}".format(e)) 

47 

48 @asynccontextmanager 

49 async def session(self): 

50 controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory()) 

51 with Session( 

52 uuid=self.uuid, 

53 labels=self.labels, 

54 root_device=self.device_factory(), 

55 ) as session: 

56 async with session.serve_unix_async() as path: 

57 async with grpc.aio.secure_channel( 

58 f"unix://{path}", grpc.local_channel_credentials(grpc.LocalConnectionType.UDS) 

59 ) as channel: 

60 response = await jumpstarter_pb2_grpc.ExporterServiceStub(channel).GetReport(empty_pb2.Empty()) 

61 logger.info("Registering exporter with controller") 

62 await controller.Register( 

63 jumpstarter_pb2.RegisterRequest( 

64 labels=self.labels, 

65 reports=response.reports, 

66 ) 

67 ) 

68 yield path 

69 

70 async def handle(self, lease_name, tg): 

71 logger.info("Listening for incoming connection requests on lease %s", lease_name) 

72 

73 listen_tx, listen_rx = create_memory_object_stream() 

74 

75 async def listen(retries=5, backoff=3): 

76 retries_left = retries 

77 while True: 

78 try: 

79 controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory()) 

80 async for request in controller.Listen(jumpstarter_pb2.ListenRequest(lease_name=lease_name)): 

81 await listen_tx.send(request) 

82 except Exception as e: 

83 if retries_left > 0: 

84 retries_left -= 1 

85 logger.info( 

86 "Listen stream interrupted, restarting in {}s, {} retries left: {}".format( 

87 backoff, retries_left, e 

88 ) 

89 ) 

90 await sleep(backoff) 

91 else: 

92 raise 

93 else: 

94 retries_left = retries 

95 

96 tg.start_soon(listen) 

97 

98 async with self.session() as path: 

99 async for request in listen_rx: 

100 logger.info("Handling new connection request on lease %s", lease_name) 

101 tg.start_soon( 

102 self.__handle, path, request.router_endpoint, request.router_token, self.tls, self.grpc_options 

103 ) 

104 

105 async def serve(self): # noqa: C901 

106 # initial registration 

107 async with self.session(): 

108 pass 

109 started = False 

110 status_tx, status_rx = create_memory_object_stream() 

111 

112 async def status(retries=5, backoff=3): 

113 retries_left = retries 

114 while True: 

115 try: 

116 controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory()) 

117 async for status in controller.Status(jumpstarter_pb2.StatusRequest()): 

118 await status_tx.send(status) 

119 except Exception as e: 

120 if retries_left > 0: 

121 retries_left -= 1 

122 logger.info( 

123 "Status stream interrupted, restarting in {}s, {} retries left: {}".format( 

124 backoff, retries_left, e 

125 ) 

126 ) 

127 await sleep(backoff) 

128 else: 

129 raise 

130 else: 

131 retries_left = retries 

132 

133 async with create_task_group() as tg: 

134 tg.start_soon(status) 

135 async for status in status_rx: 

136 if self.lease_name != "" and self.lease_name != status.lease_name: 

137 self.lease_name = status.lease_name 

138 logger.info("Lease status changed, killing existing connections") 

139 tg.cancel_scope.cancel() 

140 break 

141 self.lease_name = status.lease_name 

142 if not started and self.lease_name != "": 

143 started = True 

144 tg.start_soon(self.handle, self.lease_name, tg) 

145 if status.leased: 

146 logger.info("Currently leased by %s under %s", status.client_name, status.lease_name) 

147 else: 

148 logger.info("Currently not leased")