Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/sql/compiler.py: 43%

1031 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 11:16 -0600

1import collections 

2import json 

3import re 

4from functools import partial 

5from itertools import chain 

6 

7from plain.exceptions import EmptyResultSet, FieldError, FullResultSet 

8from plain.models.constants import LOOKUP_SEP 

9from plain.models.db import DatabaseError, NotSupportedError 

10from plain.models.expressions import F, OrderBy, RawSQL, Ref, Value 

11from plain.models.functions import Cast, Random 

12from plain.models.lookups import Lookup 

13from plain.models.query_utils import select_related_descend 

14from plain.models.sql.constants import ( 

15 CURSOR, 

16 GET_ITERATOR_CHUNK_SIZE, 

17 MULTI, 

18 NO_RESULTS, 

19 ORDER_DIR, 

20 SINGLE, 

21) 

22from plain.models.sql.query import Query, get_order_dir 

23from plain.models.sql.where import AND 

24from plain.models.transaction import TransactionManagementError 

25from plain.utils.functional import cached_property 

26from plain.utils.hashable import make_hashable 

27from plain.utils.regex_helper import _lazy_re_compile 

28 

29 

30class PositionRef(Ref): 

31 def __init__(self, ordinal, refs, source): 

32 self.ordinal = ordinal 

33 super().__init__(refs, source) 

34 

35 def as_sql(self, compiler, connection): 

36 return str(self.ordinal), () 

37 

38 

39class SQLCompiler: 

40 # Multiline ordering SQL clause may appear from RawSQL. 

41 ordering_parts = _lazy_re_compile( 

42 r"^(.*)\s(?:ASC|DESC).*", 

43 re.MULTILINE | re.DOTALL, 

44 ) 

45 

46 def __init__(self, query, connection, using, elide_empty=True): 

47 self.query = query 

48 self.connection = connection 

49 self.using = using 

50 # Some queries, e.g. coalesced aggregation, need to be executed even if 

51 # they would return an empty result set. 

52 self.elide_empty = elide_empty 

53 self.quote_cache = {"*": "*"} 

54 # The select, klass_info, and annotations are needed by QuerySet.iterator() 

55 # these are set as a side-effect of executing the query. Note that we calculate 

56 # separately a list of extra select columns needed for grammatical correctness 

57 # of the query, but these columns are not included in self.select. 

58 self.select = None 

59 self.annotation_col_map = None 

60 self.klass_info = None 

61 self._meta_ordering = None 

62 

63 def __repr__(self): 

64 return ( 

65 f"<{self.__class__.__qualname__} " 

66 f"model={self.query.model.__qualname__} " 

67 f"connection={self.connection!r} using={self.using!r}>" 

68 ) 

69 

70 def setup_query(self, with_col_aliases=False): 

71 if all(self.query.alias_refcount[a] == 0 for a in self.query.alias_map): 

72 self.query.get_initial_alias() 

73 self.select, self.klass_info, self.annotation_col_map = self.get_select( 

74 with_col_aliases=with_col_aliases, 

75 ) 

76 self.col_count = len(self.select) 

77 

78 def pre_sql_setup(self, with_col_aliases=False): 

79 """ 

80 Do any necessary class setup immediately prior to producing SQL. This 

81 is for things that can't necessarily be done in __init__ because we 

82 might not have all the pieces in place at that time. 

83 """ 

84 self.setup_query(with_col_aliases=with_col_aliases) 

85 order_by = self.get_order_by() 

86 self.where, self.having, self.qualify = self.query.where.split_having_qualify( 

87 must_group_by=self.query.group_by is not None 

88 ) 

89 extra_select = self.get_extra_select(order_by, self.select) 

90 self.has_extra_select = bool(extra_select) 

91 group_by = self.get_group_by(self.select + extra_select, order_by) 

92 return extra_select, order_by, group_by 

93 

94 def get_group_by(self, select, order_by): 

95 """ 

96 Return a list of 2-tuples of form (sql, params). 

97 

98 The logic of what exactly the GROUP BY clause contains is hard 

99 to describe in other words than "if it passes the test suite, 

100 then it is correct". 

101 """ 

102 # Some examples: 

103 # SomeModel.objects.annotate(Count('somecol')) 

104 # GROUP BY: all fields of the model 

105 # 

106 # SomeModel.objects.values('name').annotate(Count('somecol')) 

107 # GROUP BY: name 

108 # 

109 # SomeModel.objects.annotate(Count('somecol')).values('name') 

110 # GROUP BY: all cols of the model 

111 # 

112 # SomeModel.objects.values('name', 'pk') 

113 # .annotate(Count('somecol')).values('pk') 

114 # GROUP BY: name, pk 

115 # 

116 # SomeModel.objects.values('name').annotate(Count('somecol')).values('pk') 

117 # GROUP BY: name, pk 

118 # 

119 # In fact, the self.query.group_by is the minimal set to GROUP BY. It 

120 # can't be ever restricted to a smaller set, but additional columns in 

121 # HAVING, ORDER BY, and SELECT clauses are added to it. Unfortunately 

122 # the end result is that it is impossible to force the query to have 

123 # a chosen GROUP BY clause - you can almost do this by using the form: 

124 # .values(*wanted_cols).annotate(AnAggregate()) 

125 # but any later annotations, extra selects, values calls that 

126 # refer some column outside of the wanted_cols, order_by, or even 

127 # filter calls can alter the GROUP BY clause. 

128 

129 # The query.group_by is either None (no GROUP BY at all), True 

130 # (group by select fields), or a list of expressions to be added 

131 # to the group by. 

132 if self.query.group_by is None: 

133 return [] 

134 expressions = [] 

135 group_by_refs = set() 

136 if self.query.group_by is not True: 

137 # If the group by is set to a list (by .values() call most likely), 

138 # then we need to add everything in it to the GROUP BY clause. 

139 # Backwards compatibility hack for setting query.group_by. Remove 

140 # when we have public API way of forcing the GROUP BY clause. 

141 # Converts string references to expressions. 

142 for expr in self.query.group_by: 

143 if not hasattr(expr, "as_sql"): 

144 expr = self.query.resolve_ref(expr) 

145 if isinstance(expr, Ref): 

146 if expr.refs not in group_by_refs: 

147 group_by_refs.add(expr.refs) 

148 expressions.append(expr.source) 

149 else: 

150 expressions.append(expr) 

151 # Note that even if the group_by is set, it is only the minimal 

152 # set to group by. So, we need to add cols in select, order_by, and 

153 # having into the select in any case. 

154 selected_expr_positions = {} 

155 for ordinal, (expr, _, alias) in enumerate(select, start=1): 

156 if alias: 

157 selected_expr_positions[expr] = ordinal 

158 # Skip members of the select clause that are already explicitly 

159 # grouped against. 

160 if alias in group_by_refs: 

161 continue 

162 expressions.extend(expr.get_group_by_cols()) 

163 if not self._meta_ordering: 

164 for expr, (sql, params, is_ref) in order_by: 

165 # Skip references to the SELECT clause, as all expressions in 

166 # the SELECT clause are already part of the GROUP BY. 

167 if not is_ref: 

168 expressions.extend(expr.get_group_by_cols()) 

169 having_group_by = self.having.get_group_by_cols() if self.having else () 

170 for expr in having_group_by: 

171 expressions.append(expr) 

172 result = [] 

173 seen = set() 

174 expressions = self.collapse_group_by(expressions, having_group_by) 

175 

176 allows_group_by_select_index = ( 

177 self.connection.features.allows_group_by_select_index 

178 ) 

179 for expr in expressions: 

180 try: 

181 sql, params = self.compile(expr) 

182 except (EmptyResultSet, FullResultSet): 

183 continue 

184 if ( 

185 allows_group_by_select_index 

186 and (position := selected_expr_positions.get(expr)) is not None 

187 ): 

188 sql, params = str(position), () 

189 else: 

190 sql, params = expr.select_format(self, sql, params) 

191 params_hash = make_hashable(params) 

192 if (sql, params_hash) not in seen: 

193 result.append((sql, params)) 

194 seen.add((sql, params_hash)) 

195 return result 

196 

197 def collapse_group_by(self, expressions, having): 

198 # If the database supports group by functional dependence reduction, 

199 # then the expressions can be reduced to the set of selected table 

200 # primary keys as all other columns are functionally dependent on them. 

201 if self.connection.features.allows_group_by_selected_pks: 

202 # Filter out all expressions associated with a table's primary key 

203 # present in the grouped columns. This is done by identifying all 

204 # tables that have their primary key included in the grouped 

205 # columns and removing non-primary key columns referring to them. 

206 # Unmanaged models are excluded because they could be representing 

207 # database views on which the optimization might not be allowed. 

208 pks = { 

209 expr 

210 for expr in expressions 

211 if ( 

212 hasattr(expr, "target") 

213 and expr.target.primary_key 

214 and self.connection.features.allows_group_by_selected_pks_on_model( 

215 expr.target.model 

216 ) 

217 ) 

218 } 

219 aliases = {expr.alias for expr in pks} 

220 expressions = [ 

221 expr 

222 for expr in expressions 

223 if expr in pks 

224 or expr in having 

225 or getattr(expr, "alias", None) not in aliases 

226 ] 

227 return expressions 

228 

229 def get_select(self, with_col_aliases=False): 

230 """ 

231 Return three values: 

232 - a list of 3-tuples of (expression, (sql, params), alias) 

233 - a klass_info structure, 

234 - a dictionary of annotations 

235 

236 The (sql, params) is what the expression will produce, and alias is the 

237 "AS alias" for the column (possibly None). 

238 

239 The klass_info structure contains the following information: 

240 - The base model of the query. 

241 - Which columns for that model are present in the query (by 

242 position of the select clause). 

243 - related_klass_infos: [f, klass_info] to descent into 

244 

245 The annotations is a dictionary of {'attname': column position} values. 

246 """ 

247 select = [] 

248 klass_info = None 

249 annotations = {} 

