Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/query_utils.py: 45%
233 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
1"""
2Various data structures used in query construction.
4Factored out from plain.models.query to avoid making the main module very
5large and/or so that they can be used by other modules without getting into
6circular import difficulties.
7"""
9import functools
10import inspect
11import logging
12from collections import namedtuple
14from plain.exceptions import FieldError
15from plain.models.constants import LOOKUP_SEP
16from plain.models.db import DEFAULT_DB_ALIAS, DatabaseError, connections
17from plain.utils import tree
19logger = logging.getLogger("plain.models")
21# PathInfo is used when converting lookups (fk__somecol). The contents
22# describe the relation in Model terms (model Options and Fields for both
23# sides of the relation. The join_field is the field backing the relation.
24PathInfo = namedtuple(
25 "PathInfo",
26 "from_opts to_opts target_fields join_field m2m direct filtered_relation",
27)
30def subclasses(cls):
31 yield cls
32 for subclass in cls.__subclasses__():
33 yield from subclasses(subclass)
36class Q(tree.Node):
37 """
38 Encapsulate filters as objects that can then be combined logically (using
39 `&` and `|`).
40 """
42 # Connection types
43 AND = "AND"
44 OR = "OR"
45 XOR = "XOR"
46 default = AND
47 conditional = True
49 def __init__(self, *args, _connector=None, _negated=False, **kwargs):
50 super().__init__(
51 children=[*args, *sorted(kwargs.items())],
52 connector=_connector,
53 negated=_negated,
54 )
56 def _combine(self, other, conn):
57 if getattr(other, "conditional", False) is False:
58 raise TypeError(other)
59 if not self:
60 return other.copy()
61 if not other and isinstance(other, Q):
62 return self.copy()
64 obj = self.create(connector=conn)
65 obj.add(self, conn)
66 obj.add(other, conn)
67 return obj
69 def __or__(self, other):
70 return self._combine(other, self.OR)
72 def __and__(self, other):
73 return self._combine(other, self.AND)
75 def __xor__(self, other):
76 return self._combine(other, self.XOR)
78 def __invert__(self):
79 obj = self.copy()
80 obj.negate()
81 return obj
83 def resolve_expression(
84 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
85 ):
86 # We must promote any new joins to left outer joins so that when Q is
87 # used as an expression, rows aren't filtered due to joins.
88 clause, joins = query._add_q(
89 self,
90 reuse,
91 allow_joins=allow_joins,
92 split_subq=False,
93 check_filterable=False,
94 summarize=summarize,
95 )
96 query.promote_joins(joins)
97 return clause
99 def flatten(self):
100 """
101 Recursively yield this Q object and all subexpressions, in depth-first
102 order.
103 """
104 yield self
105 for child in self.children:
106 if isinstance(child, tuple):
107 # Use the lookup.
108 child = child[1]
109 if hasattr(child, "flatten"):
110 yield from child.flatten()
111 else:
112 yield child
114 def check(self, against, using=DEFAULT_DB_ALIAS):
115 """
116 Do a database query to check if the expressions of the Q instance
117 matches against the expressions.
118 """
119 # Avoid circular imports.
120 from plain.models.expressions import Value
121 from plain.models.fields import BooleanField
122 from plain.models.functions import Coalesce
123 from plain.models.sql import Query
124 from plain.models.sql.constants import SINGLE
126 query = Query(None)
127 for name, value in against.items():
128 if not hasattr(value, "resolve_expression"):
129 value = Value(value)
130 query.add_annotation(value, name, select=False)
131 query.add_annotation(Value(1), "_check")
132 # This will raise a FieldError if a field is missing in "against".
133 if connections[using].features.supports_comparing_boolean_expr:
134 query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
135 else:
136 query.add_q(self)
137 compiler = query.get_compiler(using=using)
138 try:
139 return compiler.execute_sql(SINGLE) is not None
140 except DatabaseError as e:
141 logger.warning("Got a database error calling check() on %r: %s", self, e)
142 return True
144 def deconstruct(self):
145 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
146 if path.startswith("plain.models.query_utils"):
147 path = path.replace("plain.models.query_utils", "plain.models")
148 args = tuple(self.children)
149 kwargs = {}
150 if self.connector != self.default:
151 kwargs["_connector"] = self.connector
152 if self.negated:
153 kwargs["_negated"] = True
154 return path, args, kwargs
157class DeferredAttribute:
158 """
159 A wrapper for a deferred-loading field. When the value is read from this
160 object the first time, the query is executed.
161 """
163 def __init__(self, field):
164 self.field = field
166 def __get__(self, instance, cls=None):
167 """
168 Retrieve and caches the value from the datastore on the first lookup.
169 Return the cached value.
170 """
171 if instance is None:
172 return self
173 data = instance.__dict__
174 field_name = self.field.attname
175 if field_name not in data:
176 # Let's see if the field is part of the parent chain. If so we
177 # might be able to reuse the already loaded value. Refs #18343.
178 val = self._check_parent_chain(instance)
179 if val is None:
180 instance.refresh_from_db(fields=[field_name])
181 else:
182 data[field_name] = val
183 return data[field_name]
185 def _check_parent_chain(self, instance):
186 """
187 Check if the field value can be fetched from a parent field already
188 loaded in the instance. This can be done if the to-be fetched
189 field is a primary key field.
190 """
191 opts = instance._meta
192 link_field = opts.get_ancestor_link(self.field.model)
193 if self.field.primary_key and self.field != link_field:
194 return getattr(instance, link_field.attname)
195 return None
198class class_or_instance_method:
199 """
200 Hook used in RegisterLookupMixin to return partial functions depending on
201 the caller type (instance or class of models.Field).
202 """
204 def __init__(self, class_method, instance_method):
205 self.class_method = class_method
206 self.instance_method = instance_method
208 def __get__(self, instance, owner):
209 if instance is None:
210 return functools.partial(self.class_method, owner)
211 return functools.partial(self.instance_method, instance)
214class RegisterLookupMixin:
215 def _get_lookup(self, lookup_name):
216 return self.get_lookups().get(lookup_name, None)
218 @functools.cache
219 def get_class_lookups(cls):
220 class_lookups = [
221 parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
222 ]
223 return cls.merge_dicts(class_lookups)
225 def get_instance_lookups(self):
226 class_lookups = self.get_class_lookups()
227 if instance_lookups := getattr(self, "instance_lookups", None):
228 return {**class_lookups, **instance_lookups}
229 return class_lookups
231 get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
232 get_class_lookups = classmethod(get_class_lookups)
234 def get_lookup(self, lookup_name):
235 from plain.models.lookups import Lookup
237 found = self._get_lookup(lookup_name)
238 if found is None and hasattr(self, "output_field"):
239 return self.output_field.get_lookup(lookup_name)
240 if found is not None and not issubclass(found, Lookup):
241 return None
242 return found
244 def get_transform(self, lookup_name):
245 from plain.models.lookups import Transform
247 found = self._get_lookup(lookup_name)
248 if found is None and hasattr(self, "output_field"):
249 return self.output_field.get_transform(lookup_name)
250 if found is not None and not issubclass(found, Transform):
251 return None
252 return found
254 @staticmethod
255 def merge_dicts(dicts):
256 """
257 Merge dicts in reverse to preference the order of the original list. e.g.,
258 merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'.
259 """
260 merged = {}
261 for d in reversed(dicts):
262 merged.update(d)
263 return merged
265 @classmethod
266 def _clear_cached_class_lookups(cls):
267 for subclass in subclasses(cls):
268 subclass.get_class_lookups.cache_clear()
270 def register_class_lookup(cls, lookup, lookup_name=None):
271 if lookup_name is None:
272 lookup_name = lookup.lookup_name
273 if "class_lookups" not in cls.__dict__:
274 cls.class_lookups = {}
275 cls.class_lookups[lookup_name] = lookup
276 cls._clear_cached_class_lookups()
277 return lookup
279 def register_instance_lookup(self, lookup, lookup_name=None):
280 if lookup_name is None:
281 lookup_name = lookup.lookup_name
282 if "instance_lookups" not in self.__dict__:
283 self.instance_lookups = {}
284 self.instance_lookups[lookup_name] = lookup
285 return lookup
287 register_lookup = class_or_instance_method(
288 register_class_lookup, register_instance_lookup
289 )
290 register_class_lookup = classmethod(register_class_lookup)
292 def _unregister_class_lookup(cls, lookup, lookup_name=None):
293 """
294 Remove given lookup from cls lookups. For use in tests only as it's
295 not thread-safe.
296 """
297 if lookup_name is None:
298 lookup_name = lookup.lookup_name
299 del cls.class_lookups[lookup_name]
300 cls._clear_cached_class_lookups()
302 def _unregister_instance_lookup(self, lookup, lookup_name=None):
303 """
304 Remove given lookup from instance lookups. For use in tests only as
305 it's not thread-safe.
306 """
307 if lookup_name is None:
308 lookup_name = lookup.lookup_name
309 del self.instance_lookups[lookup_name]
311 _unregister_lookup = class_or_instance_method(
312 _unregister_class_lookup, _unregister_instance_lookup
313 )
314 _unregister_class_lookup = classmethod(_unregister_class_lookup)
317def select_related_descend(field, restricted, requested, select_mask, reverse=False):
318 """
319 Return True if this field should be used to descend deeper for
320 select_related() purposes. Used by both the query construction code
321 (compiler.get_related_selections()) and the model instance creation code
322 (compiler.klass_info).
324 Arguments:
325 * field - the field to be checked
326 * restricted - a boolean field, indicating if the field list has been
327 manually restricted using a requested clause)
328 * requested - The select_related() dictionary.
329 * select_mask - the dictionary of selected fields.
330 * reverse - boolean, True if we are checking a reverse select related
331 """
332 if not field.remote_field:
333 return False
334 if field.remote_field.parent_link and not reverse:
335 return False
336 if restricted:
337 if reverse and field.related_query_name() not in requested:
338 return False
339 if not reverse and field.name not in requested:
340 return False
341 if not restricted and field.null:
342 return False
343 if (
344 restricted
345 and select_mask
346 and field.name in requested
347 and field not in select_mask
348 ):
349 raise FieldError(
350 f"Field {field.model._meta.object_name}.{field.name} cannot be both "
351 "deferred and traversed using select_related at the same time."
352 )
353 return True
356def refs_expression(lookup_parts, annotations):
357 """
358 Check if the lookup_parts contains references to the given annotations set.
359 Because the LOOKUP_SEP is contained in the default annotation names, check
360 each prefix of the lookup_parts for a match.
361 """
362 for n in range(1, len(lookup_parts) + 1):
363 level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
364 if annotations.get(level_n_lookup):
365 return level_n_lookup, lookup_parts[n:]
366 return None, ()
369def check_rel_lookup_compatibility(model, target_opts, field):
370 """
371 Check that self.model is compatible with target_opts. Compatibility
372 is OK if:
373 1) model and opts match (where proxy inheritance is removed)
374 2) model is parent of opts' model or the other way around
375 """
377 def check(opts):
378 return (
379 model._meta.concrete_model == opts.concrete_model
380 or opts.concrete_model in model._meta.get_parent_list()
381 or model in opts.get_parent_list()
382 )
384 # If the field is a primary key, then doing a query against the field's
385 # model is ok, too. Consider the case:
386 # class Restaurant(models.Model):
387 # place = OneToOneField(Place, primary_key=True):
388 # Restaurant.objects.filter(pk__in=Restaurant.objects.all()).
389 # If we didn't have the primary key check, then pk__in (== place__in) would
390 # give Place's opts as the target opts, but Restaurant isn't compatible
391 # with that. This logic applies only to primary keys, as when doing __in=qs,
392 # we are going to turn this into __in=qs.values('pk') later on.
393 return check(target_opts) or (
394 getattr(field, "primary_key", False) and check(field.model._meta)
395 )
398class FilteredRelation:
399 """Specify custom filtering in the ON clause of SQL joins."""
401 def __init__(self, relation_name, *, condition=Q()):
402 if not relation_name:
403 raise ValueError("relation_name cannot be empty.")
404 self.relation_name = relation_name
405 self.alias = None
406 if not isinstance(condition, Q):
407 raise ValueError("condition argument must be a Q() instance.")
408 self.condition = condition
409 self.path = []
411 def __eq__(self, other):
412 if not isinstance(other, self.__class__):
413 return NotImplemented
414 return (
415 self.relation_name == other.relation_name
416 and self.alias == other.alias
417 and self.condition == other.condition
418 )
420 def clone(self):
421 clone = FilteredRelation(self.relation_name, condition=self.condition)
422 clone.alias = self.alias
423 clone.path = self.path[:]
424 return clone
426 def resolve_expression(self, *args, **kwargs):
427 """
428 QuerySet.annotate() only accepts expression-like arguments
429 (with a resolve_expression() method).
430 """
431 raise NotImplementedError("FilteredRelation.resolve_expression() is unused.")
433 def as_sql(self, compiler, connection):
434 # Resolve the condition in Join.filtered_relation.
435 query = compiler.query
436 where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
437 return compiler.compile(where)