Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/client/lease.py: 43%

129 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-06 10:20 +0200

1import logging 

2from contextlib import ( 

3 AbstractAsyncContextManager, 

4 AbstractContextManager, 

5 ExitStack, 

6 asynccontextmanager, 

7 contextmanager, 

8) 

9from dataclasses import dataclass, field 

10from datetime import datetime, timedelta 

11from typing import Any 

12 

13from anyio import create_task_group, fail_after, sleep 

14from anyio.from_thread import BlockingPortal 

15from grpc.aio import Channel 

16from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc 

17 

18from .exceptions import LeaseError 

19from jumpstarter.client import client_from_path 

20from jumpstarter.client.grpc import ClientService 

21from jumpstarter.common import TemporaryUnixListener 

22from jumpstarter.common.condition import condition_false, condition_message, condition_present_and_equal, condition_true 

23from jumpstarter.common.grpc import translate_grpc_exceptions 

24from jumpstarter.common.streams import connect_router_stream 

25from jumpstarter.config.tls import TLSConfigV1Alpha1 

26 

27logger = logging.getLogger(__name__) 

28 

29 

30@dataclass(kw_only=True) 

31class Lease(AbstractContextManager, AbstractAsyncContextManager): 

32 channel: Channel 

33 duration: timedelta 

34 selector: str 

35 portal: BlockingPortal 

36 namespace: str 

37 name: str | None = field(default=None) 

38 allow: list[str] 

39 unsafe: bool 

40 release: bool = True # release on contexts exit 

41 controller: jumpstarter_pb2_grpc.ControllerServiceStub = field(init=False) 

42 tls_config: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) 

43 grpc_options: dict[str, Any] = field(default_factory=dict) 

44 

45 def __post_init__(self): 

46 if hasattr(super(), "__post_init__"): 

47 super().__post_init__() 

48 

49 self.controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel) 

50 self.svc = ClientService(channel=self.channel, namespace=self.namespace) 

51 self.manager = self.portal.wrap_async_context_manager(self) 

52 

53 async def _create(self): 

54 logger.debug("Creating lease request for selector %s for duration %s", self.selector, self.duration) 

55 with translate_grpc_exceptions(): 

56 self.name = ( 

57 await self.svc.CreateLease( 

58 selector=self.selector, 

59 duration=self.duration, 

60 ) 

61 ).name 

62 logger.info("Created lease request for selector %s for duration %s", self.selector, self.duration) 

63 

64 async def get(self): 

65 with translate_grpc_exceptions(): 

66 svc = ClientService(channel=self.channel, namespace=self.namespace) 

67 return await svc.GetLease(name=self.name) 

68 

69 def request(self): 

70 """Request a lease, or verifies a lease which was already created. 

71 

72 :return: lease 

73 :rtype: Lease 

74 :raises LeaseError: if lease is unsatisfiable 

75 :raises LeaseError: if lease is not pending 

76 :raises TimeoutError: if lease is not ready after timeout 

77 """ 

78 return self.portal.call(self.request_async) 

79 

80 async def request_async(self): 

81 """Request a lease, or verifies a lease which was already created. 

82 

83 :return: lease 

84 :rtype: Lease 

85 :raises LeaseError: if lease is unsatisfiable 

86 :raises LeaseError: if lease is not pending 

87 :raises TimeoutError: if lease is not ready after timeout 

88 """ 

89 if self.name: 

90 logger.debug("Using existing lease %s", self.name) 

91 else: 

92 await self._create() 

93 return await self._acquire() 

94 

95 async def _acquire(self): 

96 """Acquire a lease. 

97 

98 Makes sure the lease is ready, and returns the lease object. 

99 """ 

100 with fail_after(300): # TODO: configurable timeout 

101 while True: 

102 logger.debug("Polling Lease %s", self.name) 

103 result = await self.get() 

104 # lease ready 

105 if condition_true(result.conditions, "Ready"): 

106 logger.debug("Lease %s acquired", self.name) 

107 return self 

108 # lease unsatisfiable 

