Coverage for src/meshadmin/server/networks/services.py: 95%

201 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-10 16:08 +0200

1import hashlib 

2import ipaddress 

3import json 

4import os 

5from datetime import datetime 

6from pathlib import Path 

7from typing import Optional 

8 

9import structlog 

10import yaml 

11from django.contrib.auth import get_user_model 

12from django.utils.timezone import now 

13from jwcrypto.jwk import JWK 

14from jwcrypto.jwt import JWT 

15 

16from meshadmin.common.utils import create_ca, print_ca, sign_keys 

17from meshadmin.server import assets 

18from meshadmin.server.networks.models import ( 

19 CA, 

20 Group, 

21 Host, 

22 HostCert, 

23 HostConfig, 

24 Network, 

25 NetworkMembership, 

26 Rule, 

27 SigningCA, 

28 Template, 

29) 

30 

31User = get_user_model() 

32logger = structlog.get_logger(__name__) 

33 

34 

35def create_available_hosts_iterator(cidr, unavailable_ips): 

36 network = ipaddress.IPv4Network(cidr) 

37 hosts_iterator = ( 

38 host for host in network.hosts() if str(host) not in unavailable_ips 

39 ) 

40 return hosts_iterator 

41 

42 

43def network_available_hosts_iterator(network): 

44 reserved_ips = [ 

45 host.assigned_ip for host in Host.objects.filter(network=network).all() 

46 ] 

47 ipv4_iterator = create_available_hosts_iterator(network.cidr, reserved_ips) 

48 return ipv4_iterator 

49 

50 

51def create_network_ca(ca_name, network): 

52 cert, key = create_ca(ca_name) 

53 cert_print = print_ca(cert) 

54 ca = CA.objects.create( 

55 network=network, name=ca_name, cert=cert, key=key, cert_print=cert_print 

56 ) 

57 return ca 

58 

59 

60def create_network( 

61 network_name: str, network_cidr: str, user: User, update_interval: int = 5 

62): 

63 logger.info("creating network") 

64 network = Network.objects.create( 

65 name=network_name, cidr=network_cidr, update_interval=update_interval 

66 ) 

67 NetworkMembership.objects.create( 

68 network=network, user=user, role=NetworkMembership.Role.ADMIN 

69 ) 

70 

71 ca_name = "auto created initial ca" 

72 cert, key = create_ca(ca_name) 

73 json_data = print_ca(cert) 

74 ca = CA.objects.create( 

75 network=network, name=ca_name, cert=cert, key=key, cert_print=json_data 

76 ) 

77 

78 SigningCA.objects.create(network=network, ca=ca) 

79 logger.info("created network", network=str(network)) 

80 return network 

81 

82 

83def generate_config_yaml(host_id: int, ignore_freeze: bool = False): 

84 host = Host.objects.get(id=host_id) 

85 if host.config_freeze and not ignore_freeze: 

86 last_config = host.hostconfig_set.order_by("-created_at").first() 

87 if last_config: 

88 logger.info("using frozen config", host_id=host.id, host_name=host.name) 

89 return last_config.config 

90 

91 logger.info("generating config", host_id=host.id, host_name=host.name) 

92 

93 network = host.network 

94 ca = network.signingca.ca 

95 

96 assert host.public_key 

97 # load config yaml 

98 config_template = (assets.asset_path / "config.yml").read_text() 

99 config_data = yaml.safe_load(config_template) 

100 config_data["pki"]["ca"] = "".join([ca.cert for ca in network.ca_set.all()]) 

101 

102 groups = frozenset([group.name for group in host.groups.all()]) 

103 assigned_ip = f"{host.assigned_ip}/24" 

104 

105 group_timestamps = frozenset([group.updated_at for group in host.groups.all()]) 

106 group_rules = frozenset( 

107 [rule.updated_at for group in host.groups.all() for rule in group.rules.all()] 

108 ) 

109 

110 query_key = ( 

111 ca.cert, 

112 host.public_key, 

113 host.name, 

114 assigned_ip, 

115 hash(group_timestamps), 

116 hash(group_rules), 

117 ) 

118 query_key_hash = hash(query_key) 

119 host_cert = HostCert.objects.filter(host=host, ca=ca).first() 

120 if host_cert is None or host_cert.hash != query_key_hash: 

121 logger.info( 

122 "generating new host certificate", 

123 host_id=host.id, 

124 host_name=host.name, 

125 reason="initial" if host_cert is None else "hash_changed", 

126 ) 

127 if host_cert: 

128 host_cert.delete() 

129 

130 cert = sign_keys( 

131 ca_key=ca.key, 

132 ca_crt=ca.cert, 

133 public_key=host.public_key, 

134 name=host.name, 

135 ip=assigned_ip, 

136 groups=groups, 

137 ) 

