Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is dual licensed under the terms of the Apache License, Version 

2# 2.0, and the BSD License. See the LICENSE file in the root of this repository 

3# for complete details. 

4 

5 

6import binascii 

7import os 

8import re 

9import struct 

10import typing 

11from base64 import encodebytes as _base64_encode 

12 

13from cryptography import utils 

14from cryptography.exceptions import UnsupportedAlgorithm 

15from cryptography.hazmat.backends import _get_backend 

16from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, rsa 

17from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes 

18from cryptography.hazmat.primitives.serialization import ( 

19 Encoding, 

20 NoEncryption, 

21 PrivateFormat, 

22 PublicFormat, 

23) 

24 

25try: 

26 from bcrypt import kdf as _bcrypt_kdf 

27 

28 _bcrypt_supported = True 

29except ImportError: 

30 _bcrypt_supported = False 

31 

32 def _bcrypt_kdf( 

33 password: bytes, 

34 salt: bytes, 

35 desired_key_bytes: int, 

36 rounds: int, 

37 ignore_few_rounds: bool = False, 

38 ) -> bytes: 

39 raise UnsupportedAlgorithm("Need bcrypt module") 

40 

41 

42_SSH_ED25519 = b"ssh-ed25519" 

43_SSH_RSA = b"ssh-rsa" 

44_SSH_DSA = b"ssh-dss" 

45_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256" 

46_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384" 

47_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521" 

48_CERT_SUFFIX = b"-cert-v01@openssh.com" 

49 

50_SSH_PUBKEY_RC = re.compile(br"\A(\S+)[ \t]+(\S+)") 

51_SK_MAGIC = b"openssh-key-v1\0" 

52_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----" 

53_SK_END = b"-----END OPENSSH PRIVATE KEY-----" 

54_BCRYPT = b"bcrypt" 

55_NONE = b"none" 

56_DEFAULT_CIPHER = b"aes256-ctr" 

57_DEFAULT_ROUNDS = 16 

58_MAX_PASSWORD = 72 

59 

60# re is only way to work on bytes-like data 

61_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL) 

62 

63# padding for max blocksize 

64_PADDING = memoryview(bytearray(range(1, 1 + 16))) 

65 

66# ciphers that are actually used in key wrapping 

67_SSH_CIPHERS = { 

68 b"aes256-ctr": (algorithms.AES, 32, modes.CTR, 16), 

69 b"aes256-cbc": (algorithms.AES, 32, modes.CBC, 16), 

70} 

71 

72# map local curve name to key type 

73_ECDSA_KEY_TYPE = { 

74 "secp256r1": _ECDSA_NISTP256, 

75 "secp384r1": _ECDSA_NISTP384, 

76 "secp521r1": _ECDSA_NISTP521, 

77} 

78 

79_U32 = struct.Struct(b">I") 

80_U64 = struct.Struct(b">Q") 

81 

82 

83def _ecdsa_key_type(public_key): 

84 """Return SSH key_type and curve_name for private key.""" 

85 curve = public_key.curve 

86 if curve.name not in _ECDSA_KEY_TYPE: 

87 raise ValueError( 

88 "Unsupported curve for ssh private key: %r" % curve.name 

89 ) 

90 return _ECDSA_KEY_TYPE[curve.name] 

91 

92 

93def _ssh_pem_encode(data, prefix=_SK_START + b"\n", suffix=_SK_END + b"\n"): 

94 return b"".join([prefix, _base64_encode(data), suffix]) 

95 

96 

97def _check_block_size(data, block_len): 

98 """Require data to be full blocks""" 

99 if not data or len(data) % block_len != 0: 

100 raise ValueError("Corrupt data: missing padding") 

101 

102 

103def _check_empty(data): 

104 """All data should have been parsed.""" 

105 if data: 

106 raise ValueError("Corrupt data: unparsed data") 

107 

108 

109def _init_cipher(ciphername, password, salt, rounds, backend): 

