Coverage for src/abcd_graph/graph/graph.py: 100%
93 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-12-04 21:31 +0100
« prev ^ index » next coverage.py v7.5.3, created at 2024-12-04 21:31 +0100
1# Copyright (c) 2024 Jordan Barrett & Aleksander Wojnarowicz
2#
3# Permission is hereby granted, free of charge, to any person obtaining a copy
4# of this software and associated documentation files (the "Software"), to deal
5# in the Software without restriction, including without limitation the rights
6# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7# copies of the Software, and to permit persons to whom the Software is
8# furnished to do so, subject to the following conditions:
9#
10# The above copyright notice and this permission notice shall be included in all
11# copies or substantial portions of the Software.
12#
13# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19# SOFTWARE.
21__all__ = ["ABCDGraph"]
23import time
24import warnings
25from datetime import datetime
26from typing import Optional
28from abcd_graph.callbacks.abstract import (
29 ABCDCallback,
30 BuildContext,
31)
32from abcd_graph.exporter import GraphExporter
33from abcd_graph.graph.community import ABCDCommunity
34from abcd_graph.graph.core.abcd_objects import GraphImpl
35from abcd_graph.graph.core.build import (
36 add_outliers,
37 assign_degrees,
38 build_communities,
39 build_community_sizes,
40 build_degrees,
41 split_degrees,
42)
43from abcd_graph.logger import construct_logger
44from abcd_graph.models import (
45 Model,
46 configuration_model,
47)
48from abcd_graph.params import ABCDParams
51class ABCDGraph:
52 def __init__(
53 self,
54 params: Optional[ABCDParams] = None,
55 logger: bool = False,
56 callbacks: Optional[list[ABCDCallback]] = None,
57 ) -> None:
59 self.params: ABCDParams = params or ABCDParams()
61 self._vcount = self.params.vcount
63 self.num_outliers = self.params.num_outliers
65 self._has_outliers: bool = self.num_outliers > 0
67 self._num_regular_vertices = self._vcount - self.num_outliers
69 self.logger = construct_logger(logger)
71 self._graph: Optional[GraphImpl] = None
73 self._exporter: Optional[GraphExporter] = None
74 self._callbacks = callbacks or []
76 def reset(self) -> None:
77 self._graph = None
79 @property
80 def is_built(self) -> bool:
81 return self._graph is not None
83 @property
84 def exporter(self) -> GraphExporter:
85 if not self.is_built:
86 raise RuntimeError("Exporter is not available if the graph has not been built.")
88 if self._exporter is None:
89 raise RuntimeError("Exporter is not available.")
91 assert self._exporter is not None
93 return self._exporter
95 @property
96 def vcount(self) -> int:
97 return self._vcount if self.is_built else 0
99 @property
100 def edges(self) -> list[tuple[int, int]]:
101 return self._graph.edges if self._graph else []
103 @property
104 def membership_list(self) -> list[int]:
105 return self._graph.membership_list if self._graph else []
107 @property
108 def communities(self) -> list[ABCDCommunity]:
109 return (
110 [
111 ABCDCommunity(
112 community_id=community.community_id,
113 vertices=community.vertices,
114 average_degree=community.average_degree,
115 degree_sequence=community.degree_sequence,
116 empirical_xi=community.empirical_xi,
117 )
118 for community in self._graph.communities
119 ]
120 if self._graph
121 else []
122 )
124 def build(self, model: Optional[Model] = None) -> "ABCDGraph":
125 if self.is_built:
126 warnings.warn("Graph has already been built. Run `reset` and try again.")
127 return self
129 model = model if model else configuration_model
131 context = BuildContext(
132 model_used=model,
133 start_time=datetime.now(),
134 params=self.params,
135 number_of_nodes=self._vcount,
136 )
138 for callback in self._callbacks:
139 callback.before_build(context)
141 try:
142 build_start = time.perf_counter()
143 build_end = self._build_impl(model)
144 context.end_time = datetime.now()
145 except Exception as e:
146 self.logger.error(f"An error occurred while building the graph: {e}")
147 self.reset()
148 raise e
150 context.raw_build_time = build_end - build_start
152 assert self._graph is not None
153 self._exporter = GraphExporter(self._graph)
155 for callback in self._callbacks:
156 callback.after_build(self._graph, context, self._exporter)
158 return self
160 def _build_impl(self, model: Model) -> float:
161 degrees = build_degrees(
162 self._num_regular_vertices,
163 self.params.gamma,
164 self.params.min_degree,
165 self.params.max_degree,
166 )
168 self.logger.info("Building community sizes")
170 community_sizes = build_community_sizes(
171 self._num_regular_vertices,
172 self.params.beta,
173 self.params.min_community_size,
174 self.params.max_community_size,
175 )
177 self.logger.info("Building communities")
179 communities = build_communities(community_sizes)
181 self.logger.info("Assigning degrees")
183 deg = assign_degrees(degrees, communities, community_sizes, self.params.xi)
185 self.logger.info("Splitting degrees")
187 deg_c, deg_b = split_degrees(deg, communities, self.params.xi)
189 if self._has_outliers:
190 self.logger.info("Adding outliers")
191 communities, deg_b, deg_c = add_outliers(
192 vcount=self._vcount,
193 num_outliers=self.num_outliers,
194 gamma=self.params.gamma,
195 min_degree=self.params.min_degree,
196 max_degree=self.params.max_degree,
197 communities=communities,
198 deg_b=deg_b,
199 deg_c=deg_c,
200 )
202 self._graph = GraphImpl(deg_b, deg_c, params=self.params)
204 self.logger.info("Building community edges")
205 self._graph.build_communities(communities, model)
207 self.logger.info("Building background edges")
208 self._graph.build_background_edges(model)
210 self.logger.info("Resolving collisions")
211 self._graph.combine_edges()
213 self._graph.rewire_graph()
215 return time.perf_counter()