Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/expressions.py: 44%

983 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 11:16 -0600

1import copy 

2import datetime 

3import functools 

4import inspect 

5from collections import defaultdict 

6from decimal import Decimal 

7from types import NoneType 

8from uuid import UUID 

9 

10from plain.exceptions import EmptyResultSet, FieldError, FullResultSet 

11from plain.models import fields 

12from plain.models.constants import LOOKUP_SEP 

13from plain.models.db import DatabaseError, NotSupportedError, connection 

14from plain.models.query_utils import Q 

15from plain.utils.deconstruct import deconstructible 

16from plain.utils.functional import cached_property 

17from plain.utils.hashable import make_hashable 

18 

19 

20class SQLiteNumericMixin: 

21 """ 

22 Some expressions with output_field=DecimalField() must be cast to 

23 numeric to be properly filtered. 

24 """ 

25 

26 def as_sqlite(self, compiler, connection, **extra_context): 

27 sql, params = self.as_sql(compiler, connection, **extra_context) 

28 try: 

29 if self.output_field.get_internal_type() == "DecimalField": 

30 sql = f"CAST({sql} AS NUMERIC)" 

31 except FieldError: 

32 pass 

33 return sql, params 

34 

35 

36class Combinable: 

37 """ 

38 Provide the ability to combine one or two objects with 

39 some connector. For example F('foo') + F('bar'). 

40 """ 

41 

42 # Arithmetic connectors 

43 ADD = "+" 

44 SUB = "-" 

45 MUL = "*" 

46 DIV = "/" 

47 POW = "^" 

48 # The following is a quoted % operator - it is quoted because it can be 

49 # used in strings that also have parameter substitution. 

50 MOD = "%%" 

51 

52 # Bitwise operators - note that these are generated by .bitand() 

53 # and .bitor(), the '&' and '|' are reserved for boolean operator 

54 # usage. 

55 BITAND = "&" 

56 BITOR = "|" 

57 BITLEFTSHIFT = "<<" 

58 BITRIGHTSHIFT = ">>" 

59 BITXOR = "#" 

60 

61 def _combine(self, other, connector, reversed): 

62 if not hasattr(other, "resolve_expression"): 

63 # everything must be resolvable to an expression 

64 other = Value(other) 

65 

66 if reversed: 

67 return CombinedExpression(other, connector, self) 

68 return CombinedExpression(self, connector, other) 

69 

70 ############# 

71 # OPERATORS # 

72 ############# 

73 

74 def __neg__(self): 

75 return self._combine(-1, self.MUL, False) 

76 

77 def __add__(self, other): 

78 return self._combine(other, self.ADD, False) 

79 

80 def __sub__(self, other): 

81 return self._combine(other, self.SUB, False) 

82 

83 def __mul__(self, other): 

84 return self._combine(other, self.MUL, False) 

85 

86 def __truediv__(self, other): 

87 return self._combine(other, self.DIV, False) 

88 

89 def __mod__(self, other): 

90 return self._combine(other, self.MOD, False) 

91 

92 def __pow__(self, other): 

93 return self._combine(other, self.POW, False) 

94 

95 def __and__(self, other): 

96 if getattr(self, "conditional", False) and getattr(other, "conditional", False): 

97 return Q(self) & Q(other) 

98 raise NotImplementedError( 

99 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations." 

100 ) 

101 

102 def bitand(self, other): 

103 return self._combine(other, self.BITAND, False) 

104 

105 def bitleftshift(self, other): 

106 return self._combine(other, self.BITLEFTSHIFT, False) 

107 

108 def bitrightshift(self, other): 

109 return self._combine(other, self.BITRIGHTSHIFT, False) 

110 

111 def __xor__(self, other): 

112 if getattr(self, "conditional", False) and getattr(other, "conditional", False): 

113 return Q(self) ^ Q(other) 

114 raise NotImplementedError( 

115 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations." 

116 ) 

117 

118 def bitxor(self, other): 

119 return self._combine(other, self.BITXOR, False) 

120 

121 def __or__(self, other): 

122 if getattr(self, "conditional", False) and getattr(other, "conditional", False): 

123 return Q(self) | Q(other) 

124 raise NotImplementedError( 

125 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations." 

126 ) 

127 

128 def bitor(self, other): 

129 return self._combine(other, self.BITOR, False) 

130 

131 def __radd__(self, other): 

132 return self._combine(other, self.ADD, True) 

133 

134 def __rsub__(self, other): 

135 return self._combine(other, self.SUB, True) 

136 

137 def __rmul__(self, other): 

138 return self._combine(other, self.MUL, True) 

139 

140 def __rtruediv__(self, other): 

141 return self._combine(other, self.DIV, True) 

142 

143 def __rmod__(self, other): 

144 return self._combine(other, self.MOD, True) 

145 

146 def __rpow__(self, other): 

147 return self._combine(other, self.POW, True) 

148 

149 def __rand__(self, other): 

150 raise NotImplementedError( 

151 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations." 

152 ) 

153 

154 def __ror__(self, other): 

155 raise NotImplementedError( 

156 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations." 

157 ) 

158 

159 def __rxor__(self, other): 

160 raise NotImplementedError( 

161 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations." 

162 ) 

163 

164 def __invert__(self): 

165 return NegatedExpression(self) 

166 

167 

168class BaseExpression: 

169 """Base class for all query expressions.""" 

170 

171 empty_result_set_value = NotImplemented 

172 # aggregate specific fields 

173 is_summary = False 

174 _output_field_resolved_to_none = False 

175 # Can the expression be used in a WHERE clause? 

176 filterable = True 

177 # Can the expression can be used as a source expression in Window? 

178 window_compatible = False 

179 

180 def __init__(self, output_field=None): 

181 if output_field is not None: 

182 self.output_field = output_field 

183 

184 def __getstate__(self): 

185 state = self.__dict__.copy() 

186 state.pop("convert_value", None) 

187 return state 

188 

189 def get_db_converters(self, connection): 

190 return ( 

191 [] 

192 if self.convert_value is self._convert_value_noop 

193 else [self.convert_value] 

194 ) + self.output_field.get_db_converters(connection) 

195 

196 def get_source_expressions(self): 

197 return [] 

198 

199 def set_source_expressions(self, exprs): 

200 assert not exprs 

201 

202 def _parse_expressions(self, *expressions): 

203 return [ 

204 arg 

205 if hasattr(arg, "resolve_expression") 

206 else (F(arg) if isinstance(arg, str) else Value(arg)) 

207 for arg in expressions 

208 ] 

209 

210 def as_sql(self, compiler, connection): 

211 """ 

212 Responsible for returning a (sql, [params]) tuple to be included 

213 in the current query. 

214 

215 Different backends can provide their own implementation, by 

216 providing an `as_{vendor}` method and patching the Expression: 

217 

218 ``` 

219 def override_as_sql(self, compiler, connection): 

220 # custom logic 

221 return super().as_sql(compiler, connection) 

222 setattr(Expression, 'as_' + connection.vendor, override_as_sql) 

223 ``` 

224 

225 Arguments: 

226 * compiler: the query compiler responsible for generating the query. 

227 Must have a compile method, returning a (sql, [params]) tuple. 

228 Calling compiler(value) will return a quoted `value`. 

229 

230 * connection: the database connection used for the current query. 

231 

232 Return: (sql, params) 

233 Where `sql` is a string containing ordered sql parameters to be 

234 replaced with the elements of the list `params`. 

235 """ 

236 raise NotImplementedError("Subclasses must implement as_sql()") 

237 

238 @cached_property 

239 def contains_aggregate(self): 

240 return any( 

241 expr and expr.contains_aggregate for expr in self.get_source_expressions() 

242 ) 

243 

244 @cached_property 

245 def contains_over_clause(self): 

246 return any( 

247 expr and expr.contains_over_clause for expr in self.get_source_expressions() 

248 ) 

249 

