Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/backends/sqlite3/operations.py: 45%

211 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 11:16 -0600

1import datetime 

2import decimal 

3import uuid 

4from functools import lru_cache 

5 

6from plain import models 

7from plain.exceptions import FieldError 

8from plain.models.backends.base.operations import BaseDatabaseOperations 

9from plain.models.constants import OnConflict 

10from plain.models.db import DatabaseError, NotSupportedError 

11from plain.models.expressions import Col 

12from plain.utils import timezone 

13from plain.utils.dateparse import parse_date, parse_datetime, parse_time 

14from plain.utils.functional import cached_property 

15 

16 

17class DatabaseOperations(BaseDatabaseOperations): 

18 cast_char_field_without_max_length = "text" 

19 cast_data_types = { 

20 "DateField": "TEXT", 

21 "DateTimeField": "TEXT", 

22 } 

23 explain_prefix = "EXPLAIN QUERY PLAN" 

24 # List of datatypes to that cannot be extracted with JSON_EXTRACT() on 

25 # SQLite. Use JSON_TYPE() instead. 

26 jsonfield_datatype_values = frozenset(["null", "false", "true"]) 

27 

28 def bulk_batch_size(self, fields, objs): 

29 """ 

30 SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of 

31 999 variables per query. 

32 

33 If there's only a single field to insert, the limit is 500 

34 (SQLITE_MAX_COMPOUND_SELECT). 

35 """ 

36 if len(fields) == 1: 

37 return 500 

38 elif len(fields) > 1: 

39 return self.connection.features.max_query_params // len(fields) 

40 else: 

41 return len(objs) 

42 

43 def check_expression_support(self, expression): 

44 bad_fields = (models.DateField, models.DateTimeField, models.TimeField) 

45 bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev) 

46 if isinstance(expression, bad_aggregates): 

47 for expr in expression.get_source_expressions(): 

48 try: 

49 output_field = expr.output_field 

50 except (AttributeError, FieldError): 

51 # Not every subexpression has an output_field which is fine 

52 # to ignore. 

53 pass 

54 else: 

55 if isinstance(output_field, bad_fields): 

56 raise NotSupportedError( 

57 "You cannot use Sum, Avg, StdDev, and Variance " 

58 "aggregations on date/time fields in sqlite3 " 

59 "since date/time is saved as text." 

60 ) 

61 if ( 

62 isinstance(expression, models.Aggregate) 

63 and expression.distinct 

64 and len(expression.source_expressions) > 1 

65 ): 

66 raise NotSupportedError( 

67 "SQLite doesn't support DISTINCT on aggregate functions " 

68 "accepting multiple arguments." 

69 ) 

70 

71 def date_extract_sql(self, lookup_type, sql, params): 

72 """ 

73 Support EXTRACT with a user-defined function plain_date_extract() 

74 that's registered in connect(). Use single quotes because this is a 

75 string and could otherwise cause a collision with a field name. 

76 """ 

77 return f"plain_date_extract(%s, {sql})", (lookup_type.lower(), *params) 

78 

79 def fetch_returned_insert_rows(self, cursor): 

80 """ 

81 Given a cursor object that has just performed an INSERT...RETURNING 

82 statement into a table, return the list of returned data. 

83 """ 

84 return cursor.fetchall() 

85 

86 def format_for_duration_arithmetic(self, sql): 

87 """Do nothing since formatting is handled in the custom function.""" 

88 return sql 

89 

90 def date_trunc_sql(self, lookup_type, sql, params, tzname=None): 

91 return f"plain_date_trunc(%s, {sql}, %s, %s)", ( 

92 lookup_type.lower(), 

93 *params, 

94 *self._convert_tznames_to_sql(tzname), 

95 ) 

96 

97 def time_trunc_sql(self, lookup_type, sql, params, tzname=None): 

98 return f"plain_time_trunc(%s, {sql}, %s, %s)", ( 

99 lookup_type.lower(), 

100 *params, 

101 *self._convert_tznames_to_sql(tzname), 

102 ) 

