muutils.statcounter
StatCounter
class for counting and calculating statistics on numbers
cleaner and more efficient than just using a Counter
or array
1"""`StatCounter` class for counting and calculating statistics on numbers 2 3cleaner and more efficient than just using a `Counter` or array""" 4 5from __future__ import annotations 6 7import json 8import math 9from collections import Counter 10from functools import cached_property 11from itertools import chain 12from typing import Callable, Optional, Sequence, Union 13 14 15# _GeneralArray = Union[np.ndarray, "torch.Tensor"] 16NumericSequence = Sequence[Union[float, int, "NumericSequence"]] 17 18# pylint: disable=abstract-method 19 20# misc 21# ================================================== 22 23 24def universal_flatten( 25 arr: Union[NumericSequence, float, int], require_rectangular: bool = True 26) -> NumericSequence: 27 """flattens any iterable""" 28 29 # mypy complains that the sequence has no attribute "flatten" 30 if hasattr(arr, "flatten") and callable(arr.flatten): # type: ignore 31 return arr.flatten() # type: ignore 32 elif isinstance(arr, Sequence): 33 elements_iterable: list[bool] = [isinstance(x, Sequence) for x in arr] 34 if require_rectangular and (all(elements_iterable) != any(elements_iterable)): 35 raise ValueError("arr contains mixed iterable and non-iterable elements") 36 if any(elements_iterable): 37 return list(chain.from_iterable(universal_flatten(x) for x in arr)) # type: ignore[misc] 38 else: 39 return arr 40 else: 41 return [arr] 42 43 44# StatCounter 45# ================================================== 46 47 48class StatCounter(Counter): 49 """`Counter`, but with some stat calculation methods which assume the keys are numerical 50 51 works best when the keys are `int`s 52 """ 53 54 def validate(self) -> bool: 55 """validate the counter as being all floats or ints""" 56 return all(isinstance(k, (bool, int, float, type(None))) for k in self.keys()) 57 58 def min(self): 59 "minimum value" 60 return min(x for x, v in self.items() if v > 0) 61 62 def max(self): 63 "maximum value" 64 return max(x for x, v in self.items() if v > 0) 65 66 def total(self): 67 """Sum of the counts""" 68 return sum(self.values()) 69 70 @cached_property 71 def keys_sorted(self) -> list: 72 """return the keys""" 73 return sorted(list(self.keys())) 74 75 def percentile(self, p: float): 76 """return the value at the given percentile 77 78 this could be log time if we did binary search, but that would be a lot of added complexity 79 """ 80 81 if p < 0 or p > 1: 82 raise ValueError(f"percentile must be between 0 and 1: {p}") 83 # flip for speed 84 sorted_keys: list[float] = [float(x) for x in self.keys_sorted] 85 sort: int = 1 86 if p > 0.51: 87 sort = -1 88 p = 1 - p 89 90 sorted_keys = sorted_keys[::sort] 91 real_target: float = p * (self.total() - 1) 92 93 n_target_f: int = math.floor(real_target) 94 n_target_c: int = math.ceil(real_target) 95 96 n_sofar: float = -1 97 98 # print(f'{p = } {real_target = } {n_target_f = } {n_target_c = }') 99 100 for i, k in enumerate(sorted_keys): 101 n_sofar += self[k] 102 103 # print(f'{k = } {n_sofar = }') 104 105 if n_sofar > n_target_f: 106 return k 107 108 elif n_sofar == n_target_f: 109 if n_sofar == n_target_c: 110 return k 111 else: 112 # print( 113 # sorted_keys[i], (n_sofar + 1 - real_target), 114 # sorted_keys[i + 1], (real_target - n_sofar), 115 # ) 116 return sorted_keys[i] * (n_sofar + 1 - real_target) + sorted_keys[ 117 i + 1 118 ] * (real_target - n_sofar) 119 else: 120 continue 121 122 raise ValueError(f"percentile {p} not found???") 123 124 def median(self) -> float: 125 return self.percentile(0.5) 126 127 def mean(self) -> float: 128 """return the mean of the values""" 129 return float(sum(k * c for k, c in self.items()) / self.total()) 130 131 def mode(self) -> float: 132 return self.most_common()[0][0] 133 134 def std(self) -> float: 135 """return the standard deviation of the values""" 136 mean: float = self.mean() 137 deviations: float = sum(c * (k - mean) ** 2 for k, c in self.items()) 138 139 return (deviations / self.total()) ** 0.5 140 141 def summary( 142 self, 143 typecast: Callable = lambda x: x, 144 *, 145 extra_percentiles: Optional[list[float]] = None, 146 ) -> dict[str, Union[float, int]]: 147 """return a summary of the stats, without the raw data. human readable and small""" 148 # common stats that always work 149 output: dict = dict( 150 total_items=self.total(), 151 n_keys=len(self.keys()), 152 mode=self.mode(), 153 ) 154 155 if self.total() > 0: 156 if self.validate(): 157 # if its a numeric counter, we can do some stats 158 output = { 159 **output, 160 **dict( 161 mean=float(self.mean()), 162 std=float(self.std()), 163 min=typecast(self.min()), 164 q1=typecast(self.percentile(0.25)), 165 median=typecast(self.median()), 166 q3=typecast(self.percentile(0.75)), 167 max=typecast(self.max()), 168 ), 169 } 170 171 if extra_percentiles is not None: 172 for p in extra_percentiles: 173 output[f"percentile_{p}"] = typecast(self.percentile(p)) 174 else: 175 # if its not, we can only do the simpler things 176 # mean mode and total are done in the initial declaration of `output` 177 pass 178 179 return output 180 181 def serialize( 182 self, 183 typecast: Callable = lambda x: x, 184 *, 185 extra_percentiles: Optional[list[float]] = None, 186 ) -> dict: 187 """return a json-serializable version of the counter 188 189 includes both the output of `summary` and the raw data: 190 191 ```json 192 { 193 "StatCounter": { <keys, values from raw data> }, 194 "summary": self.summary(typecast, extra_percentiles=extra_percentiles), 195 } 196 197 """ 198 199 return { 200 "StatCounter": { 201 typecast(k): v 202 for k, v in sorted(dict(self).items(), key=lambda x: x[0]) 203 }, 204 "summary": self.summary(typecast, extra_percentiles=extra_percentiles), 205 } 206 207 def __str__(self) -> str: 208 "summary as json with 2 space indent, good for printing" 209 return json.dumps(self.summary(), indent=2) 210 211 def __repr__(self) -> str: 212 return json.dumps(self.serialize(), indent=2) 213 214 @classmethod 215 def load(cls, data: dict) -> "StatCounter": 216 "load from a the output of `StatCounter.serialize`" 217 if "StatCounter" in data: 218 loadme = data["StatCounter"] 219 else: 220 loadme = data 221 222 return cls({float(k): v for k, v in loadme.items()}) 223 224 @classmethod 225 def from_list_arrays( 226 cls, 227 arr, 228 map_func: Callable = float, 229 ) -> "StatCounter": 230 """calls `map_func` on each element of `universal_flatten(arr)`""" 231 return cls([map_func(x) for x in universal_flatten(arr)])
25def universal_flatten( 26 arr: Union[NumericSequence, float, int], require_rectangular: bool = True 27) -> NumericSequence: 28 """flattens any iterable""" 29 30 # mypy complains that the sequence has no attribute "flatten" 31 if hasattr(arr, "flatten") and callable(arr.flatten): # type: ignore 32 return arr.flatten() # type: ignore 33 elif isinstance(arr, Sequence): 34 elements_iterable: list[bool] = [isinstance(x, Sequence) for x in arr] 35 if require_rectangular and (all(elements_iterable) != any(elements_iterable)): 36 raise ValueError("arr contains mixed iterable and non-iterable elements") 37 if any(elements_iterable): 38 return list(chain.from_iterable(universal_flatten(x) for x in arr)) # type: ignore[misc] 39 else: 40 return arr 41 else: 42 return [arr]
flattens any iterable
49class StatCounter(Counter): 50 """`Counter`, but with some stat calculation methods which assume the keys are numerical 51 52 works best when the keys are `int`s 53 """ 54 55 def validate(self) -> bool: 56 """validate the counter as being all floats or ints""" 57 return all(isinstance(k, (bool, int, float, type(None))) for k in self.keys()) 58 59 def min(self): 60 "minimum value" 61 return min(x for x, v in self.items() if v > 0) 62 63 def max(self): 64 "maximum value" 65 return max(x for x, v in self.items() if v > 0) 66 67 def total(self): 68 """Sum of the counts""" 69 return sum(self.values()) 70 71 @cached_property 72 def keys_sorted(self) -> list: 73 """return the keys""" 74 return sorted(list(self.keys())) 75 76 def percentile(self, p: float): 77 """return the value at the given percentile 78 79 this could be log time if we did binary search, but that would be a lot of added complexity 80 """ 81 82 if p < 0 or p > 1: 83 raise ValueError(f"percentile must be between 0 and 1: {p}") 84 # flip for speed 85 sorted_keys: list[float] = [float(x) for x in self.keys_sorted] 86 sort: int = 1 87 if p > 0.51: 88 sort = -1 89 p = 1 - p 90 91 sorted_keys = sorted_keys[::sort] 92 real_target: float = p * (self.total() - 1) 93 94 n_target_f: int = math.floor(real_target) 95 n_target_c: int = math.ceil(real_target) 96 97 n_sofar: float = -1 98 99 # print(f'{p = } {real_target = } {n_target_f = } {n_target_c = }') 100 101 for i, k in enumerate(sorted_keys): 102 n_sofar += self[k] 103 104 # print(f'{k = } {n_sofar = }') 105 106 if n_sofar > n_target_f: 107 return k 108 109 elif n_sofar == n_target_f: 110 if n_sofar == n_target_c: 111 return k 112 else: 113 # print( 114 # sorted_keys[i], (n_sofar + 1 - real_target), 115 # sorted_keys[i + 1], (real_target - n_sofar), 116 # ) 117 return sorted_keys[i] * (n_sofar + 1 - real_target) + sorted_keys[ 118 i + 1 119 ] * (real_target - n_sofar) 120 else: 121 continue 122 123 raise ValueError(f"percentile {p} not found???") 124 125 def median(self) -> float: 126 return self.percentile(0.5) 127 128 def mean(self) -> float: 129 """return the mean of the values""" 130 return float(sum(k * c for k, c in self.items()) / self.total()) 131 132 def mode(self) -> float: 133 return self.most_common()[0][0] 134 135 def std(self) -> float: 136 """return the standard deviation of the values""" 137 mean: float = self.mean() 138 deviations: float = sum(c * (k - mean) ** 2 for k, c in self.items()) 139 140 return (deviations / self.total()) ** 0.5 141 142 def summary( 143 self, 144 typecast: Callable = lambda x: x, 145 *, 146 extra_percentiles: Optional[list[float]] = None, 147 ) -> dict[str, Union[float, int]]: 148 """return a summary of the stats, without the raw data. human readable and small""" 149 # common stats that always work 150 output: dict = dict( 151 total_items=self.total(), 152 n_keys=len(self.keys()), 153 mode=self.mode(), 154 ) 155 156 if self.total() > 0: 157 if self.validate(): 158 # if its a numeric counter, we can do some stats 159 output = { 160 **output, 161 **dict( 162 mean=float(self.mean()), 163 std=float(self.std()), 164 min=typecast(self.min()), 165 q1=typecast(self.percentile(0.25)), 166 median=typecast(self.median()), 167 q3=typecast(self.percentile(0.75)), 168 max=typecast(self.max()), 169 ), 170 } 171 172 if extra_percentiles is not None: 173 for p in extra_percentiles: 174 output[f"percentile_{p}"] = typecast(self.percentile(p)) 175 else: 176 # if its not, we can only do the simpler things 177 # mean mode and total are done in the initial declaration of `output` 178 pass 179 180 return output 181 182 def serialize( 183 self, 184 typecast: Callable = lambda x: x, 185 *, 186 extra_percentiles: Optional[list[float]] = None, 187 ) -> dict: 188 """return a json-serializable version of the counter 189 190 includes both the output of `summary` and the raw data: 191 192 ```json 193 { 194 "StatCounter": { <keys, values from raw data> }, 195 "summary": self.summary(typecast, extra_percentiles=extra_percentiles), 196 } 197 198 """ 199 200 return { 201 "StatCounter": { 202 typecast(k): v 203 for k, v in sorted(dict(self).items(), key=lambda x: x[0]) 204 }, 205 "summary": self.summary(typecast, extra_percentiles=extra_percentiles), 206 } 207 208 def __str__(self) -> str: 209 "summary as json with 2 space indent, good for printing" 210 return json.dumps(self.summary(), indent=2) 211 212 def __repr__(self) -> str: 213 return json.dumps(self.serialize(), indent=2) 214 215 @classmethod 216 def load(cls, data: dict) -> "StatCounter": 217 "load from a the output of `StatCounter.serialize`" 218 if "StatCounter" in data: 219 loadme = data["StatCounter"] 220 else: 221 loadme = data 222 223 return cls({float(k): v for k, v in loadme.items()}) 224 225 @classmethod 226 def from_list_arrays( 227 cls, 228 arr, 229 map_func: Callable = float, 230 ) -> "StatCounter": 231 """calls `map_func` on each element of `universal_flatten(arr)`""" 232 return cls([map_func(x) for x in universal_flatten(arr)])
Counter
, but with some stat calculation methods which assume the keys are numerical
works best when the keys are int
s
55 def validate(self) -> bool: 56 """validate the counter as being all floats or ints""" 57 return all(isinstance(k, (bool, int, float, type(None))) for k in self.keys())
validate the counter as being all floats or ints
71 @cached_property 72 def keys_sorted(self) -> list: 73 """return the keys""" 74 return sorted(list(self.keys()))
return the keys
76 def percentile(self, p: float): 77 """return the value at the given percentile 78 79 this could be log time if we did binary search, but that would be a lot of added complexity 80 """ 81 82 if p < 0 or p > 1: 83 raise ValueError(f"percentile must be between 0 and 1: {p}") 84 # flip for speed 85 sorted_keys: list[float] = [float(x) for x in self.keys_sorted] 86 sort: int = 1 87 if p > 0.51: 88 sort = -1 89 p = 1 - p 90 91 sorted_keys = sorted_keys[::sort] 92 real_target: float = p * (self.total() - 1) 93 94 n_target_f: int = math.floor(real_target) 95 n_target_c: int = math.ceil(real_target) 96 97 n_sofar: float = -1 98 99 # print(f'{p = } {real_target = } {n_target_f = } {n_target_c = }') 100 101 for i, k in enumerate(sorted_keys): 102 n_sofar += self[k] 103 104 # print(f'{k = } {n_sofar = }') 105 106 if n_sofar > n_target_f: 107 return k 108 109 elif n_sofar == n_target_f: 110 if n_sofar == n_target_c: 111 return k 112 else: 113 # print( 114 # sorted_keys[i], (n_sofar + 1 - real_target), 115 # sorted_keys[i + 1], (real_target - n_sofar), 116 # ) 117 return sorted_keys[i] * (n_sofar + 1 - real_target) + sorted_keys[ 118 i + 1 119 ] * (real_target - n_sofar) 120 else: 121 continue 122 123 raise ValueError(f"percentile {p} not found???")
return the value at the given percentile
this could be log time if we did binary search, but that would be a lot of added complexity
128 def mean(self) -> float: 129 """return the mean of the values""" 130 return float(sum(k * c for k, c in self.items()) / self.total())
return the mean of the values
135 def std(self) -> float: 136 """return the standard deviation of the values""" 137 mean: float = self.mean() 138 deviations: float = sum(c * (k - mean) ** 2 for k, c in self.items()) 139 140 return (deviations / self.total()) ** 0.5
return the standard deviation of the values
142 def summary( 143 self, 144 typecast: Callable = lambda x: x, 145 *, 146 extra_percentiles: Optional[list[float]] = None, 147 ) -> dict[str, Union[float, int]]: 148 """return a summary of the stats, without the raw data. human readable and small""" 149 # common stats that always work 150 output: dict = dict( 151 total_items=self.total(), 152 n_keys=len(self.keys()), 153 mode=self.mode(), 154 ) 155 156 if self.total() > 0: 157 if self.validate(): 158 # if its a numeric counter, we can do some stats 159 output = { 160 **output, 161 **dict( 162 mean=float(self.mean()), 163 std=float(self.std()), 164 min=typecast(self.min()), 165 q1=typecast(self.percentile(0.25)), 166 median=typecast(self.median()), 167 q3=typecast(self.percentile(0.75)), 168 max=typecast(self.max()), 169 ), 170 } 171 172 if extra_percentiles is not None: 173 for p in extra_percentiles: 174 output[f"percentile_{p}"] = typecast(self.percentile(p)) 175 else: 176 # if its not, we can only do the simpler things 177 # mean mode and total are done in the initial declaration of `output` 178 pass 179 180 return output
return a summary of the stats, without the raw data. human readable and small
182 def serialize( 183 self, 184 typecast: Callable = lambda x: x, 185 *, 186 extra_percentiles: Optional[list[float]] = None, 187 ) -> dict: 188 """return a json-serializable version of the counter 189 190 includes both the output of `summary` and the raw data: 191 192 ```json 193 { 194 "StatCounter": { <keys, values from raw data> }, 195 "summary": self.summary(typecast, extra_percentiles=extra_percentiles), 196 } 197 198 """ 199 200 return { 201 "StatCounter": { 202 typecast(k): v 203 for k, v in sorted(dict(self).items(), key=lambda x: x[0]) 204 }, 205 "summary": self.summary(typecast, extra_percentiles=extra_percentiles), 206 }
return a json-serializable version of the counter
includes both the output of summary
and the raw data:
```json
{
"StatCounter": {
215 @classmethod 216 def load(cls, data: dict) -> "StatCounter": 217 "load from a the output of `StatCounter.serialize`" 218 if "StatCounter" in data: 219 loadme = data["StatCounter"] 220 else: 221 loadme = data 222 223 return cls({float(k): v for k, v in loadme.items()})
load from a the output of StatCounter.serialize
225 @classmethod 226 def from_list_arrays( 227 cls, 228 arr, 229 map_func: Callable = float, 230 ) -> "StatCounter": 231 """calls `map_func` on each element of `universal_flatten(arr)`""" 232 return cls([map_func(x) for x in universal_flatten(arr)])
calls map_func
on each element of universal_flatten(arr)
Inherited Members
- collections.Counter
- Counter
- most_common
- elements
- fromkeys
- update
- subtract
- copy
- builtins.dict
- get
- setdefault
- pop
- popitem
- keys
- items
- values
- clear