Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/migrations/state.py: 57%

550 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 11:16 -0600

1import copy 

2from collections import defaultdict 

3from contextlib import contextmanager 

4from functools import partial 

5 

6from plain import models 

7from plain.exceptions import FieldDoesNotExist 

8from plain.models.fields import NOT_PROVIDED 

9from plain.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT 

10from plain.models.migrations.utils import field_is_referenced, get_references 

11from plain.models.options import DEFAULT_NAMES 

12from plain.models.utils import make_model_tuple 

13from plain.packages import PackageConfig 

14from plain.packages.registry import Packages 

15from plain.packages.registry import packages as global_packages 

16from plain.runtime import settings 

17from plain.utils.functional import cached_property 

18from plain.utils.module_loading import import_string 

19 

20from .exceptions import InvalidBasesError 

21from .utils import resolve_relation 

22 

23 

24def _get_package_label_and_model_name(model, package_label=""): 

25 if isinstance(model, str): 

26 split = model.split(".", 1) 

27 return tuple(split) if len(split) == 2 else (package_label, split[0]) 

28 else: 

29 return model._meta.package_label, model._meta.model_name 

30 

31 

32def _get_related_models(m): 

33 """Return all models that have a direct relationship to the given model.""" 

34 related_models = [ 

35 subclass 

36 for subclass in m.__subclasses__() 

37 if issubclass(subclass, models.Model) 

38 ] 

39 related_fields_models = set() 

40 for f in m._meta.get_fields(include_parents=True, include_hidden=True): 

41 if ( 

42 f.is_relation 

43 and f.related_model is not None 

44 and not isinstance(f.related_model, str) 

45 ): 

46 related_fields_models.add(f.model) 

47 related_models.append(f.related_model) 

48 return related_models 

49 

50 

51def get_related_models_tuples(model): 

52 """ 

53 Return a list of typical (package_label, model_name) tuples for all related 

54 models for the given model. 

55 """ 

56 return { 

57 (rel_mod._meta.package_label, rel_mod._meta.model_name) 

58 for rel_mod in _get_related_models(model) 

59 } 

60 

61 

62def get_related_models_recursive(model): 

63 """ 

64 Return all models that have a direct or indirect relationship 

65 to the given model. 

66 

67 Relationships are either defined by explicit relational fields, like 

68 ForeignKey, ManyToManyField or OneToOneField, or by inheriting from another 

69 model (a superclass is related to its subclasses, but not vice versa). Note, 

70 however, that a model inheriting from a concrete model is also related to 

71 its superclass through the implicit *_ptr OneToOneField on the subclass. 

72 """ 

73 seen = set() 

74 queue = _get_related_models(model) 

75 for rel_mod in queue: 

76 rel_package_label, rel_model_name = ( 

77 rel_mod._meta.package_label, 

78 rel_mod._meta.model_name, 

79 ) 

80 if (rel_package_label, rel_model_name) in seen: 

81 continue 

82 seen.add((rel_package_label, rel_model_name)) 

83 queue.extend(_get_related_models(rel_mod)) 

84 return seen - {(model._meta.package_label, model._meta.model_name)} 

85 

86 

87class ProjectState: 

88 """ 

89 Represent the entire project's overall state. This is the item that is 

90 passed around - do it here rather than at the app level so that cross-app 

91 FKs/etc. resolve properly. 

92 """ 

93 

94 def __init__(self, models=None, real_packages=None): 

95 self.models = models or {} 

96 # Packages to include from main registry, usually unmigrated ones 

97 if real_packages is None: 

98 real_packages = set() 

99 else: 

100 assert isinstance(real_packages, set) 

101 self.real_packages = real_packages 

102 self.is_delayed = False 

103 # {remote_model_key: {model_key: {field_name: field}}} 

104 self._relations = None 

105 

106 @property 

107 def relations(self): 

108 if self._relations is None: 

109 self.resolve_fields_and_relations() 

110 return self._relations 

111 

112 def add_model(self, model_state): 

113 model_key = model_state.package_label, model_state.name_lower 

114 self.models[model_key] = model_state 

115 if self._relations is not None: 

116 self.resolve_model_relations(model_key) 

117 if "packages" in self.__dict__: # hasattr would cache the property 

118 self.reload_model(*model_key) 

119 

120 def remove_model(self, package_label, model_name): 

121 model_key = package_label, model_name 

122 del self.models[model_key] 

