Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/driver/base.py: 53%
131 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-06 10:21 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-06 10:21 +0200
1"""
2Base classes for drivers and driver clients
3"""
5from __future__ import annotations
7import logging
8from abc import ABCMeta, abstractmethod
9from contextlib import asynccontextmanager
10from dataclasses import field
11from inspect import isasyncgenfunction, iscoroutinefunction
12from itertools import chain
13from typing import Any
14from uuid import UUID, uuid4
16import aiohttp
17from anyio import to_thread
18from grpc import StatusCode
19from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc
20from pydantic import TypeAdapter
21from pydantic.dataclasses import dataclass
23from .decorators import (
24 MARKER_DRIVERCALL,
25 MARKER_MAGIC,
26 MARKER_STREAMCALL,
27 MARKER_STREAMING_DRIVERCALL,
28)
29from jumpstarter.common import Metadata
30from jumpstarter.common.resources import ClientStreamResource, PresignedRequestResource, Resource, ResourceMetadata
31from jumpstarter.common.serde import decode_value, encode_value
32from jumpstarter.common.streams import (
33 DriverStreamRequest,
34 ResourceStreamRequest,
35)
36from jumpstarter.streams.aiohttp import AiohttpStreamReaderStream
37from jumpstarter.streams.common import create_memory_stream
38from jumpstarter.streams.metadata import MetadataStream
39from jumpstarter.streams.progress import ProgressStream
42@dataclass(kw_only=True)
43class Driver(
44 Metadata,
45 jumpstarter_pb2_grpc.ExporterServiceServicer,
46 router_pb2_grpc.RouterServiceServicer,
47 metaclass=ABCMeta,
48):
49 """Base class for drivers
51 Drivers should at the minimum implement the `client` method.
53 Regular or streaming driver calls can be marked with the `export` decorator.
54 Raw stream constructors can be marked with the `exportstream` decorator.
55 """
57 children: dict[str, Driver] = field(default_factory=dict)
59 resources: dict[UUID, Any] = field(default_factory=dict, init=False)
60 """Dict of client side resources"""
62 log_level: str = "INFO"
63 logger: logging.Logger = field(init=False)
65 def __post_init__(self):
66 if hasattr(super(), "__post_init__"):
67 super().__post_init__()
69 self.logger = logging.getLogger(self.__class__.__name__)
70 self.logger.setLevel(self.log_level)
72 def close(self):
73 for child in self.children.values():
74 child.close()
76 def reset(self):
77 for child in self.children.values():
78 child.reset()
80 @classmethod
81 @abstractmethod
82 def client(cls) -> str:
83 """
84 Return full import path of the corresponding driver client class
85 """
87 def extra_labels(self) -> dict[str, str]:
88 return {}
90 async def DriverCall(self, request, context):
91 """
92 :meta private:
93 """
94 try:
95 method = await self.__lookup_drivercall(request.method, context, MARKER_DRIVERCALL)
97 args = [decode_value(arg) for arg in request.args]
99 if iscoroutinefunction(method):
100 result = await method(*args)
101 else:
102 result = await to_thread.run_sync(method, *args)
104 return jumpstarter_pb2.DriverCallResponse(
105 uuid=str(uuid4()),
106 result=encode_value(result),
107 )
108 except NotImplementedError as e:
109 await context.abort(StatusCode.UNIMPLEMENTED, str(e))
110 except ValueError as e:
111 await context.abort(StatusCode.INVALID_ARGUMENT, str(e))
112 except TimeoutError as e:
113 await context.abort(StatusCode.DEADLINE_EXCEEDED, str(e))
114 except Exception as e:
115 await context.abort(StatusCode.UNKNOWN, str(e))
117 async def StreamingDriverCall(self, request, context):
118 """
119 :meta private:
120 """
121 try:
122 method = await self.__lookup_drivercall(request.method, context, MARKER_STREAMING_DRIVERCALL)
124 args = [decode_value(arg) for arg in request.args]
126 if isasyncgenfunction(method):
127 async for result in method(*args):
128 yield jumpstarter_pb2.StreamingDriverCallResponse(
129 uuid=str(uuid4()),
130 result=encode_value(result),
131 )
132 else:
133 for result in await to_thread.run_sync(method, *args):
134 yield jumpstarter_pb2.StreamingDriverCallResponse(
135 uuid=str(uuid4()),
136 result=encode_value(result),
137 )
138 except NotImplementedError as e:
139 await context.abort(StatusCode.UNIMPLEMENTED, str(e))
140 except ValueError as e:
141 await context.abort(StatusCode.INVALID_ARGUMENT, str(e))
142 except TimeoutError as e:
143 await context.abort(StatusCode.DEADLINE_EXCEEDED, str(e))
144 except Exception as e:
145 await context.abort(StatusCode.UNKNOWN, str(e))
147 @asynccontextmanager
148 async def Stream(self, request, context):
149 """
150 :meta private:
151 """
152 match request:
153 case DriverStreamRequest(method=driver_method):
154 method = await self.__lookup_drivercall(driver_method, context, MARKER_STREAMCALL)
156 async with method() as stream:
157 yield stream
159 case ResourceStreamRequest():
160 remote, resource = create_memory_stream()
162 resource_uuid = uuid4()
164 self.resources[resource_uuid] = resource
166 async with MetadataStream(
167 stream=remote,
168 metadata=ResourceMetadata.model_construct(
169 resource=ClientStreamResource(uuid=resource_uuid)
170 ).model_dump(mode="json", round_trip=True),
171 ) as stream:
172 yield stream
174 def report(self, *, root=None, parent=None, name=None):
175 """
176 Create DriverInstanceReport
178 :meta private:
179 """
181 if root is None:
182 root = self
184 return jumpstarter_pb2.DriverInstanceReport(
185 uuid=str(self.uuid),
186 parent_uuid=str(parent.uuid) if parent else None,
187 labels=self.labels
188 | self.extra_labels()
189 | ({"jumpstarter.dev/client": self.client()})
190 | ({"jumpstarter.dev/name": name} if name else {}),
191 )
193 def enumerate(self, *, root=None, parent=None, name=None):
194 """
195 Get list of self and child devices
197 :meta private:
198 """
199 if root is None:
200 root = self
202 return [(self.uuid, parent, name, self)] + list(
203 chain(*[child.enumerate(root=root, parent=self, name=cname) for (cname, child) in self.children.items()])
204 )
206 @asynccontextmanager
207 async def resource(self, handle: str, timeout: int = 300):
208 handle = TypeAdapter(Resource).validate_python(handle)
209 match handle:
210 case ClientStreamResource(uuid=uuid):
211 async with self.resources[uuid] as stream:
212 try:
213 yield stream
214 finally:
215 del self.resources[uuid]
216 case PresignedRequestResource(headers=headers, url=url, method=method):
217 client_timeout = aiohttp.ClientTimeout(total=timeout)
218 match method:
219 case "GET":
220 async with aiohttp.request(
221 method, url, headers=headers, raise_for_status=True, timeout=client_timeout
222 ) as resp:
223 async with AiohttpStreamReaderStream(reader=resp.content) as stream:
224 yield ProgressStream(stream=stream, logging=True)
225 case "PUT":
226 remote, stream = create_memory_stream()
227 async with aiohttp.request(
228 method, url, headers=headers, raise_for_status=True, data=remote, timeout=client_timeout
229 ) as resp:
230 async with stream:
231 yield ProgressStream(stream=stream, logging=True)
232 case _:
233 # INVARIANT: method is always one of GET or PUT, see PresignedRequestResource
234 raise ValueError("unreachable")
236 async def __lookup_drivercall(self, name, context, marker):
237 """Lookup drivercall by method name
239 Methods are checked against magic markers
240 to avoid accidentally calling non-exported
241 methods
242 """
243 method = getattr(self, name, None)
245 if method is None:
246 await context.abort(StatusCode.NOT_FOUND, f"method {name} not found on driver")
248 if getattr(method, marker, None) != MARKER_MAGIC:
249 await context.abort(StatusCode.NOT_FOUND, f"method {name} missing marker {marker}")
251 return method