Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/backends/sqlite3/operations.py: 45%
211 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
1import datetime
2import decimal
3import uuid
4from functools import lru_cache
6from plain import models
7from plain.exceptions import FieldError
8from plain.models.backends.base.operations import BaseDatabaseOperations
9from plain.models.constants import OnConflict
10from plain.models.db import DatabaseError, NotSupportedError
11from plain.models.expressions import Col
12from plain.utils import timezone
13from plain.utils.dateparse import parse_date, parse_datetime, parse_time
14from plain.utils.functional import cached_property
17class DatabaseOperations(BaseDatabaseOperations):
18 cast_char_field_without_max_length = "text"
19 cast_data_types = {
20 "DateField": "TEXT",
21 "DateTimeField": "TEXT",
22 }
23 explain_prefix = "EXPLAIN QUERY PLAN"
24 # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
25 # SQLite. Use JSON_TYPE() instead.
26 jsonfield_datatype_values = frozenset(["null", "false", "true"])
28 def bulk_batch_size(self, fields, objs):
29 """
30 SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
31 999 variables per query.
33 If there's only a single field to insert, the limit is 500
34 (SQLITE_MAX_COMPOUND_SELECT).
35 """
36 if len(fields) == 1:
37 return 500
38 elif len(fields) > 1:
39 return self.connection.features.max_query_params // len(fields)
40 else:
41 return len(objs)
43 def check_expression_support(self, expression):
44 bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
45 bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
46 if isinstance(expression, bad_aggregates):
47 for expr in expression.get_source_expressions():
48 try:
49 output_field = expr.output_field
50 except (AttributeError, FieldError):
51 # Not every subexpression has an output_field which is fine
52 # to ignore.
53 pass
54 else:
55 if isinstance(output_field, bad_fields):
56 raise NotSupportedError(
57 "You cannot use Sum, Avg, StdDev, and Variance "
58 "aggregations on date/time fields in sqlite3 "
59 "since date/time is saved as text."
60 )
61 if (
62 isinstance(expression, models.Aggregate)
63 and expression.distinct
64 and len(expression.source_expressions) > 1
65 ):
66 raise NotSupportedError(
67 "SQLite doesn't support DISTINCT on aggregate functions "
68 "accepting multiple arguments."
69 )
71 def date_extract_sql(self, lookup_type, sql, params):
72 """
73 Support EXTRACT with a user-defined function plain_date_extract()
74 that's registered in connect(). Use single quotes because this is a
75 string and could otherwise cause a collision with a field name.
76 """
77 return f"plain_date_extract(%s, {sql})", (lookup_type.lower(), *params)
79 def fetch_returned_insert_rows(self, cursor):
80 """
81 Given a cursor object that has just performed an INSERT...RETURNING
82 statement into a table, return the list of returned data.
83 """
84 return cursor.fetchall()
86 def format_for_duration_arithmetic(self, sql):
87 """Do nothing since formatting is handled in the custom function."""
88 return sql
90 def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
91 return f"plain_date_trunc(%s, {sql}, %s, %s)", (
92 lookup_type.lower(),
93 *params,
94 *self._convert_tznames_to_sql(tzname),
95 )
97 def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
98 return f"plain_time_trunc(%s, {sql}, %s, %s)", (
99 lookup_type.lower(),
100 *params,
101 *self._convert_tznames_to_sql(tzname),
102 )
104 def _convert_tznames_to_sql(self, tzname):
105 if tzname:
106 return tzname, self.connection.timezone_name
107 return None, None
109 def datetime_cast_date_sql(self, sql, params, tzname):
110 return f"plain_datetime_cast_date({sql}, %s, %s)", (
111 *params,
112 *self._convert_tznames_to_sql(tzname),
113 )
115 def datetime_cast_time_sql(self, sql, params, tzname):
116 return f"plain_datetime_cast_time({sql}, %s, %s)", (
117 *params,
118 *self._convert_tznames_to_sql(tzname),
119 )
121 def datetime_extract_sql(self, lookup_type, sql, params, tzname):
122 return f"plain_datetime_extract(%s, {sql}, %s, %s)", (
123 lookup_type.lower(),
124 *params,
125 *self._convert_tznames_to_sql(tzname),
126 )
128 def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
129 return f"plain_datetime_trunc(%s, {sql}, %s, %s)", (
130 lookup_type.lower(),
131 *params,
132 *self._convert_tznames_to_sql(tzname),
133 )
135 def time_extract_sql(self, lookup_type, sql, params):
136 return f"plain_time_extract(%s, {sql})", (lookup_type.lower(), *params)
138 def pk_default_value(self):
139 return "NULL"
141 def _quote_params_for_last_executed_query(self, params):
142 """
143 Only for last_executed_query! Don't use this to execute SQL queries!
144 """
145 # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
146 # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
147 # number of return values, default = 2000). Since Python's sqlite3
148 # module doesn't expose the get_limit() C API, assume the default
149 # limits are in effect and split the work in batches if needed.
150 BATCH_SIZE = 999
151 if len(params) > BATCH_SIZE:
152 results = ()
153 for index in range(0, len(params), BATCH_SIZE):
154 chunk = params[index : index + BATCH_SIZE]
155 results += self._quote_params_for_last_executed_query(chunk)
156 return results
158 sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
159 # Bypass Plain's wrappers and use the underlying sqlite3 connection
160 # to avoid logging this query - it would trigger infinite recursion.
161 cursor = self.connection.connection.cursor()
162 # Native sqlite3 cursors cannot be used as context managers.
163 try:
164 return cursor.execute(sql, params).fetchone()
165 finally:
166 cursor.close()
168 def last_executed_query(self, cursor, sql, params):
169 # Python substitutes parameters in Modules/_sqlite/cursor.c with:
170 # bind_parameters(state, self->statement, parameters);
171 # Unfortunately there is no way to reach self->statement from Python,
172 # so we quote and substitute parameters manually.
173 if params:
174 if isinstance(params, list | tuple):
175 params = self._quote_params_for_last_executed_query(params)
176 else:
177 values = tuple(params.values())
178 values = self._quote_params_for_last_executed_query(values)
179 params = dict(zip(params, values))
180 return sql % params
181 # For consistency with SQLiteCursorWrapper.execute(), just return sql
182 # when there are no parameters. See #13648 and #17158.
183 else:
184 return sql
186 def quote_name(self, name):
187 if name.startswith('"') and name.endswith('"'):
188 return name # Quoting once is enough.
189 return f'"{name}"'
191 def no_limit_value(self):
192 return -1
194 def __references_graph(self, table_name):
195 query = """
196 WITH tables AS (
197 SELECT %s name
198 UNION
199 SELECT sqlite_master.name
200 FROM sqlite_master
201 JOIN tables ON (sql REGEXP %s || tables.name || %s)
202 ) SELECT name FROM tables;
203 """
204 params = (
205 table_name,
206 r'(?i)\s+references\s+("|\')?',
207 r'("|\')?\s*\(',
208 )
209 with self.connection.cursor() as cursor:
210 results = cursor.execute(query, params)
211 return [row[0] for row in results.fetchall()]
213 @cached_property
214 def _references_graph(self):
215 # 512 is large enough to fit the ~330 tables (as of this writing) in
216 # Plain's test suite.
217 return lru_cache(maxsize=512)(self.__references_graph)
219 def sequence_reset_by_name_sql(self, style, sequences):
220 if not sequences:
221 return []
222 return [
223 "{} {} {} {} = 0 {} {} {} ({});".format(
224 style.SQL_KEYWORD("UPDATE"),
225 style.SQL_TABLE(self.quote_name("sqlite_sequence")),
226 style.SQL_KEYWORD("SET"),
227 style.SQL_FIELD(self.quote_name("seq")),
228 style.SQL_KEYWORD("WHERE"),
229 style.SQL_FIELD(self.quote_name("name")),
230 style.SQL_KEYWORD("IN"),
231 ", ".join(
232 [
233 "'{}'".format(sequence_info["table"])
234 for sequence_info in sequences
235 ]
236 ),
237 ),
238 ]
240 def adapt_datetimefield_value(self, value):
241 if value is None:
242 return None
244 # Expression values are adapted by the database.
245 if hasattr(value, "resolve_expression"):
246 return value
248 # SQLite doesn't support tz-aware datetimes
249 if timezone.is_aware(value):
250 value = timezone.make_naive(value, self.connection.timezone)
252 return str(value)
254 def adapt_timefield_value(self, value):
255 if value is None:
256 return None
258 # Expression values are adapted by the database.
259 if hasattr(value, "resolve_expression"):
260 return value
262 # SQLite doesn't support tz-aware datetimes
263 if timezone.is_aware(value):
264 raise ValueError("SQLite backend does not support timezone-aware times.")
266 return str(value)
268 def get_db_converters(self, expression):
269 converters = super().get_db_converters(expression)
270 internal_type = expression.output_field.get_internal_type()
271 if internal_type == "DateTimeField":
272 converters.append(self.convert_datetimefield_value)
273 elif internal_type == "DateField":
274 converters.append(self.convert_datefield_value)
275 elif internal_type == "TimeField":
276 converters.append(self.convert_timefield_value)
277 elif internal_type == "DecimalField":
278 converters.append(self.get_decimalfield_converter(expression))
279 elif internal_type == "UUIDField":
280 converters.append(self.convert_uuidfield_value)
281 elif internal_type == "BooleanField":
282 converters.append(self.convert_booleanfield_value)
283 return converters
285 def convert_datetimefield_value(self, value, expression, connection):
286 if value is not None:
287 if not isinstance(value, datetime.datetime):
288 value = parse_datetime(value)
289 if not timezone.is_aware(value):
290 value = timezone.make_aware(value, self.connection.timezone)
291 return value
293 def convert_datefield_value(self, value, expression, connection):
294 if value is not None:
295 if not isinstance(value, datetime.date):
296 value = parse_date(value)
297 return value
299 def convert_timefield_value(self, value, expression, connection):
300 if value is not None:
301 if not isinstance(value, datetime.time):
302 value = parse_time(value)
303 return value
305 def get_decimalfield_converter(self, expression):
306 # SQLite stores only 15 significant digits. Digits coming from
307 # float inaccuracy must be removed.
308 create_decimal = decimal.Context(prec=15).create_decimal_from_float
309 if isinstance(expression, Col):
310 quantize_value = decimal.Decimal(1).scaleb(
311 -expression.output_field.decimal_places
312 )
314 def converter(value, expression, connection):
315 if value is not None:
316 return create_decimal(value).quantize(
317 quantize_value, context=expression.output_field.context
318 )
320 else:
322 def converter(value, expression, connection):
323 if value is not None:
324 return create_decimal(value)
326 return converter
328 def convert_uuidfield_value(self, value, expression, connection):
329 if value is not None:
330 value = uuid.UUID(value)
331 return value
333 def convert_booleanfield_value(self, value, expression, connection):
334 return bool(value) if value in (1, 0) else value
336 def bulk_insert_sql(self, fields, placeholder_rows):
337 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
338 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
339 return f"VALUES {values_sql}"
341 def combine_expression(self, connector, sub_expressions):
342 # SQLite doesn't have a ^ operator, so use the user-defined POWER
343 # function that's registered in connect().
344 if connector == "^":
345 return "POWER({})".format(",".join(sub_expressions))
346 elif connector == "#":
347 return "BITXOR({})".format(",".join(sub_expressions))
348 return super().combine_expression(connector, sub_expressions)
350 def combine_duration_expression(self, connector, sub_expressions):
351 if connector not in ["+", "-", "*", "/"]:
352 raise DatabaseError(f"Invalid connector for timedelta: {connector}.")
353 fn_params = [f"'{connector}'"] + sub_expressions
354 if len(fn_params) > 3:
355 raise ValueError("Too many params for timedelta operations.")
356 return "plain_format_dtdelta({})".format(", ".join(fn_params))
358 def integer_field_range(self, internal_type):
359 # SQLite doesn't enforce any integer constraints, but sqlite3 supports
360 # integers up to 64 bits.
361 if internal_type in [
362 "PositiveBigIntegerField",
363 "PositiveIntegerField",
364 "PositiveSmallIntegerField",
365 ]:
366 return (0, 9223372036854775807)
367 return (-9223372036854775808, 9223372036854775807)
369 def subtract_temporals(self, internal_type, lhs, rhs):
370 lhs_sql, lhs_params = lhs
371 rhs_sql, rhs_params = rhs
372 params = (*lhs_params, *rhs_params)
373 if internal_type == "TimeField":
374 return f"plain_time_diff({lhs_sql}, {rhs_sql})", params
375 return f"plain_timestamp_diff({lhs_sql}, {rhs_sql})", params
377 def insert_statement(self, on_conflict=None):
378 if on_conflict == OnConflict.IGNORE:
379 return "INSERT OR IGNORE INTO"
380 return super().insert_statement(on_conflict=on_conflict)
382 def return_insert_columns(self, fields):
383 # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
384 if not fields:
385 return "", ()
386 columns = [
387 f"{self.quote_name(field.model._meta.db_table)}.{self.quote_name(field.column)}"
388 for field in fields
389 ]
390 return "RETURNING {}".format(", ".join(columns)), ()
392 def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
393 if (
394 on_conflict == OnConflict.UPDATE
395 and self.connection.features.supports_update_conflicts_with_target
396 ):
397 return "ON CONFLICT({}) DO UPDATE SET {}".format(
398 ", ".join(map(self.quote_name, unique_fields)),
399 ", ".join(
400 [
401 f"{field} = EXCLUDED.{field}"
402 for field in map(self.quote_name, update_fields)
403 ]
404 ),
405 )
406 return super().on_conflict_suffix_sql(
407 fields,
408 on_conflict,
409 update_fields,
410 unique_fields,
411 )