Coverage for src/abcd_graph/callbacks/stats_collector.py: 100%
35 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-11-17 23:31 +0100
« prev ^ index » next coverage.py v7.5.3, created at 2024-11-17 23:31 +0100
1# Copyright (c) 2024 Jordan Barrett & Aleksander Wojnarowicz
2#
3# Permission is hereby granted, free of charge, to any person obtaining a copy
4# of this software and associated documentation files (the "Software"), to deal
5# in the Software without restriction, including without limitation the rights
6# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7# copies of the Software, and to permit persons to whom the Software is
8# furnished to do so, subject to the following conditions:
9#
10# The above copyright notice and this permission notice shall be included in all
11# copies or substantial portions of the Software.
12#
13# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19# SOFTWARE.
21__all__ = ["StatsCollector"]
23from typing import Any
25from abcd_graph.callbacks.abstract import (
26 ABCDCallback,
27 BuildContext,
28)
29from abcd_graph.exporter import GraphExporter
30from abcd_graph.graph.core.abcd_objects.graph_impl import GraphImpl
33class StatsCollector(ABCDCallback):
34 def __init__(self) -> None:
35 self._statistics: dict[str, Any] = {}
37 @property
38 def statistics(self) -> dict[str, Any]:
39 return self._statistics
41 def log_statistic(self, key: str, value: Any) -> None:
42 self._statistics[key] = value
44 def fetch_statistic(self, key: str) -> Any:
45 return self._statistics[key]
47 def before_build(self, context: BuildContext) -> None:
48 self.log_statistic("model_used", context.model_used.__name__)
49 self.log_statistic("params", context.params)
50 self.log_statistic("number_of_nodes", context.number_of_nodes)
52 def after_build(self, graph: "GraphImpl", context: BuildContext, exporter: GraphExporter) -> None:
53 self.log_statistic("start_time", context.start_time)
54 self.log_statistic("end_time", context.end_time)
55 self.log_statistic("time_to_build", context.raw_build_time)
57 self.log_statistic("number_of_edges", len(graph.edges))
58 self.log_statistic("number_of_communities", graph.num_communities)
59 self.log_statistic("expected_average_degree", graph.expected_average_degree)
60 self.log_statistic("actual_average_degree", graph.average_degree)
61 self.log_statistic("expected_average_community_size", graph.expected_average_community_size)
62 self.log_statistic("actual_average_community_size", graph.actual_average_community_size)
63 self.log_statistic("number_of_loops", graph.num_loops)
64 self.log_statistic("number_of_multi_edges", graph.num_multi_edges)
66 self.log_statistic("empirical_xi", get_empirical_xi(graph))
69def get_empirical_xi(graph: GraphImpl) -> float:
70 num_community_edges = sum(len(community.edges) for community in graph.communities)
71 return 1 - (num_community_edges / len(graph.edges))