250 select_idx = 0 

251 for alias, (sql, params) in self.query.extra_select.items(): 

252 annotations[alias] = select_idx 

253 select.append((RawSQL(sql, params), alias)) 

254 select_idx += 1 

255 assert not (self.query.select and self.query.default_cols) 

256 select_mask = self.query.get_select_mask() 

257 if self.query.default_cols: 

258 cols = self.get_default_columns(select_mask) 

259 else: 

260 # self.query.select is a special case. These columns never go to 

261 # any model. 

262 cols = self.query.select 

263 if cols: 

264 select_list = [] 

265 for col in cols: 

266 select_list.append(select_idx) 

267 select.append((col, None)) 

268 select_idx += 1 

269 klass_info = { 

270 "model": self.query.model, 

271 "select_fields": select_list, 

272 } 

273 for alias, annotation in self.query.annotation_select.items(): 

274 annotations[alias] = select_idx 

275 select.append((annotation, alias)) 

276 select_idx += 1 

277 

278 if self.query.select_related: 

279 related_klass_infos = self.get_related_selections(select, select_mask) 

280 klass_info["related_klass_infos"] = related_klass_infos 

281 

282 def get_select_from_parent(klass_info): 

283 for ki in klass_info["related_klass_infos"]: 

284 if ki["from_parent"]: 

285 ki["select_fields"] = ( 

286 klass_info["select_fields"] + ki["select_fields"] 

287 ) 

288 get_select_from_parent(ki) 

289 

290 get_select_from_parent(klass_info) 

291 

292 ret = [] 

293 col_idx = 1 

294 for col, alias in select: 

295 try: 

296 sql, params = self.compile(col) 

297 except EmptyResultSet: 

298 empty_result_set_value = getattr( 

299 col, "empty_result_set_value", NotImplemented 

300 ) 

301 if empty_result_set_value is NotImplemented: 

302 # Select a predicate that's always False. 

303 sql, params = "0", () 

304 else: 

305 sql, params = self.compile(Value(empty_result_set_value)) 

306 except FullResultSet: 

307 sql, params = self.compile(Value(True)) 

308 else: 

309 sql, params = col.select_format(self, sql, params) 

310 if alias is None and with_col_aliases: 

311 alias = f"col{col_idx}" 

312 col_idx += 1 

313 ret.append((col, (sql, params), alias)) 

314 return ret, klass_info, annotations 

315 

316 def _order_by_pairs(self): 

317 if self.query.extra_order_by: 

318 ordering = self.query.extra_order_by 

319 elif not self.query.default_ordering: 

320 ordering = self.query.order_by 

321 elif self.query.order_by: 

322 ordering = self.query.order_by 

323 elif (meta := self.query.get_meta()) and meta.ordering: 

324 ordering = meta.ordering 

325 self._meta_ordering = ordering 

326 else: 

327 ordering = [] 

328 if self.query.standard_ordering: 

329 default_order, _ = ORDER_DIR["ASC"] 

330 else: 

331 default_order, _ = ORDER_DIR["DESC"] 

332 

333 selected_exprs = {} 

334 if select := self.select: 

335 for ordinal, (expr, _, alias) in enumerate(select, start=1): 

336 pos_expr = PositionRef(ordinal, alias, expr) 

337 if alias: 

338 selected_exprs[alias] = pos_expr 

339 selected_exprs[expr] = pos_expr 

340 

341 for field in ordering: 

342 if hasattr(field, "resolve_expression"): 

343 if isinstance(field, Value): 

344 # output_field must be resolved for constants. 

345 field = Cast(field, field.output_field) 

346 if not isinstance(field, OrderBy): 

347 field = field.asc() 

348 if not self.query.standard_ordering: 

349 field = field.copy() 

350 field.reverse_ordering() 

351 select_ref = selected_exprs.get(field.expression) 

352 if select_ref or ( 

353 isinstance(field.expression, F) 

354 and (select_ref := selected_exprs.get(field.expression.name)) 

355 ): 

356 # Emulation of NULLS (FIRST|LAST) cannot be combined with 

357 # the usage of ordering by position. 

358 if ( 

359 field.nulls_first is None and field.nulls_last is None 

360 ) or self.connection.features.supports_order_by_nulls_modifier: 

361 field = field.copy() 

362 field.expression = select_ref 

363 # Alias collisions are not possible when dealing with 

364 # combined queries so fallback to it if emulation of NULLS 

365 # handling is required. 

366 elif self.query.combinator: 

367 field = field.copy() 

368 field.expression = Ref(select_ref.refs, select_ref.source) 

369 yield field, select_ref is not None 

370 continue 

371 if field == "?": # random 

372 yield OrderBy(Random()), False 

373 continue 

374 

375 col, order = get_order_dir(field, default_order) 

376 descending = order == "DESC" 

377 

378 if select_ref := selected_exprs.get(col): 

379 # Reference to expression in SELECT clause 

380 yield ( 

381 OrderBy( 

382 select_ref, 

383 descending=descending, 

384 ), 

385 True, 

386 ) 

387 continue 

388 if col in self.query.annotations: 

389 # References to an expression which is masked out of the SELECT 

390 # clause. 

391 if self.query.combinator and self.select: 

392 # Don't use the resolved annotation because other 

393 # combinated queries might define it differently. 

394 expr = F(col) 

395 else: 

396 expr = self.query.annotations[col] 

397 if isinstance(expr, Value): 

398 # output_field must be resolved for constants. 

399 expr = Cast(expr, expr.output_field) 

400 yield OrderBy(expr, descending=descending), False 

401 continue 

402 

403 if "." in field: 

404 # This came in through an extra(order_by=...) addition. Pass it 

405 # on verbatim. 

406 table, col = col.split(".", 1) 

407 yield ( 

408 OrderBy( 

409 RawSQL(f"{self.quote_name_unless_alias(table)}.{col}", []), 

410 descending=descending, 

411 ), 

412 False, 

413 ) 

414 continue 

415 

416 if self.query.extra and col in self.query.extra: 

417 if col in self.query.extra_select: 

418 yield ( 

419 OrderBy( 

420 Ref(col, RawSQL(*self.query.extra[col])), 

421 descending=descending, 

422 ), 

423 True, 

424 ) 

425 else: 

426 yield ( 

427 OrderBy(RawSQL(*self.query.extra[col]), descending=descending), 

428 False, 

429 ) 

430 else: 

431 if self.query.combinator and self.select: 

432 # Don't use the first model's field because other 

433 # combinated queries might define it differently. 

434 yield OrderBy(F(col), descending=descending), False 

435 else: 

436 # 'col' is of the form 'field' or 'field1__field2' or 

437 # '-field1__field2__field', etc. 

438 yield from self.find_ordering_name( 

439 field, 

440 self.query.get_meta(), 

441 default_order=default_order, 

442 ) 

443 

444 def get_order_by(self): 

445 """ 

446 Return a list of 2-tuples of the form (expr, (sql, params, is_ref)) for 

447 the ORDER BY clause. 

448 

449 The order_by clause can alter the select clause (for example it can add 

450 aliases to clauses that do not yet have one, or it can add totally new 

451 select clauses). 

452 """ 

453 result = [] 

454 seen = set() 

455 for expr, is_ref in self._order_by_pairs(): 

456 resolved = expr.resolve_expression(self.query, allow_joins=True, reuse=None) 

457 if not is_ref and self.query.combinator and self.select: 

458 src = resolved.expression 

459 expr_src = expr.expression 

460 for sel_expr, _, col_alias in self.select: 

461 if src == sel_expr: 

462 # When values() is used the exact alias must be used to 

463 # reference annotations. 

464 if ( 

465 self.query.has_select_fields 

466 and col_alias in self.query.annotation_select 

467 and not ( 

468 isinstance(expr_src, F) and col_alias == expr_src.name 

469 ) 

470 ): 

471 continue 

472 resolved.set_source_expressions( 

473 [Ref(col_alias if col_alias else src.target.column, src)] 

474 ) 

475 break 

476 else: 

477 # Add column used in ORDER BY clause to the selected 

478 # columns and to each combined query. 

479 order_by_idx = len(self.query.select) + 1 

480 col_alias = f"__orderbycol{order_by_idx}" 

481 for q in self.query.combined_queries: 

482 # If fields were explicitly selected through values() 

483 # combined queries cannot be augmented. 

484 if q.has_select_fields: 

485 raise DatabaseError( 

486 "ORDER BY term does not match any column in " 

487 "the result set." 

488 ) 

489 q.add_annotation(expr_src, col_alias) 

490 self.query.add_select_col(resolved, col_alias) 

491 resolved.set_source_expressions([Ref(col_alias, src)]) 

492 sql, params = self.compile(resolved) 

493 # Don't add the same column twice, but the order direction is 

494 # not taken into account so we strip it. When this entire method 

495 # is refactored into expressions, then we can check each part as we 

496 # generate it. 

497 without_ordering = self.ordering_parts.search(sql)[1] 

498 params_hash = make_hashable(params) 

499 if (without_ordering, params_hash) in seen: 

500 continue 

501 seen.add((without_ordering, params_hash)) 

502 result.append((resolved, (sql, params, is_ref))) 

503 return result 

504 

505 def get_extra_select(self, order_by, select): 

506 extra_select = [] 

507 if self.query.distinct and not self.query.distinct_fields: 

508 select_sql = [t[1] for t in select] 

509 for expr, (sql, params, is_ref) in order_by: 

510 without_ordering = self.ordering_parts.search(sql)[1] 

511 if not is_ref and (without_ordering, params) not in select_sql: 

512 extra_select.append((expr, (without_ordering, params), None)) 

513 return extra_select 

514 

515 def quote_name_unless_alias(self, name): 

516 """ 

517 A wrapper around connection.ops.quote_name that doesn't quote aliases 

518 for table names. This avoids problems with some SQL dialects that treat 

519 quoted strings specially (e.g. PostgreSQL). 

520 """ 

521 if name in self.quote_cache: 