250 @cached_property 

251 def contains_column_references(self): 

252 return any( 

253 expr and expr.contains_column_references 

254 for expr in self.get_source_expressions() 

255 ) 

256 

257 def resolve_expression( 

258 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

259 ): 

260 """ 

261 Provide the chance to do any preprocessing or validation before being 

262 added to the query. 

263 

264 Arguments: 

265 * query: the backend query implementation 

266 * allow_joins: boolean allowing or denying use of joins 

267 in this query 

268 * reuse: a set of reusable joins for multijoins 

269 * summarize: a terminal aggregate clause 

270 * for_save: whether this expression about to be used in a save or update 

271 

272 Return: an Expression to be added to the query. 

273 """ 

274 c = self.copy() 

275 c.is_summary = summarize 

276 c.set_source_expressions( 

277 [ 

278 expr.resolve_expression(query, allow_joins, reuse, summarize) 

279 if expr 

280 else None 

281 for expr in c.get_source_expressions() 

282 ] 

283 ) 

284 return c 

285 

286 @property 

287 def conditional(self): 

288 return isinstance(self.output_field, fields.BooleanField) 

289 

290 @property 

291 def field(self): 

292 return self.output_field 

293 

294 @cached_property 

295 def output_field(self): 

296 """Return the output type of this expressions.""" 

297 output_field = self._resolve_output_field() 

298 if output_field is None: 

299 self._output_field_resolved_to_none = True 

300 raise FieldError("Cannot resolve expression type, unknown output_field") 

301 return output_field 

302 

303 @cached_property 

304 def _output_field_or_none(self): 

305 """ 

306 Return the output field of this expression, or None if 

307 _resolve_output_field() didn't return an output type. 

308 """ 

309 try: 

310 return self.output_field 

311 except FieldError: 

312 if not self._output_field_resolved_to_none: 

313 raise 

314 

315 def _resolve_output_field(self): 

316 """ 

317 Attempt to infer the output type of the expression. 

318 

319 As a guess, if the output fields of all source fields match then simply 

320 infer the same type here. 

321 

322 If a source's output field resolves to None, exclude it from this check. 

323 If all sources are None, then an error is raised higher up the stack in 

324 the output_field property. 

325 """ 

326 # This guess is mostly a bad idea, but there is quite a lot of code 

327 # (especially 3rd party Func subclasses) that depend on it, we'd need a 

328 # deprecation path to fix it. 

329 sources_iter = ( 

330 source for source in self.get_source_fields() if source is not None 

331 ) 

332 for output_field in sources_iter: 

333 for source in sources_iter: 

334 if not isinstance(output_field, source.__class__): 

335 raise FieldError( 

336 f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must " 

337 "set output_field." 

338 ) 

339 return output_field 

340 

341 @staticmethod 

342 def _convert_value_noop(value, expression, connection): 

343 return value 

344 

345 @cached_property 

346 def convert_value(self): 

347 """ 

348 Expressions provide their own converters because users have the option 

349 of manually specifying the output_field which may be a different type 

350 from the one the database returns. 

351 """ 

352 field = self.output_field 

353 internal_type = field.get_internal_type() 

354 if internal_type == "FloatField": 

355 return ( 

356 lambda value, expression, connection: None 

357 if value is None 

358 else float(value) 

359 ) 

360 elif internal_type.endswith("IntegerField"): 

361 return ( 

362 lambda value, expression, connection: None 

363 if value is None 

364 else int(value) 

365 ) 

366 elif internal_type == "DecimalField": 

367 return ( 

368 lambda value, expression, connection: None 

369 if value is None 

370 else Decimal(value) 

371 ) 

372 return self._convert_value_noop 

373 

374 def get_lookup(self, lookup): 

375 return self.output_field.get_lookup(lookup) 

376 

377 def get_transform(self, name): 

378 return self.output_field.get_transform(name) 

379 

380 def relabeled_clone(self, change_map): 

381 clone = self.copy() 

382 clone.set_source_expressions( 

383 [ 

384 e.relabeled_clone(change_map) if e is not None else None 

385 for e in self.get_source_expressions() 

386 ] 

387 ) 

388 return clone 

389 

390 def replace_expressions(self, replacements): 

391 if replacement := replacements.get(self): 

392 return replacement 

393 clone = self.copy() 

394 source_expressions = clone.get_source_expressions() 

395 clone.set_source_expressions( 

396 [ 

397 expr.replace_expressions(replacements) if expr else None 

398 for expr in source_expressions 

399 ] 

400 ) 

401 return clone 

402 

403 def get_refs(self): 

404 refs = set() 

405 for expr in self.get_source_expressions(): 

406 refs |= expr.get_refs() 

407 return refs 

408 

409 def copy(self): 

410 return copy.copy(self) 

411 

412 def prefix_references(self, prefix): 

413 clone = self.copy() 

414 clone.set_source_expressions( 

415 [ 

416 F(f"{prefix}{expr.name}") 

417 if isinstance(expr, F) 

418 else expr.prefix_references(prefix) 

419 for expr in self.get_source_expressions() 

420 ] 

421 ) 

422 return clone 

423 

424 def get_group_by_cols(self): 

425 if not self.contains_aggregate: 

426 return [self] 

427 cols = [] 

428 for source in self.get_source_expressions(): 

429 cols.extend(source.get_group_by_cols()) 

430 return cols 

431 

432 def get_source_fields(self): 

433 """Return the underlying field types used by this aggregate.""" 

434 return [e._output_field_or_none for e in self.get_source_expressions()] 

435 

436 def asc(self, **kwargs): 

437 return OrderBy(self, **kwargs) 

438 

439 def desc(self, **kwargs): 

440 return OrderBy(self, descending=True, **kwargs) 

441 

442 def reverse_ordering(self): 

443 return self 

444 

445 def flatten(self): 

446 """ 

447 Recursively yield this expression and all subexpressions, in 

448 depth-first order. 

449 """ 

450 yield self 

451 for expr in self.get_source_expressions(): 

452 if expr: 

453 if hasattr(expr, "flatten"): 

454 yield from expr.flatten() 

455 else: 

456 yield expr 

457 

458 def select_format(self, compiler, sql, params): 

459 """ 

460 Custom format for select clauses. For example, EXISTS expressions need 

461 to be wrapped in CASE WHEN on Oracle. 

462 """ 

463 if hasattr(self.output_field, "select_format"): 

464 return self.output_field.select_format(compiler, sql, params) 

465 return sql, params 

466 

467 

468@deconstructible 

469class Expression(BaseExpression, Combinable): 

470 """An expression that can be combined with other expressions.""" 

471 

472 @cached_property 

473 def identity(self): 

474 constructor_signature = inspect.signature(self.__init__) 

475 args, kwargs = self._constructor_args 

476 signature = constructor_signature.bind_partial(*args, **kwargs) 

477 signature.apply_defaults() 

478 arguments = signature.arguments.items() 

479 identity = [self.__class__] 

480 for arg, value in arguments: 

481 if isinstance(value, fields.Field): 

482 if value.name and value.model: 

483 value = (value.model._meta.label, value.name) 

484 else: 

485 value = type(value) 

486 else: 

487 value = make_hashable(value) 

488 identity.append((arg, value)) 

489 return tuple(identity) 

490 

491 def __eq__(self, other): 

492 if not isinstance(other, Expression): 

493 return NotImplemented 

494 return other.identity == self.identity 

495 

496 def __hash__(self): 

497 return hash(self.identity) 

498 

499 

500# Type inference for CombinedExpression.output_field. 

501# Missing items will result in FieldError, by design. 

502# 

503# The current approach for NULL is based on lowest common denominator behavior 

504# i.e. if one of the supported databases is raising an error (rather than 

505# return NULL) for `val <op> NULL`, then Plain raises FieldError. 

506 

