Coverage for src/par_run/executor.py: 85%

271 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-13 15:32 -0400

1"""Todo""" 

2 

3import asyncio 

4import configparser 

5import enum 

6import multiprocessing as mp 

7import os 

8import queue 

9import subprocess 

10import time 

11from collections import OrderedDict 

12from concurrent.futures import Future, ProcessPoolExecutor 

13from pathlib import Path 

14from queue import Queue 

15from typing import Any, Optional, Protocol, TypeVar, Union 

16 

17import tomlkit 

18from pydantic import BaseModel, ConfigDict, Field 

19 

20# Type alias for a generic future. 

21GenFuture = Union[Future, asyncio.Future] 

22 

23ContextT = TypeVar("ContextT") 

24 

25 

26class ProcessingStrategy(enum.Enum): 

27 """Enum for processing strategies.""" 

28 

29 ON_COMP = "comp" 

30 ON_RECV = "recv" 

31 

32 

33class CommandStatus(enum.Enum): 

34 """Enum for command status.""" 

35 

36 NOT_STARTED = "Not Started" 

37 RUNNING = "Running" 

38 SUCCESS = "Success" 

39 FAILURE = "Failure" 

40 

41 def completed(self) -> bool: 

42 """Return True if the command has completed.""" 

43 return self in [CommandStatus.SUCCESS, CommandStatus.FAILURE] 

44 

45 

46class Command(BaseModel): 

47 """Holder for a command and its name.""" 

48 

49 model_config = ConfigDict(arbitrary_types_allowed=True) 

50 

51 name: str 

52 cmd: str 

53 passenv: Optional[list[str]] = Field(default=None) 

54 setenv: Optional[dict[str, str]] = Field(default=None) 

55 status: CommandStatus = CommandStatus.NOT_STARTED 

56 unflushed: list[str] = Field(default=[], exclude=True) 

57 num_non_empty_lines: int = Field(default=0, exclude=True) 

58 ret_code: Optional[int] = Field(default=None, exclude=True) 

59 fut: Optional[GenFuture] = Field(default=None, exclude=True) 

60 start_time: Optional[float] = Field(default=None, exclude=True) 

61 elapsed: Optional[float] = Field(default=None, exclude=True) 

62 

63 def incr_line_count(self, line: str) -> None: 

64 """Increment the non-empty line count.""" 

65 if line.strip(): 

66 self.num_non_empty_lines += 1 

67 

68 def append_unflushed(self, line: str) -> None: 

69 """Append a line to the output and increment the non-empty line count.""" 

70 self.unflushed.append(line) 

71 

72 def clear_unflushed(self) -> None: 

73 """Clear the unflushed output.""" 

74 self.unflushed.clear() 

75 

76 def set_ret_code(self, ret_code: int): 

77 """Set the return code and status of the command.""" 

78 if self.start_time: 

79 self.elapsed = time.perf_counter() - self.start_time 

80 self.ret_code = ret_code 

81 if self.fut: 

82 self.fut.cancel() 

83 self.fut = None 

84 if ret_code == 0: 

85 self.status = CommandStatus.SUCCESS 

86 else: 

87 self.status = CommandStatus.FAILURE 

88 

89 def set_running(self): 

90 """Set the command status to running.""" 

91 self.start_time = time.perf_counter() 

92 self.status = CommandStatus.RUNNING 

93 

94 

95class CommandCB(Protocol): 

96 def on_start(self, cmd: Command) -> None: ... 

97 def on_recv(self, cmd: Command, output: str) -> None: ... 

98 def on_term(self, cmd: Command, exit_code: int) -> None: ... 

99 

100 

101class CommandAsyncCB(Protocol): 

102 async def on_start(self, cmd: Command) -> None: ... 

103 async def on_recv(self, cmd: Command, output: str) -> None: ... 

104 async def on_term(self, cmd: Command, exit_code: int) -> None: ... 

105 

106 

107class QRetriever: 

108 def __init__(self, q: Queue, timeout: int, retries: int): 

109 self.q = q 

110 self.timeout = timeout 

111 self.retries = retries 

112 

113 def get(self): 

114 retry_count = 0 

115 while True: 

116 try: 

117 return self.q.get(block=True, timeout=self.timeout) 

118 except queue.Empty: # noqa: PERF203 

119 if retry_count < self.retries: 

120 retry_count += 1 

121 continue 

122 else: 

123 raise TimeoutError("Timeout waiting for command output") from None 

124 

125 def __str__(self) -> str: 

