Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# orm/persistence.py 

2# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: http://www.opensource.org/licenses/mit-license.php 

7 

8"""private module containing functions used to emit INSERT, UPDATE 

9and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending 

10mappers. 

11 

12The functions here are called only by the unit of work functions 

13in unitofwork.py. 

14 

15""" 

16 

17from itertools import chain 

18from itertools import groupby 

19import operator 

20 

21from . import attributes 

22from . import evaluator 

23from . import exc as orm_exc 

24from . import loading 

25from . import sync 

26from .base import _entity_descriptor 

27from .base import state_str 

28from .. import exc as sa_exc 

29from .. import sql 

30from .. import util 

31from ..sql import expression 

32from ..sql.base import _from_objects 

33 

34 

35def _bulk_insert( 

36 mapper, 

37 mappings, 

38 session_transaction, 

39 isstates, 

40 return_defaults, 

41 render_nulls, 

42): 

43 base_mapper = mapper.base_mapper 

44 

45 cached_connections = _cached_connection_dict(base_mapper) 

46 

47 if session_transaction.session.connection_callable: 

48 raise NotImplementedError( 

49 "connection_callable / per-instance sharding " 

50 "not supported in bulk_insert()" 

51 ) 

52 

53 if isstates: 

54 if return_defaults: 

55 states = [(state, state.dict) for state in mappings] 

56 mappings = [dict_ for (state, dict_) in states] 

57 else: 

58 mappings = [state.dict for state in mappings] 

59 else: 

60 mappings = list(mappings) 

61 

62 connection = session_transaction.connection(base_mapper) 

63 for table, super_mapper in base_mapper._sorted_tables.items(): 

64 if not mapper.isa(super_mapper): 

65 continue 

66 

67 records = ( 

68 ( 

69 None, 

70 state_dict, 

71 params, 

72 mapper, 

73 connection, 

74 value_params, 

75 has_all_pks, 

76 has_all_defaults, 

77 ) 

78 for ( 

79 state, 

80 state_dict, 

81 params, 

82 mp, 

83 conn, 

84 value_params, 

85 has_all_pks, 

86 has_all_defaults, 

87 ) in _collect_insert_commands( 

88 table, 

89 ((None, mapping, mapper, connection) for mapping in mappings), 

90 bulk=True, 

91 return_defaults=return_defaults, 

92 render_nulls=render_nulls, 

93 ) 

94 ) 

95 _emit_insert_statements( 

96 base_mapper, 

97 None, 

98 cached_connections, 

99 super_mapper, 

100 table, 

101 records, 

102 bookkeeping=return_defaults, 

103 ) 

104 

105 if return_defaults and isstates: 

106 identity_cls = mapper._identity_class 

107 identity_props = [p.key for p in mapper._identity_key_props] 

108 for state, dict_ in states: 

109 state.key = ( 

110 identity_cls, 

111 tuple([dict_[key] for key in identity_props]), 

112 ) 

113 

114 

115def _bulk_update( 

116 mapper, mappings, session_transaction, isstates, update_changed_only 

117): 

118 base_mapper = mapper.base_mapper 

119 

120 cached_connections = _cached_connection_dict(base_mapper) 

121 

122 search_keys = mapper._primary_key_propkeys 

123 if mapper._version_id_prop: 

124 search_keys = {mapper._version_id_prop.key}.union(search_keys) 

125 

126 def _changed_dict(mapper, state): 

127 return dict( 

128 (k, v) 

129 for k, v in state.dict.items() 

130 if k in state.committed_state or k in search_keys 

131 ) 

132 

133 if isstates: 

134 if update_changed_only: 

135 mappings = [_changed_dict(mapper, state) for state in mappings] 

136 else: 

137 mappings = [state.dict for state in mappings] 

138 else: 

139 mappings = list(mappings) 

140 

141 if session_transaction.session.connection_callable: 

142 raise NotImplementedError( 

143 "connection_callable / per-instance sharding " 

144 "not supported in bulk_update()" 

145 ) 

146 

147 connection = session_transaction.connection(base_mapper) 

148 

149 for table, super_mapper in base_mapper._sorted_tables.items(): 

150 if not mapper.isa(super_mapper): 

151 continue 

152 

153 records = _collect_update_commands( 

154 None, 

155 table, 

156 ( 

157 ( 

158 None, 

159 mapping, 

160 mapper, 

161 connection, 

162 ( 

163 mapping[mapper._version_id_prop.key] 

164 if mapper._version_id_prop 

165 else None 

166 ), 

167 ) 

168 for mapping in mappings 

169 ), 

170 bulk=True, 

171 ) 

172 

173 _emit_update_statements( 

174 base_mapper, 

175 None, 

176 cached_connections, 

177 super_mapper, 

178 table, 

179 records, 

180 bookkeeping=False, 

181 ) 

182 

183 

184def save_obj(base_mapper, states, uowtransaction, single=False): 

185 """Issue ``INSERT`` and/or ``UPDATE`` statements for a list 

186 of objects. 

187 

188 This is called within the context of a UOWTransaction during a 

189 flush operation, given a list of states to be flushed. The 

190 base mapper in an inheritance hierarchy handles the inserts/ 

191 updates for all descendant mappers. 

192 

193 """ 

194 

195 # if batch=false, call _save_obj separately for each object 

196 if not single and not base_mapper.batch: 

197 for state in _sort_states(base_mapper, states): 

198 save_obj(base_mapper, [state], uowtransaction, single=True) 

199 return 

200 

201 states_to_update = [] 

202 states_to_insert = [] 

203 cached_connections = _cached_connection_dict(base_mapper) 

204 

205 for ( 

206 state, 

207 dict_, 

208 mapper, 

209 connection, 

210 has_identity, 

211 row_switch, 

212 update_version_id, 

213 ) in _organize_states_for_save(base_mapper, states, uowtransaction): 

214 if has_identity or row_switch: 

215 states_to_update.append( 

216 (state, dict_, mapper, connection, update_version_id) 

217 ) 

218 else: 

219 states_to_insert.append((state, dict_, mapper, connection)) 

220 

221 for table, mapper in base_mapper._sorted_tables.items(): 

222 if table not in mapper._pks_by_table: 

223 continue 

224 insert = _collect_insert_commands(table, states_to_insert) 

225 

226 update = _collect_update_commands( 

227 uowtransaction, table, states_to_update 

228 ) 

229 

230 _emit_update_statements( 

231 base_mapper, 

232 uowtransaction, 

233 cached_connections, 

234 mapper, 

235 table, 

236 update, 

237 ) 

238 

239 _emit_insert_statements( 

240 base_mapper, 

241 uowtransaction, 

242 cached_connections, 

243 mapper, 

244 table, 

245 insert, 

246 ) 

247 

248 _finalize_insert_update_commands( 

249 base_mapper, 

250 uowtransaction, 

251 chain( 

252 ( 

253 (state, state_dict, mapper, connection, False) 

254 for (state, state_dict, mapper, connection) in states_to_insert 

255 ), 

256 ( 

257 (state, state_dict, mapper, connection, True) 

258 for ( 

259 state, 

260 state_dict, 

261 mapper, 

262 connection, 

263 update_version_id, 

264 ) in states_to_update 

265 ), 

266 ), 

267 ) 

268 

269 

270def post_update(base_mapper, states, uowtransaction, post_update_cols): 

271 """Issue UPDATE statements on behalf of a relationship() which 

272 specifies post_update. 

273 

274 """ 

275 cached_connections = _cached_connection_dict(base_mapper) 

276 

277 states_to_update = list( 

278 _organize_states_for_post_update(base_mapper, states, uowtransaction) 

279 ) 

280 

281 for table, mapper in base_mapper._sorted_tables.items(): 

282 if table not in mapper._pks_by_table: 

283 continue 

284 