123 if self._relations is not None: 

124 self._relations.pop(model_key, None) 

125 # Call list() since _relations can change size during iteration. 

126 for related_model_key, model_relations in list(self._relations.items()): 

127 model_relations.pop(model_key, None) 

128 if not model_relations: 

129 del self._relations[related_model_key] 

130 if "packages" in self.__dict__: # hasattr would cache the property 

131 self.packages.unregister_model(*model_key) 

132 # Need to do this explicitly since unregister_model() doesn't clear 

133 # the cache automatically (#24513) 

134 self.packages.clear_cache() 

135 

136 def rename_model(self, package_label, old_name, new_name): 

137 # Add a new model. 

138 old_name_lower = old_name.lower() 

139 new_name_lower = new_name.lower() 

140 renamed_model = self.models[package_label, old_name_lower].clone() 

141 renamed_model.name = new_name 

142 self.models[package_label, new_name_lower] = renamed_model 

143 # Repoint all fields pointing to the old model to the new one. 

144 old_model_tuple = (package_label, old_name_lower) 

145 new_remote_model = f"{package_label}.{new_name}" 

146 to_reload = set() 

147 for model_state, name, field, reference in get_references( 

148 self, old_model_tuple 

149 ): 

150 changed_field = None 

151 if reference.to: 

152 changed_field = field.clone() 

153 changed_field.remote_field.model = new_remote_model 

154 if reference.through: 

155 if changed_field is None: 

156 changed_field = field.clone() 

157 changed_field.remote_field.through = new_remote_model 

158 if changed_field: 

159 model_state.fields[name] = changed_field 

160 to_reload.add((model_state.package_label, model_state.name_lower)) 

161 if self._relations is not None: 

162 old_name_key = package_label, old_name_lower 

163 new_name_key = package_label, new_name_lower 

164 if old_name_key in self._relations: 

165 self._relations[new_name_key] = self._relations.pop(old_name_key) 

166 for model_relations in self._relations.values(): 

167 if old_name_key in model_relations: 

168 model_relations[new_name_key] = model_relations.pop(old_name_key) 

169 # Reload models related to old model before removing the old model. 

170 self.reload_models(to_reload, delay=True) 

171 # Remove the old model. 

172 self.remove_model(package_label, old_name_lower) 

173 self.reload_model(package_label, new_name_lower, delay=True) 

174 

175 def alter_model_options(self, package_label, model_name, options, option_keys=None): 

176 model_state = self.models[package_label, model_name] 

177 model_state.options = {**model_state.options, **options} 

178 if option_keys: 

179 for key in option_keys: 

180 if key not in options: 

181 model_state.options.pop(key, False) 

182 self.reload_model(package_label, model_name, delay=True) 

183 

184 def remove_model_options( 

185 self, package_label, model_name, option_name, value_to_remove 

186 ): 

187 model_state = self.models[package_label, model_name] 

188 if objs := model_state.options.get(option_name): 

189 model_state.options[option_name] = [ 

190 obj for obj in objs if tuple(obj) != tuple(value_to_remove) 

191 ] 

192 self.reload_model(package_label, model_name, delay=True) 

193 

194 def alter_model_managers(self, package_label, model_name, managers): 

195 model_state = self.models[package_label, model_name] 

196 model_state.managers = list(managers) 

197 self.reload_model(package_label, model_name, delay=True) 

198 

199 def _append_option(self, package_label, model_name, option_name, obj): 

200 model_state = self.models[package_label, model_name] 

201 model_state.options[option_name] = [*model_state.options[option_name], obj] 

202 self.reload_model(package_label, model_name, delay=True) 

203 

204 def _remove_option(self, package_label, model_name, option_name, obj_name): 

205 model_state = self.models[package_label, model_name] 

206 objs = model_state.options[option_name] 

207 model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name] 

208 self.reload_model(package_label, model_name, delay=True) 

209 

210 def add_index(self, package_label, model_name, index): 

211 self._append_option(package_label, model_name, "indexes", index) 

212 

213 def remove_index(self, package_label, model_name, index_name): 

214 self._remove_option(package_label, model_name, "indexes", index_name) 

215 

216 def rename_index(self, package_label, model_name, old_index_name, new_index_name): 

217 model_state = self.models[package_label, model_name] 

218 objs = model_state.options["indexes"] 

219 

220 new_indexes = [] 

221 for obj in objs: 

222 if obj.name == old_index_name: 

