Coverage for src/edwh_restic_plugin/repositories/__init__.py: 49%
234 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-28 16:28 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-28 16:28 +0100
1import abc
2import contextlib
3import datetime
4import heapq
5import importlib
6import importlib.util
7import io
8import os
9import re
10import sys
11import typing
12from collections import OrderedDict, defaultdict
13from pathlib import Path
15import invoke
16from invoke import Context
17from invoke.exceptions import AuthFailure
18from termcolor import cprint
19from tqdm import tqdm
20from typing_extensions import NotRequired
22from ..env import DOTENV, check_env, read_dotenv
23from ..helpers import _require_restic, camel_to_snake, fix_tags
25# the path where the restic command is going to be executed
26DEFAULT_BACKUP_FOLDER = Path("captain-hooks")
29class SortableMeta(abc.ABCMeta):
30 """
31 Allows sorting the class objects (not instances), which is useful for storing the class in a heapq.
33 The sort actually doesn't do anything, so you should store a tuple with a priority as the first item.
34 The class can then be included simply for lookup, not for any sorting purposes.
35 """
37 def __lt__(self, other: typing.Any) -> bool:
38 return False
40 def __gt__(self, other: typing.Any) -> bool:
41 return False
44class Repository(abc.ABC, metaclass=SortableMeta):
45 ####################
46 # IMPLEMENT THESE: #
47 ####################
49 @abc.abstractmethod
50 def setup(self) -> None:
51 """Ensure that the settings are in the .env file"""
52 # you probably want some `self.check_env(...)` statements here
53 # You need at least a <REPO>_NAME and <REPO>_PASSWORD variable,
54 # where <REPO> is the name of your Restic repository type.
55 raise NotImplementedError("Setup undefined")
57 @abc.abstractmethod
58 def prepare_for_restic(self, c: Context) -> None:
59 """No environment variables need to be defined for local"""
60 # prepare_for_restic implementations should probably start with:
61 # env = self.env_config
62 # os.environ["RESTIC_REPOSITORY"] = self.uri
63 # os.environ["RESTIC_PASSWORD"] = env["<REPO>_PASSWORD"]
64 raise NotImplementedError("Prepare for restic undefined")
66 @property
67 @abc.abstractmethod
68 def uri(self) -> str:
69 """Return the prefix required for restic to indicate the protocol, for example sftp:hostname:"""
70 raise NotImplementedError("Prefix unknown in base class")
72 ###########################
73 # END OF NOT IMPLEMENTED, #
74 # START BASE CLASS: #
75 ###########################
77 def _add_missing_boilerpalte_restic_vars(self):
78 """
79 HOST, URI, RESTIC_REPOSITORY and RESTIC_HOST are usually the same so if those aren't set yet, \
80 use a sensible default
81 """
82 os.environ["HOST"] = os.environ.get("HOST") or self.hostarg
83 os.environ["URI"] = os.environ.get("URI") or self.uri
84 os.environ["RESTIC_HOST"] = os.environ.get("RESTIC_HOST") or self.hostarg
85 os.environ["RESTIC_REPOSITORY"] = os.environ.get("RESTIC_REPOSITORY") or self.uri
87 def prepare_env_for_restic(self, c: Context):
88 self.prepare_for_restic(c) # <- abstract method used by all Repositories
89 self._add_missing_boilerpalte_restic_vars() # <- add $HOST and other common variables that could be missing
91 def __repr__(self):
92 cls = self.__class__.__name__
93 try:
94 uri = self.uri
95 except Exception:
96 uri = "?"
98 return f"<{cls}({uri})>"
100 # _targets: a list of file and directory paths that should be included in the backup.
101 _targets = [".env", "./backup"]
102 # _excluded: a list of file and directory paths that should be excluded from the backup.
103 _excluded = [
104 ".git",
105 ".idea",
106 "backups",
107 "*.pyc",
108 "*.bak",
109 "../",
110 "./..",
111 "errors",
112 "sessions",
113 "__pycache__",
114 ]
116 _env_path: Path
117 env_config: dict[str, str]
119 def _require_restic(self):
120 _require_restic()
122 def __init__(self, env_path: Path = DOTENV) -> None:
123 super().__init__()
124 self._require_restic()
125 env_path.touch(exist_ok=True)
126 print("start repo init", self.__class__.__name__)
127 self._env_path = env_path
128 self.env_config = env = read_dotenv(env_path)
129 os.environ |= env
130 self._restichostname = env.get("RESTICHOSTNAME") # or None if it is not there
131 print("end repo init", self)
133 def check_env(
134 self,
135 key: str,
136 default: str | None,
137 comment: str,
138 prefix: str = None,
139 suffix: str = None,
140 postfix: str = None,
141 path: Path = None,
142 ):
143 value = check_env(
144 key=key,
145 default=default,
146 comment=comment,
147 prefix=prefix,
148 suffix=suffix,
149 postfix=postfix,
150 path=path or self._env_path,
151 )
153 # update local variant too:
154 self.env_config[key] = value
155 return value
157 def _restic_self_update(self, c: Context) -> None:
158 if not c.run("restic self-update", hide=True, warn=True):
159 # done
160 return
162 with contextlib.suppress(AuthFailure):
163 return c.sudo("restic self-update", hide=True, warn=True)
165 def configure(self, c: Context):
166 """Configure the backup environment variables."""
167 self.prepare_env_for_restic(c)
168 print("configure")
169 # First, make sure restic is up-to-date
170 self._restic_self_update(c)
171 # This is the command used to configure the environment variables properly.
172 c.run(f"restic init --repository-version 2 -r {self.uri}")
174 @property
175 def hostarg(self):
176 """Return the host argument for restic command."""
177 return f" --host {self._restichostname} " if self._restichostname else ""
179 @property
180 def targets(self):
181 """Return the target files and directories for the backup."""
182 return " ".join(self._targets)
184 @property
185 def excluded(self):
186 """Return the excluded files and directories for the backup.
187 Here comes the files that are going to be excluded"""
188 return " --exclude ".join(self._excluded)
190 @staticmethod
191 def get_snapshot_from(stdout: str) -> str:
192 """
193 Parses the stdout from a Restic command to extract the snapshot ID.
195 Args:
196 - stdout (str): The stdout output from a Restic command.
198 Returns:
199 - The snapshot ID as a string.
200 """
201 snapshots_ids = re.findall(r"snapshot (.*?) saved", stdout)
202 return snapshots_ids[-1] if snapshots_ids else None
204 @staticmethod
205 def get_scripts(target, verb):
206 """Retrieves the scripts that contain a restic command and returns them to 'execute_files' to execute them.
208 Args:
209 - target (str): target is a string that specifies the target of the backup, can be a file, stream, directory,
210 or any other object that needs to be backed up.
211 - verb (str): is also a string that specifies the action to be performed on the target.
212 For example, the verb could be "backup" or "restore". The verb is used in combination with the target to
213 search for the backup script files that contain the restic command.
214 """
215 # get files by verb and target. EXAMPLE backup_files_*.sh
216 files = [str(file) for file in DEFAULT_BACKUP_FOLDER.glob(f"{verb}_{target}*")]
217 # check if no files are found
218 if not files:
219 print("no files found with target:", target)
220 sys.exit(255)
222 return files
224 def execute_files(
225 self,
226 c: Context,
227 target: str,
228 verb: str,
229 verbose: bool,
230 message: str = None,
231 snapshot: str = "latest",
232 ):
233 """
234 Executes the backup scripts retrieved by 'get_scripts' function.
236 Args:
237 - verbose (bool): A flag indicating whether to display verbose output.
238 - target (str): The target of the backup.
239 - verb (str): The verb associated with the backup.
240 - message (str, optional): The message to be associated with the backup.
241 If not provided, the current local time is used. Defaults to None.
242 - snapshot (str, optional): The snapshot to be used for the backup. Defaults to "latest".
243 """
244 self.prepare_env_for_restic(c)
246 # set snapshot available in environment for sh files
247 os.environ["SNAPSHOT"] = snapshot
249 # Here you can make a message that you will see in the snapshots list
250 if message is None:
251 # If no message is provided, use the current local time as the backup message
252 message = f"{datetime.datetime.now()} localtime"
254 # set MSG in environment for sh files
255 os.environ["MSG"] = message
257 # get files by target and verb. see self.get_scripts for more info
258 files = self.get_scripts(target, verb)
260 snapshots_created = []
261 file_codes = []
262 # run all backup/restore files
263 for file in tqdm(files):
264 if verbose:
265 print("\033[1m running", file, "\033[0m")
267 # run the script by default with pty=True,
268 # when the script crashes run the script again but then grab the stdout
270 print(f"{file} output: " if verbose else "", file=sys.stderr)
271 try:
272 ran_script: invoke.runners.Result = c.run(file, hide=not verbose, pty=True)
273 file_codes.append(0)
274 except invoke.exceptions.UnexpectedExit as e:
275 ran_script = e.result
276 file_codes.append(e.result.exited)
278 snapshot = self.get_snapshot_from(ran_script.stdout)
279 snapshots_created.append(snapshot)
281 # send message with backup. see message for more info
282 # also if a tag in tags is None it will be removed by fix_tags
283 if verb != "restore":
284 tags = fix_tags(["message", *snapshots_created])
285 c.run(
286 f"restic {self.hostarg} -r {self.uri} backup --tag {','.join(tags)} --stdin --stdin-filename message",
287 in_stream=io.StringIO(message),
288 hide=True,
289 )
291 print("\n\nfile status codes:")
293 for filename, status_code in zip(files, file_codes):
294 if status_code == 0:
295 cprint(f"[success] {filename}", color="green")
296 else:
297 cprint(f"[failure ({status_code})] {filename}", color="red")
299 if worst_status_code := max(file_codes) > 0:
300 exit(worst_status_code)
302 def backup(self, c, verbose: bool, target: str, message: str | None):
303 """
304 Backs up the specified target.
306 Args:
307 - verbose (bool): A flag indicating whether to display verbose output.
308 - target (str): The target of the backup (e.g. 'files', 'stream'; default is all types).
309 - verb (str): The verb associated with the backup.
310 - message (str): The message to be associated with the backup.
311 """
312 self.execute_files(c, target, "backup", verbose, message)
314 def restore(self, c, verbose: bool, target: str, snapshot: str = "latest"):
315 """
316 Restores the specified target using the specified snapshot or the latest if None is given.
318 Args:
319 - verbose (bool): A flag indicating whether to display verbose output.
320 - target (str): The target of the restore.
321 - verb (str): The verb associated with the restore.
322 - snapshot (str, optional): The snapshot to be used for the restore. Defaults to "latest".
323 """
324 self.execute_files(c, target, "restore", verbose, snapshot=snapshot)
326 def check(self, c):
327 """
328 Checks the integrity of the backup repository.
329 """
330 self.prepare_env_for_restic(c)
331 c.run(f"restic {self.hostarg} -r {self.uri} check --read-data")
333 def snapshot(self, c: Context, tags: list[str] = None, n: int = 2, verbose: bool = False):
334 """
335 a list of all the backups with a message
337 Args:
338 - tags (list, optional): A list of tags to use for the snapshot. Defaults to None.
339 - n (int, optional): The number of latest snapshots to show. Defaults to 2.
340 - verbose (bool): Show more info about what's happening?
342 Returns:
343 None. This function only prints the output to the console.
344 """
345 # choose to see only the files or the stream snapshots
346 if tags is None:
347 tags = ["files", "stream"]
349 self.prepare_env_for_restic(c)
350 tags_flag = "--tag " + " --tag ".join(tags) if tags else ""
351 command = f"restic {self.hostarg} -r {self.uri} snapshots --latest {n} {tags_flag} -c"
352 if verbose:
353 print("$", command, file=sys.stderr)
355 stdout = c.run(
356 command,
357 hide=True,
358 ).stdout
360 if verbose:
361 print(stdout, file=sys.stderr)
363 snapshot_lines = re.findall(r"^([0-9a-z]{8})\s", stdout, re.MULTILINE)
364 main_tag_per_snapshot = {
365 snapshot: re.findall(rf"^{snapshot}.*?(\w*)$", stdout, re.MULTILINE)
366 for snapshot in snapshot_lines
367 # snapshot: re.findall(rf"^{snapshot}", stdout) for snapshot in snapshots
368 }
370 message_snapshot_per_snapshot = defaultdict(list) # key is source, value is snapshot containing the message
371 for snapshot, possible_tag_names in main_tag_per_snapshot.items():
372 tag_name = possible_tag_names[0]
373 if tag_name != "message":
374 continue
375 for _, is_message_for_snapshot_id in re.findall(rf"\n{snapshot}.*(\n\s+(.*)\n)+", stdout):
376 message_snapshot_per_snapshot[is_message_for_snapshot_id].append(snapshot)
378 for snapshot, message_snapshots in message_snapshot_per_snapshot.items():
379 # print all Restic messages
380 command = f"restic {self.hostarg} -r {self.uri} dump {message_snapshots[0]} --tag message message"
381 if verbose:
382 print("$", command, file=sys.stderr)
384 restore_output = c.run(
385 command,
386 hide=True,
387 warn=True,
388 ).stdout
390 if verbose:
391 print(restore_output, file=sys.stderr)
393 message = restore_output.strip()
394 stdout = re.sub(rf"\n{snapshot}(.*)\n", rf"\n{snapshot}\1 : [{message}]\n", stdout)
396 if verbose:
397 print("---\n", file=sys.stderr)
399 print(stdout)
401 def forget(self, c: Context):
402 self.prepare_env_for_restic(c)
404 return c.run(
405 "restic forget --prune --keep-daily 7 --keep-weekly 4 --keep-monthly 12 --keep-yearly 100",
406 hide=True,
407 warn=True,
408 )
410 # noop gt, lt etc methods
412 def __gt__(self, other):
413 return False
415 def __lt__(self, other):
416 return False
419class RepositoryRegistration(typing.TypedDict):
420 short_name: str
421 aliases: NotRequired[tuple[str, ...]]
422 priority: NotRequired[int]
425class RepositoryRegistrations:
426 def __init__(self) -> None:
427 # _queue is for internal use by heapq only!
428 # external api should use .queue !!!
429 self._queue: list[tuple[int, typing.Type[Repository], RepositoryRegistration]] = []
430 # aliases stores a reference for each name to the Repo class
431 self._aliases: dict[str, typing.Type[Repository]] = {}
433 def push(self, repo: typing.Type[Repository], settings: RepositoryRegistration):
434 priority = settings.get("priority", -1)
435 if priority < 0:
436 priority = sys.maxsize - priority # very high int
438 heapq.heappush(self._queue, (priority, repo, settings))
439 self._aliases[settings["short_name"]] = repo
440 for alias in settings.get("aliases", []):
441 self._aliases[alias] = repo
443 @property
444 def queue(self):
445 if not self._queue:
446 self._find_items()
448 return self._queue
450 def clear(self):
451 self._queue = []
452 self._aliases = {}
454 def get(self, name: str) -> typing.Type[Repository] | None:
455 return self._aliases.get(name)
457 def to_sorted_list(self):
458 # No need for sorting here; heapq maintains the heap property
459 return list(self)
461 def to_ordered_dict(self) -> OrderedDict[str, typing.Type[Repository]]:
462 ordered_dict = OrderedDict()
463 for _, item, settings in self.queue:
464 ordered_dict[settings["short_name"]] = item
465 return ordered_dict
467 def __iter__(self) -> typing.Generator[typing.Type[Repository], None, None]:
468 return (item[1] for item in self.queue)
470 def __bool__(self):
471 return bool(self.queue)
473 def _find_items(self) -> None:
474 # import all registrations in this folder, so @register adds them to _queue
475 package_path = Path(__file__).resolve().parent
477 for file_path in package_path.glob("*.py"):
478 pkg = file_path.stem
479 if not pkg.startswith("__"):
480 importlib.import_module(f".{pkg}", package=__name__)
483def register(
484 short_name: str = None,
485 aliases: tuple[str, ...] = (),
486 priority: int = -1,
487 # **settings: Unpack[RepositoryRegistration] # <- not really supported yet!
488) -> typing.Callable[[typing.Type[Repository]], typing.Type[Repository]]:
489 if isinstance(short_name, type):
490 raise SyntaxError("Please call @register() with parentheses!")
492 def wraps(cls: typing.Type[Repository]) -> typing.Type[Repository]:
493 if not (isinstance(cls, type) and issubclass(cls, Repository)):
494 raise TypeError(f"Decorated class {cls} must be a subclass of Repository!")
496 settings: RepositoryRegistration = {
497 "short_name": short_name or camel_to_snake(cls.__name__).removesuffix("_repository"),
498 "aliases": aliases,
499 "priority": priority,
500 }
502 registrations.push(cls, settings)
503 return cls
505 return wraps
508registrations = RepositoryRegistrations()