507_connector_combinations = [ 

508 # Numeric operations - operands of same type. 

509 { 

510 connector: [ 

511 (fields.IntegerField, fields.IntegerField, fields.IntegerField), 

512 (fields.FloatField, fields.FloatField, fields.FloatField), 

513 (fields.DecimalField, fields.DecimalField, fields.DecimalField), 

514 ] 

515 for connector in ( 

516 Combinable.ADD, 

517 Combinable.SUB, 

518 Combinable.MUL, 

519 # Behavior for DIV with integer arguments follows Postgres/SQLite, 

520 # not MySQL/Oracle. 

521 Combinable.DIV, 

522 Combinable.MOD, 

523 Combinable.POW, 

524 ) 

525 }, 

526 # Numeric operations - operands of different type. 

527 { 

528 connector: [ 

529 (fields.IntegerField, fields.DecimalField, fields.DecimalField), 

530 (fields.DecimalField, fields.IntegerField, fields.DecimalField), 

531 (fields.IntegerField, fields.FloatField, fields.FloatField), 

532 (fields.FloatField, fields.IntegerField, fields.FloatField), 

533 ] 

534 for connector in ( 

535 Combinable.ADD, 

536 Combinable.SUB, 

537 Combinable.MUL, 

538 Combinable.DIV, 

539 Combinable.MOD, 

540 ) 

541 }, 

542 # Bitwise operators. 

543 { 

544 connector: [ 

545 (fields.IntegerField, fields.IntegerField, fields.IntegerField), 

546 ] 

547 for connector in ( 

548 Combinable.BITAND, 

549 Combinable.BITOR, 

550 Combinable.BITLEFTSHIFT, 

551 Combinable.BITRIGHTSHIFT, 

552 Combinable.BITXOR, 

553 ) 

554 }, 

555 # Numeric with NULL. 

556 { 

557 connector: [ 

558 (field_type, NoneType, field_type), 

559 (NoneType, field_type, field_type), 

560 ] 

561 for connector in ( 

562 Combinable.ADD, 

563 Combinable.SUB, 

564 Combinable.MUL, 

565 Combinable.DIV, 

566 Combinable.MOD, 

567 Combinable.POW, 

568 ) 

569 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField) 

570 }, 

571 # Date/DateTimeField/DurationField/TimeField. 

572 { 

573 Combinable.ADD: [ 

574 # Date/DateTimeField. 

575 (fields.DateField, fields.DurationField, fields.DateTimeField), 

576 (fields.DateTimeField, fields.DurationField, fields.DateTimeField), 

577 (fields.DurationField, fields.DateField, fields.DateTimeField), 

578 (fields.DurationField, fields.DateTimeField, fields.DateTimeField), 

579 # DurationField. 

580 (fields.DurationField, fields.DurationField, fields.DurationField), 

581 # TimeField. 

582 (fields.TimeField, fields.DurationField, fields.TimeField), 

583 (fields.DurationField, fields.TimeField, fields.TimeField), 

584 ], 

585 }, 

586 { 

587 Combinable.SUB: [ 

588 # Date/DateTimeField. 

589 (fields.DateField, fields.DurationField, fields.DateTimeField), 

590 (fields.DateTimeField, fields.DurationField, fields.DateTimeField), 

591 (fields.DateField, fields.DateField, fields.DurationField), 

592 (fields.DateField, fields.DateTimeField, fields.DurationField), 

593 (fields.DateTimeField, fields.DateField, fields.DurationField), 

594 (fields.DateTimeField, fields.DateTimeField, fields.DurationField), 

595 # DurationField. 

596 (fields.DurationField, fields.DurationField, fields.DurationField), 

597 # TimeField. 

598 (fields.TimeField, fields.DurationField, fields.TimeField), 

599 (fields.TimeField, fields.TimeField, fields.DurationField), 

600 ], 

601 }, 

602] 

603 

604_connector_combinators = defaultdict(list) 

605 

606 

607def register_combinable_fields(lhs, connector, rhs, result): 

608 """ 

609 Register combinable types: 

610 lhs <connector> rhs -> result 

611 e.g. 

612 register_combinable_fields( 

613 IntegerField, Combinable.ADD, FloatField, FloatField 

614 ) 

615 """ 

616 _connector_combinators[connector].append((lhs, rhs, result)) 

617 

618 

619for d in _connector_combinations: 

620 for connector, field_types in d.items(): 

621 for lhs, rhs, result in field_types: 

622 register_combinable_fields(lhs, connector, rhs, result) 

623 

624 

625@functools.lru_cache(maxsize=128) 

626def _resolve_combined_type(connector, lhs_type, rhs_type): 

627 combinators = _connector_combinators.get(connector, ()) 

628 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators: 

629 if issubclass(lhs_type, combinator_lhs_type) and issubclass( 

630 rhs_type, combinator_rhs_type 

631 ): 

632 return combined_type 

633 

634 

635class CombinedExpression(SQLiteNumericMixin, Expression): 

636 def __init__(self, lhs, connector, rhs, output_field=None): 

637 super().__init__(output_field=output_field) 

638 self.connector = connector 

639 self.lhs = lhs 

640 self.rhs = rhs 

641 

642 def __repr__(self): 

643 return f"<{self.__class__.__name__}: {self}>" 

644 

645 def __str__(self): 

646 return f"{self.lhs} {self.connector} {self.rhs}" 

647 

648 def get_source_expressions(self): 

649 return [self.lhs, self.rhs] 

650 

651 def set_source_expressions(self, exprs): 

652 self.lhs, self.rhs = exprs 

653 

654 def _resolve_output_field(self): 

655 # We avoid using super() here for reasons given in 

656 # Expression._resolve_output_field() 

657 combined_type = _resolve_combined_type( 

658 self.connector, 

659 type(self.lhs._output_field_or_none), 

660 type(self.rhs._output_field_or_none), 

661 ) 

662 if combined_type is None: 

663 raise FieldError( 

664 f"Cannot infer type of {self.connector!r} expression involving these " 

665 f"types: {self.lhs.output_field.__class__.__name__}, " 

666 f"{self.rhs.output_field.__class__.__name__}. You must set " 

667 f"output_field." 

668 ) 

669 return combined_type() 

670 

671 def as_sql(self, compiler, connection): 

672 expressions = [] 

673 expression_params = [] 

674 sql, params = compiler.compile(self.lhs) 

675 expressions.append(sql) 

676 expression_params.extend(params) 

677 sql, params = compiler.compile(self.rhs) 

678 expressions.append(sql) 

679 expression_params.extend(params) 

680 # order of precedence 

681 expression_wrapper = "(%s)" 

682 sql = connection.ops.combine_expression(self.connector, expressions) 

683 return expression_wrapper % sql, expression_params 

684 

685 def resolve_expression( 

686 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

687 ): 

688 lhs = self.lhs.resolve_expression( 

689 query, allow_joins, reuse, summarize, for_save 

690 ) 

691 rhs = self.rhs.resolve_expression( 

692 query, allow_joins, reuse, summarize, for_save 

693 ) 

694 if not isinstance(self, DurationExpression | TemporalSubtraction): 

695 try: 

696 lhs_type = lhs.output_field.get_internal_type() 

697 except (AttributeError, FieldError): 

698 lhs_type = None 

699 try: 

700 rhs_type = rhs.output_field.get_internal_type() 

701 except (AttributeError, FieldError): 

702 rhs_type = None 

703 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type: 

704 return DurationExpression( 

705 self.lhs, self.connector, self.rhs 

706 ).resolve_expression( 

707 query, 

708 allow_joins, 

709 reuse, 

710 summarize, 

711 for_save, 

712 ) 

713 datetime_fields = {"DateField", "DateTimeField", "TimeField"} 

714 if ( 

715 self.connector == self.SUB 

716 and lhs_type in datetime_fields 

717 and lhs_type == rhs_type 

718 ): 

719 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression( 

720 query, 

721 allow_joins, 

722 reuse, 

723 summarize, 

724 for_save, 

725 ) 