522 return self.quote_cache[name] 

523 if ( 

524 (name in self.query.alias_map and name not in self.query.table_map) 

525 or name in self.query.extra_select 

526 or ( 

527 self.query.external_aliases.get(name) 

528 and name not in self.query.table_map 

529 ) 

530 ): 

531 self.quote_cache[name] = name 

532 return name 

533 r = self.connection.ops.quote_name(name) 

534 self.quote_cache[name] = r 

535 return r 

536 

537 def compile(self, node): 

538 vendor_impl = getattr(node, "as_" + self.connection.vendor, None) 

539 if vendor_impl: 

540 sql, params = vendor_impl(self, self.connection) 

541 else: 

542 sql, params = node.as_sql(self, self.connection) 

543 return sql, params 

544 

545 def get_combinator_sql(self, combinator, all): 

546 features = self.connection.features 

547 compilers = [ 

548 query.get_compiler(self.using, self.connection, self.elide_empty) 

549 for query in self.query.combined_queries 

550 ] 

551 if not features.supports_slicing_ordering_in_compound: 

552 for compiler in compilers: 

553 if compiler.query.is_sliced: 

554 raise DatabaseError( 

555 "LIMIT/OFFSET not allowed in subqueries of compound statements." 

556 ) 

557 if compiler.get_order_by(): 

558 raise DatabaseError( 

559 "ORDER BY not allowed in subqueries of compound statements." 

560 ) 

561 elif self.query.is_sliced and combinator == "union": 

562 for compiler in compilers: 

563 # A sliced union cannot have its parts elided as some of them 

564 # might be sliced as well and in the event where only a single 

565 # part produces a non-empty resultset it might be impossible to 

566 # generate valid SQL. 

567 compiler.elide_empty = False 

568 parts = () 

569 for compiler in compilers: 

570 try: 

571 # If the columns list is limited, then all combined queries 

572 # must have the same columns list. Set the selects defined on 

573 # the query on all combined queries, if not already set. 

574 if not compiler.query.values_select and self.query.values_select: 

575 compiler.query = compiler.query.clone() 

576 compiler.query.set_values( 

577 ( 

578 *self.query.extra_select, 

579 *self.query.values_select, 

580 *self.query.annotation_select, 

581 ) 

582 ) 

583 part_sql, part_args = compiler.as_sql(with_col_aliases=True) 

584 if compiler.query.combinator: 

585 # Wrap in a subquery if wrapping in parentheses isn't 

586 # supported. 

587 if not features.supports_parentheses_in_compound: 

588 part_sql = f"SELECT * FROM ({part_sql})" 

589 # Add parentheses when combining with compound query if not 

590 # already added for all compound queries. 

591 elif ( 

592 self.query.subquery 

593 or not features.supports_slicing_ordering_in_compound 

594 ): 

595 part_sql = f"({part_sql})" 

596 elif ( 

597 self.query.subquery 

598 and features.supports_slicing_ordering_in_compound 

599 ): 

600 part_sql = f"({part_sql})" 

601 parts += ((part_sql, part_args),) 

602 except EmptyResultSet: 

603 # Omit the empty queryset with UNION and with DIFFERENCE if the 

604 # first queryset is nonempty. 

605 if combinator == "union" or (combinator == "difference" and parts): 

606 continue 

607 raise 

608 if not parts: 

609 raise EmptyResultSet 

610 combinator_sql = self.connection.ops.set_operators[combinator] 

611 if all and combinator == "union": 

612 combinator_sql += " ALL" 

613 braces = "{}" 

614 if not self.query.subquery and features.supports_slicing_ordering_in_compound: 

615 braces = "({})" 

616 sql_parts, args_parts = zip( 

617 *((braces.format(sql), args) for sql, args in parts) 

618 ) 

619 result = [f" {combinator_sql} ".join(sql_parts)] 

620 params = [] 

621 for part in args_parts: 

622 params.extend(part) 

623 return result, params 

624 

625 def get_qualify_sql(self): 

626 where_parts = [] 

627 if self.where: 

628 where_parts.append(self.where) 

629 if self.having: 

630 where_parts.append(self.having) 

631 inner_query = self.query.clone() 

632 inner_query.subquery = True 

633 inner_query.where = inner_query.where.__class__(where_parts) 

634 # Augment the inner query with any window function references that 

635 # might have been masked via values() and alias(). If any masked 

636 # aliases are added they'll be masked again to avoid fetching 

637 # the data in the `if qual_aliases` branch below. 

638 select = { 

639 expr: alias for expr, _, alias in self.get_select(with_col_aliases=True)[0] 

640 } 

641 select_aliases = set(select.values()) 

642 qual_aliases = set() 

643 replacements = {} 

644 

645 def collect_replacements(expressions): 

646 while expressions: 

647 expr = expressions.pop() 

648 if expr in replacements: 

649 continue 

650 elif select_alias := select.get(expr): 

651 replacements[expr] = select_alias 

652 elif isinstance(expr, Lookup): 

653 expressions.extend(expr.get_source_expressions()) 

654 elif isinstance(expr, Ref): 

655 if expr.refs not in select_aliases: 

656 expressions.extend(expr.get_source_expressions()) 

657 else: 

658 num_qual_alias = len(qual_aliases) 

659 select_alias = f"qual{num_qual_alias}" 

660 qual_aliases.add(select_alias) 

661 inner_query.add_annotation(expr, select_alias) 

662 replacements[expr] = select_alias 

663 

664 collect_replacements(list(self.qualify.leaves())) 

665 self.qualify = self.qualify.replace_expressions( 

666 {expr: Ref(alias, expr) for expr, alias in replacements.items()} 

667 ) 

668 order_by = [] 

669 for order_by_expr, *_ in self.get_order_by(): 

670 collect_replacements(order_by_expr.get_source_expressions()) 

671 order_by.append( 

672 order_by_expr.replace_expressions( 

673 {expr: Ref(alias, expr) for expr, alias in replacements.items()} 

674 ) 

675 ) 

676 inner_query_compiler = inner_query.get_compiler( 

677 self.using, connection=self.connection, elide_empty=self.elide_empty 

678 ) 

679 inner_sql, inner_params = inner_query_compiler.as_sql( 

680 # The limits must be applied to the outer query to avoid pruning 

681 # results too eagerly. 

682 with_limits=False, 

683 # Force unique aliasing of selected columns to avoid collisions 

684 # and make rhs predicates referencing easier. 

685 with_col_aliases=True, 

686 ) 

687 qualify_sql, qualify_params = self.compile(self.qualify) 

688 result = [ 

689 "SELECT * FROM (", 

690 inner_sql, 

691 ")", 

692 self.connection.ops.quote_name("qualify"), 

693 "WHERE", 

694 qualify_sql, 

695 ] 

696 if qual_aliases: 

697 # If some select aliases were unmasked for filtering purposes they 

698 # must be masked back. 

699 cols = [self.connection.ops.quote_name(alias) for alias in select.values()] 

700 result = [ 

701 "SELECT", 

702 ", ".join(cols), 

703 "FROM (", 

704 *result, 

705 ")", 

706 self.connection.ops.quote_name("qualify_mask"), 

707 ] 

708 params = list(inner_params) + qualify_params 

709 # As the SQL spec is unclear on whether or not derived tables 

710 # ordering must propagate it has to be explicitly repeated on the 

711 # outer-most query to ensure it's preserved. 

712 if order_by: 

713 ordering_sqls = [] 

714 for ordering in order_by: 

715 ordering_sql, ordering_params = self.compile(ordering) 

716 ordering_sqls.append(ordering_sql) 

717 params.extend(ordering_params) 

718 result.extend(["ORDER BY", ", ".join(ordering_sqls)]) 

719 return result, params 

720 

721 def as_sql(self, with_limits=True, with_col_aliases=False): 

722 """ 

723 Create the SQL for this query. Return the SQL string and list of 

724 parameters. 

725 

726 If 'with_limits' is False, any limit/offset information is not included 

727 in the query. 

728 """ 

729 refcounts_before = self.query.alias_refcount.copy() 

730 try: 

731 combinator = self.query.combinator 

732 extra_select, order_by, group_by = self.pre_sql_setup( 

733 with_col_aliases=with_col_aliases or bool(combinator), 

734 ) 

735 for_update_part = None 

736 # Is a LIMIT/OFFSET clause needed? 

737 with_limit_offset = with_limits and self.query.is_sliced 

738 combinator = self.query.combinator 

739 features = self.connection.features 

740 if combinator: 

741 if not getattr(features, f"supports_select_{combinator}"): 

742 raise NotSupportedError( 

743 f"{combinator} is not supported on this database backend." 

744 ) 

745 result, params = self.get_combinator_sql( 

746 combinator, self.query.combinator_all 

747 ) 

748 elif self.qualify: 

749 result, params = self.get_qualify_sql() 

750 order_by = None 

751 else: 

752 distinct_fields, distinct_params = self.get_distinct() 

753 # This must come after 'select', 'ordering', and 'distinct' 

754 # (see docstring of get_from_clause() for details). 

755 from_, f_params = self.get_from_clause() 

756 try: 

757 where, w_params = ( 

758 self.compile(self.where) if self.where is not None else ("", []) 

759 ) 

760 except EmptyResultSet: 

761 if self.elide_empty: 

762 raise 

763 # Use a predicate that's always False. 

764 where, w_params = "0 = 1", [] 

765 except FullResultSet: 

766 where, w_params = "", [] 

767 try: 

768 having, h_params = ( 

769 self.compile(self.having) 

770 if self.having is not None 

771 else ("", []) 

772 ) 

773 except FullResultSet: 

774 having, h_params = "", [] 

775 result = ["SELECT"] 

776 params = [] 

777 

778 if self.query.distinct: 

779 distinct_result, distinct_params = self.connection.ops.distinct_sql( 

780 distinct_fields, 

781 distinct_params, 

782 ) 

783 result += distinct_result 

784 params += distinct_params 

785 

