Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/migrations/state.py: 57%
550 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
1import copy
2from collections import defaultdict
3from contextlib import contextmanager
4from functools import partial
6from plain import models
7from plain.exceptions import FieldDoesNotExist
8from plain.models.fields import NOT_PROVIDED
9from plain.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
10from plain.models.migrations.utils import field_is_referenced, get_references
11from plain.models.options import DEFAULT_NAMES
12from plain.models.utils import make_model_tuple
13from plain.packages import PackageConfig
14from plain.packages.registry import Packages
15from plain.packages.registry import packages as global_packages
16from plain.runtime import settings
17from plain.utils.functional import cached_property
18from plain.utils.module_loading import import_string
20from .exceptions import InvalidBasesError
21from .utils import resolve_relation
24def _get_package_label_and_model_name(model, package_label=""):
25 if isinstance(model, str):
26 split = model.split(".", 1)
27 return tuple(split) if len(split) == 2 else (package_label, split[0])
28 else:
29 return model._meta.package_label, model._meta.model_name
32def _get_related_models(m):
33 """Return all models that have a direct relationship to the given model."""
34 related_models = [
35 subclass
36 for subclass in m.__subclasses__()
37 if issubclass(subclass, models.Model)
38 ]
39 related_fields_models = set()
40 for f in m._meta.get_fields(include_parents=True, include_hidden=True):
41 if (
42 f.is_relation
43 and f.related_model is not None
44 and not isinstance(f.related_model, str)
45 ):
46 related_fields_models.add(f.model)
47 related_models.append(f.related_model)
48 return related_models
51def get_related_models_tuples(model):
52 """
53 Return a list of typical (package_label, model_name) tuples for all related
54 models for the given model.
55 """
56 return {
57 (rel_mod._meta.package_label, rel_mod._meta.model_name)
58 for rel_mod in _get_related_models(model)
59 }
62def get_related_models_recursive(model):
63 """
64 Return all models that have a direct or indirect relationship
65 to the given model.
67 Relationships are either defined by explicit relational fields, like
68 ForeignKey, ManyToManyField or OneToOneField, or by inheriting from another
69 model (a superclass is related to its subclasses, but not vice versa). Note,
70 however, that a model inheriting from a concrete model is also related to
71 its superclass through the implicit *_ptr OneToOneField on the subclass.
72 """
73 seen = set()
74 queue = _get_related_models(model)
75 for rel_mod in queue:
76 rel_package_label, rel_model_name = (
77 rel_mod._meta.package_label,
78 rel_mod._meta.model_name,
79 )
80 if (rel_package_label, rel_model_name) in seen:
81 continue
82 seen.add((rel_package_label, rel_model_name))
83 queue.extend(_get_related_models(rel_mod))
84 return seen - {(model._meta.package_label, model._meta.model_name)}
87class ProjectState:
88 """
89 Represent the entire project's overall state. This is the item that is
90 passed around - do it here rather than at the app level so that cross-app
91 FKs/etc. resolve properly.
92 """
94 def __init__(self, models=None, real_packages=None):
95 self.models = models or {}
96 # Packages to include from main registry, usually unmigrated ones
97 if real_packages is None:
98 real_packages = set()
99 else:
100 assert isinstance(real_packages, set)
101 self.real_packages = real_packages
102 self.is_delayed = False
103 # {remote_model_key: {model_key: {field_name: field}}}
104 self._relations = None
106 @property
107 def relations(self):
108 if self._relations is None:
109 self.resolve_fields_and_relations()
110 return self._relations
112 def add_model(self, model_state):
113 model_key = model_state.package_label, model_state.name_lower
114 self.models[model_key] = model_state
115 if self._relations is not None:
116 self.resolve_model_relations(model_key)
117 if "packages" in self.__dict__: # hasattr would cache the property
118 self.reload_model(*model_key)
120 def remove_model(self, package_label, model_name):
121 model_key = package_label, model_name
122 del self.models[model_key]
123 if self._relations is not None:
124 self._relations.pop(model_key, None)
125 # Call list() since _relations can change size during iteration.
126 for related_model_key, model_relations in list(self._relations.items()):
127 model_relations.pop(model_key, None)
128 if not model_relations:
129 del self._relations[related_model_key]
130 if "packages" in self.__dict__: # hasattr would cache the property
131 self.packages.unregister_model(*model_key)
132 # Need to do this explicitly since unregister_model() doesn't clear
133 # the cache automatically (#24513)
134 self.packages.clear_cache()
136 def rename_model(self, package_label, old_name, new_name):
137 # Add a new model.
138 old_name_lower = old_name.lower()
139 new_name_lower = new_name.lower()
140 renamed_model = self.models[package_label, old_name_lower].clone()
141 renamed_model.name = new_name
142 self.models[package_label, new_name_lower] = renamed_model
143 # Repoint all fields pointing to the old model to the new one.
144 old_model_tuple = (package_label, old_name_lower)
145 new_remote_model = f"{package_label}.{new_name}"
146 to_reload = set()
147 for model_state, name, field, reference in get_references(
148 self, old_model_tuple
149 ):
150 changed_field = None
151 if reference.to:
152 changed_field = field.clone()
153 changed_field.remote_field.model = new_remote_model
154 if reference.through:
155 if changed_field is None:
156 changed_field = field.clone()
157 changed_field.remote_field.through = new_remote_model
158 if changed_field:
159 model_state.fields[name] = changed_field
160 to_reload.add((model_state.package_label, model_state.name_lower))
161 if self._relations is not None:
162 old_name_key = package_label, old_name_lower
163 new_name_key = package_label, new_name_lower
164 if old_name_key in self._relations:
165 self._relations[new_name_key] = self._relations.pop(old_name_key)
166 for model_relations in self._relations.values():
167 if old_name_key in model_relations:
168 model_relations[new_name_key] = model_relations.pop(old_name_key)
169 # Reload models related to old model before removing the old model.
170 self.reload_models(to_reload, delay=True)
171 # Remove the old model.
172 self.remove_model(package_label, old_name_lower)
173 self.reload_model(package_label, new_name_lower, delay=True)
175 def alter_model_options(self, package_label, model_name, options, option_keys=None):
176 model_state = self.models[package_label, model_name]
177 model_state.options = {**model_state.options, **options}
178 if option_keys:
179 for key in option_keys:
180 if key not in options:
181 model_state.options.pop(key, False)
182 self.reload_model(package_label, model_name, delay=True)
184 def remove_model_options(
185 self, package_label, model_name, option_name, value_to_remove
186 ):
187 model_state = self.models[package_label, model_name]
188 if objs := model_state.options.get(option_name):
189 model_state.options[option_name] = [
190 obj for obj in objs if tuple(obj) != tuple(value_to_remove)
191 ]
192 self.reload_model(package_label, model_name, delay=True)
194 def alter_model_managers(self, package_label, model_name, managers):
195 model_state = self.models[package_label, model_name]
196 model_state.managers = list(managers)
197 self.reload_model(package_label, model_name, delay=True)
199 def _append_option(self, package_label, model_name, option_name, obj):
200 model_state = self.models[package_label, model_name]
201 model_state.options[option_name] = [*model_state.options[option_name], obj]
202 self.reload_model(package_label, model_name, delay=True)
204 def _remove_option(self, package_label, model_name, option_name, obj_name):
205 model_state = self.models[package_label, model_name]
206 objs = model_state.options[option_name]
207 model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name]
208 self.reload_model(package_label, model_name, delay=True)
210 def add_index(self, package_label, model_name, index):
211 self._append_option(package_label, model_name, "indexes", index)
213 def remove_index(self, package_label, model_name, index_name):
214 self._remove_option(package_label, model_name, "indexes", index_name)
216 def rename_index(self, package_label, model_name, old_index_name, new_index_name):
217 model_state = self.models[package_label, model_name]
218 objs = model_state.options["indexes"]
220 new_indexes = []
221 for obj in objs:
222 if obj.name == old_index_name:
223 obj = obj.clone()
224 obj.name = new_index_name
225 new_indexes.append(obj)
227 model_state.options["indexes"] = new_indexes
228 self.reload_model(package_label, model_name, delay=True)
230 def add_constraint(self, package_label, model_name, constraint):
231 self._append_option(package_label, model_name, "constraints", constraint)
233 def remove_constraint(self, package_label, model_name, constraint_name):
234 self._remove_option(package_label, model_name, "constraints", constraint_name)
236 def add_field(self, package_label, model_name, name, field, preserve_default):
237 # If preserve default is off, don't use the default for future state.
238 if not preserve_default:
239 field = field.clone()
240 field.default = NOT_PROVIDED
241 else:
242 field = field
243 model_key = package_label, model_name
244 self.models[model_key].fields[name] = field
245 if self._relations is not None:
246 self.resolve_model_field_relations(model_key, name, field)
247 # Delay rendering of relationships if it's not a relational field.
248 delay = not field.is_relation
249 self.reload_model(*model_key, delay=delay)
251 def remove_field(self, package_label, model_name, name):
252 model_key = package_label, model_name
253 model_state = self.models[model_key]
254 old_field = model_state.fields.pop(name)
255 if self._relations is not None:
256 self.resolve_model_field_relations(model_key, name, old_field)
257 # Delay rendering of relationships if it's not a relational field.
258 delay = not old_field.is_relation
259 self.reload_model(*model_key, delay=delay)
261 def alter_field(self, package_label, model_name, name, field, preserve_default):
262 if not preserve_default:
263 field = field.clone()
264 field.default = NOT_PROVIDED
265 else:
266 field = field
267 model_key = package_label, model_name
268 fields = self.models[model_key].fields
269 if self._relations is not None:
270 old_field = fields.pop(name)
271 if old_field.is_relation:
272 self.resolve_model_field_relations(model_key, name, old_field)
273 fields[name] = field
274 if field.is_relation:
275 self.resolve_model_field_relations(model_key, name, field)
276 else:
277 fields[name] = field
278 # TODO: investigate if old relational fields must be reloaded or if
279 # it's sufficient if the new field is (#27737).
280 # Delay rendering of relationships if it's not a relational field and
281 # not referenced by a foreign key.
282 delay = not field.is_relation and not field_is_referenced(
283 self, model_key, (name, field)
284 )
285 self.reload_model(*model_key, delay=delay)
287 def rename_field(self, package_label, model_name, old_name, new_name):
288 model_key = package_label, model_name
289 model_state = self.models[model_key]
290 # Rename the field.
291 fields = model_state.fields
292 try:
293 found = fields.pop(old_name)
294 except KeyError:
295 raise FieldDoesNotExist(
296 f"{package_label}.{model_name} has no field named '{old_name}'"
297 )
298 fields[new_name] = found
299 for field in fields.values():
300 # Fix from_fields to refer to the new field.
301 from_fields = getattr(field, "from_fields", None)
302 if from_fields:
303 field.from_fields = tuple(
304 [
305 new_name if from_field_name == old_name else from_field_name
306 for from_field_name in from_fields
307 ]
308 )
310 # Fix to_fields to refer to the new field.
311 delay = True
312 references = get_references(self, model_key, (old_name, found))
313 for *_, field, reference in references:
314 delay = False
315 if reference.to:
316 remote_field, to_fields = reference.to
317 if getattr(remote_field, "field_name", None) == old_name:
318 remote_field.field_name = new_name
319 if to_fields:
320 field.to_fields = tuple(
321 [
322 new_name if to_field_name == old_name else to_field_name
323 for to_field_name in to_fields
324 ]
325 )
326 if self._relations is not None:
327 old_name_lower = old_name.lower()
328 new_name_lower = new_name.lower()
329 for to_model in self._relations.values():
330 if old_name_lower in to_model[model_key]:
331 field = to_model[model_key].pop(old_name_lower)
332 field.name = new_name_lower
333 to_model[model_key][new_name_lower] = field
334 self.reload_model(*model_key, delay=delay)
336 def _find_reload_model(self, package_label, model_name, delay=False):
337 if delay:
338 self.is_delayed = True
340 related_models = set()
342 try:
343 old_model = self.packages.get_model(package_label, model_name)
344 except LookupError:
345 pass
346 else:
347 # Get all relations to and from the old model before reloading,
348 # as _meta.packages may change
349 if delay:
350 related_models = get_related_models_tuples(old_model)
351 else:
352 related_models = get_related_models_recursive(old_model)
354 # Get all outgoing references from the model to be rendered
355 model_state = self.models[(package_label, model_name)]
356 # Directly related models are the models pointed to by ForeignKeys,
357 # OneToOneFields, and ManyToManyFields.
358 direct_related_models = set()
359 for field in model_state.fields.values():
360 if field.is_relation:
361 if field.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT:
362 continue
363 rel_package_label, rel_model_name = _get_package_label_and_model_name(
364 field.related_model, package_label
365 )
366 direct_related_models.add((rel_package_label, rel_model_name.lower()))
368 # For all direct related models recursively get all related models.
369 related_models.update(direct_related_models)
370 for rel_package_label, rel_model_name in direct_related_models:
371 try:
372 rel_model = self.packages.get_model(rel_package_label, rel_model_name)
373 except LookupError:
374 pass
375 else:
376 if delay:
377 related_models.update(get_related_models_tuples(rel_model))
378 else:
379 related_models.update(get_related_models_recursive(rel_model))
381 # Include the model itself
382 related_models.add((package_label, model_name))
384 return related_models
386 def reload_model(self, package_label, model_name, delay=False):
387 if "packages" in self.__dict__: # hasattr would cache the property
388 related_models = self._find_reload_model(package_label, model_name, delay)
389 self._reload(related_models)
391 def reload_models(self, models, delay=True):
392 if "packages" in self.__dict__: # hasattr would cache the property
393 related_models = set()
394 for package_label, model_name in models:
395 related_models.update(
396 self._find_reload_model(package_label, model_name, delay)
397 )
398 self._reload(related_models)
400 def _reload(self, related_models):
401 # Unregister all related models
402 with self.packages.bulk_update():
403 for rel_package_label, rel_model_name in related_models:
404 self.packages.unregister_model(rel_package_label, rel_model_name)
406 states_to_be_rendered = []
407 # Gather all models states of those models that will be rerendered.
408 # This includes:
409 # 1. All related models of unmigrated packages
410 for model_state in self.packages.real_models:
411 if (model_state.package_label, model_state.name_lower) in related_models:
412 states_to_be_rendered.append(model_state)
414 # 2. All related models of migrated packages
415 for rel_package_label, rel_model_name in related_models:
416 try:
417 model_state = self.models[rel_package_label, rel_model_name]
418 except KeyError:
419 pass
420 else:
421 states_to_be_rendered.append(model_state)
423 # Render all models
424 self.packages.render_multiple(states_to_be_rendered)
426 def update_model_field_relation(
427 self,
428 model,
429 model_key,
430 field_name,
431 field,
432 concretes,
433 ):
434 remote_model_key = resolve_relation(model, *model_key)
435 if (
436 remote_model_key[0] not in self.real_packages
437 and remote_model_key in concretes
438 ):
439 remote_model_key = concretes[remote_model_key]
440 relations_to_remote_model = self._relations[remote_model_key]
441 if field_name in self.models[model_key].fields:
442 # The assert holds because it's a new relation, or an altered
443 # relation, in which case references have been removed by
444 # alter_field().
445 assert field_name not in relations_to_remote_model[model_key]
446 relations_to_remote_model[model_key][field_name] = field
447 else:
448 del relations_to_remote_model[model_key][field_name]
449 if not relations_to_remote_model[model_key]:
450 del relations_to_remote_model[model_key]
452 def resolve_model_field_relations(
453 self,
454 model_key,
455 field_name,
456 field,
457 concretes=None,
458 ):
459 remote_field = field.remote_field
460 if not remote_field:
461 return
462 if concretes is None:
463 concretes = self._get_concrete_models_mapping()
465 self.update_model_field_relation(
466 remote_field.model,
467 model_key,
468 field_name,
469 field,
470 concretes,
471 )
473 through = getattr(remote_field, "through", None)
474 if not through:
475 return
476 self.update_model_field_relation(
477 through, model_key, field_name, field, concretes
478 )
480 def resolve_model_relations(self, model_key, concretes=None):
481 if concretes is None:
482 concretes = self._get_concrete_models_mapping()
484 model_state = self.models[model_key]
485 for field_name, field in model_state.fields.items():
486 self.resolve_model_field_relations(model_key, field_name, field, concretes)
488 def resolve_fields_and_relations(self):
489 # Resolve fields.
490 for model_state in self.models.values():
491 for field_name, field in model_state.fields.items():
492 field.name = field_name
493 # Resolve relations.
494 # {remote_model_key: {model_key: {field_name: field}}}
495 self._relations = defaultdict(partial(defaultdict, dict))
496 concretes = self._get_concrete_models_mapping()
498 for model_key in concretes:
499 self.resolve_model_relations(model_key, concretes)
501 def get_concrete_model_key(self, model):
502 (concrete_models_mapping,) = self._get_concrete_models_mapping()
503 model_key = make_model_tuple(model)
504 return concrete_models_mapping[model_key]
506 def _get_concrete_models_mapping(self):
507 concrete_models_mapping = {}
508 for model_key, model_state in self.models.items():
509 concrete_models_mapping[model_key] = model_key
510 return concrete_models_mapping
512 def clone(self):
513 """Return an exact copy of this ProjectState."""
514 new_state = ProjectState(
515 models={k: v.clone() for k, v in self.models.items()},
516 real_packages=self.real_packages,
517 )
518 if "packages" in self.__dict__:
519 new_state.packages = self.packages.clone()
520 new_state.is_delayed = self.is_delayed
521 return new_state
523 def clear_delayed_packages_cache(self):
524 if self.is_delayed and "packages" in self.__dict__:
525 del self.__dict__["packages"]
527 @cached_property
528 def packages(self):
529 return StatePackages(self.real_packages, self.models)
531 @classmethod
532 def from_packages(cls, packages):
533 """Take an Packages and return a ProjectState matching it."""
534 app_models = {}
535 for model in packages.get_models(include_swapped=True):
536 model_state = ModelState.from_model(model)
537 app_models[(model_state.package_label, model_state.name_lower)] = (
538 model_state
539 )
540 return cls(app_models)
542 def __eq__(self, other):
543 return self.models == other.models and self.real_packages == other.real_packages
546class PackageConfigStub(PackageConfig):
547 """Stub of an PackageConfig. Only provides a label and a dict of models."""
549 def __init__(self, label):
550 self.packages = None
551 self.models = {}
552 # Package-label and package-name are not the same thing, so technically passing
553 # in the label here is wrong. In practice, migrations don't care about
554 # the package name, but we need something unique, and the label works fine.
555 self.label = label
556 self.name = label
558 def import_models(self):
559 self.models = self.packages.all_models[self.label]
562class StatePackages(Packages):
563 """
564 Subclass of the global Packages registry class to better handle dynamic model
565 additions and removals.
566 """
568 def __init__(self, real_packages, models, ignore_swappable=False):
569 # Any packages in self.real_packages should have all their models included
570 # in the render. We don't use the original model instances as there
571 # are some variables that refer to the Packages object.
572 # FKs/M2Ms from real packages are also not included as they just
573 # mess things up with partial states (due to lack of dependencies)
574 self.real_models = []
575 for package_label in real_packages:
576 app = global_packages.get_package_config(package_label)
577 for model in app.get_models():
578 self.real_models.append(ModelState.from_model(model, exclude_rels=True))
579 # Populate the app registry with a stub for each application.
580 package_labels = {model_state.package_label for model_state in models.values()}
581 package_configs = [
582 PackageConfigStub(label)
583 for label in sorted([*real_packages, *package_labels])
584 ]
585 super().__init__(package_configs)
587 # These locks get in the way of copying as implemented in clone(),
588 # which is called whenever Plain duplicates a StatePackages before
589 # updating it.
590 self._lock = None
592 self.render_multiple([*models.values(), *self.real_models])
594 # There shouldn't be any operations pending at this point.
595 from plain.models.preflight import _check_lazy_references
597 ignore = (
598 {make_model_tuple(settings.AUTH_USER_MODEL)} if ignore_swappable else set()
599 )
600 errors = _check_lazy_references(self, ignore=ignore)
601 if errors:
602 raise ValueError("\n".join(error.msg for error in errors))
604 @contextmanager
605 def bulk_update(self):
606 # Avoid clearing each model's cache for each change. Instead, clear
607 # all caches when we're finished updating the model instances.
608 ready = self.ready
609 self.ready = False
610 try:
611 yield
612 finally:
613 self.ready = ready
614 self.clear_cache()
616 def render_multiple(self, model_states):
617 # We keep trying to render the models in a loop, ignoring invalid
618 # base errors, until the size of the unrendered models doesn't
619 # decrease by at least one, meaning there's a base dependency loop/
620 # missing base.
621 if not model_states:
622 return
623 # Prevent that all model caches are expired for each render.
624 with self.bulk_update():
625 unrendered_models = model_states
626 while unrendered_models:
627 new_unrendered_models = []
628 for model in unrendered_models:
629 try:
630 model.render(self)
631 except InvalidBasesError:
632 new_unrendered_models.append(model)
633 if len(new_unrendered_models) == len(unrendered_models):
634 raise InvalidBasesError(
635 f"Cannot resolve bases for {new_unrendered_models!r}\nThis can happen if you are "
636 "inheriting models from an app with migrations (e.g. "
637 "contrib.auth)\n in an app with no migrations"
638 )
639 unrendered_models = new_unrendered_models
641 def clone(self):
642 """Return a clone of this registry."""
643 clone = StatePackages([], {})
644 clone.all_models = copy.deepcopy(self.all_models)
646 for package_label in self.package_configs:
647 package_config = PackageConfigStub(package_label)
648 package_config.packages = clone
649 package_config.import_models()
650 clone.package_configs[package_label] = package_config
652 # No need to actually clone them, they'll never change
653 clone.real_models = self.real_models
654 return clone
656 def register_model(self, package_label, model):
657 self.all_models[package_label][model._meta.model_name] = model
658 if package_label not in self.package_configs:
659 self.package_configs[package_label] = PackageConfigStub(package_label)
660 self.package_configs[package_label].packages = self
661 self.package_configs[package_label].models[model._meta.model_name] = model
662 self.do_pending_operations(model)
663 self.clear_cache()
665 def unregister_model(self, package_label, model_name):
666 try:
667 del self.all_models[package_label][model_name]
668 del self.package_configs[package_label].models[model_name]
669 except KeyError:
670 pass
673class ModelState:
674 """
675 Represent a Plain Model. Don't use the actual Model class as it's not
676 designed to have its options changed - instead, mutate this one and then
677 render it into a Model as required.
679 Note that while you are allowed to mutate .fields, you are not allowed
680 to mutate the Field instances inside there themselves - you must instead
681 assign new ones, as these are not detached during a clone.
682 """
684 def __init__(
685 self, package_label, name, fields, options=None, bases=None, managers=None
686 ):
687 self.package_label = package_label
688 self.name = name
689 self.fields = dict(fields)
690 self.options = options or {}
691 self.options.setdefault("indexes", [])
692 self.options.setdefault("constraints", [])
693 self.bases = bases or (models.Model,)
694 self.managers = managers or []
695 for name, field in self.fields.items():
696 # Sanity-check that fields are NOT already bound to a model.
697 if hasattr(field, "model"):
698 raise ValueError(
699 f'ModelState.fields cannot be bound to a model - "{name}" is.'
700 )
701 # Sanity-check that relation fields are NOT referring to a model class.
702 if field.is_relation and hasattr(field.related_model, "_meta"):
703 raise ValueError(
704 f'ModelState.fields cannot refer to a model class - "{name}.to" does. '
705 "Use a string reference instead."
706 )
707 if field.many_to_many and hasattr(field.remote_field.through, "_meta"):
708 raise ValueError(
709 f'ModelState.fields cannot refer to a model class - "{name}.through" '
710 "does. Use a string reference instead."
711 )
712 # Sanity-check that indexes have their name set.
713 for index in self.options["indexes"]:
714 if not index.name:
715 raise ValueError(
716 "Indexes passed to ModelState require a name attribute. "
717 f"{index!r} doesn't have one."
718 )
720 @cached_property
721 def name_lower(self):
722 return self.name.lower()
724 def get_field(self, field_name):
725 if field_name == "_order":
726 field_name = self.options.get("order_with_respect_to", field_name)
727 return self.fields[field_name]
729 @classmethod
730 def from_model(cls, model, exclude_rels=False):
731 """Given a model, return a ModelState representing it."""
732 # Deconstruct the fields
733 fields = []
734 for field in model._meta.local_fields:
735 if getattr(field, "remote_field", None) and exclude_rels:
736 continue
737 if isinstance(field, models.OrderWrt):
738 continue
739 name = field.name
740 try:
741 fields.append((name, field.clone()))
742 except TypeError as e:
743 raise TypeError(
744 f"Couldn't reconstruct field {name} on {model._meta.label}: {e}"
745 )
746 if not exclude_rels:
747 for field in model._meta.local_many_to_many:
748 name = field.name
749 try:
750 fields.append((name, field.clone()))
751 except TypeError as e:
752 raise TypeError(
753 f"Couldn't reconstruct m2m field {name} on {model._meta.object_name}: {e}"
754 )
755 # Extract the options
756 options = {}
757 for name in DEFAULT_NAMES:
758 # Ignore some special options
759 if name in ["packages", "package_label"]:
760 continue
761 elif name in model._meta.original_attrs:
762 if name == "indexes":
763 indexes = [idx.clone() for idx in model._meta.indexes]
764 for index in indexes:
765 if not index.name:
766 index.set_name_with_model(model)
767 options["indexes"] = indexes
768 elif name == "constraints":
769 options["constraints"] = [
770 con.clone() for con in model._meta.constraints
771 ]
772 else:
773 options[name] = model._meta.original_attrs[name]
774 # If we're ignoring relationships, remove all field-listing model
775 # options (that option basically just means "make a stub model")
776 if exclude_rels:
777 if "order_with_respect_to" in options:
778 del options["order_with_respect_to"]
779 # Private fields are ignored, so remove options that refer to them.
780 elif options.get("order_with_respect_to") in {
781 field.name for field in model._meta.private_fields
782 }:
783 del options["order_with_respect_to"]
785 def flatten_bases(model):
786 bases = []
787 for base in model.__bases__:
788 if hasattr(base, "_meta") and base._meta.abstract:
789 bases.extend(flatten_bases(base))
790 else:
791 bases.append(base)
792 return bases
794 # We can't rely on __mro__ directly because we only want to flatten
795 # abstract models and not the whole tree. However by recursing on
796 # __bases__ we may end up with duplicates and ordering issues, we
797 # therefore discard any duplicates and reorder the bases according
798 # to their index in the MRO.
799 flattened_bases = sorted(
800 set(flatten_bases(model)), key=lambda x: model.__mro__.index(x)
801 )
803 # Make our record
804 bases = tuple(
805 (base._meta.label_lower if hasattr(base, "_meta") else base)
806 for base in flattened_bases
807 )
808 # Ensure at least one base inherits from models.Model
809 if not any(
810 (isinstance(base, str) or issubclass(base, models.Model)) for base in bases
811 ):
812 bases = (models.Model,)
814 managers = []
815 manager_names = set()
816 default_manager_shim = None
817 for manager in model._meta.managers:
818 if manager.name in manager_names:
819 # Skip overridden managers.
820 continue
821 elif manager.use_in_migrations:
822 # Copy managers usable in migrations.
823 new_manager = copy.copy(manager)
824 new_manager._set_creation_counter()
825 elif manager is model._base_manager or manager is model._default_manager:
826 # Shim custom managers used as default and base managers.
827 new_manager = models.Manager()
828 new_manager.model = manager.model
829 new_manager.name = manager.name
830 if manager is model._default_manager:
831 default_manager_shim = new_manager
832 else:
833 continue
834 manager_names.add(manager.name)
835 managers.append((manager.name, new_manager))
837 # Ignore a shimmed default manager called objects if it's the only one.
838 if managers == [("objects", default_manager_shim)]:
839 managers = []
841 # Construct the new ModelState
842 return cls(
843 model._meta.package_label,
844 model._meta.object_name,
845 fields,
846 options,
847 bases,
848 managers,
849 )
851 def construct_managers(self):
852 """Deep-clone the managers using deconstruction."""
853 # Sort all managers by their creation counter
854 sorted_managers = sorted(self.managers, key=lambda v: v[1].creation_counter)
855 for mgr_name, manager in sorted_managers:
856 as_manager, manager_path, qs_path, args, kwargs = manager.deconstruct()
857 if as_manager:
858 qs_class = import_string(qs_path)
859 yield mgr_name, qs_class.as_manager()
860 else:
861 manager_class = import_string(manager_path)
862 yield mgr_name, manager_class(*args, **kwargs)
864 def clone(self):
865 """Return an exact copy of this ModelState."""
866 return self.__class__(
867 package_label=self.package_label,
868 name=self.name,
869 fields=dict(self.fields),
870 # Since options are shallow-copied here, operations such as
871 # AddIndex must replace their option (e.g 'indexes') rather
872 # than mutating it.
873 options=dict(self.options),
874 bases=self.bases,
875 managers=list(self.managers),
876 )
878 def render(self, packages):
879 """Create a Model object from our current state into the given packages."""
880 # First, make a Meta object
881 meta_contents = {
882 "package_label": self.package_label,
883 "packages": packages,
884 **self.options,
885 }
886 meta = type("Meta", (), meta_contents)
887 # Then, work out our bases
888 try:
889 bases = tuple(
890 (packages.get_model(base) if isinstance(base, str) else base)
891 for base in self.bases
892 )
893 except LookupError:
894 raise InvalidBasesError(
895 f"Cannot resolve one or more bases from {self.bases!r}"
896 )
897 # Clone fields for the body, add other bits.
898 body = {name: field.clone() for name, field in self.fields.items()}
899 body["Meta"] = meta
900 body["__module__"] = "__fake__"
902 # Restore managers
903 body.update(self.construct_managers())
904 # Then, make a Model object (packages.register_model is called in __new__)
905 return type(self.name, bases, body)
907 def get_index_by_name(self, name):
908 for index in self.options["indexes"]:
909 if index.name == name:
910 return index
911 raise ValueError(f"No index named {name} on model {self.name}")
913 def get_constraint_by_name(self, name):
914 for constraint in self.options["constraints"]:
915 if constraint.name == name:
916 return constraint
917 raise ValueError(f"No constraint named {name} on model {self.name}")
919 def __repr__(self):
920 return f"<{self.__class__.__name__}: '{self.package_label}.{self.name}'>"
922 def __eq__(self, other):
923 return (
924 (self.package_label == other.package_label)
925 and (self.name == other.name)
926 and (len(self.fields) == len(other.fields))
927 and all(
928 k1 == k2 and f1.deconstruct()[1:] == f2.deconstruct()[1:]
929 for (k1, f1), (k2, f2) in zip(
930 sorted(self.fields.items()),
931 sorted(other.fields.items()),
932 )
933 )
934 and (self.options == other.options)
935 and (self.bases == other.bases)
936 and (self.managers == other.managers)
937 )