726 c = self.copy() 

727 c.is_summary = summarize 

728 c.lhs = lhs 

729 c.rhs = rhs 

730 return c 

731 

732 

733class DurationExpression(CombinedExpression): 

734 def compile(self, side, compiler, connection): 

735 try: 

736 output = side.output_field 

737 except FieldError: 

738 pass 

739 else: 

740 if output.get_internal_type() == "DurationField": 

741 sql, params = compiler.compile(side) 

742 return connection.ops.format_for_duration_arithmetic(sql), params 

743 return compiler.compile(side) 

744 

745 def as_sql(self, compiler, connection): 

746 if connection.features.has_native_duration_field: 

747 return super().as_sql(compiler, connection) 

748 connection.ops.check_expression_support(self) 

749 expressions = [] 

750 expression_params = [] 

751 sql, params = self.compile(self.lhs, compiler, connection) 

752 expressions.append(sql) 

753 expression_params.extend(params) 

754 sql, params = self.compile(self.rhs, compiler, connection) 

755 expressions.append(sql) 

756 expression_params.extend(params) 

757 # order of precedence 

758 expression_wrapper = "(%s)" 

759 sql = connection.ops.combine_duration_expression(self.connector, expressions) 

760 return expression_wrapper % sql, expression_params 

761 

762 def as_sqlite(self, compiler, connection, **extra_context): 

763 sql, params = self.as_sql(compiler, connection, **extra_context) 

764 if self.connector in {Combinable.MUL, Combinable.DIV}: 

765 try: 

766 lhs_type = self.lhs.output_field.get_internal_type() 

767 rhs_type = self.rhs.output_field.get_internal_type() 

768 except (AttributeError, FieldError): 

769 pass 

770 else: 

771 allowed_fields = { 

772 "DecimalField", 

773 "DurationField", 

774 "FloatField", 

775 "IntegerField", 

776 } 

777 if lhs_type not in allowed_fields or rhs_type not in allowed_fields: 

778 raise DatabaseError( 

779 f"Invalid arguments for operator {self.connector}." 

780 ) 

781 return sql, params 

782 

783 

784class TemporalSubtraction(CombinedExpression): 

785 output_field = fields.DurationField() 

786 

787 def __init__(self, lhs, rhs): 

788 super().__init__(lhs, self.SUB, rhs) 

789 

790 def as_sql(self, compiler, connection): 

791 connection.ops.check_expression_support(self) 

792 lhs = compiler.compile(self.lhs) 

793 rhs = compiler.compile(self.rhs) 

794 return connection.ops.subtract_temporals( 

795 self.lhs.output_field.get_internal_type(), lhs, rhs 

796 ) 

797 

798 

799@deconstructible(path="plain.models.F") 

800class F(Combinable): 

801 """An object capable of resolving references to existing query objects.""" 

802 

803 def __init__(self, name): 

804 """ 

805 Arguments: 

806 * name: the name of the field this expression references 

807 """ 

808 self.name = name 

809 

810 def __repr__(self): 

811 return f"{self.__class__.__name__}({self.name})" 

812 

813 def resolve_expression( 

814 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

815 ): 

816 return query.resolve_ref(self.name, allow_joins, reuse, summarize) 

817 

818 def replace_expressions(self, replacements): 

819 return replacements.get(self, self) 

820 

821 def asc(self, **kwargs): 

822 return OrderBy(self, **kwargs) 

823 

824 def desc(self, **kwargs): 

825 return OrderBy(self, descending=True, **kwargs) 

826 

827 def __eq__(self, other): 

828 return self.__class__ == other.__class__ and self.name == other.name 

829 

830 def __hash__(self): 

831 return hash(self.name) 

832 

833 def copy(self): 

834 return copy.copy(self) 

835 

836 

837class ResolvedOuterRef(F): 

838 """ 

839 An object that contains a reference to an outer query. 

840 

841 In this case, the reference to the outer query has been resolved because 

842 the inner query has been used as a subquery. 

843 """ 

844 

845 contains_aggregate = False 

846 contains_over_clause = False 

847 

848 def as_sql(self, *args, **kwargs): 

849 raise ValueError( 

850 "This queryset contains a reference to an outer query and may " 

851 "only be used in a subquery." 

852 ) 

853 

854 def resolve_expression(self, *args, **kwargs): 

855 col = super().resolve_expression(*args, **kwargs) 

856 if col.contains_over_clause: 

857 raise NotSupportedError( 

858 f"Referencing outer query window expression is not supported: " 

859 f"{self.name}." 

860 ) 

861 # FIXME: Rename possibly_multivalued to multivalued and fix detection 

862 # for non-multivalued JOINs (e.g. foreign key fields). This should take 

863 # into account only many-to-many and one-to-many relationships. 

864 col.possibly_multivalued = LOOKUP_SEP in self.name 

865 return col 

866 

867 def relabeled_clone(self, relabels): 

868 return self 

869 

870 def get_group_by_cols(self): 

871 return [] 

872 

873 

874class OuterRef(F): 

875 contains_aggregate = False 

876 

877 def resolve_expression(self, *args, **kwargs): 

878 if isinstance(self.name, self.__class__): 

879 return self.name 

880 return ResolvedOuterRef(self.name) 

881 

882 def relabeled_clone(self, relabels): 

883 return self 

884 

885 

886@deconstructible(path="plain.models.Func") 

887class Func(SQLiteNumericMixin, Expression): 

888 """An SQL function call.""" 

889 

890 function = None 

891 template = "%(function)s(%(expressions)s)" 

892 arg_joiner = ", " 

893 arity = None # The number of arguments the function accepts. 

894 

895 def __init__(self, *expressions, output_field=None, **extra): 

896 if self.arity is not None and len(expressions) != self.arity: 

897 raise TypeError( 

898 "'{}' takes exactly {} {} ({} given)".format( 

899 self.__class__.__name__, 

900 self.arity, 

901 "argument" if self.arity == 1 else "arguments", 

902 len(expressions), 

903 ) 

904 ) 

905 super().__init__(output_field=output_field) 

906 self.source_expressions = self._parse_expressions(*expressions) 

907 self.extra = extra 

908 

909 def __repr__(self): 

910 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions) 

911 extra = {**self.extra, **self._get_repr_options()} 

912 if extra: 

913 extra = ", ".join( 

914 str(key) + "=" + str(val) for key, val in sorted(extra.items()) 

915 ) 

916 return f"{self.__class__.__name__}({args}, {extra})" 

917 return f"{self.__class__.__name__}({args})" 

918 

919 def _get_repr_options(self): 

920 """Return a dict of extra __init__() options to include in the repr.""" 

921 return {} 

922 

923 def get_source_expressions(self): 

924 return self.source_expressions 

925 

926 def set_source_expressions(self, exprs): 

927 self.source_expressions = exprs 

928 

929 def resolve_expression( 

930 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

931 ): 

932 c = self.copy() 

933 c.is_summary = summarize 

934 for pos, arg in enumerate(c.source_expressions): 

935 c.source_expressions[pos] = arg.resolve_expression( 

936 query, allow_joins, reuse, summarize, for_save 

937 ) 

938 return c 

939 

940 def as_sql( 

941 self, 

942 compiler, 

943 connection, 

944 function=None, 

945 template=None, 

946 arg_joiner=None, 

947 **extra_context, 

948 ): 

949 connection.ops.check_expression_support(self) 

950 sql_parts = [] 

951 params = [] 

952 for arg in self.source_expressions: 

953 try: 

954 arg_sql, arg_params = compiler.compile(arg) 

955 except EmptyResultSet: 

956 empty_result_set_value = getattr( 

957 arg, "empty_result_set_value", NotImplemented 

958 ) 

959 if empty_result_set_value is NotImplemented: 

960 raise 

961 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value)) 

962 except FullResultSet: 

963 arg_sql, arg_params = compiler.compile(Value(True)) 

