Coverage for src/edwh_restic_plugin/repositories/__init__.py: 49%

234 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-28 16:28 +0100

1import abc 

2import contextlib 

3import datetime 

4import heapq 

5import importlib 

6import importlib.util 

7import io 

8import os 

9import re 

10import sys 

11import typing 

12from collections import OrderedDict, defaultdict 

13from pathlib import Path 

14 

15import invoke 

16from invoke import Context 

17from invoke.exceptions import AuthFailure 

18from termcolor import cprint 

19from tqdm import tqdm 

20from typing_extensions import NotRequired 

21 

22from ..env import DOTENV, check_env, read_dotenv 

23from ..helpers import _require_restic, camel_to_snake, fix_tags 

24 

25# the path where the restic command is going to be executed 

26DEFAULT_BACKUP_FOLDER = Path("captain-hooks") 

27 

28 

29class SortableMeta(abc.ABCMeta): 

30 """ 

31 Allows sorting the class objects (not instances), which is useful for storing the class in a heapq. 

32 

33 The sort actually doesn't do anything, so you should store a tuple with a priority as the first item. 

34 The class can then be included simply for lookup, not for any sorting purposes. 

35 """ 

36 

37 def __lt__(self, other: typing.Any) -> bool: 

38 return False 

39 

40 def __gt__(self, other: typing.Any) -> bool: 

41 return False 

42 

43 

44class Repository(abc.ABC, metaclass=SortableMeta): 

45 #################### 

46 # IMPLEMENT THESE: # 

47 #################### 

48 

49 @abc.abstractmethod 

50 def setup(self) -> None: 

51 """Ensure that the settings are in the .env file""" 

52 # you probably want some `self.check_env(...)` statements here 

53 # You need at least a <REPO>_NAME and <REPO>_PASSWORD variable, 

54 # where <REPO> is the name of your Restic repository type. 

55 raise NotImplementedError("Setup undefined") 

56 

57 @abc.abstractmethod 

58 def prepare_for_restic(self, c: Context) -> None: 

59 """No environment variables need to be defined for local""" 

60 # prepare_for_restic implementations should probably start with: 

61 # env = self.env_config 

62 # os.environ["RESTIC_REPOSITORY"] = self.uri 

63 # os.environ["RESTIC_PASSWORD"] = env["<REPO>_PASSWORD"] 

64 raise NotImplementedError("Prepare for restic undefined") 

65 

66 @property 

67 @abc.abstractmethod 

68 def uri(self) -> str: 

69 """Return the prefix required for restic to indicate the protocol, for example sftp:hostname:""" 

70 raise NotImplementedError("Prefix unknown in base class") 

71 

72 ########################### 

73 # END OF NOT IMPLEMENTED, # 

74 # START BASE CLASS: # 

75 ########################### 

76 

77 def _add_missing_boilerpalte_restic_vars(self): 

78 """ 

79 HOST, URI, RESTIC_REPOSITORY and RESTIC_HOST are usually the same so if those aren't set yet, \ 

80 use a sensible default 

81 """ 

82 os.environ["HOST"] = os.environ.get("HOST") or self.hostarg 

83 os.environ["URI"] = os.environ.get("URI") or self.uri 

84 os.environ["RESTIC_HOST"] = os.environ.get("RESTIC_HOST") or self.hostarg 

85 os.environ["RESTIC_REPOSITORY"] = os.environ.get("RESTIC_REPOSITORY") or self.uri 

86 

87 def prepare_env_for_restic(self, c: Context): 

88 self.prepare_for_restic(c) # <- abstract method used by all Repositories 

89 self._add_missing_boilerpalte_restic_vars() # <- add $HOST and other common variables that could be missing 

90 

91 def __repr__(self): 

92 cls = self.__class__.__name__ 

93 try: 

94 uri = self.uri 

95 except Exception: 

96 uri = "?" 

97 

98 return f"<{cls}({uri})>" 

99 

100 # _targets: a list of file and directory paths that should be included in the backup. 

101 _targets = [".env", "./backup"] 

102 # _excluded: a list of file and directory paths that should be excluded from the backup. 

103 _excluded = [ 

104 ".git", 

105 ".idea", 

106 "backups", 

107 "*.pyc", 

108 "*.bak", 

109 "../", 

110 "./..", 

111 "errors", 

112 "sessions", 

113 "__pycache__", 

114 ] 

115 

116 _env_path: Path 

117 env_config: dict[str, str] 

118 

119 def _require_restic(self): 

120 _require_restic() 

121 

122 def __init__(self, env_path: Path = DOTENV) -> None: 

123 super().__init__() 

124 self._require_restic() 

125 env_path.touch(exist_ok=True) 

126 print("start repo init", self.__class__.__name__) 

127 self._env_path = env_path 

