Coverage for src/ssh_agent_client/__init__.py: 100.000%

111 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-14 11:39 +0200

1# SPDX-FileCopyrightText: 2024 Marco Ricci <m@the13thletter.info> 

2# 

3# SPDX-License-Identifier: MIT 

4 

5"""A bare-bones SSH agent client supporting signing and key listing.""" 

6 

7from __future__ import annotations 

8 

9import collections 

10import errno 

11import os 

12import socket 

13 

14from collections.abc import Sequence 

15from typing_extensions import Any, Self 

16 

17from ssh_agent_client import types 

18 

19__all__ = ('SSHAgentClient',) 

20__author__ = 'Marco Ricci <m@the13thletter.info>' 

21__version__ = "0.1.0" 

22 

23_socket = socket 

24 

25class TrailingDataError(RuntimeError): 

26 """The result contained trailing data.""" 

27 

28class SSHAgentClient: 

29 """A bare-bones SSH agent client supporting signing and key listing. 

30 

31 The main use case is requesting the agent sign some data, after 

32 checking that the necessary key is already loaded. 

33 

34 The main fleshed out methods are `list_keys` and `sign`, which 

35 implement the `REQUEST_IDENTITIES` and `SIGN_REQUEST` requests. If 

36 you *really* wanted to, there is enough infrastructure in place to 

37 issue other requests as defined in the protocol---it's merely the 

38 wrapper functions and the protocol numbers table that are missing. 

39 

40 """ 

41 _connection: socket.socket 

42 def __init__( 

43 self, /, *, socket: socket.socket | None = None, timeout: int = 125 

44 ) -> None: 

45 """Initialize the client. 

46 

47 Args: 

48 socket: 

49 An optional socket, connected to the SSH agent. If not 

50 given, we query the `SSH_AUTH_SOCK` environment 

51 variable to auto-discover the correct socket address. 

52 timeout: 

53 A connection timeout for the SSH agent. Only used if 

54 the socket is not yet connected. The default value 

55 gives ample time for agent connections forwarded via 

56 SSH on high-latency networks (e.g. Tor). 

57 

58 Raises: 

59 KeyError: 

60 The `SSH_AUTH_SOCK` environment was not found. 

61 OSError: 

62 There was an error setting up a socket connection to the 

63 agent. 

64 

65 """ 

66 if socket is not None: 

67 self._connection = socket 

68 else: 

69 self._connection = _socket.socket(family=_socket.AF_UNIX) 

70 try: 

71 # Test whether the socket is connected. 

72 self._connection.getpeername() 

73 except OSError as e: 

74 # This condition is hard to test purposefully, so exclude 

75 # from coverage. 

76 if e.errno != errno.ENOTCONN: # pragma: no cover 

77 raise 

78 if 'SSH_AUTH_SOCK' not in os.environ: 

79 raise KeyError('SSH_AUTH_SOCK environment variable') 

80 ssh_auth_sock = os.environ['SSH_AUTH_SOCK'] 

81 self._connection.settimeout(timeout) 

82 self._connection.connect(ssh_auth_sock) 

83 

84 def __enter__(self) -> Self: 

85 """Close socket connection upon context manager completion.""" 

86 self._connection.__enter__() 

87 return self 

88 

89 def __exit__( 

90 self, exc_type: Any, exc_val: Any, exc_tb: Any 

91 ) -> bool: 

92 """Close socket connection upon context manager completion.""" 

93 return bool( 

94 self._connection.__exit__( 

95 exc_type, exc_val, exc_tb) # type: ignore[func-returns-value] 

96 ) 

97 

98 @staticmethod 

99 def uint32(num: int, /) -> bytes: 

100 r"""Format the number as a `uint32`, as per the agent protocol. 

101 

102 Args: 

103 num: A number. 

104 

105 Returns: 

106 The number in SSH agent wire protocol format, i.e. as 

107 a 32-bit big endian number. 

108 

109 Raises: 

110 OverflowError: 

111 As per [`int.to_bytes`][]. 

112 

113 Examples: 

114 >>> SSHAgentClient.uint32(16777216) 

115 b'\x01\x00\x00\x00' 

116 

117 """ 

118 return int.to_bytes(num, 4, 'big', signed=False) 

119 

120 @classmethod 

121 def string(cls, payload: bytes | bytearray, /) -> bytes | bytearray: 