126 return f"QRetriever(timeout={self.timeout}, retries={self.retries})" 

127 

128 

129class CommandGroup(BaseModel): 

130 """Holder for a group of commands.""" 

131 

132 name: str 

133 desc: Optional[str] = None 

134 cmds: OrderedDict[str, Command] = Field(default_factory=OrderedDict) 

135 timeout: int = Field(default=30) 

136 retries: int = Field(default=3) 

137 cont_on_fail: bool = Field(default=False) 

138 serial: bool = Field(default=False) 

139 status: CommandStatus = CommandStatus.NOT_STARTED 

140 

141 def update_status(self, cmds: OrderedDict[str, Command]): 

142 """Update the status of the command group.""" 

143 if all(cmd.status == CommandStatus.SUCCESS for cmd in cmds.values()): 

144 self.status = CommandStatus.SUCCESS 

145 else: 

146 self.status = CommandStatus.FAILURE 

147 

148 async def run_async( 

149 self, 

150 strategy: ProcessingStrategy, 

151 callbacks: CommandAsyncCB, 

152 ): 

153 q = mp.Manager().Queue() 

154 pool = ProcessPoolExecutor() 

155 futs = [ 

156 asyncio.get_event_loop().run_in_executor(pool, run_command, cmd.name, cmd.cmd, cmd.setenv, q) 

157 for _, cmd in self.cmds.items() 

158 ] 

159 

160 for (_, cmd), fut in zip(self.cmds.items(), futs): 

161 cmd.fut = fut 

162 cmd.set_running() 

163 

164 return await self._process_q_async(q, strategy, callbacks) 

165 

166 def run(self, strategy: ProcessingStrategy, callbacks: CommandCB): 

167 q = mp.Manager().Queue() 

168 pool = ProcessPoolExecutor() 

169 cmd_series = [OrderedDict([(k, v)]) for k, v in self.cmds.items()] if self.serial else [self.cmds] 

170 group_exit_code = 0 

171 

172 for cmd_entries in cmd_series: 

173 futs = [pool.submit(run_command, cmd.name, cmd.cmd, cmd.setenv, q) for cmd in cmd_entries.values()] 

174 for cmd, fut in zip(cmd_entries.values(), futs): 

175 cmd.fut = fut 

176 cmd.set_running() 

177 exit_code = self._process_q(cmd_entries, q, strategy, callbacks) 

178 if exit_code != 0: 

179 group_exit_code = 1 

180 if not self.cont_on_fail: 

181 break 

182 return group_exit_code 

183 

184 def _process_q( # noqa: PLR0912 

185 self, 

186 cmds: OrderedDict[str, Command], 

187 q: Queue, 

188 strategy: ProcessingStrategy, 

189 callbacks: CommandCB, 

190 ) -> int: 

191 grp_exit_code = 0 

192 

193 if strategy == ProcessingStrategy.ON_RECV: 

194 for cmd in cmds.values(): 

195 callbacks.on_start(cmd) 

196 

197 q_ret = QRetriever(q, self.timeout, self.retries) 

198 while True: 

199 q_result = q_ret.get() 

200 

201 # Can only get here with a valid message from the Q 

202 cmd_name = q_result[0] 

203 exit_code: Optional[int] = q_result[1] if isinstance(q_result[1], int) else None 

204 output_line: Optional[str] = q_result[1] if isinstance(q_result[1], str) else None 

205 if exit_code is None and output_line is None: 

206 raise ValueError("Invalid Q message") # pragma: no cover 

207 

208 cmd = self.cmds[cmd_name] 

209 if strategy == ProcessingStrategy.ON_RECV: 

210 if output_line is not None: 

211 cmd.incr_line_count(output_line) 

212 callbacks.on_recv(cmd, output_line) 

213 elif exit_code is not None: 

214 cmd.set_ret_code(exit_code) 

215 callbacks.on_term(cmd, exit_code) 

216 if exit_code != 0: 

217 grp_exit_code = 1 

218 else: 

219 raise ValueError("Invalid Q message") # pragma: no cover 

220 

221 if strategy == ProcessingStrategy.ON_COMP: 

222 if output_line is not None: 

223 cmd.incr_line_count(output_line) 

224 cmd.append_unflushed(output_line) 

225 elif exit_code is not None: 

226 callbacks.on_start(cmd) 

227 for line in cmd.unflushed: 

228 callbacks.on_recv(cmd, line) 

229 cmd.clear_unflushed() 

230 callbacks.on_term(cmd, exit_code) 