786 out_cols = [] 

787 for _, (s_sql, s_params), alias in self.select + extra_select: 

788 if alias: 

789 s_sql = f"{s_sql} AS {self.connection.ops.quote_name(alias)}" 

790 params.extend(s_params) 

791 out_cols.append(s_sql) 

792 

793 result += [", ".join(out_cols)] 

794 if from_: 

795 result += ["FROM", *from_] 

796 elif self.connection.features.bare_select_suffix: 

797 result += [self.connection.features.bare_select_suffix] 

798 params.extend(f_params) 

799 

800 if self.query.select_for_update and features.has_select_for_update: 

801 if ( 

802 self.connection.get_autocommit() 

803 # Don't raise an exception when database doesn't 

804 # support transactions, as it's a noop. 

805 and features.supports_transactions 

806 ): 

807 raise TransactionManagementError( 

808 "select_for_update cannot be used outside of a transaction." 

809 ) 

810 

811 if ( 

812 with_limit_offset 

813 and not features.supports_select_for_update_with_limit 

814 ): 

815 raise NotSupportedError( 

816 "LIMIT/OFFSET is not supported with " 

817 "select_for_update on this database backend." 

818 ) 

819 nowait = self.query.select_for_update_nowait 

820 skip_locked = self.query.select_for_update_skip_locked 

821 of = self.query.select_for_update_of 

822 no_key = self.query.select_for_no_key_update 

823 # If it's a NOWAIT/SKIP LOCKED/OF/NO KEY query but the 

824 # backend doesn't support it, raise NotSupportedError to 

825 # prevent a possible deadlock. 

826 if nowait and not features.has_select_for_update_nowait: 

827 raise NotSupportedError( 

828 "NOWAIT is not supported on this database backend." 

829 ) 

830 elif skip_locked and not features.has_select_for_update_skip_locked: 

831 raise NotSupportedError( 

832 "SKIP LOCKED is not supported on this database backend." 

833 ) 

834 elif of and not features.has_select_for_update_of: 

835 raise NotSupportedError( 

836 "FOR UPDATE OF is not supported on this database backend." 

837 ) 

838 elif no_key and not features.has_select_for_no_key_update: 

839 raise NotSupportedError( 

840 "FOR NO KEY UPDATE is not supported on this " 

841 "database backend." 

842 ) 

843 for_update_part = self.connection.ops.for_update_sql( 

844 nowait=nowait, 

845 skip_locked=skip_locked, 

846 of=self.get_select_for_update_of_arguments(), 

847 no_key=no_key, 

848 ) 

849 

850 if for_update_part and features.for_update_after_from: 

851 result.append(for_update_part) 

852 

853 if where: 

854 result.append(f"WHERE {where}") 

855 params.extend(w_params) 

856 

857 grouping = [] 

858 for g_sql, g_params in group_by: 

859 grouping.append(g_sql) 

860 params.extend(g_params) 

861 if grouping: 

862 if distinct_fields: 

863 raise NotImplementedError( 

864 "annotate() + distinct(fields) is not implemented." 

865 ) 

866 order_by = order_by or self.connection.ops.force_no_ordering() 

867 result.append("GROUP BY {}".format(", ".join(grouping))) 

868 if self._meta_ordering: 

869 order_by = None 

870 if having: 

871 result.append(f"HAVING {having}") 

872 params.extend(h_params) 

873 

874 if self.query.explain_info: 

875 result.insert( 

876 0, 

877 self.connection.ops.explain_query_prefix( 

878 self.query.explain_info.format, 

879 **self.query.explain_info.options, 

880 ), 

881 ) 

882 

883 if order_by: 

884 ordering = [] 

885 for _, (o_sql, o_params, _) in order_by: 

886 ordering.append(o_sql) 

887 params.extend(o_params) 

888 order_by_sql = "ORDER BY {}".format(", ".join(ordering)) 

889 if combinator and features.requires_compound_order_by_subquery: 

890 result = ["SELECT * FROM (", *result, ")", order_by_sql] 

891 else: 

892 result.append(order_by_sql) 

893 

894 if with_limit_offset: 

895 result.append( 

896 self.connection.ops.limit_offset_sql( 

897 self.query.low_mark, self.query.high_mark 

898 ) 

899 ) 

900 

901 if for_update_part and not features.for_update_after_from: 

902 result.append(for_update_part) 

903 

904 if self.query.subquery and extra_select: 

905 # If the query is used as a subquery, the extra selects would 

906 # result in more columns than the left-hand side expression is 

907 # expecting. This can happen when a subquery uses a combination 

908 # of order_by() and distinct(), forcing the ordering expressions 

909 # to be selected as well. Wrap the query in another subquery 

910 # to exclude extraneous selects. 

911 sub_selects = [] 

912 sub_params = [] 

913 for index, (select, _, alias) in enumerate(self.select, start=1): 

914 if alias: 

915 sub_selects.append( 

916 "{}.{}".format( 

917 self.connection.ops.quote_name("subquery"), 

918 self.connection.ops.quote_name(alias), 

919 ) 

920 ) 

921 else: 

922 select_clone = select.relabeled_clone( 

923 {select.alias: "subquery"} 

924 ) 

925 subselect, subparams = select_clone.as_sql( 

926 self, self.connection 

927 ) 

928 sub_selects.append(subselect) 

929 sub_params.extend(subparams) 

930 return "SELECT {} FROM ({}) subquery".format( 

931 ", ".join(sub_selects), 

932 " ".join(result), 

933 ), tuple(sub_params + params) 

934 

935 return " ".join(result), tuple(params) 

936 finally: 

937 # Finally do cleanup - get rid of the joins we created above. 

938 self.query.reset_refcounts(refcounts_before) 

939 

940 def get_default_columns( 

941 self, select_mask, start_alias=None, opts=None, from_parent=None 

942 ): 

943 """ 

944 Compute the default columns for selecting every field in the base 

945 model. Will sometimes be called to pull in related models (e.g. via 

946 select_related), in which case "opts" and "start_alias" will be given 

947 to provide a starting point for the traversal. 

948 

949 Return a list of strings, quoted appropriately for use in SQL 

950 directly, as well as a set of aliases used in the select statement (if 

951 'as_pairs' is True, return a list of (alias, col_name) pairs instead 

952 of strings as the first component and None as the second component). 

953 """ 

954 result = [] 

955 if opts is None: 

956 if (opts := self.query.get_meta()) is None: 

957 return result 

958 start_alias = start_alias or self.query.get_initial_alias() 

959 # The 'seen_models' is used to optimize checking the needed parent 

960 # alias for a given field. This also includes None -> start_alias to 

961 # be used by local fields. 

962 seen_models = {None: start_alias} 

963 

964 for field in opts.concrete_fields: 

965 model = field.model._meta.concrete_model 

966 # A proxy model will have a different model and concrete_model. We 

967 # will assign None if the field belongs to this model. 

968 if model == opts.model: 

969 model = None 

970 if ( 

971 from_parent 

972 and model is not None 

973 and issubclass( 

974 from_parent._meta.concrete_model, model._meta.concrete_model 

975 ) 

976 ): 

977 # Avoid loading data for already loaded parents. 

978 # We end up here in the case select_related() resolution 

979 # proceeds from parent model to child model. In that case the 

980 # parent model data is already present in the SELECT clause, 

981 # and we want to avoid reloading the same data again. 

982 continue 

983 if select_mask and field not in select_mask: 

984 continue 

985 alias = self.query.join_parent_model(opts, model, start_alias, seen_models) 

986 column = field.get_col(alias) 

987 result.append(column) 

988 return result 

989 

990 def get_distinct(self): 

991 """ 

992 Return a quoted list of fields to use in DISTINCT ON part of the query. 

993 

994 This method can alter the tables in the query, and thus it must be 

995 called before get_from_clause(). 

996 """ 

997 result = [] 

998 params = [] 

999 opts = self.query.get_meta() 

1000 

1001 for name in self.query.distinct_fields: 

1002 parts = name.split(LOOKUP_SEP) 

1003 _, targets, alias, joins, path, _, transform_function = self._setup_joins( 

1004 parts, opts, None 

1005 ) 

1006 targets, alias, _ = self.query.trim_joins(targets, joins, path) 

1007 for target in targets: 

1008 if name in self.query.annotation_select: 

1009 result.append(self.connection.ops.quote_name(name)) 

1010 else: 

1011 r, p = self.compile(transform_function(target, alias)) 

1012 result.append(r) 

1013 params.append(p) 

1014 return result, params 

1015 

1016 def find_ordering_name( 

1017 self, name, opts, alias=None, default_order="ASC", already_seen=None 

1018 ): 

1019 """ 

1020 Return the table alias (the name might be ambiguous, the alias will 

1021 not be) and column name for ordering by the given 'name' parameter. 

1022 The 'name' is of the form 'field1__field2__...__fieldN'. 

1023 """ 

1024 name, order = get_order_dir(name, default_order) 

1025 descending = order == "DESC" 

1026 pieces = name.split(LOOKUP_SEP) 

1027 ( 

1028 field, 

1029 targets, 

1030 alias, 

1031 joins, 

1032 path, 

1033 opts, 

1034 transform_function, 

1035 ) = self._setup_joins(pieces, opts, alias) 

1036 

1037 # If we get to this point and the field is a relation to another model, 

1038 # append the default ordering for that model unless it is the pk 

1039 # shortcut or the attribute name of the field that is specified or 

1040 # there are transforms to process. 

1041 if ( 

1042 field.is_relation 

1043 and opts.ordering 

1044 and getattr(field, "attname", None) != pieces[-1] 

1045 and name != "pk" 

1046 and not getattr(transform_function, "has_transforms", False) 

1047 ): 

1048 # Firstly, avoid infinite loops. 

1049 already_seen = already_seen or set() 

1050 join_tuple = tuple( 

1051 getattr(self.query.alias_map[j], "join_cols", None) for j in joins 

1052 ) 

1053 if join_tuple in already_seen: 