285 update = ( 

286 ( 

287 state, 

288 state_dict, 

289 sub_mapper, 

290 connection, 

291 mapper._get_committed_state_attr_by_column( 

292 state, state_dict, mapper.version_id_col 

293 ) 

294 if mapper.version_id_col is not None 

295 else None, 

296 ) 

297 for state, state_dict, sub_mapper, connection in states_to_update 

298 if table in sub_mapper._pks_by_table 

299 ) 

300 

301 update = _collect_post_update_commands( 

302 base_mapper, uowtransaction, table, update, post_update_cols 

303 ) 

304 

305 _emit_post_update_statements( 

306 base_mapper, 

307 uowtransaction, 

308 cached_connections, 

309 mapper, 

310 table, 

311 update, 

312 ) 

313 

314 

315def delete_obj(base_mapper, states, uowtransaction): 

316 """Issue ``DELETE`` statements for a list of objects. 

317 

318 This is called within the context of a UOWTransaction during a 

319 flush operation. 

320 

321 """ 

322 

323 cached_connections = _cached_connection_dict(base_mapper) 

324 

325 states_to_delete = list( 

326 _organize_states_for_delete(base_mapper, states, uowtransaction) 

327 ) 

328 

329 table_to_mapper = base_mapper._sorted_tables 

330 

331 for table in reversed(list(table_to_mapper.keys())): 

332 mapper = table_to_mapper[table] 

333 if table not in mapper._pks_by_table: 

334 continue 

335 elif mapper.inherits and mapper.passive_deletes: 

336 continue 

337 

338 delete = _collect_delete_commands( 

339 base_mapper, uowtransaction, table, states_to_delete 

340 ) 

341 

342 _emit_delete_statements( 

343 base_mapper, 

344 uowtransaction, 

345 cached_connections, 

346 mapper, 

347 table, 

348 delete, 

349 ) 

350 

351 for ( 

352 state, 

353 state_dict, 

354 mapper, 

355 connection, 

356 update_version_id, 

357 ) in states_to_delete: 

358 mapper.dispatch.after_delete(mapper, connection, state) 

359 

360 

361def _organize_states_for_save(base_mapper, states, uowtransaction): 

362 """Make an initial pass across a set of states for INSERT or 

363 UPDATE. 

364 

365 This includes splitting out into distinct lists for 

366 each, calling before_insert/before_update, obtaining 

367 key information for each state including its dictionary, 

368 mapper, the connection to use for the execution per state, 

369 and the identity flag. 

370 

371 """ 

372 

373 for state, dict_, mapper, connection in _connections_for_states( 

374 base_mapper, uowtransaction, states 

375 ): 

376 

377 has_identity = bool(state.key) 

378 

379 instance_key = state.key or mapper._identity_key_from_state(state) 

380 

381 row_switch = update_version_id = None 

382 

383 # call before_XXX extensions 

384 if not has_identity: 

385 mapper.dispatch.before_insert(mapper, connection, state) 

386 else: 

387 mapper.dispatch.before_update(mapper, connection, state) 

388 

389 if mapper._validate_polymorphic_identity: 

390 mapper._validate_polymorphic_identity(mapper, state, dict_) 

391 

392 # detect if we have a "pending" instance (i.e. has 

393 # no instance_key attached to it), and another instance 

394 # with the same identity key already exists as persistent. 

395 # convert to an UPDATE if so. 

396 if ( 

397 not has_identity 

398 and instance_key in uowtransaction.session.identity_map 

399 ): 

400 instance = uowtransaction.session.identity_map[instance_key] 

401 existing = attributes.instance_state(instance) 

402 

403 if not uowtransaction.was_already_deleted(existing): 

404 if not uowtransaction.is_deleted(existing): 

405 raise orm_exc.FlushError( 

406 "New instance %s with identity key %s conflicts " 

407 "with persistent instance %s" 

408 % (state_str(state), instance_key, state_str(existing)) 

409 ) 

410 

411 base_mapper._log_debug( 

412 "detected row switch for identity %s. " 

413 "will update %s, remove %s from " 

414 "transaction", 

415 instance_key, 

416 state_str(state), 

417 state_str(existing), 

418 ) 

419 

420 # remove the "delete" flag from the existing element 

421 uowtransaction.remove_state_actions(existing) 

422 row_switch = existing 

423 

424 if (has_identity or row_switch) and mapper.version_id_col is not None: 

425 update_version_id = mapper._get_committed_state_attr_by_column( 

426 row_switch if row_switch else state, 

427 row_switch.dict if row_switch else dict_, 

428 mapper.version_id_col, 

429 ) 

430 

431 yield ( 

432 state, 

433 dict_, 

434 mapper, 

435 connection, 

436 has_identity, 

437 row_switch, 

438 update_version_id, 

439 ) 

440 

441 

442def _organize_states_for_post_update(base_mapper, states, uowtransaction): 

443 """Make an initial pass across a set of states for UPDATE 

444 corresponding to post_update. 

445 

446 This includes obtaining key information for each state 

447 including its dictionary, mapper, the connection to use for 

448 the execution per state. 

449 

450 """ 

451 return _connections_for_states(base_mapper, uowtransaction, states) 

452 

453 

454def _organize_states_for_delete(base_mapper, states, uowtransaction): 

455 """Make an initial pass across a set of states for DELETE. 

456 

457 This includes calling out before_delete and obtaining 

458 key information for each state including its dictionary, 

459 mapper, the connection to use for the execution per state. 

460 

461 """ 

462 for state, dict_, mapper, connection in _connections_for_states( 

463 base_mapper, uowtransaction, states 

464 ): 

465 

466 mapper.dispatch.before_delete(mapper, connection, state) 

467 

468 if mapper.version_id_col is not None: 

469 update_version_id = mapper._get_committed_state_attr_by_column( 

470 state, dict_, mapper.version_id_col 

471 ) 

472 else: 

473 update_version_id = None 

474 

475 yield (state, dict_, mapper, connection, update_version_id) 

476 

477 

478def _collect_insert_commands( 

479 table, 

480 states_to_insert, 

481 bulk=False, 

482 return_defaults=False, 

483 render_nulls=False, 

484): 

485 """Identify sets of values to use in INSERT statements for a 

486 list of states. 

487 

488 """ 

489 for state, state_dict, mapper, connection in states_to_insert: 

490 if table not in mapper._pks_by_table: 

491 continue 

492 

493 params = {} 

494 value_params = {} 

495 

496 propkey_to_col = mapper._propkey_to_col[table] 

497 

498 eval_none = mapper._insert_cols_evaluating_none[table] 

499 

500 for propkey in set(propkey_to_col).intersection(state_dict): 

501 value = state_dict[propkey] 

502 col = propkey_to_col[propkey] 

503 if value is None and col not in eval_none and not render_nulls: 

504 continue 

505 elif not bulk and ( 

506 hasattr(value, "__clause_element__") 

507 or isinstance(value, sql.ClauseElement) 

508 ): 

509 value_params[col] = ( 

510 value.__clause_element__() 

511 if hasattr(value, "__clause_element__") 

512 else value 

513 ) 

514 else: 

515 params[col.key] = value 

516 

517 if not bulk: 

518 # for all the columns that have no default and we don't have 

519 # a value and where "None" is not a special value, add 

520 # explicit None to the INSERT. This is a legacy behavior 

521 # which might be worth removing, as it should not be necessary 

522 # and also produces confusion, given that "missing" and None 

523 # now have distinct meanings 

524 for colkey in ( 

525 mapper._insert_cols_as_none[table] 

526 .difference(params) 

527 .difference([c.key for c in value_params]) 

528 ): 

529 params[colkey] = None 

530 

531 if not bulk or return_defaults: 

532 # params are in terms of Column key objects, so 

533 # compare to pk_keys_by_table 

534 has_all_pks = mapper._pk_keys_by_table[table].issubset(params) 

535 

536 if mapper.base_mapper.eager_defaults: 

537 has_all_defaults = mapper._server_default_cols[table].issubset( 

538 params 

539 ) 

540 else: 

541 has_all_defaults = True 

542 else: 

543 has_all_defaults = has_all_pks = True 

544 