138 host_cert = HostCert.objects.create( 

139 host=host, ca=ca, cert=cert, hash=query_key_hash 

140 ) 

141 

142 config_data["pki"]["cert"] = host_cert.cert 

143 config_data["pki"]["key"] = "host.key" 

144 

145 lighthouses = Host.objects.filter(network=network, is_lighthouse=True).all() 

146 if host.is_lighthouse: 

147 config_data["lighthouse"]["am_lighthouse"] = True 

148 config_data["lighthouse"]["hosts"] = [] 

149 else: 

150 config_data["lighthouse"]["am_lighthouse"] = False 

151 config_data["lighthouse"]["hosts"] = [ 

152 lighthouse.assigned_ip for lighthouse in lighthouses 

153 ] 

154 config_data["static_host_map"] = { 

155 lighthouse.assigned_ip: [f"{lighthouse.public_ip_or_hostname}:4242"] 

156 for lighthouse in lighthouses 

157 } 

158 config_data["relay"]["am_relay"] = host.is_relay 

159 config_data["relay"]["use_relays"] = host.use_relay 

160 

161 config_data["tun"]["dev"] = host.interface 

162 

163 inbound_rules = [] 

164 outbound_rules = [] 

165 for group in host.groups.all(): 

166 for rule in group.rules.all(): 

167 rule_data = {} 

168 

169 if rule.local_cidr is not None: 

170 rule_data["local_cidr"] = rule.local_cidr 

171 

172 if rule.cidr is not None: 

173 rule_data["cidr"] = rule.cidr 

174 

175 groups = rule.groups.all() 

176 if len(groups) > 0: 

177 rule_data["groups"] = [group.name for group in groups] 

178 

179 if rule.group is not None: 

180 rule_data["group"] = rule.group.name 

181 

182 if rule.proto is not None: 

183 rule_data["proto"] = rule.proto 

184 

185 if rule.port is not None: 

186 rule_data["port"] = rule.port 

187 

188 if rule.direction == Rule.Direction.INBOUND: 

189 inbound_rules.append(rule_data) 

190 else: 

191 outbound_rules.append(rule_data) 

192 

193 config_data["firewall"]["inbound"] = inbound_rules 

194 config_data["firewall"]["outbound"] = outbound_rules 

195 

196 yaml_config = yaml.safe_dump(config_data, indent=4) 

197 

198 sha256 = hashlib.sha256(yaml_config.encode()).hexdigest() 

199 

200 if not HostConfig.objects.filter(sha256=sha256).exists(): 

201 logger.info("saving new config version", host_id=host.id, host_name=host.name) 

202 HostConfig.objects.create(host=host, config=yaml_config, sha256=sha256) 

203 else: 

204 logger.debug("config unchanged", host_id=host.id, host_name=host.name) 

205 

206 return yaml_config 

207 

208 

209def create_template( 

210 name: str, 

211 network_name: str, 

212 is_lighthouse=False, 

213 is_relay: bool = False, 

214 use_relay: bool = True, 

215 groups: list[str] = (), 

216 reusable: bool = True, 

217 usage_limit: int = None, 

218 ephemeral_peers: bool = False, 

219 expires_at: datetime = None, 

220): 

221 template = Template.objects.create( 

222 name=name, 

223 network=Network.objects.get(name=network_name), 

224 is_lighthouse=is_lighthouse, 

225 is_relay=is_relay, 

226 use_relay=use_relay, 

227 reusable=reusable, 

228 usage_limit=usage_limit, 

229 ephemeral_peers=ephemeral_peers, 

230 expires_at=expires_at, 

231 ) 

232 

233 for group in groups: 

234 try: 

235 template.groups.add( 

236 Group.objects.get(name=group, network__name=network_name) 

237 ) 

238 except Group.DoesNotExist: 

239 raise LookupError(f"Group does not exist in network {group}/{network_name}") 

240 

241 return template 

242 

243 

244def get_server_signing_key(): 

245 key_path = Path("enrollment_signing.key") 

246 if os.path.exists(key_path): 

247 with open(key_path, "r") as f: 

248 return JWK.from_json(f.read()) 

249 else: 

250 key = JWK.generate(kty="EC", crv="P-256") 

251 with open(key_path, "w") as f: 

252 f.write(key.export_private()) 

253 os.chmod(key_path, 0o600) 

254 return key 

255 

256 

257def generate_enrollment_token(template: Template): 

258 signing_key = get_server_signing_key() 

259 claims = { 

260 "jti": str(template.enrollment_key), 

261 "iss": "meshadmin", 

262 "sub": f"template:{template.id}", 

263 "iat": int(now().timestamp()), 

264 "template_id": template.id, 

265 "network_id": template.network_id, 

266 "is_lighthouse": template.is_lighthouse, 

267 "is_relay": template.is_relay, 

268 "use_relay": template.use_relay, 

269 "reusable": template.reusable, 

270 "usage_limit": template.usage_limit, 

271 "ephemeral_peers": template.ephemeral_peers, 

272 } 