964 sql_parts.append(arg_sql) 

965 params.extend(arg_params) 

966 data = {**self.extra, **extra_context} 

967 # Use the first supplied value in this order: the parameter to this 

968 # method, a value supplied in __init__()'s **extra (the value in 

969 # `data`), or the value defined on the class. 

970 if function is not None: 

971 data["function"] = function 

972 else: 

973 data.setdefault("function", self.function) 

974 template = template or data.get("template", self.template) 

975 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner) 

976 data["expressions"] = data["field"] = arg_joiner.join(sql_parts) 

977 return template % data, params 

978 

979 def copy(self): 

980 copy = super().copy() 

981 copy.source_expressions = self.source_expressions[:] 

982 copy.extra = self.extra.copy() 

983 return copy 

984 

985 

986@deconstructible(path="plain.models.Value") 

987class Value(SQLiteNumericMixin, Expression): 

988 """Represent a wrapped value as a node within an expression.""" 

989 

990 # Provide a default value for `for_save` in order to allow unresolved 

991 # instances to be compiled until a decision is taken in #25425. 

992 for_save = False 

993 

994 def __init__(self, value, output_field=None): 

995 """ 

996 Arguments: 

997 * value: the value this expression represents. The value will be 

998 added into the sql parameter list and properly quoted. 

999 

1000 * output_field: an instance of the model field type that this 

1001 expression will return, such as IntegerField() or CharField(). 

1002 """ 

1003 super().__init__(output_field=output_field) 

1004 self.value = value 

1005 

1006 def __repr__(self): 

1007 return f"{self.__class__.__name__}({self.value!r})" 

1008 

1009 def as_sql(self, compiler, connection): 

1010 connection.ops.check_expression_support(self) 

1011 val = self.value 

1012 output_field = self._output_field_or_none 

1013 if output_field is not None: 

1014 if self.for_save: 

1015 val = output_field.get_db_prep_save(val, connection=connection) 

1016 else: 

1017 val = output_field.get_db_prep_value(val, connection=connection) 

1018 if hasattr(output_field, "get_placeholder"): 

1019 return output_field.get_placeholder(val, compiler, connection), [val] 

1020 if val is None: 

1021 # cx_Oracle does not always convert None to the appropriate 

1022 # NULL type (like in case expressions using numbers), so we 

1023 # use a literal SQL NULL 

1024 return "NULL", [] 

1025 return "%s", [val] 

1026 

1027 def resolve_expression( 

1028 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

1029 ): 

1030 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) 

1031 c.for_save = for_save 

1032 return c 

1033 

1034 def get_group_by_cols(self): 

1035 return [] 

1036 

1037 def _resolve_output_field(self): 

1038 if isinstance(self.value, str): 

1039 return fields.CharField() 

1040 if isinstance(self.value, bool): 

1041 return fields.BooleanField() 

1042 if isinstance(self.value, int): 

1043 return fields.IntegerField() 

1044 if isinstance(self.value, float): 

1045 return fields.FloatField() 

1046 if isinstance(self.value, datetime.datetime): 

1047 return fields.DateTimeField() 

1048 if isinstance(self.value, datetime.date): 

1049 return fields.DateField() 

1050 if isinstance(self.value, datetime.time): 

1051 return fields.TimeField() 

1052 if isinstance(self.value, datetime.timedelta): 

1053 return fields.DurationField() 

1054 if isinstance(self.value, Decimal): 

1055 return fields.DecimalField() 

1056 if isinstance(self.value, bytes): 

1057 return fields.BinaryField() 

1058 if isinstance(self.value, UUID): 

1059 return fields.UUIDField() 

1060 

1061 @property 

1062 def empty_result_set_value(self): 

1063 return self.value 

1064 

1065 

1066class RawSQL(Expression): 

1067 def __init__(self, sql, params, output_field=None): 

1068 if output_field is None: 

1069 output_field = fields.Field() 

1070 self.sql, self.params = sql, params 

1071 super().__init__(output_field=output_field) 

1072 

1073 def __repr__(self): 

1074 return f"{self.__class__.__name__}({self.sql}, {self.params})" 

1075 

1076 def as_sql(self, compiler, connection): 

1077 return f"({self.sql})", self.params 

1078 

1079 def get_group_by_cols(self): 

1080 return [self] 

1081 

1082 def resolve_expression( 

1083 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

1084 ): 

1085 # Resolve parents fields used in raw SQL. 

1086 if query.model: 

1087 for parent in query.model._meta.get_parent_list(): 

1088 for parent_field in parent._meta.local_fields: 

1089 _, column_name = parent_field.get_attname_column() 

1090 if column_name.lower() in self.sql.lower(): 

1091 query.resolve_ref( 

1092 parent_field.name, allow_joins, reuse, summarize 

1093 ) 

1094 break 

1095 return super().resolve_expression( 

1096 query, allow_joins, reuse, summarize, for_save 

1097 ) 

1098 

1099 

1100class Star(Expression): 

1101 def __repr__(self): 

1102 return "'*'" 

1103 

1104 def as_sql(self, compiler, connection): 

1105 return "*", [] 

1106 

1107 

1108class Col(Expression): 

1109 contains_column_references = True 

1110 possibly_multivalued = False 

1111 

1112 def __init__(self, alias, target, output_field=None): 

1113 if output_field is None: 

1114 output_field = target 

1115 super().__init__(output_field=output_field) 

1116 self.alias, self.target = alias, target 

1117 

1118 def __repr__(self): 

1119 alias, target = self.alias, self.target 

1120 identifiers = (alias, str(target)) if alias else (str(target),) 

1121 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers)) 

1122 

1123 def as_sql(self, compiler, connection): 

1124 alias, column = self.alias, self.target.column 

1125 identifiers = (alias, column) if alias else (column,) 

1126 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers)) 

1127 return sql, [] 

1128 

1129 def relabeled_clone(self, relabels): 

1130 if self.alias is None: 

1131 return self 

1132 return self.__class__( 

1133 relabels.get(self.alias, self.alias), self.target, self.output_field 

1134 ) 

1135 

1136 def get_group_by_cols(self): 

1137 return [self] 

1138 

1139 def get_db_converters(self, connection): 

1140 if self.target == self.output_field: 

1141 return self.output_field.get_db_converters(connection) 

1142 return self.output_field.get_db_converters( 

1143 connection 

1144 ) + self.target.get_db_converters(connection) 

1145 

1146 

1147class Ref(Expression): 

1148 """ 

1149 Reference to column alias of the query. For example, Ref('sum_cost') in 

1150 qs.annotate(sum_cost=Sum('cost')) query. 

1151 """ 

1152 

1153 def __init__(self, refs, source): 

1154 super().__init__() 

1155 self.refs, self.source = refs, source 

1156 

1157 def __repr__(self): 

1158 return f"{self.__class__.__name__}({self.refs}, {self.source})" 

1159 

1160 def get_source_expressions(self): 

1161 return [self.source] 

1162 

1163 def set_source_expressions(self, exprs): 

1164 (self.source,) = exprs 

1165 

1166 def resolve_expression( 

1167 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

1168 ): 

1169 # The sub-expression `source` has already been resolved, as this is 

1170 # just a reference to the name of `source`. 

1171 return self 

1172 

1173 def get_refs(self): 

1174 return {self.refs} 

1175 

1176 def relabeled_clone(self, relabels): 

1177 return self 

1178 

1179 def as_sql(self, compiler, connection): 

1180 return connection.ops.quote_name(self.refs), [] 

1181 

1182 def get_group_by_cols(self): 

1183 return [self] 

1184 

1185 

1186class ExpressionList(Func): 

1187 """ 

1188 An expression containing multiple expressions. Can be used to provide a 

1189 list of expressions as an argument to another expression, like a partition 

1190 clause. 

1191 """ 

1192 

1193 template = "%(expressions)s" 

1194 