122 r"""Format the payload as an SSH string, as per the agent protocol. 

123 

124 Args: 

125 payload: A byte string. 

126 

127 Returns: 

128 The payload, framed in the SSH agent wire protocol format. 

129 

130 Examples: 

131 >>> bytes(SSHAgentClient.string(b'ssh-rsa')) 

132 b'\x00\x00\x00\x07ssh-rsa' 

133 

134 """ 

135 try: 

136 ret = bytearray() 

137 ret.extend(cls.uint32(len(payload))) 

138 ret.extend(payload) 

139 return ret 

140 except Exception as e: 

141 raise TypeError('invalid payload type') from e 

142 

143 @classmethod 

144 def unstring(cls, bytestring: bytes | bytearray, /) -> bytes | bytearray: 

145 r"""Unpack an SSH string. 

146 

147 Args: 

148 bytestring: A framed byte string. 

149 

150 Returns: 

151 The unframed byte string, i.e., the payload. 

152 

153 Raises: 

154 ValueError: 

155 The byte string is not an SSH string. 

156 

157 Examples: 

158 >>> bytes(SSHAgentClient.unstring(b'\x00\x00\x00\x07ssh-rsa')) 

159 b'ssh-rsa' 

160 >>> bytes(SSHAgentClient.unstring(SSHAgentClient.string(b'ssh-ed25519'))) 

161 b'ssh-ed25519' 

162 

163 """ 

164 n = len(bytestring) 

165 if n < 4: 

166 raise ValueError('malformed SSH byte string') 

167 elif n != 4 + int.from_bytes(bytestring[:4], 'big', signed=False): 

168 raise ValueError('malformed SSH byte string') 

169 return bytestring[4:] 

170 

171 @classmethod 

172 def unstring_prefix( 

173 cls, bytestring: bytes | bytearray, / 

174 ) -> tuple[bytes | bytearray, bytes | bytearray]: 

175 r"""Unpack an SSH string at the beginning of the byte string. 

176 

177 Args: 

178 bytestring: 

179 A (general) byte string, beginning with a framed/SSH 

180 byte string. 

181 

182 Returns: 

183 A 2-tuple `(a, b)`, where `a` is the unframed byte 

184 string/payload at the beginning of input byte string, and 

185 `b` is the remainder of the input byte string. 

186 

187 Raises: 

188 ValueError: 

189 The byte string does not begin with an SSH string. 

190 

191 Examples: 

192 >>> a, b = SSHAgentClient.unstring_prefix( 

193 ... b'\x00\x00\x00\x07ssh-rsa____trailing data') 

194 >>> (bytes(a), bytes(b)) 

195 (b'ssh-rsa', b'____trailing data') 

196 >>> a, b = SSHAgentClient.unstring_prefix( 

197 ... SSHAgentClient.string(b'ssh-ed25519')) 

198 >>> (bytes(a), bytes(b)) 

199 (b'ssh-ed25519', b'') 

200 

201 """ 

202 n = len(bytestring) 

203 if n < 4: 

204 raise ValueError('malformed SSH byte string') 

205 m = int.from_bytes(bytestring[:4], 'big', signed=False) 

206 if m + 4 > n: 

207 raise ValueError('malformed SSH byte string') 

208 return (bytestring[4:m + 4], bytestring[m + 4:]) 

209 

210 def request( 

211 self, code: int, payload: bytes | bytearray, / 

212 ) -> tuple[int, bytes | bytearray]: 

213 """Issue a generic request to the SSH agent. 

214 

215 Args: 

216 code: 

217 The request code. See the SSH agent protocol for 

218 protocol numbers to use here (and which protocol numbers 

219 to expect in a response). 

220 payload: 

221 A byte string containing the payload, or "contents", of 

222 the request. Request-specific. `request` will add any 

223 necessary wire framing around the request code and the 

224 payload. 

225 

226 Returns: 

227 A 2-tuple consisting of the response code and the payload, 

228 with all wire framing removed. 

229 

230 Raises: 

231 EOFError: 

232 The response from the SSH agent is truncated or missing. 

233 

234 """ 

235 request_message = bytearray([code]) 

236 request_message.extend(payload) 

237 self._connection.sendall(self.string(request_message)) 

238 chunk = self._connection.recv(4) 

239 if len(chunk) < 4: 