1054 raise FieldError("Infinite loop caused by ordering.") 

1055 already_seen.add(join_tuple) 

1056 

1057 results = [] 

1058 for item in opts.ordering: 

1059 if hasattr(item, "resolve_expression") and not isinstance( 

1060 item, OrderBy 

1061 ): 

1062 item = item.desc() if descending else item.asc() 

1063 if isinstance(item, OrderBy): 

1064 results.append( 

1065 (item.prefix_references(f"{name}{LOOKUP_SEP}"), False) 

1066 ) 

1067 continue 

1068 results.extend( 

1069 (expr.prefix_references(f"{name}{LOOKUP_SEP}"), is_ref) 

1070 for expr, is_ref in self.find_ordering_name( 

1071 item, opts, alias, order, already_seen 

1072 ) 

1073 ) 

1074 return results 

1075 targets, alias, _ = self.query.trim_joins(targets, joins, path) 

1076 return [ 

1077 (OrderBy(transform_function(t, alias), descending=descending), False) 

1078 for t in targets 

1079 ] 

1080 

1081 def _setup_joins(self, pieces, opts, alias): 

1082 """ 

1083 Helper method for get_order_by() and get_distinct(). 

1084 

1085 get_ordering() and get_distinct() must produce same target columns on 

1086 same input, as the prefixes of get_ordering() and get_distinct() must 

1087 match. Executing SQL where this is not true is an error. 

1088 """ 

1089 alias = alias or self.query.get_initial_alias() 

1090 field, targets, opts, joins, path, transform_function = self.query.setup_joins( 

1091 pieces, opts, alias 

1092 ) 

1093 alias = joins[-1] 

1094 return field, targets, alias, joins, path, opts, transform_function 

1095 

1096 def get_from_clause(self): 

1097 """ 

1098 Return a list of strings that are joined together to go after the 

1099 "FROM" part of the query, as well as a list any extra parameters that 

1100 need to be included. Subclasses, can override this to create a 

1101 from-clause via a "select". 

1102 

1103 This should only be called after any SQL construction methods that 

1104 might change the tables that are needed. This means the select columns, 

1105 ordering, and distinct must be done first. 

1106 """ 

1107 result = [] 

1108 params = [] 

1109 for alias in tuple(self.query.alias_map): 

1110 if not self.query.alias_refcount[alias]: 

1111 continue 

1112 try: 

1113 from_clause = self.query.alias_map[alias] 

1114 except KeyError: 

1115 # Extra tables can end up in self.tables, but not in the 

1116 # alias_map if they aren't in a join. That's OK. We skip them. 

1117 continue 

1118 clause_sql, clause_params = self.compile(from_clause) 

1119 result.append(clause_sql) 

1120 params.extend(clause_params) 

1121 for t in self.query.extra_tables: 

1122 alias, _ = self.query.table_alias(t) 

1123 # Only add the alias if it's not already present (the table_alias() 

1124 # call increments the refcount, so an alias refcount of one means 

1125 # this is the only reference). 

1126 if ( 

1127 alias not in self.query.alias_map 

1128 or self.query.alias_refcount[alias] == 1 

1129 ): 

1130 result.append(f", {self.quote_name_unless_alias(alias)}") 

1131 return result, params 

1132 

1133 def get_related_selections( 

1134 self, 

1135 select, 

1136 select_mask, 

1137 opts=None, 

1138 root_alias=None, 

1139 cur_depth=1, 

1140 requested=None, 

1141 restricted=None, 

1142 ): 

1143 """ 

1144 Fill in the information needed for a select_related query. The current 

1145 depth is measured as the number of connections away from the root model 

1146 (for example, cur_depth=1 means we are looking at models with direct 

1147 connections to the root model). 

1148 """ 

1149 

1150 def _get_field_choices(): 

1151 direct_choices = (f.name for f in opts.fields if f.is_relation) 

1152 reverse_choices = ( 

1153 f.field.related_query_name() 

1154 for f in opts.related_objects 

1155 if f.field.unique 

1156 ) 

1157 return chain( 

1158 direct_choices, reverse_choices, self.query._filtered_relations 

1159 ) 

1160 

1161 related_klass_infos = [] 

1162 if not restricted and cur_depth > self.query.max_depth: 

1163 # We've recursed far enough; bail out. 

1164 return related_klass_infos 

1165 

1166 if not opts: 

1167 opts = self.query.get_meta() 

1168 root_alias = self.query.get_initial_alias() 

1169 

1170 # Setup for the case when only particular related fields should be 

1171 # included in the related selection. 

1172 fields_found = set() 

1173 if requested is None: 

1174 restricted = isinstance(self.query.select_related, dict) 

1175 if restricted: 

1176 requested = self.query.select_related 

1177 

1178 def get_related_klass_infos(klass_info, related_klass_infos): 

1179 klass_info["related_klass_infos"] = related_klass_infos 

1180 

1181 for f in opts.fields: 

1182 fields_found.add(f.name) 

1183 

1184 if restricted: 

1185 next = requested.get(f.name, {}) 

1186 if not f.is_relation: 

1187 # If a non-related field is used like a relation, 

1188 # or if a single non-relational field is given. 

1189 if next or f.name in requested: 

1190 raise FieldError( 

1191 "Non-relational field given in select_related: '{}'. " 

1192 "Choices are: {}".format( 

1193 f.name, 

1194 ", ".join(_get_field_choices()) or "(none)", 

1195 ) 

1196 ) 

1197 else: 

1198 next = False 

1199 

1200 if not select_related_descend(f, restricted, requested, select_mask): 

1201 continue 

1202 related_select_mask = select_mask.get(f) or {} 

1203 klass_info = { 

1204 "model": f.remote_field.model, 

1205 "field": f, 

1206 "reverse": False, 

1207 "local_setter": f.set_cached_value, 

1208 "remote_setter": f.remote_field.set_cached_value 

1209 if f.unique 

1210 else lambda x, y: None, 

1211 "from_parent": False, 

1212 } 

1213 related_klass_infos.append(klass_info) 

1214 select_fields = [] 

1215 _, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias) 

1216 alias = joins[-1] 

1217 columns = self.get_default_columns( 

1218 related_select_mask, start_alias=alias, opts=f.remote_field.model._meta 

1219 ) 

1220 for col in columns: 

1221 select_fields.append(len(select)) 

1222 select.append((col, None)) 

1223 klass_info["select_fields"] = select_fields 

1224 next_klass_infos = self.get_related_selections( 

1225 select, 

1226 related_select_mask, 

1227 f.remote_field.model._meta, 

1228 alias, 

1229 cur_depth + 1, 

1230 next, 

1231 restricted, 

1232 ) 

1233 get_related_klass_infos(klass_info, next_klass_infos) 

1234 

1235 if restricted: 

1236 related_fields = [ 

1237 (o.field, o.related_model) 

1238 for o in opts.related_objects 

1239 if o.field.unique and not o.many_to_many 

1240 ] 

1241 for related_field, model in related_fields: 

1242 related_select_mask = select_mask.get(related_field) or {} 

1243 if not select_related_descend( 

1244 related_field, 

1245 restricted, 

1246 requested, 

1247 related_select_mask, 

1248 reverse=True, 

1249 ): 

1250 continue 

1251 

1252 related_field_name = related_field.related_query_name() 

1253 fields_found.add(related_field_name) 

1254 

1255 join_info = self.query.setup_joins( 

1256 [related_field_name], opts, root_alias 

1257 ) 

1258 alias = join_info.joins[-1] 

1259 from_parent = issubclass(model, opts.model) and model is not opts.model 

1260 klass_info = { 

1261 "model": model, 

1262 "field": related_field, 

1263 "reverse": True, 

1264 "local_setter": related_field.remote_field.set_cached_value, 

1265 "remote_setter": related_field.set_cached_value, 

1266 "from_parent": from_parent, 

1267 } 

1268 related_klass_infos.append(klass_info) 

1269 select_fields = [] 

1270 columns = self.get_default_columns( 

1271 related_select_mask, 

1272 start_alias=alias, 

1273 opts=model._meta, 

1274 from_parent=opts.model, 

1275 ) 

1276 for col in columns: 

1277 select_fields.append(len(select)) 

1278 select.append((col, None)) 

1279 klass_info["select_fields"] = select_fields 

1280 next = requested.get(related_field.related_query_name(), {}) 

1281 next_klass_infos = self.get_related_selections( 

1282 select, 

1283 related_select_mask, 

1284 model._meta, 

1285 alias, 

1286 cur_depth + 1, 

1287 next, 

1288 restricted, 

1289 ) 

1290 get_related_klass_infos(klass_info, next_klass_infos) 

1291 

1292 def local_setter(final_field, obj, from_obj): 

1293 # Set a reverse fk object when relation is non-empty. 

1294 if from_obj: 

1295 final_field.remote_field.set_cached_value(from_obj, obj) 

1296 

1297 def local_setter_noop(obj, from_obj): 

1298 pass 

1299 

1300 def remote_setter(name, obj, from_obj): 

1301 setattr(from_obj, name, obj) 

1302 

1303 for name in list(requested): 

1304 # Filtered relations work only on the topmost level. 

1305 if cur_depth > 1: 

1306 break 

1307 if name in self.query._filtered_relations: 

1308 fields_found.add(name) 

1309 final_field, _, join_opts, joins, _, _ = self.query.setup_joins( 

1310 [name], opts, root_alias 

1311 ) 

1312 model = join_opts.model 

1313 alias = joins[-1] 

1314 from_parent = ( 

1315 issubclass(model, opts.model) and model is not opts.model 

1316 ) 

1317 klass_info = { 

1318 "model": model, 

1319 "field": final_field, 

1320 "reverse": True, 

1321 "local_setter": ( 

1322 partial(local_setter, final_field) 

1323 if len(joins) <= 2 

1324 else local_setter_noop 

1325 ), 

1326 "remote_setter": partial(remote_setter, name), 

1327 "from_parent": from_parent, 

1328 } 