223 obj = obj.clone() 

224 obj.name = new_index_name 

225 new_indexes.append(obj) 

226 

227 model_state.options["indexes"] = new_indexes 

228 self.reload_model(package_label, model_name, delay=True) 

229 

230 def add_constraint(self, package_label, model_name, constraint): 

231 self._append_option(package_label, model_name, "constraints", constraint) 

232 

233 def remove_constraint(self, package_label, model_name, constraint_name): 

234 self._remove_option(package_label, model_name, "constraints", constraint_name) 

235 

236 def add_field(self, package_label, model_name, name, field, preserve_default): 

237 # If preserve default is off, don't use the default for future state. 

238 if not preserve_default: 

239 field = field.clone() 

240 field.default = NOT_PROVIDED 

241 else: 

242 field = field 

243 model_key = package_label, model_name 

244 self.models[model_key].fields[name] = field 

245 if self._relations is not None: 

246 self.resolve_model_field_relations(model_key, name, field) 

247 # Delay rendering of relationships if it's not a relational field. 

248 delay = not field.is_relation 

249 self.reload_model(*model_key, delay=delay) 

250 

251 def remove_field(self, package_label, model_name, name): 

252 model_key = package_label, model_name 

253 model_state = self.models[model_key] 

254 old_field = model_state.fields.pop(name) 

255 if self._relations is not None: 

256 self.resolve_model_field_relations(model_key, name, old_field) 

257 # Delay rendering of relationships if it's not a relational field. 

258 delay = not old_field.is_relation 

259 self.reload_model(*model_key, delay=delay) 

260 

261 def alter_field(self, package_label, model_name, name, field, preserve_default): 

262 if not preserve_default: 

263 field = field.clone() 

264 field.default = NOT_PROVIDED 

265 else: 

266 field = field 

267 model_key = package_label, model_name 

268 fields = self.models[model_key].fields 

269 if self._relations is not None: 

270 old_field = fields.pop(name) 

271 if old_field.is_relation: 

272 self.resolve_model_field_relations(model_key, name, old_field) 

273 fields[name] = field 

274 if field.is_relation: 

275 self.resolve_model_field_relations(model_key, name, field) 

276 else: 

277 fields[name] = field 

278 # TODO: investigate if old relational fields must be reloaded or if 

279 # it's sufficient if the new field is (#27737). 

280 # Delay rendering of relationships if it's not a relational field and 

281 # not referenced by a foreign key. 

282 delay = not field.is_relation and not field_is_referenced( 

283 self, model_key, (name, field) 

284 ) 

285 self.reload_model(*model_key, delay=delay) 

286 

287 def rename_field(self, package_label, model_name, old_name, new_name): 

288 model_key = package_label, model_name 

289 model_state = self.models[model_key] 

290 # Rename the field. 

291 fields = model_state.fields 

292 try: 

293 found = fields.pop(old_name) 

294 except KeyError: 

295 raise FieldDoesNotExist( 

296 f"{package_label}.{model_name} has no field named '{old_name}'" 

297 ) 

298 fields[new_name] = found 

299 for field in fields.values(): 

300 # Fix from_fields to refer to the new field. 

301 from_fields = getattr(field, "from_fields", None) 

302 if from_fields: 

303 field.from_fields = tuple( 

304 [ 

305 new_name if from_field_name == old_name else from_field_name 

306 for from_field_name in from_fields 

307 ] 

308 ) 

309 

310 # Fix to_fields to refer to the new field. 

311 delay = True 

312 references = get_references(self, model_key, (old_name, found)) 

313 for *_, field, reference in references: 

314 delay = False 

315 if reference.to: 

316 remote_field, to_fields = reference.to 

317 if getattr(remote_field, "field_name", None) == old_name: 

318 remote_field.field_name = new_name 

319 if to_fields: 

320 field.to_fields = tuple( 

321 [ 

322 new_name if to_field_name == old_name else to_field_name 

323 for to_field_name in to_fields 

324 ] 

325 ) 

326 if self._relations is not None: 

327 old_name_lower = old_name.lower() 

328 new_name_lower = new_name.lower() 

329 for to_model in self._relations.values(): 

330 if old_name_lower in to_model[model_key]: 

331 field = to_model[model_key].pop(old_name_lower) 

332 field.name = new_name_lower 

333 to_model[model_key][new_name_lower] = field 

334 self.reload_model(*model_key, delay=delay) 

335 