545 if ( 

546 mapper.version_id_generator is not False 

547 and mapper.version_id_col is not None 

548 and mapper.version_id_col in mapper._cols_by_table[table] 

549 ): 

550 params[mapper.version_id_col.key] = mapper.version_id_generator( 

551 None 

552 ) 

553 

554 yield ( 

555 state, 

556 state_dict, 

557 params, 

558 mapper, 

559 connection, 

560 value_params, 

561 has_all_pks, 

562 has_all_defaults, 

563 ) 

564 

565 

566def _collect_update_commands( 

567 uowtransaction, table, states_to_update, bulk=False 

568): 

569 """Identify sets of values to use in UPDATE statements for a 

570 list of states. 

571 

572 This function works intricately with the history system 

573 to determine exactly what values should be updated 

574 as well as how the row should be matched within an UPDATE 

575 statement. Includes some tricky scenarios where the primary 

576 key of an object might have been changed. 

577 

578 """ 

579 

580 for ( 

581 state, 

582 state_dict, 

583 mapper, 

584 connection, 

585 update_version_id, 

586 ) in states_to_update: 

587 

588 if table not in mapper._pks_by_table: 

589 continue 

590 

591 pks = mapper._pks_by_table[table] 

592 

593 value_params = {} 

594 

595 propkey_to_col = mapper._propkey_to_col[table] 

596 

597 if bulk: 

598 # keys here are mapped attribute keys, so 

599 # look at mapper attribute keys for pk 

600 params = dict( 

601 (propkey_to_col[propkey].key, state_dict[propkey]) 

602 for propkey in set(propkey_to_col) 

603 .intersection(state_dict) 

604 .difference(mapper._pk_attr_keys_by_table[table]) 

605 ) 

606 has_all_defaults = True 

607 else: 

608 params = {} 

609 for propkey in set(propkey_to_col).intersection( 

610 state.committed_state 

611 ): 

612 value = state_dict[propkey] 

613 col = propkey_to_col[propkey] 

614 

615 if hasattr(value, "__clause_element__") or isinstance( 

616 value, sql.ClauseElement 

617 ): 

618 value_params[col] = ( 

619 value.__clause_element__() 

620 if hasattr(value, "__clause_element__") 

621 else value 

622 ) 

623 # guard against values that generate non-__nonzero__ 

624 # objects for __eq__() 

625 elif ( 

626 state.manager[propkey].impl.is_equal( 

627 value, state.committed_state[propkey] 

628 ) 

629 is not True 

630 ): 

631 params[col.key] = value 

632 

633 if mapper.base_mapper.eager_defaults: 

634 has_all_defaults = ( 

635 mapper._server_onupdate_default_cols[table] 

636 ).issubset(params) 

637 else: 

638 has_all_defaults = True 

639 

640 if ( 

641 update_version_id is not None 

642 and mapper.version_id_col in mapper._cols_by_table[table] 

643 ): 

644 

645 if not bulk and not (params or value_params): 

646 # HACK: check for history in other tables, in case the 

647 # history is only in a different table than the one 

648 # where the version_id_col is. This logic was lost 

649 # from 0.9 -> 1.0.0 and restored in 1.0.6. 

650 for prop in mapper._columntoproperty.values(): 

651 history = state.manager[prop.key].impl.get_history( 

652 state, state_dict, attributes.PASSIVE_NO_INITIALIZE 

653 ) 

654 if history.added: 

655 break 

656 else: 

657 # no net change, break 

658 continue 

659 

660 col = mapper.version_id_col 

661 no_params = not params and not value_params 

662 params[col._label] = update_version_id 

663 

664 if ( 

665 bulk or col.key not in params 

666 ) and mapper.version_id_generator is not False: 

667 val = mapper.version_id_generator(update_version_id) 

668 params[col.key] = val 

669 elif mapper.version_id_generator is False and no_params: 

670 # no version id generator, no values set on the table, 

671 # and version id wasn't manually incremented. 

672 # set version id to itself so we get an UPDATE 

673 # statement 

674 params[col.key] = update_version_id 

675 

676 elif not (params or value_params): 

677 continue 

678 

679 has_all_pks = True 

680 expect_pk_cascaded = False 

681 if bulk: 

682 # keys here are mapped attribute keys, so 

683 # look at mapper attribute keys for pk 

684 pk_params = dict( 

685 (propkey_to_col[propkey]._label, state_dict.get(propkey)) 

686 for propkey in set(propkey_to_col).intersection( 

687 mapper._pk_attr_keys_by_table[table] 

688 ) 

689 ) 

690 else: 

691 pk_params = {} 

692 for col in pks: 

693 propkey = mapper._columntoproperty[col].key 

694 

695 history = state.manager[propkey].impl.get_history( 

696 state, state_dict, attributes.PASSIVE_OFF 

697 ) 

698 

699 if history.added: 

700 if ( 

701 not history.deleted 

702 or ("pk_cascaded", state, col) 

703 in uowtransaction.attributes 

704 ): 

705 expect_pk_cascaded = True 

706 pk_params[col._label] = history.added[0] 

707 params.pop(col.key, None) 

708 else: 

709 # else, use the old value to locate the row 

710 pk_params[col._label] = history.deleted[0] 

711 if col in value_params: 

712 has_all_pks = False 

713 else: 

714 pk_params[col._label] = history.unchanged[0] 

715 if pk_params[col._label] is None: 

716 raise orm_exc.FlushError( 

717 "Can't update table %s using NULL for primary " 

718 "key value on column %s" % (table, col) 

719 ) 

720 

721 if params or value_params: 

722 params.update(pk_params) 

723 yield ( 

724 state, 

725 state_dict, 

726 params, 

727 mapper, 

728 connection, 

729 value_params, 

730 has_all_defaults, 

731 has_all_pks, 

732 ) 

733 elif expect_pk_cascaded: 

734 # no UPDATE occurs on this table, but we expect that CASCADE rules 

735 # have changed the primary key of the row; propagate this event to 

736 # other columns that expect to have been modified. this normally 

737 # occurs after the UPDATE is emitted however we invoke it here 

738 # explicitly in the absence of our invoking an UPDATE 

739 for m, equated_pairs in mapper._table_to_equated[table]: 

740 sync.populate( 

741 state, 

742 m, 

743 state, 

744 m, 

745 equated_pairs, 

746 uowtransaction, 

747 mapper.passive_updates, 

748 ) 

749 

750 

751def _collect_post_update_commands( 

752 base_mapper, uowtransaction, table, states_to_update, post_update_cols 

753): 

754 """Identify sets of values to use in UPDATE statements for a 

755 list of states within a post_update operation. 

756 

757 """ 

758 

759 for ( 

760 state, 

761 state_dict, 

762 mapper, 

763 connection, 

764 update_version_id, 

765 ) in states_to_update: 

766 

767 # assert table in mapper._pks_by_table 

768 

769 pks = mapper._pks_by_table[table] 

770 params = {} 

771 hasdata = False 

772 

773 for col in mapper._cols_by_table[table]: 

774 if col in pks: 

775 params[col._label] = mapper._get_state_attr_by_column( 

776 state, state_dict, col, passive=attributes.PASSIVE_OFF 

777 ) 

778 

779 elif col in post_update_cols or col.onupdate is not None: 

780 prop = mapper._columntoproperty[col] 

781 history = state.manager[prop.key].impl.get_history( 

782 state, state_dict, attributes.PASSIVE_NO_INITIALIZE 

783 ) 

784 if history.added: 

785 value = history.added[0] 

786 params[col.key] = value 

787 hasdata = True 

788 if hasdata: 

789 if ( 

790 update_version_id is not None 

791 and mapper.version_id_col in mapper._cols_by_table[table] 

792 ): 

793 

794 col = mapper.version_id_col 

795 params[col._label] = update_version_id 

796 

797 if ( 

798 bool(state.key) 

799 and col.key not in params 

800 and mapper.version_id_generator is not False 

801 ): 

802 val = mapper.version_id_generator(update_version_id) 

803 params[col.key] = val 

804 yield state, state_dict, mapper, connection, params 

