Coverage for src/par_run/cli.py: 90%

219 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-13 15:27 -0400

1"""CLI for running commands in parallel""" 

2 

3import contextlib 

4import enum 

5from collections import OrderedDict 

6from pathlib import Path 

7from typing import Annotated, Optional 

8 

9import rich 

10import typer 

11 

12from .executor import Command, CommandGroup, CommandStatus, ProcessingStrategy, read_commands_toml 

13 

14PID_FILE = ".par-run.uvicorn.pid" 

15 

16cli_app = typer.Typer() 

17 

18 

19# Web only functions 

20def clean_up(): 

21 """Clean up by removing the PID file.""" 

22 Path(PID_FILE).unlink() 

23 typer.echo("Cleaned up PID file.") 

24 

25 

26def start_web_server(port: int): 

27 """Start the web server""" 

28 if Path(PID_FILE).is_file(): 

29 typer.echo("UVicorn server is already running.") 

30 sys.exit(1) 

31 with Path(PID_FILE).open("w", encoding="utf-8") as pid_file: 

32 typer.echo(f"Starting UVicorn server on port {port}...") 

33 uvicorn_command = [ 

34 "uvicorn", 

35 "par_run.web:ws_app", 

36 "--host", 

37 "0.0.0.0", 

38 "--port", 

39 str(port), 

40 ] 