1195 def __init__(self, *expressions, **extra): 

1196 if not expressions: 

1197 raise ValueError( 

1198 f"{self.__class__.__name__} requires at least one expression." 

1199 ) 

1200 super().__init__(*expressions, **extra) 

1201 

1202 def __str__(self): 

1203 return self.arg_joiner.join(str(arg) for arg in self.source_expressions) 

1204 

1205 def as_sqlite(self, compiler, connection, **extra_context): 

1206 # Casting to numeric is unnecessary. 

1207 return self.as_sql(compiler, connection, **extra_context) 

1208 

1209 

1210class OrderByList(Func): 

1211 template = "ORDER BY %(expressions)s" 

1212 

1213 def __init__(self, *expressions, **extra): 

1214 expressions = ( 

1215 ( 

1216 OrderBy(F(expr[1:]), descending=True) 

1217 if isinstance(expr, str) and expr[0] == "-" 

1218 else expr 

1219 ) 

1220 for expr in expressions 

1221 ) 

1222 super().__init__(*expressions, **extra) 

1223 

1224 def as_sql(self, *args, **kwargs): 

1225 if not self.source_expressions: 

1226 return "", () 

1227 return super().as_sql(*args, **kwargs) 

1228 

1229 def get_group_by_cols(self): 

1230 group_by_cols = [] 

1231 for order_by in self.get_source_expressions(): 

1232 group_by_cols.extend(order_by.get_group_by_cols()) 

1233 return group_by_cols 

1234 

1235 

1236@deconstructible(path="plain.models.ExpressionWrapper") 

1237class ExpressionWrapper(SQLiteNumericMixin, Expression): 

1238 """ 

1239 An expression that can wrap another expression so that it can provide 

1240 extra context to the inner expression, such as the output_field. 

1241 """ 

1242 

1243 def __init__(self, expression, output_field): 

1244 super().__init__(output_field=output_field) 

1245 self.expression = expression 

1246 

1247 def set_source_expressions(self, exprs): 

1248 self.expression = exprs[0] 

1249 

1250 def get_source_expressions(self): 

1251 return [self.expression] 

1252 

1253 def get_group_by_cols(self): 

1254 if isinstance(self.expression, Expression): 

1255 expression = self.expression.copy() 

1256 expression.output_field = self.output_field 

1257 return expression.get_group_by_cols() 

1258 # For non-expressions e.g. an SQL WHERE clause, the entire 

1259 # `expression` must be included in the GROUP BY clause. 

1260 return super().get_group_by_cols() 

1261 

1262 def as_sql(self, compiler, connection): 

1263 return compiler.compile(self.expression) 

1264 

1265 def __repr__(self): 

1266 return f"{self.__class__.__name__}({self.expression})" 

1267 

1268 

1269class NegatedExpression(ExpressionWrapper): 

1270 """The logical negation of a conditional expression.""" 

1271 

1272 def __init__(self, expression): 

1273 super().__init__(expression, output_field=fields.BooleanField()) 

1274 

1275 def __invert__(self): 

1276 return self.expression.copy() 

1277 

1278 def as_sql(self, compiler, connection): 

1279 try: 

1280 sql, params = super().as_sql(compiler, connection) 

1281 except EmptyResultSet: 

1282 features = compiler.connection.features 

1283 if not features.supports_boolean_expr_in_select_clause: 

1284 return "1=1", () 

1285 return compiler.compile(Value(True)) 

1286 ops = compiler.connection.ops 

1287 # Some database backends (e.g. Oracle) don't allow EXISTS() and filters 

1288 # to be compared to another expression unless they're wrapped in a CASE 

1289 # WHEN. 

1290 if not ops.conditional_expression_supported_in_where_clause(self.expression): 

1291 return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params 

1292 return f"NOT {sql}", params 

1293 

1294 def resolve_expression( 

1295 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

1296 ): 

1297 resolved = super().resolve_expression( 

1298 query, allow_joins, reuse, summarize, for_save 

1299 ) 

1300 if not getattr(resolved.expression, "conditional", False): 

1301 raise TypeError("Cannot negate non-conditional expressions.") 

1302 return resolved 

1303 

1304 def select_format(self, compiler, sql, params): 

1305 # Wrap boolean expressions with a CASE WHEN expression if a database 

1306 # backend (e.g. Oracle) doesn't support boolean expression in SELECT or 

1307 # GROUP BY list. 

1308 expression_supported_in_where_clause = ( 

1309 compiler.connection.ops.conditional_expression_supported_in_where_clause 

1310 ) 

1311 if ( 

1312 not compiler.connection.features.supports_boolean_expr_in_select_clause 

1313 # Avoid double wrapping. 

1314 and expression_supported_in_where_clause(self.expression) 

1315 ): 

1316 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" 

1317 return sql, params 

1318 

1319 

1320@deconstructible(path="plain.models.When") 

1321class When(Expression): 

1322 template = "WHEN %(condition)s THEN %(result)s" 

1323 # This isn't a complete conditional expression, must be used in Case(). 

1324 conditional = False 

1325 

1326 def __init__(self, condition=None, then=None, **lookups): 

1327 if lookups: 

1328 if condition is None: 

1329 condition, lookups = Q(**lookups), None 

1330 elif getattr(condition, "conditional", False): 

1331 condition, lookups = Q(condition, **lookups), None 

1332 if condition is None or not getattr(condition, "conditional", False) or lookups: 

1333 raise TypeError( 

1334 "When() supports a Q object, a boolean expression, or lookups " 

1335 "as a condition." 

1336 ) 

1337 if isinstance(condition, Q) and not condition: 

1338 raise ValueError("An empty Q() can't be used as a When() condition.") 

1339 super().__init__(output_field=None) 

1340 self.condition = condition 

1341 self.result = self._parse_expressions(then)[0] 

1342 

1343 def __str__(self): 

1344 return f"WHEN {self.condition!r} THEN {self.result!r}" 

1345 

1346 def __repr__(self): 

1347 return f"<{self.__class__.__name__}: {self}>" 

1348 

1349 def get_source_expressions(self): 

1350 return [self.condition, self.result] 

1351 

1352 def set_source_expressions(self, exprs): 

1353 self.condition, self.result = exprs 

1354 

1355 def get_source_fields(self): 

1356 # We're only interested in the fields of the result expressions. 

1357 return [self.result._output_field_or_none] 

1358 

1359 def resolve_expression( 

1360 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

1361 ): 

1362 c = self.copy() 

1363 c.is_summary = summarize 

1364 if hasattr(c.condition, "resolve_expression"): 

1365 c.condition = c.condition.resolve_expression( 

1366 query, allow_joins, reuse, summarize, False 

1367 ) 

1368 c.result = c.result.resolve_expression( 

1369 query, allow_joins, reuse, summarize, for_save 

1370 ) 

1371 return c 

1372 

1373 def as_sql(self, compiler, connection, template=None, **extra_context): 

1374 connection.ops.check_expression_support(self) 

1375 template_params = extra_context 

1376 sql_params = [] 

1377 condition_sql, condition_params = compiler.compile(self.condition) 

1378 template_params["condition"] = condition_sql 

1379 result_sql, result_params = compiler.compile(self.result) 

1380 template_params["result"] = result_sql 

1381 template = template or self.template 

1382 return template % template_params, ( 

1383 *sql_params, 

1384 *condition_params, 

1385 *result_params, 

1386 ) 

1387 

1388 def get_group_by_cols(self): 

1389 # This is not a complete expression and cannot be used in GROUP BY. 

1390 cols = [] 

1391 for source in self.get_source_expressions(): 

1392 cols.extend(source.get_group_by_cols()) 

1393 return cols 

1394 

1395 

1396@deconstructible(path="plain.models.Case") 

1397class Case(SQLiteNumericMixin, Expression): 

