Coverage for /Users/ajo/work/jumpstarter/jumpstarter/packages/jumpstarter/jumpstarter/config/client.py: 49%

162 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-06 10:20 +0200

1import os 

2from contextlib import asynccontextmanager, contextmanager 

3from datetime import timedelta 

4from pathlib import Path 

5from typing import ClassVar, Literal, Optional, Self 

6 

7import grpc 

8import yaml 

9from anyio.from_thread import BlockingPortal, start_blocking_portal 

10from pydantic import BaseModel, ConfigDict, Field, ValidationError 

11 

12from .common import CONFIG_PATH, ObjectMeta 

13from .env import JMP_DRIVERS_ALLOW, JMP_ENDPOINT, JMP_LEASE, JMP_NAME, JMP_NAMESPACE, JMP_TOKEN 

14from .grpc import call_credentials 

15from .tls import TLSConfigV1Alpha1 

16from jumpstarter.client.grpc import ClientService 

17from jumpstarter.common.exceptions import FileNotFoundError 

18from jumpstarter.common.grpc import aio_secure_channel, ssl_channel_credentials 

19 

20 

21def _allow_from_env(): 

22 allow = os.environ.get(JMP_DRIVERS_ALLOW) 

23 match allow: 

24 case None: 

25 return [], False 

26 case "UNSAFE": 

27 return [], True 

28 case _: 

29 return allow.split(","), False 

30 

31 

32class ClientConfigV1Alpha1Drivers(BaseModel): 

33 allow: list[str] = Field(default_factory=[]) 

34 unsafe: bool = Field(default=False) 

35 

36 

37class ClientConfigV1Alpha1(BaseModel): 

38 CLIENT_CONFIGS_PATH: ClassVar[Path] = CONFIG_PATH / "clients" 

39 

40 alias: str = Field(default="default") 

41 path: Path | None = Field(default=None) 

42 

43 apiVersion: Literal["jumpstarter.dev/v1alpha1"] = Field(default="jumpstarter.dev/v1alpha1") 

44 kind: Literal["ClientConfig"] = Field(default="ClientConfig") 

45 

46 metadata: ObjectMeta 

47 

48 endpoint: str 

49 tls: TLSConfigV1Alpha1 = Field(default_factory=TLSConfigV1Alpha1) 

50 token: str 

51 grpcOptions: dict[str, str | int] | None = Field(default_factory=dict) 

52 

53 drivers: ClientConfigV1Alpha1Drivers 

54 

55 async def channel(self): 

56 credentials = grpc.composite_channel_credentials( 

57 await ssl_channel_credentials(self.endpoint, self.tls), 

58 call_credentials("Client", self.metadata, self.token), 

59 ) 

60 

61 return aio_secure_channel(self.endpoint, credentials, self.grpcOptions) 

62 

63 @contextmanager 

64 def lease( 

65 self, 

66 selector: str | None = None, 

67 lease_name: str | None = None, 

68 duration: timedelta = timedelta(minutes=30), 

69 ): 

70 with start_blocking_portal() as portal: 

71 with portal.wrap_async_context_manager(self.lease_async(selector, lease_name, duration, portal)) as lease: 

72 yield lease 

73 

74 def get_exporter(self, name: str): 

75 with start_blocking_portal() as portal: 

76 return portal.call(self.get_exporter_async, name) 

77 

78 def list_exporters( 

79 self, 

80 page_size: int | None = None, 

81 page_token: str | None = None, 

82 filter: str | None = None, 

83 ): 

84 with start_blocking_portal() as portal: 

85 return portal.call(self.list_exporters_async, page_size, page_token, filter) 

86 

87 def list_leases(self, filter: str): 

88 with start_blocking_portal() as portal: 

89 return portal.call(self.list_leases_async, filter) 

90 

91 def create_lease( 

92 self, 

93 selector: str, 

94 duration: timedelta, 

95 ): 

96 with start_blocking_portal() as portal: 

97 return portal.call(self.create_lease_async, selector, duration) 

98 

99 def delete_lease( 

100 self, 

101 name: str, 

102 ): 

103 with start_blocking_portal() as portal: 

104 return portal.call(self.delete_lease_async, name) 

105 

106 def update_lease(self, name, duration: timedelta): 

107 with start_blocking_portal() as portal: 

108 return portal.call(self.update_lease_async, name, duration) 

109 

110 async def get_exporter_async(self, name: str): 

111 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) 

112 return await svc.GetExporter(name=name) 

113 

114 async def list_exporters_async( 

115 self, 

116 page_size: int | None = None, 

117 page_token: str | None = None, 

118 filter: str | None = None, 

119 ): 

120 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) 

121 return await svc.ListExporters(page_size=page_size, page_token=page_token, filter=filter) 

122 

123 async def create_lease_async( 

124 self, 

125 selector: str, 

126 duration: timedelta, 

127 ): 

128 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) 

129 return await svc.CreateLease( 

130 selector=selector, 

131 duration=duration, 

132 ) 

133 

134 async def delete_lease_async(self, name: str): 

135 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) 

136 await svc.DeleteLease( 

137 name=name, 

138 ) 

139 

140 async def list_leases_async(self, filter: str): 

141 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) 

142 return await svc.ListLeases(filter=filter) 

143 

144 async def update_lease_async(self, name, duration: timedelta): 

145 svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) 

146 return await svc.UpdateLease(name=name, duration=duration) 