805 

806 

807def _collect_delete_commands( 

808 base_mapper, uowtransaction, table, states_to_delete 

809): 

810 """Identify values to use in DELETE statements for a list of 

811 states to be deleted.""" 

812 

813 for ( 

814 state, 

815 state_dict, 

816 mapper, 

817 connection, 

818 update_version_id, 

819 ) in states_to_delete: 

820 

821 if table not in mapper._pks_by_table: 

822 continue 

823 

824 params = {} 

825 for col in mapper._pks_by_table[table]: 

826 params[ 

827 col.key 

828 ] = value = mapper._get_committed_state_attr_by_column( 

829 state, state_dict, col 

830 ) 

831 if value is None: 

832 raise orm_exc.FlushError( 

833 "Can't delete from table %s " 

834 "using NULL for primary " 

835 "key value on column %s" % (table, col) 

836 ) 

837 

838 if ( 

839 update_version_id is not None 

840 and mapper.version_id_col in mapper._cols_by_table[table] 

841 ): 

842 params[mapper.version_id_col.key] = update_version_id 

843 yield params, connection 

844 

845 

846def _emit_update_statements( 

847 base_mapper, 

848 uowtransaction, 

849 cached_connections, 

850 mapper, 

851 table, 

852 update, 

853 bookkeeping=True, 

854): 

855 """Emit UPDATE statements corresponding to value lists collected 

856 by _collect_update_commands().""" 

857 

858 needs_version_id = ( 

859 mapper.version_id_col is not None 

860 and mapper.version_id_col in mapper._cols_by_table[table] 

861 ) 

862 

863 def update_stmt(): 

864 clause = sql.and_() 

865 

866 for col in mapper._pks_by_table[table]: 

867 clause.clauses.append( 

868 col == sql.bindparam(col._label, type_=col.type) 

869 ) 

870 

871 if needs_version_id: 

872 clause.clauses.append( 

873 mapper.version_id_col 

874 == sql.bindparam( 

875 mapper.version_id_col._label, 

876 type_=mapper.version_id_col.type, 

877 ) 

878 ) 

879 

880 stmt = table.update(clause) 

881 return stmt 

882 

883 cached_stmt = base_mapper._memo(("update", table), update_stmt) 

884 

885 for ( 

886 (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), 

887 records, 

888 ) in groupby( 

889 update, 

890 lambda rec: ( 

891 rec[4], # connection 

892 set(rec[2]), # set of parameter keys 

893 bool(rec[5]), # whether or not we have "value" parameters 

894 rec[6], # has_all_defaults 

895 rec[7], # has all pks 

896 ), 

897 ): 

898 rows = 0 

899 records = list(records) 

900 

901 statement = cached_stmt 

902 return_defaults = False 

903 

904 if not has_all_pks: 

905 statement = statement.return_defaults() 

906 return_defaults = True 

907 elif ( 

908 bookkeeping 

909 and not has_all_defaults 

910 and mapper.base_mapper.eager_defaults 

911 ): 

912 statement = statement.return_defaults() 

913 return_defaults = True 

914 elif mapper.version_id_col is not None: 

915 statement = statement.return_defaults(mapper.version_id_col) 

916 return_defaults = True 

917 

918 assert_singlerow = ( 

919 connection.dialect.supports_sane_rowcount 

920 if not return_defaults 

921 else connection.dialect.supports_sane_rowcount_returning 

922 ) 

923 

924 assert_multirow = ( 

925 assert_singlerow 

926 and connection.dialect.supports_sane_multi_rowcount 

927 ) 

928 allow_multirow = has_all_defaults and not needs_version_id 

929 

930 if hasvalue: 

931 for ( 

932 state, 

933 state_dict, 

934 params, 

935 mapper, 

936 connection, 

937 value_params, 

938 has_all_defaults, 

939 has_all_pks, 

940 ) in records: 

941 c = connection.execute(statement.values(value_params), params) 

942 if bookkeeping: 

943 _postfetch( 

944 mapper, 

945 uowtransaction, 

946 table, 

947 state, 

948 state_dict, 

949 c, 

950 c.context.compiled_parameters[0], 

951 value_params, 

952 True, 

953 ) 

954 rows += c.rowcount 

955 check_rowcount = assert_singlerow 

956 else: 

957 if not allow_multirow: 

958 check_rowcount = assert_singlerow 

959 for ( 

960 state, 

961 state_dict, 

962 params, 

963 mapper, 

964 connection, 

965 value_params, 

966 has_all_defaults, 

967 has_all_pks, 

968 ) in records: 

969 c = cached_connections[connection].execute( 

970 statement, params 

971 ) 

972 

973 # TODO: why with bookkeeping=False? 

974 if bookkeeping: 

975 _postfetch( 

976 mapper, 

977 uowtransaction, 

978 table, 

979 state, 

980 state_dict, 

981 c, 

982 c.context.compiled_parameters[0], 

983 value_params, 

984 True, 

985 ) 

986 rows += c.rowcount 

987 else: 

988 multiparams = [rec[2] for rec in records] 

989 

990 check_rowcount = assert_multirow or ( 

991 assert_singlerow and len(multiparams) == 1 

992 ) 

993 

994 c = cached_connections[connection].execute( 

995 statement, multiparams 

996 ) 

997 

998 rows += c.rowcount 

999 

1000 for ( 

1001 state, 

1002 state_dict, 

1003 params, 

1004 mapper, 

1005 connection, 

1006 value_params, 

1007 has_all_defaults, 

1008 has_all_pks, 

1009 ) in records: 

1010 if bookkeeping: 

1011 _postfetch( 

1012 mapper, 

1013 uowtransaction, 

1014 table, 

1015 state, 

1016 state_dict, 

1017 c, 

1018 c.context.compiled_parameters[0], 

1019 value_params, 

1020 True, 

1021 ) 

1022 

1023 if check_rowcount: 

1024 if rows != len(records): 

1025 raise orm_exc.StaleDataError( 

1026 "UPDATE statement on table '%s' expected to " 

1027 "update %d row(s); %d were matched." 

1028 % (table.description, len(records), rows) 

1029 ) 

1030 

1031 elif needs_version_id: 

1032 util.warn( 

1033 "Dialect %s does not support updated rowcount " 

1034 "- versioning cannot be verified." 

1035 % c.dialect.dialect_description 

1036 ) 

1037 

1038 

1039def _emit_insert_statements( 

1040 base_mapper, 

1041 uowtransaction, 

1042 cached_connections, 

1043 mapper, 

1044 table, 

1045 insert, 

1046 bookkeeping=True, 

1047): 

1048 """Emit INSERT statements corresponding to value lists collected 

1049 by _collect_insert_commands().""" 

1050 

1051 cached_stmt = base_mapper._memo(("insert", table), table.insert) 

1052 

1053 for ( 

1054 (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), 

1055 records, 

1056 ) in groupby( 

1057 insert, 

1058 lambda rec: ( 

1059 rec[4], # connection 

1060 set(rec[2]), # parameter keys 

1061 bool(rec[5]), # whether we have "value" parameters 

1062 rec[6], 

1063 rec[7], 

1064 ), 

1065 ): 

1066 

1067 statement = cached_stmt 

1068 

1069 if ( 

1070 not bookkeeping 

1071 or ( 

1072 has_all_defaults 

1073 or not base_mapper.eager_defaults 

1074 or not connection.dialect.implicit_returning 

1075 ) 

1076 and has_all_pks 

1077 and not hasvalue 

1078 ): 

1079 

1080 records = list(records) 

1081 multiparams = [rec[2] for rec in records] 

1082 

1083 c = cached_connections[connection].execute(statement, multiparams) 

1084 if bookkeeping: 

1085 for ( 

1086 ( 

1087 state, 

1088 state_dict, 

1089 params, 

1090 mapper_rec, 

1091 conn, 

1092 value_params, 

1093 has_all_pks, 

1094 has_all_defaults, 

1095 ), 

1096 last_inserted_params, 

1097 ) in zip(records, c.context.compiled_parameters): 