273 if template.expires_at: 

274 claims["exp"] = int(template.expires_at.timestamp()) 

275 token = JWT(header={"alg": "ES256"}, claims=claims) 

276 token.make_signed_token(signing_key) 

277 return token.serialize() 

278 

279 

280def verify_enrollment_token(token_string): 

281 signing_key = get_server_signing_key() 

282 try: 

283 token = JWT(jwt=token_string) 

284 token.validate(signing_key) 

285 payload = json.loads(token.token.objects["payload"]) 

286 template_id = payload.get("template_id") 

287 if not template_id: 

288 raise ValueError("Invalid token: missing template_id") 

289 return template_id 

290 except Exception as e: 

291 error_message = str(e) 

292 if "Expired" in error_message: 

293 raise ValueError("Enrollment token has expired") 

294 elif "Invalid signature" in error_message: 

295 raise ValueError("Invalid enrollment token signature") 

296 else: 

297 raise ValueError(f"Invalid enrollment token: {error_message}") 

298 

299 

300def enrollment( 

301 enrollment_key: str, 

302 public_auth_key: str, 

303 enroll_on_existence: bool, 

304 public_ip: str, 

305 preferred_hostname, 

306 public_net_key, 

307 interface: str = "nebula1", 

308): 

309 try: 

310 template_id = verify_enrollment_token(enrollment_key) 

311 template = Template.objects.get(id=template_id) 

312 except ValueError as e: 

313 logger.error("invalid enrollment token", error=str(e)) 

314 raise ValueError(f"Invalid enrollment token: {str(e)}") 

315 except Template.DoesNotExist: 

316 logger.error("template not found", template_id=template_id) 

317 raise ValueError("Template not found") 

318 

319 # Check usage limit 

320 if not template.reusable: 

321 if template.usage_count >= 1: 

322 logger.error( 

323 "single-use enrollment key has already been used", 

324 template_id=template.id, 

325 ) 

326 raise ValueError("Single-use enrollment key has already been used") 

327 elif template.usage_limit: 

328 if template.usage_count >= template.usage_limit: 

329 logger.error("enrollment key usage limit exceeded", template_id=template.id) 

330 raise ValueError("Enrollment key usage limit exceeded") 

331 

332 # check if public key is already enrolled 

333 thumbprint = JWK.from_json(public_auth_key).thumbprint() 

334 host: Optional[Host] = Host.objects.filter(public_auth_kid=thumbprint).first() 

335 

336 # host already registered 

337 if host: 

338 if enroll_on_existence: 

339 logger.info( 

340 "host already exists, aborting enrollment", 

341 host_id=host.id, 

342 enroll_on_existence=enroll_on_existence, 

343 ) 

344 raise ValueError("Host already enrolled") 

345 

346 else: 

347 host.delete() 

348 

349 network = template.network 

350 ipv4_iterator = network_available_hosts_iterator(network) 

351 

352 if template.is_lighthouse and not public_ip: 

353 raise ValueError("Cannot enroll a lighthouse without public_ip") 

354 

355 jwk_public_auth = JWK.from_json(public_auth_key) 

356 

357 already_registered_hostnames = ( 

358 Host.objects.values("name") 

359 .filter(network=network, name__startswith=preferred_hostname) 

360 .all() 

361 ) 

362 final_hostname = preferred_hostname 

363 i = 1 

364 existing_hostnames = set([host["name"] for host in already_registered_hostnames]) 

365 while final_hostname in existing_hostnames: 

366 final_hostname = f"{preferred_hostname}-{i}" 

367 i += 1 

368 

369 host = Host.objects.create( 

370 network=network, 

371 name=final_hostname, 

372 assigned_ip=next(ipv4_iterator), 

373 is_relay=template.is_relay, 

374 is_lighthouse=template.is_lighthouse, 

375 public_ip_or_hostname=public_ip, 

376 public_key=public_net_key, 

377 public_auth_key=public_auth_key, 

378 public_auth_kid=jwk_public_auth.thumbprint(), 

379 is_ephemeral=template.ephemeral_peers, 

380 interface=interface, 

381 ) 

382 

383 # Increment usage count 

384 template.usage_count += 1 

385 template.save() 

386 

387 for group in template.groups.all(): 

388 host.groups.add(group) 

389 

390 host.save() 

391 return host 

392 

393 

394def create_group(network_pk: int, group_name: str, description: str = ""): 

395 network = Network.objects.get(pk=network_pk) 

396 return Group.objects.create( 

397 network=network, name=group_name, description=description 

398 )