147 

148 @asynccontextmanager 

149 async def lease_async( 

150 self, 

151 selector: str, 

152 lease_name: str | None, 

153 duration: timedelta, 

154 portal: BlockingPortal, 

155 ): 

156 from jumpstarter.client import Lease 

157 

158 # if no lease_name provided, check if it is set in the environment 

159 lease_name = lease_name or os.environ.get(JMP_LEASE, "") 

160 # when no lease name is provided, release the lease on exit 

161 release_lease = lease_name == "" 

162 

163 async with Lease( 

164 channel=await self.channel(), 

165 namespace=self.metadata.namespace, 

166 name=lease_name, 

167 selector=selector, 

168 duration=duration, 

169 portal=portal, 

170 allow=self.drivers.allow, 

171 unsafe=self.drivers.unsafe, 

172 release=release_lease, 

173 tls_config=self.tls, 

174 grpc_options=self.grpcOptions, 

175 ) as lease: 

176 yield lease 

177 

178 @classmethod 

179 def from_file(cls, path: os.PathLike): 

180 with open(path) as f: 

181 v = cls.model_validate(yaml.safe_load(f)) 

182 v.alias = os.path.basename(path).split(".")[0] 

183 v.path = Path(path) 

184 return v 

185 

186 @classmethod 

187 def ensure_exists(cls): 

188 """Check if the clients config dir exists, otherwise create it.""" 

189 os.makedirs(cls.CLIENT_CONFIGS_PATH, exist_ok=True) 

190 

191 @classmethod 

192 def try_from_env(cls): 

193 try: 

194 return cls.from_env() 

195 except ValidationError: 

196 return None 

197 

198 @classmethod 

199 def from_env(cls): 

200 allow, unsafe = _allow_from_env() 

201 return cls( 

202 metadata=ObjectMeta( 

203 namespace=os.environ.get(JMP_NAMESPACE), 

204 name=os.environ.get(JMP_NAME), 

205 ), 

206 endpoint=os.environ.get(JMP_ENDPOINT), 

207 token=os.environ.get(JMP_TOKEN), 

208 drivers=ClientConfigV1Alpha1Drivers( 

209 allow=allow, 

210 unsafe=unsafe, 

211 ), 

212 ) 

213 

214 @classmethod 

215 def _get_path(cls, alias: str) -> Path: 

216 """Get the regular path of a client config given an alias.""" 

217 return (cls.CLIENT_CONFIGS_PATH / alias).with_suffix(".yaml") 

218 

219 @classmethod 

220 def load(cls, alias: str) -> Self: 

221 """Load a client config by alias.""" 

222 path = cls._get_path(alias) 

223 if path.exists() is False: 

224 raise FileNotFoundError(f"Client config '{path}' does not exist.") 

225 return cls.from_file(path) 

226 

227 @classmethod 

228 def save(cls, config: Self, path: Optional[os.PathLike] = None) -> Path: 

229 """Saves a client config as YAML.""" 

230 # Ensure the clients dir exists 

231 if path is None: 

232 cls.ensure_exists() 

233 # Set the config path before saving 

234 config.path = cls._get_path(config.alias) 

235 else: 

236 config.path = Path(path) 

237 with config.path.open(mode="w") as f: 

238 yaml.safe_dump(config.model_dump(mode="json", exclude={"path", "alias"}), f, sort_keys=False) 

239 return config.path 

240 

241 @classmethod 

242 def dump_yaml(cls, config: Self) -> str: 

243 return yaml.safe_dump(config.model_dump(mode="json", exclude={"path", "alias"}), sort_keys=False) 

244 

245 @classmethod 

246 def exists(cls, alias: str) -> bool: 

247 """Check if a client config exists by alias.""" 

248 return cls._get_path(alias).exists() 

249 

250 @classmethod 

251 def list(cls) -> list[Self]: 

252 """List the available client configs.""" 

253 if cls.CLIENT_CONFIGS_PATH.exists() is False: 

254 # Return an empty list if the dir does not exist 

255 return [] 

256 

257 results = os.listdir(cls.CLIENT_CONFIGS_PATH) 

258 # Only accept YAML files in the list 

259 files = filter(lambda x: x.endswith(".yaml"), results) 

260 

261 def make_config(file: str): 

262 path = cls.CLIENT_CONFIGS_PATH / file 

263 return cls.from_file(path) 

264 

265 return list(map(make_config, files)) 

266 

267 @classmethod 

268 def delete(cls, alias: str) -> Path: 

269 """Delete a client config by alias.""" 

270 path = cls._get_path(alias) 

271 if path.exists() is False: 

272 raise FileNotFoundError(f"Client config '{path}' does not exist.") 

273 path.unlink() 

274 return path 

275 

276 

277class ClientConfigListV1Alpha1(BaseModel): 

278 api_version: Literal["jumpstarter.dev/v1alpha1"] = Field(alias="apiVersion", default="jumpstarter.dev/v1alpha1") 

279 current_config: Optional[str] = Field(alias="currentConfig") 

280 items: list[ClientConfigV1Alpha1] 

281 kind: Literal["ClientConfigList"] = Field(default="ClientConfigList") 

282 

283 def dump_json(self): 

284 return self.model_dump_json(indent=4, by_alias=True) 

285 

286 def dump_yaml(self): 

287 return yaml.safe_dump(self.model_dump(mode="json", by_alias=True), indent=2) 

288 

289 model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)