Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/client/core.py: 53%

91 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-06 10:21 +0200

1""" 

2Base classes for drivers and driver clients 

3""" 

4 

5import logging 

6from contextlib import asynccontextmanager 

7from dataclasses import dataclass, field 

8 

9from anyio import create_task_group 

10from google.protobuf import empty_pb2 

11from grpc import StatusCode 

12from grpc.aio import AioRpcError, Channel 

13from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc 

14 

15from jumpstarter.common import Metadata 

16from jumpstarter.common.exceptions import JumpstarterException 

17from jumpstarter.common.resources import ResourceMetadata 

18from jumpstarter.common.serde import decode_value, encode_value 

19from jumpstarter.common.streams import ( 

20 DriverStreamRequest, 

21 ResourceStreamRequest, 

22 StreamRequestMetadata, 

23) 

24from jumpstarter.streams.common import forward_stream 

25from jumpstarter.streams.metadata import MetadataStream, MetadataStreamAttributes 

26from jumpstarter.streams.progress import ProgressStream 

27from jumpstarter.streams.router import RouterStream 

28 

29 

30class DriverError(JumpstarterException): 

31 """ 

32 Raised when a driver call returns an error 

33 """ 

34 

35 

36class DriverMethodNotImplemented(DriverError, NotImplementedError): 

37 """ 

38 Raised when a driver method is not implemented 

39 """ 

40 

41 

42class DriverInvalidArgument(DriverError, ValueError): 

43 """ 

44 Raised when a driver method is called with invalid arguments 

45 """ 

46 

47 

48@dataclass(kw_only=True) 

49class AsyncDriverClient( 

50 Metadata, 

51 jumpstarter_pb2_grpc.ExporterServiceStub, 

52 router_pb2_grpc.RouterServiceStub, 

53): 

54 """ 

55 Async driver client base class 

56 

57 Backing implementation of blocking driver client. 

58 """ 

59 

60 channel: Channel 

61 

62 log_level: str = "INFO" 

63 logger: logging.Logger = field(init=False) 

64 

65 def __post_init__(self): 

66 if hasattr(super(), "__post_init__"): 

67 super().__post_init__() 

68 jumpstarter_pb2_grpc.ExporterServiceStub.__init__(self, self.channel) 

69 router_pb2_grpc.RouterServiceStub.__init__(self, self.channel) 

70 self.logger = logging.getLogger(self.__class__.__name__) 

71 self.logger.setLevel(self.log_level) 

72 

73 # add default handler 

74 if not self.logger.handlers: 

75 handler = logging.StreamHandler() 

76 handler.setFormatter(logging.Formatter("%(name)s - %(levelname)s - %(message)s")) 

77 self.logger.addHandler(handler) 

78 

79 async def call_async(self, method, *args): 

80 """Make DriverCall by method name and arguments""" 

81 

82 request = jumpstarter_pb2.DriverCallRequest( 

83 uuid=str(self.uuid), 

84 method=method, 

85 args=[encode_value(arg) for arg in args], 

86 ) 

87 

88 try: 

89 response = await self.DriverCall(request) 

90 except AioRpcError as e: 

91 match e.code(): 

92 case StatusCode.UNIMPLEMENTED: 

93 raise DriverMethodNotImplemented(e.details()) from None 

94 case StatusCode.INVALID_ARGUMENT: 

95 raise DriverInvalidArgument(e.details()) from None 

96 case StatusCode.UNKNOWN: 

97 raise DriverError(e.details()) from None 

98 case _: 

99 raise DriverError(e.details()) from e 

100 

101 return decode_value(response.result) 

102 

103 async def streamingcall_async(self, method, *args): 

104 """Make StreamingDriverCall by method name and arguments""" 

105 

106 request = jumpstarter_pb2.StreamingDriverCallRequest( 

107 uuid=str(self.uuid), 

108 method=method, 

109 args=[encode_value(arg) for arg in args], 

110 ) 

111 

112 try: 

113 async for response in self.StreamingDriverCall(request): 

114 yield decode_value(response.result) 

115 except AioRpcError as e: 

116 match e.code(): 

117 case StatusCode.UNIMPLEMENTED: 

118 raise DriverMethodNotImplemented(e.details()) from None 

119 case StatusCode.INVALID_ARGUMENT: 

120 raise DriverInvalidArgument(e.details()) from None 

121 case StatusCode.UNKNOWN: 

122 raise DriverError(e.details()) from None 

123 case _: 

124 raise DriverError(e.details()) from e 

125 

126 @asynccontextmanager 

127 async def stream_async(self, method): 

128 context = self.Stream( 

129 metadata=StreamRequestMetadata.model_construct(request=DriverStreamRequest(uuid=self.uuid, method=method)) 

130 .model_dump(mode="json", round_trip=True) 

131 .items(), 

132 ) 

133 metadata = dict(list(await context.initial_metadata())) 

134 async with MetadataStream(stream=RouterStream(context=context), metadata=metadata) as stream: 

135 yield stream 

136 

137 @asynccontextmanager 

138 async def resource_async( 

139 self, 

140 stream, 

141 ): 

142 context = self.Stream( 

143 metadata=StreamRequestMetadata.model_construct(request=ResourceStreamRequest(uuid=self.uuid)) 

144 .model_dump(mode="json", round_trip=True) 

145 .items(), 

146 ) 

147 metadata = dict(list(await context.initial_metadata())) 

148 async with MetadataStream(stream=RouterStream(context=context), metadata=metadata) as rstream: 

149 async with forward_stream(ProgressStream(stream=stream), rstream): 

150 yield ResourceMetadata(**rstream.extra(MetadataStreamAttributes.metadata)).resource.model_dump( 

151 mode="json" 

152 ) 

153 

154 def __log(self, level: int, msg: str): 

155 self.logger.log(level, msg) 

156 

157 @asynccontextmanager 

158 async def log_stream_async(self): 

159 async def log_stream(): 

160 async for response in self.LogStream(empty_pb2.Empty()): 

161 self.__log(logging.getLevelName(response.severity), response.message) 

162 

163 async with create_task_group() as tg: 

164 tg.start_soon(log_stream) 

165 try: 

166 yield 

167 finally: 

168 tg.cancel_scope.cancel()