1098 if state: 

1099 _postfetch( 

1100 mapper_rec, 

1101 uowtransaction, 

1102 table, 

1103 state, 

1104 state_dict, 

1105 c, 

1106 last_inserted_params, 

1107 value_params, 

1108 False, 

1109 ) 

1110 else: 

1111 _postfetch_bulk_save(mapper_rec, state_dict, table) 

1112 

1113 else: 

1114 if not has_all_defaults and base_mapper.eager_defaults: 

1115 statement = statement.return_defaults() 

1116 elif mapper.version_id_col is not None: 

1117 statement = statement.return_defaults(mapper.version_id_col) 

1118 

1119 for ( 

1120 state, 

1121 state_dict, 

1122 params, 

1123 mapper_rec, 

1124 connection, 

1125 value_params, 

1126 has_all_pks, 

1127 has_all_defaults, 

1128 ) in records: 

1129 

1130 if value_params: 

1131 result = connection.execute( 

1132 statement.values(value_params), params 

1133 ) 

1134 else: 

1135 result = cached_connections[connection].execute( 

1136 statement, params 

1137 ) 

1138 

1139 primary_key = result.context.inserted_primary_key 

1140 if primary_key is not None: 

1141 # set primary key attributes 

1142 for pk, col in zip( 

1143 primary_key, mapper._pks_by_table[table] 

1144 ): 

1145 prop = mapper_rec._columntoproperty[col] 

1146 if pk is not None and ( 

1147 col in value_params 

1148 or state_dict.get(prop.key) is None 

1149 ): 

1150 state_dict[prop.key] = pk 

1151 if bookkeeping: 

1152 if state: 

1153 _postfetch( 

1154 mapper_rec, 

1155 uowtransaction, 

1156 table, 

1157 state, 

1158 state_dict, 

1159 result, 

1160 result.context.compiled_parameters[0], 

1161 value_params, 

1162 False, 

1163 ) 

1164 else: 

1165 _postfetch_bulk_save(mapper_rec, state_dict, table) 

1166 

1167 

1168def _emit_post_update_statements( 

1169 base_mapper, uowtransaction, cached_connections, mapper, table, update 

1170): 

1171 """Emit UPDATE statements corresponding to value lists collected 

1172 by _collect_post_update_commands().""" 

1173 

1174 needs_version_id = ( 

1175 mapper.version_id_col is not None 

1176 and mapper.version_id_col in mapper._cols_by_table[table] 

1177 ) 

1178 

1179 def update_stmt(): 

1180 clause = sql.and_() 

1181 

1182 for col in mapper._pks_by_table[table]: 

1183 clause.clauses.append( 

1184 col == sql.bindparam(col._label, type_=col.type) 

1185 ) 

1186 

1187 if needs_version_id: 

1188 clause.clauses.append( 

1189 mapper.version_id_col 

1190 == sql.bindparam( 

1191 mapper.version_id_col._label, 

1192 type_=mapper.version_id_col.type, 

1193 ) 

1194 ) 

1195 

1196 stmt = table.update(clause) 

1197 

1198 if mapper.version_id_col is not None: 

1199 stmt = stmt.return_defaults(mapper.version_id_col) 

1200 

1201 return stmt 

1202 

1203 statement = base_mapper._memo(("post_update", table), update_stmt) 

1204 

1205 # execute each UPDATE in the order according to the original 

1206 # list of states to guarantee row access order, but 

1207 # also group them into common (connection, cols) sets 

1208 # to support executemany(). 

1209 for key, records in groupby( 

1210 update, 

1211 lambda rec: (rec[3], set(rec[4])), # connection # parameter keys 

1212 ): 

1213 rows = 0 

1214 

1215 records = list(records) 

1216 connection = key[0] 

1217 

1218 assert_singlerow = ( 

1219 connection.dialect.supports_sane_rowcount 

1220 if mapper.version_id_col is None 

1221 else connection.dialect.supports_sane_rowcount_returning 

1222 ) 

1223 assert_multirow = ( 

1224 assert_singlerow 

1225 and connection.dialect.supports_sane_multi_rowcount 

1226 ) 

1227 allow_multirow = not needs_version_id or assert_multirow 

1228 

1229 if not allow_multirow: 

1230 check_rowcount = assert_singlerow 

1231 for state, state_dict, mapper_rec, connection, params in records: 

1232 c = cached_connections[connection].execute(statement, params) 

1233 _postfetch_post_update( 

1234 mapper_rec, 

1235 uowtransaction, 

1236 table, 

1237 state, 

1238 state_dict, 

1239 c, 

1240 c.context.compiled_parameters[0], 

1241 ) 

1242 rows += c.rowcount 

1243 else: 

1244 multiparams = [ 

1245 params 

1246 for state, state_dict, mapper_rec, conn, params in records 

1247 ] 

1248 

1249 check_rowcount = assert_multirow or ( 

1250 assert_singlerow and len(multiparams) == 1 

1251 ) 

1252 

1253 c = cached_connections[connection].execute(statement, multiparams) 

1254 

1255 rows += c.rowcount 

1256 for state, state_dict, mapper_rec, connection, params in records: 

1257 _postfetch_post_update( 

1258 mapper_rec, 

1259 uowtransaction, 

1260 table, 

1261 state, 

1262 state_dict, 

1263 c, 

1264 c.context.compiled_parameters[0], 

1265 ) 

1266 

1267 if check_rowcount: 

1268 if rows != len(records): 

1269 raise orm_exc.StaleDataError( 

1270 "UPDATE statement on table '%s' expected to " 

1271 "update %d row(s); %d were matched." 

1272 % (table.description, len(records), rows) 

1273 ) 

1274 

1275 elif needs_version_id: 

1276 util.warn( 

1277 "Dialect %s does not support updated rowcount " 

1278 "- versioning cannot be verified." 

1279 % c.dialect.dialect_description 

1280 ) 

1281 

1282 

1283def _emit_delete_statements( 

1284 base_mapper, uowtransaction, cached_connections, mapper, table, delete 

1285): 

1286 """Emit DELETE statements corresponding to value lists collected 

1287 by _collect_delete_commands().""" 

1288 

1289 need_version_id = ( 

1290 mapper.version_id_col is not None 

1291 and mapper.version_id_col in mapper._cols_by_table[table] 

1292 ) 

1293 

1294 def delete_stmt(): 

1295 clause = sql.and_() 

1296 for col in mapper._pks_by_table[table]: 

1297 clause.clauses.append( 

1298 col == sql.bindparam(col.key, type_=col.type) 

1299 ) 

1300 

1301 if need_version_id: 

1302 clause.clauses.append( 

1303 mapper.version_id_col 

1304 == sql.bindparam( 

1305 mapper.version_id_col.key, type_=mapper.version_id_col.type 

1306 ) 

1307 ) 

1308 

1309 return table.delete(clause) 

1310 

1311 statement = base_mapper._memo(("delete", table), delete_stmt) 

1312 for connection, recs in groupby(delete, lambda rec: rec[1]): # connection 

1313 del_objects = [params for params, connection in recs] 

1314 

1315 connection = cached_connections[connection] 

1316 

1317 expected = len(del_objects) 

1318 rows_matched = -1 

1319 only_warn = False 

1320 

1321 if ( 

1322 need_version_id 

1323 and not connection.dialect.supports_sane_multi_rowcount 

1324 ): 

1325 if connection.dialect.supports_sane_rowcount: 

1326 rows_matched = 0 

1327 # execute deletes individually so that versioned 

1328 # rows can be verified 

1329 for params in del_objects: 

1330 c = connection.execute(statement, params) 

1331 rows_matched += c.rowcount 

1332 else: 

1333 util.warn( 

1334 "Dialect %s does not support deleted rowcount " 

1335 "- versioning cannot be verified." 

1336 % connection.dialect.dialect_description 

1337 ) 

1338 connection.execute(statement, del_objects) 

1339 else: 

1340 c = connection.execute(statement, del_objects) 

1341 

1342 if not need_version_id: 

