Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/driver/base.py: 53%

131 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-06 10:21 +0200

1""" 

2Base classes for drivers and driver clients 

3""" 

4 

5from __future__ import annotations 

6 

7import logging 

8from abc import ABCMeta, abstractmethod 

9from contextlib import asynccontextmanager 

10from dataclasses import field 

11from inspect import isasyncgenfunction, iscoroutinefunction 

12from itertools import chain 

13from typing import Any 

14from uuid import UUID, uuid4 

15 

16import aiohttp 

17from anyio import to_thread 

18from grpc import StatusCode 

19from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc 

20from pydantic import TypeAdapter 

21from pydantic.dataclasses import dataclass 

22 

23from .decorators import ( 

24 MARKER_DRIVERCALL, 

25 MARKER_MAGIC, 

26 MARKER_STREAMCALL, 

27 MARKER_STREAMING_DRIVERCALL, 

28) 

29from jumpstarter.common import Metadata 

30from jumpstarter.common.resources import ClientStreamResource, PresignedRequestResource, Resource, ResourceMetadata 

31from jumpstarter.common.serde import decode_value, encode_value 

32from jumpstarter.common.streams import ( 

33 DriverStreamRequest, 

34 ResourceStreamRequest, 

35) 

36from jumpstarter.streams.aiohttp import AiohttpStreamReaderStream 

37from jumpstarter.streams.common import create_memory_stream 

38from jumpstarter.streams.metadata import MetadataStream 

39from jumpstarter.streams.progress import ProgressStream 

40 

41 

42@dataclass(kw_only=True) 

43class Driver( 

44 Metadata, 

45 jumpstarter_pb2_grpc.ExporterServiceServicer, 

46 router_pb2_grpc.RouterServiceServicer, 

47 metaclass=ABCMeta, 

48): 

49 """Base class for drivers 

50 

51 Drivers should at the minimum implement the `client` method. 

52 

53 Regular or streaming driver calls can be marked with the `export` decorator. 

54 Raw stream constructors can be marked with the `exportstream` decorator. 

55 """ 

56 

57 children: dict[str, Driver] = field(default_factory=dict) 

58 

59 resources: dict[UUID, Any] = field(default_factory=dict, init=False) 

60 """Dict of client side resources""" 

61 

62 log_level: str = "INFO" 

63 logger: logging.Logger = field(init=False) 

64 

65 def __post_init__(self): 

66 if hasattr(super(), "__post_init__"): 

67 super().__post_init__() 

68 

69 self.logger = logging.getLogger(self.__class__.__name__) 

70 self.logger.setLevel(self.log_level) 

71 

72 def close(self): 

73 for child in self.children.values(): 

74 child.close() 

75 

76 def reset(self): 

77 for child in self.children.values(): 

78 child.reset() 

79 

80 @classmethod 

81 @abstractmethod 

82 def client(cls) -> str: 

83 """ 

84 Return full import path of the corresponding driver client class 

85 """ 

86 

87 def extra_labels(self) -> dict[str, str]: 

88 return {} 

89 

90 async def DriverCall(self, request, context): 

91 """ 

92 :meta private: 

93 """ 

94 try: 

95 method = await self.__lookup_drivercall(request.method, context, MARKER_DRIVERCALL) 

96 

97 args = [decode_value(arg) for arg in request.args] 

98 

99 if iscoroutinefunction(method): 

100 result = await method(*args) 

101 else: 

102 result = await to_thread.run_sync(method, *args) 

103 

104 return jumpstarter_pb2.DriverCallResponse( 

105 uuid=str(uuid4()), 

106 result=encode_value(result), 

107 ) 

108 except NotImplementedError as e: 

109 await context.abort(StatusCode.UNIMPLEMENTED, str(e)) 

110 except ValueError as e: 

111 await context.abort(StatusCode.INVALID_ARGUMENT, str(e)) 

112 except TimeoutError as e: 

113 await context.abort(StatusCode.DEADLINE_EXCEEDED, str(e)) 

114 except Exception as e: 

115 await context.abort(StatusCode.UNKNOWN, str(e)) 

116 

117 async def StreamingDriverCall(self, request, context): 

118 """ 

119 :meta private: 

120 """ 

121 try: 

122 method = await self.__lookup_drivercall(request.method, context, MARKER_STREAMING_DRIVERCALL) 

123 

124 args = [decode_value(arg) for arg in request.args] 

125 

126 if isasyncgenfunction(method): 

127 async for result in method(*args): 