110 """Generate key + iv and return cipher.""" 

111 if not password: 

112 raise ValueError("Key is password-protected.") 

113 

114 algo, key_len, mode, iv_len = _SSH_CIPHERS[ciphername] 

115 seed = _bcrypt_kdf(password, salt, key_len + iv_len, rounds, True) 

116 return Cipher(algo(seed[:key_len]), mode(seed[key_len:]), backend) 

117 

118 

119def _get_u32(data): 

120 """Uint32""" 

121 if len(data) < 4: 

122 raise ValueError("Invalid data") 

123 return _U32.unpack(data[:4])[0], data[4:] 

124 

125 

126def _get_u64(data): 

127 """Uint64""" 

128 if len(data) < 8: 

129 raise ValueError("Invalid data") 

130 return _U64.unpack(data[:8])[0], data[8:] 

131 

132 

133def _get_sshstr(data): 

134 """Bytes with u32 length prefix""" 

135 n, data = _get_u32(data) 

136 if n > len(data): 

137 raise ValueError("Invalid data") 

138 return data[:n], data[n:] 

139 

140 

141def _get_mpint(data): 

142 """Big integer.""" 

143 val, data = _get_sshstr(data) 

144 if val and val[0] > 0x7F: 

145 raise ValueError("Invalid data") 

146 return int.from_bytes(val, "big"), data 

147 

148 

149def _to_mpint(val): 

150 """Storage format for signed bigint.""" 

151 if val < 0: 

152 raise ValueError("negative mpint not allowed") 

153 if not val: 

154 return b"" 

155 nbytes = (val.bit_length() + 8) // 8 

156 return utils.int_to_bytes(val, nbytes) 

157 

158 

159class _FragList(object): 

160 """Build recursive structure without data copy.""" 

161 

162 def __init__(self, init=None): 

163 self.flist = [] 

164 if init: 

165 self.flist.extend(init) 

166 

167 def put_raw(self, val): 

168 """Add plain bytes""" 

169 self.flist.append(val) 

170 

171 def put_u32(self, val): 

172 """Big-endian uint32""" 

173 self.flist.append(_U32.pack(val)) 

174 

175 def put_sshstr(self, val): 

176 """Bytes prefixed with u32 length""" 

177 if isinstance(val, (bytes, memoryview, bytearray)): 

178 self.put_u32(len(val)) 

179 self.flist.append(val) 

180 else: 

181 self.put_u32(val.size()) 

182 self.flist.extend(val.flist) 

183 

184 def put_mpint(self, val): 

185 """Big-endian bigint prefixed with u32 length""" 

186 self.put_sshstr(_to_mpint(val)) 

187 

188 def size(self): 

189 """Current number of bytes""" 

190 return sum(map(len, self.flist)) 

191 

192 def render(self, dstbuf, pos=0): 

193 """Write into bytearray""" 

194 for frag in self.flist: 

195 flen = len(frag) 

196 start, pos = pos, pos + flen 

197 dstbuf[start:pos] = frag 

198 return pos 

199 

200 def tobytes(self): 

201 """Return as bytes""" 

202 buf = memoryview(bytearray(self.size())) 

203 self.render(buf) 

204 return buf.tobytes() 

205 

206 

207class _SSHFormatRSA(object): 

208 """Format for RSA keys. 

209 

210 Public: 

211 mpint e, n 

212 Private: 

213 mpint n, e, d, iqmp, p, q 

214 """ 

215 

216 def get_public(self, data): 

217 """RSA public fields""" 

218 e, data = _get_mpint(data) 

219 n, data = _get_mpint(data) 

220 return (e, n), data 

221 

222 def load_public(self, key_type, data, backend): 

223 """Make RSA public key from data.""" 

224 (e, n), data = self.get_public(data) 

225 public_numbers = rsa.RSAPublicNumbers(e, n) 

226 public_key = public_numbers.public_key(backend) 