103 

104 def _convert_tznames_to_sql(self, tzname): 

105 if tzname: 

106 return tzname, self.connection.timezone_name 

107 return None, None 

108 

109 def datetime_cast_date_sql(self, sql, params, tzname): 

110 return f"plain_datetime_cast_date({sql}, %s, %s)", ( 

111 *params, 

112 *self._convert_tznames_to_sql(tzname), 

113 ) 

114 

115 def datetime_cast_time_sql(self, sql, params, tzname): 

116 return f"plain_datetime_cast_time({sql}, %s, %s)", ( 

117 *params, 

118 *self._convert_tznames_to_sql(tzname), 

119 ) 

120 

121 def datetime_extract_sql(self, lookup_type, sql, params, tzname): 

122 return f"plain_datetime_extract(%s, {sql}, %s, %s)", ( 

123 lookup_type.lower(), 

124 *params, 

125 *self._convert_tznames_to_sql(tzname), 

126 ) 

127 

128 def datetime_trunc_sql(self, lookup_type, sql, params, tzname): 

129 return f"plain_datetime_trunc(%s, {sql}, %s, %s)", ( 

130 lookup_type.lower(), 

131 *params, 

132 *self._convert_tznames_to_sql(tzname), 

133 ) 

134 

135 def time_extract_sql(self, lookup_type, sql, params): 

136 return f"plain_time_extract(%s, {sql})", (lookup_type.lower(), *params) 

137 

138 def pk_default_value(self): 

139 return "NULL" 

140 

141 def _quote_params_for_last_executed_query(self, params): 

142 """ 

143 Only for last_executed_query! Don't use this to execute SQL queries! 

144 """ 

145 # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the 

146 # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the 

147 # number of return values, default = 2000). Since Python's sqlite3 

148 # module doesn't expose the get_limit() C API, assume the default 

149 # limits are in effect and split the work in batches if needed. 

150 BATCH_SIZE = 999 

151 if len(params) > BATCH_SIZE: 

152 results = () 

153 for index in range(0, len(params), BATCH_SIZE): 

154 chunk = params[index : index + BATCH_SIZE] 

155 results += self._quote_params_for_last_executed_query(chunk) 

156 return results 

157 

158 sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params)) 

159 # Bypass Plain's wrappers and use the underlying sqlite3 connection 

160 # to avoid logging this query - it would trigger infinite recursion. 

161 cursor = self.connection.connection.cursor() 

162 # Native sqlite3 cursors cannot be used as context managers. 

163 try: 

164 return cursor.execute(sql, params).fetchone() 

165 finally: 

166 cursor.close() 

167 

168 def last_executed_query(self, cursor, sql, params): 

169 # Python substitutes parameters in Modules/_sqlite/cursor.c with: 

170 # bind_parameters(state, self->statement, parameters); 

171 # Unfortunately there is no way to reach self->statement from Python, 

172 # so we quote and substitute parameters manually. 

173 if params: 

174 if isinstance(params, list | tuple): 

175 params = self._quote_params_for_last_executed_query(params) 

176 else: 

177 values = tuple(params.values()) 

178 values = self._quote_params_for_last_executed_query(values) 

179 params = dict(zip(params, values)) 

180 return sql % params 

181 # For consistency with SQLiteCursorWrapper.execute(), just return sql 

182 # when there are no parameters. See #13648 and #17158. 

183 else: 

184 return sql 

185 

186 def quote_name(self, name): 

187 if name.startswith('"') and name.endswith('"'): 

188 return name # Quoting once is enough. 

189 return f'"{name}"' 

190 

191 def no_limit_value(self): 

192 return -1 

193 

194 def __references_graph(self, table_name): 

195 query = """ 

196 WITH tables AS ( 

197 SELECT %s name 

198 UNION 

199 SELECT sqlite_master.name 

200 FROM sqlite_master 

201 JOIN tables ON (sql REGEXP %s || tables.name || %s) 

202 ) SELECT name FROM tables; 

203 """ 

