Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/exporter/session.py: 74%
88 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-06 10:21 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-06 10:21 +0200
1import logging
2from collections import deque
3from contextlib import AbstractContextManager, asynccontextmanager, contextmanager, suppress
4from dataclasses import dataclass, field
5from logging.handlers import QueueHandler
6from uuid import UUID
8import grpc
9from anyio import Event, TypedAttributeLookupError, sleep
10from anyio.from_thread import start_blocking_portal
11from jumpstarter_protocol import (
12 jumpstarter_pb2,
13 jumpstarter_pb2_grpc,
14 router_pb2_grpc,
15)
17from .logging import LogHandler
18from jumpstarter.common import Metadata, TemporarySocket
19from jumpstarter.common.streams import StreamRequestMetadata
20from jumpstarter.driver import Driver
21from jumpstarter.streams.common import forward_stream
22from jumpstarter.streams.metadata import MetadataStreamAttributes
23from jumpstarter.streams.router import RouterStream
25logger = logging.getLogger(__name__)
28@dataclass(kw_only=True)
29class Session(
30 jumpstarter_pb2_grpc.ExporterServiceServicer,
31 router_pb2_grpc.RouterServiceServicer,
32 Metadata,
33 AbstractContextManager,
34):
35 root_device: Driver
36 mapping: dict[UUID, Driver]
38 _logging_queue: deque = field(init=False)
39 _logging_handler: QueueHandler = field(init=False)
41 def __enter__(self):
42 logging.getLogger().addHandler(self._logging_handler)
43 self.root_device.reset()
44 return self
46 def __exit__(self, exc_type, exc_value, traceback):
47 self.root_device.close()
48 logging.getLogger().removeHandler(self._logging_handler)
50 def __init__(self, *args, root_device, **kwargs):
51 super().__init__(*args, **kwargs)
53 self.root_device = root_device
54 self.mapping = {u: i for (u, _, _, i) in self.root_device.enumerate()}
56 self._logging_queue = deque(maxlen=32)
57 self._logging_handler = LogHandler(self._logging_queue)
59 @asynccontextmanager
60 async def serve_port_async(self, port):
61 server = grpc.aio.server()
62 server.add_insecure_port(port)
64 jumpstarter_pb2_grpc.add_ExporterServiceServicer_to_server(self, server)
65 router_pb2_grpc.add_RouterServiceServicer_to_server(self, server)
67 await server.start()
68 try:
69 yield
70 finally:
71 await server.stop(grace=None)
73 @asynccontextmanager
74 async def serve_unix_async(self):
75 with TemporarySocket() as path:
76 async with self.serve_port_async(f"unix://{path}"):
77 yield path
79 @contextmanager
80 def serve_unix(self):
81 with start_blocking_portal() as portal:
82 with portal.wrap_async_context_manager(self.serve_unix_async()) as path:
83 yield path
85 def __getitem__(self, key: UUID):
86 return self.mapping[key]
88 async def GetReport(self, request, context):
89 logger.debug("GetReport()")
90 return jumpstarter_pb2.GetReportResponse(
91 uuid=str(self.uuid),
92 labels=self.labels,
93 reports=[
94 instance.report(parent=parent, name=name)
95 for (_, parent, name, instance) in self.root_device.enumerate()
96 ],
97 )
99 async def DriverCall(self, request, context):
100 logger.debug("DriverCall(uuid=%s, method=%s)", request.uuid, request.method)
101 return await self[UUID(request.uuid)].DriverCall(request, context)
103 async def StreamingDriverCall(self, request, context):
104 logger.debug("StreamingDriverCall(uuid=%s, method=%s)", request.uuid, request.method)
105 async for v in self[UUID(request.uuid)].StreamingDriverCall(request, context):
106 yield v
108 async def Stream(self, _request_iterator, context):
109 request = StreamRequestMetadata(**dict(list(context.invocation_metadata()))).request
110 logger.debug("Streaming(%s)", request)
111 async with self[request.uuid].Stream(request, context) as stream:
112 metadata = []
113 with suppress(TypedAttributeLookupError):
114 metadata.extend(stream.extra(MetadataStreamAttributes.metadata).items())
115 await context.send_initial_metadata(metadata)
117 async with RouterStream(context=context) as remote:
118 async with forward_stream(remote, stream):
119 event = Event()
120 context.add_done_callback(lambda _: event.set())
121 await event.wait()
123 async def LogStream(self, request, context):
124 while True:
125 try:
126 yield self._logging_queue.popleft()
127 except IndexError:
128 await sleep(0.5)