1398 """ 

1399 An SQL searched CASE expression: 

1400 

1401 CASE 

1402 WHEN n > 0 

1403 THEN 'positive' 

1404 WHEN n < 0 

1405 THEN 'negative' 

1406 ELSE 'zero' 

1407 END 

1408 """ 

1409 

1410 template = "CASE %(cases)s ELSE %(default)s END" 

1411 case_joiner = " " 

1412 

1413 def __init__(self, *cases, default=None, output_field=None, **extra): 

1414 if not all(isinstance(case, When) for case in cases): 

1415 raise TypeError("Positional arguments must all be When objects.") 

1416 super().__init__(output_field) 

1417 self.cases = list(cases) 

1418 self.default = self._parse_expressions(default)[0] 

1419 self.extra = extra 

1420 

1421 def __str__(self): 

1422 return "CASE {}, ELSE {!r}".format( 

1423 ", ".join(str(c) for c in self.cases), 

1424 self.default, 

1425 ) 

1426 

1427 def __repr__(self): 

1428 return f"<{self.__class__.__name__}: {self}>" 

1429 

1430 def get_source_expressions(self): 

1431 return self.cases + [self.default] 

1432 

1433 def set_source_expressions(self, exprs): 

1434 *self.cases, self.default = exprs 

1435 

1436 def resolve_expression( 

1437 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False 

1438 ): 

1439 c = self.copy() 

1440 c.is_summary = summarize 

1441 for pos, case in enumerate(c.cases): 

1442 c.cases[pos] = case.resolve_expression( 

1443 query, allow_joins, reuse, summarize, for_save 

1444 ) 

1445 c.default = c.default.resolve_expression( 

1446 query, allow_joins, reuse, summarize, for_save 

1447 ) 

1448 return c 

1449 

1450 def copy(self): 

1451 c = super().copy() 

1452 c.cases = c.cases[:] 

1453 return c 

1454 

1455 def as_sql( 

1456 self, compiler, connection, template=None, case_joiner=None, **extra_context 

1457 ): 

1458 connection.ops.check_expression_support(self) 

1459 if not self.cases: 

1460 return compiler.compile(self.default) 

1461 template_params = {**self.extra, **extra_context} 

1462 case_parts = [] 

1463 sql_params = [] 

1464 default_sql, default_params = compiler.compile(self.default) 

1465 for case in self.cases: 

1466 try: 

1467 case_sql, case_params = compiler.compile(case) 

1468 except EmptyResultSet: 

1469 continue 

1470 except FullResultSet: 

1471 default_sql, default_params = compiler.compile(case.result) 

1472 break 

1473 case_parts.append(case_sql) 

1474 sql_params.extend(case_params) 

1475 if not case_parts: 

1476 return default_sql, default_params 

1477 case_joiner = case_joiner or self.case_joiner 

1478 template_params["cases"] = case_joiner.join(case_parts) 

1479 template_params["default"] = default_sql 

1480 sql_params.extend(default_params) 

1481 template = template or template_params.get("template", self.template) 

1482 sql = template % template_params 

1483 if self._output_field_or_none is not None: 

1484 sql = connection.ops.unification_cast_sql(self.output_field) % sql 

1485 return sql, sql_params 

1486 

1487 def get_group_by_cols(self): 

1488 if not self.cases: 

1489 return self.default.get_group_by_cols() 

1490 return super().get_group_by_cols() 

1491 

1492 

1493class Subquery(BaseExpression, Combinable): 

1494 """ 

1495 An explicit subquery. It may contain OuterRef() references to the outer 

1496 query which will be resolved when it is applied to that query. 

1497 """ 

1498 

1499 template = "(%(subquery)s)" 

1500 contains_aggregate = False 

1501 empty_result_set_value = None 

1502 

1503 def __init__(self, queryset, output_field=None, **extra): 

1504 # Allow the usage of both QuerySet and sql.Query objects. 

1505 self.query = getattr(queryset, "query", queryset).clone() 

1506 self.query.subquery = True 

1507 self.extra = extra 

1508 super().__init__(output_field) 

1509 

1510 def get_source_expressions(self): 

1511 return [self.query] 

1512 

1513 def set_source_expressions(self, exprs): 

1514 self.query = exprs[0] 

1515 

1516 def _resolve_output_field(self): 

1517 return self.query.output_field 

1518 

1519 def copy(self): 

1520 clone = super().copy() 

1521 clone.query = clone.query.clone() 

1522 return clone 

1523 

1524 @property 

1525 def external_aliases(self): 

1526 return self.query.external_aliases 

1527 

1528 def get_external_cols(self): 

1529 return self.query.get_external_cols() 

1530 

1531 def as_sql(self, compiler, connection, template=None, **extra_context): 

1532 connection.ops.check_expression_support(self) 

1533 template_params = {**self.extra, **extra_context} 

1534 subquery_sql, sql_params = self.query.as_sql(compiler, connection) 

1535 template_params["subquery"] = subquery_sql[1:-1] 

1536 

1537 template = template or template_params.get("template", self.template) 

1538 sql = template % template_params 

1539 return sql, sql_params 

1540 

1541 def get_group_by_cols(self): 

1542 return self.query.get_group_by_cols(wrapper=self) 

1543 

1544 

1545class Exists(Subquery): 

1546 template = "EXISTS(%(subquery)s)" 

1547 output_field = fields.BooleanField() 

1548 empty_result_set_value = False 

1549 

1550 def __init__(self, queryset, **kwargs): 

1551 super().__init__(queryset, **kwargs) 

1552 self.query = self.query.exists() 

1553 

1554 def select_format(self, compiler, sql, params): 

1555 # Wrap EXISTS() with a CASE WHEN expression if a database backend 

1556 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP 

1557 # BY list. 

1558 if not compiler.connection.features.supports_boolean_expr_in_select_clause: 

1559 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END" 

1560 return sql, params 

1561 

1562 

1563@deconstructible(path="plain.models.OrderBy") 

1564class OrderBy(Expression): 

1565 template = "%(expression)s %(ordering)s" 

1566 conditional = False 

1567 

1568 def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None): 

1569 if nulls_first and nulls_last: 

1570 raise ValueError("nulls_first and nulls_last are mutually exclusive") 

1571 if nulls_first is False or nulls_last is False: 

1572 raise ValueError("nulls_first and nulls_last values must be True or None.") 

1573 self.nulls_first = nulls_first 

1574 self.nulls_last = nulls_last 

1575 self.descending = descending 

1576 if not hasattr(expression, "resolve_expression"): 

1577 raise ValueError("expression must be an expression type") 

1578 self.expression = expression 

1579 

1580 def __repr__(self): 

1581 return f"{self.__class__.__name__}({self.expression}, descending={self.descending})" 

1582 

1583 def set_source_expressions(self, exprs): 

1584 self.expression = exprs[0] 

1585 

1586 def get_source_expressions(self): 

1587 return [self.expression] 

1588 

1589 def as_sql(self, compiler, connection, template=None, **extra_context): 

1590 template = template or self.template 

1591 if connection.features.supports_order_by_nulls_modifier: 

1592 if self.nulls_last: 

1593 template = f"{template} NULLS LAST" 

1594 elif self.nulls_first: 

1595 template = f"{template} NULLS FIRST" 

1596 else: 

1597 if self.nulls_last and not ( 

1598 self.descending and connection.features.order_by_nulls_first 

1599 ): 

1600 template = f"%(expression)s IS NULL, {template}" 

1601 elif self.nulls_first and not ( 

1602 not self.descending and connection.features.order_by_nulls_first 

1603 ): 

1604 template = f"%(expression)s IS NOT NULL, {template}" 

1605 connection.ops.check_expression_support(self) 

1606 expression_sql, params = compiler.compile(self.expression) 

1607 placeholders = { 

1608 "expression": expression_sql, 

1609 "ordering": "DESC" if self.descending else "ASC", 

1610 **extra_context, 

1611 } 

1612 params *= template.count("%(expression)s") 

1613 return (template % placeholders).rstrip(), params 