1329 related_klass_infos.append(klass_info) 

1330 select_fields = [] 

1331 field_select_mask = select_mask.get((name, final_field)) or {} 

1332 columns = self.get_default_columns( 

1333 field_select_mask, 

1334 start_alias=alias, 

1335 opts=model._meta, 

1336 from_parent=opts.model, 

1337 ) 

1338 for col in columns: 

1339 select_fields.append(len(select)) 

1340 select.append((col, None)) 

1341 klass_info["select_fields"] = select_fields 

1342 next_requested = requested.get(name, {}) 

1343 next_klass_infos = self.get_related_selections( 

1344 select, 

1345 field_select_mask, 

1346 opts=model._meta, 

1347 root_alias=alias, 

1348 cur_depth=cur_depth + 1, 

1349 requested=next_requested, 

1350 restricted=restricted, 

1351 ) 

1352 get_related_klass_infos(klass_info, next_klass_infos) 

1353 fields_not_found = set(requested).difference(fields_found) 

1354 if fields_not_found: 

1355 invalid_fields = (f"'{s}'" for s in fields_not_found) 

1356 raise FieldError( 

1357 "Invalid field name(s) given in select_related: {}. " 

1358 "Choices are: {}".format( 

1359 ", ".join(invalid_fields), 

1360 ", ".join(_get_field_choices()) or "(none)", 

1361 ) 

1362 ) 

1363 return related_klass_infos 

1364 

1365 def get_select_for_update_of_arguments(self): 

1366 """ 

1367 Return a quoted list of arguments for the SELECT FOR UPDATE OF part of 

1368 the query. 

1369 """ 

1370 

1371 def _get_parent_klass_info(klass_info): 

1372 concrete_model = klass_info["model"]._meta.concrete_model 

1373 for parent_model, parent_link in concrete_model._meta.parents.items(): 

1374 parent_list = parent_model._meta.get_parent_list() 

1375 yield { 

1376 "model": parent_model, 

1377 "field": parent_link, 

1378 "reverse": False, 

1379 "select_fields": [ 

1380 select_index 

1381 for select_index in klass_info["select_fields"] 

1382 # Selected columns from a model or its parents. 

1383 if ( 

1384 self.select[select_index][0].target.model == parent_model 

1385 or self.select[select_index][0].target.model in parent_list 

1386 ) 

1387 ], 

1388 } 

1389 

1390 def _get_first_selected_col_from_model(klass_info): 

1391 """ 

1392 Find the first selected column from a model. If it doesn't exist, 

1393 don't lock a model. 

1394 

1395 select_fields is filled recursively, so it also contains fields 

1396 from the parent models. 

1397 """ 

1398 concrete_model = klass_info["model"]._meta.concrete_model 

1399 for select_index in klass_info["select_fields"]: 

1400 if self.select[select_index][0].target.model == concrete_model: 

1401 return self.select[select_index][0] 

1402 

1403 def _get_field_choices(): 

1404 """Yield all allowed field paths in breadth-first search order.""" 

1405 queue = collections.deque([(None, self.klass_info)]) 

1406 while queue: 

1407 parent_path, klass_info = queue.popleft() 

1408 if parent_path is None: 

1409 path = [] 

1410 yield "self" 

1411 else: 

1412 field = klass_info["field"] 

1413 if klass_info["reverse"]: 

1414 field = field.remote_field 

1415 path = parent_path + [field.name] 

1416 yield LOOKUP_SEP.join(path) 

1417 queue.extend( 

1418 (path, klass_info) 

1419 for klass_info in _get_parent_klass_info(klass_info) 

1420 ) 

1421 queue.extend( 

1422 (path, klass_info) 

1423 for klass_info in klass_info.get("related_klass_infos", []) 

1424 ) 

1425 

1426 if not self.klass_info: 

1427 return [] 

1428 result = [] 

1429 invalid_names = [] 

1430 for name in self.query.select_for_update_of: 

1431 klass_info = self.klass_info 

1432 if name == "self": 

1433 col = _get_first_selected_col_from_model(klass_info) 

1434 else: 

1435 for part in name.split(LOOKUP_SEP): 

1436 klass_infos = ( 

1437 *klass_info.get("related_klass_infos", []), 

1438 *_get_parent_klass_info(klass_info), 

1439 ) 

1440 for related_klass_info in klass_infos: 

1441 field = related_klass_info["field"] 

1442 if related_klass_info["reverse"]: 

1443 field = field.remote_field 

1444 if field.name == part: 

1445 klass_info = related_klass_info 

1446 break 

1447 else: 

1448 klass_info = None 

1449 break 

1450 if klass_info is None: 

1451 invalid_names.append(name) 

1452 continue 

1453 col = _get_first_selected_col_from_model(klass_info) 

1454 if col is not None: 

1455 if self.connection.features.select_for_update_of_column: 

1456 result.append(self.compile(col)[0]) 

1457 else: 

1458 result.append(self.quote_name_unless_alias(col.alias)) 

1459 if invalid_names: 

1460 raise FieldError( 

1461 "Invalid field name(s) given in select_for_update(of=(...)): {}. " 

1462 "Only relational fields followed in the query are allowed. " 

1463 "Choices are: {}.".format( 

1464 ", ".join(invalid_names), 

1465 ", ".join(_get_field_choices()), 

1466 ) 

1467 ) 

1468 return result 

1469 

1470 def get_converters(self, expressions): 

1471 converters = {} 

1472 for i, expression in enumerate(expressions): 

1473 if expression: 

1474 backend_converters = self.connection.ops.get_db_converters(expression) 

1475 field_converters = expression.get_db_converters(self.connection) 

1476 if backend_converters or field_converters: 

1477 converters[i] = (backend_converters + field_converters, expression) 

1478 return converters 

1479 

1480 def apply_converters(self, rows, converters): 

1481 connection = self.connection 

1482 converters = list(converters.items()) 

1483 for row in map(list, rows): 

1484 for pos, (convs, expression) in converters: 

1485 value = row[pos] 

1486 for converter in convs: 

1487 value = converter(value, expression, connection) 

1488 row[pos] = value 

1489 yield row 

1490 

1491 def results_iter( 

1492 self, 

1493 results=None, 

1494 tuple_expected=False, 

1495 chunked_fetch=False, 

1496 chunk_size=GET_ITERATOR_CHUNK_SIZE, 

1497 ): 

1498 """Return an iterator over the results from executing this query.""" 

1499 if results is None: 

1500 results = self.execute_sql( 

1501 MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size 

1502 ) 

1503 fields = [s[0] for s in self.select[0 : self.col_count]] 

1504 converters = self.get_converters(fields) 

1505 rows = chain.from_iterable(results) 

1506 if converters: 

1507 rows = self.apply_converters(rows, converters) 

1508 if tuple_expected: 

1509 rows = map(tuple, rows) 

1510 return rows 

1511 

1512 def has_results(self): 

1513 """ 

1514 Backends (e.g. NoSQL) can override this in order to use optimized 

1515 versions of "query has any results." 

1516 """ 

1517 return bool(self.execute_sql(SINGLE)) 

1518 

1519 def execute_sql( 

1520 self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE 

1521 ): 

1522 """ 

1523 Run the query against the database and return the result(s). The 

1524 return value is a single data item if result_type is SINGLE, or an 

1525 iterator over the results if the result_type is MULTI. 

1526 

1527 result_type is either MULTI (use fetchmany() to retrieve all rows), 

1528 SINGLE (only retrieve a single row), or None. In this last case, the 

1529 cursor is returned if any query is executed, since it's used by 

1530 subclasses such as InsertQuery). It's possible, however, that no query 

1531 is needed, as the filters describe an empty set. In that case, None is 

1532 returned, to avoid any unnecessary database interaction. 

1533 """ 

1534 result_type = result_type or NO_RESULTS 

1535 try: 

1536 sql, params = self.as_sql() 

1537 if not sql: 

1538 raise EmptyResultSet 

1539 except EmptyResultSet: 

1540 if result_type == MULTI: 

1541 return iter([]) 

1542 else: 

1543 return 

1544 if chunked_fetch: 

1545 cursor = self.connection.chunked_cursor() 

1546 else: 

1547 cursor = self.connection.cursor() 

1548 try: 

1549 cursor.execute(sql, params) 

1550 except Exception: 

1551 # Might fail for server-side cursors (e.g. connection closed) 

1552 cursor.close() 

1553 raise 

1554 

1555 if result_type == CURSOR: 

1556 # Give the caller the cursor to process and close. 

1557 return cursor 

1558 if result_type == SINGLE: 

1559 try: 

1560 val = cursor.fetchone() 

1561 if val: 

1562 return val[0 : self.col_count] 

1563 return val 

1564 finally: 

1565 # done with the cursor 

1566 cursor.close() 

1567 if result_type == NO_RESULTS: 

1568 cursor.close() 

1569 return 

1570 

1571 result = cursor_iter( 

1572 cursor, 

1573 self.connection.features.empty_fetchmany_value, 

1574 self.col_count if self.has_extra_select else None, 

1575 chunk_size, 

1576 ) 

1577 if not chunked_fetch or not self.connection.features.can_use_chunked_reads: 

1578 # If we are using non-chunked reads, we return the same data 

1579 # structure as normally, but ensure it is all read into memory 

1580 # before going any further. Use chunked_fetch if requested, 

1581 # unless the database doesn't support it. 

1582 return list(result) 

1583 return result 

1584 

1585 def as_subquery_condition(self, alias, columns, compiler): 

1586 qn = compiler.quote_name_unless_alias 

1587 qn2 = self.connection.ops.quote_name 

1588 

1589 for index, select_col in enumerate(self.query.select): 

1590 lhs_sql, lhs_params = self.compile(select_col) 

1591 rhs = f"{qn(alias)}.{qn2(columns[index])}" 

1592 self.query.where.add(RawSQL(f"{lhs_sql} = {rhs}", lhs_params), AND) 

1593 