227 return public_key, data 

228 

229 def load_private(self, data, pubfields, backend): 

230 """Make RSA private key from data.""" 

231 n, data = _get_mpint(data) 

232 e, data = _get_mpint(data) 

233 d, data = _get_mpint(data) 

234 iqmp, data = _get_mpint(data) 

235 p, data = _get_mpint(data) 

236 q, data = _get_mpint(data) 

237 

238 if (e, n) != pubfields: 

239 raise ValueError("Corrupt data: rsa field mismatch") 

240 dmp1 = rsa.rsa_crt_dmp1(d, p) 

241 dmq1 = rsa.rsa_crt_dmq1(d, q) 

242 public_numbers = rsa.RSAPublicNumbers(e, n) 

243 private_numbers = rsa.RSAPrivateNumbers( 

244 p, q, d, dmp1, dmq1, iqmp, public_numbers 

245 ) 

246 private_key = private_numbers.private_key(backend) 

247 return private_key, data 

248 

249 def encode_public(self, public_key, f_pub): 

250 """Write RSA public key""" 

251 pubn = public_key.public_numbers() 

252 f_pub.put_mpint(pubn.e) 

253 f_pub.put_mpint(pubn.n) 

254 

255 def encode_private(self, private_key, f_priv): 

256 """Write RSA private key""" 

257 private_numbers = private_key.private_numbers() 

258 public_numbers = private_numbers.public_numbers 

259 

260 f_priv.put_mpint(public_numbers.n) 

261 f_priv.put_mpint(public_numbers.e) 

262 

263 f_priv.put_mpint(private_numbers.d) 

264 f_priv.put_mpint(private_numbers.iqmp) 

265 f_priv.put_mpint(private_numbers.p) 

266 f_priv.put_mpint(private_numbers.q) 

267 

268 

269class _SSHFormatDSA(object): 

270 """Format for DSA keys. 

271 

272 Public: 

273 mpint p, q, g, y 

274 Private: 

275 mpint p, q, g, y, x 

276 """ 

277 

278 def get_public(self, data): 

279 """DSA public fields""" 

280 p, data = _get_mpint(data) 

281 q, data = _get_mpint(data) 

282 g, data = _get_mpint(data) 

283 y, data = _get_mpint(data) 

284 return (p, q, g, y), data 

285 

286 def load_public(self, key_type, data, backend): 

287 """Make DSA public key from data.""" 

288 (p, q, g, y), data = self.get_public(data) 

289 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 

290 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 

291 self._validate(public_numbers) 

292 public_key = public_numbers.public_key(backend) 

293 return public_key, data 

294 

295 def load_private(self, data, pubfields, backend): 

296 """Make DSA private key from data.""" 

297 (p, q, g, y), data = self.get_public(data) 

298 x, data = _get_mpint(data) 

299 

300 if (p, q, g, y) != pubfields: 

301 raise ValueError("Corrupt data: dsa field mismatch") 

302 parameter_numbers = dsa.DSAParameterNumbers(p, q, g) 

303 public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) 

304 self._validate(public_numbers) 

305 private_numbers = dsa.DSAPrivateNumbers(x, public_numbers) 

306 private_key = private_numbers.private_key(backend) 

307 return private_key, data 

308 

309 def encode_public(self, public_key, f_pub): 

310 """Write DSA public key""" 

311 public_numbers = public_key.public_numbers() 

312 parameter_numbers = public_numbers.parameter_numbers 

313 self._validate(public_numbers) 

314 

315 f_pub.put_mpint(parameter_numbers.p) 

316 f_pub.put_mpint(parameter_numbers.q) 

317 f_pub.put_mpint(parameter_numbers.g) 

318 f_pub.put_mpint(public_numbers.y) 

319 

320 def encode_private(self, private_key, f_priv): 

321 """Write DSA private key""" 

322 self.encode_public(private_key.public_key(), f_priv) 

