Coverage for src/abcd_graph/callbacks/stats_collector.py: 100%

35 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-11-17 23:31 +0100

1# Copyright (c) 2024 Jordan Barrett & Aleksander Wojnarowicz 

2# 

3# Permission is hereby granted, free of charge, to any person obtaining a copy 

4# of this software and associated documentation files (the "Software"), to deal 

5# in the Software without restriction, including without limitation the rights 

6# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

7# copies of the Software, and to permit persons to whom the Software is 

8# furnished to do so, subject to the following conditions: 

9# 

10# The above copyright notice and this permission notice shall be included in all 

11# copies or substantial portions of the Software. 

12# 

13# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

14# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

15# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

16# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

17# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

18# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

19# SOFTWARE. 

20 

21__all__ = ["StatsCollector"] 

22 

23from typing import Any 

24 

25from abcd_graph.callbacks.abstract import ( 

26 ABCDCallback, 

27 BuildContext, 

28) 

29from abcd_graph.exporter import GraphExporter 

30from abcd_graph.graph.core.abcd_objects.graph_impl import GraphImpl 

31 

32 

33class StatsCollector(ABCDCallback): 

34 def __init__(self) -> None: 

35 self._statistics: dict[str, Any] = {} 

36 

37 @property 

38 def statistics(self) -> dict[str, Any]: 

39 return self._statistics 

40 

41 def log_statistic(self, key: str, value: Any) -> None: 

42 self._statistics[key] = value 

43 

44 def fetch_statistic(self, key: str) -> Any: 

45 return self._statistics[key] 

46 

47 def before_build(self, context: BuildContext) -> None: 

48 self.log_statistic("model_used", context.model_used.__name__) 

49 self.log_statistic("params", context.params) 

50 self.log_statistic("number_of_nodes", context.number_of_nodes) 

51 

52 def after_build(self, graph: "GraphImpl", context: BuildContext, exporter: GraphExporter) -> None: 

53 self.log_statistic("start_time", context.start_time) 

54 self.log_statistic("end_time", context.end_time) 

55 self.log_statistic("time_to_build", context.raw_build_time) 

56 

57 self.log_statistic("number_of_edges", len(graph.edges)) 

58 self.log_statistic("number_of_communities", graph.num_communities) 

59 self.log_statistic("expected_average_degree", graph.expected_average_degree) 

60 self.log_statistic("actual_average_degree", graph.average_degree) 

61 self.log_statistic("expected_average_community_size", graph.expected_average_community_size) 

62 self.log_statistic("actual_average_community_size", graph.actual_average_community_size) 

63 self.log_statistic("number_of_loops", graph.num_loops) 

64 self.log_statistic("number_of_multi_edges", graph.num_multi_edges) 

65 

66 self.log_statistic("empirical_xi", get_empirical_xi(graph)) 

67 

68 

69def get_empirical_xi(graph: GraphImpl) -> float: 

70 num_community_edges = sum(len(community.edges) for community in graph.communities) 

71 return 1 - (num_community_edges / len(graph.edges))