Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-models/plain/models/aggregates.py: 56%

127 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-16 22:04 -0500

1""" 

2Classes to represent the definitions of aggregate functions. 

3""" 

4from plain.exceptions import FieldError, FullResultSet 

5from plain.models.expressions import Case, Func, Star, Value, When 

6from plain.models.fields import IntegerField 

7from plain.models.functions.comparison import Coalesce 

8from plain.models.functions.mixins import ( 

9 FixDurationInputMixin, 

10 NumericOutputFieldMixin, 

11) 

12 

13__all__ = [ 

14 "Aggregate", 

15 "Avg", 

16 "Count", 

17 "Max", 

18 "Min", 

19 "StdDev", 

20 "Sum", 

21 "Variance", 

22] 

23 

24 

25class Aggregate(Func): 

26 template = "%(function)s(%(distinct)s%(expressions)s)" 

27 contains_aggregate = True 

28 name = None 

29 filter_template = "%s FILTER (WHERE %%(filter)s)" 

30 window_compatible = True 

31 allow_distinct = False 

32 empty_result_set_value = None 

33 

34 def __init__( 

35 self, *expressions, distinct=False, filter=None, default=None, **extra 

36 ): 

37 if distinct and not self.allow_distinct: 

38 raise TypeError("%s does not allow distinct." % self.__class__.__name__) 

39 if default is not None and self.empty_result_set_value is not None: 

40 raise TypeError(f"{self.__class__.__name__} does not allow default.") 

41 self.distinct = distinct 

42 self.filter = filter 

43 self.default = default 

44 super().__init__(*expressions, **extra) 

45 

46 def get_source_fields(self): 

47 # Don't return the filter expression since it's not a source field. 

48 return [e._output_field_or_none for e in super().get_source_expressions()] 

49 

50 def get_source_expressions(self): 

51 source_expressions = super().get_source_expressions() 

52 if self.filter: 

53 return source_expressions + [self.filter] 

54 return source_expressions 

55 

56 def set_source_expressions(self, exprs): 

57 self.filter = self.filter and exprs.pop() 

58 return super().set_source_expressions(exprs) 

59 

60 def resolve_expression( 

61 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

62 ): 

63 # Aggregates are not allowed in UPDATE queries, so ignore for_save 

64 c = super().resolve_expression(query, allow_joins, reuse, summarize) 

65 c.filter = c.filter and c.filter.resolve_expression( 

66 query, allow_joins, reuse, summarize 

67 ) 

68 if not summarize: 

69 # Call Aggregate.get_source_expressions() to avoid 

70 # returning self.filter and including that in this loop. 

71 expressions = super(Aggregate, c).get_source_expressions() 

72 for index, expr in enumerate(expressions): 

73 if expr.contains_aggregate: 

74 before_resolved = self.get_source_expressions()[index] 

75 name = ( 

76 before_resolved.name 

77 if hasattr(before_resolved, "name") 

78 else repr(before_resolved) 

79 ) 

80 raise FieldError( 

81 f"Cannot compute {c.name}('{name}'): '{name}' is an aggregate" 

82 ) 

83 if (default := c.default) is None: 

84 return c 

85 if hasattr(default, "resolve_expression"): 

86 default = default.resolve_expression(query, allow_joins, reuse, summarize) 

87 if default._output_field_or_none is None: 

88 default.output_field = c._output_field_or_none 

89 else: 

90 default = Value(default, c._output_field_or_none) 

91 c.default = None # Reset the default argument before wrapping. 

92 coalesce = Coalesce(c, default, output_field=c._output_field_or_none) 

93 coalesce.is_summary = c.is_summary 

94 return coalesce 

95 

96 @property 

97 def default_alias(self): 

98 expressions = self.get_source_expressions() 

99 if len(expressions) == 1 and hasattr(expressions[0], "name"): 

100 return f"{expressions[0].name}__{self.name.lower()}" 

101 raise TypeError("Complex expressions require an alias") 

102 

103 def get_group_by_cols(self): 

104 return [] 

105 

106 def as_sql(self, compiler, connection, **extra_context): 

107 extra_context["distinct"] = "DISTINCT " if self.distinct else "" 

108 if self.filter: 

109 if connection.features.supports_aggregate_filter_clause: 

110 try: 

111 filter_sql, filter_params = self.filter.as_sql(compiler, connection) 

112 except FullResultSet: 

113 pass 

114 else: 

115 template = self.filter_template % extra_context.get( 

116 "template", self.template 

117 ) 

118 sql, params = super().as_sql( 

119 compiler, 

120 connection, 

121 template=template, 

122 filter=filter_sql, 

123 **extra_context, 

124 ) 

125 return sql, (*params, *filter_params) 

126 else: 

127 copy = self.copy() 

128 copy.filter = None 

129 source_expressions = copy.get_source_expressions() 

130 condition = When(self.filter, then=source_expressions[0]) 

131 copy.set_source_expressions([Case(condition)] + source_expressions[1:]) 

132 return super(Aggregate, copy).as_sql( 

133 compiler, connection, **extra_context 

134 ) 

135 return super().as_sql(compiler, connection, **extra_context) 

136 

137 def _get_repr_options(self): 

138 options = super()._get_repr_options() 

139 if self.distinct: 

140 options["distinct"] = self.distinct 

141 if self.filter: 

142 options["filter"] = self.filter 

143 return options 

144 

145 

146class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate): 

147 function = "AVG" 

148 name = "Avg" 

149 allow_distinct = True 

150 

151 

152class Count(Aggregate): 

153 function = "COUNT" 

154 name = "Count" 

155 output_field = IntegerField() 

156 allow_distinct = True 

157 empty_result_set_value = 0 

158 

159 def __init__(self, expression, filter=None, **extra): 

160 if expression == "*": 

161 expression = Star() 

162 if isinstance(expression, Star) and filter is not None: 

163 raise ValueError("Star cannot be used with filter. Please specify a field.") 

164 super().__init__(expression, filter=filter, **extra) 

165 

166 

167class Max(Aggregate): 

168 function = "MAX" 

169 name = "Max" 

170 

171 

172class Min(Aggregate): 

173 function = "MIN" 

174 name = "Min" 

175 

176 

177class StdDev(NumericOutputFieldMixin, Aggregate): 

178 name = "StdDev" 

179 

180 def __init__(self, expression, sample=False, **extra): 

181 self.function = "STDDEV_SAMP" if sample else "STDDEV_POP" 

182 super().__init__(expression, **extra) 

183 

184 def _get_repr_options(self): 

185 return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"} 

186 

187 

188class Sum(FixDurationInputMixin, Aggregate): 

189 function = "SUM" 

190 name = "Sum" 

191 allow_distinct = True 

192 

193 

194class Variance(NumericOutputFieldMixin, Aggregate): 

195 name = "Variance" 

196 

197 def __init__(self, expression, sample=False, **extra): 

198 self.function = "VAR_SAMP" if sample else "VAR_POP" 

199 super().__init__(expression, **extra) 

200 

201 def _get_repr_options(self): 

202 return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}