323 f_priv.put_mpint(private_key.private_numbers().x) 

324 

325 def _validate(self, public_numbers): 

326 parameter_numbers = public_numbers.parameter_numbers 

327 if parameter_numbers.p.bit_length() != 1024: 

328 raise ValueError("SSH supports only 1024 bit DSA keys") 

329 

330 

331class _SSHFormatECDSA(object): 

332 """Format for ECDSA keys. 

333 

334 Public: 

335 str curve 

336 bytes point 

337 Private: 

338 str curve 

339 bytes point 

340 mpint secret 

341 """ 

342 

343 def __init__(self, ssh_curve_name, curve): 

344 self.ssh_curve_name = ssh_curve_name 

345 self.curve = curve 

346 

347 def get_public(self, data): 

348 """ECDSA public fields""" 

349 curve, data = _get_sshstr(data) 

350 point, data = _get_sshstr(data) 

351 if curve != self.ssh_curve_name: 

352 raise ValueError("Curve name mismatch") 

353 if point[0] != 4: 

354 raise NotImplementedError("Need uncompressed point") 

355 return (curve, point), data 

356 

357 def load_public(self, key_type, data, backend): 

358 """Make ECDSA public key from data.""" 

359 (curve_name, point), data = self.get_public(data) 

360 public_key = ec.EllipticCurvePublicKey.from_encoded_point( 

361 self.curve, point.tobytes() 

362 ) 

363 return public_key, data 

364 

365 def load_private(self, data, pubfields, backend): 

366 """Make ECDSA private key from data.""" 

367 (curve_name, point), data = self.get_public(data) 

368 secret, data = _get_mpint(data) 

369 

370 if (curve_name, point) != pubfields: 

371 raise ValueError("Corrupt data: ecdsa field mismatch") 

372 private_key = ec.derive_private_key(secret, self.curve, backend) 

373 return private_key, data 

374 

375 def encode_public(self, public_key, f_pub): 

376 """Write ECDSA public key""" 

377 point = public_key.public_bytes( 

378 Encoding.X962, PublicFormat.UncompressedPoint 

379 ) 

380 f_pub.put_sshstr(self.ssh_curve_name) 

381 f_pub.put_sshstr(point) 

382 

383 def encode_private(self, private_key, f_priv): 

384 """Write ECDSA private key""" 

385 public_key = private_key.public_key() 

386 private_numbers = private_key.private_numbers() 

387 

388 self.encode_public(public_key, f_priv) 

389 f_priv.put_mpint(private_numbers.private_value) 

390 

391 

392class _SSHFormatEd25519(object): 

393 """Format for Ed25519 keys. 

394 

395 Public: 

396 bytes point 

397 Private: 

398 bytes point 

399 bytes secret_and_point 

400 """ 

401 

402 def get_public(self, data): 

403 """Ed25519 public fields""" 

404 point, data = _get_sshstr(data) 

405 return (point,), data 

406 

407 def load_public(self, key_type, data, backend): 

408 """Make Ed25519 public key from data.""" 

409 (point,), data = self.get_public(data) 

410 public_key = ed25519.Ed25519PublicKey.from_public_bytes( 

411 point.tobytes() 

412 ) 

413 return public_key, data 

414 

415 def load_private(self, data, pubfields, backend): 

416 """Make Ed25519 private key from data.""" 

417 (point,), data = self.get_public(data) 

418 keypair, data = _get_sshstr(data) 

419 

420 secret = keypair[:32] 

421 point2 = keypair[32:] 

422 if point != point2 or (point,) != pubfields: 

423 raise ValueError("Corrupt data: ed25519 field mismatch") 

424 private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret) 

425 return private_key, data 

426 

427 def encode_public(self, public_key, f_pub): 

428 """Write Ed25519 public key""" 

429 raw_public_key = public_key.public_bytes( 

430 Encoding.Raw, PublicFormat.Raw 

431 ) 