109 if condition_true(result.conditions, "Unsatisfiable"): 

110 message = condition_message(result.conditions, "Unsatisfiable") 

111 logger.debug( 

112 "Lease %s cannot be satisfied: %s", 

113 self.name, 

114 condition_message(result.conditions, "Unsatisfiable"), 

115 ) 

116 raise LeaseError(f"the lease cannot be satisfied: {message}") 

117 

118 # lease not pending 

119 if condition_false(result.conditions, "Pending"): 

120 raise LeaseError( 

121 f"Lease {self.name} is not in pending, but it isn't in Ready or Unsatisfiable state either" 

122 ) 

123 

124 # lease released 

125 if condition_present_and_equal(result.conditions, "Ready", "False", "Released"): 

126 raise LeaseError(f"lease {self.name} released") 

127 

128 await sleep(1) 

129 

130 async def __aenter__(self): 

131 return await self.request_async() 

132 

133 async def __aexit__(self, exc_type, exc_value, traceback): 

134 if self.release: 

135 logger.info("Releasing Lease %s", self.name) 

136 await self.svc.DeleteLease( 

137 name=self.name, 

138 ) 

139 

140 def __enter__(self): 

141 # wraps the async context manager enter 

142 return self.manager.__enter__() 

143 

144 def __exit__(self, exc_type, exc_value, traceback): 

145 # wraps the async context manager exit 

146 return self.manager.__exit__(exc_type, exc_value, traceback) 

147 

148 async def handle_async(self, stream): 

149 logger.debug("Connecting to Lease with name %s", self.name) 

150 response = await self.controller.Dial(jumpstarter_pb2.DialRequest(lease_name=self.name)) 

151 async with connect_router_stream( 

152 response.router_endpoint, response.router_token, stream, self.tls_config, self.grpc_options 

153 ): 

154 pass 

155 

156 @asynccontextmanager 

157 async def serve_unix_async(self): 

158 async with TemporaryUnixListener(self.handle_async) as path: 

159 yield path 

160 

161 @asynccontextmanager 

162 async def monitor_async(self, threshold: timedelta = timedelta(minutes=5)): 

163 async def _monitor(): 

164 while True: 

165 lease = await self.get() 

166 # TODO: use effective_end_time as the authoritative source for lease end time 

167 if lease.effective_begin_time: 

168 end_time = lease.effective_begin_time + lease.duration 

169 remain = end_time - datetime.now(tz=datetime.now().astimezone().tzinfo) 

170 if remain < timedelta(0): 

171 # lease already expired, stopping monitor 

172 logger.info("Lease {} ended at {}".format(self.name, end_time)) 

173 break 

174 elif remain < threshold: 

175 # lease expiring soon, check again on expected expiration time in case it's extended 

176 logger.info("Lease {} ending soon in {} at {}".format(self.name, remain, end_time)) 

177 await sleep(threshold.total_seconds()) 

178 else: 

179 # lease still active, check again in 5 seconds 

180 await sleep(5) 

181 else: 

182 await sleep(1) 

183 

184 async with create_task_group() as tg: 

185 tg.start_soon(_monitor) 

186 try: 

187 yield 

188 finally: 

189 tg.cancel_scope.cancel() 

190 

191 @asynccontextmanager 

192 async def connect_async(self, stack): 

193 async with self.serve_unix_async() as path: 

194 async with client_from_path(path, self.portal, stack, allow=self.allow, unsafe=self.unsafe) as client: 

195 yield client 

196 

197 @contextmanager 

198 def connect(self): 

199 with ExitStack() as stack: 

200 with self.portal.wrap_async_context_manager(self.connect_async(stack)) as client: 

201 yield client 

202 

203 @contextmanager 

204 def serve_unix(self): 

205 with self.portal.wrap_async_context_manager(self.serve_unix_async()) as path: 

206 yield path 

207 

208 @contextmanager 

209 def monitor(self, threshold: timedelta = timedelta(minutes=5)): 

210 with self.portal.wrap_async_context_manager(self.monitor_async(threshold)): 

211 yield