240 raise EOFError('cannot read response length') 

241 response_length = int.from_bytes(chunk, 'big', signed=False) 

242 response = self._connection.recv(response_length) 

243 if len(response) < response_length: 

244 raise EOFError('truncated response from SSH agent') 

245 return response[0], response[1:] 

246 

247 def list_keys(self) -> Sequence[types.KeyCommentPair]: 

248 """Request a list of keys known to the SSH agent. 

249 

250 Returns: 

251 A read-only sequence of key/comment pairs. 

252 

253 Raises: 

254 EOFError: 

255 The response from the SSH agent is truncated or missing. 

256 TrailingDataError: 

257 The response from the SSH agent is too long. 

258 RuntimeError: 

259 The agent failed to complete the request. 

260 

261 """ 

262 response_code, response = self.request( 

263 types.SSH_AGENTC.REQUEST_IDENTITIES.value, b'') 

264 if response_code != types.SSH_AGENT.IDENTITIES_ANSWER.value: 

265 raise RuntimeError( 

266 f'error return from SSH agent: ' 

267 f'{response_code = }, {response = }' 

268 ) 

269 response_stream = collections.deque(response) 

270 def shift(num: int) -> bytes: 

271 buf = collections.deque(bytes()) 

272 for i in range(num): 

273 try: 

274 val = response_stream.popleft() 

275 except IndexError: 

276 response_stream.extendleft(reversed(buf)) 

277 raise EOFError( 

278 'truncated response from SSH agent' 

279 ) from None 

280 buf.append(val) 

281 return bytes(buf) 

282 key_count = int.from_bytes(shift(4), 'big') 

283 keys: collections.deque[types.KeyCommentPair] 

284 keys = collections.deque() 

285 for i in range(key_count): 

286 key_size = int.from_bytes(shift(4), 'big') 

287 key = shift(key_size) 

288 comment_size = int.from_bytes(shift(4), 'big') 

289 comment = shift(comment_size) 

290 # Both `key` and `comment` are not wrapped as SSH strings. 

291 keys.append(types.KeyCommentPair(key, comment)) 

292 if response_stream: 

293 raise TrailingDataError('overlong response from SSH agent') 

294 return keys 

295 

296 def sign( 

297 self, /, key: bytes | bytearray, payload: bytes | bytearray, 

298 *, flags: int = 0, check_if_key_loaded: bool = False, 

299 ) -> bytes | bytearray: 

300 """Request the SSH agent sign the payload with the key. 

301 

302 Args: 

303 key: 

304 The public SSH key to sign the payload with, in the same 

305 format as returned by, e.g., the `list_keys` method. 

306 The corresponding private key must have previously been 

307 loaded into the agent to successfully issue a signature. 

308 payload: 

309 A byte string of data to sign. 

310 flags: 

311 Optional flags for the signing request. Currently 

312 passed on as-is to the agent. In real-world usage, this 

313 could be used, e.g., to request more modern hash 

314 algorithms when signing with RSA keys. (No such 

315 real-world usage is currently implemented.) 

316 check_if_key_loaded: 

317 If true, check beforehand (via `list_keys`) if the 

318 corresponding key has been loaded into the agent. 

319 

320 Returns: 

321 The binary signature of the payload under the given key. 

322 

323 Raises: 

324 EOFError: 

325 The response from the SSH agent is truncated or missing. 

326 TrailingDataError: 

327 The response from the SSH agent is too long. 

328 RuntimeError: 

329 The agent failed to complete the request. 

330 KeyError: 

331 `check_if_key_loaded` is true, and the `key` was not 

332 loaded into the agent. 

333 

334 """ 

335 if check_if_key_loaded: 

336 loaded_keys = frozenset({pair.key for pair in self.list_keys()}) 

337 if bytes(key) not in loaded_keys: 

338 raise KeyError('target SSH key not loaded into agent') 

339 request_data = bytearray(self.string(key)) 

340 request_data.extend(self.string(payload)) 

341 request_data.extend(self.uint32(flags)) 

342 response_code, response = self.request( 

343 types.SSH_AGENTC.SIGN_REQUEST.value, request_data) 

344 if response_code != types.SSH_AGENT.SIGN_RESPONSE.value: 

345 raise RuntimeError( 

346 f'signing data failed: {response_code = }, {response = }' 

347 ) 

348 return self.unstring(response)