432 f_pub.put_sshstr(raw_public_key) 

433 

434 def encode_private(self, private_key, f_priv): 

435 """Write Ed25519 private key""" 

436 public_key = private_key.public_key() 

437 raw_private_key = private_key.private_bytes( 

438 Encoding.Raw, PrivateFormat.Raw, NoEncryption() 

439 ) 

440 raw_public_key = public_key.public_bytes( 

441 Encoding.Raw, PublicFormat.Raw 

442 ) 

443 f_keypair = _FragList([raw_private_key, raw_public_key]) 

444 

445 self.encode_public(public_key, f_priv) 

446 f_priv.put_sshstr(f_keypair) 

447 

448 

449_KEY_FORMATS = { 

450 _SSH_RSA: _SSHFormatRSA(), 

451 _SSH_DSA: _SSHFormatDSA(), 

452 _SSH_ED25519: _SSHFormatEd25519(), 

453 _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()), 

454 _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()), 

455 _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()), 

456} 

457 

458 

459def _lookup_kformat(key_type): 

460 """Return valid format or throw error""" 

461 if not isinstance(key_type, bytes): 

462 key_type = memoryview(key_type).tobytes() 

463 if key_type in _KEY_FORMATS: 

464 return _KEY_FORMATS[key_type] 

465 raise UnsupportedAlgorithm("Unsupported key type: %r" % key_type) 

466 

467 

468_SSH_PRIVATE_KEY_TYPES = typing.Union[ 

469 ec.EllipticCurvePrivateKey, 

470 rsa.RSAPrivateKey, 

471 dsa.DSAPrivateKey, 

472 ed25519.Ed25519PrivateKey, 

473] 

474 

475 

476def load_ssh_private_key( 

477 data: bytes, password: typing.Optional[bytes], backend=None 

478) -> _SSH_PRIVATE_KEY_TYPES: 

479 """Load private key from OpenSSH custom encoding.""" 

480 utils._check_byteslike("data", data) 

481 backend = _get_backend(backend) 

482 if password is not None: 

483 utils._check_bytes("password", password) 

484 

485 m = _PEM_RC.search(data) 

486 if not m: 

487 raise ValueError("Not OpenSSH private key format") 

488 p1 = m.start(1) 

489 p2 = m.end(1) 

490 data = binascii.a2b_base64(memoryview(data)[p1:p2]) 

491 if not data.startswith(_SK_MAGIC): 

492 raise ValueError("Not OpenSSH private key format") 

493 data = memoryview(data)[len(_SK_MAGIC) :] 

494 

495 # parse header 

496 ciphername, data = _get_sshstr(data) 

497 kdfname, data = _get_sshstr(data) 

498 kdfoptions, data = _get_sshstr(data) 

499 nkeys, data = _get_u32(data) 

500 if nkeys != 1: 

501 raise ValueError("Only one key supported") 

502 

503 # load public key data 

504 pubdata, data = _get_sshstr(data) 

505 pub_key_type, pubdata = _get_sshstr(pubdata) 

506 kformat = _lookup_kformat(pub_key_type) 

507 pubfields, pubdata = kformat.get_public(pubdata) 

508 _check_empty(pubdata) 

509 

510 # load secret data 

511 edata, data = _get_sshstr(data) 

512 _check_empty(data) 

513 

514 if (ciphername, kdfname) != (_NONE, _NONE): 

515 ciphername = ciphername.tobytes() 

516 if ciphername not in _SSH_CIPHERS: 

517 raise UnsupportedAlgorithm("Unsupported cipher: %r" % ciphername) 

518 if kdfname != _BCRYPT: 

519 raise UnsupportedAlgorithm("Unsupported KDF: %r" % kdfname) 

520 blklen = _SSH_CIPHERS[ciphername][3] 

521 _check_block_size(edata, blklen) 

522 salt, kbuf = _get_sshstr(kdfoptions) 

