Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/client/core.py: 53%
91 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-06 10:21 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-06 10:21 +0200
1"""
2Base classes for drivers and driver clients
3"""
5import logging
6from contextlib import asynccontextmanager
7from dataclasses import dataclass, field
9from anyio import create_task_group
10from google.protobuf import empty_pb2
11from grpc import StatusCode
12from grpc.aio import AioRpcError, Channel
13from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc
15from jumpstarter.common import Metadata
16from jumpstarter.common.exceptions import JumpstarterException
17from jumpstarter.common.resources import ResourceMetadata
18from jumpstarter.common.serde import decode_value, encode_value
19from jumpstarter.common.streams import (
20 DriverStreamRequest,
21 ResourceStreamRequest,
22 StreamRequestMetadata,
23)
24from jumpstarter.streams.common import forward_stream
25from jumpstarter.streams.metadata import MetadataStream, MetadataStreamAttributes
26from jumpstarter.streams.progress import ProgressStream
27from jumpstarter.streams.router import RouterStream
30class DriverError(JumpstarterException):
31 """
32 Raised when a driver call returns an error
33 """
36class DriverMethodNotImplemented(DriverError, NotImplementedError):
37 """
38 Raised when a driver method is not implemented
39 """
42class DriverInvalidArgument(DriverError, ValueError):
43 """
44 Raised when a driver method is called with invalid arguments
45 """
48@dataclass(kw_only=True)
49class AsyncDriverClient(
50 Metadata,
51 jumpstarter_pb2_grpc.ExporterServiceStub,
52 router_pb2_grpc.RouterServiceStub,
53):
54 """
55 Async driver client base class
57 Backing implementation of blocking driver client.
58 """
60 channel: Channel
62 log_level: str = "INFO"
63 logger: logging.Logger = field(init=False)
65 def __post_init__(self):
66 if hasattr(super(), "__post_init__"):
67 super().__post_init__()
68 jumpstarter_pb2_grpc.ExporterServiceStub.__init__(self, self.channel)
69 router_pb2_grpc.RouterServiceStub.__init__(self, self.channel)
70 self.logger = logging.getLogger(self.__class__.__name__)
71 self.logger.setLevel(self.log_level)
73 # add default handler
74 if not self.logger.handlers:
75 handler = logging.StreamHandler()
76 handler.setFormatter(logging.Formatter("%(name)s - %(levelname)s - %(message)s"))
77 self.logger.addHandler(handler)
79 async def call_async(self, method, *args):
80 """Make DriverCall by method name and arguments"""
82 request = jumpstarter_pb2.DriverCallRequest(
83 uuid=str(self.uuid),
84 method=method,
85 args=[encode_value(arg) for arg in args],
86 )
88 try:
89 response = await self.DriverCall(request)
90 except AioRpcError as e:
91 match e.code():
92 case StatusCode.UNIMPLEMENTED:
93 raise DriverMethodNotImplemented(e.details()) from None
94 case StatusCode.INVALID_ARGUMENT:
95 raise DriverInvalidArgument(e.details()) from None
96 case StatusCode.UNKNOWN:
97 raise DriverError(e.details()) from None
98 case _:
99 raise DriverError(e.details()) from e
101 return decode_value(response.result)
103 async def streamingcall_async(self, method, *args):
104 """Make StreamingDriverCall by method name and arguments"""
106 request = jumpstarter_pb2.StreamingDriverCallRequest(
107 uuid=str(self.uuid),
108 method=method,
109 args=[encode_value(arg) for arg in args],
110 )
112 try:
113 async for response in self.StreamingDriverCall(request):
114 yield decode_value(response.result)
115 except AioRpcError as e:
116 match e.code():
117 case StatusCode.UNIMPLEMENTED:
118 raise DriverMethodNotImplemented(e.details()) from None
119 case StatusCode.INVALID_ARGUMENT:
120 raise DriverInvalidArgument(e.details()) from None
121 case StatusCode.UNKNOWN:
122 raise DriverError(e.details()) from None
123 case _:
124 raise DriverError(e.details()) from e
126 @asynccontextmanager
127 async def stream_async(self, method):
128 context = self.Stream(
129 metadata=StreamRequestMetadata.model_construct(request=DriverStreamRequest(uuid=self.uuid, method=method))
130 .model_dump(mode="json", round_trip=True)
131 .items(),
132 )
133 metadata = dict(list(await context.initial_metadata()))
134 async with MetadataStream(stream=RouterStream(context=context), metadata=metadata) as stream:
135 yield stream
137 @asynccontextmanager
138 async def resource_async(
139 self,
140 stream,
141 ):
142 context = self.Stream(
143 metadata=StreamRequestMetadata.model_construct(request=ResourceStreamRequest(uuid=self.uuid))
144 .model_dump(mode="json", round_trip=True)
145 .items(),
146 )
147 metadata = dict(list(await context.initial_metadata()))
148 async with MetadataStream(stream=RouterStream(context=context), metadata=metadata) as rstream:
149 async with forward_stream(ProgressStream(stream=stream), rstream):
150 yield ResourceMetadata(**rstream.extra(MetadataStreamAttributes.metadata)).resource.model_dump(
151 mode="json"
152 )
154 def __log(self, level: int, msg: str):
155 self.logger.log(level, msg)
157 @asynccontextmanager
158 async def log_stream_async(self):
159 async def log_stream():
160 async for response in self.LogStream(empty_pb2.Empty()):
161 self.__log(logging.getLevelName(response.severity), response.message)
163 async with create_task_group() as tg:
164 tg.start_soon(log_stream)
165 try:
166 yield
167 finally:
168 tg.cancel_scope.cancel()