336 def _find_reload_model(self, package_label, model_name, delay=False): 

337 if delay: 

338 self.is_delayed = True 

339 

340 related_models = set() 

341 

342 try: 

343 old_model = self.packages.get_model(package_label, model_name) 

344 except LookupError: 

345 pass 

346 else: 

347 # Get all relations to and from the old model before reloading, 

348 # as _meta.packages may change 

349 if delay: 

350 related_models = get_related_models_tuples(old_model) 

351 else: 

352 related_models = get_related_models_recursive(old_model) 

353 

354 # Get all outgoing references from the model to be rendered 

355 model_state = self.models[(package_label, model_name)] 

356 # Directly related models are the models pointed to by ForeignKeys, 

357 # OneToOneFields, and ManyToManyFields. 

358 direct_related_models = set() 

359 for field in model_state.fields.values(): 

360 if field.is_relation: 

361 if field.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT: 

362 continue 

363 rel_package_label, rel_model_name = _get_package_label_and_model_name( 

364 field.related_model, package_label 

365 ) 

366 direct_related_models.add((rel_package_label, rel_model_name.lower())) 

367 

368 # For all direct related models recursively get all related models. 

369 related_models.update(direct_related_models) 

370 for rel_package_label, rel_model_name in direct_related_models: 

371 try: 

372 rel_model = self.packages.get_model(rel_package_label, rel_model_name) 

373 except LookupError: 

374 pass 

375 else: 

376 if delay: 

377 related_models.update(get_related_models_tuples(rel_model)) 

378 else: 

379 related_models.update(get_related_models_recursive(rel_model)) 

380 

381 # Include the model itself 

382 related_models.add((package_label, model_name)) 

383 

384 return related_models 

385 

386 def reload_model(self, package_label, model_name, delay=False): 

387 if "packages" in self.__dict__: # hasattr would cache the property 

388 related_models = self._find_reload_model(package_label, model_name, delay) 

389 self._reload(related_models) 

390 

391 def reload_models(self, models, delay=True): 

392 if "packages" in self.__dict__: # hasattr would cache the property 

393 related_models = set() 

394 for package_label, model_name in models: 

395 related_models.update( 

396 self._find_reload_model(package_label, model_name, delay) 

397 ) 

398 self._reload(related_models) 

399 

400 def _reload(self, related_models): 

401 # Unregister all related models 

402 with self.packages.bulk_update(): 

403 for rel_package_label, rel_model_name in related_models: 

404 self.packages.unregister_model(rel_package_label, rel_model_name) 

405 

406 states_to_be_rendered = [] 

407 # Gather all models states of those models that will be rerendered. 

408 # This includes: 

409 # 1. All related models of unmigrated packages 

410 for model_state in self.packages.real_models: 

411 if (model_state.package_label, model_state.name_lower) in related_models: 

412 states_to_be_rendered.append(model_state) 

413 

414 # 2. All related models of migrated packages 

415 for rel_package_label, rel_model_name in related_models: 

416 try: 

417 model_state = self.models[rel_package_label, rel_model_name] 

418 except KeyError: 

419 pass 

420 else: 

421 states_to_be_rendered.append(model_state) 

422 

423 # Render all models 

424 self.packages.render_multiple(states_to_be_rendered) 

425 

426 def update_model_field_relation( 

427 self, 

428 model, 

429 model_key, 

430 field_name, 

431 field, 

432 concretes, 

433 ): 

434 remote_model_key = resolve_relation(model, *model_key) 

435 if ( 

436 remote_model_key[0] not in self.real_packages 

437 and remote_model_key in concretes 

438 ): 

439 remote_model_key = concretes[remote_model_key] 

440 relations_to_remote_model = self._relations[remote_model_key] 

441 if field_name in self.models[model_key].fields: 

442 # The assert holds because it's a new relation, or an altered 

443 # relation, in which case references have been removed by 

444 # alter_field(). 

445 assert field_name not in relations_to_remote_model[model_key] 

446 relations_to_remote_model[model_key][field_name] = field 

447 else: 

448 del relations_to_remote_model[model_key][field_name] 

449 if not relations_to_remote_model[model_key]: 

450 del relations_to_remote_model[model_key] 

451 

452 def resolve_model_field_relations( 

453 self, 

454 model_key, 

455 field_name, 

456 field, 

457 concretes=None, 

458 ): 

459 remote_field = field.remote_field 

460 if not remote_field: 