231 cmd.set_ret_code(exit_code) 

232 if exit_code != 0: 

233 grp_exit_code = 1 

234 else: 

235 raise ValueError("Invalid Q message") # pragma: no cover 

236 

237 if all(cmd.status.completed() for cmd in cmds.values()): 

238 self.update_status(cmds) 

239 break 

240 return grp_exit_code 

241 

242 async def _process_q_async( # noqa: PLR0912 

243 self, 

244 q: Queue, 

245 strategy: ProcessingStrategy, 

246 callbacks: CommandAsyncCB, 

247 ) -> int: 

248 grp_exit_code = 0 

249 

250 if strategy == ProcessingStrategy.ON_RECV: 

251 for cmd in self.cmds.values(): 

252 await callbacks.on_start(cmd) 

253 

254 q_ret = QRetriever(q, self.timeout, self.retries) 

255 while True: 

256 await asyncio.sleep(0) 

257 q_result = q_ret.get() 

258 

259 # Can only get here with a valid message from the Q 

260 cmd_name = q_result[0] 

261 exit_code: Optional[int] = q_result[1] if isinstance(q_result[1], int) else None 

262 output_line: Optional[str] = q_result[1] if isinstance(q_result[1], str) else None 

263 if exit_code is None and output_line is None: 

264 raise ValueError("Invalid Q message") # pragma: no cover 

265 

266 cmd = self.cmds[cmd_name] 

267 if strategy == ProcessingStrategy.ON_RECV: 

268 if output_line is not None: 

269 cmd.incr_line_count(output_line) 

270 await callbacks.on_recv(cmd, output_line) 

271 elif exit_code is not None: 

272 cmd.set_ret_code(exit_code) 

273 await callbacks.on_term(cmd, exit_code) 

274 if exit_code != 0: 

275 grp_exit_code = 1 

276 else: 

277 raise ValueError("Invalid Q message") # pragma: no cover 

278 

279 if strategy == ProcessingStrategy.ON_COMP: 

280 if output_line is not None: 

281 cmd.incr_line_count(output_line) 

282 cmd.append_unflushed(output_line) 

283 elif exit_code is not None: 

284 await callbacks.on_start(cmd) 

285 for line in cmd.unflushed: 

286 await callbacks.on_recv(cmd, line) 

287 cmd.clear_unflushed() 

288 await callbacks.on_term(cmd, exit_code) 

289 cmd.set_ret_code(exit_code) 

290 if exit_code != 0: 

291 grp_exit_code = 1 

292 else: 

293 raise ValueError("Invalid Q message") # pragma: no cover 

294 

295 if all(cmd.status.completed() for _, cmd in self.cmds.items()): 

296 break 

297 return grp_exit_code 

298 

299 

300def read_commands_ini(filename: Union[str, Path]) -> list[CommandGroup]: 

301 """Read a commands.ini file and return a list of CommandGroup objects. 

302 

303 Args: 

304 ---- 

305 filename (Union[str, Path]): The filename of the commands.ini file. 

306 

307 Returns: 

308 ------- 

309 list[CommandGroup]: A list of CommandGroup objects. 

310 

311 """ 

312 config = configparser.ConfigParser() 

313 config.read(filename) 

314 

315 command_groups = [] 

316 for section in config.sections(): 

317 if section.startswith("group."): 

318 group_name = section.replace("group.", "") 

319 commands = OrderedDict() 

320 for name, cmd in config.items(section): 

321 clean_name = name.strip() 

322 commands[clean_name] = Command(name=clean_name, cmd=cmd.strip()) 

323 command_group = CommandGroup(name=group_name, cmds=commands) 

324 command_groups.append(command_group) 

325 

326 return command_groups 

327 

328 

329def write_commands_ini(filename: Union[str, Path], command_groups: list[CommandGroup]): 

330 """Write a list of CommandGroup objects to a commands.ini file. 

331 

332 Args: 

333 ---- 

334 filename (Union[str, Path]): The filename of the commands.ini file. 

335 command_groups (list[CommandGroup]): A list of CommandGroup objects. 

336 

337 """ 

338 config = configparser.ConfigParser() 

339 

340 for group in command_groups: 

341 section_name = f"group.{group.name}" 

342 config[section_name] = {} 

343 for command in group.cmds.values(): 

344 config[section_name][command.name] = command.cmd 

345 

346 with Path(filename).open("w", encoding="utf-8") as configfile: 

347 config.write(configfile) 

348 

349 