1343 only_warn = True 

1344 

1345 rows_matched = c.rowcount 

1346 

1347 if ( 

1348 base_mapper.confirm_deleted_rows 

1349 and rows_matched > -1 

1350 and expected != rows_matched 

1351 and ( 

1352 connection.dialect.supports_sane_multi_rowcount 

1353 or len(del_objects) == 1 

1354 ) 

1355 ): 

1356 # TODO: why does this "only warn" if versioning is turned off, 

1357 # whereas the UPDATE raises? 

1358 if only_warn: 

1359 util.warn( 

1360 "DELETE statement on table '%s' expected to " 

1361 "delete %d row(s); %d were matched. Please set " 

1362 "confirm_deleted_rows=False within the mapper " 

1363 "configuration to prevent this warning." 

1364 % (table.description, expected, rows_matched) 

1365 ) 

1366 else: 

1367 raise orm_exc.StaleDataError( 

1368 "DELETE statement on table '%s' expected to " 

1369 "delete %d row(s); %d were matched. Please set " 

1370 "confirm_deleted_rows=False within the mapper " 

1371 "configuration to prevent this warning." 

1372 % (table.description, expected, rows_matched) 

1373 ) 

1374 

1375 

1376def _finalize_insert_update_commands(base_mapper, uowtransaction, states): 

1377 """finalize state on states that have been inserted or updated, 

1378 including calling after_insert/after_update events. 

1379 

1380 """ 

1381 for state, state_dict, mapper, connection, has_identity in states: 

1382 

1383 if mapper._readonly_props: 

1384 readonly = state.unmodified_intersection( 

1385 [ 

1386 p.key 

1387 for p in mapper._readonly_props 

1388 if ( 

1389 p.expire_on_flush 

1390 and (not p.deferred or p.key in state.dict) 

1391 ) 

1392 or ( 

1393 not p.expire_on_flush 

1394 and not p.deferred 

1395 and p.key not in state.dict 

1396 ) 

1397 ] 

1398 ) 

1399 if readonly: 

1400 state._expire_attributes(state.dict, readonly) 

1401 

1402 # if eager_defaults option is enabled, load 

1403 # all expired cols. Else if we have a version_id_col, make sure 

1404 # it isn't expired. 

1405 toload_now = [] 

1406 

1407 if base_mapper.eager_defaults: 

1408 toload_now.extend( 

1409 state._unloaded_non_object.intersection( 

1410 mapper._server_default_plus_onupdate_propkeys 

1411 ) 

1412 ) 

1413 

1414 if ( 

1415 mapper.version_id_col is not None 

1416 and mapper.version_id_generator is False 

1417 ): 

1418 if mapper._version_id_prop.key in state.unloaded: 

1419 toload_now.extend([mapper._version_id_prop.key]) 

1420 

1421 if toload_now: 

1422 state.key = base_mapper._identity_key_from_state(state) 

1423 loading.load_on_ident( 

1424 uowtransaction.session.query(mapper), 

1425 state.key, 

1426 refresh_state=state, 

1427 only_load_props=toload_now, 

1428 ) 

1429 

1430 # call after_XXX extensions 

1431 if not has_identity: 

1432 mapper.dispatch.after_insert(mapper, connection, state) 

1433 else: 

1434 mapper.dispatch.after_update(mapper, connection, state) 

1435 

1436 if ( 

1437 mapper.version_id_generator is False 

1438 and mapper.version_id_col is not None 

1439 ): 

1440 if state_dict[mapper._version_id_prop.key] is None: 

1441 raise orm_exc.FlushError( 

1442 "Instance does not contain a non-NULL version value" 

1443 ) 

1444 

1445 

1446def _postfetch_post_update( 

1447 mapper, uowtransaction, table, state, dict_, result, params 

1448): 

1449 if uowtransaction.is_deleted(state): 

1450 return 

1451 

1452 prefetch_cols = result.context.compiled.prefetch 

1453 postfetch_cols = result.context.compiled.postfetch 

1454 

1455 if ( 

1456 mapper.version_id_col is not None 

1457 and mapper.version_id_col in mapper._cols_by_table[table] 

1458 ): 

1459 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] 

1460 

1461 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) 

1462 if refresh_flush: 

1463 load_evt_attrs = [] 

1464 

1465 for c in prefetch_cols: 

1466 if c.key in params and c in mapper._columntoproperty: 

1467 dict_[mapper._columntoproperty[c].key] = params[c.key] 

1468 if refresh_flush: 

1469 load_evt_attrs.append(mapper._columntoproperty[c].key) 

1470 

1471 if refresh_flush and load_evt_attrs: 

1472 mapper.class_manager.dispatch.refresh_flush( 

1473 state, uowtransaction, load_evt_attrs 

1474 ) 

1475 

1476 if postfetch_cols: 

1477 state._expire_attributes( 

1478 state.dict, 

1479 [ 

1480 mapper._columntoproperty[c].key 

1481 for c in postfetch_cols 

1482 if c in mapper._columntoproperty 

1483 ], 

1484 ) 

1485 

1486 

1487def _postfetch( 

1488 mapper, 

1489 uowtransaction, 

1490 table, 

1491 state, 

1492 dict_, 

1493 result, 

1494 params, 

1495 value_params, 

1496 isupdate, 

1497): 

1498 """Expire attributes in need of newly persisted database state, 

1499 after an INSERT or UPDATE statement has proceeded for that 

1500 state.""" 

1501 

1502 prefetch_cols = result.context.compiled.prefetch 

1503 postfetch_cols = result.context.compiled.postfetch 

1504 returning_cols = result.context.compiled.returning 

1505 

1506 if ( 

1507 mapper.version_id_col is not None 

1508 and mapper.version_id_col in mapper._cols_by_table[table] 

1509 ): 

1510 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] 

1511 

1512 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) 

1513 if refresh_flush: 

1514 load_evt_attrs = [] 

1515 

1516 if returning_cols: 

1517 row = result.context.returned_defaults 

1518 if row is not None: 

1519 for col in returning_cols: 

1520 # pk cols returned from insert are handled 

1521 # distinctly, don't step on the values here 

1522 if col.primary_key and result.context.isinsert: 

1523 continue 

1524 

1525 # note that columns can be in the "return defaults" that are 

1526 # not mapped to this mapper, typically because they are 

1527 # "excluded", which can be specified directly or also occurs 

1528 # when using declarative w/ single table inheritance 

1529 prop = mapper._columntoproperty.get(col) 

1530 if prop: 

1531 dict_[prop.key] = row[col] 

1532 if refresh_flush: 

1533 load_evt_attrs.append(prop.key) 

1534 

1535 for c in prefetch_cols: 

1536 if c.key in params and c in mapper._columntoproperty: 

1537 dict_[mapper._columntoproperty[c].key] = params[c.key] 

1538 if refresh_flush: 

1539 load_evt_attrs.append(mapper._columntoproperty[c].key) 

1540 

1541 if refresh_flush and load_evt_attrs: 

1542 mapper.class_manager.dispatch.refresh_flush( 

1543 state, uowtransaction, load_evt_attrs 

1544 ) 

1545 

1546 if isupdate and value_params: 

1547 # explicitly suit the use case specified by 

1548 # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING 

1549 # database which are set to themselves in order to do a version bump. 

1550 postfetch_cols.extend( 

1551 [ 

1552 col 

1553 for col in value_params 

1554 if col.primary_key and col not in returning_cols 

1555 ] 

1556 ) 

1557 

1558 if postfetch_cols: 

1559 state._expire_attributes( 

1560 state.dict, 

1561 [ 

1562 mapper._columntoproperty[c].key 

1563 for c in postfetch_cols 

1564 if c in mapper._columntoproperty 

1565 ], 

1566 ) 

1567 

1568 # synchronize newly inserted ids from one table to the next 

1569 # TODO: this still goes a little too often. would be nice to 

1570 # have definitive list of "columns that changed" here 

1571 for m, equated_pairs in mapper._table_to_equated[table]: 