204 params = ( 

205 table_name, 

206 r'(?i)\s+references\s+("|\')?', 

207 r'("|\')?\s*\(', 

208 ) 

209 with self.connection.cursor() as cursor: 

210 results = cursor.execute(query, params) 

211 return [row[0] for row in results.fetchall()] 

212 

213 @cached_property 

214 def _references_graph(self): 

215 # 512 is large enough to fit the ~330 tables (as of this writing) in 

216 # Plain's test suite. 

217 return lru_cache(maxsize=512)(self.__references_graph) 

218 

219 def sequence_reset_by_name_sql(self, style, sequences): 

220 if not sequences: 

221 return [] 

222 return [ 

223 "{} {} {} {} = 0 {} {} {} ({});".format( 

224 style.SQL_KEYWORD("UPDATE"), 

225 style.SQL_TABLE(self.quote_name("sqlite_sequence")), 

226 style.SQL_KEYWORD("SET"), 

227 style.SQL_FIELD(self.quote_name("seq")), 

228 style.SQL_KEYWORD("WHERE"), 

229 style.SQL_FIELD(self.quote_name("name")), 

230 style.SQL_KEYWORD("IN"), 

231 ", ".join( 

232 [ 

233 "'{}'".format(sequence_info["table"]) 

234 for sequence_info in sequences 

235 ] 

236 ), 

237 ), 

238 ] 

239 

240 def adapt_datetimefield_value(self, value): 

241 if value is None: 

242 return None 

243 

244 # Expression values are adapted by the database. 

245 if hasattr(value, "resolve_expression"): 

246 return value 

247 

248 # SQLite doesn't support tz-aware datetimes 

249 if timezone.is_aware(value): 

250 value = timezone.make_naive(value, self.connection.timezone) 

251 

252 return str(value) 

253 

254 def adapt_timefield_value(self, value): 

255 if value is None: 

256 return None 

257 

258 # Expression values are adapted by the database. 

259 if hasattr(value, "resolve_expression"): 

260 return value 

261 

262 # SQLite doesn't support tz-aware datetimes 

263 if timezone.is_aware(value): 

264 raise ValueError("SQLite backend does not support timezone-aware times.") 

265 

266 return str(value) 

267 

268 def get_db_converters(self, expression): 

269 converters = super().get_db_converters(expression) 

270 internal_type = expression.output_field.get_internal_type() 

271 if internal_type == "DateTimeField": 

272 converters.append(self.convert_datetimefield_value) 

273 elif internal_type == "DateField": 

274 converters.append(self.convert_datefield_value) 

275 elif internal_type == "TimeField": 

276 converters.append(self.convert_timefield_value) 

277 elif internal_type == "DecimalField": 

278 converters.append(self.get_decimalfield_converter(expression)) 

279 elif internal_type == "UUIDField": 

280 converters.append(self.convert_uuidfield_value) 

281 elif internal_type == "BooleanField": 

282 converters.append(self.convert_booleanfield_value) 

283 return converters 

284 

285 def convert_datetimefield_value(self, value, expression, connection): 

286 if value is not None: 

287 if not isinstance(value, datetime.datetime): 

288 value = parse_datetime(value) 

289 if not timezone.is_aware(value): 

290 value = timezone.make_aware(value, self.connection.timezone) 

291 return value 

292 

293 def convert_datefield_value(self, value, expression, connection): 

294 if value is not None: 

295 if not isinstance(value, datetime.date): 

296 value = parse_date(value) 

297 return value 

298 

299 def convert_timefield_value(self, value, expression, connection): 

300 if value is not None: 

301 if not isinstance(value, datetime.time): 

302 value = parse_time(value) 

303 return value 

304 

305 def get_decimalfield_converter(self, expression): 

306 # SQLite stores only 15 significant digits. Digits coming from 

307 # float inaccuracy must be removed. 

308 create_decimal = decimal.Context(prec=15).create_decimal_from_float 

309 if isinstance(expression, Col): 