128 self.env_config = env = read_dotenv(env_path) 

129 os.environ |= env 

130 self._restichostname = env.get("RESTICHOSTNAME") # or None if it is not there 

131 print("end repo init", self) 

132 

133 def check_env( 

134 self, 

135 key: str, 

136 default: str | None, 

137 comment: str, 

138 prefix: str = None, 

139 suffix: str = None, 

140 postfix: str = None, 

141 path: Path = None, 

142 ): 

143 value = check_env( 

144 key=key, 

145 default=default, 

146 comment=comment, 

147 prefix=prefix, 

148 suffix=suffix, 

149 postfix=postfix, 

150 path=path or self._env_path, 

151 ) 

152 

153 # update local variant too: 

154 self.env_config[key] = value 

155 return value 

156 

157 def _restic_self_update(self, c: Context) -> None: 

158 if not c.run("restic self-update", hide=True, warn=True): 

159 # done 

160 return 

161 

162 with contextlib.suppress(AuthFailure): 

163 return c.sudo("restic self-update", hide=True, warn=True) 

164 

165 def configure(self, c: Context): 

166 """Configure the backup environment variables.""" 

167 self.prepare_env_for_restic(c) 

168 print("configure") 

169 # First, make sure restic is up-to-date 

170 self._restic_self_update(c) 

171 # This is the command used to configure the environment variables properly. 

172 c.run(f"restic init --repository-version 2 -r {self.uri}") 

173 

174 @property 

175 def hostarg(self): 

176 """Return the host argument for restic command.""" 

177 return f" --host {self._restichostname} " if self._restichostname else "" 

178 

179 @property 

180 def targets(self): 

181 """Return the target files and directories for the backup.""" 

182 return " ".join(self._targets) 

183 

184 @property 

185 def excluded(self): 

186 """Return the excluded files and directories for the backup. 

187 Here comes the files that are going to be excluded""" 

188 return " --exclude ".join(self._excluded) 

189 

190 @staticmethod 

191 def get_snapshot_from(stdout: str) -> str: 

192 """ 

193 Parses the stdout from a Restic command to extract the snapshot ID. 

194 

195 Args: 

196 - stdout (str): The stdout output from a Restic command. 

197 

198 Returns: 

199 - The snapshot ID as a string. 

200 """ 

201 snapshots_ids = re.findall(r"snapshot (.*?) saved", stdout) 

202 return snapshots_ids[-1] if snapshots_ids else None 

203 

204 @staticmethod 

205 def get_scripts(target, verb): 

206 """Retrieves the scripts that contain a restic command and returns them to 'execute_files' to execute them. 

207 

208 Args: 

209 - target (str): target is a string that specifies the target of the backup, can be a file, stream, directory, 

210 or any other object that needs to be backed up. 

211 - verb (str): is also a string that specifies the action to be performed on the target. 

212 For example, the verb could be "backup" or "restore". The verb is used in combination with the target to 

213 search for the backup script files that contain the restic command. 

214 """ 

215 # get files by verb and target. EXAMPLE backup_files_*.sh 

216 files = [str(file) for file in DEFAULT_BACKUP_FOLDER.glob(f"{verb}_{target}*")] 

217 # check if no files are found 

218 if not files: 

219 print("no files found with target:", target) 

220 sys.exit(255) 

221 

222 return files 

223 

224 def execute_files( 

225 self, 

226 c: Context, 

227 target: str, 

228 verb: str, 

229 verbose: bool, 

230 message: str = None, 

231 snapshot: str = "latest", 

232 ): 

233 """ 

234 Executes the backup scripts retrieved by 'get_scripts' function. 

235 

236 Args: 

237 - verbose (bool): A flag indicating whether to display verbose output. 

238 - target (str): The target of the backup. 

239 - verb (str): The verb associated with the backup. 

240 - message (str, optional): The message to be associated with the backup. 

241 If not provided, the current local time is used. Defaults to None. 

242 - snapshot (str, optional): The snapshot to be used for the backup. Defaults to "latest". 

243 """ 

244 self.prepare_env_for_restic(c) 

245 

246 # set snapshot available in environment for sh files 

247 os.environ["SNAPSHOT"] = snapshot 

248 

249 # Here you can make a message that you will see in the snapshots list 

250 if message is None: 

251 # If no message is provided, use the current local time as the backup message 

252 message = f"{datetime.datetime.now()} localtime" 

253 

254 # set MSG in environment for sh files 

255 os.environ["MSG"] = message 

256 

257 # get files by target and verb. see self.get_scripts for more info 

258 files = self.get_scripts(target, verb) 

259 

260 snapshots_created = [] 

261 file_codes = [] 

262 # run all backup/restore files 

263 for file in tqdm(files): 

264 if verbose: 

