Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-models/plain/models/backends/sqlite3/introspection.py: 14%

196 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-16 22:03 -0500

1from collections import namedtuple 

2 

3import sqlparse 

4 

5from plain.models import Index 

6from plain.models.backends.base.introspection import ( 

7 BaseDatabaseIntrospection, 

8 TableInfo, 

9) 

10from plain.models.backends.base.introspection import FieldInfo as BaseFieldInfo 

11from plain.models.db import DatabaseError 

12from plain.utils.regex_helper import _lazy_re_compile 

13 

14FieldInfo = namedtuple( 

15 "FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint") 

16) 

17 

18field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$") 

19 

20 

21def get_field_size(name): 

22 """Extract the size number from a "varchar(11)" type name""" 

23 m = field_size_re.search(name) 

24 return int(m[1]) if m else None 

25 

26 

27# This light wrapper "fakes" a dictionary interface, because some SQLite data 

28# types include variables in them -- e.g. "varchar(30)" -- and can't be matched 

29# as a simple dictionary lookup. 

30class FlexibleFieldLookupDict: 

31 # Maps SQL types to Plain Field types. Some of the SQL types have multiple 

32 # entries here because SQLite allows for anything and doesn't normalize the 

33 # field type; it uses whatever was given. 

34 base_data_types_reverse = { 

35 "bool": "BooleanField", 

36 "boolean": "BooleanField", 

37 "smallint": "SmallIntegerField", 

38 "smallint unsigned": "PositiveSmallIntegerField", 

39 "smallinteger": "SmallIntegerField", 

40 "int": "IntegerField", 

41 "integer": "IntegerField", 

42 "bigint": "BigIntegerField", 

43 "integer unsigned": "PositiveIntegerField", 

44 "bigint unsigned": "PositiveBigIntegerField", 

45 "decimal": "DecimalField", 

46 "real": "FloatField", 

47 "text": "TextField", 

48 "char": "CharField", 

49 "varchar": "CharField", 

50 "blob": "BinaryField", 

51 "date": "DateField", 

52 "datetime": "DateTimeField", 

53 "time": "TimeField", 

54 } 

55 

56 def __getitem__(self, key): 

57 key = key.lower().split("(", 1)[0].strip() 

58 return self.base_data_types_reverse[key] 

59 

60 

61class DatabaseIntrospection(BaseDatabaseIntrospection): 

62 data_types_reverse = FlexibleFieldLookupDict() 

63 

64 def get_field_type(self, data_type, description): 

65 field_type = super().get_field_type(data_type, description) 

66 if description.pk and field_type in { 

67 "BigIntegerField", 

68 "IntegerField", 

69 "SmallIntegerField", 

70 }: 

71 # No support for BigAutoField or SmallAutoField as SQLite treats 

72 # all integer primary keys as signed 64-bit integers. 

73 return "AutoField" 

74 if description.has_json_constraint: 

75 return "JSONField" 

76 return field_type 

77 

78 def get_table_list(self, cursor): 

79 """Return a list of table and view names in the current database.""" 

80 # Skip the sqlite_sequence system table used for autoincrement key 

81 # generation. 

82 cursor.execute( 

83 """ 

84 SELECT name, type FROM sqlite_master 

85 WHERE type in ('table', 'view') AND NOT name='sqlite_sequence' 

86 ORDER BY name""" 

87 ) 

88 return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()] 

89 

90 def get_table_description(self, cursor, table_name): 

91 """ 

92 Return a description of the table with the DB-API cursor.description 

93 interface. 

94 """ 

95 cursor.execute( 

96 "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name) 

97 ) 

98 table_info = cursor.fetchall() 

99 if not table_info: 

100 raise DatabaseError(f"Table {table_name} does not exist (empty pragma).") 

101 collations = self._get_column_collations(cursor, table_name) 

102 json_columns = set() 

103 if self.connection.features.can_introspect_json_field: 

104 for line in table_info: 

105 column = line[1] 

106 json_constraint_sql = '%%json_valid("%s")%%' % column 

107 has_json_constraint = cursor.execute( 

108 """ 

109 SELECT sql 

110 FROM sqlite_master 

111 WHERE 

112 type = 'table' AND 

113 name = %s AND 

114 sql LIKE %s 

115 """, 

116 [table_name, json_constraint_sql], 

117 ).fetchone() 

118 if has_json_constraint: 

119 json_columns.add(column) 

120 return [ 

121 FieldInfo( 

122 name, 

123 data_type, 

124 get_field_size(data_type), 

125 None, 

126 None, 

127 None, 

128 not notnull, 

129 default, 

130 collations.get(name), 

131 pk == 1, 

132 name in json_columns, 

133 ) 

134 for cid, name, data_type, notnull, default, pk in table_info 

135 ] 

136 

137 def get_sequences(self, cursor, table_name, table_fields=()): 

138 pk_col = self.get_primary_key_column(cursor, table_name) 

139 return [{"table": table_name, "column": pk_col}] 

140 

141 def get_relations(self, cursor, table_name): 

142 """ 

143 Return a dictionary of {column_name: (ref_column_name, ref_table_name)} 

144 representing all foreign keys in the given table. 

145 """ 

