Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-models/plain/models/query.py: 26%
1237 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-16 22:03 -0500
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-16 22:03 -0500
1"""
2The main QuerySet implementation. This provides the public API for the ORM.
3"""
5import copy
6import operator
7import warnings
8from itertools import chain, islice
10import plain.runtime
11from plain import exceptions
12from plain.exceptions import ValidationError
13from plain.models import (
14 sql,
15 transaction,
16)
17from plain.models.constants import LOOKUP_SEP, OnConflict
18from plain.models.db import (
19 PLAIN_VERSION_PICKLE_KEY,
20 IntegrityError,
21 NotSupportedError,
22 connections,
23 router,
24)
25from plain.models.expressions import Case, F, Value, When
26from plain.models.fields import (
27 AutoField,
28 DateField,
29 DateTimeField,
30 Field,
31)
32from plain.models.functions import Cast, Trunc
33from plain.models.query_utils import FilteredRelation, Q
34from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
35from plain.models.utils import (
36 AltersData,
37 create_namedtuple_class,
38 resolve_callables,
39)
40from plain.utils import timezone
41from plain.utils.functional import cached_property, partition
43# The maximum number of results to fetch in a get() query.
44MAX_GET_RESULTS = 21
46# The maximum number of items to display in a QuerySet.__repr__
47REPR_OUTPUT_SIZE = 20
50class BaseIterable:
51 def __init__(
52 self, queryset, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
53 ):
54 self.queryset = queryset
55 self.chunked_fetch = chunked_fetch
56 self.chunk_size = chunk_size
59class ModelIterable(BaseIterable):
60 """Iterable that yields a model instance for each row."""
62 def __iter__(self):
63 queryset = self.queryset
64 db = queryset.db
65 compiler = queryset.query.get_compiler(using=db)
66 # Execute the query. This will also fill compiler.select, klass_info,
67 # and annotations.
68 results = compiler.execute_sql(
69 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
70 )
71 select, klass_info, annotation_col_map = (
72 compiler.select,
73 compiler.klass_info,
74 compiler.annotation_col_map,
75 )
76 model_cls = klass_info["model"]
77 select_fields = klass_info["select_fields"]
78 model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
79 init_list = [
80 f[0].target.attname for f in select[model_fields_start:model_fields_end]
81 ]
82 related_populators = get_related_populators(klass_info, select, db)
83 known_related_objects = [
84 (
85 field,
86 related_objs,
87 operator.attrgetter(
88 *[
89 field.attname
90 if from_field == "self"
91 else queryset.model._meta.get_field(from_field).attname
92 for from_field in field.from_fields
93 ]
94 ),
95 )
96 for field, related_objs in queryset._known_related_objects.items()
97 ]
98 for row in compiler.results_iter(results):
99 obj = model_cls.from_db(
100 db, init_list, row[model_fields_start:model_fields_end]
101 )
102 for rel_populator in related_populators:
103 rel_populator.populate(row, obj)
104 if annotation_col_map:
105 for attr_name, col_pos in annotation_col_map.items():
106 setattr(obj, attr_name, row[col_pos])
108 # Add the known related objects to the model.
109 for field, rel_objs, rel_getter in known_related_objects:
110 # Avoid overwriting objects loaded by, e.g., select_related().
111 if field.is_cached(obj):
112 continue
113 rel_obj_id = rel_getter(obj)
114 try:
115 rel_obj = rel_objs[rel_obj_id]
116 except KeyError:
117 pass # May happen in qs1 | qs2 scenarios.
118 else:
119 setattr(obj, field.name, rel_obj)
121 yield obj
124class RawModelIterable(BaseIterable):
125 """
126 Iterable that yields a model instance for each row from a raw queryset.
127 """
129 def __iter__(self):
130 # Cache some things for performance reasons outside the loop.
131 db = self.queryset.db
132 query = self.queryset.query
133 connection = connections[db]
134 compiler = connection.ops.compiler("SQLCompiler")(query, connection, db)
135 query_iterator = iter(query)
137 try:
138 (
139 model_init_names,
140 model_init_pos,
141 annotation_fields,
142 ) = self.queryset.resolve_model_init_order()
143 model_cls = self.queryset.model
144 if model_cls._meta.pk.attname not in model_init_names:
145 raise exceptions.FieldDoesNotExist(
146 "Raw query must include the primary key"
147 )
148 fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
149 converters = compiler.get_converters(
150 [f.get_col(f.model._meta.db_table) if f else None for f in fields]
151 )
152 if converters:
153 query_iterator = compiler.apply_converters(query_iterator, converters)
154 for values in query_iterator:
155 # Associate fields to values
156 model_init_values = [values[pos] for pos in model_init_pos]
157 instance = model_cls.from_db(db, model_init_names, model_init_values)
158 if annotation_fields:
159 for column, pos in annotation_fields:
160 setattr(instance, column, values[pos])
161 yield instance
162 finally:
163 # Done iterating the Query. If it has its own cursor, close it.
164 if hasattr(query, "cursor") and query.cursor:
165 query.cursor.close()
168class ValuesIterable(BaseIterable):
169 """
170 Iterable returned by QuerySet.values() that yields a dict for each row.
171 """
173 def __iter__(self):
174 queryset = self.queryset
175 query = queryset.query
176 compiler = query.get_compiler(queryset.db)
178 # extra(select=...) cols are always at the start of the row.
179 names = [
180 *query.extra_select,
181 *query.values_select,
182 *query.annotation_select,
183 ]
184 indexes = range(len(names))
185 for row in compiler.results_iter(
186 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
187 ):
188 yield {names[i]: row[i] for i in indexes}
191class ValuesListIterable(BaseIterable):
192 """
193 Iterable returned by QuerySet.values_list(flat=False) that yields a tuple
194 for each row.
195 """
197 def __iter__(self):
198 queryset = self.queryset
199 query = queryset.query
200 compiler = query.get_compiler(queryset.db)
202 if queryset._fields:
203 # extra(select=...) cols are always at the start of the row.
204 names = [
205 *query.extra_select,
206 *query.values_select,
207 *query.annotation_select,
208 ]
209 fields = [
210 *queryset._fields,
211 *(f for f in query.annotation_select if f not in queryset._fields),
212 ]
213 if fields != names:
214 # Reorder according to fields.
215 index_map = {name: idx for idx, name in enumerate(names)}
216 rowfactory = operator.itemgetter(*[index_map[f] for f in fields])
217 return map(
218 rowfactory,
219 compiler.results_iter(
220 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
221 ),
222 )
223 return compiler.results_iter(
224 tuple_expected=True,
225 chunked_fetch=self.chunked_fetch,
226 chunk_size=self.chunk_size,
227 )
230class NamedValuesListIterable(ValuesListIterable):
231 """
232 Iterable returned by QuerySet.values_list(named=True) that yields a
233 namedtuple for each row.
234 """
236 def __iter__(self):
237 queryset = self.queryset
238 if queryset._fields:
239 names = queryset._fields
240 else:
241 query = queryset.query
242 names = [
243 *query.extra_select,
244 *query.values_select,
245 *query.annotation_select,
246 ]
247 tuple_class = create_namedtuple_class(*names)
248 new = tuple.__new__
249 for row in super().__iter__():
250 yield new(tuple_class, row)
253class FlatValuesListIterable(BaseIterable):
254 """
255 Iterable returned by QuerySet.values_list(flat=True) that yields single
256 values.
257 """
259 def __iter__(self):
260 queryset = self.queryset
261 compiler = queryset.query.get_compiler(queryset.db)
262 for row in compiler.results_iter(
263 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
264 ):
265 yield row[0]
268class QuerySet(AltersData):
269 """Represent a lazy database lookup for a set of objects."""
271 def __init__(self, model=None, query=None, using=None, hints=None):
272 self.model = model
273 self._db = using
274 self._hints = hints or {}
275 self._query = query or sql.Query(self.model)
276 self._result_cache = None
277 self._sticky_filter = False
278 self._for_write = False
279 self._prefetch_related_lookups = ()
280 self._prefetch_done = False
281 self._known_related_objects = {} # {rel_field: {pk: rel_obj}}
282 self._iterable_class = ModelIterable
283 self._fields = None
284 self._defer_next_filter = False
285 self._deferred_filter = None
287 @property
288 def query(self):
289 if self._deferred_filter:
290 negate, args, kwargs = self._deferred_filter
291 self._filter_or_exclude_inplace(negate, args, kwargs)
292 self._deferred_filter = None
293 return self._query
295 @query.setter
296 def query(self, value):
297 if value.values_select:
298 self._iterable_class = ValuesIterable
299 self._query = value
301 def as_manager(cls):
302 # Address the circular dependency between `Queryset` and `Manager`.
303 from plain.models.manager import Manager
305 manager = Manager.from_queryset(cls)()
306 manager._built_with_as_manager = True
307 return manager
309 as_manager.queryset_only = True
310 as_manager = classmethod(as_manager)
312 ########################
313 # PYTHON MAGIC METHODS #
314 ########################
316 def __deepcopy__(self, memo):
317 """Don't populate the QuerySet's cache."""
318 obj = self.__class__()
319 for k, v in self.__dict__.items():
320 if k == "_result_cache":
321 obj.__dict__[k] = None
322 else:
323 obj.__dict__[k] = copy.deepcopy(v, memo)
324 return obj
326 def __getstate__(self):
327 # Force the cache to be fully populated.
328 self._fetch_all()
329 return {**self.__dict__, PLAIN_VERSION_PICKLE_KEY: plain.runtime.__version__}
331 def __setstate__(self, state):
332 pickled_version = state.get(PLAIN_VERSION_PICKLE_KEY)
333 if pickled_version:
334 if pickled_version != plain.runtime.__version__:
335 warnings.warn(
336 "Pickled queryset instance's Plain version {} does not "
337 "match the current version {}.".format(
338 pickled_version, plain.runtime.__version__
339 ),
340 RuntimeWarning,
341 stacklevel=2,
342 )
343 else:
344 warnings.warn(
345 "Pickled queryset instance's Plain version is not specified.",
346 RuntimeWarning,
347 stacklevel=2,
348 )
349 self.__dict__.update(state)
351 def __repr__(self):
352 data = list(self[: REPR_OUTPUT_SIZE + 1])
353 if len(data) > REPR_OUTPUT_SIZE:
354 data[-1] = "...(remaining elements truncated)..."
355 return f"<{self.__class__.__name__} {data!r}>"
357 def __len__(self):
358 self._fetch_all()
359 return len(self._result_cache)
361 def __iter__(self):
362 """
363 The queryset iterator protocol uses three nested iterators in the
364 default case:
365 1. sql.compiler.execute_sql()
366 - Returns 100 rows at time (constants.GET_ITERATOR_CHUNK_SIZE)
367 using cursor.fetchmany(). This part is responsible for
368 doing some column masking, and returning the rows in chunks.
369 2. sql.compiler.results_iter()
370 - Returns one row at time. At this point the rows are still just
371 tuples. In some cases the return values are converted to
372 Python values at this location.
373 3. self.iterator()
374 - Responsible for turning the rows into model objects.
375 """
376 self._fetch_all()
377 return iter(self._result_cache)
379 def __bool__(self):
380 self._fetch_all()
381 return bool(self._result_cache)
383 def __getitem__(self, k):
384 """Retrieve an item or slice from the set of results."""
385 if not isinstance(k, int | slice):
386 raise TypeError(
387 "QuerySet indices must be integers or slices, not %s."
388 % type(k).__name__
389 )
390 if (isinstance(k, int) and k < 0) or (
391 isinstance(k, slice)
392 and (
393 (k.start is not None and k.start < 0)
394 or (k.stop is not None and k.stop < 0)
395 )
396 ):
397 raise ValueError("Negative indexing is not supported.")
399 if self._result_cache is not None:
400 return self._result_cache[k]
402 if isinstance(k, slice):
403 qs = self._chain()
404 if k.start is not None:
405 start = int(k.start)
406 else:
407 start = None
408 if k.stop is not None:
409 stop = int(k.stop)
410 else:
411 stop = None
412 qs.query.set_limits(start, stop)
413 return list(qs)[:: k.step] if k.step else qs
415 qs = self._chain()
416 qs.query.set_limits(k, k + 1)
417 qs._fetch_all()
418 return qs._result_cache[0]
420 def __class_getitem__(cls, *args, **kwargs):
421 return cls
423 def __and__(self, other):
424 self._check_operator_queryset(other, "&")
425 self._merge_sanity_check(other)
426 if isinstance(other, EmptyQuerySet):
427 return other
428 if isinstance(self, EmptyQuerySet):
429 return self
430 combined = self._chain()
431 combined._merge_known_related_objects(other)
432 combined.query.combine(other.query, sql.AND)
433 return combined
435 def __or__(self, other):
436 self._check_operator_queryset(other, "|")
437 self._merge_sanity_check(other)
438 if isinstance(self, EmptyQuerySet):
439 return other
440 if isinstance(other, EmptyQuerySet):
441 return self
442 query = (
443 self
444 if self.query.can_filter()
445 else self.model._base_manager.filter(pk__in=self.values("pk"))
446 )
447 combined = query._chain()
448 combined._merge_known_related_objects(other)
449 if not other.query.can_filter():
450 other = other.model._base_manager.filter(pk__in=other.values("pk"))
451 combined.query.combine(other.query, sql.OR)
452 return combined
454 def __xor__(self, other):
455 self._check_operator_queryset(other, "^")
456 self._merge_sanity_check(other)
457 if isinstance(self, EmptyQuerySet):
458 return other
459 if isinstance(other, EmptyQuerySet):
460 return self
461 query = (
462 self
463 if self.query.can_filter()
464 else self.model._base_manager.filter(pk__in=self.values("pk"))
465 )
466 combined = query._chain()
467 combined._merge_known_related_objects(other)
468 if not other.query.can_filter():
469 other = other.model._base_manager.filter(pk__in=other.values("pk"))
470 combined.query.combine(other.query, sql.XOR)
471 return combined
473 ####################################
474 # METHODS THAT DO DATABASE QUERIES #
475 ####################################
477 def _iterator(self, use_chunked_fetch, chunk_size):
478 iterable = self._iterable_class(
479 self,
480 chunked_fetch=use_chunked_fetch,
481 chunk_size=chunk_size or 2000,
482 )
483 if not self._prefetch_related_lookups or chunk_size is None:
484 yield from iterable
485 return
487 iterator = iter(iterable)
488 while results := list(islice(iterator, chunk_size)):
489 prefetch_related_objects(results, *self._prefetch_related_lookups)
490 yield from results
492 def iterator(self, chunk_size=None):
493 """
494 An iterator over the results from applying this QuerySet to the
495 database. chunk_size must be provided for QuerySets that prefetch
496 related objects. Otherwise, a default chunk_size of 2000 is supplied.
497 """
498 if chunk_size is None:
499 if self._prefetch_related_lookups:
500 raise ValueError(
501 "chunk_size must be provided when using QuerySet.iterator() after "
502 "prefetch_related()."
503 )
504 elif chunk_size <= 0:
505 raise ValueError("Chunk size must be strictly positive.")
506 use_chunked_fetch = not connections[self.db].settings_dict.get(
507 "DISABLE_SERVER_SIDE_CURSORS"
508 )
509 return self._iterator(use_chunked_fetch, chunk_size)
511 def aggregate(self, *args, **kwargs):
512 """
513 Return a dictionary containing the calculations (aggregation)
514 over the current queryset.
516 If args is present the expression is passed as a kwarg using
517 the Aggregate object's default alias.
518 """
519 if self.query.distinct_fields:
520 raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
521 self._validate_values_are_expressions(
522 (*args, *kwargs.values()), method_name="aggregate"
523 )
524 for arg in args:
525 # The default_alias property raises TypeError if default_alias
526 # can't be set automatically or AttributeError if it isn't an
527 # attribute.
528 try:
529 arg.default_alias
530 except (AttributeError, TypeError):
531 raise TypeError("Complex aggregates require an alias")
532 kwargs[arg.default_alias] = arg
534 return self.query.chain().get_aggregation(self.db, kwargs)
536 def count(self):
537 """
538 Perform a SELECT COUNT() and return the number of records as an
539 integer.
541 If the QuerySet is already fully cached, return the length of the
542 cached results set to avoid multiple SELECT COUNT(*) calls.
543 """
544 if self._result_cache is not None:
545 return len(self._result_cache)
547 return self.query.get_count(using=self.db)
549 def get(self, *args, **kwargs):
550 """
551 Perform the query and return a single object matching the given
552 keyword arguments.
553 """
554 if self.query.combinator and (args or kwargs):
555 raise NotSupportedError(
556 "Calling QuerySet.get(...) with filters after %s() is not "
557 "supported." % self.query.combinator
558 )
559 clone = self._chain() if self.query.combinator else self.filter(*args, **kwargs)
560 if self.query.can_filter() and not self.query.distinct_fields:
561 clone = clone.order_by()
562 limit = None
563 if (
564 not clone.query.select_for_update
565 or connections[clone.db].features.supports_select_for_update_with_limit
566 ):
567 limit = MAX_GET_RESULTS
568 clone.query.set_limits(high=limit)
569 num = len(clone)
570 if num == 1:
571 return clone._result_cache[0]
572 if not num:
573 raise self.model.DoesNotExist(
574 "%s matching query does not exist." % self.model._meta.object_name
575 )
576 raise self.model.MultipleObjectsReturned(
577 "get() returned more than one {} -- it returned {}!".format(
578 self.model._meta.object_name,
579 num if not limit or num < limit else "more than %s" % (limit - 1),
580 )
581 )
583 def create(self, **kwargs):
584 """
585 Create a new object with the given kwargs, saving it to the database
586 and returning the created object.
587 """
588 obj = self.model(**kwargs)
589 self._for_write = True
590 obj.save(force_insert=True, using=self.db)
591 return obj
593 def _prepare_for_bulk_create(self, objs):
594 for obj in objs:
595 if obj.pk is None:
596 # Populate new PK values.
597 obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
598 obj._prepare_related_fields_for_save(operation_name="bulk_create")
600 def _check_bulk_create_options(
601 self, ignore_conflicts, update_conflicts, update_fields, unique_fields
602 ):
603 if ignore_conflicts and update_conflicts:
604 raise ValueError(
605 "ignore_conflicts and update_conflicts are mutually exclusive."
606 )
607 db_features = connections[self.db].features
608 if ignore_conflicts:
609 if not db_features.supports_ignore_conflicts:
610 raise NotSupportedError(
611 "This database backend does not support ignoring conflicts."
612 )
613 return OnConflict.IGNORE
614 elif update_conflicts:
615 if not db_features.supports_update_conflicts:
616 raise NotSupportedError(
617 "This database backend does not support updating conflicts."
618 )
619 if not update_fields:
620 raise ValueError(
621 "Fields that will be updated when a row insertion fails "
622 "on conflicts must be provided."
623 )
624 if unique_fields and not db_features.supports_update_conflicts_with_target:
625 raise NotSupportedError(
626 "This database backend does not support updating "
627 "conflicts with specifying unique fields that can trigger "
628 "the upsert."
629 )
630 if not unique_fields and db_features.supports_update_conflicts_with_target:
631 raise ValueError(
632 "Unique fields that can trigger the upsert must be provided."
633 )
634 # Updating primary keys and non-concrete fields is forbidden.
635 if any(not f.concrete or f.many_to_many for f in update_fields):
636 raise ValueError(
637 "bulk_create() can only be used with concrete fields in "
638 "update_fields."
639 )
640 if any(f.primary_key for f in update_fields):
641 raise ValueError(
642 "bulk_create() cannot be used with primary keys in "
643 "update_fields."
644 )
645 if unique_fields:
646 if any(not f.concrete or f.many_to_many for f in unique_fields):
647 raise ValueError(
648 "bulk_create() can only be used with concrete fields "
649 "in unique_fields."
650 )
651 return OnConflict.UPDATE
652 return None
654 def bulk_create(
655 self,
656 objs,
657 batch_size=None,
658 ignore_conflicts=False,
659 update_conflicts=False,
660 update_fields=None,
661 unique_fields=None,
662 ):
663 """
664 Insert each of the instances into the database. Do *not* call
665 save() on each of the instances, do not send any pre/post_save
666 signals, and do not set the primary key attribute if it is an
667 autoincrement field (except if features.can_return_rows_from_bulk_insert=True).
668 Multi-table models are not supported.
669 """
670 # When you bulk insert you don't get the primary keys back (if it's an
671 # autoincrement, except if can_return_rows_from_bulk_insert=True), so
672 # you can't insert into the child tables which references this. There
673 # are two workarounds:
674 # 1) This could be implemented if you didn't have an autoincrement pk
675 # 2) You could do it by doing O(n) normal inserts into the parent
676 # tables to get the primary keys back and then doing a single bulk
677 # insert into the childmost table.
678 # We currently set the primary keys on the objects when using
679 # PostgreSQL via the RETURNING ID clause. It should be possible for
680 # Oracle as well, but the semantics for extracting the primary keys is
681 # trickier so it's not done yet.
682 if batch_size is not None and batch_size <= 0:
683 raise ValueError("Batch size must be a positive integer.")
684 # Check that the parents share the same concrete model with the our
685 # model to detect the inheritance pattern ConcreteGrandParent ->
686 # MultiTableParent -> ProxyChild. Simply checking self.model._meta.proxy
687 # would not identify that case as involving multiple tables.
688 for parent in self.model._meta.get_parent_list():
689 if parent._meta.concrete_model is not self.model._meta.concrete_model:
690 raise ValueError("Can't bulk create a multi-table inherited model")
691 if not objs:
692 return objs
693 opts = self.model._meta
694 if unique_fields:
695 # Primary key is allowed in unique_fields.
696 unique_fields = [
697 self.model._meta.get_field(opts.pk.name if name == "pk" else name)
698 for name in unique_fields
699 ]
700 if update_fields:
701 update_fields = [self.model._meta.get_field(name) for name in update_fields]
702 on_conflict = self._check_bulk_create_options(
703 ignore_conflicts,
704 update_conflicts,
705 update_fields,
706 unique_fields,
707 )
708 self._for_write = True
709 fields = opts.concrete_fields
710 objs = list(objs)
711 self._prepare_for_bulk_create(objs)
712 with transaction.atomic(using=self.db, savepoint=False):
713 objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
714 if objs_with_pk:
715 returned_columns = self._batched_insert(
716 objs_with_pk,
717 fields,
718 batch_size,
719 on_conflict=on_conflict,
720 update_fields=update_fields,
721 unique_fields=unique_fields,
722 )
723 for obj_with_pk, results in zip(objs_with_pk, returned_columns):
724 for result, field in zip(results, opts.db_returning_fields):
725 if field != opts.pk:
726 setattr(obj_with_pk, field.attname, result)
727 for obj_with_pk in objs_with_pk:
728 obj_with_pk._state.adding = False
729 obj_with_pk._state.db = self.db
730 if objs_without_pk:
731 fields = [f for f in fields if not isinstance(f, AutoField)]
732 returned_columns = self._batched_insert(
733 objs_without_pk,
734 fields,
735 batch_size,
736 on_conflict=on_conflict,
737 update_fields=update_fields,
738 unique_fields=unique_fields,
739 )
740 connection = connections[self.db]
741 if (
742 connection.features.can_return_rows_from_bulk_insert
743 and on_conflict is None
744 ):
745 assert len(returned_columns) == len(objs_without_pk)
746 for obj_without_pk, results in zip(objs_without_pk, returned_columns):
747 for result, field in zip(results, opts.db_returning_fields):
748 setattr(obj_without_pk, field.attname, result)
749 obj_without_pk._state.adding = False
750 obj_without_pk._state.db = self.db
752 return objs
754 def bulk_update(self, objs, fields, batch_size=None):
755 """
756 Update the given fields in each of the given objects in the database.
757 """
758 if batch_size is not None and batch_size <= 0:
759 raise ValueError("Batch size must be a positive integer.")
760 if not fields:
761 raise ValueError("Field names must be given to bulk_update().")
762 objs = tuple(objs)
763 if any(obj.pk is None for obj in objs):
764 raise ValueError("All bulk_update() objects must have a primary key set.")
765 fields = [self.model._meta.get_field(name) for name in fields]
766 if any(not f.concrete or f.many_to_many for f in fields):
767 raise ValueError("bulk_update() can only be used with concrete fields.")
768 if any(f.primary_key for f in fields):
769 raise ValueError("bulk_update() cannot be used with primary key fields.")
770 if not objs:
771 return 0
772 for obj in objs:
773 obj._prepare_related_fields_for_save(
774 operation_name="bulk_update", fields=fields
775 )
776 # PK is used twice in the resulting update query, once in the filter
777 # and once in the WHEN. Each field will also have one CAST.
778 self._for_write = True
779 connection = connections[self.db]
780 max_batch_size = connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs)
781 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
782 requires_casting = connection.features.requires_casted_case_in_updates
783 batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size))
784 updates = []
785 for batch_objs in batches:
786 update_kwargs = {}
787 for field in fields:
788 when_statements = []
789 for obj in batch_objs:
790 attr = getattr(obj, field.attname)
791 if not hasattr(attr, "resolve_expression"):
792 attr = Value(attr, output_field=field)
793 when_statements.append(When(pk=obj.pk, then=attr))
794 case_statement = Case(*when_statements, output_field=field)
795 if requires_casting:
796 case_statement = Cast(case_statement, output_field=field)
797 update_kwargs[field.attname] = case_statement
798 updates.append(([obj.pk for obj in batch_objs], update_kwargs))
799 rows_updated = 0
800 queryset = self.using(self.db)
801 with transaction.atomic(using=self.db, savepoint=False):
802 for pks, update_kwargs in updates:
803 rows_updated += queryset.filter(pk__in=pks).update(**update_kwargs)
804 return rows_updated
806 bulk_update.alters_data = True
808 def get_or_create(self, defaults=None, **kwargs):
809 """
810 Look up an object with the given kwargs, creating one if necessary.
811 Return a tuple of (object, created), where created is a boolean
812 specifying whether an object was created.
813 """
814 # The get() needs to be targeted at the write database in order
815 # to avoid potential transaction consistency problems.
816 self._for_write = True
817 try:
818 return self.get(**kwargs), False
819 except self.model.DoesNotExist:
820 params = self._extract_model_params(defaults, **kwargs)
821 # Try to create an object using passed params.
822 try:
823 with transaction.atomic(using=self.db):
824 params = dict(resolve_callables(params))
825 return self.create(**params), True
826 except (IntegrityError, ValidationError):
827 # Since create() also validates by default,
828 # we can get any kind of ValidationError here,
829 # or it can flow through and get an IntegrityError from the database.
830 # The main thing we're concerned about is uniqueness failures,
831 # but ValidationError could include other things too.
832 # In all cases though it should be fine to try the get() again
833 # and return an existing object.
834 try:
835 return self.get(**kwargs), False
836 except self.model.DoesNotExist:
837 pass
838 raise
840 def update_or_create(self, defaults=None, create_defaults=None, **kwargs):
841 """
842 Look up an object with the given kwargs, updating one with defaults
843 if it exists, otherwise create a new one. Optionally, an object can
844 be created with different values than defaults by using
845 create_defaults.
846 Return a tuple (object, created), where created is a boolean
847 specifying whether an object was created.
848 """
849 if create_defaults is None:
850 update_defaults = create_defaults = defaults or {}
851 else:
852 update_defaults = defaults or {}
853 self._for_write = True
854 with transaction.atomic(using=self.db):
855 # Lock the row so that a concurrent update is blocked until
856 # update_or_create() has performed its save.
857 obj, created = self.select_for_update().get_or_create(
858 create_defaults, **kwargs
859 )
860 if created:
861 return obj, created
862 for k, v in resolve_callables(update_defaults):
863 setattr(obj, k, v)
865 update_fields = set(update_defaults)
866 concrete_field_names = self.model._meta._non_pk_concrete_field_names
867 # update_fields does not support non-concrete fields.
868 if concrete_field_names.issuperset(update_fields):
869 # Add fields which are set on pre_save(), e.g. auto_now fields.
870 # This is to maintain backward compatibility as these fields
871 # are not updated unless explicitly specified in the
872 # update_fields list.
873 for field in self.model._meta.local_concrete_fields:
874 if not (
875 field.primary_key or field.__class__.pre_save is Field.pre_save
876 ):
877 update_fields.add(field.name)
878 if field.name != field.attname:
879 update_fields.add(field.attname)
880 obj.save(using=self.db, update_fields=update_fields)
881 else:
882 obj.save(using=self.db)
883 return obj, False
885 def _extract_model_params(self, defaults, **kwargs):
886 """
887 Prepare `params` for creating a model instance based on the given
888 kwargs; for use by get_or_create().
889 """
890 defaults = defaults or {}
891 params = {k: v for k, v in kwargs.items() if LOOKUP_SEP not in k}
892 params.update(defaults)
893 property_names = self.model._meta._property_names
894 invalid_params = []
895 for param in params:
896 try:
897 self.model._meta.get_field(param)
898 except exceptions.FieldDoesNotExist:
899 # It's okay to use a model's property if it has a setter.
900 if not (param in property_names and getattr(self.model, param).fset):
901 invalid_params.append(param)
902 if invalid_params:
903 raise exceptions.FieldError(
904 "Invalid field name(s) for model {}: '{}'.".format(
905 self.model._meta.object_name,
906 "', '".join(sorted(invalid_params)),
907 )
908 )
909 return params
911 def _earliest(self, *fields):
912 """
913 Return the earliest object according to fields (if given) or by the
914 model's Meta.get_latest_by.
915 """
916 if fields:
917 order_by = fields
918 else:
919 order_by = getattr(self.model._meta, "get_latest_by")
920 if order_by and not isinstance(order_by, tuple | list):
921 order_by = (order_by,)
922 if order_by is None:
923 raise ValueError(
924 "earliest() and latest() require either fields as positional "
925 "arguments or 'get_latest_by' in the model's Meta."
926 )
927 obj = self._chain()
928 obj.query.set_limits(high=1)
929 obj.query.clear_ordering(force=True)
930 obj.query.add_ordering(*order_by)
931 return obj.get()
933 def earliest(self, *fields):
934 if self.query.is_sliced:
935 raise TypeError("Cannot change a query once a slice has been taken.")
936 return self._earliest(*fields)
938 def latest(self, *fields):
939 """
940 Return the latest object according to fields (if given) or by the
941 model's Meta.get_latest_by.
942 """
943 if self.query.is_sliced:
944 raise TypeError("Cannot change a query once a slice has been taken.")
945 return self.reverse()._earliest(*fields)
947 def first(self):
948 """Return the first object of a query or None if no match is found."""
949 if self.ordered:
950 queryset = self
951 else:
952 self._check_ordering_first_last_queryset_aggregation(method="first")
953 queryset = self.order_by("pk")
954 for obj in queryset[:1]:
955 return obj
957 def last(self):
958 """Return the last object of a query or None if no match is found."""
959 if self.ordered:
960 queryset = self.reverse()
961 else:
962 self._check_ordering_first_last_queryset_aggregation(method="last")
963 queryset = self.order_by("-pk")
964 for obj in queryset[:1]:
965 return obj
967 def in_bulk(self, id_list=None, *, field_name="pk"):
968 """
969 Return a dictionary mapping each of the given IDs to the object with
970 that ID. If `id_list` isn't provided, evaluate the entire QuerySet.
971 """
972 if self.query.is_sliced:
973 raise TypeError("Cannot use 'limit' or 'offset' with in_bulk().")
974 opts = self.model._meta
975 unique_fields = [
976 constraint.fields[0]
977 for constraint in opts.total_unique_constraints
978 if len(constraint.fields) == 1
979 ]
980 if (
981 field_name != "pk"
982 and not opts.get_field(field_name).unique
983 and field_name not in unique_fields
984 and self.query.distinct_fields != (field_name,)
985 ):
986 raise ValueError(
987 "in_bulk()'s field_name must be a unique field but %r isn't."
988 % field_name
989 )
990 if id_list is not None:
991 if not id_list:
992 return {}
993 filter_key = f"{field_name}__in"
994 batch_size = connections[self.db].features.max_query_params
995 id_list = tuple(id_list)
996 # If the database has a limit on the number of query parameters
997 # (e.g. SQLite), retrieve objects in batches if necessary.
998 if batch_size and batch_size < len(id_list):
999 qs = ()
1000 for offset in range(0, len(id_list), batch_size):
1001 batch = id_list[offset : offset + batch_size]
1002 qs += tuple(self.filter(**{filter_key: batch}))
1003 else:
1004 qs = self.filter(**{filter_key: id_list})
1005 else:
1006 qs = self._chain()
1007 return {getattr(obj, field_name): obj for obj in qs}
1009 def delete(self):
1010 """Delete the records in the current QuerySet."""
1011 self._not_support_combined_queries("delete")
1012 if self.query.is_sliced:
1013 raise TypeError("Cannot use 'limit' or 'offset' with delete().")
1014 if self.query.distinct or self.query.distinct_fields:
1015 raise TypeError("Cannot call delete() after .distinct().")
1016 if self._fields is not None:
1017 raise TypeError("Cannot call delete() after .values() or .values_list()")
1019 del_query = self._chain()
1021 # The delete is actually 2 queries - one to find related objects,
1022 # and one to delete. Make sure that the discovery of related
1023 # objects is performed on the same database as the deletion.
1024 del_query._for_write = True
1026 # Disable non-supported fields.
1027 del_query.query.select_for_update = False
1028 del_query.query.select_related = False
1029 del_query.query.clear_ordering(force=True)
1031 from plain.models.deletion import Collector
1033 collector = Collector(using=del_query.db, origin=self)
1034 collector.collect(del_query)
1035 deleted, _rows_count = collector.delete()
1037 # Clear the result cache, in case this QuerySet gets reused.
1038 self._result_cache = None
1039 return deleted, _rows_count
1041 delete.alters_data = True
1042 delete.queryset_only = True
1044 def _raw_delete(self, using):
1045 """
1046 Delete objects found from the given queryset in single direct SQL
1047 query. No signals are sent and there is no protection for cascades.
1048 """
1049 query = self.query.clone()
1050 query.__class__ = sql.DeleteQuery
1051 cursor = query.get_compiler(using).execute_sql(CURSOR)
1052 if cursor:
1053 with cursor:
1054 return cursor.rowcount
1055 return 0
1057 _raw_delete.alters_data = True
1059 def update(self, **kwargs):
1060 """
1061 Update all elements in the current QuerySet, setting all the given
1062 fields to the appropriate values.
1063 """
1064 self._not_support_combined_queries("update")
1065 if self.query.is_sliced:
1066 raise TypeError("Cannot update a query once a slice has been taken.")
1067 self._for_write = True
1068 query = self.query.chain(sql.UpdateQuery)
1069 query.add_update_values(kwargs)
1071 # Inline annotations in order_by(), if possible.
1072 new_order_by = []
1073 for col in query.order_by:
1074 alias = col
1075 descending = False
1076 if isinstance(alias, str) and alias.startswith("-"):
1077 alias = alias.removeprefix("-")
1078 descending = True
1079 if annotation := query.annotations.get(alias):
1080 if getattr(annotation, "contains_aggregate", False):
1081 raise exceptions.FieldError(
1082 f"Cannot update when ordering by an aggregate: {annotation}"
1083 )
1084 if descending:
1085 annotation = annotation.desc()
1086 new_order_by.append(annotation)
1087 else:
1088 new_order_by.append(col)
1089 query.order_by = tuple(new_order_by)
1091 # Clear any annotations so that they won't be present in subqueries.
1092 query.annotations = {}
1093 with transaction.mark_for_rollback_on_error(using=self.db):
1094 rows = query.get_compiler(self.db).execute_sql(CURSOR)
1095 self._result_cache = None
1096 return rows
1098 update.alters_data = True
1100 def _update(self, values):
1101 """
1102 A version of update() that accepts field objects instead of field names.
1103 Used primarily for model saving and not intended for use by general
1104 code (it requires too much poking around at model internals to be
1105 useful at that level).
1106 """
1107 if self.query.is_sliced:
1108 raise TypeError("Cannot update a query once a slice has been taken.")
1109 query = self.query.chain(sql.UpdateQuery)
1110 query.add_update_fields(values)
1111 # Clear any annotations so that they won't be present in subqueries.
1112 query.annotations = {}
1113 self._result_cache = None
1114 return query.get_compiler(self.db).execute_sql(CURSOR)
1116 _update.alters_data = True
1117 _update.queryset_only = False
1119 def exists(self):
1120 """
1121 Return True if the QuerySet would have any results, False otherwise.
1122 """
1123 if self._result_cache is None:
1124 return self.query.has_results(using=self.db)
1125 return bool(self._result_cache)
1127 def contains(self, obj):
1128 """
1129 Return True if the QuerySet contains the provided obj,
1130 False otherwise.
1131 """
1132 self._not_support_combined_queries("contains")
1133 if self._fields is not None:
1134 raise TypeError(
1135 "Cannot call QuerySet.contains() after .values() or .values_list()."
1136 )
1137 try:
1138 if obj._meta.concrete_model != self.model._meta.concrete_model:
1139 return False
1140 except AttributeError:
1141 raise TypeError("'obj' must be a model instance.")
1142 if obj.pk is None:
1143 raise ValueError("QuerySet.contains() cannot be used on unsaved objects.")
1144 if self._result_cache is not None:
1145 return obj in self._result_cache
1146 return self.filter(pk=obj.pk).exists()
1148 def _prefetch_related_objects(self):
1149 # This method can only be called once the result cache has been filled.
1150 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1151 self._prefetch_done = True
1153 def explain(self, *, format=None, **options):
1154 """
1155 Runs an EXPLAIN on the SQL query this QuerySet would perform, and
1156 returns the results.
1157 """
1158 return self.query.explain(using=self.db, format=format, **options)
1160 ##################################################
1161 # PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
1162 ##################################################
1164 def raw(self, raw_query, params=(), translations=None, using=None):
1165 if using is None:
1166 using = self.db
1167 qs = RawQuerySet(
1168 raw_query,
1169 model=self.model,
1170 params=params,
1171 translations=translations,
1172 using=using,
1173 )
1174 qs._prefetch_related_lookups = self._prefetch_related_lookups[:]
1175 return qs
1177 def _values(self, *fields, **expressions):
1178 clone = self._chain()
1179 if expressions:
1180 clone = clone.annotate(**expressions)
1181 clone._fields = fields
1182 clone.query.set_values(fields)
1183 return clone
1185 def values(self, *fields, **expressions):
1186 fields += tuple(expressions)
1187 clone = self._values(*fields, **expressions)
1188 clone._iterable_class = ValuesIterable
1189 return clone
1191 def values_list(self, *fields, flat=False, named=False):
1192 if flat and named:
1193 raise TypeError("'flat' and 'named' can't be used together.")
1194 if flat and len(fields) > 1:
1195 raise TypeError(
1196 "'flat' is not valid when values_list is called with more than one "
1197 "field."
1198 )
1200 field_names = {f for f in fields if not hasattr(f, "resolve_expression")}
1201 _fields = []
1202 expressions = {}
1203 counter = 1
1204 for field in fields:
1205 if hasattr(field, "resolve_expression"):
1206 field_id_prefix = getattr(
1207 field, "default_alias", field.__class__.__name__.lower()
1208 )
1209 while True:
1210 field_id = field_id_prefix + str(counter)
1211 counter += 1
1212 if field_id not in field_names:
1213 break
1214 expressions[field_id] = field
1215 _fields.append(field_id)
1216 else:
1217 _fields.append(field)
1219 clone = self._values(*_fields, **expressions)
1220 clone._iterable_class = (
1221 NamedValuesListIterable
1222 if named
1223 else FlatValuesListIterable
1224 if flat
1225 else ValuesListIterable
1226 )
1227 return clone
1229 def dates(self, field_name, kind, order="ASC"):
1230 """
1231 Return a list of date objects representing all available dates for
1232 the given field_name, scoped to 'kind'.
1233 """
1234 if kind not in ("year", "month", "week", "day"):
1235 raise ValueError("'kind' must be one of 'year', 'month', 'week', or 'day'.")
1236 if order not in ("ASC", "DESC"):
1237 raise ValueError("'order' must be either 'ASC' or 'DESC'.")
1238 return (
1239 self.annotate(
1240 datefield=Trunc(field_name, kind, output_field=DateField()),
1241 plain_field=F(field_name),
1242 )
1243 .values_list("datefield", flat=True)
1244 .distinct()
1245 .filter(plain_field__isnull=False)
1246 .order_by(("-" if order == "DESC" else "") + "datefield")
1247 )
1249 def datetimes(self, field_name, kind, order="ASC", tzinfo=None):
1250 """
1251 Return a list of datetime objects representing all available
1252 datetimes for the given field_name, scoped to 'kind'.
1253 """
1254 if kind not in ("year", "month", "week", "day", "hour", "minute", "second"):
1255 raise ValueError(
1256 "'kind' must be one of 'year', 'month', 'week', 'day', "
1257 "'hour', 'minute', or 'second'."
1258 )
1259 if order not in ("ASC", "DESC"):
1260 raise ValueError("'order' must be either 'ASC' or 'DESC'.")
1262 if tzinfo is None:
1263 tzinfo = timezone.get_current_timezone()
1265 return (
1266 self.annotate(
1267 datetimefield=Trunc(
1268 field_name,
1269 kind,
1270 output_field=DateTimeField(),
1271 tzinfo=tzinfo,
1272 ),
1273 plain_field=F(field_name),
1274 )
1275 .values_list("datetimefield", flat=True)
1276 .distinct()
1277 .filter(plain_field__isnull=False)
1278 .order_by(("-" if order == "DESC" else "") + "datetimefield")
1279 )
1281 def none(self):
1282 """Return an empty QuerySet."""
1283 clone = self._chain()
1284 clone.query.set_empty()
1285 return clone
1287 ##################################################################
1288 # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
1289 ##################################################################
1291 def all(self):
1292 """
1293 Return a new QuerySet that is a copy of the current one. This allows a
1294 QuerySet to proxy for a model manager in some cases.
1295 """
1296 return self._chain()
1298 def filter(self, *args, **kwargs):
1299 """
1300 Return a new QuerySet instance with the args ANDed to the existing
1301 set.
1302 """
1303 self._not_support_combined_queries("filter")
1304 return self._filter_or_exclude(False, args, kwargs)
1306 def exclude(self, *args, **kwargs):
1307 """
1308 Return a new QuerySet instance with NOT (args) ANDed to the existing
1309 set.
1310 """
1311 self._not_support_combined_queries("exclude")
1312 return self._filter_or_exclude(True, args, kwargs)
1314 def _filter_or_exclude(self, negate, args, kwargs):
1315 if (args or kwargs) and self.query.is_sliced:
1316 raise TypeError("Cannot filter a query once a slice has been taken.")
1317 clone = self._chain()
1318 if self._defer_next_filter:
1319 self._defer_next_filter = False
1320 clone._deferred_filter = negate, args, kwargs
1321 else:
1322 clone._filter_or_exclude_inplace(negate, args, kwargs)
1323 return clone
1325 def _filter_or_exclude_inplace(self, negate, args, kwargs):
1326 if negate:
1327 self._query.add_q(~Q(*args, **kwargs))
1328 else:
1329 self._query.add_q(Q(*args, **kwargs))
1331 def complex_filter(self, filter_obj):
1332 """
1333 Return a new QuerySet instance with filter_obj added to the filters.
1335 filter_obj can be a Q object or a dictionary of keyword lookup
1336 arguments.
1338 This exists to support framework features such as 'limit_choices_to',
1339 and usually it will be more natural to use other methods.
1340 """
1341 if isinstance(filter_obj, Q):
1342 clone = self._chain()
1343 clone.query.add_q(filter_obj)
1344 return clone
1345 else:
1346 return self._filter_or_exclude(False, args=(), kwargs=filter_obj)
1348 def _combinator_query(self, combinator, *other_qs, all=False):
1349 # Clone the query to inherit the select list and everything
1350 clone = self._chain()
1351 # Clear limits and ordering so they can be reapplied
1352 clone.query.clear_ordering(force=True)
1353 clone.query.clear_limits()
1354 clone.query.combined_queries = (self.query,) + tuple(
1355 qs.query for qs in other_qs
1356 )
1357 clone.query.combinator = combinator
1358 clone.query.combinator_all = all
1359 return clone
1361 def union(self, *other_qs, all=False):
1362 # If the query is an EmptyQuerySet, combine all nonempty querysets.
1363 if isinstance(self, EmptyQuerySet):
1364 qs = [q for q in other_qs if not isinstance(q, EmptyQuerySet)]
1365 if not qs:
1366 return self
1367 if len(qs) == 1:
1368 return qs[0]
1369 return qs[0]._combinator_query("union", *qs[1:], all=all)
1370 return self._combinator_query("union", *other_qs, all=all)
1372 def intersection(self, *other_qs):
1373 # If any query is an EmptyQuerySet, return it.
1374 if isinstance(self, EmptyQuerySet):
1375 return self
1376 for other in other_qs:
1377 if isinstance(other, EmptyQuerySet):
1378 return other
1379 return self._combinator_query("intersection", *other_qs)
1381 def difference(self, *other_qs):
1382 # If the query is an EmptyQuerySet, return it.
1383 if isinstance(self, EmptyQuerySet):
1384 return self
1385 return self._combinator_query("difference", *other_qs)
1387 def select_for_update(self, nowait=False, skip_locked=False, of=(), no_key=False):
1388 """
1389 Return a new QuerySet instance that will select objects with a
1390 FOR UPDATE lock.
1391 """
1392 if nowait and skip_locked:
1393 raise ValueError("The nowait option cannot be used with skip_locked.")
1394 obj = self._chain()
1395 obj._for_write = True
1396 obj.query.select_for_update = True
1397 obj.query.select_for_update_nowait = nowait
1398 obj.query.select_for_update_skip_locked = skip_locked
1399 obj.query.select_for_update_of = of
1400 obj.query.select_for_no_key_update = no_key
1401 return obj
1403 def select_related(self, *fields):
1404 """
1405 Return a new QuerySet instance that will select related objects.
1407 If fields are specified, they must be ForeignKey fields and only those
1408 related objects are included in the selection.
1410 If select_related(None) is called, clear the list.
1411 """
1412 self._not_support_combined_queries("select_related")
1413 if self._fields is not None:
1414 raise TypeError(
1415 "Cannot call select_related() after .values() or .values_list()"
1416 )
1418 obj = self._chain()
1419 if fields == (None,):
1420 obj.query.select_related = False
1421 elif fields:
1422 obj.query.add_select_related(fields)
1423 else:
1424 obj.query.select_related = True
1425 return obj
1427 def prefetch_related(self, *lookups):
1428 """
1429 Return a new QuerySet instance that will prefetch the specified
1430 Many-To-One and Many-To-Many related objects when the QuerySet is
1431 evaluated.
1433 When prefetch_related() is called more than once, append to the list of
1434 prefetch lookups. If prefetch_related(None) is called, clear the list.
1435 """
1436 self._not_support_combined_queries("prefetch_related")
1437 clone = self._chain()
1438 if lookups == (None,):
1439 clone._prefetch_related_lookups = ()
1440 else:
1441 for lookup in lookups:
1442 if isinstance(lookup, Prefetch):
1443 lookup = lookup.prefetch_to
1444 lookup = lookup.split(LOOKUP_SEP, 1)[0]
1445 if lookup in self.query._filtered_relations:
1446 raise ValueError(
1447 "prefetch_related() is not supported with FilteredRelation."
1448 )
1449 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1450 return clone
1452 def annotate(self, *args, **kwargs):
1453 """
1454 Return a query set in which the returned objects have been annotated
1455 with extra data or aggregations.
1456 """
1457 self._not_support_combined_queries("annotate")
1458 return self._annotate(args, kwargs, select=True)
1460 def alias(self, *args, **kwargs):
1461 """
1462 Return a query set with added aliases for extra data or aggregations.
1463 """
1464 self._not_support_combined_queries("alias")
1465 return self._annotate(args, kwargs, select=False)
1467 def _annotate(self, args, kwargs, select=True):
1468 self._validate_values_are_expressions(
1469 args + tuple(kwargs.values()), method_name="annotate"
1470 )
1471 annotations = {}
1472 for arg in args:
1473 # The default_alias property may raise a TypeError.
1474 try:
1475 if arg.default_alias in kwargs:
1476 raise ValueError(
1477 "The named annotation '%s' conflicts with the "
1478 "default name for another annotation." % arg.default_alias
1479 )
1480 except TypeError:
1481 raise TypeError("Complex annotations require an alias")
1482 annotations[arg.default_alias] = arg
1483 annotations.update(kwargs)
1485 clone = self._chain()
1486 names = self._fields
1487 if names is None:
1488 names = set(
1489 chain.from_iterable(
1490 (field.name, field.attname)
1491 if hasattr(field, "attname")
1492 else (field.name,)
1493 for field in self.model._meta.get_fields()
1494 )
1495 )
1497 for alias, annotation in annotations.items():
1498 if alias in names:
1499 raise ValueError(
1500 "The annotation '%s' conflicts with a field on "
1501 "the model." % alias
1502 )
1503 if isinstance(annotation, FilteredRelation):
1504 clone.query.add_filtered_relation(annotation, alias)
1505 else:
1506 clone.query.add_annotation(
1507 annotation,
1508 alias,
1509 select=select,
1510 )
1511 for alias, annotation in clone.query.annotations.items():
1512 if alias in annotations and annotation.contains_aggregate:
1513 if clone._fields is None:
1514 clone.query.group_by = True
1515 else:
1516 clone.query.set_group_by()
1517 break
1519 return clone
1521 def order_by(self, *field_names):
1522 """Return a new QuerySet instance with the ordering changed."""
1523 if self.query.is_sliced:
1524 raise TypeError("Cannot reorder a query once a slice has been taken.")
1525 obj = self._chain()
1526 obj.query.clear_ordering(force=True, clear_default=False)
1527 obj.query.add_ordering(*field_names)
1528 return obj
1530 def distinct(self, *field_names):
1531 """
1532 Return a new QuerySet instance that will select only distinct results.
1533 """
1534 self._not_support_combined_queries("distinct")
1535 if self.query.is_sliced:
1536 raise TypeError(
1537 "Cannot create distinct fields once a slice has been taken."
1538 )
1539 obj = self._chain()
1540 obj.query.add_distinct_fields(*field_names)
1541 return obj
1543 def extra(
1544 self,
1545 select=None,
1546 where=None,
1547 params=None,
1548 tables=None,
1549 order_by=None,
1550 select_params=None,
1551 ):
1552 """Add extra SQL fragments to the query."""
1553 self._not_support_combined_queries("extra")
1554 if self.query.is_sliced:
1555 raise TypeError("Cannot change a query once a slice has been taken.")
1556 clone = self._chain()
1557 clone.query.add_extra(select, select_params, where, params, tables, order_by)
1558 return clone
1560 def reverse(self):
1561 """Reverse the ordering of the QuerySet."""
1562 if self.query.is_sliced:
1563 raise TypeError("Cannot reverse a query once a slice has been taken.")
1564 clone = self._chain()
1565 clone.query.standard_ordering = not clone.query.standard_ordering
1566 return clone
1568 def defer(self, *fields):
1569 """
1570 Defer the loading of data for certain fields until they are accessed.
1571 Add the set of deferred fields to any existing set of deferred fields.
1572 The only exception to this is if None is passed in as the only
1573 parameter, in which case removal all deferrals.
1574 """
1575 self._not_support_combined_queries("defer")
1576 if self._fields is not None:
1577 raise TypeError("Cannot call defer() after .values() or .values_list()")
1578 clone = self._chain()
1579 if fields == (None,):
1580 clone.query.clear_deferred_loading()
1581 else:
1582 clone.query.add_deferred_loading(fields)
1583 return clone
1585 def only(self, *fields):
1586 """
1587 Essentially, the opposite of defer(). Only the fields passed into this
1588 method and that are not already specified as deferred are loaded
1589 immediately when the queryset is evaluated.
1590 """
1591 self._not_support_combined_queries("only")
1592 if self._fields is not None:
1593 raise TypeError("Cannot call only() after .values() or .values_list()")
1594 if fields == (None,):
1595 # Can only pass None to defer(), not only(), as the rest option.
1596 # That won't stop people trying to do this, so let's be explicit.
1597 raise TypeError("Cannot pass None as an argument to only().")
1598 for field in fields:
1599 field = field.split(LOOKUP_SEP, 1)[0]
1600 if field in self.query._filtered_relations:
1601 raise ValueError("only() is not supported with FilteredRelation.")
1602 clone = self._chain()
1603 clone.query.add_immediate_loading(fields)
1604 return clone
1606 def using(self, alias):
1607 """Select which database this QuerySet should execute against."""
1608 clone = self._chain()
1609 clone._db = alias
1610 return clone
1612 ###################################
1613 # PUBLIC INTROSPECTION ATTRIBUTES #
1614 ###################################
1616 @property
1617 def ordered(self):
1618 """
1619 Return True if the QuerySet is ordered -- i.e. has an order_by()
1620 clause or a default ordering on the model (or is empty).
1621 """
1622 if isinstance(self, EmptyQuerySet):
1623 return True
1624 if self.query.extra_order_by or self.query.order_by:
1625 return True
1626 elif (
1627 self.query.default_ordering
1628 and self.query.get_meta().ordering
1629 and
1630 # A default ordering doesn't affect GROUP BY queries.
1631 not self.query.group_by
1632 ):
1633 return True
1634 else:
1635 return False
1637 @property
1638 def db(self):
1639 """Return the database used if this query is executed now."""
1640 if self._for_write:
1641 return self._db or router.db_for_write(self.model, **self._hints)
1642 return self._db or router.db_for_read(self.model, **self._hints)
1644 ###################
1645 # PRIVATE METHODS #
1646 ###################
1648 def _insert(
1649 self,
1650 objs,
1651 fields,
1652 returning_fields=None,
1653 raw=False,
1654 using=None,
1655 on_conflict=None,
1656 update_fields=None,
1657 unique_fields=None,
1658 ):
1659 """
1660 Insert a new record for the given model. This provides an interface to
1661 the InsertQuery class and is how Model.save() is implemented.
1662 """
1663 self._for_write = True
1664 if using is None:
1665 using = self.db
1666 query = sql.InsertQuery(
1667 self.model,
1668 on_conflict=on_conflict,
1669 update_fields=update_fields,
1670 unique_fields=unique_fields,
1671 )
1672 query.insert_values(fields, objs, raw=raw)
1673 return query.get_compiler(using=using).execute_sql(returning_fields)
1675 _insert.alters_data = True
1676 _insert.queryset_only = False
1678 def _batched_insert(
1679 self,
1680 objs,
1681 fields,
1682 batch_size,
1683 on_conflict=None,
1684 update_fields=None,
1685 unique_fields=None,
1686 ):
1687 """
1688 Helper method for bulk_create() to insert objs one batch at a time.
1689 """
1690 connection = connections[self.db]
1691 ops = connection.ops
1692 max_batch_size = max(ops.bulk_batch_size(fields, objs), 1)
1693 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
1694 inserted_rows = []
1695 bulk_return = connection.features.can_return_rows_from_bulk_insert
1696 for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
1697 if bulk_return and on_conflict is None:
1698 inserted_rows.extend(
1699 self._insert(
1700 item,
1701 fields=fields,
1702 using=self.db,
1703 returning_fields=self.model._meta.db_returning_fields,
1704 )
1705 )
1706 else:
1707 self._insert(
1708 item,
1709 fields=fields,
1710 using=self.db,
1711 on_conflict=on_conflict,
1712 update_fields=update_fields,
1713 unique_fields=unique_fields,
1714 )
1715 return inserted_rows
1717 def _chain(self):
1718 """
1719 Return a copy of the current QuerySet that's ready for another
1720 operation.
1721 """
1722 obj = self._clone()
1723 if obj._sticky_filter:
1724 obj.query.filter_is_sticky = True
1725 obj._sticky_filter = False
1726 return obj
1728 def _clone(self):
1729 """
1730 Return a copy of the current QuerySet. A lightweight alternative
1731 to deepcopy().
1732 """
1733 c = self.__class__(
1734 model=self.model,
1735 query=self.query.chain(),
1736 using=self._db,
1737 hints=self._hints,
1738 )
1739 c._sticky_filter = self._sticky_filter
1740 c._for_write = self._for_write
1741 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1742 c._known_related_objects = self._known_related_objects
1743 c._iterable_class = self._iterable_class
1744 c._fields = self._fields
1745 return c
1747 def _fetch_all(self):
1748 if self._result_cache is None:
1749 self._result_cache = list(self._iterable_class(self))
1750 if self._prefetch_related_lookups and not self._prefetch_done:
1751 self._prefetch_related_objects()
1753 def _next_is_sticky(self):
1754 """
1755 Indicate that the next filter call and the one following that should
1756 be treated as a single filter. This is only important when it comes to
1757 determining when to reuse tables for many-to-many filters. Required so
1758 that we can filter naturally on the results of related managers.
1760 This doesn't return a clone of the current QuerySet (it returns
1761 "self"). The method is only used internally and should be immediately
1762 followed by a filter() that does create a clone.
1763 """
1764 self._sticky_filter = True
1765 return self
1767 def _merge_sanity_check(self, other):
1768 """Check that two QuerySet classes may be merged."""
1769 if self._fields is not None and (
1770 set(self.query.values_select) != set(other.query.values_select)
1771 or set(self.query.extra_select) != set(other.query.extra_select)
1772 or set(self.query.annotation_select) != set(other.query.annotation_select)
1773 ):
1774 raise TypeError(
1775 "Merging '%s' classes must involve the same values in each case."
1776 % self.__class__.__name__
1777 )
1779 def _merge_known_related_objects(self, other):
1780 """
1781 Keep track of all known related objects from either QuerySet instance.
1782 """
1783 for field, objects in other._known_related_objects.items():
1784 self._known_related_objects.setdefault(field, {}).update(objects)
1786 def resolve_expression(self, *args, **kwargs):
1787 if self._fields and len(self._fields) > 1:
1788 # values() queryset can only be used as nested queries
1789 # if they are set up to select only a single field.
1790 raise TypeError("Cannot use multi-field values as a filter value.")
1791 query = self.query.resolve_expression(*args, **kwargs)
1792 query._db = self._db
1793 return query
1795 resolve_expression.queryset_only = True
1797 def _add_hints(self, **hints):
1798 """
1799 Update hinting information for use by routers. Add new key/values or
1800 overwrite existing key/values.
1801 """
1802 self._hints.update(hints)
1804 def _has_filters(self):
1805 """
1806 Check if this QuerySet has any filtering going on. This isn't
1807 equivalent with checking if all objects are present in results, for
1808 example, qs[1:]._has_filters() -> False.
1809 """
1810 return self.query.has_filters()
1812 @staticmethod
1813 def _validate_values_are_expressions(values, method_name):
1814 invalid_args = sorted(
1815 str(arg) for arg in values if not hasattr(arg, "resolve_expression")
1816 )
1817 if invalid_args:
1818 raise TypeError(
1819 "QuerySet.{}() received non-expression(s): {}.".format(
1820 method_name,
1821 ", ".join(invalid_args),
1822 )
1823 )
1825 def _not_support_combined_queries(self, operation_name):
1826 if self.query.combinator:
1827 raise NotSupportedError(
1828 "Calling QuerySet.{}() after {}() is not supported.".format(
1829 operation_name, self.query.combinator
1830 )
1831 )
1833 def _check_operator_queryset(self, other, operator_):
1834 if self.query.combinator or other.query.combinator:
1835 raise TypeError(f"Cannot use {operator_} operator with combined queryset.")
1837 def _check_ordering_first_last_queryset_aggregation(self, method):
1838 if isinstance(self.query.group_by, tuple) and not any(
1839 col.output_field is self.model._meta.pk for col in self.query.group_by
1840 ):
1841 raise TypeError(
1842 f"Cannot use QuerySet.{method}() on an unordered queryset performing "
1843 f"aggregation. Add an ordering with order_by()."
1844 )
1847class InstanceCheckMeta(type):
1848 def __instancecheck__(self, instance):
1849 return isinstance(instance, QuerySet) and instance.query.is_empty()
1852class EmptyQuerySet(metaclass=InstanceCheckMeta):
1853 """
1854 Marker class to checking if a queryset is empty by .none():
1855 isinstance(qs.none(), EmptyQuerySet) -> True
1856 """
1858 def __init__(self, *args, **kwargs):
1859 raise TypeError("EmptyQuerySet can't be instantiated")
1862class RawQuerySet:
1863 """
1864 Provide an iterator which converts the results of raw SQL queries into
1865 annotated model instances.
1866 """
1868 def __init__(
1869 self,
1870 raw_query,
1871 model=None,
1872 query=None,
1873 params=(),
1874 translations=None,
1875 using=None,
1876 hints=None,
1877 ):
1878 self.raw_query = raw_query
1879 self.model = model
1880 self._db = using
1881 self._hints = hints or {}
1882 self.query = query or sql.RawQuery(sql=raw_query, using=self.db, params=params)
1883 self.params = params
1884 self.translations = translations or {}
1885 self._result_cache = None
1886 self._prefetch_related_lookups = ()
1887 self._prefetch_done = False
1889 def resolve_model_init_order(self):
1890 """Resolve the init field names and value positions."""
1891 converter = connections[self.db].introspection.identifier_converter
1892 model_init_fields = [
1893 f for f in self.model._meta.fields if converter(f.column) in self.columns
1894 ]
1895 annotation_fields = [
1896 (column, pos)
1897 for pos, column in enumerate(self.columns)
1898 if column not in self.model_fields
1899 ]
1900 model_init_order = [
1901 self.columns.index(converter(f.column)) for f in model_init_fields
1902 ]
1903 model_init_names = [f.attname for f in model_init_fields]
1904 return model_init_names, model_init_order, annotation_fields
1906 def prefetch_related(self, *lookups):
1907 """Same as QuerySet.prefetch_related()"""
1908 clone = self._clone()
1909 if lookups == (None,):
1910 clone._prefetch_related_lookups = ()
1911 else:
1912 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1913 return clone
1915 def _prefetch_related_objects(self):
1916 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1917 self._prefetch_done = True
1919 def _clone(self):
1920 """Same as QuerySet._clone()"""
1921 c = self.__class__(
1922 self.raw_query,
1923 model=self.model,
1924 query=self.query,
1925 params=self.params,
1926 translations=self.translations,
1927 using=self._db,
1928 hints=self._hints,
1929 )
1930 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1931 return c
1933 def _fetch_all(self):
1934 if self._result_cache is None:
1935 self._result_cache = list(self.iterator())
1936 if self._prefetch_related_lookups and not self._prefetch_done:
1937 self._prefetch_related_objects()
1939 def __len__(self):
1940 self._fetch_all()
1941 return len(self._result_cache)
1943 def __bool__(self):
1944 self._fetch_all()
1945 return bool(self._result_cache)
1947 def __iter__(self):
1948 self._fetch_all()
1949 return iter(self._result_cache)
1951 def iterator(self):
1952 yield from RawModelIterable(self)
1954 def __repr__(self):
1955 return f"<{self.__class__.__name__}: {self.query}>"
1957 def __getitem__(self, k):
1958 return list(self)[k]
1960 @property
1961 def db(self):
1962 """Return the database used if this query is executed now."""
1963 return self._db or router.db_for_read(self.model, **self._hints)
1965 def using(self, alias):
1966 """Select the database this RawQuerySet should execute against."""
1967 return RawQuerySet(
1968 self.raw_query,
1969 model=self.model,
1970 query=self.query.chain(using=alias),
1971 params=self.params,
1972 translations=self.translations,
1973 using=alias,
1974 )
1976 @cached_property
1977 def columns(self):
1978 """
1979 A list of model field names in the order they'll appear in the
1980 query results.
1981 """
1982 columns = self.query.get_columns()
1983 # Adjust any column names which don't match field names
1984 for query_name, model_name in self.translations.items():
1985 # Ignore translations for nonexistent column names
1986 try:
1987 index = columns.index(query_name)
1988 except ValueError:
1989 pass
1990 else:
1991 columns[index] = model_name
1992 return columns
1994 @cached_property
1995 def model_fields(self):
1996 """A dict mapping column names to model field names."""
1997 converter = connections[self.db].introspection.identifier_converter
1998 model_fields = {}
1999 for field in self.model._meta.fields:
2000 name, column = field.get_attname_column()
2001 model_fields[converter(column)] = field
2002 return model_fields
2005class Prefetch:
2006 def __init__(self, lookup, queryset=None, to_attr=None):
2007 # `prefetch_through` is the path we traverse to perform the prefetch.
2008 self.prefetch_through = lookup
2009 # `prefetch_to` is the path to the attribute that stores the result.
2010 self.prefetch_to = lookup
2011 if queryset is not None and (
2012 isinstance(queryset, RawQuerySet)
2013 or (
2014 hasattr(queryset, "_iterable_class")
2015 and not issubclass(queryset._iterable_class, ModelIterable)
2016 )
2017 ):
2018 raise ValueError(
2019 "Prefetch querysets cannot use raw(), values(), and values_list()."
2020 )
2021 if to_attr:
2022 self.prefetch_to = LOOKUP_SEP.join(
2023 lookup.split(LOOKUP_SEP)[:-1] + [to_attr]
2024 )
2026 self.queryset = queryset
2027 self.to_attr = to_attr
2029 def __getstate__(self):
2030 obj_dict = self.__dict__.copy()
2031 if self.queryset is not None:
2032 queryset = self.queryset._chain()
2033 # Prevent the QuerySet from being evaluated
2034 queryset._result_cache = []
2035 queryset._prefetch_done = True
2036 obj_dict["queryset"] = queryset
2037 return obj_dict
2039 def add_prefix(self, prefix):
2040 self.prefetch_through = prefix + LOOKUP_SEP + self.prefetch_through
2041 self.prefetch_to = prefix + LOOKUP_SEP + self.prefetch_to
2043 def get_current_prefetch_to(self, level):
2044 return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[: level + 1])
2046 def get_current_to_attr(self, level):
2047 parts = self.prefetch_to.split(LOOKUP_SEP)
2048 to_attr = parts[level]
2049 as_attr = self.to_attr and level == len(parts) - 1
2050 return to_attr, as_attr
2052 def get_current_queryset(self, level):
2053 if self.get_current_prefetch_to(level) == self.prefetch_to:
2054 return self.queryset
2055 return None
2057 def __eq__(self, other):
2058 if not isinstance(other, Prefetch):
2059 return NotImplemented
2060 return self.prefetch_to == other.prefetch_to
2062 def __hash__(self):
2063 return hash((self.__class__, self.prefetch_to))
2066def normalize_prefetch_lookups(lookups, prefix=None):
2067 """Normalize lookups into Prefetch objects."""
2068 ret = []
2069 for lookup in lookups:
2070 if not isinstance(lookup, Prefetch):
2071 lookup = Prefetch(lookup)
2072 if prefix:
2073 lookup.add_prefix(prefix)
2074 ret.append(lookup)
2075 return ret
2078def prefetch_related_objects(model_instances, *related_lookups):
2079 """
2080 Populate prefetched object caches for a list of model instances based on
2081 the lookups/Prefetch instances given.
2082 """
2083 if not model_instances:
2084 return # nothing to do
2086 # We need to be able to dynamically add to the list of prefetch_related
2087 # lookups that we look up (see below). So we need some book keeping to
2088 # ensure we don't do duplicate work.
2089 done_queries = {} # dictionary of things like 'foo__bar': [results]
2091 auto_lookups = set() # we add to this as we go through.
2092 followed_descriptors = set() # recursion protection
2094 all_lookups = normalize_prefetch_lookups(reversed(related_lookups))
2095 while all_lookups:
2096 lookup = all_lookups.pop()
2097 if lookup.prefetch_to in done_queries:
2098 if lookup.queryset is not None:
2099 raise ValueError(
2100 "'%s' lookup was already seen with a different queryset. "
2101 "You may need to adjust the ordering of your lookups."
2102 % lookup.prefetch_to
2103 )
2105 continue
2107 # Top level, the list of objects to decorate is the result cache
2108 # from the primary QuerySet. It won't be for deeper levels.
2109 obj_list = model_instances
2111 through_attrs = lookup.prefetch_through.split(LOOKUP_SEP)
2112 for level, through_attr in enumerate(through_attrs):
2113 # Prepare main instances
2114 if not obj_list:
2115 break
2117 prefetch_to = lookup.get_current_prefetch_to(level)
2118 if prefetch_to in done_queries:
2119 # Skip any prefetching, and any object preparation
2120 obj_list = done_queries[prefetch_to]
2121 continue
2123 # Prepare objects:
2124 good_objects = True
2125 for obj in obj_list:
2126 # Since prefetching can re-use instances, it is possible to have
2127 # the same instance multiple times in obj_list, so obj might
2128 # already be prepared.
2129 if not hasattr(obj, "_prefetched_objects_cache"):
2130 try:
2131 obj._prefetched_objects_cache = {}
2132 except (AttributeError, TypeError):
2133 # Must be an immutable object from
2134 # values_list(flat=True), for example (TypeError) or
2135 # a QuerySet subclass that isn't returning Model
2136 # instances (AttributeError), either in Plain or a 3rd
2137 # party. prefetch_related() doesn't make sense, so quit.
2138 good_objects = False
2139 break
2140 if not good_objects:
2141 break
2143 # Descend down tree
2145 # We assume that objects retrieved are homogeneous (which is the premise
2146 # of prefetch_related), so what applies to first object applies to all.
2147 first_obj = obj_list[0]
2148 to_attr = lookup.get_current_to_attr(level)[0]
2149 prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(
2150 first_obj, through_attr, to_attr
2151 )
2153 if not attr_found:
2154 raise AttributeError(
2155 "Cannot find '{}' on {} object, '{}' is an invalid "
2156 "parameter to prefetch_related()".format(
2157 through_attr,
2158 first_obj.__class__.__name__,
2159 lookup.prefetch_through,
2160 )
2161 )
2163 if level == len(through_attrs) - 1 and prefetcher is None:
2164 # Last one, this *must* resolve to something that supports
2165 # prefetching, otherwise there is no point adding it and the
2166 # developer asking for it has made a mistake.
2167 raise ValueError(
2168 "'%s' does not resolve to an item that supports "
2169 "prefetching - this is an invalid parameter to "
2170 "prefetch_related()." % lookup.prefetch_through
2171 )
2173 obj_to_fetch = None
2174 if prefetcher is not None:
2175 obj_to_fetch = [obj for obj in obj_list if not is_fetched(obj)]
2177 if obj_to_fetch:
2178 obj_list, additional_lookups = prefetch_one_level(
2179 obj_to_fetch,
2180 prefetcher,
2181 lookup,
2182 level,
2183 )
2184 # We need to ensure we don't keep adding lookups from the
2185 # same relationships to stop infinite recursion. So, if we
2186 # are already on an automatically added lookup, don't add
2187 # the new lookups from relationships we've seen already.
2188 if not (
2189 prefetch_to in done_queries
2190 and lookup in auto_lookups
2191 and descriptor in followed_descriptors
2192 ):
2193 done_queries[prefetch_to] = obj_list
2194 new_lookups = normalize_prefetch_lookups(
2195 reversed(additional_lookups), prefetch_to
2196 )
2197 auto_lookups.update(new_lookups)
2198 all_lookups.extend(new_lookups)
2199 followed_descriptors.add(descriptor)
2200 else:
2201 # Either a singly related object that has already been fetched
2202 # (e.g. via select_related), or hopefully some other property
2203 # that doesn't support prefetching but needs to be traversed.
2205 # We replace the current list of parent objects with the list
2206 # of related objects, filtering out empty or missing values so
2207 # that we can continue with nullable or reverse relations.
2208 new_obj_list = []
2209 for obj in obj_list:
2210 if through_attr in getattr(obj, "_prefetched_objects_cache", ()):
2211 # If related objects have been prefetched, use the
2212 # cache rather than the object's through_attr.
2213 new_obj = list(obj._prefetched_objects_cache.get(through_attr))
2214 else:
2215 try:
2216 new_obj = getattr(obj, through_attr)
2217 except exceptions.ObjectDoesNotExist:
2218 continue
2219 if new_obj is None:
2220 continue
2221 # We special-case `list` rather than something more generic
2222 # like `Iterable` because we don't want to accidentally match
2223 # user models that define __iter__.
2224 if isinstance(new_obj, list):
2225 new_obj_list.extend(new_obj)
2226 else:
2227 new_obj_list.append(new_obj)
2228 obj_list = new_obj_list
2231def get_prefetcher(instance, through_attr, to_attr):
2232 """
2233 For the attribute 'through_attr' on the given instance, find
2234 an object that has a get_prefetch_queryset().
2235 Return a 4 tuple containing:
2236 (the object with get_prefetch_queryset (or None),
2237 the descriptor object representing this relationship (or None),
2238 a boolean that is False if the attribute was not found at all,
2239 a function that takes an instance and returns a boolean that is True if
2240 the attribute has already been fetched for that instance)
2241 """
2243 def has_to_attr_attribute(instance):
2244 return hasattr(instance, to_attr)
2246 prefetcher = None
2247 is_fetched = has_to_attr_attribute
2249 # For singly related objects, we have to avoid getting the attribute
2250 # from the object, as this will trigger the query. So we first try
2251 # on the class, in order to get the descriptor object.
2252 rel_obj_descriptor = getattr(instance.__class__, through_attr, None)
2253 if rel_obj_descriptor is None:
2254 attr_found = hasattr(instance, through_attr)
2255 else:
2256 attr_found = True
2257 if rel_obj_descriptor:
2258 # singly related object, descriptor object has the
2259 # get_prefetch_queryset() method.
2260 if hasattr(rel_obj_descriptor, "get_prefetch_queryset"):
2261 prefetcher = rel_obj_descriptor
2262 is_fetched = rel_obj_descriptor.is_cached
2263 else:
2264 # descriptor doesn't support prefetching, so we go ahead and get
2265 # the attribute on the instance rather than the class to
2266 # support many related managers
2267 rel_obj = getattr(instance, through_attr)
2268 if hasattr(rel_obj, "get_prefetch_queryset"):
2269 prefetcher = rel_obj
2270 if through_attr != to_attr:
2271 # Special case cached_property instances because hasattr
2272 # triggers attribute computation and assignment.
2273 if isinstance(
2274 getattr(instance.__class__, to_attr, None), cached_property
2275 ):
2277 def has_cached_property(instance):
2278 return to_attr in instance.__dict__
2280 is_fetched = has_cached_property
2281 else:
2283 def in_prefetched_cache(instance):
2284 return through_attr in instance._prefetched_objects_cache
2286 is_fetched = in_prefetched_cache
2287 return prefetcher, rel_obj_descriptor, attr_found, is_fetched
2290def prefetch_one_level(instances, prefetcher, lookup, level):
2291 """
2292 Helper function for prefetch_related_objects().
2294 Run prefetches on all instances using the prefetcher object,
2295 assigning results to relevant caches in instance.
2297 Return the prefetched objects along with any additional prefetches that
2298 must be done due to prefetch_related lookups found from default managers.
2299 """
2300 # prefetcher must have a method get_prefetch_queryset() which takes a list
2301 # of instances, and returns a tuple:
2303 # (queryset of instances of self.model that are related to passed in instances,
2304 # callable that gets value to be matched for returned instances,
2305 # callable that gets value to be matched for passed in instances,
2306 # boolean that is True for singly related objects,
2307 # cache or field name to assign to,
2308 # boolean that is True when the previous argument is a cache name vs a field name).
2310 # The 'values to be matched' must be hashable as they will be used
2311 # in a dictionary.
2313 (
2314 rel_qs,
2315 rel_obj_attr,
2316 instance_attr,
2317 single,
2318 cache_name,
2319 is_descriptor,
2320 ) = prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level))
2321 # We have to handle the possibility that the QuerySet we just got back
2322 # contains some prefetch_related lookups. We don't want to trigger the
2323 # prefetch_related functionality by evaluating the query. Rather, we need
2324 # to merge in the prefetch_related lookups.
2325 # Copy the lookups in case it is a Prefetch object which could be reused
2326 # later (happens in nested prefetch_related).
2327 additional_lookups = [
2328 copy.copy(additional_lookup)
2329 for additional_lookup in getattr(rel_qs, "_prefetch_related_lookups", ())
2330 ]
2331 if additional_lookups:
2332 # Don't need to clone because the manager should have given us a fresh
2333 # instance, so we access an internal instead of using public interface
2334 # for performance reasons.
2335 rel_qs._prefetch_related_lookups = ()
2337 all_related_objects = list(rel_qs)
2339 rel_obj_cache = {}
2340 for rel_obj in all_related_objects:
2341 rel_attr_val = rel_obj_attr(rel_obj)
2342 rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj)
2344 to_attr, as_attr = lookup.get_current_to_attr(level)
2345 # Make sure `to_attr` does not conflict with a field.
2346 if as_attr and instances:
2347 # We assume that objects retrieved are homogeneous (which is the premise
2348 # of prefetch_related), so what applies to first object applies to all.
2349 model = instances[0].__class__
2350 try:
2351 model._meta.get_field(to_attr)
2352 except exceptions.FieldDoesNotExist:
2353 pass
2354 else:
2355 msg = "to_attr={} conflicts with a field on the {} model."
2356 raise ValueError(msg.format(to_attr, model.__name__))
2358 # Whether or not we're prefetching the last part of the lookup.
2359 leaf = len(lookup.prefetch_through.split(LOOKUP_SEP)) - 1 == level
2361 for obj in instances:
2362 instance_attr_val = instance_attr(obj)
2363 vals = rel_obj_cache.get(instance_attr_val, [])
2365 if single:
2366 val = vals[0] if vals else None
2367 if as_attr:
2368 # A to_attr has been given for the prefetch.
2369 setattr(obj, to_attr, val)
2370 elif is_descriptor:
2371 # cache_name points to a field name in obj.
2372 # This field is a descriptor for a related object.
2373 setattr(obj, cache_name, val)
2374 else:
2375 # No to_attr has been given for this prefetch operation and the
2376 # cache_name does not point to a descriptor. Store the value of
2377 # the field in the object's field cache.
2378 obj._state.fields_cache[cache_name] = val
2379 else:
2380 if as_attr:
2381 setattr(obj, to_attr, vals)
2382 else:
2383 manager = getattr(obj, to_attr)
2384 if leaf and lookup.queryset is not None:
2385 qs = manager._apply_rel_filters(lookup.queryset)
2386 else:
2387 qs = manager.get_queryset()
2388 qs._result_cache = vals
2389 # We don't want the individual qs doing prefetch_related now,
2390 # since we have merged this into the current work.
2391 qs._prefetch_done = True
2392 obj._prefetched_objects_cache[cache_name] = qs
2393 return all_related_objects, additional_lookups
2396class RelatedPopulator:
2397 """
2398 RelatedPopulator is used for select_related() object instantiation.
2400 The idea is that each select_related() model will be populated by a
2401 different RelatedPopulator instance. The RelatedPopulator instances get
2402 klass_info and select (computed in SQLCompiler) plus the used db as
2403 input for initialization. That data is used to compute which columns
2404 to use, how to instantiate the model, and how to populate the links
2405 between the objects.
2407 The actual creation of the objects is done in populate() method. This
2408 method gets row and from_obj as input and populates the select_related()
2409 model instance.
2410 """
2412 def __init__(self, klass_info, select, db):
2413 self.db = db
2414 # Pre-compute needed attributes. The attributes are:
2415 # - model_cls: the possibly deferred model class to instantiate
2416 # - either:
2417 # - cols_start, cols_end: usually the columns in the row are
2418 # in the same order model_cls.__init__ expects them, so we
2419 # can instantiate by model_cls(*row[cols_start:cols_end])
2420 # - reorder_for_init: When select_related descends to a child
2421 # class, then we want to reuse the already selected parent
2422 # data. However, in this case the parent data isn't necessarily
2423 # in the same order that Model.__init__ expects it to be, so
2424 # we have to reorder the parent data. The reorder_for_init
2425 # attribute contains a function used to reorder the field data
2426 # in the order __init__ expects it.
2427 # - pk_idx: the index of the primary key field in the reordered
2428 # model data. Used to check if a related object exists at all.
2429 # - init_list: the field attnames fetched from the database. For
2430 # deferred models this isn't the same as all attnames of the
2431 # model's fields.
2432 # - related_populators: a list of RelatedPopulator instances if
2433 # select_related() descends to related models from this model.
2434 # - local_setter, remote_setter: Methods to set cached values on
2435 # the object being populated and on the remote object. Usually
2436 # these are Field.set_cached_value() methods.
2437 select_fields = klass_info["select_fields"]
2438 from_parent = klass_info["from_parent"]
2439 if not from_parent:
2440 self.cols_start = select_fields[0]
2441 self.cols_end = select_fields[-1] + 1
2442 self.init_list = [
2443 f[0].target.attname for f in select[self.cols_start : self.cols_end]
2444 ]
2445 self.reorder_for_init = None
2446 else:
2447 attname_indexes = {
2448 select[idx][0].target.attname: idx for idx in select_fields
2449 }
2450 model_init_attnames = (
2451 f.attname for f in klass_info["model"]._meta.concrete_fields
2452 )
2453 self.init_list = [
2454 attname for attname in model_init_attnames if attname in attname_indexes
2455 ]
2456 self.reorder_for_init = operator.itemgetter(
2457 *[attname_indexes[attname] for attname in self.init_list]
2458 )
2460 self.model_cls = klass_info["model"]
2461 self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname)
2462 self.related_populators = get_related_populators(klass_info, select, self.db)
2463 self.local_setter = klass_info["local_setter"]
2464 self.remote_setter = klass_info["remote_setter"]
2466 def populate(self, row, from_obj):
2467 if self.reorder_for_init:
2468 obj_data = self.reorder_for_init(row)
2469 else:
2470 obj_data = row[self.cols_start : self.cols_end]
2471 if obj_data[self.pk_idx] is None:
2472 obj = None
2473 else:
2474 obj = self.model_cls.from_db(self.db, self.init_list, obj_data)
2475 for rel_iter in self.related_populators:
2476 rel_iter.populate(row, obj)
2477 self.local_setter(from_obj, obj)
2478 if obj is not None:
2479 self.remote_setter(obj, from_obj)
2482def get_related_populators(klass_info, select, db):
2483 iterators = []
2484 related_klass_infos = klass_info.get("related_klass_infos", [])
2485 for rel_klass_info in related_klass_infos:
2486 rel_cls = RelatedPopulator(rel_klass_info, select, db)
2487 iterators.append(rel_cls)
2488 return iterators