265 print("\033[1m running", file, "\033[0m") 

266 

267 # run the script by default with pty=True, 

268 # when the script crashes run the script again but then grab the stdout 

269 

270 print(f"{file} output: " if verbose else "", file=sys.stderr) 

271 try: 

272 ran_script: invoke.runners.Result = c.run(file, hide=not verbose, pty=True) 

273 file_codes.append(0) 

274 except invoke.exceptions.UnexpectedExit as e: 

275 ran_script = e.result 

276 file_codes.append(e.result.exited) 

277 

278 snapshot = self.get_snapshot_from(ran_script.stdout) 

279 snapshots_created.append(snapshot) 

280 

281 # send message with backup. see message for more info 

282 # also if a tag in tags is None it will be removed by fix_tags 

283 if verb != "restore": 

284 tags = fix_tags(["message", *snapshots_created]) 

285 c.run( 

286 f"restic {self.hostarg} -r {self.uri} backup --tag {','.join(tags)} --stdin --stdin-filename message", 

287 in_stream=io.StringIO(message), 

288 hide=True, 

289 ) 

290 

291 print("\n\nfile status codes:") 

292 

293 for filename, status_code in zip(files, file_codes): 

294 if status_code == 0: 

295 cprint(f"[success] {filename}", color="green") 

296 else: 

297 cprint(f"[failure ({status_code})] {filename}", color="red") 

298 

299 if worst_status_code := max(file_codes) > 0: 

300 exit(worst_status_code) 

301 

302 def backup(self, c, verbose: bool, target: str, message: str | None): 

303 """ 

304 Backs up the specified target. 

305 

306 Args: 

307 - verbose (bool): A flag indicating whether to display verbose output. 

308 - target (str): The target of the backup (e.g. 'files', 'stream'; default is all types). 

309 - verb (str): The verb associated with the backup. 

310 - message (str): The message to be associated with the backup. 

311 """ 

312 self.execute_files(c, target, "backup", verbose, message) 

313 

314 def restore(self, c, verbose: bool, target: str, snapshot: str = "latest"): 

315 """ 

316 Restores the specified target using the specified snapshot or the latest if None is given. 

317 

318 Args: 

319 - verbose (bool): A flag indicating whether to display verbose output. 

320 - target (str): The target of the restore. 

321 - verb (str): The verb associated with the restore. 

322 - snapshot (str, optional): The snapshot to be used for the restore. Defaults to "latest". 

323 """ 

324 self.execute_files(c, target, "restore", verbose, snapshot=snapshot) 

325 

326 def check(self, c): 

327 """ 

328 Checks the integrity of the backup repository. 

329 """ 

330 self.prepare_env_for_restic(c) 

331 c.run(f"restic {self.hostarg} -r {self.uri} check --read-data") 

332 

333 def snapshot(self, c: Context, tags: list[str] = None, n: int = 2, verbose: bool = False): 

334 """ 

335 a list of all the backups with a message 

336 

337 Args: 

338 - tags (list, optional): A list of tags to use for the snapshot. Defaults to None. 

339 - n (int, optional): The number of latest snapshots to show. Defaults to 2. 

340 - verbose (bool): Show more info about what's happening? 

341 

342 Returns: 

343 None. This function only prints the output to the console. 

344 """ 

345 # choose to see only the files or the stream snapshots 

346 if tags is None: 

347 tags = ["files", "stream"] 

348 

349 self.prepare_env_for_restic(c) 

350 tags_flag = "--tag " + " --tag ".join(tags) if tags else "" 

351 command = f"restic {self.hostarg} -r {self.uri} snapshots --latest {n} {tags_flag} -c" 

352 if verbose: 

353 print("$", command, file=sys.stderr) 

354 

355 stdout = c.run( 

356 command, 

357 hide=True, 

358 ).stdout 

359 

360 if verbose: 

361 print(stdout, file=sys.stderr) 

362 

363 snapshot_lines = re.findall(r"^([0-9a-z]{8})\s", stdout, re.MULTILINE) 

364 main_tag_per_snapshot = { 

365 snapshot: re.findall(rf"^{snapshot}.*?(\w*)$", stdout, re.MULTILINE) 

366 for snapshot in snapshot_lines 

367 # snapshot: re.findall(rf"^{snapshot}", stdout) for snapshot in snapshots 

368 } 

369 

370 message_snapshot_per_snapshot = defaultdict(list) # key is source, value is snapshot containing the message 

371 for snapshot, possible_tag_names in main_tag_per_snapshot.items(): 

372 tag_name = possible_tag_names[0] 

373 if tag_name != "message": 

374 continue 

375 for _, is_message_for_snapshot_id in re.findall(rf"\n{snapshot}.*(\n\s+(.*)\n)+", stdout): 