146 cursor.execute( 

147 "PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name) 

148 ) 

149 return { 

150 column_name: (ref_column_name, ref_table_name) 

151 for ( 

152 _, 

153 _, 

154 ref_table_name, 

155 column_name, 

156 ref_column_name, 

157 *_, 

158 ) in cursor.fetchall() 

159 } 

160 

161 def get_primary_key_columns(self, cursor, table_name): 

162 cursor.execute( 

163 "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name) 

164 ) 

165 return [name for _, name, *_, pk in cursor.fetchall() if pk] 

166 

167 def _parse_column_or_constraint_definition(self, tokens, columns): 

168 token = None 

169 is_constraint_definition = None 

170 field_name = None 

171 constraint_name = None 

172 unique = False 

173 unique_columns = [] 

174 check = False 

175 check_columns = [] 

176 braces_deep = 0 

177 for token in tokens: 

178 if token.match(sqlparse.tokens.Punctuation, "("): 

179 braces_deep += 1 

180 elif token.match(sqlparse.tokens.Punctuation, ")"): 

181 braces_deep -= 1 

182 if braces_deep < 0: 

183 # End of columns and constraints for table definition. 

184 break 

185 elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","): 

186 # End of current column or constraint definition. 

187 break 

188 # Detect column or constraint definition by first token. 

189 if is_constraint_definition is None: 

190 is_constraint_definition = token.match( 

191 sqlparse.tokens.Keyword, "CONSTRAINT" 

192 ) 

193 if is_constraint_definition: 

194 continue 

195 if is_constraint_definition: 

196 # Detect constraint name by second token. 

197 if constraint_name is None: 

198 if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword): 

199 constraint_name = token.value 

200 elif token.ttype == sqlparse.tokens.Literal.String.Symbol: 

201 constraint_name = token.value[1:-1] 

202 # Start constraint columns parsing after UNIQUE keyword. 

203 if token.match(sqlparse.tokens.Keyword, "UNIQUE"): 

204 unique = True 

205 unique_braces_deep = braces_deep 

206 elif unique: 

207 if unique_braces_deep == braces_deep: 

208 if unique_columns: 

209 # Stop constraint parsing. 

210 unique = False 

211 continue 

212 if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword): 

213 unique_columns.append(token.value) 

214 elif token.ttype == sqlparse.tokens.Literal.String.Symbol: 

215 unique_columns.append(token.value[1:-1]) 

216 else: 

217 # Detect field name by first token. 

218 if field_name is None: 

219 if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword): 

220 field_name = token.value 

221 elif token.ttype == sqlparse.tokens.Literal.String.Symbol: 

222 field_name = token.value[1:-1] 

223 if token.match(sqlparse.tokens.Keyword, "UNIQUE"): 

224 unique_columns = [field_name] 

225 # Start constraint columns parsing after CHECK keyword. 

226 if token.match(sqlparse.tokens.Keyword, "CHECK"): 

227 check = True 

228 check_braces_deep = braces_deep 

229 elif check: 

230 if check_braces_deep == braces_deep: 

231 if check_columns: 

232 # Stop constraint parsing. 

233 check = False 

234 continue 

235 if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword): 

236 if token.value in columns: 

237 check_columns.append(token.value) 

238 elif token.ttype == sqlparse.tokens.Literal.String.Symbol: 

239 if token.value[1:-1] in columns: 

240 check_columns.append(token.value[1:-1]) 

241 unique_constraint = ( 

242 { 

243 "unique": True, 

244 "columns": unique_columns, 

245 "primary_key": False, 

246 "foreign_key": None, 

247 "check": False, 

248 "index": False, 

249 } 

250 if unique_columns 

251 else None 

252 ) 

253 check_constraint = ( 

254 { 

255 "check": True, 

256 "columns": check_columns, 

257 "primary_key": False, 

258 "unique": False, 

259 "foreign_key": None, 

260 "index": False, 

261 } 

262 if check_columns 

263 else None 

264 ) 

265 return constraint_name, unique_constraint, check_constraint, token 

266 

267 def _parse_table_constraints(self, sql, columns): 

268 # Check constraint parsing is based of SQLite syntax diagram. 

269 # https://www.sqlite.org/syntaxdiagrams.html#table-constraint 

270 statement = sqlparse.parse(sql)[0] 

271 constraints = {} 

272 unnamed_constrains_index = 0 

273 tokens = (token for token in statement.flatten() if not token.is_whitespace) 

274 # Go to columns and constraint definition 

275 for token in tokens: 

276 if token.match(sqlparse.tokens.Punctuation, "("): 

277 break 

278 # Parse columns and constraint definition 

279 while True: 

280 ( 

281 constraint_name, 

282 unique, 

283 check, 

284 end_token, 

285 ) = self._parse_column_or_constraint_definition(tokens, columns) 

286 if unique: 

287 if constraint_name: 

288 constraints[constraint_name] = unique 

289 else: 

290 unnamed_constrains_index += 1 

291 constraints[ 

292 "__unnamed_constraint_%s__" % unnamed_constrains_index 

293 ] = unique 