350def _validate_mandatory_keys(data: tomlkit.items.Table, keys: list[str], context: str) -> tuple[Any, ...]: 

351 """Validate that the mandatory keys are present in the data. 

352 

353 Args: 

354 ---- 

355 data (tomlkit.items.Table): The data to validate. 

356 keys (list[str]): The mandatory keys. 

357 

358 """ 

359 vals = [] 

360 for key in keys: 

361 val = data.get(key, None) 

362 if not val: 

363 raise ValueError(f"{key} is mandatory, not found in {context}") 

364 vals.append(val) 

365 return tuple(vals) 

366 

367 

368def _get_optional_keys(data: tomlkit.items.Table, keys: list[str], default=None) -> tuple[Optional[Any], ...]: 

369 """Get Optional keys or default. 

370 

371 Args: 

372 ---- 

373 data (tomlkit.items.Table): The data to use as source 

374 keys (list[str]): The optional keys. 

375 

376 """ 

377 res = tuple(data.get(key, default) for key in keys) 

378 return res 

379 

380 

381def read_commands_toml(filename: Union[str, Path]) -> list[CommandGroup]: 

382 """Read a commands.toml file and return a list of CommandGroup objects. 

383 

384 Args: 

385 ---- 

386 filename (Union[str, Path]): The filename of the commands.toml file. 

387 

388 Returns: 

389 ------- 

390 list[CommandGroup]: A list of CommandGroup objects. 

391 

392 """ 

393 with Path(filename).open(encoding="utf-8") as toml_file: 

394 toml_data = tomlkit.parse(toml_file.read()) 

395 

396 cmd_groups_data = toml_data.get("tool", {}).get("par-run", {}) 

397 if not cmd_groups_data: 

398 raise ValueError("No par-run data found in toml file") 

399 _ = cmd_groups_data.get("description", None) 

400 

401 command_groups = [] 

402 for group_data in cmd_groups_data.get("groups", []): 

403 (group_name,) = _validate_mandatory_keys(group_data, ["name"], "top level par-run group") 

404 group_desc, group_timeout, group_retries = _get_optional_keys( 

405 group_data, 

406 ["desc", "timeout", "retries"], 

407 default=None, 

408 ) 

409 (group_cont_on_fail, group_serial) = _get_optional_keys(group_data, ["cont_on_fail", "serial"], default=False) 

410 

411 if not group_timeout: 

412 group_timeout = 30 

413 if not group_retries: 

414 group_retries = 3 

415 group_cont_on_fail = bool(group_cont_on_fail and group_cont_on_fail is True) 

416 group_serial = bool(group_serial and group_serial is True) 

417 

418 commands = OrderedDict() 

419 for cmd_data in group_data.get("commands", []): 

420 name, exec = _validate_mandatory_keys(cmd_data, ["name", "exec"], f"command group {group_name}") 

421 setenv, passenv = _get_optional_keys(cmd_data, ["setenv", "passenv"], default=None) 

422 

423 commands[name] = Command(name=name, cmd=exec, setenv=setenv, passenv=passenv) 

424 command_group = CommandGroup( 

425 name=group_name, 

426 desc=group_desc, 

427 cmds=commands, 

428 timeout=group_timeout, 

429 retries=group_retries, 

430 cont_on_fail=group_cont_on_fail, 

431 serial=group_serial, 

432 ) 

433 command_groups.append(command_group) 

434 

435 return command_groups 

436 

437 

438def run_command(name: str, command: str, setenv: Optional[dict[str, str]], q: Queue) -> None: 

439 """Run a command and put the output into a queue. The output is a tuple of the command 

440 name and the output line. The final output is a tuple of the command name and a dictionary 

441 with the return code. 

442 

443 Args: 

444 ---- 

445 name (Command): Command to run. 

446 q (Queue): Queue to put the output into. 

447 

448 """ 

449 new_env = None 

450 if setenv: 

451 new_env = os.environ.copy() 

452 new_env.update(setenv) 

453 

454 with subprocess.Popen( 

455 command, 

456 shell=True, 

457 env=new_env, 

458 stdout=subprocess.PIPE, 

459 stderr=subprocess.STDOUT, 

460 text=True, 

461 ) as process: 

462 if process.stdout: 

463 for line in iter(process.stdout.readline, ""): 

464 q.put((name, line.strip())) 

465 process.stdout.close() 

466 process.wait() 

467 ret_code = process.returncode 

468 if ret_code is not None: 

469 q.put((name, int(ret_code))) 

470 else: 

471 raise ValueError("Process has no return code") # pragma: no cover