376 message_snapshot_per_snapshot[is_message_for_snapshot_id].append(snapshot) 

377 

378 for snapshot, message_snapshots in message_snapshot_per_snapshot.items(): 

379 # print all Restic messages 

380 command = f"restic {self.hostarg} -r {self.uri} dump {message_snapshots[0]} --tag message message" 

381 if verbose: 

382 print("$", command, file=sys.stderr) 

383 

384 restore_output = c.run( 

385 command, 

386 hide=True, 

387 warn=True, 

388 ).stdout 

389 

390 if verbose: 

391 print(restore_output, file=sys.stderr) 

392 

393 message = restore_output.strip() 

394 stdout = re.sub(rf"\n{snapshot}(.*)\n", rf"\n{snapshot}\1 : [{message}]\n", stdout) 

395 

396 if verbose: 

397 print("---\n", file=sys.stderr) 

398 

399 print(stdout) 

400 

401 def forget(self, c: Context): 

402 self.prepare_env_for_restic(c) 

403 

404 return c.run( 

405 "restic forget --prune --keep-daily 7 --keep-weekly 4 --keep-monthly 12 --keep-yearly 100", 

406 hide=True, 

407 warn=True, 

408 ) 

409 

410 # noop gt, lt etc methods 

411 

412 def __gt__(self, other): 

413 return False 

414 

415 def __lt__(self, other): 

416 return False 

417 

418 

419class RepositoryRegistration(typing.TypedDict): 

420 short_name: str 

421 aliases: NotRequired[tuple[str, ...]] 

422 priority: NotRequired[int] 

423 

424 

425class RepositoryRegistrations: 

426 def __init__(self) -> None: 

427 # _queue is for internal use by heapq only! 

428 # external api should use .queue !!! 

429 self._queue: list[tuple[int, typing.Type[Repository], RepositoryRegistration]] = [] 

430 # aliases stores a reference for each name to the Repo class 

431 self._aliases: dict[str, typing.Type[Repository]] = {} 

432 

433 def push(self, repo: typing.Type[Repository], settings: RepositoryRegistration): 

434 priority = settings.get("priority", -1) 

435 if priority < 0: 

436 priority = sys.maxsize - priority # very high int 

437 

438 heapq.heappush(self._queue, (priority, repo, settings)) 

439 self._aliases[settings["short_name"]] = repo 

440 for alias in settings.get("aliases", []): 

441 self._aliases[alias] = repo 

442 

443 @property 

444 def queue(self): 

445 if not self._queue: 

446 self._find_items() 

447 

448 return self._queue 

449 

450 def clear(self): 

451 self._queue = [] 

452 self._aliases = {} 

453 

454 def get(self, name: str) -> typing.Type[Repository] | None: 

455 return self._aliases.get(name) 

456 

457 def to_sorted_list(self): 

458 # No need for sorting here; heapq maintains the heap property 

459 return list(self) 

460 

461 def to_ordered_dict(self) -> OrderedDict[str, typing.Type[Repository]]: 

462 ordered_dict = OrderedDict() 

463 for _, item, settings in self.queue: 

464 ordered_dict[settings["short_name"]] = item 

465 return ordered_dict 

466 

467 def __iter__(self) -> typing.Generator[typing.Type[Repository], None, None]: 

468 return (item[1] for item in self.queue) 

469 

470 def __bool__(self): 

471 return bool(self.queue) 

472 

473 def _find_items(self) -> None: 

474 # import all registrations in this folder, so @register adds them to _queue 

475 package_path = Path(__file__).resolve().parent 

476 

477 for file_path in package_path.glob("*.py"): 

478 pkg = file_path.stem 

479 if not pkg.startswith("__"): 

480 importlib.import_module(f".{pkg}", package=__name__) 

481 

482 

483def register( 

484 short_name: str = None, 

485 aliases: tuple[str, ...] = (), 

486 priority: int = -1, 

487 # **settings: Unpack[RepositoryRegistration] # <- not really supported yet! 

488) -> typing.Callable[[typing.Type[Repository]], typing.Type[Repository]]: 

489 if isinstance(short_name, type): 

490 raise SyntaxError("Please call @register() with parentheses!") 

491 

492 def wraps(cls: typing.Type[Repository]) -> typing.Type[Repository]: 

493 if not (isinstance(cls, type) and issubclass(cls, Repository)): 

494 raise TypeError(f"Decorated class {cls} must be a subclass of Repository!") 

495 

496 settings: RepositoryRegistration = { 

497 "short_name": short_name or camel_to_snake(cls.__name__).removesuffix("_repository"), 

498 "aliases": aliases, 

499 "priority": priority, 

500 } 

501 

502 registrations.push(cls, settings) 

503 return cls 

504 

505 return wraps 

506 

507 

508registrations = RepositoryRegistrations()