Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-models/plain/models/backends/sqlite3/operations.py: 46%
211 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-16 22:03 -0500
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-16 22:03 -0500
1import datetime
2import decimal
3import uuid
4from functools import lru_cache
6from plain import models
7from plain.exceptions import FieldError
8from plain.models.backends.base.operations import BaseDatabaseOperations
9from plain.models.constants import OnConflict
10from plain.models.db import DatabaseError, NotSupportedError
11from plain.models.expressions import Col
12from plain.utils import timezone
13from plain.utils.dateparse import parse_date, parse_datetime, parse_time
14from plain.utils.functional import cached_property
17class DatabaseOperations(BaseDatabaseOperations):
18 cast_char_field_without_max_length = "text"
19 cast_data_types = {
20 "DateField": "TEXT",
21 "DateTimeField": "TEXT",
22 }
23 explain_prefix = "EXPLAIN QUERY PLAN"
24 # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
25 # SQLite. Use JSON_TYPE() instead.
26 jsonfield_datatype_values = frozenset(["null", "false", "true"])
28 def bulk_batch_size(self, fields, objs):
29 """
30 SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
31 999 variables per query.
33 If there's only a single field to insert, the limit is 500
34 (SQLITE_MAX_COMPOUND_SELECT).
35 """
36 if len(fields) == 1:
37 return 500
38 elif len(fields) > 1:
39 return self.connection.features.max_query_params // len(fields)
40 else:
41 return len(objs)
43 def check_expression_support(self, expression):
44 bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
45 bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
46 if isinstance(expression, bad_aggregates):
47 for expr in expression.get_source_expressions():
48 try:
49 output_field = expr.output_field
50 except (AttributeError, FieldError):
51 # Not every subexpression has an output_field which is fine
52 # to ignore.
53 pass
54 else:
55 if isinstance(output_field, bad_fields):
56 raise NotSupportedError(
57 "You cannot use Sum, Avg, StdDev, and Variance "
58 "aggregations on date/time fields in sqlite3 "
59 "since date/time is saved as text."
60 )
61 if (
62 isinstance(expression, models.Aggregate)
63 and expression.distinct
64 and len(expression.source_expressions) > 1
65 ):
66 raise NotSupportedError(
67 "SQLite doesn't support DISTINCT on aggregate functions "
68 "accepting multiple arguments."
69 )
71 def date_extract_sql(self, lookup_type, sql, params):
72 """
73 Support EXTRACT with a user-defined function plain_date_extract()
74 that's registered in connect(). Use single quotes because this is a
75 string and could otherwise cause a collision with a field name.
76 """
77 return f"plain_date_extract(%s, {sql})", (lookup_type.lower(), *params)
79 def fetch_returned_insert_rows(self, cursor):
80 """
81 Given a cursor object that has just performed an INSERT...RETURNING
82 statement into a table, return the list of returned data.
83 """
84 return cursor.fetchall()
86 def format_for_duration_arithmetic(self, sql):
87 """Do nothing since formatting is handled in the custom function."""
88 return sql
90 def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
91 return f"plain_date_trunc(%s, {sql}, %s, %s)", (
92 lookup_type.lower(),
93 *params,
94 *self._convert_tznames_to_sql(tzname),
95 )
97 def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
98 return f"plain_time_trunc(%s, {sql}, %s, %s)", (
99 lookup_type.lower(),
100 *params,
101 *self._convert_tznames_to_sql(tzname),
102 )
104 def _convert_tznames_to_sql(self, tzname):
105 if tzname:
106 return tzname, self.connection.timezone_name
107 return None, None
109 def datetime_cast_date_sql(self, sql, params, tzname):
110 return f"plain_datetime_cast_date({sql}, %s, %s)", (
111 *params,
112 *self._convert_tznames_to_sql(tzname),
113 )
115 def datetime_cast_time_sql(self, sql, params, tzname):
116 return f"plain_datetime_cast_time({sql}, %s, %s)", (
117 *params,
118 *self._convert_tznames_to_sql(tzname),
119 )
121 def datetime_extract_sql(self, lookup_type, sql, params, tzname):
122 return f"plain_datetime_extract(%s, {sql}, %s, %s)", (
123 lookup_type.lower(),
124 *params,
125 *self._convert_tznames_to_sql(tzname),
126 )
128 def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
129 return f"plain_datetime_trunc(%s, {sql}, %s, %s)", (
130 lookup_type.lower(),
131 *params,
132 *self._convert_tznames_to_sql(tzname),
133 )
135 def time_extract_sql(self, lookup_type, sql, params):
136 return f"plain_time_extract(%s, {sql})", (lookup_type.lower(), *params)
138 def pk_default_value(self):
139 return "NULL"
141 def _quote_params_for_last_executed_query(self, params):
142 """
143 Only for last_executed_query! Don't use this to execute SQL queries!
144 """
145 # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
146 # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
147 # number of return values, default = 2000). Since Python's sqlite3
148 # module doesn't expose the get_limit() C API, assume the default
149 # limits are in effect and split the work in batches if needed.
150 BATCH_SIZE = 999
151 if len(params) > BATCH_SIZE:
152 results = ()
153 for index in range(0, len(params), BATCH_SIZE):
154 chunk = params[index : index + BATCH_SIZE]
155 results += self._quote_params_for_last_executed_query(chunk)
156 return results
158 sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
159 # Bypass Plain's wrappers and use the underlying sqlite3 connection
160 # to avoid logging this query - it would trigger infinite recursion.
161 cursor = self.connection.connection.cursor()
162 # Native sqlite3 cursors cannot be used as context managers.
163 try:
164 return cursor.execute(sql, params).fetchone()
165 finally:
166 cursor.close()
168 def last_executed_query(self, cursor, sql, params):
169 # Python substitutes parameters in Modules/_sqlite/cursor.c with:
170 # bind_parameters(state, self->statement, parameters);
171 # Unfortunately there is no way to reach self->statement from Python,
172 # so we quote and substitute parameters manually.
173 if params:
174 if isinstance(params, list | tuple):
175 params = self._quote_params_for_last_executed_query(params)
176 else:
177 values = tuple(params.values())
178 values = self._quote_params_for_last_executed_query(values)
179 params = dict(zip(params, values))
180 return sql % params
181 # For consistency with SQLiteCursorWrapper.execute(), just return sql
182 # when there are no parameters. See #13648 and #17158.
183 else:
184 return sql
186 def quote_name(self, name):
187 if name.startswith('"') and name.endswith('"'):
188 return name # Quoting once is enough.
189 return '"%s"' % name
191 def no_limit_value(self):
192 return -1
194 def __references_graph(self, table_name):
195 query = """
196 WITH tables AS (
197 SELECT %s name
198 UNION
199 SELECT sqlite_master.name
200 FROM sqlite_master
201 JOIN tables ON (sql REGEXP %s || tables.name || %s)
202 ) SELECT name FROM tables;
203 """
204 params = (
205 table_name,
206 r'(?i)\s+references\s+("|\')?',
207 r'("|\')?\s*\(',
208 )
209 with self.connection.cursor() as cursor:
210 results = cursor.execute(query, params)
211 return [row[0] for row in results.fetchall()]
213 @cached_property
214 def _references_graph(self):
215 # 512 is large enough to fit the ~330 tables (as of this writing) in
216 # Plain's test suite.
217 return lru_cache(maxsize=512)(self.__references_graph)
219 def sequence_reset_by_name_sql(self, style, sequences):
220 if not sequences:
221 return []
222 return [
223 "{} {} {} {} = 0 {} {} {} ({});".format(
224 style.SQL_KEYWORD("UPDATE"),
225 style.SQL_TABLE(self.quote_name("sqlite_sequence")),
226 style.SQL_KEYWORD("SET"),
227 style.SQL_FIELD(self.quote_name("seq")),
228 style.SQL_KEYWORD("WHERE"),
229 style.SQL_FIELD(self.quote_name("name")),
230 style.SQL_KEYWORD("IN"),
231 ", ".join(
232 ["'%s'" % sequence_info["table"] for sequence_info in sequences]
233 ),
234 ),
235 ]
237 def adapt_datetimefield_value(self, value):
238 if value is None:
239 return None
241 # Expression values are adapted by the database.
242 if hasattr(value, "resolve_expression"):
243 return value
245 # SQLite doesn't support tz-aware datetimes
246 if timezone.is_aware(value):
247 value = timezone.make_naive(value, self.connection.timezone)
249 return str(value)
251 def adapt_timefield_value(self, value):
252 if value is None:
253 return None
255 # Expression values are adapted by the database.
256 if hasattr(value, "resolve_expression"):
257 return value
259 # SQLite doesn't support tz-aware datetimes
260 if timezone.is_aware(value):
261 raise ValueError("SQLite backend does not support timezone-aware times.")
263 return str(value)
265 def get_db_converters(self, expression):
266 converters = super().get_db_converters(expression)
267 internal_type = expression.output_field.get_internal_type()
268 if internal_type == "DateTimeField":
269 converters.append(self.convert_datetimefield_value)
270 elif internal_type == "DateField":
271 converters.append(self.convert_datefield_value)
272 elif internal_type == "TimeField":
273 converters.append(self.convert_timefield_value)
274 elif internal_type == "DecimalField":
275 converters.append(self.get_decimalfield_converter(expression))
276 elif internal_type == "UUIDField":
277 converters.append(self.convert_uuidfield_value)
278 elif internal_type == "BooleanField":
279 converters.append(self.convert_booleanfield_value)
280 return converters
282 def convert_datetimefield_value(self, value, expression, connection):
283 if value is not None:
284 if not isinstance(value, datetime.datetime):
285 value = parse_datetime(value)
286 if not timezone.is_aware(value):
287 value = timezone.make_aware(value, self.connection.timezone)
288 return value
290 def convert_datefield_value(self, value, expression, connection):
291 if value is not None:
292 if not isinstance(value, datetime.date):
293 value = parse_date(value)
294 return value
296 def convert_timefield_value(self, value, expression, connection):
297 if value is not None:
298 if not isinstance(value, datetime.time):
299 value = parse_time(value)
300 return value
302 def get_decimalfield_converter(self, expression):
303 # SQLite stores only 15 significant digits. Digits coming from
304 # float inaccuracy must be removed.
305 create_decimal = decimal.Context(prec=15).create_decimal_from_float
306 if isinstance(expression, Col):
307 quantize_value = decimal.Decimal(1).scaleb(
308 -expression.output_field.decimal_places
309 )
311 def converter(value, expression, connection):
312 if value is not None:
313 return create_decimal(value).quantize(
314 quantize_value, context=expression.output_field.context
315 )
317 else:
319 def converter(value, expression, connection):
320 if value is not None:
321 return create_decimal(value)
323 return converter
325 def convert_uuidfield_value(self, value, expression, connection):
326 if value is not None:
327 value = uuid.UUID(value)
328 return value
330 def convert_booleanfield_value(self, value, expression, connection):
331 return bool(value) if value in (1, 0) else value
333 def bulk_insert_sql(self, fields, placeholder_rows):
334 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
335 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
336 return f"VALUES {values_sql}"
338 def combine_expression(self, connector, sub_expressions):
339 # SQLite doesn't have a ^ operator, so use the user-defined POWER
340 # function that's registered in connect().
341 if connector == "^":
342 return "POWER(%s)" % ",".join(sub_expressions)
343 elif connector == "#":
344 return "BITXOR(%s)" % ",".join(sub_expressions)
345 return super().combine_expression(connector, sub_expressions)
347 def combine_duration_expression(self, connector, sub_expressions):
348 if connector not in ["+", "-", "*", "/"]:
349 raise DatabaseError("Invalid connector for timedelta: %s." % connector)
350 fn_params = ["'%s'" % connector] + sub_expressions
351 if len(fn_params) > 3:
352 raise ValueError("Too many params for timedelta operations.")
353 return "plain_format_dtdelta(%s)" % ", ".join(fn_params)
355 def integer_field_range(self, internal_type):
356 # SQLite doesn't enforce any integer constraints, but sqlite3 supports
357 # integers up to 64 bits.
358 if internal_type in [
359 "PositiveBigIntegerField",
360 "PositiveIntegerField",
361 "PositiveSmallIntegerField",
362 ]:
363 return (0, 9223372036854775807)
364 return (-9223372036854775808, 9223372036854775807)
366 def subtract_temporals(self, internal_type, lhs, rhs):
367 lhs_sql, lhs_params = lhs
368 rhs_sql, rhs_params = rhs
369 params = (*lhs_params, *rhs_params)
370 if internal_type == "TimeField":
371 return f"plain_time_diff({lhs_sql}, {rhs_sql})", params
372 return f"plain_timestamp_diff({lhs_sql}, {rhs_sql})", params
374 def insert_statement(self, on_conflict=None):
375 if on_conflict == OnConflict.IGNORE:
376 return "INSERT OR IGNORE INTO"
377 return super().insert_statement(on_conflict=on_conflict)
379 def return_insert_columns(self, fields):
380 # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
381 if not fields:
382 return "", ()
383 columns = [
384 "{}.{}".format(
385 self.quote_name(field.model._meta.db_table),
386 self.quote_name(field.column),
387 )
388 for field in fields
389 ]
390 return "RETURNING %s" % ", ".join(columns), ()
392 def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
393 if (
394 on_conflict == OnConflict.UPDATE
395 and self.connection.features.supports_update_conflicts_with_target
396 ):
397 return "ON CONFLICT({}) DO UPDATE SET {}".format(
398 ", ".join(map(self.quote_name, unique_fields)),
399 ", ".join(
400 [
401 f"{field} = EXCLUDED.{field}"
402 for field in map(self.quote_name, update_fields)
403 ]
404 ),
405 )
406 return super().on_conflict_suffix_sql(
407 fields,
408 on_conflict,
409 update_fields,
410 unique_fields,
411 )