Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-models/plain/models/expressions.py: 44%
983 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-16 22:04 -0500
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-16 22:04 -0500
1import copy
2import datetime
3import functools
4import inspect
5from collections import defaultdict
6from decimal import Decimal
7from types import NoneType
8from uuid import UUID
10from plain.exceptions import EmptyResultSet, FieldError, FullResultSet
11from plain.models import fields
12from plain.models.constants import LOOKUP_SEP
13from plain.models.db import DatabaseError, NotSupportedError, connection
14from plain.models.query_utils import Q
15from plain.utils.deconstruct import deconstructible
16from plain.utils.functional import cached_property
17from plain.utils.hashable import make_hashable
20class SQLiteNumericMixin:
21 """
22 Some expressions with output_field=DecimalField() must be cast to
23 numeric to be properly filtered.
24 """
26 def as_sqlite(self, compiler, connection, **extra_context):
27 sql, params = self.as_sql(compiler, connection, **extra_context)
28 try:
29 if self.output_field.get_internal_type() == "DecimalField":
30 sql = "CAST(%s AS NUMERIC)" % sql
31 except FieldError:
32 pass
33 return sql, params
36class Combinable:
37 """
38 Provide the ability to combine one or two objects with
39 some connector. For example F('foo') + F('bar').
40 """
42 # Arithmetic connectors
43 ADD = "+"
44 SUB = "-"
45 MUL = "*"
46 DIV = "/"
47 POW = "^"
48 # The following is a quoted % operator - it is quoted because it can be
49 # used in strings that also have parameter substitution.
50 MOD = "%%"
52 # Bitwise operators - note that these are generated by .bitand()
53 # and .bitor(), the '&' and '|' are reserved for boolean operator
54 # usage.
55 BITAND = "&"
56 BITOR = "|"
57 BITLEFTSHIFT = "<<"
58 BITRIGHTSHIFT = ">>"
59 BITXOR = "#"
61 def _combine(self, other, connector, reversed):
62 if not hasattr(other, "resolve_expression"):
63 # everything must be resolvable to an expression
64 other = Value(other)
66 if reversed:
67 return CombinedExpression(other, connector, self)
68 return CombinedExpression(self, connector, other)
70 #############
71 # OPERATORS #
72 #############
74 def __neg__(self):
75 return self._combine(-1, self.MUL, False)
77 def __add__(self, other):
78 return self._combine(other, self.ADD, False)
80 def __sub__(self, other):
81 return self._combine(other, self.SUB, False)
83 def __mul__(self, other):
84 return self._combine(other, self.MUL, False)
86 def __truediv__(self, other):
87 return self._combine(other, self.DIV, False)
89 def __mod__(self, other):
90 return self._combine(other, self.MOD, False)
92 def __pow__(self, other):
93 return self._combine(other, self.POW, False)
95 def __and__(self, other):
96 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
97 return Q(self) & Q(other)
98 raise NotImplementedError(
99 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
100 )
102 def bitand(self, other):
103 return self._combine(other, self.BITAND, False)
105 def bitleftshift(self, other):
106 return self._combine(other, self.BITLEFTSHIFT, False)
108 def bitrightshift(self, other):
109 return self._combine(other, self.BITRIGHTSHIFT, False)
111 def __xor__(self, other):
112 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
113 return Q(self) ^ Q(other)
114 raise NotImplementedError(
115 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
116 )
118 def bitxor(self, other):
119 return self._combine(other, self.BITXOR, False)
121 def __or__(self, other):
122 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
123 return Q(self) | Q(other)
124 raise NotImplementedError(
125 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
126 )
128 def bitor(self, other):
129 return self._combine(other, self.BITOR, False)
131 def __radd__(self, other):
132 return self._combine(other, self.ADD, True)
134 def __rsub__(self, other):
135 return self._combine(other, self.SUB, True)
137 def __rmul__(self, other):
138 return self._combine(other, self.MUL, True)
140 def __rtruediv__(self, other):
141 return self._combine(other, self.DIV, True)
143 def __rmod__(self, other):
144 return self._combine(other, self.MOD, True)
146 def __rpow__(self, other):
147 return self._combine(other, self.POW, True)
149 def __rand__(self, other):
150 raise NotImplementedError(
151 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
152 )
154 def __ror__(self, other):
155 raise NotImplementedError(
156 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
157 )
159 def __rxor__(self, other):
160 raise NotImplementedError(
161 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
162 )
164 def __invert__(self):
165 return NegatedExpression(self)
168class BaseExpression:
169 """Base class for all query expressions."""
171 empty_result_set_value = NotImplemented
172 # aggregate specific fields
173 is_summary = False
174 _output_field_resolved_to_none = False
175 # Can the expression be used in a WHERE clause?
176 filterable = True
177 # Can the expression can be used as a source expression in Window?
178 window_compatible = False
180 def __init__(self, output_field=None):
181 if output_field is not None:
182 self.output_field = output_field
184 def __getstate__(self):
185 state = self.__dict__.copy()
186 state.pop("convert_value", None)
187 return state
189 def get_db_converters(self, connection):
190 return (
191 []
192 if self.convert_value is self._convert_value_noop
193 else [self.convert_value]
194 ) + self.output_field.get_db_converters(connection)
196 def get_source_expressions(self):
197 return []
199 def set_source_expressions(self, exprs):
200 assert not exprs
202 def _parse_expressions(self, *expressions):
203 return [
204 arg
205 if hasattr(arg, "resolve_expression")
206 else (F(arg) if isinstance(arg, str) else Value(arg))
207 for arg in expressions
208 ]
210 def as_sql(self, compiler, connection):
211 """
212 Responsible for returning a (sql, [params]) tuple to be included
213 in the current query.
215 Different backends can provide their own implementation, by
216 providing an `as_{vendor}` method and patching the Expression:
218 ```
219 def override_as_sql(self, compiler, connection):
220 # custom logic
221 return super().as_sql(compiler, connection)
222 setattr(Expression, 'as_' + connection.vendor, override_as_sql)
223 ```
225 Arguments:
226 * compiler: the query compiler responsible for generating the query.
227 Must have a compile method, returning a (sql, [params]) tuple.
228 Calling compiler(value) will return a quoted `value`.
230 * connection: the database connection used for the current query.
232 Return: (sql, params)
233 Where `sql` is a string containing ordered sql parameters to be
234 replaced with the elements of the list `params`.
235 """
236 raise NotImplementedError("Subclasses must implement as_sql()")
238 @cached_property
239 def contains_aggregate(self):
240 return any(
241 expr and expr.contains_aggregate for expr in self.get_source_expressions()
242 )
244 @cached_property
245 def contains_over_clause(self):
246 return any(
247 expr and expr.contains_over_clause for expr in self.get_source_expressions()
248 )
250 @cached_property
251 def contains_column_references(self):
252 return any(
253 expr and expr.contains_column_references
254 for expr in self.get_source_expressions()
255 )
257 def resolve_expression(
258 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
259 ):
260 """
261 Provide the chance to do any preprocessing or validation before being
262 added to the query.
264 Arguments:
265 * query: the backend query implementation
266 * allow_joins: boolean allowing or denying use of joins
267 in this query
268 * reuse: a set of reusable joins for multijoins
269 * summarize: a terminal aggregate clause
270 * for_save: whether this expression about to be used in a save or update
272 Return: an Expression to be added to the query.
273 """
274 c = self.copy()
275 c.is_summary = summarize
276 c.set_source_expressions(
277 [
278 expr.resolve_expression(query, allow_joins, reuse, summarize)
279 if expr
280 else None
281 for expr in c.get_source_expressions()
282 ]
283 )
284 return c
286 @property
287 def conditional(self):
288 return isinstance(self.output_field, fields.BooleanField)
290 @property
291 def field(self):
292 return self.output_field
294 @cached_property
295 def output_field(self):
296 """Return the output type of this expressions."""
297 output_field = self._resolve_output_field()
298 if output_field is None:
299 self._output_field_resolved_to_none = True
300 raise FieldError("Cannot resolve expression type, unknown output_field")
301 return output_field
303 @cached_property
304 def _output_field_or_none(self):
305 """
306 Return the output field of this expression, or None if
307 _resolve_output_field() didn't return an output type.
308 """
309 try:
310 return self.output_field
311 except FieldError:
312 if not self._output_field_resolved_to_none:
313 raise
315 def _resolve_output_field(self):
316 """
317 Attempt to infer the output type of the expression.
319 As a guess, if the output fields of all source fields match then simply
320 infer the same type here.
322 If a source's output field resolves to None, exclude it from this check.
323 If all sources are None, then an error is raised higher up the stack in
324 the output_field property.
325 """
326 # This guess is mostly a bad idea, but there is quite a lot of code
327 # (especially 3rd party Func subclasses) that depend on it, we'd need a
328 # deprecation path to fix it.
329 sources_iter = (
330 source for source in self.get_source_fields() if source is not None
331 )
332 for output_field in sources_iter:
333 for source in sources_iter:
334 if not isinstance(output_field, source.__class__):
335 raise FieldError(
336 "Expression contains mixed types: {}, {}. You must "
337 "set output_field.".format(
338 output_field.__class__.__name__,
339 source.__class__.__name__,
340 )
341 )
342 return output_field
344 @staticmethod
345 def _convert_value_noop(value, expression, connection):
346 return value
348 @cached_property
349 def convert_value(self):
350 """
351 Expressions provide their own converters because users have the option
352 of manually specifying the output_field which may be a different type
353 from the one the database returns.
354 """
355 field = self.output_field
356 internal_type = field.get_internal_type()
357 if internal_type == "FloatField":
358 return (
359 lambda value, expression, connection: None
360 if value is None
361 else float(value)
362 )
363 elif internal_type.endswith("IntegerField"):
364 return (
365 lambda value, expression, connection: None
366 if value is None
367 else int(value)
368 )
369 elif internal_type == "DecimalField":
370 return (
371 lambda value, expression, connection: None
372 if value is None
373 else Decimal(value)
374 )
375 return self._convert_value_noop
377 def get_lookup(self, lookup):
378 return self.output_field.get_lookup(lookup)
380 def get_transform(self, name):
381 return self.output_field.get_transform(name)
383 def relabeled_clone(self, change_map):
384 clone = self.copy()
385 clone.set_source_expressions(
386 [
387 e.relabeled_clone(change_map) if e is not None else None
388 for e in self.get_source_expressions()
389 ]
390 )
391 return clone
393 def replace_expressions(self, replacements):
394 if replacement := replacements.get(self):
395 return replacement
396 clone = self.copy()
397 source_expressions = clone.get_source_expressions()
398 clone.set_source_expressions(
399 [
400 expr.replace_expressions(replacements) if expr else None
401 for expr in source_expressions
402 ]
403 )
404 return clone
406 def get_refs(self):
407 refs = set()
408 for expr in self.get_source_expressions():
409 refs |= expr.get_refs()
410 return refs
412 def copy(self):
413 return copy.copy(self)
415 def prefix_references(self, prefix):
416 clone = self.copy()
417 clone.set_source_expressions(
418 [
419 F(f"{prefix}{expr.name}")
420 if isinstance(expr, F)
421 else expr.prefix_references(prefix)
422 for expr in self.get_source_expressions()
423 ]
424 )
425 return clone
427 def get_group_by_cols(self):
428 if not self.contains_aggregate:
429 return [self]
430 cols = []
431 for source in self.get_source_expressions():
432 cols.extend(source.get_group_by_cols())
433 return cols
435 def get_source_fields(self):
436 """Return the underlying field types used by this aggregate."""
437 return [e._output_field_or_none for e in self.get_source_expressions()]
439 def asc(self, **kwargs):
440 return OrderBy(self, **kwargs)
442 def desc(self, **kwargs):
443 return OrderBy(self, descending=True, **kwargs)
445 def reverse_ordering(self):
446 return self
448 def flatten(self):
449 """
450 Recursively yield this expression and all subexpressions, in
451 depth-first order.
452 """
453 yield self
454 for expr in self.get_source_expressions():
455 if expr:
456 if hasattr(expr, "flatten"):
457 yield from expr.flatten()
458 else:
459 yield expr
461 def select_format(self, compiler, sql, params):
462 """
463 Custom format for select clauses. For example, EXISTS expressions need
464 to be wrapped in CASE WHEN on Oracle.
465 """
466 if hasattr(self.output_field, "select_format"):
467 return self.output_field.select_format(compiler, sql, params)
468 return sql, params
471@deconstructible
472class Expression(BaseExpression, Combinable):
473 """An expression that can be combined with other expressions."""
475 @cached_property
476 def identity(self):
477 constructor_signature = inspect.signature(self.__init__)
478 args, kwargs = self._constructor_args
479 signature = constructor_signature.bind_partial(*args, **kwargs)
480 signature.apply_defaults()
481 arguments = signature.arguments.items()
482 identity = [self.__class__]
483 for arg, value in arguments:
484 if isinstance(value, fields.Field):
485 if value.name and value.model:
486 value = (value.model._meta.label, value.name)
487 else:
488 value = type(value)
489 else:
490 value = make_hashable(value)
491 identity.append((arg, value))
492 return tuple(identity)
494 def __eq__(self, other):
495 if not isinstance(other, Expression):
496 return NotImplemented
497 return other.identity == self.identity
499 def __hash__(self):
500 return hash(self.identity)
503# Type inference for CombinedExpression.output_field.
504# Missing items will result in FieldError, by design.
505#
506# The current approach for NULL is based on lowest common denominator behavior
507# i.e. if one of the supported databases is raising an error (rather than
508# return NULL) for `val <op> NULL`, then Plain raises FieldError.
510_connector_combinations = [
511 # Numeric operations - operands of same type.
512 {
513 connector: [
514 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
515 (fields.FloatField, fields.FloatField, fields.FloatField),
516 (fields.DecimalField, fields.DecimalField, fields.DecimalField),
517 ]
518 for connector in (
519 Combinable.ADD,
520 Combinable.SUB,
521 Combinable.MUL,
522 # Behavior for DIV with integer arguments follows Postgres/SQLite,
523 # not MySQL/Oracle.
524 Combinable.DIV,
525 Combinable.MOD,
526 Combinable.POW,
527 )
528 },
529 # Numeric operations - operands of different type.
530 {
531 connector: [
532 (fields.IntegerField, fields.DecimalField, fields.DecimalField),
533 (fields.DecimalField, fields.IntegerField, fields.DecimalField),
534 (fields.IntegerField, fields.FloatField, fields.FloatField),
535 (fields.FloatField, fields.IntegerField, fields.FloatField),
536 ]
537 for connector in (
538 Combinable.ADD,
539 Combinable.SUB,
540 Combinable.MUL,
541 Combinable.DIV,
542 Combinable.MOD,
543 )
544 },
545 # Bitwise operators.
546 {
547 connector: [
548 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
549 ]
550 for connector in (
551 Combinable.BITAND,
552 Combinable.BITOR,
553 Combinable.BITLEFTSHIFT,
554 Combinable.BITRIGHTSHIFT,
555 Combinable.BITXOR,
556 )
557 },
558 # Numeric with NULL.
559 {
560 connector: [
561 (field_type, NoneType, field_type),
562 (NoneType, field_type, field_type),
563 ]
564 for connector in (
565 Combinable.ADD,
566 Combinable.SUB,
567 Combinable.MUL,
568 Combinable.DIV,
569 Combinable.MOD,
570 Combinable.POW,
571 )
572 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
573 },
574 # Date/DateTimeField/DurationField/TimeField.
575 {
576 Combinable.ADD: [
577 # Date/DateTimeField.
578 (fields.DateField, fields.DurationField, fields.DateTimeField),
579 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
580 (fields.DurationField, fields.DateField, fields.DateTimeField),
581 (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
582 # DurationField.
583 (fields.DurationField, fields.DurationField, fields.DurationField),
584 # TimeField.
585 (fields.TimeField, fields.DurationField, fields.TimeField),
586 (fields.DurationField, fields.TimeField, fields.TimeField),
587 ],
588 },
589 {
590 Combinable.SUB: [
591 # Date/DateTimeField.
592 (fields.DateField, fields.DurationField, fields.DateTimeField),
593 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
594 (fields.DateField, fields.DateField, fields.DurationField),
595 (fields.DateField, fields.DateTimeField, fields.DurationField),
596 (fields.DateTimeField, fields.DateField, fields.DurationField),
597 (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
598 # DurationField.
599 (fields.DurationField, fields.DurationField, fields.DurationField),
600 # TimeField.
601 (fields.TimeField, fields.DurationField, fields.TimeField),
602 (fields.TimeField, fields.TimeField, fields.DurationField),
603 ],
604 },
605]
607_connector_combinators = defaultdict(list)
610def register_combinable_fields(lhs, connector, rhs, result):
611 """
612 Register combinable types:
613 lhs <connector> rhs -> result
614 e.g.
615 register_combinable_fields(
616 IntegerField, Combinable.ADD, FloatField, FloatField
617 )
618 """
619 _connector_combinators[connector].append((lhs, rhs, result))
622for d in _connector_combinations:
623 for connector, field_types in d.items():
624 for lhs, rhs, result in field_types:
625 register_combinable_fields(lhs, connector, rhs, result)
628@functools.lru_cache(maxsize=128)
629def _resolve_combined_type(connector, lhs_type, rhs_type):
630 combinators = _connector_combinators.get(connector, ())
631 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
632 if issubclass(lhs_type, combinator_lhs_type) and issubclass(
633 rhs_type, combinator_rhs_type
634 ):
635 return combined_type
638class CombinedExpression(SQLiteNumericMixin, Expression):
639 def __init__(self, lhs, connector, rhs, output_field=None):
640 super().__init__(output_field=output_field)
641 self.connector = connector
642 self.lhs = lhs
643 self.rhs = rhs
645 def __repr__(self):
646 return f"<{self.__class__.__name__}: {self}>"
648 def __str__(self):
649 return f"{self.lhs} {self.connector} {self.rhs}"
651 def get_source_expressions(self):
652 return [self.lhs, self.rhs]
654 def set_source_expressions(self, exprs):
655 self.lhs, self.rhs = exprs
657 def _resolve_output_field(self):
658 # We avoid using super() here for reasons given in
659 # Expression._resolve_output_field()
660 combined_type = _resolve_combined_type(
661 self.connector,
662 type(self.lhs._output_field_or_none),
663 type(self.rhs._output_field_or_none),
664 )
665 if combined_type is None:
666 raise FieldError(
667 f"Cannot infer type of {self.connector!r} expression involving these "
668 f"types: {self.lhs.output_field.__class__.__name__}, "
669 f"{self.rhs.output_field.__class__.__name__}. You must set "
670 f"output_field."
671 )
672 return combined_type()
674 def as_sql(self, compiler, connection):
675 expressions = []
676 expression_params = []
677 sql, params = compiler.compile(self.lhs)
678 expressions.append(sql)
679 expression_params.extend(params)
680 sql, params = compiler.compile(self.rhs)
681 expressions.append(sql)
682 expression_params.extend(params)
683 # order of precedence
684 expression_wrapper = "(%s)"
685 sql = connection.ops.combine_expression(self.connector, expressions)
686 return expression_wrapper % sql, expression_params
688 def resolve_expression(
689 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
690 ):
691 lhs = self.lhs.resolve_expression(
692 query, allow_joins, reuse, summarize, for_save
693 )
694 rhs = self.rhs.resolve_expression(
695 query, allow_joins, reuse, summarize, for_save
696 )
697 if not isinstance(self, DurationExpression | TemporalSubtraction):
698 try:
699 lhs_type = lhs.output_field.get_internal_type()
700 except (AttributeError, FieldError):
701 lhs_type = None
702 try:
703 rhs_type = rhs.output_field.get_internal_type()
704 except (AttributeError, FieldError):
705 rhs_type = None
706 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
707 return DurationExpression(
708 self.lhs, self.connector, self.rhs
709 ).resolve_expression(
710 query,
711 allow_joins,
712 reuse,
713 summarize,
714 for_save,
715 )
716 datetime_fields = {"DateField", "DateTimeField", "TimeField"}
717 if (
718 self.connector == self.SUB
719 and lhs_type in datetime_fields
720 and lhs_type == rhs_type
721 ):
722 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
723 query,
724 allow_joins,
725 reuse,
726 summarize,
727 for_save,
728 )
729 c = self.copy()
730 c.is_summary = summarize
731 c.lhs = lhs
732 c.rhs = rhs
733 return c
736class DurationExpression(CombinedExpression):
737 def compile(self, side, compiler, connection):
738 try:
739 output = side.output_field
740 except FieldError:
741 pass
742 else:
743 if output.get_internal_type() == "DurationField":
744 sql, params = compiler.compile(side)
745 return connection.ops.format_for_duration_arithmetic(sql), params
746 return compiler.compile(side)
748 def as_sql(self, compiler, connection):
749 if connection.features.has_native_duration_field:
750 return super().as_sql(compiler, connection)
751 connection.ops.check_expression_support(self)
752 expressions = []
753 expression_params = []
754 sql, params = self.compile(self.lhs, compiler, connection)
755 expressions.append(sql)
756 expression_params.extend(params)
757 sql, params = self.compile(self.rhs, compiler, connection)
758 expressions.append(sql)
759 expression_params.extend(params)
760 # order of precedence
761 expression_wrapper = "(%s)"
762 sql = connection.ops.combine_duration_expression(self.connector, expressions)
763 return expression_wrapper % sql, expression_params
765 def as_sqlite(self, compiler, connection, **extra_context):
766 sql, params = self.as_sql(compiler, connection, **extra_context)
767 if self.connector in {Combinable.MUL, Combinable.DIV}:
768 try:
769 lhs_type = self.lhs.output_field.get_internal_type()
770 rhs_type = self.rhs.output_field.get_internal_type()
771 except (AttributeError, FieldError):
772 pass
773 else:
774 allowed_fields = {
775 "DecimalField",
776 "DurationField",
777 "FloatField",
778 "IntegerField",
779 }
780 if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
781 raise DatabaseError(
782 f"Invalid arguments for operator {self.connector}."
783 )
784 return sql, params
787class TemporalSubtraction(CombinedExpression):
788 output_field = fields.DurationField()
790 def __init__(self, lhs, rhs):
791 super().__init__(lhs, self.SUB, rhs)
793 def as_sql(self, compiler, connection):
794 connection.ops.check_expression_support(self)
795 lhs = compiler.compile(self.lhs)
796 rhs = compiler.compile(self.rhs)
797 return connection.ops.subtract_temporals(
798 self.lhs.output_field.get_internal_type(), lhs, rhs
799 )
802@deconstructible(path="plain.models.F")
803class F(Combinable):
804 """An object capable of resolving references to existing query objects."""
806 def __init__(self, name):
807 """
808 Arguments:
809 * name: the name of the field this expression references
810 """
811 self.name = name
813 def __repr__(self):
814 return f"{self.__class__.__name__}({self.name})"
816 def resolve_expression(
817 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
818 ):
819 return query.resolve_ref(self.name, allow_joins, reuse, summarize)
821 def replace_expressions(self, replacements):
822 return replacements.get(self, self)
824 def asc(self, **kwargs):
825 return OrderBy(self, **kwargs)
827 def desc(self, **kwargs):
828 return OrderBy(self, descending=True, **kwargs)
830 def __eq__(self, other):
831 return self.__class__ == other.__class__ and self.name == other.name
833 def __hash__(self):
834 return hash(self.name)
836 def copy(self):
837 return copy.copy(self)
840class ResolvedOuterRef(F):
841 """
842 An object that contains a reference to an outer query.
844 In this case, the reference to the outer query has been resolved because
845 the inner query has been used as a subquery.
846 """
848 contains_aggregate = False
849 contains_over_clause = False
851 def as_sql(self, *args, **kwargs):
852 raise ValueError(
853 "This queryset contains a reference to an outer query and may "
854 "only be used in a subquery."
855 )
857 def resolve_expression(self, *args, **kwargs):
858 col = super().resolve_expression(*args, **kwargs)
859 if col.contains_over_clause:
860 raise NotSupportedError(
861 f"Referencing outer query window expression is not supported: "
862 f"{self.name}."
863 )
864 # FIXME: Rename possibly_multivalued to multivalued and fix detection
865 # for non-multivalued JOINs (e.g. foreign key fields). This should take
866 # into account only many-to-many and one-to-many relationships.
867 col.possibly_multivalued = LOOKUP_SEP in self.name
868 return col
870 def relabeled_clone(self, relabels):
871 return self
873 def get_group_by_cols(self):
874 return []
877class OuterRef(F):
878 contains_aggregate = False
880 def resolve_expression(self, *args, **kwargs):
881 if isinstance(self.name, self.__class__):
882 return self.name
883 return ResolvedOuterRef(self.name)
885 def relabeled_clone(self, relabels):
886 return self
889@deconstructible(path="plain.models.Func")
890class Func(SQLiteNumericMixin, Expression):
891 """An SQL function call."""
893 function = None
894 template = "%(function)s(%(expressions)s)"
895 arg_joiner = ", "
896 arity = None # The number of arguments the function accepts.
898 def __init__(self, *expressions, output_field=None, **extra):
899 if self.arity is not None and len(expressions) != self.arity:
900 raise TypeError(
901 "'{}' takes exactly {} {} ({} given)".format(
902 self.__class__.__name__,
903 self.arity,
904 "argument" if self.arity == 1 else "arguments",
905 len(expressions),
906 )
907 )
908 super().__init__(output_field=output_field)
909 self.source_expressions = self._parse_expressions(*expressions)
910 self.extra = extra
912 def __repr__(self):
913 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
914 extra = {**self.extra, **self._get_repr_options()}
915 if extra:
916 extra = ", ".join(
917 str(key) + "=" + str(val) for key, val in sorted(extra.items())
918 )
919 return f"{self.__class__.__name__}({args}, {extra})"
920 return f"{self.__class__.__name__}({args})"
922 def _get_repr_options(self):
923 """Return a dict of extra __init__() options to include in the repr."""
924 return {}
926 def get_source_expressions(self):
927 return self.source_expressions
929 def set_source_expressions(self, exprs):
930 self.source_expressions = exprs
932 def resolve_expression(
933 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
934 ):
935 c = self.copy()
936 c.is_summary = summarize
937 for pos, arg in enumerate(c.source_expressions):
938 c.source_expressions[pos] = arg.resolve_expression(
939 query, allow_joins, reuse, summarize, for_save
940 )
941 return c
943 def as_sql(
944 self,
945 compiler,
946 connection,
947 function=None,
948 template=None,
949 arg_joiner=None,
950 **extra_context,
951 ):
952 connection.ops.check_expression_support(self)
953 sql_parts = []
954 params = []
955 for arg in self.source_expressions:
956 try:
957 arg_sql, arg_params = compiler.compile(arg)
958 except EmptyResultSet:
959 empty_result_set_value = getattr(
960 arg, "empty_result_set_value", NotImplemented
961 )
962 if empty_result_set_value is NotImplemented:
963 raise
964 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
965 except FullResultSet:
966 arg_sql, arg_params = compiler.compile(Value(True))
967 sql_parts.append(arg_sql)
968 params.extend(arg_params)
969 data = {**self.extra, **extra_context}
970 # Use the first supplied value in this order: the parameter to this
971 # method, a value supplied in __init__()'s **extra (the value in
972 # `data`), or the value defined on the class.
973 if function is not None:
974 data["function"] = function
975 else:
976 data.setdefault("function", self.function)
977 template = template or data.get("template", self.template)
978 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
979 data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
980 return template % data, params
982 def copy(self):
983 copy = super().copy()
984 copy.source_expressions = self.source_expressions[:]
985 copy.extra = self.extra.copy()
986 return copy
989@deconstructible(path="plain.models.Value")
990class Value(SQLiteNumericMixin, Expression):
991 """Represent a wrapped value as a node within an expression."""
993 # Provide a default value for `for_save` in order to allow unresolved
994 # instances to be compiled until a decision is taken in #25425.
995 for_save = False
997 def __init__(self, value, output_field=None):
998 """
999 Arguments:
1000 * value: the value this expression represents. The value will be
1001 added into the sql parameter list and properly quoted.
1003 * output_field: an instance of the model field type that this
1004 expression will return, such as IntegerField() or CharField().
1005 """
1006 super().__init__(output_field=output_field)
1007 self.value = value
1009 def __repr__(self):
1010 return f"{self.__class__.__name__}({self.value!r})"
1012 def as_sql(self, compiler, connection):
1013 connection.ops.check_expression_support(self)
1014 val = self.value
1015 output_field = self._output_field_or_none
1016 if output_field is not None:
1017 if self.for_save:
1018 val = output_field.get_db_prep_save(val, connection=connection)
1019 else:
1020 val = output_field.get_db_prep_value(val, connection=connection)
1021 if hasattr(output_field, "get_placeholder"):
1022 return output_field.get_placeholder(val, compiler, connection), [val]
1023 if val is None:
1024 # cx_Oracle does not always convert None to the appropriate
1025 # NULL type (like in case expressions using numbers), so we
1026 # use a literal SQL NULL
1027 return "NULL", []
1028 return "%s", [val]
1030 def resolve_expression(
1031 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1032 ):
1033 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1034 c.for_save = for_save
1035 return c
1037 def get_group_by_cols(self):
1038 return []
1040 def _resolve_output_field(self):
1041 if isinstance(self.value, str):
1042 return fields.CharField()
1043 if isinstance(self.value, bool):
1044 return fields.BooleanField()
1045 if isinstance(self.value, int):
1046 return fields.IntegerField()
1047 if isinstance(self.value, float):
1048 return fields.FloatField()
1049 if isinstance(self.value, datetime.datetime):
1050 return fields.DateTimeField()
1051 if isinstance(self.value, datetime.date):
1052 return fields.DateField()
1053 if isinstance(self.value, datetime.time):
1054 return fields.TimeField()
1055 if isinstance(self.value, datetime.timedelta):
1056 return fields.DurationField()
1057 if isinstance(self.value, Decimal):
1058 return fields.DecimalField()
1059 if isinstance(self.value, bytes):
1060 return fields.BinaryField()
1061 if isinstance(self.value, UUID):
1062 return fields.UUIDField()
1064 @property
1065 def empty_result_set_value(self):
1066 return self.value
1069class RawSQL(Expression):
1070 def __init__(self, sql, params, output_field=None):
1071 if output_field is None:
1072 output_field = fields.Field()
1073 self.sql, self.params = sql, params
1074 super().__init__(output_field=output_field)
1076 def __repr__(self):
1077 return f"{self.__class__.__name__}({self.sql}, {self.params})"
1079 def as_sql(self, compiler, connection):
1080 return "(%s)" % self.sql, self.params
1082 def get_group_by_cols(self):
1083 return [self]
1085 def resolve_expression(
1086 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1087 ):
1088 # Resolve parents fields used in raw SQL.
1089 if query.model:
1090 for parent in query.model._meta.get_parent_list():
1091 for parent_field in parent._meta.local_fields:
1092 _, column_name = parent_field.get_attname_column()
1093 if column_name.lower() in self.sql.lower():
1094 query.resolve_ref(
1095 parent_field.name, allow_joins, reuse, summarize
1096 )
1097 break
1098 return super().resolve_expression(
1099 query, allow_joins, reuse, summarize, for_save
1100 )
1103class Star(Expression):
1104 def __repr__(self):
1105 return "'*'"
1107 def as_sql(self, compiler, connection):
1108 return "*", []
1111class Col(Expression):
1112 contains_column_references = True
1113 possibly_multivalued = False
1115 def __init__(self, alias, target, output_field=None):
1116 if output_field is None:
1117 output_field = target
1118 super().__init__(output_field=output_field)
1119 self.alias, self.target = alias, target
1121 def __repr__(self):
1122 alias, target = self.alias, self.target
1123 identifiers = (alias, str(target)) if alias else (str(target),)
1124 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1126 def as_sql(self, compiler, connection):
1127 alias, column = self.alias, self.target.column
1128 identifiers = (alias, column) if alias else (column,)
1129 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1130 return sql, []
1132 def relabeled_clone(self, relabels):
1133 if self.alias is None:
1134 return self
1135 return self.__class__(
1136 relabels.get(self.alias, self.alias), self.target, self.output_field
1137 )
1139 def get_group_by_cols(self):
1140 return [self]
1142 def get_db_converters(self, connection):
1143 if self.target == self.output_field:
1144 return self.output_field.get_db_converters(connection)
1145 return self.output_field.get_db_converters(
1146 connection
1147 ) + self.target.get_db_converters(connection)
1150class Ref(Expression):
1151 """
1152 Reference to column alias of the query. For example, Ref('sum_cost') in
1153 qs.annotate(sum_cost=Sum('cost')) query.
1154 """
1156 def __init__(self, refs, source):
1157 super().__init__()
1158 self.refs, self.source = refs, source
1160 def __repr__(self):
1161 return f"{self.__class__.__name__}({self.refs}, {self.source})"
1163 def get_source_expressions(self):
1164 return [self.source]
1166 def set_source_expressions(self, exprs):
1167 (self.source,) = exprs
1169 def resolve_expression(
1170 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1171 ):
1172 # The sub-expression `source` has already been resolved, as this is
1173 # just a reference to the name of `source`.
1174 return self
1176 def get_refs(self):
1177 return {self.refs}
1179 def relabeled_clone(self, relabels):
1180 return self
1182 def as_sql(self, compiler, connection):
1183 return connection.ops.quote_name(self.refs), []
1185 def get_group_by_cols(self):
1186 return [self]
1189class ExpressionList(Func):
1190 """
1191 An expression containing multiple expressions. Can be used to provide a
1192 list of expressions as an argument to another expression, like a partition
1193 clause.
1194 """
1196 template = "%(expressions)s"
1198 def __init__(self, *expressions, **extra):
1199 if not expressions:
1200 raise ValueError(
1201 "%s requires at least one expression." % self.__class__.__name__
1202 )
1203 super().__init__(*expressions, **extra)
1205 def __str__(self):
1206 return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1208 def as_sqlite(self, compiler, connection, **extra_context):
1209 # Casting to numeric is unnecessary.
1210 return self.as_sql(compiler, connection, **extra_context)
1213class OrderByList(Func):
1214 template = "ORDER BY %(expressions)s"
1216 def __init__(self, *expressions, **extra):
1217 expressions = (
1218 (
1219 OrderBy(F(expr[1:]), descending=True)
1220 if isinstance(expr, str) and expr[0] == "-"
1221 else expr
1222 )
1223 for expr in expressions
1224 )
1225 super().__init__(*expressions, **extra)
1227 def as_sql(self, *args, **kwargs):
1228 if not self.source_expressions:
1229 return "", ()
1230 return super().as_sql(*args, **kwargs)
1232 def get_group_by_cols(self):
1233 group_by_cols = []
1234 for order_by in self.get_source_expressions():
1235 group_by_cols.extend(order_by.get_group_by_cols())
1236 return group_by_cols
1239@deconstructible(path="plain.models.ExpressionWrapper")
1240class ExpressionWrapper(SQLiteNumericMixin, Expression):
1241 """
1242 An expression that can wrap another expression so that it can provide
1243 extra context to the inner expression, such as the output_field.
1244 """
1246 def __init__(self, expression, output_field):
1247 super().__init__(output_field=output_field)
1248 self.expression = expression
1250 def set_source_expressions(self, exprs):
1251 self.expression = exprs[0]
1253 def get_source_expressions(self):
1254 return [self.expression]
1256 def get_group_by_cols(self):
1257 if isinstance(self.expression, Expression):
1258 expression = self.expression.copy()
1259 expression.output_field = self.output_field
1260 return expression.get_group_by_cols()
1261 # For non-expressions e.g. an SQL WHERE clause, the entire
1262 # `expression` must be included in the GROUP BY clause.
1263 return super().get_group_by_cols()
1265 def as_sql(self, compiler, connection):
1266 return compiler.compile(self.expression)
1268 def __repr__(self):
1269 return f"{self.__class__.__name__}({self.expression})"
1272class NegatedExpression(ExpressionWrapper):
1273 """The logical negation of a conditional expression."""
1275 def __init__(self, expression):
1276 super().__init__(expression, output_field=fields.BooleanField())
1278 def __invert__(self):
1279 return self.expression.copy()
1281 def as_sql(self, compiler, connection):
1282 try:
1283 sql, params = super().as_sql(compiler, connection)
1284 except EmptyResultSet:
1285 features = compiler.connection.features
1286 if not features.supports_boolean_expr_in_select_clause:
1287 return "1=1", ()
1288 return compiler.compile(Value(True))
1289 ops = compiler.connection.ops
1290 # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1291 # to be compared to another expression unless they're wrapped in a CASE
1292 # WHEN.
1293 if not ops.conditional_expression_supported_in_where_clause(self.expression):
1294 return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1295 return f"NOT {sql}", params
1297 def resolve_expression(
1298 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1299 ):
1300 resolved = super().resolve_expression(
1301 query, allow_joins, reuse, summarize, for_save
1302 )
1303 if not getattr(resolved.expression, "conditional", False):
1304 raise TypeError("Cannot negate non-conditional expressions.")
1305 return resolved
1307 def select_format(self, compiler, sql, params):
1308 # Wrap boolean expressions with a CASE WHEN expression if a database
1309 # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1310 # GROUP BY list.
1311 expression_supported_in_where_clause = (
1312 compiler.connection.ops.conditional_expression_supported_in_where_clause
1313 )
1314 if (
1315 not compiler.connection.features.supports_boolean_expr_in_select_clause
1316 # Avoid double wrapping.
1317 and expression_supported_in_where_clause(self.expression)
1318 ):
1319 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1320 return sql, params
1323@deconstructible(path="plain.models.When")
1324class When(Expression):
1325 template = "WHEN %(condition)s THEN %(result)s"
1326 # This isn't a complete conditional expression, must be used in Case().
1327 conditional = False
1329 def __init__(self, condition=None, then=None, **lookups):
1330 if lookups:
1331 if condition is None:
1332 condition, lookups = Q(**lookups), None
1333 elif getattr(condition, "conditional", False):
1334 condition, lookups = Q(condition, **lookups), None
1335 if condition is None or not getattr(condition, "conditional", False) or lookups:
1336 raise TypeError(
1337 "When() supports a Q object, a boolean expression, or lookups "
1338 "as a condition."
1339 )
1340 if isinstance(condition, Q) and not condition:
1341 raise ValueError("An empty Q() can't be used as a When() condition.")
1342 super().__init__(output_field=None)
1343 self.condition = condition
1344 self.result = self._parse_expressions(then)[0]
1346 def __str__(self):
1347 return f"WHEN {self.condition!r} THEN {self.result!r}"
1349 def __repr__(self):
1350 return f"<{self.__class__.__name__}: {self}>"
1352 def get_source_expressions(self):
1353 return [self.condition, self.result]
1355 def set_source_expressions(self, exprs):
1356 self.condition, self.result = exprs
1358 def get_source_fields(self):
1359 # We're only interested in the fields of the result expressions.
1360 return [self.result._output_field_or_none]
1362 def resolve_expression(
1363 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1364 ):
1365 c = self.copy()
1366 c.is_summary = summarize
1367 if hasattr(c.condition, "resolve_expression"):
1368 c.condition = c.condition.resolve_expression(
1369 query, allow_joins, reuse, summarize, False
1370 )
1371 c.result = c.result.resolve_expression(
1372 query, allow_joins, reuse, summarize, for_save
1373 )
1374 return c
1376 def as_sql(self, compiler, connection, template=None, **extra_context):
1377 connection.ops.check_expression_support(self)
1378 template_params = extra_context
1379 sql_params = []
1380 condition_sql, condition_params = compiler.compile(self.condition)
1381 template_params["condition"] = condition_sql
1382 result_sql, result_params = compiler.compile(self.result)
1383 template_params["result"] = result_sql
1384 template = template or self.template
1385 return template % template_params, (
1386 *sql_params,
1387 *condition_params,
1388 *result_params,
1389 )
1391 def get_group_by_cols(self):
1392 # This is not a complete expression and cannot be used in GROUP BY.
1393 cols = []
1394 for source in self.get_source_expressions():
1395 cols.extend(source.get_group_by_cols())
1396 return cols
1399@deconstructible(path="plain.models.Case")
1400class Case(SQLiteNumericMixin, Expression):
1401 """
1402 An SQL searched CASE expression:
1404 CASE
1405 WHEN n > 0
1406 THEN 'positive'
1407 WHEN n < 0
1408 THEN 'negative'
1409 ELSE 'zero'
1410 END
1411 """
1413 template = "CASE %(cases)s ELSE %(default)s END"
1414 case_joiner = " "
1416 def __init__(self, *cases, default=None, output_field=None, **extra):
1417 if not all(isinstance(case, When) for case in cases):
1418 raise TypeError("Positional arguments must all be When objects.")
1419 super().__init__(output_field)
1420 self.cases = list(cases)
1421 self.default = self._parse_expressions(default)[0]
1422 self.extra = extra
1424 def __str__(self):
1425 return "CASE {}, ELSE {!r}".format(
1426 ", ".join(str(c) for c in self.cases),
1427 self.default,
1428 )
1430 def __repr__(self):
1431 return f"<{self.__class__.__name__}: {self}>"
1433 def get_source_expressions(self):
1434 return self.cases + [self.default]
1436 def set_source_expressions(self, exprs):
1437 *self.cases, self.default = exprs
1439 def resolve_expression(
1440 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1441 ):
1442 c = self.copy()
1443 c.is_summary = summarize
1444 for pos, case in enumerate(c.cases):
1445 c.cases[pos] = case.resolve_expression(
1446 query, allow_joins, reuse, summarize, for_save
1447 )
1448 c.default = c.default.resolve_expression(
1449 query, allow_joins, reuse, summarize, for_save
1450 )
1451 return c
1453 def copy(self):
1454 c = super().copy()
1455 c.cases = c.cases[:]
1456 return c
1458 def as_sql(
1459 self, compiler, connection, template=None, case_joiner=None, **extra_context
1460 ):
1461 connection.ops.check_expression_support(self)
1462 if not self.cases:
1463 return compiler.compile(self.default)
1464 template_params = {**self.extra, **extra_context}
1465 case_parts = []
1466 sql_params = []
1467 default_sql, default_params = compiler.compile(self.default)
1468 for case in self.cases:
1469 try:
1470 case_sql, case_params = compiler.compile(case)
1471 except EmptyResultSet:
1472 continue
1473 except FullResultSet:
1474 default_sql, default_params = compiler.compile(case.result)
1475 break
1476 case_parts.append(case_sql)
1477 sql_params.extend(case_params)
1478 if not case_parts:
1479 return default_sql, default_params
1480 case_joiner = case_joiner or self.case_joiner
1481 template_params["cases"] = case_joiner.join(case_parts)
1482 template_params["default"] = default_sql
1483 sql_params.extend(default_params)
1484 template = template or template_params.get("template", self.template)
1485 sql = template % template_params
1486 if self._output_field_or_none is not None:
1487 sql = connection.ops.unification_cast_sql(self.output_field) % sql
1488 return sql, sql_params
1490 def get_group_by_cols(self):
1491 if not self.cases:
1492 return self.default.get_group_by_cols()
1493 return super().get_group_by_cols()
1496class Subquery(BaseExpression, Combinable):
1497 """
1498 An explicit subquery. It may contain OuterRef() references to the outer
1499 query which will be resolved when it is applied to that query.
1500 """
1502 template = "(%(subquery)s)"
1503 contains_aggregate = False
1504 empty_result_set_value = None
1506 def __init__(self, queryset, output_field=None, **extra):
1507 # Allow the usage of both QuerySet and sql.Query objects.
1508 self.query = getattr(queryset, "query", queryset).clone()
1509 self.query.subquery = True
1510 self.extra = extra
1511 super().__init__(output_field)
1513 def get_source_expressions(self):
1514 return [self.query]
1516 def set_source_expressions(self, exprs):
1517 self.query = exprs[0]
1519 def _resolve_output_field(self):
1520 return self.query.output_field
1522 def copy(self):
1523 clone = super().copy()
1524 clone.query = clone.query.clone()
1525 return clone
1527 @property
1528 def external_aliases(self):
1529 return self.query.external_aliases
1531 def get_external_cols(self):
1532 return self.query.get_external_cols()
1534 def as_sql(self, compiler, connection, template=None, **extra_context):
1535 connection.ops.check_expression_support(self)
1536 template_params = {**self.extra, **extra_context}
1537 subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1538 template_params["subquery"] = subquery_sql[1:-1]
1540 template = template or template_params.get("template", self.template)
1541 sql = template % template_params
1542 return sql, sql_params
1544 def get_group_by_cols(self):
1545 return self.query.get_group_by_cols(wrapper=self)
1548class Exists(Subquery):
1549 template = "EXISTS(%(subquery)s)"
1550 output_field = fields.BooleanField()
1551 empty_result_set_value = False
1553 def __init__(self, queryset, **kwargs):
1554 super().__init__(queryset, **kwargs)
1555 self.query = self.query.exists()
1557 def select_format(self, compiler, sql, params):
1558 # Wrap EXISTS() with a CASE WHEN expression if a database backend
1559 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1560 # BY list.
1561 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1562 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1563 return sql, params
1566@deconstructible(path="plain.models.OrderBy")
1567class OrderBy(Expression):
1568 template = "%(expression)s %(ordering)s"
1569 conditional = False
1571 def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None):
1572 if nulls_first and nulls_last:
1573 raise ValueError("nulls_first and nulls_last are mutually exclusive")
1574 if nulls_first is False or nulls_last is False:
1575 raise ValueError("nulls_first and nulls_last values must be True or None.")
1576 self.nulls_first = nulls_first
1577 self.nulls_last = nulls_last
1578 self.descending = descending
1579 if not hasattr(expression, "resolve_expression"):
1580 raise ValueError("expression must be an expression type")
1581 self.expression = expression
1583 def __repr__(self):
1584 return "{}({}, descending={})".format(
1585 self.__class__.__name__, self.expression, self.descending
1586 )
1588 def set_source_expressions(self, exprs):
1589 self.expression = exprs[0]
1591 def get_source_expressions(self):
1592 return [self.expression]
1594 def as_sql(self, compiler, connection, template=None, **extra_context):
1595 template = template or self.template
1596 if connection.features.supports_order_by_nulls_modifier:
1597 if self.nulls_last:
1598 template = "%s NULLS LAST" % template
1599 elif self.nulls_first:
1600 template = "%s NULLS FIRST" % template
1601 else:
1602 if self.nulls_last and not (
1603 self.descending and connection.features.order_by_nulls_first
1604 ):
1605 template = "%%(expression)s IS NULL, %s" % template
1606 elif self.nulls_first and not (
1607 not self.descending and connection.features.order_by_nulls_first
1608 ):
1609 template = "%%(expression)s IS NOT NULL, %s" % template
1610 connection.ops.check_expression_support(self)
1611 expression_sql, params = compiler.compile(self.expression)
1612 placeholders = {
1613 "expression": expression_sql,
1614 "ordering": "DESC" if self.descending else "ASC",
1615 **extra_context,
1616 }
1617 params *= template.count("%(expression)s")
1618 return (template % placeholders).rstrip(), params
1620 def get_group_by_cols(self):
1621 cols = []
1622 for source in self.get_source_expressions():
1623 cols.extend(source.get_group_by_cols())
1624 return cols
1626 def reverse_ordering(self):
1627 self.descending = not self.descending
1628 if self.nulls_first:
1629 self.nulls_last = True
1630 self.nulls_first = None
1631 elif self.nulls_last:
1632 self.nulls_first = True
1633 self.nulls_last = None
1634 return self
1636 def asc(self):
1637 self.descending = False
1639 def desc(self):
1640 self.descending = True
1643class Window(SQLiteNumericMixin, Expression):
1644 template = "%(expression)s OVER (%(window)s)"
1645 # Although the main expression may either be an aggregate or an
1646 # expression with an aggregate function, the GROUP BY that will
1647 # be introduced in the query as a result is not desired.
1648 contains_aggregate = False
1649 contains_over_clause = True
1651 def __init__(
1652 self,
1653 expression,
1654 partition_by=None,
1655 order_by=None,
1656 frame=None,
1657 output_field=None,
1658 ):
1659 self.partition_by = partition_by
1660 self.order_by = order_by
1661 self.frame = frame
1663 if not getattr(expression, "window_compatible", False):
1664 raise ValueError(
1665 "Expression '%s' isn't compatible with OVER clauses."
1666 % expression.__class__.__name__
1667 )
1669 if self.partition_by is not None:
1670 if not isinstance(self.partition_by, tuple | list):
1671 self.partition_by = (self.partition_by,)
1672 self.partition_by = ExpressionList(*self.partition_by)
1674 if self.order_by is not None:
1675 if isinstance(self.order_by, list | tuple):
1676 self.order_by = OrderByList(*self.order_by)
1677 elif isinstance(self.order_by, BaseExpression | str):
1678 self.order_by = OrderByList(self.order_by)
1679 else:
1680 raise ValueError(
1681 "Window.order_by must be either a string reference to a "
1682 "field, an expression, or a list or tuple of them."
1683 )
1684 super().__init__(output_field=output_field)
1685 self.source_expression = self._parse_expressions(expression)[0]
1687 def _resolve_output_field(self):
1688 return self.source_expression.output_field
1690 def get_source_expressions(self):
1691 return [self.source_expression, self.partition_by, self.order_by, self.frame]
1693 def set_source_expressions(self, exprs):
1694 self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1696 def as_sql(self, compiler, connection, template=None):
1697 connection.ops.check_expression_support(self)
1698 if not connection.features.supports_over_clause:
1699 raise NotSupportedError("This backend does not support window expressions.")
1700 expr_sql, params = compiler.compile(self.source_expression)
1701 window_sql, window_params = [], ()
1703 if self.partition_by is not None:
1704 sql_expr, sql_params = self.partition_by.as_sql(
1705 compiler=compiler,
1706 connection=connection,
1707 template="PARTITION BY %(expressions)s",
1708 )
1709 window_sql.append(sql_expr)
1710 window_params += tuple(sql_params)
1712 if self.order_by is not None:
1713 order_sql, order_params = compiler.compile(self.order_by)
1714 window_sql.append(order_sql)
1715 window_params += tuple(order_params)
1717 if self.frame:
1718 frame_sql, frame_params = compiler.compile(self.frame)
1719 window_sql.append(frame_sql)
1720 window_params += tuple(frame_params)
1722 template = template or self.template
1724 return (
1725 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1726 (*params, *window_params),
1727 )
1729 def as_sqlite(self, compiler, connection):
1730 if isinstance(self.output_field, fields.DecimalField):
1731 # Casting to numeric must be outside of the window expression.
1732 copy = self.copy()
1733 source_expressions = copy.get_source_expressions()
1734 source_expressions[0].output_field = fields.FloatField()
1735 copy.set_source_expressions(source_expressions)
1736 return super(Window, copy).as_sqlite(compiler, connection)
1737 return self.as_sql(compiler, connection)
1739 def __str__(self):
1740 return "{} OVER ({}{}{})".format(
1741 str(self.source_expression),
1742 "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1743 str(self.order_by or ""),
1744 str(self.frame or ""),
1745 )
1747 def __repr__(self):
1748 return f"<{self.__class__.__name__}: {self}>"
1750 def get_group_by_cols(self):
1751 group_by_cols = []
1752 if self.partition_by:
1753 group_by_cols.extend(self.partition_by.get_group_by_cols())
1754 if self.order_by is not None:
1755 group_by_cols.extend(self.order_by.get_group_by_cols())
1756 return group_by_cols
1759class WindowFrame(Expression):
1760 """
1761 Model the frame clause in window expressions. There are two types of frame
1762 clauses which are subclasses, however, all processing and validation (by no
1763 means intended to be complete) is done here. Thus, providing an end for a
1764 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1765 row in the frame).
1766 """
1768 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1770 def __init__(self, start=None, end=None):
1771 self.start = Value(start)
1772 self.end = Value(end)
1774 def set_source_expressions(self, exprs):
1775 self.start, self.end = exprs
1777 def get_source_expressions(self):
1778 return [self.start, self.end]
1780 def as_sql(self, compiler, connection):
1781 connection.ops.check_expression_support(self)
1782 start, end = self.window_frame_start_end(
1783 connection, self.start.value, self.end.value
1784 )
1785 return (
1786 self.template
1787 % {
1788 "frame_type": self.frame_type,
1789 "start": start,
1790 "end": end,
1791 },
1792 [],
1793 )
1795 def __repr__(self):
1796 return f"<{self.__class__.__name__}: {self}>"
1798 def get_group_by_cols(self):
1799 return []
1801 def __str__(self):
1802 if self.start.value is not None and self.start.value < 0:
1803 start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING)
1804 elif self.start.value is not None and self.start.value == 0:
1805 start = connection.ops.CURRENT_ROW
1806 else:
1807 start = connection.ops.UNBOUNDED_PRECEDING
1809 if self.end.value is not None and self.end.value > 0:
1810 end = "%d %s" % (self.end.value, connection.ops.FOLLOWING)
1811 elif self.end.value is not None and self.end.value == 0:
1812 end = connection.ops.CURRENT_ROW
1813 else:
1814 end = connection.ops.UNBOUNDED_FOLLOWING
1815 return self.template % {
1816 "frame_type": self.frame_type,
1817 "start": start,
1818 "end": end,
1819 }
1821 def window_frame_start_end(self, connection, start, end):
1822 raise NotImplementedError("Subclasses must implement window_frame_start_end().")
1825class RowRange(WindowFrame):
1826 frame_type = "ROWS"
1828 def window_frame_start_end(self, connection, start, end):
1829 return connection.ops.window_frame_rows_start_end(start, end)
1832class ValueRange(WindowFrame):
1833 frame_type = "RANGE"
1835 def window_frame_start_end(self, connection, start, end):
1836 return connection.ops.window_frame_range_start_end(start, end)