310 quantize_value = decimal.Decimal(1).scaleb( 

311 -expression.output_field.decimal_places 

312 ) 

313 

314 def converter(value, expression, connection): 

315 if value is not None: 

316 return create_decimal(value).quantize( 

317 quantize_value, context=expression.output_field.context 

318 ) 

319 

320 else: 

321 

322 def converter(value, expression, connection): 

323 if value is not None: 

324 return create_decimal(value) 

325 

326 return converter 

327 

328 def convert_uuidfield_value(self, value, expression, connection): 

329 if value is not None: 

330 value = uuid.UUID(value) 

331 return value 

332 

333 def convert_booleanfield_value(self, value, expression, connection): 

334 return bool(value) if value in (1, 0) else value 

335 

336 def bulk_insert_sql(self, fields, placeholder_rows): 

337 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows) 

338 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql) 

339 return f"VALUES {values_sql}" 

340 

341 def combine_expression(self, connector, sub_expressions): 

342 # SQLite doesn't have a ^ operator, so use the user-defined POWER 

343 # function that's registered in connect(). 

344 if connector == "^": 

345 return "POWER({})".format(",".join(sub_expressions)) 

346 elif connector == "#": 

347 return "BITXOR({})".format(",".join(sub_expressions)) 

348 return super().combine_expression(connector, sub_expressions) 

349 

350 def combine_duration_expression(self, connector, sub_expressions): 

351 if connector not in ["+", "-", "*", "/"]: 

352 raise DatabaseError(f"Invalid connector for timedelta: {connector}.") 

353 fn_params = [f"'{connector}'"] + sub_expressions 

354 if len(fn_params) > 3: 

355 raise ValueError("Too many params for timedelta operations.") 

356 return "plain_format_dtdelta({})".format(", ".join(fn_params)) 

357 

358 def integer_field_range(self, internal_type): 

359 # SQLite doesn't enforce any integer constraints, but sqlite3 supports 

360 # integers up to 64 bits. 

361 if internal_type in [ 

362 "PositiveBigIntegerField", 

363 "PositiveIntegerField", 

364 "PositiveSmallIntegerField", 

365 ]: 

366 return (0, 9223372036854775807) 

367 return (-9223372036854775808, 9223372036854775807) 

368 

369 def subtract_temporals(self, internal_type, lhs, rhs): 

370 lhs_sql, lhs_params = lhs 

371 rhs_sql, rhs_params = rhs 

372 params = (*lhs_params, *rhs_params) 

373 if internal_type == "TimeField": 

374 return f"plain_time_diff({lhs_sql}, {rhs_sql})", params 

375 return f"plain_timestamp_diff({lhs_sql}, {rhs_sql})", params 

376 

377 def insert_statement(self, on_conflict=None): 

378 if on_conflict == OnConflict.IGNORE: 

379 return "INSERT OR IGNORE INTO" 

380 return super().insert_statement(on_conflict=on_conflict) 

381 

382 def return_insert_columns(self, fields): 

383 # SQLite < 3.35 doesn't support an INSERT...RETURNING statement. 

384 if not fields: 

385 return "", () 

386 columns = [ 

387 f"{self.quote_name(field.model._meta.db_table)}.{self.quote_name(field.column)}" 

388 for field in fields 

389 ] 

390 return "RETURNING {}".format(", ".join(columns)), () 

391 

392 def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): 

393 if ( 

394 on_conflict == OnConflict.UPDATE 

395 and self.connection.features.supports_update_conflicts_with_target 

396 ): 

397 return "ON CONFLICT({}) DO UPDATE SET {}".format( 

398 ", ".join(map(self.quote_name, unique_fields)), 

399 ", ".join( 

400 [ 

401 f"{field} = EXCLUDED.{field}" 

402 for field in map(self.quote_name, update_fields) 

403 ] 

404 ), 

405 ) 

406 return super().on_conflict_suffix_sql( 

407 fields, 

408 on_conflict, 

409 update_fields, 

410 unique_fields, 

411 )