1594 sql, params = self.as_sql() 

1595 return f"EXISTS ({sql})", params 

1596 

1597 def explain_query(self): 

1598 result = list(self.execute_sql()) 

1599 # Some backends return 1 item tuples with strings, and others return 

1600 # tuples with integers and strings. Flatten them out into strings. 

1601 format_ = self.query.explain_info.format 

1602 output_formatter = json.dumps if format_ and format_.lower() == "json" else str 

1603 for row in result[0]: 

1604 if not isinstance(row, str): 

1605 yield " ".join(output_formatter(c) for c in row) 

1606 else: 

1607 yield row 

1608 

1609 

1610class SQLInsertCompiler(SQLCompiler): 

1611 returning_fields = None 

1612 returning_params = () 

1613 

1614 def field_as_sql(self, field, val): 

1615 """ 

1616 Take a field and a value intended to be saved on that field, and 

1617 return placeholder SQL and accompanying params. Check for raw values, 

1618 expressions, and fields with get_placeholder() defined in that order. 

1619 

1620 When field is None, consider the value raw and use it as the 

1621 placeholder, with no corresponding parameters returned. 

1622 """ 

1623 if field is None: 

1624 # A field value of None means the value is raw. 

1625 sql, params = val, [] 

1626 elif hasattr(val, "as_sql"): 

1627 # This is an expression, let's compile it. 

1628 sql, params = self.compile(val) 

1629 elif hasattr(field, "get_placeholder"): 

1630 # Some fields (e.g. geo fields) need special munging before 

1631 # they can be inserted. 

1632 sql, params = field.get_placeholder(val, self, self.connection), [val] 

1633 else: 

1634 # Return the common case for the placeholder 

1635 sql, params = "%s", [val] 

1636 

1637 # The following hook is only used by Oracle Spatial, which sometimes 

1638 # needs to yield 'NULL' and [] as its placeholder and params instead 

1639 # of '%s' and [None]. The 'NULL' placeholder is produced earlier by 

1640 # OracleOperations.get_geom_placeholder(). The following line removes 

1641 # the corresponding None parameter. See ticket #10888. 

1642 params = self.connection.ops.modify_insert_params(sql, params) 

1643 

1644 return sql, params 

1645 

1646 def prepare_value(self, field, value): 

1647 """ 

1648 Prepare a value to be used in a query by resolving it if it is an 

1649 expression and otherwise calling the field's get_db_prep_save(). 

1650 """ 

1651 if hasattr(value, "resolve_expression"): 

1652 value = value.resolve_expression( 

1653 self.query, allow_joins=False, for_save=True 

1654 ) 

1655 # Don't allow values containing Col expressions. They refer to 

1656 # existing columns on a row, but in the case of insert the row 

1657 # doesn't exist yet. 

1658 if value.contains_column_references: 

1659 raise ValueError( 

1660 f'Failed to insert expression "{value}" on {field}. F() expressions ' 

1661 "can only be used to update, not to insert." 

1662 ) 

1663 if value.contains_aggregate: 

1664 raise FieldError( 

1665 "Aggregate functions are not allowed in this query " 

1666 f"({field.name}={value!r})." 

1667 ) 

1668 if value.contains_over_clause: 

1669 raise FieldError( 

1670 f"Window expressions are not allowed in this query ({field.name}={value!r})." 

1671 ) 

1672 return field.get_db_prep_save(value, connection=self.connection) 

1673 

1674 def pre_save_val(self, field, obj): 

1675 """ 

1676 Get the given field's value off the given obj. pre_save() is used for 

1677 things like auto_now on DateTimeField. Skip it if this is a raw query. 

1678 """ 

1679 if self.query.raw: 

1680 return getattr(obj, field.attname) 

1681 return field.pre_save(obj, add=True) 

1682 

1683 def assemble_as_sql(self, fields, value_rows): 

1684 """ 

1685 Take a sequence of N fields and a sequence of M rows of values, and 

1686 generate placeholder SQL and parameters for each field and value. 

1687 Return a pair containing: 

1688 * a sequence of M rows of N SQL placeholder strings, and 

1689 * a sequence of M rows of corresponding parameter values. 

1690 

1691 Each placeholder string may contain any number of '%s' interpolation 

1692 strings, and each parameter row will contain exactly as many params 

1693 as the total number of '%s's in the corresponding placeholder row. 

1694 """ 

1695 if not value_rows: 

1696 return [], [] 

1697 

1698 # list of (sql, [params]) tuples for each object to be saved 

1699 # Shape: [n_objs][n_fields][2] 

1700 rows_of_fields_as_sql = ( 

1701 (self.field_as_sql(field, v) for field, v in zip(fields, row)) 

1702 for row in value_rows 

1703 ) 

1704 

1705 # tuple like ([sqls], [[params]s]) for each object to be saved 

1706 # Shape: [n_objs][2][n_fields] 

1707 sql_and_param_pair_rows = (zip(*row) for row in rows_of_fields_as_sql) 

1708 

1709 # Extract separate lists for placeholders and params. 

1710 # Each of these has shape [n_objs][n_fields] 

1711 placeholder_rows, param_rows = zip(*sql_and_param_pair_rows) 

1712 

1713 # Params for each field are still lists, and need to be flattened. 

1714 param_rows = [[p for ps in row for p in ps] for row in param_rows] 

1715 

1716 return placeholder_rows, param_rows 

1717 

1718 def as_sql(self): 

1719 # We don't need quote_name_unless_alias() here, since these are all 

1720 # going to be column names (so we can avoid the extra overhead). 

1721 qn = self.connection.ops.quote_name 

1722 opts = self.query.get_meta() 

1723 insert_statement = self.connection.ops.insert_statement( 

1724 on_conflict=self.query.on_conflict, 

1725 ) 

1726 result = [f"{insert_statement} {qn(opts.db_table)}"] 

1727 fields = self.query.fields or [opts.pk] 

1728 result.append("({})".format(", ".join(qn(f.column) for f in fields))) 

1729 

1730 if self.query.fields: 

1731 value_rows = [ 

1732 [ 

1733 self.prepare_value(field, self.pre_save_val(field, obj)) 

1734 for field in fields 

1735 ] 

1736 for obj in self.query.objs 

1737 ] 

1738 else: 

1739 # An empty object. 

1740 value_rows = [ 

1741 [self.connection.ops.pk_default_value()] for _ in self.query.objs 

1742 ] 

1743 fields = [None] 

1744 

1745 # Currently the backends just accept values when generating bulk 

1746 # queries and generate their own placeholders. Doing that isn't 

1747 # necessary and it should be possible to use placeholders and 

1748 # expressions in bulk inserts too. 

1749 can_bulk = ( 

1750 not self.returning_fields and self.connection.features.has_bulk_insert 

1751 ) 

1752 

1753 placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows) 

1754 

1755 on_conflict_suffix_sql = self.connection.ops.on_conflict_suffix_sql( 

1756 fields, 

1757 self.query.on_conflict, 

1758 (f.column for f in self.query.update_fields), 

1759 (f.column for f in self.query.unique_fields), 

1760 ) 

1761 if ( 

1762 self.returning_fields 

1763 and self.connection.features.can_return_columns_from_insert 

1764 ): 

1765 if self.connection.features.can_return_rows_from_bulk_insert: 

1766 result.append( 

1767 self.connection.ops.bulk_insert_sql(fields, placeholder_rows) 

1768 ) 

1769 params = param_rows 

1770 else: 

1771 result.append("VALUES ({})".format(", ".join(placeholder_rows[0]))) 

1772 params = [param_rows[0]] 

1773 if on_conflict_suffix_sql: 

1774 result.append(on_conflict_suffix_sql) 

1775 # Skip empty r_sql to allow subclasses to customize behavior for 

1776 # 3rd party backends. Refs #19096. 

1777 r_sql, self.returning_params = self.connection.ops.return_insert_columns( 

1778 self.returning_fields 

1779 ) 

1780 if r_sql: 

1781 result.append(r_sql) 

1782 params += [self.returning_params] 

1783 return [(" ".join(result), tuple(chain.from_iterable(params)))] 

1784 

1785 if can_bulk: 

1786 result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) 

1787 if on_conflict_suffix_sql: 

1788 result.append(on_conflict_suffix_sql) 

1789 return [(" ".join(result), tuple(p for ps in param_rows for p in ps))] 

1790 else: 

1791 if on_conflict_suffix_sql: 

1792 result.append(on_conflict_suffix_sql) 

1793 return [ 

1794 (" ".join(result + ["VALUES ({})".format(", ".join(p))]), vals) 

1795 for p, vals in zip(placeholder_rows, param_rows) 

1796 ] 

1797 

1798 def execute_sql(self, returning_fields=None): 

1799 assert not ( 

1800 returning_fields 

1801 and len(self.query.objs) != 1 

1802 and not self.connection.features.can_return_rows_from_bulk_insert 

1803 ) 

1804 opts = self.query.get_meta() 

1805 self.returning_fields = returning_fields 

1806 with self.connection.cursor() as cursor: 

1807 for sql, params in self.as_sql(): 

1808 cursor.execute(sql, params) 

1809 if not self.returning_fields: 

1810 return [] 

1811 if ( 

1812 self.connection.features.can_return_rows_from_bulk_insert 

1813 and len(self.query.objs) > 1 

1814 ): 

1815 rows = self.connection.ops.fetch_returned_insert_rows(cursor) 

1816 elif self.connection.features.can_return_columns_from_insert: 

1817 assert len(self.query.objs) == 1 

1818 rows = [ 

1819 self.connection.ops.fetch_returned_insert_columns( 

1820 cursor, 

1821 self.returning_params, 

1822 ) 

1823 ] 

1824 else: 

1825 rows = [ 

1826 ( 

1827 self.connection.ops.last_insert_id( 

1828 cursor, 

1829 opts.db_table, 

1830 opts.pk.column, 

1831 ), 

1832 ) 

1833 ] 

