Coverage for src/abcd_graph/graph/graph.py: 100%

93 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-12-04 21:31 +0100

1# Copyright (c) 2024 Jordan Barrett & Aleksander Wojnarowicz 

2# 

3# Permission is hereby granted, free of charge, to any person obtaining a copy 

4# of this software and associated documentation files (the "Software"), to deal 

5# in the Software without restriction, including without limitation the rights 

6# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

7# copies of the Software, and to permit persons to whom the Software is 

8# furnished to do so, subject to the following conditions: 

9# 

10# The above copyright notice and this permission notice shall be included in all 

11# copies or substantial portions of the Software. 

12# 

13# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

14# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

15# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

16# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

17# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

18# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

19# SOFTWARE. 

20 

21__all__ = ["ABCDGraph"] 

22 

23import time 

24import warnings 

25from datetime import datetime 

26from typing import Optional 

27 

28from abcd_graph.callbacks.abstract import ( 

29 ABCDCallback, 

30 BuildContext, 

31) 

32from abcd_graph.exporter import GraphExporter 

33from abcd_graph.graph.community import ABCDCommunity 

34from abcd_graph.graph.core.abcd_objects import GraphImpl 

35from abcd_graph.graph.core.build import ( 

36 add_outliers, 

37 assign_degrees, 

38 build_communities, 

39 build_community_sizes, 

40 build_degrees, 

41 split_degrees, 

42) 

43from abcd_graph.logger import construct_logger 

44from abcd_graph.models import ( 

45 Model, 

46 configuration_model, 

47) 

48from abcd_graph.params import ABCDParams 

49 

50 

51class ABCDGraph: 

52 def __init__( 

53 self, 

54 params: Optional[ABCDParams] = None, 

55 logger: bool = False, 

56 callbacks: Optional[list[ABCDCallback]] = None, 

57 ) -> None: 

58 

59 self.params: ABCDParams = params or ABCDParams() 

60 

61 self._vcount = self.params.vcount 

62 

63 self.num_outliers = self.params.num_outliers 

64 

65 self._has_outliers: bool = self.num_outliers > 0 

66 

67 self._num_regular_vertices = self._vcount - self.num_outliers 

68 

69 self.logger = construct_logger(logger) 

70 

71 self._graph: Optional[GraphImpl] = None 

72 

73 self._exporter: Optional[GraphExporter] = None 

74 self._callbacks = callbacks or [] 

75 

76 def reset(self) -> None: 

77 self._graph = None 

78 

79 @property 

80 def is_built(self) -> bool: 

81 return self._graph is not None 

82 

83 @property 

84 def exporter(self) -> GraphExporter: 

85 if not self.is_built: 

86 raise RuntimeError("Exporter is not available if the graph has not been built.") 

87 

88 if self._exporter is None: 

89 raise RuntimeError("Exporter is not available.") 

90 

91 assert self._exporter is not None 

92 

93 return self._exporter 

94 

95 @property 

96 def vcount(self) -> int: 

97 return self._vcount if self.is_built else 0 

98 

99 @property 

100 def edges(self) -> list[tuple[int, int]]: 

101 return self._graph.edges if self._graph else [] 

102 

103 @property 

104 def membership_list(self) -> list[int]: 

105 return self._graph.membership_list if self._graph else [] 

106 

107 @property 

108 def communities(self) -> list[ABCDCommunity]: 

109 return ( 

110 [ 

111 ABCDCommunity( 

112 community_id=community.community_id, 

113 vertices=community.vertices, 

114 average_degree=community.average_degree, 

115 degree_sequence=community.degree_sequence, 

116 empirical_xi=community.empirical_xi, 

117 ) 

118 for community in self._graph.communities 

119 ] 

120 if self._graph 

121 else [] 

122 ) 

123 

124 def build(self, model: Optional[Model] = None) -> "ABCDGraph": 

125 if self.is_built: 

126 warnings.warn("Graph has already been built. Run `reset` and try again.") 

127 return self 

128 

129 model = model if model else configuration_model 

130 

131 context = BuildContext( 

132 model_used=model, 

133 start_time=datetime.now(), 

134 params=self.params, 

135 number_of_nodes=self._vcount, 

136 ) 

137 

138 for callback in self._callbacks: 

139 callback.before_build(context) 

140 

141 try: 

142 build_start = time.perf_counter() 

143 build_end = self._build_impl(model) 

144 context.end_time = datetime.now() 

145 except Exception as e: 

146 self.logger.error(f"An error occurred while building the graph: {e}") 

147 self.reset() 

148 raise e 

149 

150 context.raw_build_time = build_end - build_start 

151 

152 assert self._graph is not None 

153 self._exporter = GraphExporter(self._graph) 

154 

155 for callback in self._callbacks: 

156 callback.after_build(self._graph, context, self._exporter) 

157 

158 return self 

159 

160 def _build_impl(self, model: Model) -> float: 

161 degrees = build_degrees( 

162 self._num_regular_vertices, 

163 self.params.gamma, 

164 self.params.min_degree, 

165 self.params.max_degree, 

166 ) 

167 

168 self.logger.info("Building community sizes") 

169 

170 community_sizes = build_community_sizes( 

171 self._num_regular_vertices, 

172 self.params.beta, 

173 self.params.min_community_size, 

174 self.params.max_community_size, 

175 ) 

176 

177 self.logger.info("Building communities") 

178 

179 communities = build_communities(community_sizes) 

180 

181 self.logger.info("Assigning degrees") 

182 

183 deg = assign_degrees(degrees, communities, community_sizes, self.params.xi) 

184 

185 self.logger.info("Splitting degrees") 

186 

187 deg_c, deg_b = split_degrees(deg, communities, self.params.xi) 

188 

189 if self._has_outliers: 

190 self.logger.info("Adding outliers") 

191 communities, deg_b, deg_c = add_outliers( 

192 vcount=self._vcount, 

193 num_outliers=self.num_outliers, 

194 gamma=self.params.gamma, 

195 min_degree=self.params.min_degree, 

196 max_degree=self.params.max_degree, 

197 communities=communities, 

198 deg_b=deg_b, 

199 deg_c=deg_c, 

200 ) 

201 

202 self._graph = GraphImpl(deg_b, deg_c, params=self.params) 

203 

204 self.logger.info("Building community edges") 

205 self._graph.build_communities(communities, model) 

206 

207 self.logger.info("Building background edges") 

208 self._graph.build_background_edges(model) 

209 

210 self.logger.info("Resolving collisions") 

211 self._graph.combine_edges() 

212 

213 self._graph.rewire_graph() 

214 

215 return time.perf_counter()