1572 sync.populate( 

1573 state, 

1574 m, 

1575 state, 

1576 m, 

1577 equated_pairs, 

1578 uowtransaction, 

1579 mapper.passive_updates, 

1580 ) 

1581 

1582 

1583def _postfetch_bulk_save(mapper, dict_, table): 

1584 for m, equated_pairs in mapper._table_to_equated[table]: 

1585 sync.bulk_populate_inherit_keys(dict_, m, equated_pairs) 

1586 

1587 

1588def _connections_for_states(base_mapper, uowtransaction, states): 

1589 """Return an iterator of (state, state.dict, mapper, connection). 

1590 

1591 The states are sorted according to _sort_states, then paired 

1592 with the connection they should be using for the given 

1593 unit of work transaction. 

1594 

1595 """ 

1596 # if session has a connection callable, 

1597 # organize individual states with the connection 

1598 # to use for update 

1599 if uowtransaction.session.connection_callable: 

1600 connection_callable = uowtransaction.session.connection_callable 

1601 else: 

1602 connection = uowtransaction.transaction.connection(base_mapper) 

1603 connection_callable = None 

1604 

1605 for state in _sort_states(base_mapper, states): 

1606 if connection_callable: 

1607 connection = connection_callable(base_mapper, state.obj()) 

1608 

1609 mapper = state.manager.mapper 

1610 

1611 yield state, state.dict, mapper, connection 

1612 

1613 

1614def _cached_connection_dict(base_mapper): 

1615 # dictionary of connection->connection_with_cache_options. 

1616 return util.PopulateDict( 

1617 lambda conn: conn.execution_options( 

1618 compiled_cache=base_mapper._compiled_cache 

1619 ) 

1620 ) 

1621 

1622 

1623def _sort_states(mapper, states): 

1624 pending = set(states) 

1625 persistent = set(s for s in pending if s.key is not None) 

1626 pending.difference_update(persistent) 

1627 

1628 try: 

1629 persistent_sorted = sorted( 

1630 persistent, key=mapper._persistent_sortkey_fn 

1631 ) 

1632 except TypeError as err: 

1633 util.raise_( 

1634 sa_exc.InvalidRequestError( 

1635 "Could not sort objects by primary key; primary key " 

1636 "values must be sortable in Python (was: %s)" % err 

1637 ), 

1638 replace_context=err, 

1639 ) 

1640 return ( 

1641 sorted(pending, key=operator.attrgetter("insert_order")) 

1642 + persistent_sorted 

1643 ) 

1644 

1645 

1646class BulkUD(object): 

1647 """Handle bulk update and deletes via a :class:`_query.Query`.""" 

1648 

1649 def __init__(self, query): 

1650 self.query = query.enable_eagerloads(False) 

1651 self.mapper = self.query._bind_mapper() 

1652 self._validate_query_state() 

1653 

1654 def _validate_query_state(self): 

1655 for attr, methname, notset, op in ( 

1656 ("_limit", "limit()", None, operator.is_), 

1657 ("_offset", "offset()", None, operator.is_), 

1658 ("_order_by", "order_by()", False, operator.is_), 

1659 ("_group_by", "group_by()", False, operator.is_), 

1660 ("_distinct", "distinct()", False, operator.is_), 

1661 ( 

1662 "_from_obj", 

1663 "join(), outerjoin(), select_from(), or from_self()", 

1664 (), 

1665 operator.eq, 

1666 ), 

1667 ): 

1668 if not op(getattr(self.query, attr), notset): 

1669 raise sa_exc.InvalidRequestError( 

1670 "Can't call Query.update() or Query.delete() " 

1671 "when %s has been called" % (methname,) 

1672 ) 

1673 

1674 @property 

1675 def session(self): 

1676 return self.query.session 

1677 

1678 @classmethod 

1679 def _factory(cls, lookup, synchronize_session, *arg): 

1680 try: 

1681 klass = lookup[synchronize_session] 

1682 except KeyError as err: 

1683 util.raise_( 

1684 sa_exc.ArgumentError( 

1685 "Valid strategies for session synchronization " 

1686 "are %s" % (", ".join(sorted(repr(x) for x in lookup))) 

1687 ), 

1688 replace_context=err, 

1689 ) 

1690 else: 

1691 return klass(*arg) 

1692 

1693 def exec_(self): 

1694 self._do_before_compile() 

1695 self._do_pre() 

1696 self._do_pre_synchronize() 

1697 self._do_exec() 

1698 self._do_post_synchronize() 

1699 self._do_post() 

1700 

1701 def _execute_stmt(self, stmt): 

1702 self.result = self.query._execute_crud(stmt, self.mapper) 

1703 self.rowcount = self.result.rowcount 

1704 

1705 def _do_before_compile(self): 

1706 raise NotImplementedError() 

1707 

1708 @util.dependencies("sqlalchemy.orm.query") 

1709 def _do_pre(self, querylib): 

1710 query = self.query 

1711 

1712 self.context = querylib.QueryContext(query) 

1713 

1714 if isinstance(query._entities[0], querylib._ColumnEntity): 

1715 # check for special case of query(table) 

1716 tables = set() 

1717 for ent in query._entities: 

1718 if not isinstance(ent, querylib._ColumnEntity): 

1719 tables.clear() 

1720 break 

1721 else: 

1722 tables.update(_from_objects(ent.column)) 

1723 

1724 if len(tables) != 1: 

1725 raise sa_exc.InvalidRequestError( 

1726 "This operation requires only one Table or " 

1727 "entity be specified as the target." 

1728 ) 

1729 else: 

1730 self.primary_table = tables.pop() 

1731 

1732 else: 

1733 self.primary_table = query._only_entity_zero( 

1734 "This operation requires only one Table or " 

1735 "entity be specified as the target." 

1736 ).mapper.local_table 

1737 

1738 session = query.session 

1739 

1740 if query._autoflush: 

1741 session._autoflush() 

1742 

1743 def _do_pre_synchronize(self): 

1744 pass 

1745 

1746 def _do_post_synchronize(self): 

1747 pass 

1748 

1749 

1750class BulkEvaluate(BulkUD): 

1751 """BulkUD which does the 'evaluate' method of session state resolution.""" 

1752 

1753 def _additional_evaluators(self, evaluator_compiler): 

1754 pass 

1755 

1756 def _do_pre_synchronize(self): 

1757 query = self.query 

1758 target_cls = query._mapper_zero().class_ 

1759 

1760 try: 

1761 evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) 

1762 if query.whereclause is not None: 

1763 eval_condition = evaluator_compiler.process(query.whereclause) 

1764 else: 

1765 

1766 def eval_condition(obj): 

1767 return True 

1768 

1769 self._additional_evaluators(evaluator_compiler) 

1770 

1771 except evaluator.UnevaluatableError as err: 

1772 util.raise_( 

1773 sa_exc.InvalidRequestError( 

1774 'Could not evaluate current criteria in Python: "%s". ' 

1775 "Specify 'fetch' or False for the " 

1776 "synchronize_session parameter." % err 

1777 ), 

1778 from_=err, 

1779 ) 

1780 

1781 # TODO: detect when the where clause is a trivial primary key match 

1782 self.matched_objects = [ 

1783 obj 

1784 for ( 

1785 cls, 

1786 pk, 

1787 identity_token, 

1788 ), obj in query.session.identity_map.items() 

1789 if issubclass(cls, target_cls) and eval_condition(obj) 

1790 ] 

1791 

1792 

1793class BulkFetch(BulkUD): 

1794 """BulkUD which does the 'fetch' method of session state resolution.""" 

1795 

1796 def _do_pre_synchronize(self): 

1797 query = self.query 

1798 session = query.session 

1799 context = query._compile_context() 

1800 select_stmt = context.statement.with_only_columns( 

1801 self.primary_table.primary_key 

1802 ) 

1803 self.matched_rows = session.execute( 

1804 select_stmt, mapper=self.mapper, params=query._params 

1805 ).fetchall() 

1806 

1807 