41 process = subprocess.Popen(uvicorn_command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 

42 pid_file.write(str(process.pid)) 

43 

44 # Wait for UVicorn to start 

45 wait_time = 3 * 10**9 # 3 seconds 

46 start_time = time.time_ns() 

47 

48 while time.time_ns() - start_time < wait_time: 

49 test_port = get_process_port(process.pid) 

50 if port == test_port: 

51 typer.echo(f"UVicorn server is running on port {port} in {(time.time_ns() - start_time)/10**6:.2f} ms.") 

52 break 

53 time.sleep(0.1) # Poll every 0.1 seconds 

54 

55 else: 

56 typer.echo(f"UVicorn server did not respond within {wait_time} seconds.") 

57 typer.echo("run 'par-run web status' to check the status.") 

58 

59 

60def stop_web_server(): 

61 """Stop the UVicorn server by reading its PID from the PID file and sending a termination signal.""" 

62 if not Path(PID_FILE).is_file(): 

63 typer.echo("UVicorn server is not running.") 

64 return 

65 

66 with Path(PID_FILE).open() as pid_file: 

67 pid = int(pid_file.read().strip()) 

68 

69 typer.echo(f"Stopping UVicorn server with {pid=:}...") 

70 with contextlib.suppress(ProcessLookupError): 

71 os.kill(pid, signal.SIGTERM) 

72 clean_up() 

73 

74 

75def get_process_port(pid: int) -> Optional[int]: 

76 process = psutil.Process(pid) 

77 connections = process.connections() 

78 if connections: 

79 port = connections[0].laddr.port 

80 return port 

81 return None 

82 

83 

84def list_uvicorn_processes(): 

85 """Check for other UVicorn processes and list them""" 

86 uvicorn_processes = [] 

87 with contextlib.suppress(psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): 

88 for process in psutil.process_iter(): 

89 process_name = process.name() 

90 if "uvicorn" in process_name.lower(): 

91 uvicorn_processes.append(process) 

92 

93 if uvicorn_processes: 

94 typer.echo("Other UVicorn processes:") 

95 for process in uvicorn_processes: 

96 typer.echo(f"PID: {process.pid}, Name: {process.name()}") 

97 else: 

98 typer.echo("No other UVicorn processes found.") 

99 

100 

101def get_web_server_status(): 

102 """Get the status of the UVicorn server by reading its PID from the PID file.""" 

103 if not Path(PID_FILE).is_file(): 

104 typer.echo("No pid file found. Server likely not running.") 

105 list_uvicorn_processes() 

106 return 

107 

108 with Path(PID_FILE).open() as pid_file: 

109 pid = int(pid_file.read().strip()) 

110 if psutil.pid_exists(pid): 

111 port = get_process_port(pid) 

112 if port: 

113 typer.echo(f"UVicorn server is running with {pid=}, {port=}") 

114 else: 

115 typer.echo(f"UVicorn server is running with {pid=:}, couldn't determine port.") 

116 else: 

117 typer.echo("UVicorn server is not running but pid files exists, deleting it.") 

118 clean_up() 

119 

120 

121class WebCommand(enum.Enum): 

122 """Web command enumeration.""" 

123 

124 START = "start" 

125 STOP = "stop" 

126 RESTART = "restart" 

127 STATUS = "status" 

128 

129 def __str__(self): 

130 return self.value 

131 

132 

133class CLICommandCBOnComp: 

134 def on_start(self, cmd: Command): 

135 rich.print(f"[blue bold]Completed command {cmd.name}[/]") 

136 

137 def on_recv(self, _: Command, output: str): 

138 rich.print(output) 

139 

140 def on_term(self, cmd: Command, exit_code: int): 

141 """Callback function for when a command receives output""" 

142 if cmd.status == CommandStatus.SUCCESS: 

143 rich.print(f"[green bold]Command {cmd.name} finished[/]") 

144 elif cmd.status == CommandStatus.FAILURE: 

145 rich.print(f"[red bold]Command {cmd.name} failed, {exit_code=:}[/]") 

146 

147 

148class CLICommandCBOnRecv: 

149 def on_start(self, cmd: Command): 

150 rich.print(f"[blue bold]{cmd.name}: Started[/]") 

151 

152 def on_recv(self, cmd: Command, output: str): 

153 rich.print(f"{cmd.name}: {output}") 

154 

155 def on_term(self, cmd: Command, exit_code: int): 

156 """Callback function for when a command receives output""" 

157 if cmd.status == CommandStatus.SUCCESS: 

158 rich.print(f"[green bold]{cmd.name}: Finished[/]") 

159 elif cmd.status == CommandStatus.FAILURE: 

160 rich.print(f"[red bold]{cmd.name}: Failed, {exit_code=:}[/]") 

161 

162 

163def format_elapsed_time(seconds: float) -> str: 

164 """Converts a number of seconds into a human-readable time format of HH:MM:SS.xxx 

165 

166 Args: 

167 ---- 

168 seconds (float): The number of seconds elapsed. 

169 

170 Returns: 

171 ------- 

172 str: The formatted time string. 

173 

174 """ 

175 hours = int(seconds) // 3600 

176 minutes = (int(seconds) % 3600) // 60 

177 seconds = seconds % 60 # Keeping the fractional part of seconds 

178 

179 # Return formatted string with seconds rounded to 2 d.p. 

180 return f"{hours:02}:{minutes:02}:{seconds:06.3f}" 

181 

182 

183def show_commands(groups: list[CommandGroup]): 

184 for grp in groups: 

185 rich.print(f"[blue bold]Group: {grp.name}[/]") 

186 rich.print( 

187 f"Params: cont_on_fail={grp.cont_on_fail}, serial={grp.serial}, timeout={grp.timeout}, retries={grp.retries}" 

188 ) 

189 for cmd in grp.cmds.values(): 

190 rich.print(f"[green bold]{cmd.name}[/]: {cmd.cmd}") 

191 

192 

193def filter_groups( 

194 group_list: list[CommandGroup], filter_groups: Optional[str], filter_cmds: Optional[str] 

195) -> list[CommandGroup]: 

196 if filter_groups: 

197 group_list = [grp for grp in group_list if grp.name in [g.strip() for g in filter_groups.split(",")]] 

198 

199 if filter_cmds: 

200 for grp in group_list: 

201 grp.cmds = OrderedDict( 

202 { 

203 cmd_name: cmd 

204 for cmd_name, cmd in grp.cmds.items() 

205 if cmd_name in [c.strip() for c in filter_cmds.split(",")] 

206 }, 

207 ) 

208 group_list = [grp for grp in group_list if grp.cmds] 

209 return group_list 

210 

211 

212def add_table_break(tbl: rich.table.Table, break_ch: str = "-", break_style: Optional[str] = None) -> rich.table.Table: 

213 break_data: list[str] = [break_ch * int(col.width) for col in tbl.columns if col.width is not None] 

214 tbl.add_row( 

215 *break_data, 

216 style=break_style, 

217 ) 

218 return tbl 

219 

220 

221def build_results_tbl() -> rich.table.Table: 

222 res_tbl = rich.table.Table(title="Results", show_header=True, header_style="bold blue", box=rich.box.ROUNDED) 

223 group_w, name_w, cmd_w, status_w, elap_w = (10, 15, 50, 6, 12) 

224 res_tbl.add_column("Group", style="bold blue", width=group_w, no_wrap=True) 

225 res_tbl.add_column("Name", style="bold blue", width=name_w, no_wrap=True) 

226 res_tbl.add_column("Command", style="bold blue", width=cmd_w, no_wrap=True) 

227 res_tbl.add_column("Status", style="bold blue", width=status_w, no_wrap=True) 

228 res_tbl.add_column("Elapsed", style="bold blue", width=elap_w, no_wrap=True) 

229 return res_tbl 

230 

231 

232def add_command_row(tbl: rich.table.Table, cmd: Command, group_name: str) -> rich.table.Table: 

233 elap_str = format_elapsed_time(cmd.elapsed) if cmd.elapsed else "XX:XX:XX.xxx" 

234 

235 if cmd.status == CommandStatus.SUCCESS: 

236 cmd_status = "✅" 

237 row_style = "green" 

238 elif cmd.status == CommandStatus.FAILURE: 

239 cmd_status = "❌" 

240 row_style = "red" 

241 else: 

242 cmd_status = "⏳" 

243 row_style = "yellow" 

244 

245 tbl.add_row(group_name, cmd.name, cmd.cmd, cmd_status, elap_str, style=row_style) 

246 return tbl 

247 

248 

249def fmt_group_name(cmd_group: CommandGroup) -> str: 

250 if cmd_group.status == CommandStatus.SUCCESS: 

251 return f"[green]{cmd_group.name}[/]" 

252 elif cmd_group.status == CommandStatus.FAILURE: 

253 return f"[red]{cmd_group.name}[/]" 

254 else: 

255 return f"[yellow]{cmd_group.name}[/]" 

256 

257 

258style_default = typer.Option(help="Processing strategy", default="comp") 

259show_default = typer.Option(help="Show available groups and commands", default=False) 

260pyproj_default = typer.Option(help="The default toml file to use", default=Path("pyproject.toml")) 

261groups_default = typer.Option(help="Run a specific group of commands, comma spearated", default=None) 

262cmds_default = typer.Option(help="Run specific commands, comma separated", default=None) 

263 

264 

265@cli_app.command() 

266def run( 

267 style: Annotated[ProcessingStrategy, typer.Option] = style_default, 

268 show: Annotated[bool, typer.Option] = show_default, 

269 file: Annotated[Path, typer.Option] = pyproj_default, 

270 groups: Annotated[Optional[str], typer.Option] = groups_default, 

271 cmds: Annotated[Optional[str], typer.Option] = cmds_default, 

272): 

273 """Run commands in parallel""" 

274 # Overall exit code, need to track all command exit codes to update this 

275 exit_code = 0 

276 st_all = time.perf_counter() 

277 

278 master_groups = read_commands_toml(file) 

279 if show: 

280 return show_commands(master_groups) 

281 

282 master_groups = filter_groups(master_groups, groups, cmds) 

283 

284 if not master_groups: 

285 rich.print("[blue]No groups or commands found.[/]") 

286 raise typer.Exit(0) 

287 

288 for grp in master_groups: 

289 if style == ProcessingStrategy.ON_COMP: 

290 exit_code = grp.run(style, CLICommandCBOnComp()) 

291 elif style == ProcessingStrategy.ON_RECV: 

292 exit_code = grp.run(style, CLICommandCBOnRecv()) 

293 else: 

294 raise typer.BadParameter("Invalid processing strategy") 

295 if exit_code != 0 and not grp.cont_on_fail: 

296 break 

297 

298 # Summarise the results 

299 console = rich.console.Console() 

300 res_tbl = build_results_tbl() 

301 

302 for grp_ix, grp in enumerate(master_groups): 

303 for ix, cmd in enumerate(grp.cmds.values()): 

304 if grp_ix > 0 and ix == 0: 

305 add_table_break(res_tbl) 

306 grp_name = fmt_group_name(grp) 

307 if ix > 0: 

308 grp_name = "" 

309 add_command_row(res_tbl, cmd, grp_name) 

310 

311 console.print(res_tbl) 

312 end_style = "[green bold]" if exit_code == 0 else "[red bold]" 

313 rich.print(f"\n{end_style}Total elapsed time: {format_elapsed_time(time.perf_counter() - st_all)}[/]") 

314 raise typer.Exit(exit_code) 

315 

316 

317try: 

318 import os 

319 import signal 

320 import subprocess 

321 import sys 

322 import time 

323 from pathlib import Path 

324 from typing import Optional 

325 

326 import psutil 

327 import typer 

328 

329 rich.print("[blue]Web commands loaded[/]") 

330 

331 PID_FILE = ".par-run.uvicorn.pid" 

332 

333 command_default = typer.Argument(..., help="command to control/interract with the web server") 

334 port_default = typer.Option(8001, help="Port to run the web server") 

335 

336 @cli_app.command() 

337 def web( 

338 command: WebCommand = command_default, 

339 port: int = port_default, 

340 ): 

341 """Run the web server""" 

342 if command == WebCommand.START: 

343 start_web_server(port) 

344 elif command == WebCommand.STOP: 

345 stop_web_server() 

346 elif command == WebCommand.RESTART: 

347 stop_web_server() 

348 start_web_server(port) 

349 elif command == WebCommand.STATUS: 

350 get_web_server_status() 

351 else: 

352 typer.echo(f"Not a valid command '{command}'", err=True) 

353 raise typer.Abort() 

354 

355except ImportError: # pragma: no cover 

356 pass # pragma: no cover 

357 

358if __name__ == "__main__": 

359 cli_app()