128 yield jumpstarter_pb2.StreamingDriverCallResponse( 

129 uuid=str(uuid4()), 

130 result=encode_value(result), 

131 ) 

132 else: 

133 for result in await to_thread.run_sync(method, *args): 

134 yield jumpstarter_pb2.StreamingDriverCallResponse( 

135 uuid=str(uuid4()), 

136 result=encode_value(result), 

137 ) 

138 except NotImplementedError as e: 

139 await context.abort(StatusCode.UNIMPLEMENTED, str(e)) 

140 except ValueError as e: 

141 await context.abort(StatusCode.INVALID_ARGUMENT, str(e)) 

142 except TimeoutError as e: 

143 await context.abort(StatusCode.DEADLINE_EXCEEDED, str(e)) 

144 except Exception as e: 

145 await context.abort(StatusCode.UNKNOWN, str(e)) 

146 

147 @asynccontextmanager 

148 async def Stream(self, request, context): 

149 """ 

150 :meta private: 

151 """ 

152 match request: 

153 case DriverStreamRequest(method=driver_method): 

154 method = await self.__lookup_drivercall(driver_method, context, MARKER_STREAMCALL) 

155 

156 async with method() as stream: 

157 yield stream 

158 

159 case ResourceStreamRequest(): 

160 remote, resource = create_memory_stream() 

161 

162 resource_uuid = uuid4() 

163 

164 self.resources[resource_uuid] = resource 

165 

166 async with MetadataStream( 

167 stream=remote, 

168 metadata=ResourceMetadata.model_construct( 

169 resource=ClientStreamResource(uuid=resource_uuid) 

170 ).model_dump(mode="json", round_trip=True), 

171 ) as stream: 

172 yield stream 

173 

174 def report(self, *, root=None, parent=None, name=None): 

175 """ 

176 Create DriverInstanceReport 

177 

178 :meta private: 

179 """ 

180 

181 if root is None: 

182 root = self 

183 

184 return jumpstarter_pb2.DriverInstanceReport( 

185 uuid=str(self.uuid), 

186 parent_uuid=str(parent.uuid) if parent else None, 

187 labels=self.labels 

188 | self.extra_labels() 

189 | ({"jumpstarter.dev/client": self.client()}) 

190 | ({"jumpstarter.dev/name": name} if name else {}), 

191 ) 

192 

193 def enumerate(self, *, root=None, parent=None, name=None): 

194 """ 

195 Get list of self and child devices 

196 

197 :meta private: 

198 """ 

199 if root is None: 

200 root = self 

201 

202 return [(self.uuid, parent, name, self)] + list( 

203 chain(*[child.enumerate(root=root, parent=self, name=cname) for (cname, child) in self.children.items()]) 

204 ) 

205 

206 @asynccontextmanager 

207 async def resource(self, handle: str, timeout: int = 300): 

208 handle = TypeAdapter(Resource).validate_python(handle) 

209 match handle: 

210 case ClientStreamResource(uuid=uuid): 

211 async with self.resources[uuid] as stream: 

212 try: 

213 yield stream 

214 finally: 

215 del self.resources[uuid] 

216 case PresignedRequestResource(headers=headers, url=url, method=method): 

217 client_timeout = aiohttp.ClientTimeout(total=timeout) 

218 match method: 

219 case "GET": 

220 async with aiohttp.request( 

221 method, url, headers=headers, raise_for_status=True, timeout=client_timeout 

222 ) as resp: 

223 async with AiohttpStreamReaderStream(reader=resp.content) as stream: 

224 yield ProgressStream(stream=stream, logging=True) 

225 case "PUT": 

226 remote, stream = create_memory_stream() 

227 async with aiohttp.request( 

228 method, url, headers=headers, raise_for_status=True, data=remote, timeout=client_timeout 

229 ) as resp: 

230 async with stream: 

231 yield ProgressStream(stream=stream, logging=True) 

232 case _: 

233 # INVARIANT: method is always one of GET or PUT, see PresignedRequestResource 

234 raise ValueError("unreachable") 

235 

236 async def __lookup_drivercall(self, name, context, marker): 

237 """Lookup drivercall by method name 

238 

239 Methods are checked against magic markers 

240 to avoid accidentally calling non-exported 

241 methods 

242 """ 

243 method = getattr(self, name, None) 

244 

245 if method is None: 

246 await context.abort(StatusCode.NOT_FOUND, f"method {name} not found on driver") 

247 

248 if getattr(method, marker, None) != MARKER_MAGIC: 

249 await context.abort(StatusCode.NOT_FOUND, f"method {name} missing marker {marker}") 

250 

251 return method