Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/exporter/session.py: 74%

88 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-06 10:21 +0200

1import logging 

2from collections import deque 

3from contextlib import AbstractContextManager, asynccontextmanager, contextmanager, suppress 

4from dataclasses import dataclass, field 

5from logging.handlers import QueueHandler 

6from uuid import UUID 

7 

8import grpc 

9from anyio import Event, TypedAttributeLookupError, sleep 

10from anyio.from_thread import start_blocking_portal 

11from jumpstarter_protocol import ( 

12 jumpstarter_pb2, 

13 jumpstarter_pb2_grpc, 

14 router_pb2_grpc, 

15) 

16 

17from .logging import LogHandler 

18from jumpstarter.common import Metadata, TemporarySocket 

19from jumpstarter.common.streams import StreamRequestMetadata 

20from jumpstarter.driver import Driver 

21from jumpstarter.streams.common import forward_stream 

22from jumpstarter.streams.metadata import MetadataStreamAttributes 

23from jumpstarter.streams.router import RouterStream 

24 

25logger = logging.getLogger(__name__) 

26 

27 

28@dataclass(kw_only=True) 

29class Session( 

30 jumpstarter_pb2_grpc.ExporterServiceServicer, 

31 router_pb2_grpc.RouterServiceServicer, 

32 Metadata, 

33 AbstractContextManager, 

34): 

35 root_device: Driver 

36 mapping: dict[UUID, Driver] 

37 

38 _logging_queue: deque = field(init=False) 

39 _logging_handler: QueueHandler = field(init=False) 

40 

41 def __enter__(self): 

42 logging.getLogger().addHandler(self._logging_handler) 

43 self.root_device.reset() 

44 return self 

45 

46 def __exit__(self, exc_type, exc_value, traceback): 

47 self.root_device.close() 

48 logging.getLogger().removeHandler(self._logging_handler) 

49 

50 def __init__(self, *args, root_device, **kwargs): 

51 super().__init__(*args, **kwargs) 

52 

53 self.root_device = root_device 

54 self.mapping = {u: i for (u, _, _, i) in self.root_device.enumerate()} 

55 

56 self._logging_queue = deque(maxlen=32) 

57 self._logging_handler = LogHandler(self._logging_queue) 

58 

59 @asynccontextmanager 

60 async def serve_port_async(self, port): 

61 server = grpc.aio.server() 

62 server.add_insecure_port(port) 

63 

64 jumpstarter_pb2_grpc.add_ExporterServiceServicer_to_server(self, server) 

65 router_pb2_grpc.add_RouterServiceServicer_to_server(self, server) 

66 

67 await server.start() 

68 try: 

69 yield 

70 finally: 

71 await server.stop(grace=None) 

72 

73 @asynccontextmanager 

74 async def serve_unix_async(self): 

75 with TemporarySocket() as path: 

76 async with self.serve_port_async(f"unix://{path}"): 

77 yield path 

78 

79 @contextmanager 

80 def serve_unix(self): 

81 with start_blocking_portal() as portal: 

82 with portal.wrap_async_context_manager(self.serve_unix_async()) as path: 

83 yield path 

84 

85 def __getitem__(self, key: UUID): 

86 return self.mapping[key] 

87 

88 async def GetReport(self, request, context): 

89 logger.debug("GetReport()") 

90 return jumpstarter_pb2.GetReportResponse( 

91 uuid=str(self.uuid), 

92 labels=self.labels, 

93 reports=[ 

94 instance.report(parent=parent, name=name) 

95 for (_, parent, name, instance) in self.root_device.enumerate() 

96 ], 

97 ) 

98 

99 async def DriverCall(self, request, context): 

100 logger.debug("DriverCall(uuid=%s, method=%s)", request.uuid, request.method) 

101 return await self[UUID(request.uuid)].DriverCall(request, context) 

102 

103 async def StreamingDriverCall(self, request, context): 

104 logger.debug("StreamingDriverCall(uuid=%s, method=%s)", request.uuid, request.method) 

105 async for v in self[UUID(request.uuid)].StreamingDriverCall(request, context): 

106 yield v 

107 

108 async def Stream(self, _request_iterator, context): 

109 request = StreamRequestMetadata(**dict(list(context.invocation_metadata()))).request 

110 logger.debug("Streaming(%s)", request) 

111 async with self[request.uuid].Stream(request, context) as stream: 

112 metadata = [] 

113 with suppress(TypedAttributeLookupError): 

114 metadata.extend(stream.extra(MetadataStreamAttributes.metadata).items()) 

115 await context.send_initial_metadata(metadata) 

116 

117 async with RouterStream(context=context) as remote: 

118 async with forward_stream(remote, stream): 

119 event = Event() 

120 context.add_done_callback(lambda _: event.set()) 

121 await event.wait() 

122 

123 async def LogStream(self, request, context): 

124 while True: 

125 try: 

126 yield self._logging_queue.popleft() 

127 except IndexError: 

128 await sleep(0.5)