1808class BulkUpdate(BulkUD): 

1809 """BulkUD which handles UPDATEs.""" 

1810 

1811 def __init__(self, query, values, update_kwargs): 

1812 super(BulkUpdate, self).__init__(query) 

1813 self.values = values 

1814 self.update_kwargs = update_kwargs 

1815 

1816 @classmethod 

1817 def factory(cls, query, synchronize_session, values, update_kwargs): 

1818 return BulkUD._factory( 

1819 { 

1820 "evaluate": BulkUpdateEvaluate, 

1821 "fetch": BulkUpdateFetch, 

1822 False: BulkUpdate, 

1823 }, 

1824 synchronize_session, 

1825 query, 

1826 values, 

1827 update_kwargs, 

1828 ) 

1829 

1830 def _do_before_compile(self): 

1831 if self.query.dispatch.before_compile_update: 

1832 for fn in self.query.dispatch.before_compile_update: 

1833 new_query = fn(self.query, self) 

1834 if new_query is not None: 

1835 self.query = new_query 

1836 

1837 @property 

1838 def _resolved_values(self): 

1839 values = [] 

1840 for k, v in ( 

1841 self.values.items() 

1842 if hasattr(self.values, "items") 

1843 else self.values 

1844 ): 

1845 if self.mapper: 

1846 if isinstance(k, util.string_types): 

1847 desc = _entity_descriptor(self.mapper, k) 

1848 values.extend(desc._bulk_update_tuples(v)) 

1849 elif isinstance(k, attributes.QueryableAttribute): 

1850 values.extend(k._bulk_update_tuples(v)) 

1851 else: 

1852 values.append((k, v)) 

1853 else: 

1854 values.append((k, v)) 

1855 return values 

1856 

1857 @property 

1858 def _resolved_values_keys_as_propnames(self): 

1859 values = [] 

1860 for k, v in self._resolved_values: 

1861 if isinstance(k, attributes.QueryableAttribute): 

1862 values.append((k.key, v)) 

1863 continue 

1864 elif hasattr(k, "__clause_element__"): 

1865 k = k.__clause_element__() 

1866 

1867 if self.mapper and isinstance(k, expression.ColumnElement): 

1868 try: 

1869 attr = self.mapper._columntoproperty[k] 

1870 except orm_exc.UnmappedColumnError: 

1871 pass 

1872 else: 

1873 values.append((attr.key, v)) 

1874 else: 

1875 raise sa_exc.InvalidRequestError( 

1876 "Invalid expression type: %r" % k 

1877 ) 

1878 return values 

1879 

1880 def _do_exec(self): 

1881 values = self._resolved_values 

1882 

1883 if not self.update_kwargs.get("preserve_parameter_order", False): 

1884 values = dict(values) 

1885 

1886 update_stmt = sql.update( 

1887 self.primary_table, 

1888 self.context.whereclause, 

1889 values, 

1890 **self.update_kwargs 

1891 ) 

1892 

1893 self._execute_stmt(update_stmt) 

1894 

1895 def _do_post(self): 

1896 session = self.query.session 

1897 session.dispatch.after_bulk_update(self) 

1898 

1899 

1900class BulkDelete(BulkUD): 

1901 """BulkUD which handles DELETEs.""" 

1902 

1903 def __init__(self, query): 

1904 super(BulkDelete, self).__init__(query) 

1905 

1906 @classmethod 

1907 def factory(cls, query, synchronize_session): 

1908 return BulkUD._factory( 

1909 { 

1910 "evaluate": BulkDeleteEvaluate, 

1911 "fetch": BulkDeleteFetch, 

1912 False: BulkDelete, 

1913 }, 

1914 synchronize_session, 

1915 query, 

1916 ) 

1917 

1918 def _do_before_compile(self): 

1919 if self.query.dispatch.before_compile_delete: 

1920 for fn in self.query.dispatch.before_compile_delete: 

1921 new_query = fn(self.query, self) 

1922 if new_query is not None: 

1923 self.query = new_query 

1924 

1925 def _do_exec(self): 

1926 delete_stmt = sql.delete(self.primary_table, self.context.whereclause) 

1927 

1928 self._execute_stmt(delete_stmt) 

1929 

1930 def _do_post(self): 

1931 session = self.query.session 

1932 session.dispatch.after_bulk_delete(self) 

1933 

1934 

1935class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): 

1936 """BulkUD which handles UPDATEs using the "evaluate" 

1937 method of session resolution.""" 

1938 

1939 def _additional_evaluators(self, evaluator_compiler): 

1940 self.value_evaluators = {} 

1941 values = self._resolved_values_keys_as_propnames 

1942 for key, value in values: 

1943 self.value_evaluators[key] = evaluator_compiler.process( 

1944 expression._literal_as_binds(value) 

1945 ) 

1946 

1947 def _do_post_synchronize(self): 

1948 session = self.query.session 

1949 states = set() 

1950 evaluated_keys = list(self.value_evaluators.keys()) 

1951 for obj in self.matched_objects: 

1952 state, dict_ = ( 

1953 attributes.instance_state(obj), 

1954 attributes.instance_dict(obj), 

1955 ) 

1956 

1957 # only evaluate unmodified attributes 

1958 to_evaluate = state.unmodified.intersection(evaluated_keys) 

1959 for key in to_evaluate: 

1960 dict_[key] = self.value_evaluators[key](obj) 

1961 

1962 state.manager.dispatch.refresh(state, None, to_evaluate) 

1963 

1964 state._commit(dict_, list(to_evaluate)) 

1965 

1966 # expire attributes with pending changes 

1967 # (there was no autoflush, so they are overwritten) 

1968 state._expire_attributes( 

1969 dict_, set(evaluated_keys).difference(to_evaluate) 

1970 ) 

1971 states.add(state) 

1972 session._register_altered(states) 

1973 

1974 

1975class BulkDeleteEvaluate(BulkEvaluate, BulkDelete): 

1976 """BulkUD which handles DELETEs using the "evaluate" 

1977 method of session resolution.""" 

1978 

1979 def _do_post_synchronize(self): 

1980 self.query.session._remove_newly_deleted( 

1981 [attributes.instance_state(obj) for obj in self.matched_objects] 

1982 ) 

1983 

1984 

1985class BulkUpdateFetch(BulkFetch, BulkUpdate): 

1986 """BulkUD which handles UPDATEs using the "fetch" 

1987 method of session resolution.""" 

1988 

1989 def _do_post_synchronize(self): 

1990 session = self.query.session 

1991 target_mapper = self.query._mapper_zero() 

1992 

1993 states = set( 

1994 [ 

1995 attributes.instance_state(session.identity_map[identity_key]) 

1996 for identity_key in [ 

1997 target_mapper.identity_key_from_primary_key( 

1998 list(primary_key) 

1999 ) 

2000 for primary_key in self.matched_rows 

2001 ] 

2002 if identity_key in session.identity_map 

2003 ] 

2004 ) 

2005 

2006 values = self._resolved_values_keys_as_propnames 

2007 attrib = set(k for k, v in values) 

2008 for state in states: 

2009 to_expire = attrib.intersection(state.dict) 

2010 if to_expire: 

2011 session._expire_state(state, to_expire) 

2012 session._register_altered(states) 

2013 

2014 

2015class BulkDeleteFetch(BulkFetch, BulkDelete): 

2016 """BulkUD which handles DELETEs using the "fetch" 

2017 method of session resolution.""" 

2018 

2019 def _do_post_synchronize(self): 

2020 session = self.query.session 

2021 target_mapper = self.query._mapper_zero() 

2022 for primary_key in self.matched_rows: 

2023 # TODO: inline this and call remove_newly_deleted 

2024 # once 

2025 identity_key = target_mapper.identity_key_from_primary_key( 

2026 list(primary_key) 

2027 ) 

2028 if identity_key in session.identity_map: 

2029 session._remove_newly_deleted( 

2030 [ 

2031 attributes.instance_state( 

2032 session.identity_map[identity_key] 

2033 ) 

2034 ] 

2035 )