Coverage for src/par_run/executor.py: 85%
271 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-13 15:32 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-13 15:32 -0400
1"""Todo"""
3import asyncio
4import configparser
5import enum
6import multiprocessing as mp
7import os
8import queue
9import subprocess
10import time
11from collections import OrderedDict
12from concurrent.futures import Future, ProcessPoolExecutor
13from pathlib import Path
14from queue import Queue
15from typing import Any, Optional, Protocol, TypeVar, Union
17import tomlkit
18from pydantic import BaseModel, ConfigDict, Field
20# Type alias for a generic future.
21GenFuture = Union[Future, asyncio.Future]
23ContextT = TypeVar("ContextT")
26class ProcessingStrategy(enum.Enum):
27 """Enum for processing strategies."""
29 ON_COMP = "comp"
30 ON_RECV = "recv"
33class CommandStatus(enum.Enum):
34 """Enum for command status."""
36 NOT_STARTED = "Not Started"
37 RUNNING = "Running"
38 SUCCESS = "Success"
39 FAILURE = "Failure"
41 def completed(self) -> bool:
42 """Return True if the command has completed."""
43 return self in [CommandStatus.SUCCESS, CommandStatus.FAILURE]
46class Command(BaseModel):
47 """Holder for a command and its name."""
49 model_config = ConfigDict(arbitrary_types_allowed=True)
51 name: str
52 cmd: str
53 passenv: Optional[list[str]] = Field(default=None)
54 setenv: Optional[dict[str, str]] = Field(default=None)
55 status: CommandStatus = CommandStatus.NOT_STARTED
56 unflushed: list[str] = Field(default=[], exclude=True)
57 num_non_empty_lines: int = Field(default=0, exclude=True)
58 ret_code: Optional[int] = Field(default=None, exclude=True)
59 fut: Optional[GenFuture] = Field(default=None, exclude=True)
60 start_time: Optional[float] = Field(default=None, exclude=True)
61 elapsed: Optional[float] = Field(default=None, exclude=True)
63 def incr_line_count(self, line: str) -> None:
64 """Increment the non-empty line count."""
65 if line.strip():
66 self.num_non_empty_lines += 1
68 def append_unflushed(self, line: str) -> None:
69 """Append a line to the output and increment the non-empty line count."""
70 self.unflushed.append(line)
72 def clear_unflushed(self) -> None:
73 """Clear the unflushed output."""
74 self.unflushed.clear()
76 def set_ret_code(self, ret_code: int):
77 """Set the return code and status of the command."""
78 if self.start_time:
79 self.elapsed = time.perf_counter() - self.start_time
80 self.ret_code = ret_code
81 if self.fut:
82 self.fut.cancel()
83 self.fut = None
84 if ret_code == 0:
85 self.status = CommandStatus.SUCCESS
86 else:
87 self.status = CommandStatus.FAILURE
89 def set_running(self):
90 """Set the command status to running."""
91 self.start_time = time.perf_counter()
92 self.status = CommandStatus.RUNNING
95class CommandCB(Protocol):
96 def on_start(self, cmd: Command) -> None: ...
97 def on_recv(self, cmd: Command, output: str) -> None: ...
98 def on_term(self, cmd: Command, exit_code: int) -> None: ...
101class CommandAsyncCB(Protocol):
102 async def on_start(self, cmd: Command) -> None: ...
103 async def on_recv(self, cmd: Command, output: str) -> None: ...
104 async def on_term(self, cmd: Command, exit_code: int) -> None: ...
107class QRetriever:
108 def __init__(self, q: Queue, timeout: int, retries: int):
109 self.q = q
110 self.timeout = timeout
111 self.retries = retries
113 def get(self):
114 retry_count = 0
115 while True:
116 try:
117 return self.q.get(block=True, timeout=self.timeout)
118 except queue.Empty: # noqa: PERF203
119 if retry_count < self.retries:
120 retry_count += 1
121 continue
122 else:
123 raise TimeoutError("Timeout waiting for command output") from None
125 def __str__(self) -> str:
126 return f"QRetriever(timeout={self.timeout}, retries={self.retries})"
129class CommandGroup(BaseModel):
130 """Holder for a group of commands."""
132 name: str
133 desc: Optional[str] = None
134 cmds: OrderedDict[str, Command] = Field(default_factory=OrderedDict)
135 timeout: int = Field(default=30)
136 retries: int = Field(default=3)
137 cont_on_fail: bool = Field(default=False)
138 serial: bool = Field(default=False)
139 status: CommandStatus = CommandStatus.NOT_STARTED
141 def update_status(self, cmds: OrderedDict[str, Command]):
142 """Update the status of the command group."""
143 if all(cmd.status == CommandStatus.SUCCESS for cmd in cmds.values()):
144 self.status = CommandStatus.SUCCESS
145 else:
146 self.status = CommandStatus.FAILURE
148 async def run_async(
149 self,
150 strategy: ProcessingStrategy,
151 callbacks: CommandAsyncCB,
152 ):
153 q = mp.Manager().Queue()
154 pool = ProcessPoolExecutor()
155 futs = [
156 asyncio.get_event_loop().run_in_executor(pool, run_command, cmd.name, cmd.cmd, cmd.setenv, q)
157 for _, cmd in self.cmds.items()
158 ]
160 for (_, cmd), fut in zip(self.cmds.items(), futs):
161 cmd.fut = fut
162 cmd.set_running()
164 return await self._process_q_async(q, strategy, callbacks)
166 def run(self, strategy: ProcessingStrategy, callbacks: CommandCB):
167 q = mp.Manager().Queue()
168 pool = ProcessPoolExecutor()
169 cmd_series = [OrderedDict([(k, v)]) for k, v in self.cmds.items()] if self.serial else [self.cmds]
170 group_exit_code = 0
172 for cmd_entries in cmd_series:
173 futs = [pool.submit(run_command, cmd.name, cmd.cmd, cmd.setenv, q) for cmd in cmd_entries.values()]
174 for cmd, fut in zip(cmd_entries.values(), futs):
175 cmd.fut = fut
176 cmd.set_running()
177 exit_code = self._process_q(cmd_entries, q, strategy, callbacks)
178 if exit_code != 0:
179 group_exit_code = 1
180 if not self.cont_on_fail:
181 break
182 return group_exit_code
184 def _process_q( # noqa: PLR0912
185 self,
186 cmds: OrderedDict[str, Command],
187 q: Queue,
188 strategy: ProcessingStrategy,
189 callbacks: CommandCB,
190 ) -> int:
191 grp_exit_code = 0
193 if strategy == ProcessingStrategy.ON_RECV:
194 for cmd in cmds.values():
195 callbacks.on_start(cmd)
197 q_ret = QRetriever(q, self.timeout, self.retries)
198 while True:
199 q_result = q_ret.get()
201 # Can only get here with a valid message from the Q
202 cmd_name = q_result[0]
203 exit_code: Optional[int] = q_result[1] if isinstance(q_result[1], int) else None
204 output_line: Optional[str] = q_result[1] if isinstance(q_result[1], str) else None
205 if exit_code is None and output_line is None:
206 raise ValueError("Invalid Q message") # pragma: no cover
208 cmd = self.cmds[cmd_name]
209 if strategy == ProcessingStrategy.ON_RECV:
210 if output_line is not None:
211 cmd.incr_line_count(output_line)
212 callbacks.on_recv(cmd, output_line)
213 elif exit_code is not None:
214 cmd.set_ret_code(exit_code)
215 callbacks.on_term(cmd, exit_code)
216 if exit_code != 0:
217 grp_exit_code = 1
218 else:
219 raise ValueError("Invalid Q message") # pragma: no cover
221 if strategy == ProcessingStrategy.ON_COMP:
222 if output_line is not None:
223 cmd.incr_line_count(output_line)
224 cmd.append_unflushed(output_line)
225 elif exit_code is not None:
226 callbacks.on_start(cmd)
227 for line in cmd.unflushed:
228 callbacks.on_recv(cmd, line)
229 cmd.clear_unflushed()
230 callbacks.on_term(cmd, exit_code)
231 cmd.set_ret_code(exit_code)
232 if exit_code != 0:
233 grp_exit_code = 1
234 else:
235 raise ValueError("Invalid Q message") # pragma: no cover
237 if all(cmd.status.completed() for cmd in cmds.values()):
238 self.update_status(cmds)
239 break
240 return grp_exit_code
242 async def _process_q_async( # noqa: PLR0912
243 self,
244 q: Queue,
245 strategy: ProcessingStrategy,
246 callbacks: CommandAsyncCB,
247 ) -> int:
248 grp_exit_code = 0
250 if strategy == ProcessingStrategy.ON_RECV:
251 for cmd in self.cmds.values():
252 await callbacks.on_start(cmd)
254 q_ret = QRetriever(q, self.timeout, self.retries)
255 while True:
256 await asyncio.sleep(0)
257 q_result = q_ret.get()
259 # Can only get here with a valid message from the Q
260 cmd_name = q_result[0]
261 exit_code: Optional[int] = q_result[1] if isinstance(q_result[1], int) else None
262 output_line: Optional[str] = q_result[1] if isinstance(q_result[1], str) else None
263 if exit_code is None and output_line is None:
264 raise ValueError("Invalid Q message") # pragma: no cover
266 cmd = self.cmds[cmd_name]
267 if strategy == ProcessingStrategy.ON_RECV:
268 if output_line is not None:
269 cmd.incr_line_count(output_line)
270 await callbacks.on_recv(cmd, output_line)
271 elif exit_code is not None:
272 cmd.set_ret_code(exit_code)
273 await callbacks.on_term(cmd, exit_code)
274 if exit_code != 0:
275 grp_exit_code = 1
276 else:
277 raise ValueError("Invalid Q message") # pragma: no cover
279 if strategy == ProcessingStrategy.ON_COMP:
280 if output_line is not None:
281 cmd.incr_line_count(output_line)
282 cmd.append_unflushed(output_line)
283 elif exit_code is not None:
284 await callbacks.on_start(cmd)
285 for line in cmd.unflushed:
286 await callbacks.on_recv(cmd, line)
287 cmd.clear_unflushed()
288 await callbacks.on_term(cmd, exit_code)
289 cmd.set_ret_code(exit_code)
290 if exit_code != 0:
291 grp_exit_code = 1
292 else:
293 raise ValueError("Invalid Q message") # pragma: no cover
295 if all(cmd.status.completed() for _, cmd in self.cmds.items()):
296 break
297 return grp_exit_code
300def read_commands_ini(filename: Union[str, Path]) -> list[CommandGroup]:
301 """Read a commands.ini file and return a list of CommandGroup objects.
303 Args:
304 ----
305 filename (Union[str, Path]): The filename of the commands.ini file.
307 Returns:
308 -------
309 list[CommandGroup]: A list of CommandGroup objects.
311 """
312 config = configparser.ConfigParser()
313 config.read(filename)
315 command_groups = []
316 for section in config.sections():
317 if section.startswith("group."):
318 group_name = section.replace("group.", "")
319 commands = OrderedDict()
320 for name, cmd in config.items(section):
321 clean_name = name.strip()
322 commands[clean_name] = Command(name=clean_name, cmd=cmd.strip())
323 command_group = CommandGroup(name=group_name, cmds=commands)
324 command_groups.append(command_group)
326 return command_groups
329def write_commands_ini(filename: Union[str, Path], command_groups: list[CommandGroup]):
330 """Write a list of CommandGroup objects to a commands.ini file.
332 Args:
333 ----
334 filename (Union[str, Path]): The filename of the commands.ini file.
335 command_groups (list[CommandGroup]): A list of CommandGroup objects.
337 """
338 config = configparser.ConfigParser()
340 for group in command_groups:
341 section_name = f"group.{group.name}"
342 config[section_name] = {}
343 for command in group.cmds.values():
344 config[section_name][command.name] = command.cmd
346 with Path(filename).open("w", encoding="utf-8") as configfile:
347 config.write(configfile)
350def _validate_mandatory_keys(data: tomlkit.items.Table, keys: list[str], context: str) -> tuple[Any, ...]:
351 """Validate that the mandatory keys are present in the data.
353 Args:
354 ----
355 data (tomlkit.items.Table): The data to validate.
356 keys (list[str]): The mandatory keys.
358 """
359 vals = []
360 for key in keys:
361 val = data.get(key, None)
362 if not val:
363 raise ValueError(f"{key} is mandatory, not found in {context}")
364 vals.append(val)
365 return tuple(vals)
368def _get_optional_keys(data: tomlkit.items.Table, keys: list[str], default=None) -> tuple[Optional[Any], ...]:
369 """Get Optional keys or default.
371 Args:
372 ----
373 data (tomlkit.items.Table): The data to use as source
374 keys (list[str]): The optional keys.
376 """
377 res = tuple(data.get(key, default) for key in keys)
378 return res
381def read_commands_toml(filename: Union[str, Path]) -> list[CommandGroup]:
382 """Read a commands.toml file and return a list of CommandGroup objects.
384 Args:
385 ----
386 filename (Union[str, Path]): The filename of the commands.toml file.
388 Returns:
389 -------
390 list[CommandGroup]: A list of CommandGroup objects.
392 """
393 with Path(filename).open(encoding="utf-8") as toml_file:
394 toml_data = tomlkit.parse(toml_file.read())
396 cmd_groups_data = toml_data.get("tool", {}).get("par-run", {})
397 if not cmd_groups_data:
398 raise ValueError("No par-run data found in toml file")
399 _ = cmd_groups_data.get("description", None)
401 command_groups = []
402 for group_data in cmd_groups_data.get("groups", []):
403 (group_name,) = _validate_mandatory_keys(group_data, ["name"], "top level par-run group")
404 group_desc, group_timeout, group_retries = _get_optional_keys(
405 group_data,
406 ["desc", "timeout", "retries"],
407 default=None,
408 )
409 (group_cont_on_fail, group_serial) = _get_optional_keys(group_data, ["cont_on_fail", "serial"], default=False)
411 if not group_timeout:
412 group_timeout = 30
413 if not group_retries:
414 group_retries = 3
415 group_cont_on_fail = bool(group_cont_on_fail and group_cont_on_fail is True)
416 group_serial = bool(group_serial and group_serial is True)
418 commands = OrderedDict()
419 for cmd_data in group_data.get("commands", []):
420 name, exec = _validate_mandatory_keys(cmd_data, ["name", "exec"], f"command group {group_name}")
421 setenv, passenv = _get_optional_keys(cmd_data, ["setenv", "passenv"], default=None)
423 commands[name] = Command(name=name, cmd=exec, setenv=setenv, passenv=passenv)
424 command_group = CommandGroup(
425 name=group_name,
426 desc=group_desc,
427 cmds=commands,
428 timeout=group_timeout,
429 retries=group_retries,
430 cont_on_fail=group_cont_on_fail,
431 serial=group_serial,
432 )
433 command_groups.append(command_group)
435 return command_groups
438def run_command(name: str, command: str, setenv: Optional[dict[str, str]], q: Queue) -> None:
439 """Run a command and put the output into a queue. The output is a tuple of the command
440 name and the output line. The final output is a tuple of the command name and a dictionary
441 with the return code.
443 Args:
444 ----
445 name (Command): Command to run.
446 q (Queue): Queue to put the output into.
448 """
449 new_env = None
450 if setenv:
451 new_env = os.environ.copy()
452 new_env.update(setenv)
454 with subprocess.Popen(
455 command,
456 shell=True,
457 env=new_env,
458 stdout=subprocess.PIPE,
459 stderr=subprocess.STDOUT,
460 text=True,
461 ) as process:
462 if process.stdout:
463 for line in iter(process.stdout.readline, ""):
464 q.put((name, line.strip()))
465 process.stdout.close()
466 process.wait()
467 ret_code = process.returncode
468 if ret_code is not None:
469 q.put((name, int(ret_code)))
470 else:
471 raise ValueError("Process has no return code") # pragma: no cover