523 rounds, kbuf = _get_u32(kbuf) 

524 _check_empty(kbuf) 

525 ciph = _init_cipher( 

526 ciphername, password, salt.tobytes(), rounds, backend 

527 ) 

528 edata = memoryview(ciph.decryptor().update(edata)) 

529 else: 

530 blklen = 8 

531 _check_block_size(edata, blklen) 

532 ck1, edata = _get_u32(edata) 

533 ck2, edata = _get_u32(edata) 

534 if ck1 != ck2: 

535 raise ValueError("Corrupt data: broken checksum") 

536 

537 # load per-key struct 

538 key_type, edata = _get_sshstr(edata) 

539 if key_type != pub_key_type: 

540 raise ValueError("Corrupt data: key type mismatch") 

541 private_key, edata = kformat.load_private(edata, pubfields, backend) 

542 comment, edata = _get_sshstr(edata) 

543 

544 # yes, SSH does padding check *after* all other parsing is done. 

545 # need to follow as it writes zero-byte padding too. 

546 if edata != _PADDING[: len(edata)]: 

547 raise ValueError("Corrupt data: invalid padding") 

548 

549 return private_key 

550 

551 

552def serialize_ssh_private_key( 

553 private_key: _SSH_PRIVATE_KEY_TYPES, 

554 password: typing.Optional[bytes] = None, 

555): 

556 """Serialize private key with OpenSSH custom encoding.""" 

557 if password is not None: 

558 utils._check_bytes("password", password) 

559 if password and len(password) > _MAX_PASSWORD: 

560 raise ValueError( 

561 "Passwords longer than 72 bytes are not supported by " 

562 "OpenSSH private key format" 

563 ) 

564 

565 if isinstance(private_key, ec.EllipticCurvePrivateKey): 

566 key_type = _ecdsa_key_type(private_key.public_key()) 

567 elif isinstance(private_key, rsa.RSAPrivateKey): 

568 key_type = _SSH_RSA 

569 elif isinstance(private_key, dsa.DSAPrivateKey): 

570 key_type = _SSH_DSA 

571 elif isinstance(private_key, ed25519.Ed25519PrivateKey): 

572 key_type = _SSH_ED25519 

573 else: 

574 raise ValueError("Unsupported key type") 

575 kformat = _lookup_kformat(key_type) 

576 

577 # setup parameters 

578 f_kdfoptions = _FragList() 

579 if password: 

580 ciphername = _DEFAULT_CIPHER 

581 blklen = _SSH_CIPHERS[ciphername][3] 

582 kdfname = _BCRYPT 

583 rounds = _DEFAULT_ROUNDS 

584 salt = os.urandom(16) 

585 f_kdfoptions.put_sshstr(salt) 

586 f_kdfoptions.put_u32(rounds) 

587 backend = _get_backend(None) 

588 ciph = _init_cipher(ciphername, password, salt, rounds, backend) 

589 else: 

590 ciphername = kdfname = _NONE 

591 blklen = 8 

592 ciph = None 

593 nkeys = 1 

594 checkval = os.urandom(4) 

595 comment = b"" 

596 

597 # encode public and private parts together 

598 f_public_key = _FragList() 

599 f_public_key.put_sshstr(key_type) 

600 kformat.encode_public(private_key.public_key(), f_public_key) 

601 

602 f_secrets = _FragList([checkval, checkval]) 

603 f_secrets.put_sshstr(key_type) 

604 kformat.encode_private(private_key, f_secrets) 

605 f_secrets.put_sshstr(comment) 

606 f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)]) 

607 

608 # top-level structure 

609 f_main = _FragList() 

610 f_main.put_raw(_SK_MAGIC) 

611 f_main.put_sshstr(ciphername) 

612 f_main.put_sshstr(kdfname) 

613 f_main.put_sshstr(f_kdfoptions) 

614 f_main.put_u32(nkeys) 