461 return 

462 if concretes is None: 

463 concretes = self._get_concrete_models_mapping() 

464 

465 self.update_model_field_relation( 

466 remote_field.model, 

467 model_key, 

468 field_name, 

469 field, 

470 concretes, 

471 ) 

472 

473 through = getattr(remote_field, "through", None) 

474 if not through: 

475 return 

476 self.update_model_field_relation( 

477 through, model_key, field_name, field, concretes 

478 ) 

479 

480 def resolve_model_relations(self, model_key, concretes=None): 

481 if concretes is None: 

482 concretes = self._get_concrete_models_mapping() 

483 

484 model_state = self.models[model_key] 

485 for field_name, field in model_state.fields.items(): 

486 self.resolve_model_field_relations(model_key, field_name, field, concretes) 

487 

488 def resolve_fields_and_relations(self): 

489 # Resolve fields. 

490 for model_state in self.models.values(): 

491 for field_name, field in model_state.fields.items(): 

492 field.name = field_name 

493 # Resolve relations. 

494 # {remote_model_key: {model_key: {field_name: field}}} 

495 self._relations = defaultdict(partial(defaultdict, dict)) 

496 concretes = self._get_concrete_models_mapping() 

497 

498 for model_key in concretes: 

499 self.resolve_model_relations(model_key, concretes) 

500 

501 def get_concrete_model_key(self, model): 

502 (concrete_models_mapping,) = self._get_concrete_models_mapping() 

503 model_key = make_model_tuple(model) 

504 return concrete_models_mapping[model_key] 

505 

506 def _get_concrete_models_mapping(self): 

507 concrete_models_mapping = {} 

508 for model_key, model_state in self.models.items(): 

509 concrete_models_mapping[model_key] = model_key 

510 return concrete_models_mapping 

511 

512 def clone(self): 

513 """Return an exact copy of this ProjectState.""" 

514 new_state = ProjectState( 

515 models={k: v.clone() for k, v in self.models.items()}, 

516 real_packages=self.real_packages, 

517 ) 

518 if "packages" in self.__dict__: 

519 new_state.packages = self.packages.clone() 

520 new_state.is_delayed = self.is_delayed 

521 return new_state 

522 

523 def clear_delayed_packages_cache(self): 

524 if self.is_delayed and "packages" in self.__dict__: 

525 del self.__dict__["packages"] 

526 

527 @cached_property 

528 def packages(self): 

529 return StatePackages(self.real_packages, self.models) 

530 

531 @classmethod 

532 def from_packages(cls, packages): 

533 """Take an Packages and return a ProjectState matching it.""" 

534 app_models = {} 

535 for model in packages.get_models(include_swapped=True): 

536 model_state = ModelState.from_model(model) 

537 app_models[(model_state.package_label, model_state.name_lower)] = ( 

538 model_state 

539 ) 

540 return cls(app_models) 

541 

542 def __eq__(self, other): 

543 return self.models == other.models and self.real_packages == other.real_packages 

544 

545 

546class PackageConfigStub(PackageConfig): 

547 """Stub of an PackageConfig. Only provides a label and a dict of models.""" 

548 

549 def __init__(self, label): 

550 self.packages = None 

551 self.models = {} 

552 # Package-label and package-name are not the same thing, so technically passing 

553 # in the label here is wrong. In practice, migrations don't care about 

554 # the package name, but we need something unique, and the label works fine. 

555 self.label = label 

556 self.name = label 

557 

558 def import_models(self): 

559 self.models = self.packages.all_models[self.label] 

560 

561 

562class StatePackages(Packages): 

563 """ 

564 Subclass of the global Packages registry class to better handle dynamic model 

565 additions and removals. 

566 """ 

567 

568 def __init__(self, real_packages, models, ignore_swappable=False): 

569 # Any packages in self.real_packages should have all their models included 

570 # in the render. We don't use the original model instances as there 

571 # are some variables that refer to the Packages object. 

572 # FKs/M2Ms from real packages are also not included as they just 

573 # mess things up with partial states (due to lack of dependencies) 

574 self.real_models = [] 

575 for package_label in real_packages: 

576 app = global_packages.get_package_config(package_label) 

577 for model in app.get_models(): 

578 self.real_models.append(ModelState.from_model(model, exclude_rels=True)) 

579 # Populate the app registry with a stub for each application. 

580 package_labels = {model_state.package_label for model_state in models.values()} 

