Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/functions/comparison.py: 43%

95 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 11:16 -0600

1"""Database functions that do comparisons or type conversions.""" 

2 

3from plain.models.db import NotSupportedError 

4from plain.models.expressions import Func, Value 

5from plain.models.fields import TextField 

6from plain.models.fields.json import JSONField 

7from plain.utils.regex_helper import _lazy_re_compile 

8 

9 

10class Cast(Func): 

11 """Coerce an expression to a new field type.""" 

12 

13 function = "CAST" 

14 template = "%(function)s(%(expressions)s AS %(db_type)s)" 

15 

16 def __init__(self, expression, output_field): 

17 super().__init__(expression, output_field=output_field) 

18 

19 def as_sql(self, compiler, connection, **extra_context): 

20 extra_context["db_type"] = self.output_field.cast_db_type(connection) 

21 return super().as_sql(compiler, connection, **extra_context) 

22 

23 def as_sqlite(self, compiler, connection, **extra_context): 

24 db_type = self.output_field.db_type(connection) 

25 if db_type in {"datetime", "time"}: 

26 # Use strftime as datetime/time don't keep fractional seconds. 

27 template = "strftime(%%s, %(expressions)s)" 

28 sql, params = super().as_sql( 

29 compiler, connection, template=template, **extra_context 

30 ) 

31 format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f" 

32 params.insert(0, format_string) 

33 return sql, params 

34 elif db_type == "date": 

35 template = "date(%(expressions)s)" 

36 return super().as_sql( 

37 compiler, connection, template=template, **extra_context 

38 ) 

39 return self.as_sql(compiler, connection, **extra_context) 

40 

41 def as_mysql(self, compiler, connection, **extra_context): 

42 template = None 

43 output_type = self.output_field.get_internal_type() 

44 # MySQL doesn't support explicit cast to float. 

45 if output_type == "FloatField": 

46 template = "(%(expressions)s + 0.0)" 

47 # MariaDB doesn't support explicit cast to JSON. 

48 elif output_type == "JSONField" and connection.mysql_is_mariadb: 

49 template = "JSON_EXTRACT(%(expressions)s, '$')" 

50 return self.as_sql(compiler, connection, template=template, **extra_context) 

51 

52 def as_postgresql(self, compiler, connection, **extra_context): 

53 # CAST would be valid too, but the :: shortcut syntax is more readable. 

54 # 'expressions' is wrapped in parentheses in case it's a complex 

55 # expression. 

56 return self.as_sql( 

57 compiler, 

58 connection, 

59 template="(%(expressions)s)::%(db_type)s", 

60 **extra_context, 

61 ) 

62 

63 

64class Coalesce(Func): 

65 """Return, from left to right, the first non-null expression.""" 

66 

67 function = "COALESCE" 

68 

69 def __init__(self, *expressions, **extra): 

70 if len(expressions) < 2: 

71 raise ValueError("Coalesce must take at least two expressions") 

72 super().__init__(*expressions, **extra) 

73 

74 @property 

75 def empty_result_set_value(self): 

76 for expression in self.get_source_expressions(): 

77 result = expression.empty_result_set_value 

78 if result is NotImplemented or result is not None: 

79 return result 

80 return None 

81 

82 

83class Collate(Func): 

84 function = "COLLATE" 

85 template = "%(expressions)s %(function)s %(collation)s" 

86 # Inspired from 

87 # https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS 

88 collation_re = _lazy_re_compile(r"^[\w\-]+$") 

89 

90 def __init__(self, expression, collation): 

91 if not (collation and self.collation_re.match(collation)): 

92 raise ValueError(f"Invalid collation name: {collation!r}.") 

93 self.collation = collation 

94 super().__init__(expression) 

95 

96 def as_sql(self, compiler, connection, **extra_context): 

97 extra_context.setdefault("collation", connection.ops.quote_name(self.collation)) 

98 return super().as_sql(compiler, connection, **extra_context) 

99 

100 

101class Greatest(Func): 

102 """ 

103 Return the maximum expression. 

104 

105 If any expression is null the return value is database-specific: 

106 On PostgreSQL, the maximum not-null expression is returned. 

107 On MySQL, Oracle, and SQLite, if any expression is null, null is returned. 

108 """ 

109 

110 function = "GREATEST" 

111 

112 def __init__(self, *expressions, **extra): 

113 if len(expressions) < 2: 

114 raise ValueError("Greatest must take at least two expressions") 

115 super().__init__(*expressions, **extra) 

116 

117 def as_sqlite(self, compiler, connection, **extra_context): 

118 """Use the MAX function on SQLite.""" 

119 return super().as_sqlite(compiler, connection, function="MAX", **extra_context) 

120 

121 

122class JSONObject(Func): 

123 function = "JSON_OBJECT" 

124 output_field = JSONField() 

125 

126 def __init__(self, **fields): 

127 expressions = [] 

128 for key, value in fields.items(): 

129 expressions.extend((Value(key), value)) 

130 super().__init__(*expressions) 

131 

132 def as_sql(self, compiler, connection, **extra_context): 

133 if not connection.features.has_json_object_function: 

134 raise NotSupportedError( 

135 "JSONObject() is not supported on this database backend." 

136 ) 

137 return super().as_sql(compiler, connection, **extra_context) 

138 

139 def as_postgresql(self, compiler, connection, **extra_context): 

140 copy = self.copy() 

141 copy.set_source_expressions( 

142 [ 

143 Cast(expression, TextField()) if index % 2 == 0 else expression 

144 for index, expression in enumerate(copy.get_source_expressions()) 

145 ] 

146 ) 

147 return super(JSONObject, copy).as_sql( 

148 compiler, 

149 connection, 

150 function="JSONB_BUILD_OBJECT", 

151 **extra_context, 

152 ) 

153 

154 

155class Least(Func): 

156 """ 

157 Return the minimum expression. 

158 

159 If any expression is null the return value is database-specific: 

160 On PostgreSQL, return the minimum not-null expression. 

161 On MySQL, Oracle, and SQLite, if any expression is null, return null. 

162 """ 

163 

164 function = "LEAST" 

165 

166 def __init__(self, *expressions, **extra): 

167 if len(expressions) < 2: 

168 raise ValueError("Least must take at least two expressions") 

169 super().__init__(*expressions, **extra) 

170 

171 def as_sqlite(self, compiler, connection, **extra_context): 

172 """Use the MIN function on SQLite.""" 

173 return super().as_sqlite(compiler, connection, function="MIN", **extra_context) 

174 

175 

176class NullIf(Func): 

177 function = "NULLIF" 

178 arity = 2