294 if check: 

295 if constraint_name: 

296 constraints[constraint_name] = check 

297 else: 

298 unnamed_constrains_index += 1 

299 constraints[ 

300 "__unnamed_constraint_%s__" % unnamed_constrains_index 

301 ] = check 

302 if end_token.match(sqlparse.tokens.Punctuation, ")"): 

303 break 

304 return constraints 

305 

306 def get_constraints(self, cursor, table_name): 

307 """ 

308 Retrieve any constraints or keys (unique, pk, fk, check, index) across 

309 one or more columns. 

310 """ 

311 constraints = {} 

312 # Find inline check constraints. 

313 try: 

314 table_schema = cursor.execute( 

315 "SELECT sql FROM sqlite_master WHERE type='table' and name={}".format( 

316 self.connection.ops.quote_name(table_name) 

317 ) 

318 ).fetchone()[0] 

319 except TypeError: 

320 # table_name is a view. 

321 pass 

322 else: 

323 columns = { 

324 info.name for info in self.get_table_description(cursor, table_name) 

325 } 

326 constraints.update(self._parse_table_constraints(table_schema, columns)) 

327 

328 # Get the index info 

329 cursor.execute( 

330 "PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name) 

331 ) 

332 for row in cursor.fetchall(): 

333 # SQLite 3.8.9+ has 5 columns, however older versions only give 3 

334 # columns. Discard last 2 columns if there. 

335 number, index, unique = row[:3] 

336 cursor.execute( 

337 "SELECT sql FROM sqlite_master " 

338 "WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index) 

339 ) 

340 # There's at most one row. 

341 (sql,) = cursor.fetchone() or (None,) 

342 # Inline constraints are already detected in 

343 # _parse_table_constraints(). The reasons to avoid fetching inline 

344 # constraints from `PRAGMA index_list` are: 

345 # - Inline constraints can have a different name and information 

346 # than what `PRAGMA index_list` gives. 

347 # - Not all inline constraints may appear in `PRAGMA index_list`. 

348 if not sql: 

349 # An inline constraint 

350 continue 

351 # Get the index info for that index 

352 cursor.execute( 

353 "PRAGMA index_info(%s)" % self.connection.ops.quote_name(index) 

354 ) 

355 for index_rank, column_rank, column in cursor.fetchall(): 

356 if index not in constraints: 

357 constraints[index] = { 

358 "columns": [], 

359 "primary_key": False, 

360 "unique": bool(unique), 

361 "foreign_key": None, 

362 "check": False, 

363 "index": True, 

364 } 

365 constraints[index]["columns"].append(column) 

366 # Add type and column orders for indexes 

367 if constraints[index]["index"]: 

368 # SQLite doesn't support any index type other than b-tree 

369 constraints[index]["type"] = Index.suffix 

370 orders = self._get_index_columns_orders(sql) 

371 if orders is not None: 

372 constraints[index]["orders"] = orders 

373 # Get the PK 

374 pk_columns = self.get_primary_key_columns(cursor, table_name) 

375 if pk_columns: 

376 # SQLite doesn't actually give a name to the PK constraint, 

377 # so we invent one. This is fine, as the SQLite backend never 

378 # deletes PK constraints by name, as you can't delete constraints 

379 # in SQLite; we remake the table with a new PK instead. 

380 constraints["__primary__"] = { 

381 "columns": pk_columns, 

382 "primary_key": True, 

383 "unique": False, # It's not actually a unique constraint. 

384 "foreign_key": None, 

385 "check": False, 

386 "index": False, 

387 } 

388 relations = enumerate(self.get_relations(cursor, table_name).items()) 

389 constraints.update( 

390 { 

391 f"fk_{index}": { 

392 "columns": [column_name], 

393 "primary_key": False, 

394 "unique": False, 

395 "foreign_key": (ref_table_name, ref_column_name), 

396 "check": False, 

397 "index": False, 

398 } 

399 for index, (column_name, (ref_column_name, ref_table_name)) in relations 

400 } 

401 ) 

402 return constraints 

403 

404 def _get_index_columns_orders(self, sql): 

405 tokens = sqlparse.parse(sql)[0] 

406 for token in tokens: 

407 if isinstance(token, sqlparse.sql.Parenthesis): 

408 columns = str(token).strip("()").split(", ") 

409 return ["DESC" if info.endswith("DESC") else "ASC" for info in columns] 

410 return None 

411 

412 def _get_column_collations(self, cursor, table_name): 

413 row = cursor.execute( 

414 """ 

415 SELECT sql 

416 FROM sqlite_master 

417 WHERE type = 'table' AND name = %s 

418 """, 

419 [table_name], 

420 ).fetchone() 

421 if not row: 

422 return {} 

423 

424 sql = row[0] 

425 columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ") 

426 collations = {} 

427 for column in columns: 

428 tokens = column[1:].split() 

429 column_name = tokens[0].strip('"') 

430 for index, token in enumerate(tokens): 

431 if token == "COLLATE": 

432 collation = tokens[index + 1] 

433 break 

434 else: 

435 collation = None 

436 collations[column_name] = collation 

437 return collations