581 package_configs = [ 

582 PackageConfigStub(label) 

583 for label in sorted([*real_packages, *package_labels]) 

584 ] 

585 super().__init__(package_configs) 

586 

587 # These locks get in the way of copying as implemented in clone(), 

588 # which is called whenever Plain duplicates a StatePackages before 

589 # updating it. 

590 self._lock = None 

591 

592 self.render_multiple([*models.values(), *self.real_models]) 

593 

594 # There shouldn't be any operations pending at this point. 

595 from plain.models.preflight import _check_lazy_references 

596 

597 ignore = ( 

598 {make_model_tuple(settings.AUTH_USER_MODEL)} if ignore_swappable else set() 

599 ) 

600 errors = _check_lazy_references(self, ignore=ignore) 

601 if errors: 

602 raise ValueError("\n".join(error.msg for error in errors)) 

603 

604 @contextmanager 

605 def bulk_update(self): 

606 # Avoid clearing each model's cache for each change. Instead, clear 

607 # all caches when we're finished updating the model instances. 

608 ready = self.ready 

609 self.ready = False 

610 try: 

611 yield 

612 finally: 

613 self.ready = ready 

614 self.clear_cache() 

615 

616 def render_multiple(self, model_states): 

617 # We keep trying to render the models in a loop, ignoring invalid 

618 # base errors, until the size of the unrendered models doesn't 

619 # decrease by at least one, meaning there's a base dependency loop/ 

620 # missing base. 

621 if not model_states: 

622 return 

623 # Prevent that all model caches are expired for each render. 

624 with self.bulk_update(): 

625 unrendered_models = model_states 

626 while unrendered_models: 

627 new_unrendered_models = [] 

628 for model in unrendered_models: 

629 try: 

630 model.render(self) 

631 except InvalidBasesError: 

632 new_unrendered_models.append(model) 

633 if len(new_unrendered_models) == len(unrendered_models): 

634 raise InvalidBasesError( 

635 f"Cannot resolve bases for {new_unrendered_models!r}\nThis can happen if you are " 

636 "inheriting models from an app with migrations (e.g. " 

637 "contrib.auth)\n in an app with no migrations" 

638 ) 

639 unrendered_models = new_unrendered_models 

640 

641 def clone(self): 

642 """Return a clone of this registry.""" 

643 clone = StatePackages([], {}) 

644 clone.all_models = copy.deepcopy(self.all_models) 

645 

646 for package_label in self.package_configs: 

647 package_config = PackageConfigStub(package_label) 

648 package_config.packages = clone 

649 package_config.import_models() 

650 clone.package_configs[package_label] = package_config 

651 

652 # No need to actually clone them, they'll never change 

653 clone.real_models = self.real_models 

654 return clone 

655 

656 def register_model(self, package_label, model): 

657 self.all_models[package_label][model._meta.model_name] = model 

658 if package_label not in self.package_configs: 

659 self.package_configs[package_label] = PackageConfigStub(package_label) 

660 self.package_configs[package_label].packages = self 

661 self.package_configs[package_label].models[model._meta.model_name] = model 

662 self.do_pending_operations(model) 

663 self.clear_cache() 

664 

665 def unregister_model(self, package_label, model_name): 

666 try: 

667 del self.all_models[package_label][model_name] 

668 del self.package_configs[package_label].models[model_name] 

669 except KeyError: 

670 pass 

671 

672 

673class ModelState: 

674 """ 

675 Represent a Plain Model. Don't use the actual Model class as it's not 

676 designed to have its options changed - instead, mutate this one and then 

677 render it into a Model as required. 

678 

679 Note that while you are allowed to mutate .fields, you are not allowed 

680 to mutate the Field instances inside there themselves - you must instead 

681 assign new ones, as these are not detached during a clone. 

682 """ 

683 

684 def __init__( 

685 self, package_label, name, fields, options=None, bases=None, managers=None 

686 ): 

687 self.package_label = package_label 

688 self.name = name 

689 self.fields = dict(fields) 

690 self.options = options or {} 

691 self.options.setdefault("indexes", []) 

692 self.options.setdefault("constraints", []) 

693 self.bases = bases or (models.Model,) 

694 self.managers = managers or [] 

695 for name, field in self.fields.items(): 

696 # Sanity-check that fields are NOT already bound to a model. 

697 if hasattr(field, "model"): 

