Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/lookups.py: 62%
448 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
1import itertools
2import math
4from plain.exceptions import EmptyResultSet, FullResultSet
5from plain.models.expressions import Expression, Func, Value
6from plain.models.fields import (
7 BooleanField,
8 CharField,
9 DateTimeField,
10 Field,
11 IntegerField,
12 UUIDField,
13)
14from plain.models.query_utils import RegisterLookupMixin
15from plain.utils.datastructures import OrderedSet
16from plain.utils.functional import cached_property
17from plain.utils.hashable import make_hashable
20class Lookup(Expression):
21 lookup_name = None
22 prepare_rhs = True
23 can_use_none_as_rhs = False
25 def __init__(self, lhs, rhs):
26 self.lhs, self.rhs = lhs, rhs
27 self.rhs = self.get_prep_lookup()
28 self.lhs = self.get_prep_lhs()
29 if hasattr(self.lhs, "get_bilateral_transforms"):
30 bilateral_transforms = self.lhs.get_bilateral_transforms()
31 else:
32 bilateral_transforms = []
33 if bilateral_transforms:
34 # Warn the user as soon as possible if they are trying to apply
35 # a bilateral transformation on a nested QuerySet: that won't work.
36 from plain.models.sql.query import Query # avoid circular import
38 if isinstance(rhs, Query):
39 raise NotImplementedError(
40 "Bilateral transformations on nested querysets are not implemented."
41 )
42 self.bilateral_transforms = bilateral_transforms
44 def apply_bilateral_transforms(self, value):
45 for transform in self.bilateral_transforms:
46 value = transform(value)
47 return value
49 def __repr__(self):
50 return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})"
52 def batch_process_rhs(self, compiler, connection, rhs=None):
53 if rhs is None:
54 rhs = self.rhs
55 if self.bilateral_transforms:
56 sqls, sqls_params = [], []
57 for p in rhs:
58 value = Value(p, output_field=self.lhs.output_field)
59 value = self.apply_bilateral_transforms(value)
60 value = value.resolve_expression(compiler.query)
61 sql, sql_params = compiler.compile(value)
62 sqls.append(sql)
63 sqls_params.extend(sql_params)
64 else:
65 _, params = self.get_db_prep_lookup(rhs, connection)
66 sqls, sqls_params = ["%s"] * len(params), params
67 return sqls, sqls_params
69 def get_source_expressions(self):
70 if self.rhs_is_direct_value():
71 return [self.lhs]
72 return [self.lhs, self.rhs]
74 def set_source_expressions(self, new_exprs):
75 if len(new_exprs) == 1:
76 self.lhs = new_exprs[0]
77 else:
78 self.lhs, self.rhs = new_exprs
80 def get_prep_lookup(self):
81 if not self.prepare_rhs or hasattr(self.rhs, "resolve_expression"):
82 return self.rhs
83 if hasattr(self.lhs, "output_field"):
84 if hasattr(self.lhs.output_field, "get_prep_value"):
85 return self.lhs.output_field.get_prep_value(self.rhs)
86 elif self.rhs_is_direct_value():
87 return Value(self.rhs)
88 return self.rhs
90 def get_prep_lhs(self):
91 if hasattr(self.lhs, "resolve_expression"):
92 return self.lhs
93 return Value(self.lhs)
95 def get_db_prep_lookup(self, value, connection):
96 return ("%s", [value])
98 def process_lhs(self, compiler, connection, lhs=None):
99 lhs = lhs or self.lhs
100 if hasattr(lhs, "resolve_expression"):
101 lhs = lhs.resolve_expression(compiler.query)
102 sql, params = compiler.compile(lhs)
103 if isinstance(lhs, Lookup):
104 # Wrapped in parentheses to respect operator precedence.
105 sql = f"({sql})"
106 return sql, params
108 def process_rhs(self, compiler, connection):
109 value = self.rhs
110 if self.bilateral_transforms:
111 if self.rhs_is_direct_value():
112 # Do not call get_db_prep_lookup here as the value will be
113 # transformed before being used for lookup
114 value = Value(value, output_field=self.lhs.output_field)
115 value = self.apply_bilateral_transforms(value)
116 value = value.resolve_expression(compiler.query)
117 if hasattr(value, "as_sql"):
118 sql, params = compiler.compile(value)
119 # Ensure expression is wrapped in parentheses to respect operator
120 # precedence but avoid double wrapping as it can be misinterpreted
121 # on some backends (e.g. subqueries on SQLite).
122 if sql and sql[0] != "(":
123 sql = f"({sql})"
124 return sql, params
125 else:
126 return self.get_db_prep_lookup(value, connection)
128 def rhs_is_direct_value(self):
129 return not hasattr(self.rhs, "as_sql")
131 def get_group_by_cols(self):
132 cols = []
133 for source in self.get_source_expressions():
134 cols.extend(source.get_group_by_cols())
135 return cols
137 @cached_property
138 def output_field(self):
139 return BooleanField()
141 @property
142 def identity(self):
143 return self.__class__, self.lhs, self.rhs
145 def __eq__(self, other):
146 if not isinstance(other, Lookup):
147 return NotImplemented
148 return self.identity == other.identity
150 def __hash__(self):
151 return hash(make_hashable(self.identity))
153 def resolve_expression(
154 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
155 ):
156 c = self.copy()
157 c.is_summary = summarize
158 c.lhs = self.lhs.resolve_expression(
159 query, allow_joins, reuse, summarize, for_save
160 )
161 if hasattr(self.rhs, "resolve_expression"):
162 c.rhs = self.rhs.resolve_expression(
163 query, allow_joins, reuse, summarize, for_save
164 )
165 return c
167 def select_format(self, compiler, sql, params):
168 # Wrap filters with a CASE WHEN expression if a database backend
169 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
170 # BY list.
171 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
172 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
173 return sql, params
176class Transform(RegisterLookupMixin, Func):
177 """
178 RegisterLookupMixin() is first so that get_lookup() and get_transform()
179 first examine self and then check output_field.
180 """
182 bilateral = False
183 arity = 1
185 @property
186 def lhs(self):
187 return self.get_source_expressions()[0]
189 def get_bilateral_transforms(self):
190 if hasattr(self.lhs, "get_bilateral_transforms"):
191 bilateral_transforms = self.lhs.get_bilateral_transforms()
192 else:
193 bilateral_transforms = []
194 if self.bilateral:
195 bilateral_transforms.append(self.__class__)
196 return bilateral_transforms
199class BuiltinLookup(Lookup):
200 def process_lhs(self, compiler, connection, lhs=None):
201 lhs_sql, params = super().process_lhs(compiler, connection, lhs)
202 field_internal_type = self.lhs.output_field.get_internal_type()
203 db_type = self.lhs.output_field.db_type(connection=connection)
204 lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
205 lhs_sql = (
206 connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
207 )
208 return lhs_sql, list(params)
210 def as_sql(self, compiler, connection):
211 lhs_sql, params = self.process_lhs(compiler, connection)
212 rhs_sql, rhs_params = self.process_rhs(compiler, connection)
213 params.extend(rhs_params)
214 rhs_sql = self.get_rhs_op(connection, rhs_sql)
215 return f"{lhs_sql} {rhs_sql}", params
217 def get_rhs_op(self, connection, rhs):
218 return connection.operators[self.lookup_name] % rhs
221class FieldGetDbPrepValueMixin:
222 """
223 Some lookups require Field.get_db_prep_value() to be called on their
224 inputs.
225 """
227 get_db_prep_lookup_value_is_iterable = False
229 def get_db_prep_lookup(self, value, connection):
230 # For relational fields, use the 'target_field' attribute of the
231 # output_field.
232 field = getattr(self.lhs.output_field, "target_field", None)
233 get_db_prep_value = (
234 getattr(field, "get_db_prep_value", None)
235 or self.lhs.output_field.get_db_prep_value
236 )
237 return (
238 "%s",
239 [get_db_prep_value(v, connection, prepared=True) for v in value]
240 if self.get_db_prep_lookup_value_is_iterable
241 else [get_db_prep_value(value, connection, prepared=True)],
242 )
245class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
246 """
247 Some lookups require Field.get_db_prep_value() to be called on each value
248 in an iterable.
249 """
251 get_db_prep_lookup_value_is_iterable = True
253 def get_prep_lookup(self):
254 if hasattr(self.rhs, "resolve_expression"):
255 return self.rhs
256 prepared_values = []
257 for rhs_value in self.rhs:
258 if hasattr(rhs_value, "resolve_expression"):
259 # An expression will be handled by the database but can coexist
260 # alongside real values.
261 pass
262 elif self.prepare_rhs and hasattr(self.lhs.output_field, "get_prep_value"):
263 rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
264 prepared_values.append(rhs_value)
265 return prepared_values
267 def process_rhs(self, compiler, connection):
268 if self.rhs_is_direct_value():
269 # rhs should be an iterable of values. Use batch_process_rhs()
270 # to prepare/transform those values.
271 return self.batch_process_rhs(compiler, connection)
272 else:
273 return super().process_rhs(compiler, connection)
275 def resolve_expression_parameter(self, compiler, connection, sql, param):
276 params = [param]
277 if hasattr(param, "resolve_expression"):
278 param = param.resolve_expression(compiler.query)
279 if hasattr(param, "as_sql"):
280 sql, params = compiler.compile(param)
281 return sql, params
283 def batch_process_rhs(self, compiler, connection, rhs=None):
284 pre_processed = super().batch_process_rhs(compiler, connection, rhs)
285 # The params list may contain expressions which compile to a
286 # sql/param pair. Zip them to get sql and param pairs that refer to the
287 # same argument and attempt to replace them with the result of
288 # compiling the param step.
289 sql, params = zip(
290 *(
291 self.resolve_expression_parameter(compiler, connection, sql, param)
292 for sql, param in zip(*pre_processed)
293 )
294 )
295 params = itertools.chain.from_iterable(params)
296 return sql, tuple(params)
299class PostgresOperatorLookup(Lookup):
300 """Lookup defined by operators on PostgreSQL."""
302 postgres_operator = None
304 def as_postgresql(self, compiler, connection):
305 lhs, lhs_params = self.process_lhs(compiler, connection)
306 rhs, rhs_params = self.process_rhs(compiler, connection)
307 params = tuple(lhs_params) + tuple(rhs_params)
308 return f"{lhs} {self.postgres_operator} {rhs}", params
311@Field.register_lookup
312class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
313 lookup_name = "exact"
315 def get_prep_lookup(self):
316 from plain.models.sql.query import Query # avoid circular import
318 if isinstance(self.rhs, Query):
319 if self.rhs.has_limit_one():
320 if not self.rhs.has_select_fields:
321 self.rhs.clear_select_clause()
322 self.rhs.add_fields(["pk"])
323 else:
324 raise ValueError(
325 "The QuerySet value for an exact lookup must be limited to "
326 "one result using slicing."
327 )
328 return super().get_prep_lookup()
330 def as_sql(self, compiler, connection):
331 # Avoid comparison against direct rhs if lhs is a boolean value. That
332 # turns "boolfield__exact=True" into "WHERE boolean_field" instead of
333 # "WHERE boolean_field = True" when allowed.
334 if (
335 isinstance(self.rhs, bool)
336 and getattr(self.lhs, "conditional", False)
337 and connection.ops.conditional_expression_supported_in_where_clause(
338 self.lhs
339 )
340 ):
341 lhs_sql, params = self.process_lhs(compiler, connection)
342 template = "%s" if self.rhs else "NOT %s"
343 return template % lhs_sql, params
344 return super().as_sql(compiler, connection)
347@Field.register_lookup
348class IExact(BuiltinLookup):
349 lookup_name = "iexact"
350 prepare_rhs = False
352 def process_rhs(self, qn, connection):
353 rhs, params = super().process_rhs(qn, connection)
354 if params:
355 params[0] = connection.ops.prep_for_iexact_query(params[0])
356 return rhs, params
359@Field.register_lookup
360class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
361 lookup_name = "gt"
364@Field.register_lookup
365class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
366 lookup_name = "gte"
369@Field.register_lookup
370class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
371 lookup_name = "lt"
374@Field.register_lookup
375class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
376 lookup_name = "lte"
379class IntegerFieldOverflow:
380 underflow_exception = EmptyResultSet
381 overflow_exception = EmptyResultSet
383 def process_rhs(self, compiler, connection):
384 rhs = self.rhs
385 if isinstance(rhs, int):
386 field_internal_type = self.lhs.output_field.get_internal_type()
387 min_value, max_value = connection.ops.integer_field_range(
388 field_internal_type
389 )
390 if min_value is not None and rhs < min_value:
391 raise self.underflow_exception
392 if max_value is not None and rhs > max_value:
393 raise self.overflow_exception
394 return super().process_rhs(compiler, connection)
397class IntegerFieldFloatRounding:
398 """
399 Allow floats to work as query values for IntegerField. Without this, the
400 decimal portion of the float would always be discarded.
401 """
403 def get_prep_lookup(self):
404 if isinstance(self.rhs, float):
405 self.rhs = math.ceil(self.rhs)
406 return super().get_prep_lookup()
409@IntegerField.register_lookup
410class IntegerFieldExact(IntegerFieldOverflow, Exact):
411 pass
414@IntegerField.register_lookup
415class IntegerGreaterThan(IntegerFieldOverflow, GreaterThan):
416 underflow_exception = FullResultSet
419@IntegerField.register_lookup
420class IntegerGreaterThanOrEqual(
421 IntegerFieldOverflow, IntegerFieldFloatRounding, GreaterThanOrEqual
422):
423 underflow_exception = FullResultSet
426@IntegerField.register_lookup
427class IntegerLessThan(IntegerFieldOverflow, IntegerFieldFloatRounding, LessThan):
428 overflow_exception = FullResultSet
431@IntegerField.register_lookup
432class IntegerLessThanOrEqual(IntegerFieldOverflow, LessThanOrEqual):
433 overflow_exception = FullResultSet
436@Field.register_lookup
437class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
438 lookup_name = "in"
440 def get_prep_lookup(self):
441 from plain.models.sql.query import Query # avoid circular import
443 if isinstance(self.rhs, Query):
444 self.rhs.clear_ordering(clear_default=True)
445 if not self.rhs.has_select_fields:
446 self.rhs.clear_select_clause()
447 self.rhs.add_fields(["pk"])
448 return super().get_prep_lookup()
450 def process_rhs(self, compiler, connection):
451 db_rhs = getattr(self.rhs, "_db", None)
452 if db_rhs is not None and db_rhs != connection.alias:
453 raise ValueError(
454 "Subqueries aren't allowed across different databases. Force "
455 "the inner query to be evaluated using `list(inner_query)`."
456 )
458 if self.rhs_is_direct_value():
459 # Remove None from the list as NULL is never equal to anything.
460 try:
461 rhs = OrderedSet(self.rhs)
462 rhs.discard(None)
463 except TypeError: # Unhashable items in self.rhs
464 rhs = [r for r in self.rhs if r is not None]
466 if not rhs:
467 raise EmptyResultSet
469 # rhs should be an iterable; use batch_process_rhs() to
470 # prepare/transform those values.
471 sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
472 placeholder = "(" + ", ".join(sqls) + ")"
473 return (placeholder, sqls_params)
474 return super().process_rhs(compiler, connection)
476 def get_rhs_op(self, connection, rhs):
477 return f"IN {rhs}"
479 def as_sql(self, compiler, connection):
480 max_in_list_size = connection.ops.max_in_list_size()
481 if (
482 self.rhs_is_direct_value()
483 and max_in_list_size
484 and len(self.rhs) > max_in_list_size
485 ):
486 return self.split_parameter_list_as_sql(compiler, connection)
487 return super().as_sql(compiler, connection)
489 def split_parameter_list_as_sql(self, compiler, connection):
490 # This is a special case for databases which limit the number of
491 # elements which can appear in an 'IN' clause.
492 max_in_list_size = connection.ops.max_in_list_size()
493 lhs, lhs_params = self.process_lhs(compiler, connection)
494 rhs, rhs_params = self.batch_process_rhs(compiler, connection)
495 in_clause_elements = ["("]
496 params = []
497 for offset in range(0, len(rhs_params), max_in_list_size):
498 if offset > 0:
499 in_clause_elements.append(" OR ")
500 in_clause_elements.append(f"{lhs} IN (")
501 params.extend(lhs_params)
502 sqls = rhs[offset : offset + max_in_list_size]
503 sqls_params = rhs_params[offset : offset + max_in_list_size]
504 param_group = ", ".join(sqls)
505 in_clause_elements.append(param_group)
506 in_clause_elements.append(")")
507 params.extend(sqls_params)
508 in_clause_elements.append(")")
509 return "".join(in_clause_elements), params
512class PatternLookup(BuiltinLookup):
513 param_pattern = "%%%s%%"
514 prepare_rhs = False
516 def get_rhs_op(self, connection, rhs):
517 # Assume we are in startswith. We need to produce SQL like:
518 # col LIKE %s, ['thevalue%']
519 # For python values we can (and should) do that directly in Python,
520 # but if the value is for example reference to other column, then
521 # we need to add the % pattern match to the lookup by something like
522 # col LIKE othercol || '%%'
523 # So, for Python values we don't need any special pattern, but for
524 # SQL reference values or SQL transformations we need the correct
525 # pattern added.
526 if hasattr(self.rhs, "as_sql") or self.bilateral_transforms:
527 pattern = connection.pattern_ops[self.lookup_name].format(
528 connection.pattern_esc
529 )
530 return pattern.format(rhs)
531 else:
532 return super().get_rhs_op(connection, rhs)
534 def process_rhs(self, qn, connection):
535 rhs, params = super().process_rhs(qn, connection)
536 if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
537 params[0] = self.param_pattern % connection.ops.prep_for_like_query(
538 params[0]
539 )
540 return rhs, params
543@Field.register_lookup
544class Contains(PatternLookup):
545 lookup_name = "contains"
548@Field.register_lookup
549class IContains(Contains):
550 lookup_name = "icontains"
553@Field.register_lookup
554class StartsWith(PatternLookup):
555 lookup_name = "startswith"
556 param_pattern = "%s%%"
559@Field.register_lookup
560class IStartsWith(StartsWith):
561 lookup_name = "istartswith"
564@Field.register_lookup
565class EndsWith(PatternLookup):
566 lookup_name = "endswith"
567 param_pattern = "%%%s"
570@Field.register_lookup
571class IEndsWith(EndsWith):
572 lookup_name = "iendswith"
575@Field.register_lookup
576class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
577 lookup_name = "range"
579 def get_rhs_op(self, connection, rhs):
580 return f"BETWEEN {rhs[0]} AND {rhs[1]}"
583@Field.register_lookup
584class IsNull(BuiltinLookup):
585 lookup_name = "isnull"
586 prepare_rhs = False
588 def as_sql(self, compiler, connection):
589 if not isinstance(self.rhs, bool):
590 raise ValueError(
591 "The QuerySet value for an isnull lookup must be True or False."
592 )
593 sql, params = self.process_lhs(compiler, connection)
594 if self.rhs:
595 return f"{sql} IS NULL", params
596 else:
597 return f"{sql} IS NOT NULL", params
600@Field.register_lookup
601class Regex(BuiltinLookup):
602 lookup_name = "regex"
603 prepare_rhs = False
605 def as_sql(self, compiler, connection):
606 if self.lookup_name in connection.operators:
607 return super().as_sql(compiler, connection)
608 else:
609 lhs, lhs_params = self.process_lhs(compiler, connection)
610 rhs, rhs_params = self.process_rhs(compiler, connection)
611 sql_template = connection.ops.regex_lookup(self.lookup_name)
612 return sql_template % (lhs, rhs), lhs_params + rhs_params
615@Field.register_lookup
616class IRegex(Regex):
617 lookup_name = "iregex"
620class YearLookup(Lookup):
621 def year_lookup_bounds(self, connection, year):
622 from plain.models.functions import ExtractIsoYear
624 iso_year = isinstance(self.lhs, ExtractIsoYear)
625 output_field = self.lhs.lhs.output_field
626 if isinstance(output_field, DateTimeField):
627 bounds = connection.ops.year_lookup_bounds_for_datetime_field(
628 year,
629 iso_year=iso_year,
630 )
631 else:
632 bounds = connection.ops.year_lookup_bounds_for_date_field(
633 year,
634 iso_year=iso_year,
635 )
636 return bounds
638 def as_sql(self, compiler, connection):
639 # Avoid the extract operation if the rhs is a direct value to allow
640 # indexes to be used.
641 if self.rhs_is_direct_value():
642 # Skip the extract part by directly using the originating field,
643 # that is self.lhs.lhs.
644 lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
645 rhs_sql, _ = self.process_rhs(compiler, connection)
646 rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
647 start, finish = self.year_lookup_bounds(connection, self.rhs)
648 params.extend(self.get_bound_params(start, finish))
649 return f"{lhs_sql} {rhs_sql}", params
650 return super().as_sql(compiler, connection)
652 def get_direct_rhs_sql(self, connection, rhs):
653 return connection.operators[self.lookup_name] % rhs
655 def get_bound_params(self, start, finish):
656 raise NotImplementedError(
657 "subclasses of YearLookup must provide a get_bound_params() method"
658 )
661class YearExact(YearLookup, Exact):
662 def get_direct_rhs_sql(self, connection, rhs):
663 return "BETWEEN %s AND %s"
665 def get_bound_params(self, start, finish):
666 return (start, finish)
669class YearGt(YearLookup, GreaterThan):
670 def get_bound_params(self, start, finish):
671 return (finish,)
674class YearGte(YearLookup, GreaterThanOrEqual):
675 def get_bound_params(self, start, finish):
676 return (start,)
679class YearLt(YearLookup, LessThan):
680 def get_bound_params(self, start, finish):
681 return (start,)
684class YearLte(YearLookup, LessThanOrEqual):
685 def get_bound_params(self, start, finish):
686 return (finish,)
689class UUIDTextMixin:
690 """
691 Strip hyphens from a value when filtering a UUIDField on backends without
692 a native datatype for UUID.
693 """
695 def process_rhs(self, qn, connection):
696 if not connection.features.has_native_uuid_field:
697 from plain.models.functions import Replace
699 if self.rhs_is_direct_value():
700 self.rhs = Value(self.rhs)
701 self.rhs = Replace(
702 self.rhs, Value("-"), Value(""), output_field=CharField()
703 )
704 rhs, params = super().process_rhs(qn, connection)
705 return rhs, params
708@UUIDField.register_lookup
709class UUIDIExact(UUIDTextMixin, IExact):
710 pass
713@UUIDField.register_lookup
714class UUIDContains(UUIDTextMixin, Contains):
715 pass
718@UUIDField.register_lookup
719class UUIDIContains(UUIDTextMixin, IContains):
720 pass
723@UUIDField.register_lookup
724class UUIDStartsWith(UUIDTextMixin, StartsWith):
725 pass
728@UUIDField.register_lookup
729class UUIDIStartsWith(UUIDTextMixin, IStartsWith):
730 pass
733@UUIDField.register_lookup
734class UUIDEndsWith(UUIDTextMixin, EndsWith):
735 pass
738@UUIDField.register_lookup
739class UUIDIEndsWith(UUIDTextMixin, IEndsWith):
740 pass