Coverage for src/meshadmin/server/networks/services.py: 95%
201 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-10 16:08 +0200
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-10 16:08 +0200
1import hashlib
2import ipaddress
3import json
4import os
5from datetime import datetime
6from pathlib import Path
7from typing import Optional
9import structlog
10import yaml
11from django.contrib.auth import get_user_model
12from django.utils.timezone import now
13from jwcrypto.jwk import JWK
14from jwcrypto.jwt import JWT
16from meshadmin.common.utils import create_ca, print_ca, sign_keys
17from meshadmin.server import assets
18from meshadmin.server.networks.models import (
19 CA,
20 Group,
21 Host,
22 HostCert,
23 HostConfig,
24 Network,
25 NetworkMembership,
26 Rule,
27 SigningCA,
28 Template,
29)
31User = get_user_model()
32logger = structlog.get_logger(__name__)
35def create_available_hosts_iterator(cidr, unavailable_ips):
36 network = ipaddress.IPv4Network(cidr)
37 hosts_iterator = (
38 host for host in network.hosts() if str(host) not in unavailable_ips
39 )
40 return hosts_iterator
43def network_available_hosts_iterator(network):
44 reserved_ips = [
45 host.assigned_ip for host in Host.objects.filter(network=network).all()
46 ]
47 ipv4_iterator = create_available_hosts_iterator(network.cidr, reserved_ips)
48 return ipv4_iterator
51def create_network_ca(ca_name, network):
52 cert, key = create_ca(ca_name)
53 cert_print = print_ca(cert)
54 ca = CA.objects.create(
55 network=network, name=ca_name, cert=cert, key=key, cert_print=cert_print
56 )
57 return ca
60def create_network(
61 network_name: str, network_cidr: str, user: User, update_interval: int = 5
62):
63 logger.info("creating network")
64 network = Network.objects.create(
65 name=network_name, cidr=network_cidr, update_interval=update_interval
66 )
67 NetworkMembership.objects.create(
68 network=network, user=user, role=NetworkMembership.Role.ADMIN
69 )
71 ca_name = "auto created initial ca"
72 cert, key = create_ca(ca_name)
73 json_data = print_ca(cert)
74 ca = CA.objects.create(
75 network=network, name=ca_name, cert=cert, key=key, cert_print=json_data
76 )
78 SigningCA.objects.create(network=network, ca=ca)
79 logger.info("created network", network=str(network))
80 return network
83def generate_config_yaml(host_id: int, ignore_freeze: bool = False):
84 host = Host.objects.get(id=host_id)
85 if host.config_freeze and not ignore_freeze:
86 last_config = host.hostconfig_set.order_by("-created_at").first()
87 if last_config:
88 logger.info("using frozen config", host_id=host.id, host_name=host.name)
89 return last_config.config
91 logger.info("generating config", host_id=host.id, host_name=host.name)
93 network = host.network
94 ca = network.signingca.ca
96 assert host.public_key
97 # load config yaml
98 config_template = (assets.asset_path / "config.yml").read_text()
99 config_data = yaml.safe_load(config_template)
100 config_data["pki"]["ca"] = "".join([ca.cert for ca in network.ca_set.all()])
102 groups = frozenset([group.name for group in host.groups.all()])
103 assigned_ip = f"{host.assigned_ip}/24"
105 group_timestamps = frozenset([group.updated_at for group in host.groups.all()])
106 group_rules = frozenset(
107 [rule.updated_at for group in host.groups.all() for rule in group.rules.all()]
108 )
110 query_key = (
111 ca.cert,
112 host.public_key,
113 host.name,
114 assigned_ip,
115 hash(group_timestamps),
116 hash(group_rules),
117 )
118 query_key_hash = hash(query_key)
119 host_cert = HostCert.objects.filter(host=host, ca=ca).first()
120 if host_cert is None or host_cert.hash != query_key_hash:
121 logger.info(
122 "generating new host certificate",
123 host_id=host.id,
124 host_name=host.name,
125 reason="initial" if host_cert is None else "hash_changed",
126 )
127 if host_cert:
128 host_cert.delete()
130 cert = sign_keys(
131 ca_key=ca.key,
132 ca_crt=ca.cert,
133 public_key=host.public_key,
134 name=host.name,
135 ip=assigned_ip,
136 groups=groups,
137 )
138 host_cert = HostCert.objects.create(
139 host=host, ca=ca, cert=cert, hash=query_key_hash
140 )
142 config_data["pki"]["cert"] = host_cert.cert
143 config_data["pki"]["key"] = "host.key"
145 lighthouses = Host.objects.filter(network=network, is_lighthouse=True).all()
146 if host.is_lighthouse:
147 config_data["lighthouse"]["am_lighthouse"] = True
148 config_data["lighthouse"]["hosts"] = []
149 else:
150 config_data["lighthouse"]["am_lighthouse"] = False
151 config_data["lighthouse"]["hosts"] = [
152 lighthouse.assigned_ip for lighthouse in lighthouses
153 ]
154 config_data["static_host_map"] = {
155 lighthouse.assigned_ip: [f"{lighthouse.public_ip_or_hostname}:4242"]
156 for lighthouse in lighthouses
157 }
158 config_data["relay"]["am_relay"] = host.is_relay
159 config_data["relay"]["use_relays"] = host.use_relay
161 config_data["tun"]["dev"] = host.interface
163 inbound_rules = []
164 outbound_rules = []
165 for group in host.groups.all():
166 for rule in group.rules.all():
167 rule_data = {}
169 if rule.local_cidr is not None:
170 rule_data["local_cidr"] = rule.local_cidr
172 if rule.cidr is not None:
173 rule_data["cidr"] = rule.cidr
175 groups = rule.groups.all()
176 if len(groups) > 0:
177 rule_data["groups"] = [group.name for group in groups]
179 if rule.group is not None:
180 rule_data["group"] = rule.group.name
182 if rule.proto is not None:
183 rule_data["proto"] = rule.proto
185 if rule.port is not None:
186 rule_data["port"] = rule.port
188 if rule.direction == Rule.Direction.INBOUND:
189 inbound_rules.append(rule_data)
190 else:
191 outbound_rules.append(rule_data)
193 config_data["firewall"]["inbound"] = inbound_rules
194 config_data["firewall"]["outbound"] = outbound_rules
196 yaml_config = yaml.safe_dump(config_data, indent=4)
198 sha256 = hashlib.sha256(yaml_config.encode()).hexdigest()
200 if not HostConfig.objects.filter(sha256=sha256).exists():
201 logger.info("saving new config version", host_id=host.id, host_name=host.name)
202 HostConfig.objects.create(host=host, config=yaml_config, sha256=sha256)
203 else:
204 logger.debug("config unchanged", host_id=host.id, host_name=host.name)
206 return yaml_config
209def create_template(
210 name: str,
211 network_name: str,
212 is_lighthouse=False,
213 is_relay: bool = False,
214 use_relay: bool = True,
215 groups: list[str] = (),
216 reusable: bool = True,
217 usage_limit: int = None,
218 ephemeral_peers: bool = False,
219 expires_at: datetime = None,
220):
221 template = Template.objects.create(
222 name=name,
223 network=Network.objects.get(name=network_name),
224 is_lighthouse=is_lighthouse,
225 is_relay=is_relay,
226 use_relay=use_relay,
227 reusable=reusable,
228 usage_limit=usage_limit,
229 ephemeral_peers=ephemeral_peers,
230 expires_at=expires_at,
231 )
233 for group in groups:
234 try:
235 template.groups.add(
236 Group.objects.get(name=group, network__name=network_name)
237 )
238 except Group.DoesNotExist:
239 raise LookupError(f"Group does not exist in network {group}/{network_name}")
241 return template
244def get_server_signing_key():
245 key_path = Path("enrollment_signing.key")
246 if os.path.exists(key_path):
247 with open(key_path, "r") as f:
248 return JWK.from_json(f.read())
249 else:
250 key = JWK.generate(kty="EC", crv="P-256")
251 with open(key_path, "w") as f:
252 f.write(key.export_private())
253 os.chmod(key_path, 0o600)
254 return key
257def generate_enrollment_token(template: Template):
258 signing_key = get_server_signing_key()
259 claims = {
260 "jti": str(template.enrollment_key),
261 "iss": "meshadmin",
262 "sub": f"template:{template.id}",
263 "iat": int(now().timestamp()),
264 "template_id": template.id,
265 "network_id": template.network_id,
266 "is_lighthouse": template.is_lighthouse,
267 "is_relay": template.is_relay,
268 "use_relay": template.use_relay,
269 "reusable": template.reusable,
270 "usage_limit": template.usage_limit,
271 "ephemeral_peers": template.ephemeral_peers,
272 }
273 if template.expires_at:
274 claims["exp"] = int(template.expires_at.timestamp())
275 token = JWT(header={"alg": "ES256"}, claims=claims)
276 token.make_signed_token(signing_key)
277 return token.serialize()
280def verify_enrollment_token(token_string):
281 signing_key = get_server_signing_key()
282 try:
283 token = JWT(jwt=token_string)
284 token.validate(signing_key)
285 payload = json.loads(token.token.objects["payload"])
286 template_id = payload.get("template_id")
287 if not template_id:
288 raise ValueError("Invalid token: missing template_id")
289 return template_id
290 except Exception as e:
291 error_message = str(e)
292 if "Expired" in error_message:
293 raise ValueError("Enrollment token has expired")
294 elif "Invalid signature" in error_message:
295 raise ValueError("Invalid enrollment token signature")
296 else:
297 raise ValueError(f"Invalid enrollment token: {error_message}")
300def enrollment(
301 enrollment_key: str,
302 public_auth_key: str,
303 enroll_on_existence: bool,
304 public_ip: str,
305 preferred_hostname,
306 public_net_key,
307 interface: str = "nebula1",
308):
309 try:
310 template_id = verify_enrollment_token(enrollment_key)
311 template = Template.objects.get(id=template_id)
312 except ValueError as e:
313 logger.error("invalid enrollment token", error=str(e))
314 raise ValueError(f"Invalid enrollment token: {str(e)}")
315 except Template.DoesNotExist:
316 logger.error("template not found", template_id=template_id)
317 raise ValueError("Template not found")
319 # Check usage limit
320 if not template.reusable:
321 if template.usage_count >= 1:
322 logger.error(
323 "single-use enrollment key has already been used",
324 template_id=template.id,
325 )
326 raise ValueError("Single-use enrollment key has already been used")
327 elif template.usage_limit:
328 if template.usage_count >= template.usage_limit:
329 logger.error("enrollment key usage limit exceeded", template_id=template.id)
330 raise ValueError("Enrollment key usage limit exceeded")
332 # check if public key is already enrolled
333 thumbprint = JWK.from_json(public_auth_key).thumbprint()
334 host: Optional[Host] = Host.objects.filter(public_auth_kid=thumbprint).first()
336 # host already registered
337 if host:
338 if enroll_on_existence:
339 logger.info(
340 "host already exists, aborting enrollment",
341 host_id=host.id,
342 enroll_on_existence=enroll_on_existence,
343 )
344 raise ValueError("Host already enrolled")
346 else:
347 host.delete()
349 network = template.network
350 ipv4_iterator = network_available_hosts_iterator(network)
352 if template.is_lighthouse and not public_ip:
353 raise ValueError("Cannot enroll a lighthouse without public_ip")
355 jwk_public_auth = JWK.from_json(public_auth_key)
357 already_registered_hostnames = (
358 Host.objects.values("name")
359 .filter(network=network, name__startswith=preferred_hostname)
360 .all()
361 )
362 final_hostname = preferred_hostname
363 i = 1
364 existing_hostnames = set([host["name"] for host in already_registered_hostnames])
365 while final_hostname in existing_hostnames:
366 final_hostname = f"{preferred_hostname}-{i}"
367 i += 1
369 host = Host.objects.create(
370 network=network,
371 name=final_hostname,
372 assigned_ip=next(ipv4_iterator),
373 is_relay=template.is_relay,
374 is_lighthouse=template.is_lighthouse,
375 public_ip_or_hostname=public_ip,
376 public_key=public_net_key,
377 public_auth_key=public_auth_key,
378 public_auth_kid=jwk_public_auth.thumbprint(),
379 is_ephemeral=template.ephemeral_peers,
380 interface=interface,
381 )
383 # Increment usage count
384 template.usage_count += 1
385 template.save()
387 for group in template.groups.all():
388 host.groups.add(group)
390 host.save()
391 return host
394def create_group(network_pk: int, group_name: str, description: str = ""):
395 network = Network.objects.get(pk=network_pk)
396 return Group.objects.create(
397 network=network, name=group_name, description=description
398 )