Coverage for src/ssh_agent_client/__init__.py: 100.000%
111 statements
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-14 11:39 +0200
« prev ^ index » next coverage.py v7.6.0, created at 2024-07-14 11:39 +0200
1# SPDX-FileCopyrightText: 2024 Marco Ricci <m@the13thletter.info>
2#
3# SPDX-License-Identifier: MIT
5"""A bare-bones SSH agent client supporting signing and key listing."""
7from __future__ import annotations
9import collections
10import errno
11import os
12import socket
14from collections.abc import Sequence
15from typing_extensions import Any, Self
17from ssh_agent_client import types
19__all__ = ('SSHAgentClient',)
20__author__ = 'Marco Ricci <m@the13thletter.info>'
21__version__ = "0.1.0"
23_socket = socket
25class TrailingDataError(RuntimeError):
26 """The result contained trailing data."""
28class SSHAgentClient:
29 """A bare-bones SSH agent client supporting signing and key listing.
31 The main use case is requesting the agent sign some data, after
32 checking that the necessary key is already loaded.
34 The main fleshed out methods are `list_keys` and `sign`, which
35 implement the `REQUEST_IDENTITIES` and `SIGN_REQUEST` requests. If
36 you *really* wanted to, there is enough infrastructure in place to
37 issue other requests as defined in the protocol---it's merely the
38 wrapper functions and the protocol numbers table that are missing.
40 """
41 _connection: socket.socket
42 def __init__(
43 self, /, *, socket: socket.socket | None = None, timeout: int = 125
44 ) -> None:
45 """Initialize the client.
47 Args:
48 socket:
49 An optional socket, connected to the SSH agent. If not
50 given, we query the `SSH_AUTH_SOCK` environment
51 variable to auto-discover the correct socket address.
52 timeout:
53 A connection timeout for the SSH agent. Only used if
54 the socket is not yet connected. The default value
55 gives ample time for agent connections forwarded via
56 SSH on high-latency networks (e.g. Tor).
58 Raises:
59 KeyError:
60 The `SSH_AUTH_SOCK` environment was not found.
61 OSError:
62 There was an error setting up a socket connection to the
63 agent.
65 """
66 if socket is not None:
67 self._connection = socket
68 else:
69 self._connection = _socket.socket(family=_socket.AF_UNIX)
70 try:
71 # Test whether the socket is connected.
72 self._connection.getpeername()
73 except OSError as e:
74 # This condition is hard to test purposefully, so exclude
75 # from coverage.
76 if e.errno != errno.ENOTCONN: # pragma: no cover
77 raise
78 if 'SSH_AUTH_SOCK' not in os.environ:
79 raise KeyError('SSH_AUTH_SOCK environment variable')
80 ssh_auth_sock = os.environ['SSH_AUTH_SOCK']
81 self._connection.settimeout(timeout)
82 self._connection.connect(ssh_auth_sock)
84 def __enter__(self) -> Self:
85 """Close socket connection upon context manager completion."""
86 self._connection.__enter__()
87 return self
89 def __exit__(
90 self, exc_type: Any, exc_val: Any, exc_tb: Any
91 ) -> bool:
92 """Close socket connection upon context manager completion."""
93 return bool(
94 self._connection.__exit__(
95 exc_type, exc_val, exc_tb) # type: ignore[func-returns-value]
96 )
98 @staticmethod
99 def uint32(num: int, /) -> bytes:
100 r"""Format the number as a `uint32`, as per the agent protocol.
102 Args:
103 num: A number.
105 Returns:
106 The number in SSH agent wire protocol format, i.e. as
107 a 32-bit big endian number.
109 Raises:
110 OverflowError:
111 As per [`int.to_bytes`][].
113 Examples:
114 >>> SSHAgentClient.uint32(16777216)
115 b'\x01\x00\x00\x00'
117 """
118 return int.to_bytes(num, 4, 'big', signed=False)
120 @classmethod
121 def string(cls, payload: bytes | bytearray, /) -> bytes | bytearray:
122 r"""Format the payload as an SSH string, as per the agent protocol.
124 Args:
125 payload: A byte string.
127 Returns:
128 The payload, framed in the SSH agent wire protocol format.
130 Examples:
131 >>> bytes(SSHAgentClient.string(b'ssh-rsa'))
132 b'\x00\x00\x00\x07ssh-rsa'
134 """
135 try:
136 ret = bytearray()
137 ret.extend(cls.uint32(len(payload)))
138 ret.extend(payload)
139 return ret
140 except Exception as e:
141 raise TypeError('invalid payload type') from e
143 @classmethod
144 def unstring(cls, bytestring: bytes | bytearray, /) -> bytes | bytearray:
145 r"""Unpack an SSH string.
147 Args:
148 bytestring: A framed byte string.
150 Returns:
151 The unframed byte string, i.e., the payload.
153 Raises:
154 ValueError:
155 The byte string is not an SSH string.
157 Examples:
158 >>> bytes(SSHAgentClient.unstring(b'\x00\x00\x00\x07ssh-rsa'))
159 b'ssh-rsa'
160 >>> bytes(SSHAgentClient.unstring(SSHAgentClient.string(b'ssh-ed25519')))
161 b'ssh-ed25519'
163 """
164 n = len(bytestring)
165 if n < 4:
166 raise ValueError('malformed SSH byte string')
167 elif n != 4 + int.from_bytes(bytestring[:4], 'big', signed=False):
168 raise ValueError('malformed SSH byte string')
169 return bytestring[4:]
171 @classmethod
172 def unstring_prefix(
173 cls, bytestring: bytes | bytearray, /
174 ) -> tuple[bytes | bytearray, bytes | bytearray]:
175 r"""Unpack an SSH string at the beginning of the byte string.
177 Args:
178 bytestring:
179 A (general) byte string, beginning with a framed/SSH
180 byte string.
182 Returns:
183 A 2-tuple `(a, b)`, where `a` is the unframed byte
184 string/payload at the beginning of input byte string, and
185 `b` is the remainder of the input byte string.
187 Raises:
188 ValueError:
189 The byte string does not begin with an SSH string.
191 Examples:
192 >>> a, b = SSHAgentClient.unstring_prefix(
193 ... b'\x00\x00\x00\x07ssh-rsa____trailing data')
194 >>> (bytes(a), bytes(b))
195 (b'ssh-rsa', b'____trailing data')
196 >>> a, b = SSHAgentClient.unstring_prefix(
197 ... SSHAgentClient.string(b'ssh-ed25519'))
198 >>> (bytes(a), bytes(b))
199 (b'ssh-ed25519', b'')
201 """
202 n = len(bytestring)
203 if n < 4:
204 raise ValueError('malformed SSH byte string')
205 m = int.from_bytes(bytestring[:4], 'big', signed=False)
206 if m + 4 > n:
207 raise ValueError('malformed SSH byte string')
208 return (bytestring[4:m + 4], bytestring[m + 4:])
210 def request(
211 self, code: int, payload: bytes | bytearray, /
212 ) -> tuple[int, bytes | bytearray]:
213 """Issue a generic request to the SSH agent.
215 Args:
216 code:
217 The request code. See the SSH agent protocol for
218 protocol numbers to use here (and which protocol numbers
219 to expect in a response).
220 payload:
221 A byte string containing the payload, or "contents", of
222 the request. Request-specific. `request` will add any
223 necessary wire framing around the request code and the
224 payload.
226 Returns:
227 A 2-tuple consisting of the response code and the payload,
228 with all wire framing removed.
230 Raises:
231 EOFError:
232 The response from the SSH agent is truncated or missing.
234 """
235 request_message = bytearray([code])
236 request_message.extend(payload)
237 self._connection.sendall(self.string(request_message))
238 chunk = self._connection.recv(4)
239 if len(chunk) < 4:
240 raise EOFError('cannot read response length')
241 response_length = int.from_bytes(chunk, 'big', signed=False)
242 response = self._connection.recv(response_length)
243 if len(response) < response_length:
244 raise EOFError('truncated response from SSH agent')
245 return response[0], response[1:]
247 def list_keys(self) -> Sequence[types.KeyCommentPair]:
248 """Request a list of keys known to the SSH agent.
250 Returns:
251 A read-only sequence of key/comment pairs.
253 Raises:
254 EOFError:
255 The response from the SSH agent is truncated or missing.
256 TrailingDataError:
257 The response from the SSH agent is too long.
258 RuntimeError:
259 The agent failed to complete the request.
261 """
262 response_code, response = self.request(
263 types.SSH_AGENTC.REQUEST_IDENTITIES.value, b'')
264 if response_code != types.SSH_AGENT.IDENTITIES_ANSWER.value:
265 raise RuntimeError(
266 f'error return from SSH agent: '
267 f'{response_code = }, {response = }'
268 )
269 response_stream = collections.deque(response)
270 def shift(num: int) -> bytes:
271 buf = collections.deque(bytes())
272 for i in range(num):
273 try:
274 val = response_stream.popleft()
275 except IndexError:
276 response_stream.extendleft(reversed(buf))
277 raise EOFError(
278 'truncated response from SSH agent'
279 ) from None
280 buf.append(val)
281 return bytes(buf)
282 key_count = int.from_bytes(shift(4), 'big')
283 keys: collections.deque[types.KeyCommentPair]
284 keys = collections.deque()
285 for i in range(key_count):
286 key_size = int.from_bytes(shift(4), 'big')
287 key = shift(key_size)
288 comment_size = int.from_bytes(shift(4), 'big')
289 comment = shift(comment_size)
290 # Both `key` and `comment` are not wrapped as SSH strings.
291 keys.append(types.KeyCommentPair(key, comment))
292 if response_stream:
293 raise TrailingDataError('overlong response from SSH agent')
294 return keys
296 def sign(
297 self, /, key: bytes | bytearray, payload: bytes | bytearray,
298 *, flags: int = 0, check_if_key_loaded: bool = False,
299 ) -> bytes | bytearray:
300 """Request the SSH agent sign the payload with the key.
302 Args:
303 key:
304 The public SSH key to sign the payload with, in the same
305 format as returned by, e.g., the `list_keys` method.
306 The corresponding private key must have previously been
307 loaded into the agent to successfully issue a signature.
308 payload:
309 A byte string of data to sign.
310 flags:
311 Optional flags for the signing request. Currently
312 passed on as-is to the agent. In real-world usage, this
313 could be used, e.g., to request more modern hash
314 algorithms when signing with RSA keys. (No such
315 real-world usage is currently implemented.)
316 check_if_key_loaded:
317 If true, check beforehand (via `list_keys`) if the
318 corresponding key has been loaded into the agent.
320 Returns:
321 The binary signature of the payload under the given key.
323 Raises:
324 EOFError:
325 The response from the SSH agent is truncated or missing.
326 TrailingDataError:
327 The response from the SSH agent is too long.
328 RuntimeError:
329 The agent failed to complete the request.
330 KeyError:
331 `check_if_key_loaded` is true, and the `key` was not
332 loaded into the agent.
334 """
335 if check_if_key_loaded:
336 loaded_keys = frozenset({pair.key for pair in self.list_keys()})
337 if bytes(key) not in loaded_keys:
338 raise KeyError('target SSH key not loaded into agent')
339 request_data = bytearray(self.string(key))
340 request_data.extend(self.string(payload))
341 request_data.extend(self.uint32(flags))
342 response_code, response = self.request(
343 types.SSH_AGENTC.SIGN_REQUEST.value, request_data)
344 if response_code != types.SSH_AGENT.SIGN_RESPONSE.value:
345 raise RuntimeError(
346 f'signing data failed: {response_code = }, {response = }'
347 )
348 return self.unstring(response)