Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/exporter/exporter.py: 28%
98 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-06 10:20 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-06 10:20 +0200
1import logging
2from collections.abc import Callable
3from contextlib import AbstractAsyncContextManager, asynccontextmanager
4from dataclasses import dataclass, field
6import grpc
7from anyio import connect_unix, create_memory_object_stream, create_task_group, sleep
8from google.protobuf import empty_pb2
9from jumpstarter_protocol import (
10 jumpstarter_pb2,
11 jumpstarter_pb2_grpc,
12)
14from jumpstarter.common import Metadata
15from jumpstarter.common.streams import connect_router_stream
16from jumpstarter.config.tls import TLSConfigV1Alpha1
17from jumpstarter.driver import Driver
18from jumpstarter.exporter.session import Session
20logger = logging.getLogger(__name__)
23@dataclass(kw_only=True)
24class Exporter(AbstractAsyncContextManager, Metadata):
25 channel_factory: Callable[[], grpc.aio.Channel]
26 device_factory: Callable[[], Driver]
27 lease_name: str = field(init=False, default="")
28 tls: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1)
29 grpc_options: dict[str, str] = field(default_factory=dict)
31 async def __aexit__(self, exc_type, exc_value, traceback):
32 controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory())
33 logger.info("Unregistering exporter with controller")
34 await controller.Unregister(
35 jumpstarter_pb2.UnregisterRequest(
36 reason="TODO",
37 )
38 )
40 async def __handle(self, path, endpoint, token, tls_config, grpc_options):
41 try:
42 async with await connect_unix(path) as stream:
43 async with connect_router_stream(endpoint, token, stream, tls_config, grpc_options):
44 pass
45 except Exception as e:
46 logger.info("failed to handle connection: {}".format(e))
48 @asynccontextmanager
49 async def session(self):
50 controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory())
51 with Session(
52 uuid=self.uuid,
53 labels=self.labels,
54 root_device=self.device_factory(),
55 ) as session:
56 async with session.serve_unix_async() as path:
57 async with grpc.aio.secure_channel(
58 f"unix://{path}", grpc.local_channel_credentials(grpc.LocalConnectionType.UDS)
59 ) as channel:
60 response = await jumpstarter_pb2_grpc.ExporterServiceStub(channel).GetReport(empty_pb2.Empty())
61 logger.info("Registering exporter with controller")
62 await controller.Register(
63 jumpstarter_pb2.RegisterRequest(
64 labels=self.labels,
65 reports=response.reports,
66 )
67 )
68 yield path
70 async def handle(self, lease_name, tg):
71 logger.info("Listening for incoming connection requests on lease %s", lease_name)
73 listen_tx, listen_rx = create_memory_object_stream()
75 async def listen(retries=5, backoff=3):
76 retries_left = retries
77 while True:
78 try:
79 controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory())
80 async for request in controller.Listen(jumpstarter_pb2.ListenRequest(lease_name=lease_name)):
81 await listen_tx.send(request)
82 except Exception as e:
83 if retries_left > 0:
84 retries_left -= 1
85 logger.info(
86 "Listen stream interrupted, restarting in {}s, {} retries left: {}".format(
87 backoff, retries_left, e
88 )
89 )
90 await sleep(backoff)
91 else:
92 raise
93 else:
94 retries_left = retries
96 tg.start_soon(listen)
98 async with self.session() as path:
99 async for request in listen_rx:
100 logger.info("Handling new connection request on lease %s", lease_name)
101 tg.start_soon(
102 self.__handle, path, request.router_endpoint, request.router_token, self.tls, self.grpc_options
103 )
105 async def serve(self): # noqa: C901
106 # initial registration
107 async with self.session():
108 pass
109 started = False
110 status_tx, status_rx = create_memory_object_stream()
112 async def status(retries=5, backoff=3):
113 retries_left = retries
114 while True:
115 try:
116 controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory())
117 async for status in controller.Status(jumpstarter_pb2.StatusRequest()):
118 await status_tx.send(status)
119 except Exception as e:
120 if retries_left > 0:
121 retries_left -= 1
122 logger.info(
123 "Status stream interrupted, restarting in {}s, {} retries left: {}".format(
124 backoff, retries_left, e
125 )
126 )
127 await sleep(backoff)
128 else:
129 raise
130 else:
131 retries_left = retries
133 async with create_task_group() as tg:
134 tg.start_soon(status)
135 async for status in status_rx:
136 if self.lease_name != "" and self.lease_name != status.lease_name:
137 self.lease_name = status.lease_name
138 logger.info("Lease status changed, killing existing connections")
139 tg.cancel_scope.cancel()
140 break
141 self.lease_name = status.lease_name
142 if not started and self.lease_name != "":
143 started = True
144 tg.start_soon(self.handle, self.lease_name, tg)
145 if status.leased:
146 logger.info("Currently leased by %s under %s", status.client_name, status.lease_name)
147 else:
148 logger.info("Currently not leased")