Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/aggregates.py: 42%
127 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
1"""
2Classes to represent the definitions of aggregate functions.
3"""
5from plain.exceptions import FieldError, FullResultSet
6from plain.models.expressions import Case, Func, Star, Value, When
7from plain.models.fields import IntegerField
8from plain.models.functions.comparison import Coalesce
9from plain.models.functions.mixins import (
10 FixDurationInputMixin,
11 NumericOutputFieldMixin,
12)
14__all__ = [
15 "Aggregate",
16 "Avg",
17 "Count",
18 "Max",
19 "Min",
20 "StdDev",
21 "Sum",
22 "Variance",
23]
26class Aggregate(Func):
27 template = "%(function)s(%(distinct)s%(expressions)s)"
28 contains_aggregate = True
29 name = None
30 filter_template = "%s FILTER (WHERE %%(filter)s)"
31 window_compatible = True
32 allow_distinct = False
33 empty_result_set_value = None
35 def __init__(
36 self, *expressions, distinct=False, filter=None, default=None, **extra
37 ):
38 if distinct and not self.allow_distinct:
39 raise TypeError(f"{self.__class__.__name__} does not allow distinct.")
40 if default is not None and self.empty_result_set_value is not None:
41 raise TypeError(f"{self.__class__.__name__} does not allow default.")
42 self.distinct = distinct
43 self.filter = filter
44 self.default = default
45 super().__init__(*expressions, **extra)
47 def get_source_fields(self):
48 # Don't return the filter expression since it's not a source field.
49 return [e._output_field_or_none for e in super().get_source_expressions()]
51 def get_source_expressions(self):
52 source_expressions = super().get_source_expressions()
53 if self.filter:
54 return source_expressions + [self.filter]
55 return source_expressions
57 def set_source_expressions(self, exprs):
58 self.filter = self.filter and exprs.pop()
59 return super().set_source_expressions(exprs)
61 def resolve_expression(
62 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
63 ):
64 # Aggregates are not allowed in UPDATE queries, so ignore for_save
65 c = super().resolve_expression(query, allow_joins, reuse, summarize)
66 c.filter = c.filter and c.filter.resolve_expression(
67 query, allow_joins, reuse, summarize
68 )
69 if not summarize:
70 # Call Aggregate.get_source_expressions() to avoid
71 # returning self.filter and including that in this loop.
72 expressions = super(Aggregate, c).get_source_expressions()
73 for index, expr in enumerate(expressions):
74 if expr.contains_aggregate:
75 before_resolved = self.get_source_expressions()[index]
76 name = (
77 before_resolved.name
78 if hasattr(before_resolved, "name")
79 else repr(before_resolved)
80 )
81 raise FieldError(
82 f"Cannot compute {c.name}('{name}'): '{name}' is an aggregate"
83 )
84 if (default := c.default) is None:
85 return c
86 if hasattr(default, "resolve_expression"):
87 default = default.resolve_expression(query, allow_joins, reuse, summarize)
88 if default._output_field_or_none is None:
89 default.output_field = c._output_field_or_none
90 else:
91 default = Value(default, c._output_field_or_none)
92 c.default = None # Reset the default argument before wrapping.
93 coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
94 coalesce.is_summary = c.is_summary
95 return coalesce
97 @property
98 def default_alias(self):
99 expressions = self.get_source_expressions()
100 if len(expressions) == 1 and hasattr(expressions[0], "name"):
101 return f"{expressions[0].name}__{self.name.lower()}"
102 raise TypeError("Complex expressions require an alias")
104 def get_group_by_cols(self):
105 return []
107 def as_sql(self, compiler, connection, **extra_context):
108 extra_context["distinct"] = "DISTINCT " if self.distinct else ""
109 if self.filter:
110 if connection.features.supports_aggregate_filter_clause:
111 try:
112 filter_sql, filter_params = self.filter.as_sql(compiler, connection)
113 except FullResultSet:
114 pass
115 else:
116 template = self.filter_template % extra_context.get(
117 "template", self.template
118 )
119 sql, params = super().as_sql(
120 compiler,
121 connection,
122 template=template,
123 filter=filter_sql,
124 **extra_context,
125 )
126 return sql, (*params, *filter_params)
127 else:
128 copy = self.copy()
129 copy.filter = None
130 source_expressions = copy.get_source_expressions()
131 condition = When(self.filter, then=source_expressions[0])
132 copy.set_source_expressions([Case(condition)] + source_expressions[1:])
133 return super(Aggregate, copy).as_sql(
134 compiler, connection, **extra_context
135 )
136 return super().as_sql(compiler, connection, **extra_context)
138 def _get_repr_options(self):
139 options = super()._get_repr_options()
140 if self.distinct:
141 options["distinct"] = self.distinct
142 if self.filter:
143 options["filter"] = self.filter
144 return options
147class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
148 function = "AVG"
149 name = "Avg"
150 allow_distinct = True
153class Count(Aggregate):
154 function = "COUNT"
155 name = "Count"
156 output_field = IntegerField()
157 allow_distinct = True
158 empty_result_set_value = 0
160 def __init__(self, expression, filter=None, **extra):
161 if expression == "*":
162 expression = Star()
163 if isinstance(expression, Star) and filter is not None:
164 raise ValueError("Star cannot be used with filter. Please specify a field.")
165 super().__init__(expression, filter=filter, **extra)
168class Max(Aggregate):
169 function = "MAX"
170 name = "Max"
173class Min(Aggregate):
174 function = "MIN"
175 name = "Min"
178class StdDev(NumericOutputFieldMixin, Aggregate):
179 name = "StdDev"
181 def __init__(self, expression, sample=False, **extra):
182 self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
183 super().__init__(expression, **extra)
185 def _get_repr_options(self):
186 return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
189class Sum(FixDurationInputMixin, Aggregate):
190 function = "SUM"
191 name = "Sum"
192 allow_distinct = True
195class Variance(NumericOutputFieldMixin, Aggregate):
196 name = "Variance"
198 def __init__(self, expression, sample=False, **extra):
199 self.function = "VAR_SAMP" if sample else "VAR_POP"
200 super().__init__(expression, **extra)
202 def _get_repr_options(self):
203 return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}