615 f_main.put_sshstr(f_public_key) 

616 f_main.put_sshstr(f_secrets) 

617 

618 # copy result info bytearray 

619 slen = f_secrets.size() 

620 mlen = f_main.size() 

621 buf = memoryview(bytearray(mlen + blklen)) 

622 f_main.render(buf) 

623 ofs = mlen - slen 

624 

625 # encrypt in-place 

626 if ciph is not None: 

627 ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:]) 

628 

629 txt = _ssh_pem_encode(buf[:mlen]) 

630 # Ignore the following type because mypy wants 

631 # Sequence[bytes] but what we're passing is fine. 

632 # https://github.com/python/mypy/issues/9999 

633 buf[ofs:mlen] = bytearray(slen) # type: ignore 

634 return txt 

635 

636 

637_SSH_PUBLIC_KEY_TYPES = typing.Union[ 

638 ec.EllipticCurvePublicKey, 

639 rsa.RSAPublicKey, 

640 dsa.DSAPublicKey, 

641 ed25519.Ed25519PublicKey, 

642] 

643 

644 

645def load_ssh_public_key(data: bytes, backend=None) -> _SSH_PUBLIC_KEY_TYPES: 

646 """Load public key from OpenSSH one-line format.""" 

647 backend = _get_backend(backend) 

648 utils._check_byteslike("data", data) 

649 

650 m = _SSH_PUBKEY_RC.match(data) 

651 if not m: 

652 raise ValueError("Invalid line format") 

653 key_type = orig_key_type = m.group(1) 

654 key_body = m.group(2) 

655 with_cert = False 

656 if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]: 

657 with_cert = True 

658 key_type = key_type[: -len(_CERT_SUFFIX)] 

659 kformat = _lookup_kformat(key_type) 

660 

661 try: 

662 data = memoryview(binascii.a2b_base64(key_body)) 

663 except (TypeError, binascii.Error): 

664 raise ValueError("Invalid key format") 

665 

666 inner_key_type, data = _get_sshstr(data) 

667 if inner_key_type != orig_key_type: 

668 raise ValueError("Invalid key format") 

669 if with_cert: 

670 nonce, data = _get_sshstr(data) 

671 public_key, data = kformat.load_public(key_type, data, backend) 

672 if with_cert: 

673 serial, data = _get_u64(data) 

674 cctype, data = _get_u32(data) 

675 key_id, data = _get_sshstr(data) 

676 principals, data = _get_sshstr(data) 

677 valid_after, data = _get_u64(data) 

678 valid_before, data = _get_u64(data) 

679 crit_options, data = _get_sshstr(data) 

680 extensions, data = _get_sshstr(data) 

681 reserved, data = _get_sshstr(data) 

682 sig_key, data = _get_sshstr(data) 

683 signature, data = _get_sshstr(data) 

684 _check_empty(data) 

685 return public_key 

686 

687 

688def serialize_ssh_public_key(public_key: _SSH_PUBLIC_KEY_TYPES) -> bytes: 

689 """One-line public key format for OpenSSH""" 

690 if isinstance(public_key, ec.EllipticCurvePublicKey): 

691 key_type = _ecdsa_key_type(public_key) 

692 elif isinstance(public_key, rsa.RSAPublicKey): 

693 key_type = _SSH_RSA 

694 elif isinstance(public_key, dsa.DSAPublicKey): 

695 key_type = _SSH_DSA 

696 elif isinstance(public_key, ed25519.Ed25519PublicKey): 

697 key_type = _SSH_ED25519 

698 else: 

699 raise ValueError("Unsupported key type") 

700 kformat = _lookup_kformat(key_type) 

701 

702 f_pub = _FragList() 

703 f_pub.put_sshstr(key_type) 

704 kformat.encode_public(public_key, f_pub) 

705 

706 pub = binascii.b2a_base64(f_pub.tobytes()).strip() 

707 return b"".join([key_type, b" ", pub])