1614 

1615 def get_group_by_cols(self): 

1616 cols = [] 

1617 for source in self.get_source_expressions(): 

1618 cols.extend(source.get_group_by_cols()) 

1619 return cols 

1620 

1621 def reverse_ordering(self): 

1622 self.descending = not self.descending 

1623 if self.nulls_first: 

1624 self.nulls_last = True 

1625 self.nulls_first = None 

1626 elif self.nulls_last: 

1627 self.nulls_first = True 

1628 self.nulls_last = None 

1629 return self 

1630 

1631 def asc(self): 

1632 self.descending = False 

1633 

1634 def desc(self): 

1635 self.descending = True 

1636 

1637 

1638class Window(SQLiteNumericMixin, Expression): 

1639 template = "%(expression)s OVER (%(window)s)" 

1640 # Although the main expression may either be an aggregate or an 

1641 # expression with an aggregate function, the GROUP BY that will 

1642 # be introduced in the query as a result is not desired. 

1643 contains_aggregate = False 

1644 contains_over_clause = True 

1645 

1646 def __init__( 

1647 self, 

1648 expression, 

1649 partition_by=None, 

1650 order_by=None, 

1651 frame=None, 

1652 output_field=None, 

1653 ): 

1654 self.partition_by = partition_by 

1655 self.order_by = order_by 

1656 self.frame = frame 

1657 

1658 if not getattr(expression, "window_compatible", False): 

1659 raise ValueError( 

1660 f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses." 

1661 ) 

1662 

1663 if self.partition_by is not None: 

1664 if not isinstance(self.partition_by, tuple | list): 

1665 self.partition_by = (self.partition_by,) 

1666 self.partition_by = ExpressionList(*self.partition_by) 

1667 

1668 if self.order_by is not None: 

1669 if isinstance(self.order_by, list | tuple): 

1670 self.order_by = OrderByList(*self.order_by) 

1671 elif isinstance(self.order_by, BaseExpression | str): 

1672 self.order_by = OrderByList(self.order_by) 

1673 else: 

1674 raise ValueError( 

1675 "Window.order_by must be either a string reference to a " 

1676 "field, an expression, or a list or tuple of them." 

1677 ) 

1678 super().__init__(output_field=output_field) 

1679 self.source_expression = self._parse_expressions(expression)[0] 

1680 

1681 def _resolve_output_field(self): 

1682 return self.source_expression.output_field 

1683 

1684 def get_source_expressions(self): 

1685 return [self.source_expression, self.partition_by, self.order_by, self.frame] 

1686 

1687 def set_source_expressions(self, exprs): 

1688 self.source_expression, self.partition_by, self.order_by, self.frame = exprs 

1689 

1690 def as_sql(self, compiler, connection, template=None): 

1691 connection.ops.check_expression_support(self) 

1692 if not connection.features.supports_over_clause: 

1693 raise NotSupportedError("This backend does not support window expressions.") 

1694 expr_sql, params = compiler.compile(self.source_expression) 

1695 window_sql, window_params = [], () 

1696 

1697 if self.partition_by is not None: 

1698 sql_expr, sql_params = self.partition_by.as_sql( 

1699 compiler=compiler, 

1700 connection=connection, 

1701 template="PARTITION BY %(expressions)s", 

1702 ) 

1703 window_sql.append(sql_expr) 

1704 window_params += tuple(sql_params) 

1705 

1706 if self.order_by is not None: 

1707 order_sql, order_params = compiler.compile(self.order_by) 

1708 window_sql.append(order_sql) 

1709 window_params += tuple(order_params) 

1710 

1711 if self.frame: 

1712 frame_sql, frame_params = compiler.compile(self.frame) 

1713 window_sql.append(frame_sql) 

1714 window_params += tuple(frame_params) 

1715 

1716 template = template or self.template 

1717 

1718 return ( 

1719 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()}, 

1720 (*params, *window_params), 

1721 ) 

1722 

1723 def as_sqlite(self, compiler, connection): 

1724 if isinstance(self.output_field, fields.DecimalField): 

1725 # Casting to numeric must be outside of the window expression. 

1726 copy = self.copy() 

1727 source_expressions = copy.get_source_expressions() 

1728 source_expressions[0].output_field = fields.FloatField() 

1729 copy.set_source_expressions(source_expressions) 

1730 return super(Window, copy).as_sqlite(compiler, connection) 

1731 return self.as_sql(compiler, connection) 

1732 

1733 def __str__(self): 

1734 return "{} OVER ({}{}{})".format( 

1735 str(self.source_expression), 

1736 "PARTITION BY " + str(self.partition_by) if self.partition_by else "", 

1737 str(self.order_by or ""), 

1738 str(self.frame or ""), 

1739 ) 

1740 

1741 def __repr__(self): 

1742 return f"<{self.__class__.__name__}: {self}>" 

1743 

1744 def get_group_by_cols(self): 

1745 group_by_cols = [] 

1746 if self.partition_by: 

1747 group_by_cols.extend(self.partition_by.get_group_by_cols()) 

1748 if self.order_by is not None: 

1749 group_by_cols.extend(self.order_by.get_group_by_cols()) 

1750 return group_by_cols 

1751 

1752 

1753class WindowFrame(Expression): 

1754 """ 

1755 Model the frame clause in window expressions. There are two types of frame 

1756 clauses which are subclasses, however, all processing and validation (by no 

1757 means intended to be complete) is done here. Thus, providing an end for a 

1758 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last 

1759 row in the frame). 

1760 """ 

1761 

1762 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s" 

1763 

1764 def __init__(self, start=None, end=None): 

1765 self.start = Value(start) 

1766 self.end = Value(end) 

1767 

1768 def set_source_expressions(self, exprs): 

1769 self.start, self.end = exprs 

1770 

1771 def get_source_expressions(self): 

1772 return [self.start, self.end] 

1773 

1774 def as_sql(self, compiler, connection): 

1775 connection.ops.check_expression_support(self) 

1776 start, end = self.window_frame_start_end( 

1777 connection, self.start.value, self.end.value 

1778 ) 

1779 return ( 

1780 self.template 

1781 % { 

1782 "frame_type": self.frame_type, 

1783 "start": start, 

1784 "end": end, 

1785 }, 

1786 [], 

1787 ) 

1788 

1789 def __repr__(self): 

1790 return f"<{self.__class__.__name__}: {self}>" 

1791 

1792 def get_group_by_cols(self): 

1793 return [] 

1794 

1795 def __str__(self): 

1796 if self.start.value is not None and self.start.value < 0: 

1797 start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING) 

1798 elif self.start.value is not None and self.start.value == 0: 

1799 start = connection.ops.CURRENT_ROW 

1800 else: 

1801 start = connection.ops.UNBOUNDED_PRECEDING 

1802 

1803 if self.end.value is not None and self.end.value > 0: 

1804 end = "%d %s" % (self.end.value, connection.ops.FOLLOWING) 

1805 elif self.end.value is not None and self.end.value == 0: 

1806 end = connection.ops.CURRENT_ROW 

1807 else: 

1808 end = connection.ops.UNBOUNDED_FOLLOWING 

1809 return self.template % { 

1810 "frame_type": self.frame_type, 

1811 "start": start, 

1812 "end": end, 

1813 } 

1814 

1815 def window_frame_start_end(self, connection, start, end): 

1816 raise NotImplementedError("Subclasses must implement window_frame_start_end().") 

1817 

1818 

1819class RowRange(WindowFrame): 

1820 frame_type = "ROWS" 

1821 

1822 def window_frame_start_end(self, connection, start, end): 

1823 return connection.ops.window_frame_rows_start_end(start, end) 

1824 

1825 

1826class ValueRange(WindowFrame): 

1827 frame_type = "RANGE" 

1828 

1829 def window_frame_start_end(self, connection, start, end): 

1830 return connection.ops.window_frame_range_start_end(start, end)