698 raise ValueError( 

699 f'ModelState.fields cannot be bound to a model - "{name}" is.' 

700 ) 

701 # Sanity-check that relation fields are NOT referring to a model class. 

702 if field.is_relation and hasattr(field.related_model, "_meta"): 

703 raise ValueError( 

704 f'ModelState.fields cannot refer to a model class - "{name}.to" does. ' 

705 "Use a string reference instead." 

706 ) 

707 if field.many_to_many and hasattr(field.remote_field.through, "_meta"): 

708 raise ValueError( 

709 f'ModelState.fields cannot refer to a model class - "{name}.through" ' 

710 "does. Use a string reference instead." 

711 ) 

712 # Sanity-check that indexes have their name set. 

713 for index in self.options["indexes"]: 

714 if not index.name: 

715 raise ValueError( 

716 "Indexes passed to ModelState require a name attribute. " 

717 f"{index!r} doesn't have one." 

718 ) 

719 

720 @cached_property 

721 def name_lower(self): 

722 return self.name.lower() 

723 

724 def get_field(self, field_name): 

725 if field_name == "_order": 

726 field_name = self.options.get("order_with_respect_to", field_name) 

727 return self.fields[field_name] 

728 

729 @classmethod 

730 def from_model(cls, model, exclude_rels=False): 

731 """Given a model, return a ModelState representing it.""" 

732 # Deconstruct the fields 

733 fields = [] 

734 for field in model._meta.local_fields: 

735 if getattr(field, "remote_field", None) and exclude_rels: 

736 continue 

737 if isinstance(field, models.OrderWrt): 

738 continue 

739 name = field.name 

740 try: 

741 fields.append((name, field.clone())) 

742 except TypeError as e: 

743 raise TypeError( 

744 f"Couldn't reconstruct field {name} on {model._meta.label}: {e}" 

745 ) 

746 if not exclude_rels: 

747 for field in model._meta.local_many_to_many: 

748 name = field.name 

749 try: 

750 fields.append((name, field.clone())) 

751 except TypeError as e: 

752 raise TypeError( 

753 f"Couldn't reconstruct m2m field {name} on {model._meta.object_name}: {e}" 

754 ) 

755 # Extract the options 

756 options = {} 

757 for name in DEFAULT_NAMES: 

758 # Ignore some special options 

759 if name in ["packages", "package_label"]: 

760 continue 

761 elif name in model._meta.original_attrs: 

762 if name == "indexes": 

763 indexes = [idx.clone() for idx in model._meta.indexes] 

764 for index in indexes: 

765 if not index.name: 

766 index.set_name_with_model(model) 

767 options["indexes"] = indexes 

768 elif name == "constraints": 

769 options["constraints"] = [ 

770 con.clone() for con in model._meta.constraints 

771 ] 

772 else: 

773 options[name] = model._meta.original_attrs[name] 

774 # If we're ignoring relationships, remove all field-listing model 

775 # options (that option basically just means "make a stub model") 

776 if exclude_rels: 

777 if "order_with_respect_to" in options: 

778 del options["order_with_respect_to"] 

779 # Private fields are ignored, so remove options that refer to them. 

780 elif options.get("order_with_respect_to") in { 

781 field.name for field in model._meta.private_fields 

782 }: 

783 del options["order_with_respect_to"] 

784 

785 def flatten_bases(model): 

786 bases = [] 

787 for base in model.__bases__: 

788 if hasattr(base, "_meta") and base._meta.abstract: 

789 bases.extend(flatten_bases(base)) 

790 else: 

791 bases.append(base) 

792 return bases 

793 

794 # We can't rely on __mro__ directly because we only want to flatten 

795 # abstract models and not the whole tree. However by recursing on 

796 # __bases__ we may end up with duplicates and ordering issues, we 

797 # therefore discard any duplicates and reorder the bases according 

798 # to their index in the MRO. 

799 flattened_bases = sorted( 

800 set(flatten_bases(model)), key=lambda x: model.__mro__.index(x) 

801 ) 

802 

803 # Make our record 

804 bases = tuple( 

805 (base._meta.label_lower if hasattr(base, "_meta") else base) 

806 for base in flattened_bases 

807 ) 

808 # Ensure at least one base inherits from models.Model 

809 if not any( 

810 (isinstance(base, str) or issubclass(base, models.Model)) for base in bases 

811 ): 

812 bases = (models.Model,) 

813 

814 managers = [] 

815 manager_names = set() 