1834 cols = [field.get_col(opts.db_table) for field in self.returning_fields] 

1835 converters = self.get_converters(cols) 

1836 if converters: 

1837 rows = list(self.apply_converters(rows, converters)) 

1838 return rows 

1839 

1840 

1841class SQLDeleteCompiler(SQLCompiler): 

1842 @cached_property 

1843 def single_alias(self): 

1844 # Ensure base table is in aliases. 

1845 self.query.get_initial_alias() 

1846 return sum(self.query.alias_refcount[t] > 0 for t in self.query.alias_map) == 1 

1847 

1848 @classmethod 

1849 def _expr_refs_base_model(cls, expr, base_model): 

1850 if isinstance(expr, Query): 

1851 return expr.model == base_model 

1852 if not hasattr(expr, "get_source_expressions"): 

1853 return False 

1854 return any( 

1855 cls._expr_refs_base_model(source_expr, base_model) 

1856 for source_expr in expr.get_source_expressions() 

1857 ) 

1858 

1859 @cached_property 

1860 def contains_self_reference_subquery(self): 

1861 return any( 

1862 self._expr_refs_base_model(expr, self.query.model) 

1863 for expr in chain( 

1864 self.query.annotations.values(), self.query.where.children 

1865 ) 

1866 ) 

1867 

1868 def _as_sql(self, query): 

1869 delete = f"DELETE FROM {self.quote_name_unless_alias(query.base_table)}" 

1870 try: 

1871 where, params = self.compile(query.where) 

1872 except FullResultSet: 

1873 return delete, () 

1874 return f"{delete} WHERE {where}", tuple(params) 

1875 

1876 def as_sql(self): 

1877 """ 

1878 Create the SQL for this query. Return the SQL string and list of 

1879 parameters. 

1880 """ 

1881 if self.single_alias and not self.contains_self_reference_subquery: 

1882 return self._as_sql(self.query) 

1883 innerq = self.query.clone() 

1884 innerq.__class__ = Query 

1885 innerq.clear_select_clause() 

1886 pk = self.query.model._meta.pk 

1887 innerq.select = [pk.get_col(self.query.get_initial_alias())] 

1888 outerq = Query(self.query.model) 

1889 if not self.connection.features.update_can_self_select: 

1890 # Force the materialization of the inner query to allow reference 

1891 # to the target table on MySQL. 

1892 sql, params = innerq.get_compiler(connection=self.connection).as_sql() 

1893 innerq = RawSQL(f"SELECT * FROM ({sql}) subquery", params) 

1894 outerq.add_filter("pk__in", innerq) 

1895 return self._as_sql(outerq) 

1896 

1897 

1898class SQLUpdateCompiler(SQLCompiler): 

1899 def as_sql(self): 

1900 """ 

1901 Create the SQL for this query. Return the SQL string and list of 

1902 parameters. 

1903 """ 

1904 self.pre_sql_setup() 

1905 if not self.query.values: 

1906 return "", () 

1907 qn = self.quote_name_unless_alias 

1908 values, update_params = [], [] 

1909 for field, model, val in self.query.values: 

1910 if hasattr(val, "resolve_expression"): 

1911 val = val.resolve_expression( 

1912 self.query, allow_joins=False, for_save=True 

1913 ) 

1914 if val.contains_aggregate: 

1915 raise FieldError( 

1916 "Aggregate functions are not allowed in this query " 

1917 f"({field.name}={val!r})." 

1918 ) 

1919 if val.contains_over_clause: 

1920 raise FieldError( 

1921 "Window expressions are not allowed in this query " 

1922 f"({field.name}={val!r})." 

1923 ) 

1924 elif hasattr(val, "prepare_database_save"): 

1925 if field.remote_field: 

1926 val = val.prepare_database_save(field) 

1927 else: 

1928 raise TypeError( 

1929 f"Tried to update field {field} with a model instance, {val!r}. " 

1930 f"Use a value compatible with {field.__class__.__name__}." 

1931 ) 

1932 val = field.get_db_prep_save(val, connection=self.connection) 

1933 

1934 # Getting the placeholder for the field. 

1935 if hasattr(field, "get_placeholder"): 

1936 placeholder = field.get_placeholder(val, self, self.connection) 

1937 else: 

1938 placeholder = "%s" 

1939 name = field.column 

1940 if hasattr(val, "as_sql"): 

1941 sql, params = self.compile(val) 

1942 values.append(f"{qn(name)} = {placeholder % sql}") 

1943 update_params.extend(params) 

1944 elif val is not None: 

1945 values.append(f"{qn(name)} = {placeholder}") 

1946 update_params.append(val) 

1947 else: 

1948 values.append(f"{qn(name)} = NULL") 

1949 table = self.query.base_table 

1950 result = [ 

1951 f"UPDATE {qn(table)} SET", 

1952 ", ".join(values), 

1953 ] 

1954 try: 

1955 where, params = self.compile(self.query.where) 

1956 except FullResultSet: 

1957 params = [] 

1958 else: 

1959 result.append(f"WHERE {where}") 

1960 return " ".join(result), tuple(update_params + params) 

1961 

1962 def execute_sql(self, result_type): 

1963 """ 

1964 Execute the specified update. Return the number of rows affected by 

1965 the primary update query. The "primary update query" is the first 

1966 non-empty query that is executed. Row counts for any subsequent, 

1967 related queries are not available. 

1968 """ 

1969 cursor = super().execute_sql(result_type) 

1970 try: 

1971 rows = cursor.rowcount if cursor else 0 

1972 is_empty = cursor is None 

1973 finally: 

1974 if cursor: 

1975 cursor.close() 

1976 for query in self.query.get_related_updates(): 

1977 aux_rows = query.get_compiler(self.using).execute_sql(result_type) 

1978 if is_empty and aux_rows: 

1979 rows = aux_rows 

1980 is_empty = False 

1981 return rows 

1982 

1983 def pre_sql_setup(self): 

1984 """ 

1985 If the update depends on results from other tables, munge the "where" 

1986 conditions to match the format required for (portable) SQL updates. 

1987 

1988 If multiple updates are required, pull out the id values to update at 

1989 this point so that they don't change as a result of the progressive 

1990 updates. 

1991 """ 

1992 refcounts_before = self.query.alias_refcount.copy() 

1993 # Ensure base table is in the query 

1994 self.query.get_initial_alias() 

1995 count = self.query.count_active_tables() 

1996 if not self.query.related_updates and count == 1: 

1997 return 

1998 query = self.query.chain(klass=Query) 

1999 query.select_related = False 

2000 query.clear_ordering(force=True) 

2001 query.extra = {} 

2002 query.select = [] 

2003 meta = query.get_meta() 

2004 fields = [meta.pk.name] 

2005 related_ids_index = [] 

2006 for related in self.query.related_updates: 

2007 if all( 

2008 path.join_field.primary_key for path in meta.get_path_to_parent(related) 

2009 ): 

2010 # If a primary key chain exists to the targeted related update, 

2011 # then the meta.pk value can be used for it. 

2012 related_ids_index.append((related, 0)) 

2013 else: 

2014 # This branch will only be reached when updating a field of an 

2015 # ancestor that is not part of the primary key chain of a MTI 

2016 # tree. 

2017 related_ids_index.append((related, len(fields))) 

2018 fields.append(related._meta.pk.name) 

2019 query.add_fields(fields) 

2020 super().pre_sql_setup() 

2021 

2022 must_pre_select = ( 

2023 count > 1 and not self.connection.features.update_can_self_select 

2024 ) 

2025 

2026 # Now we adjust the current query: reset the where clause and get rid 

2027 # of all the tables we don't need (since they're in the sub-select). 

2028 self.query.clear_where() 

2029 if self.query.related_updates or must_pre_select: 

2030 # Either we're using the idents in multiple update queries (so 

2031 # don't want them to change), or the db backend doesn't support 

2032 # selecting from the updating table (e.g. MySQL). 

2033 idents = [] 

2034 related_ids = collections.defaultdict(list) 

2035 for rows in query.get_compiler(self.using).execute_sql(MULTI): 

2036 idents.extend(r[0] for r in rows) 

2037 for parent, index in related_ids_index: 

2038 related_ids[parent].extend(r[index] for r in rows) 

2039 self.query.add_filter("pk__in", idents) 

2040 self.query.related_ids = related_ids 

2041 else: 

2042 # The fast path. Filters and updates in one query. 

2043 self.query.add_filter("pk__in", query) 

2044 self.query.reset_refcounts(refcounts_before) 

2045 

2046 

2047class SQLAggregateCompiler(SQLCompiler): 

2048 def as_sql(self): 

2049 """ 

2050 Create the SQL for this query. Return the SQL string and list of 

2051 parameters. 

2052 """ 

2053 sql, params = [], [] 

2054 for annotation in self.query.annotation_select.values(): 

2055 ann_sql, ann_params = self.compile(annotation) 

2056 ann_sql, ann_params = annotation.select_format(self, ann_sql, ann_params) 

2057 sql.append(ann_sql) 

2058 params.extend(ann_params) 

2059 self.col_count = len(self.query.annotation_select) 

2060 sql = ", ".join(sql) 

2061 params = tuple(params) 

2062 

2063 inner_query_sql, inner_query_params = self.query.inner_query.get_compiler( 

2064 self.using, 

2065 elide_empty=self.elide_empty, 

2066 ).as_sql(with_col_aliases=True) 

2067 sql = f"SELECT {sql} FROM ({inner_query_sql}) subquery" 

2068 params += inner_query_params 

2069 return sql, params 

2070 

2071 

2072def cursor_iter(cursor, sentinel, col_count, itersize): 

2073 """ 

2074 Yield blocks of rows from a cursor and ensure the cursor is closed when 

2075 done. 

2076 """ 

2077 try: 

2078 for rows in iter((lambda: cursor.fetchmany(itersize)), sentinel): 

2079 yield rows if col_count is None else [r[:col_count] for r in rows] 

2080 finally: 

2081 cursor.close()