Coverage for src/extratools_gittools/repo.py: 0%

79 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2025-04-10 19:59 -0700

1from __future__ import annotations 

2 

3from datetime import UTC, datetime, timedelta 

4import os 

5from collections.abc import Sequence 

6from io import BytesIO 

7from pathlib import Path 

8from typing import Any 

9 

10import sh 

11 

12from .status import get_status 

13 

14 

15class Repo: 

16 def __init__( 

17 self, path: Path | str, 

18 *, 

19 user_name: str, 

20 user_email: str, 

21 ) -> None: 

22 self.__path: Path = Path(path).expanduser() 

23 

24 self.__git = sh.bake( 

25 _cwd=self.__path, 

26 _env={ 

27 "GIT_AUTHOR_NAME": user_name, 

28 "GIT_AUTHOR_EMAIL": user_email, 

29 "GIT_COMMITTER_NAME": user_name, 

30 "GIT_COMMITTER_EMAIL": user_email, 

31 } | os.environ, 

32 ).git 

33 

34 if not (self.__path / ".git").is_dir(): 

35 msg = "Specified path must be part of a Git repo." 

36 raise ValueError(msg) 

37 

38 @staticmethod 

39 def init( 

40 path: Path | str, 

41 *, 

42 exist_ok: bool = True, 

43 **kwargs: Any, 

44 ) -> Repo: 

45 repo_path: Path = Path(path).expanduser() 

46 

47 repo_path.mkdir(parents=True, exist_ok=True) 

48 

49 if (repo_path / ".git").exists(): 

50 if not exist_ok: 

51 msg = "Specified path is already a Git repo." 

52 raise RuntimeError(msg) 

53 else: 

54 sh.git( 

55 "init", 

56 _cwd=repo_path, 

57 ) 

58 

59 return Repo(repo_path, **kwargs) 

60 

61 def is_clean(self) -> bool: 

62 status: dict[str, Any] | None = get_status(str(self.__path)) 

63 if not status: 

64 msg = "Cannot get status of Git repo." 

65 raise RuntimeError(msg) 

66 

67 return not (status["files"]["staged"] or status["files"]["unstaged"]) 

68 

69 def stage(self, *files: str) -> None: 

70 args: list[str] = ["--", *files] if files else ["."] 

71 

72 self.__git( 

73 "add", *args, 

74 ) 

75 

76 def reset(self) -> None: 

77 self.__git( 

78 "reset", 

79 ) 

80 

81 def commit(self, message: str, *, stage_all: bool = True, background: bool = False) -> None: 

82 args: list[str] = ["--all"] if stage_all else [] 

83 

84 self.__git( 

85 "commit", *args, f"--message={message}", 

86 _bg=background, 

87 ) 

88 

89 def pull(self, *, rebase: bool = True, background: bool = False) -> None: 

90 if not self.is_clean(): 

91 msg = "Repo is not clean." 

92 raise RuntimeError(msg) 

93 

94 args: list[str] = ["--rebase=true"] if rebase else [] 

95 

96 self.__git( 

97 "pull", *args, 

98 _bg=background, 

99 ) 

100 

101 def push(self, *, background: bool = False) -> None: 

102 if not self.is_clean(): 

103 msg = "Repo is not clean." 

104 raise RuntimeError(msg) 

105 

106 self.__git( 

107 "push", 

108 _bg=background, 

109 ) 

110 

111 def list_commits( 

112 self, 

113 relative_path: Path | str | None = None, 

114 *, 

115 max_count: int | None = None, 

116 before: datetime | timedelta | None = None, 

117 ) -> Sequence[str]: 

118 args: list[str] = [] 

119 

120 if before: 

121 if isinstance(before, timedelta): 

122 before = datetime.now(UTC) - before 

123 

124 args.append(f"--before={before.isoformat()}") 

125 

126 if max_count: 

127 args.append(f"--max-count={max_count}") 

128 

129 if relative_path: 

130 args.append(str(relative_path)) 

131 

132 output: str = self.__git( 

133 "log", "--oneline", "--reverse", *args, 

134 _tty_out=False, 

135 ) 

136 

137 return [ 

138 line.split(' ')[0] 

139 for line in output.strip().splitlines() 

140 ] 

141 

142 def get_blob( 

143 self, 

144 relative_path: Path | str, 

145 *, 

146 version: str | int | datetime | timedelta | None = None, 

147 ) -> bytes: 

148 blob_path: Path = self.__path / relative_path 

149 

150 try: 

151 if version is None: 

152 return blob_path.read_bytes() 

153 

154 if isinstance(version, int): 

155 commits: Sequence[str] = self.list_commits( 

156 relative_path, 

157 max_count=(-version if version < 0 else None), 

158 ) 

159 

160 version = commits[version] 

161 elif isinstance(version, (datetime, timedelta)): 

162 commits: Sequence[str] = self.list_commits( 

163 relative_path, 

164 max_count=1, 

165 before=version, 

166 ) 

167 

168 version = commits[0] 

169 

170 bio = BytesIO() 

171 self.__git( 

172 "show", f"{version}:{relative_path}", 

173 _out=bio, 

174 _tty_out=False, 

175 ) 

176 return bio.getvalue() 

177 except Exception as e: 

178 raise FileNotFoundError from e