816 default_manager_shim = None 

817 for manager in model._meta.managers: 

818 if manager.name in manager_names: 

819 # Skip overridden managers. 

820 continue 

821 elif manager.use_in_migrations: 

822 # Copy managers usable in migrations. 

823 new_manager = copy.copy(manager) 

824 new_manager._set_creation_counter() 

825 elif manager is model._base_manager or manager is model._default_manager: 

826 # Shim custom managers used as default and base managers. 

827 new_manager = models.Manager() 

828 new_manager.model = manager.model 

829 new_manager.name = manager.name 

830 if manager is model._default_manager: 

831 default_manager_shim = new_manager 

832 else: 

833 continue 

834 manager_names.add(manager.name) 

835 managers.append((manager.name, new_manager)) 

836 

837 # Ignore a shimmed default manager called objects if it's the only one. 

838 if managers == [("objects", default_manager_shim)]: 

839 managers = [] 

840 

841 # Construct the new ModelState 

842 return cls( 

843 model._meta.package_label, 

844 model._meta.object_name, 

845 fields, 

846 options, 

847 bases, 

848 managers, 

849 ) 

850 

851 def construct_managers(self): 

852 """Deep-clone the managers using deconstruction.""" 

853 # Sort all managers by their creation counter 

854 sorted_managers = sorted(self.managers, key=lambda v: v[1].creation_counter) 

855 for mgr_name, manager in sorted_managers: 

856 as_manager, manager_path, qs_path, args, kwargs = manager.deconstruct() 

857 if as_manager: 

858 qs_class = import_string(qs_path) 

859 yield mgr_name, qs_class.as_manager() 

860 else: 

861 manager_class = import_string(manager_path) 

862 yield mgr_name, manager_class(*args, **kwargs) 

863 

864 def clone(self): 

865 """Return an exact copy of this ModelState.""" 

866 return self.__class__( 

867 package_label=self.package_label, 

868 name=self.name, 

869 fields=dict(self.fields), 

870 # Since options are shallow-copied here, operations such as 

871 # AddIndex must replace their option (e.g 'indexes') rather 

872 # than mutating it. 

873 options=dict(self.options), 

874 bases=self.bases, 

875 managers=list(self.managers), 

876 ) 

877 

878 def render(self, packages): 

879 """Create a Model object from our current state into the given packages.""" 

880 # First, make a Meta object 

881 meta_contents = { 

882 "package_label": self.package_label, 

883 "packages": packages, 

884 **self.options, 

885 } 

886 meta = type("Meta", (), meta_contents) 

887 # Then, work out our bases 

888 try: 

889 bases = tuple( 

890 (packages.get_model(base) if isinstance(base, str) else base) 

891 for base in self.bases 

892 ) 

893 except LookupError: 

894 raise InvalidBasesError( 

895 f"Cannot resolve one or more bases from {self.bases!r}" 

896 ) 

897 # Clone fields for the body, add other bits. 

898 body = {name: field.clone() for name, field in self.fields.items()} 

899 body["Meta"] = meta 

900 body["__module__"] = "__fake__" 

901 

902 # Restore managers 

903 body.update(self.construct_managers()) 

904 # Then, make a Model object (packages.register_model is called in __new__) 

905 return type(self.name, bases, body) 

906 

907 def get_index_by_name(self, name): 

908 for index in self.options["indexes"]: 

909 if index.name == name: 

910 return index 

911 raise ValueError(f"No index named {name} on model {self.name}") 

912 

913 def get_constraint_by_name(self, name): 

914 for constraint in self.options["constraints"]: 

915 if constraint.name == name: 

916 return constraint 

917 raise ValueError(f"No constraint named {name} on model {self.name}") 

918 

919 def __repr__(self): 

920 return f"<{self.__class__.__name__}: '{self.package_label}.{self.name}'>" 

921 

922 def __eq__(self, other): 

923 return ( 

924 (self.package_label == other.package_label) 

925 and (self.name == other.name) 

926 and (len(self.fields) == len(other.fields)) 

927 and all( 

928 k1 == k2 and f1.deconstruct()[1:] == f2.deconstruct()[1:] 

929 for (k1, f1), (k2, f2) in zip( 

930 sorted(self.fields.items()), 

931 sorted(other.fields.items()), 

932 ) 

933 ) 

934 and (self.options == other.options) 

935 and (self.bases == other.bases) 

936 and (self.managers == other.managers) 

937 )