Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/aggregates.py: 42%

127 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 11:16 -0600

1""" 

2Classes to represent the definitions of aggregate functions. 

3""" 

4 

5from plain.exceptions import FieldError, FullResultSet 

6from plain.models.expressions import Case, Func, Star, Value, When 

7from plain.models.fields import IntegerField 

8from plain.models.functions.comparison import Coalesce 

9from plain.models.functions.mixins import ( 

10 FixDurationInputMixin, 

11 NumericOutputFieldMixin, 

12) 

13 

14__all__ = [ 

15 "Aggregate", 

16 "Avg", 

17 "Count", 

18 "Max", 

19 "Min", 

20 "StdDev", 

21 "Sum", 

22 "Variance", 

23] 

24 

25 

26class Aggregate(Func): 

27 template = "%(function)s(%(distinct)s%(expressions)s)" 

28 contains_aggregate = True 

29 name = None 

30 filter_template = "%s FILTER (WHERE %%(filter)s)" 

31 window_compatible = True 

32 allow_distinct = False 

33 empty_result_set_value = None 

34 

35 def __init__( 

36 self, *expressions, distinct=False, filter=None, default=None, **extra 

37 ): 

38 if distinct and not self.allow_distinct: 

39 raise TypeError(f"{self.__class__.__name__} does not allow distinct.") 

40 if default is not None and self.empty_result_set_value is not None: 

41 raise TypeError(f"{self.__class__.__name__} does not allow default.") 

42 self.distinct = distinct 

43 self.filter = filter 

44 self.default = default 

45 super().__init__(*expressions, **extra) 

46 

47 def get_source_fields(self): 

48 # Don't return the filter expression since it's not a source field. 

49 return [e._output_field_or_none for e in super().get_source_expressions()] 

50 

51 def get_source_expressions(self): 

52 source_expressions = super().get_source_expressions() 

53 if self.filter: 

54 return source_expressions + [self.filter] 

55 return source_expressions 

56 

57 def set_source_expressions(self, exprs): 

58 self.filter = self.filter and exprs.pop() 

59 return super().set_source_expressions(exprs) 

60 

61 def resolve_expression( 

62 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

63 ): 

64 # Aggregates are not allowed in UPDATE queries, so ignore for_save 

65 c = super().resolve_expression(query, allow_joins, reuse, summarize) 

66 c.filter = c.filter and c.filter.resolve_expression( 

67 query, allow_joins, reuse, summarize 

68 ) 

69 if not summarize: 

70 # Call Aggregate.get_source_expressions() to avoid 

71 # returning self.filter and including that in this loop. 

72 expressions = super(Aggregate, c).get_source_expressions() 

73 for index, expr in enumerate(expressions): 

74 if expr.contains_aggregate: 

75 before_resolved = self.get_source_expressions()[index] 

76 name = ( 

77 before_resolved.name 

78 if hasattr(before_resolved, "name") 

79 else repr(before_resolved) 

80 ) 

81 raise FieldError( 

82 f"Cannot compute {c.name}('{name}'): '{name}' is an aggregate" 

83 ) 

84 if (default := c.default) is None: 

85 return c 

86 if hasattr(default, "resolve_expression"): 

87 default = default.resolve_expression(query, allow_joins, reuse, summarize) 

88 if default._output_field_or_none is None: 

89 default.output_field = c._output_field_or_none 

90 else: 

91 default = Value(default, c._output_field_or_none) 

92 c.default = None # Reset the default argument before wrapping. 

93 coalesce = Coalesce(c, default, output_field=c._output_field_or_none) 

94 coalesce.is_summary = c.is_summary 

95 return coalesce 

96 

97 @property 

98 def default_alias(self): 

99 expressions = self.get_source_expressions() 

100 if len(expressions) == 1 and hasattr(expressions[0], "name"): 

101 return f"{expressions[0].name}__{self.name.lower()}" 

102 raise TypeError("Complex expressions require an alias") 

103 

104 def get_group_by_cols(self): 

105 return [] 

106 

107 def as_sql(self, compiler, connection, **extra_context): 

108 extra_context["distinct"] = "DISTINCT " if self.distinct else "" 

109 if self.filter: 

110 if connection.features.supports_aggregate_filter_clause: 

111 try: 

112 filter_sql, filter_params = self.filter.as_sql(compiler, connection) 

113 except FullResultSet: 

114 pass 

115 else: 

116 template = self.filter_template % extra_context.get( 

117 "template", self.template 

118 ) 

119 sql, params = super().as_sql( 

120 compiler, 

121 connection, 

122 template=template, 

123 filter=filter_sql, 

124 **extra_context, 

125 ) 

126 return sql, (*params, *filter_params) 

127 else: 

128 copy = self.copy() 

129 copy.filter = None 

130 source_expressions = copy.get_source_expressions() 

131 condition = When(self.filter, then=source_expressions[0]) 

132 copy.set_source_expressions([Case(condition)] + source_expressions[1:]) 

133 return super(Aggregate, copy).as_sql( 

134 compiler, connection, **extra_context 

135 ) 

136 return super().as_sql(compiler, connection, **extra_context) 

137 

138 def _get_repr_options(self): 

139 options = super()._get_repr_options() 

140 if self.distinct: 

141 options["distinct"] = self.distinct 

142 if self.filter: 

143 options["filter"] = self.filter 

144 return options 

145 

146 

147class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate): 

148 function = "AVG" 

149 name = "Avg" 

150 allow_distinct = True 

151 

152 

153class Count(Aggregate): 

154 function = "COUNT" 

155 name = "Count" 

156 output_field = IntegerField() 

157 allow_distinct = True 

158 empty_result_set_value = 0 

159 

160 def __init__(self, expression, filter=None, **extra): 

161 if expression == "*": 

162 expression = Star() 

163 if isinstance(expression, Star) and filter is not None: 

164 raise ValueError("Star cannot be used with filter. Please specify a field.") 

165 super().__init__(expression, filter=filter, **extra) 

166 

167 

168class Max(Aggregate): 

169 function = "MAX" 

170 name = "Max" 

171 

172 

173class Min(Aggregate): 

174 function = "MIN" 

175 name = "Min" 

176 

177 

178class StdDev(NumericOutputFieldMixin, Aggregate): 

179 name = "StdDev" 

180 

181 def __init__(self, expression, sample=False, **extra): 

182 self.function = "STDDEV_SAMP" if sample else "STDDEV_POP" 

183 super().__init__(expression, **extra) 

184 

185 def _get_repr_options(self): 

186 return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"} 

187 

188 

189class Sum(FixDurationInputMixin, Aggregate): 

190 function = "SUM" 

191 name = "Sum" 

192 allow_distinct = True 

193 

194 

195class Variance(NumericOutputFieldMixin, Aggregate): 

196 name = "Variance" 

197 

198 def __init__(self, expression, sample=False, **extra): 

199 self.function = "VAR_SAMP" if sample else "VAR_POP" 

200 super().__init__(expression, **extra) 

201 

202 def _get_repr_options(self): 

203 return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}