sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name, seq_get 12 13logger = logging.getLogger("sqlglot") 14 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML) 16 17 18class ScopeType(Enum): 19 ROOT = auto() 20 SUBQUERY = auto() 21 DERIVED_TABLE = auto() 22 CTE = auto() 23 UNION = auto() 24 UDTF = auto() 25 26 27class Scope: 28 """ 29 Selection scope. 30 31 Attributes: 32 expression (exp.Select|exp.SetOperation): Root expression of this scope 33 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 34 a Table expression or another Scope instance. For example: 35 SELECT * FROM x {"x": Table(this="x")} 36 SELECT * FROM x AS y {"y": Table(this="x")} 37 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 38 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 39 For example: 40 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 41 The LATERAL VIEW EXPLODE gets x as a source. 42 cte_sources (dict[str, Scope]): Sources from CTES 43 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 44 defines a column list for the alias of this scope, this is that list of columns. 45 For example: 46 SELECT * FROM (SELECT ...) AS y(col1, col2) 47 The inner query would have `["col1", "col2"]` for its `outer_columns` 48 parent (Scope): Parent scope 49 scope_type (ScopeType): Type of this scope, relative to it's parent 50 subquery_scopes (list[Scope]): List of all child scopes for subqueries 51 cte_scopes (list[Scope]): List of all child scopes for CTEs 52 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 53 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 54 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 55 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 56 a list of the left and right child scopes. 57 """ 58 59 def __init__( 60 self, 61 expression, 62 sources=None, 63 outer_columns=None, 64 parent=None, 65 scope_type=ScopeType.ROOT, 66 lateral_sources=None, 67 cte_sources=None, 68 can_be_correlated=None, 69 ): 70 self.expression = expression 71 self.sources = sources or {} 72 self.lateral_sources = lateral_sources or {} 73 self.cte_sources = cte_sources or {} 74 self.sources.update(self.lateral_sources) 75 self.sources.update(self.cte_sources) 76 self.outer_columns = outer_columns or [] 77 self.parent = parent 78 self.scope_type = scope_type 79 self.subquery_scopes = [] 80 self.derived_table_scopes = [] 81 self.table_scopes = [] 82 self.cte_scopes = [] 83 self.union_scopes = [] 84 self.udtf_scopes = [] 85 self.can_be_correlated = can_be_correlated 86 self.clear_cache() 87 88 def clear_cache(self): 89 self._collected = False 90 self._raw_columns = None 91 self._stars = None 92 self._derived_tables = None 93 self._udtfs = None 94 self._tables = None 95 self._ctes = None 96 self._subqueries = None 97 self._selected_sources = None 98 self._columns = None 99 self._external_columns = None 100 self._join_hints = None 101 self._pivots = None 102 self._references = None 103 104 def branch( 105 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 106 ): 107 """Branch from the current scope to a new, inner scope""" 108 return Scope( 109 expression=expression.unnest(), 110 sources=sources.copy() if sources else None, 111 parent=self, 112 scope_type=scope_type, 113 cte_sources={**self.cte_sources, **(cte_sources or {})}, 114 lateral_sources=lateral_sources.copy() if lateral_sources else None, 115 can_be_correlated=self.can_be_correlated 116 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 117 **kwargs, 118 ) 119 120 def _collect(self): 121 self._tables = [] 122 self._ctes = [] 123 self._subqueries = [] 124 self._derived_tables = [] 125 self._udtfs = [] 126 self._raw_columns = [] 127 self._stars = [] 128 self._join_hints = [] 129 130 for node in self.walk(bfs=False): 131 if node is self.expression: 132 continue 133 134 if isinstance(node, exp.Dot) and node.is_star: 135 self._stars.append(node) 136 elif isinstance(node, exp.Column): 137 if isinstance(node.this, exp.Star): 138 self._stars.append(node) 139 else: 140 self._raw_columns.append(node) 141 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 142 self._tables.append(node) 143 elif isinstance(node, exp.JoinHint): 144 self._join_hints.append(node) 145 elif isinstance(node, exp.UDTF): 146 self._udtfs.append(node) 147 elif isinstance(node, exp.CTE): 148 self._ctes.append(node) 149 elif _is_derived_table(node) and _is_from_or_join(node): 150 self._derived_tables.append(node) 151 elif isinstance(node, exp.UNWRAPPED_QUERIES): 152 self._subqueries.append(node) 153 154 self._collected = True 155 156 def _ensure_collected(self): 157 if not self._collected: 158 self._collect() 159 160 def walk(self, bfs=True, prune=None): 161 return walk_in_scope(self.expression, bfs=bfs, prune=None) 162 163 def find(self, *expression_types, bfs=True): 164 return find_in_scope(self.expression, expression_types, bfs=bfs) 165 166 def find_all(self, *expression_types, bfs=True): 167 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 168 169 def replace(self, old, new): 170 """ 171 Replace `old` with `new`. 172 173 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 174 175 Args: 176 old (exp.Expression): old node 177 new (exp.Expression): new node 178 """ 179 old.replace(new) 180 self.clear_cache() 181 182 @property 183 def tables(self): 184 """ 185 List of tables in this scope. 186 187 Returns: 188 list[exp.Table]: tables 189 """ 190 self._ensure_collected() 191 return self._tables 192 193 @property 194 def ctes(self): 195 """ 196 List of CTEs in this scope. 197 198 Returns: 199 list[exp.CTE]: ctes 200 """ 201 self._ensure_collected() 202 return self._ctes 203 204 @property 205 def derived_tables(self): 206 """ 207 List of derived tables in this scope. 208 209 For example: 210 SELECT * FROM (SELECT ...) <- that's a derived table 211 212 Returns: 213 list[exp.Subquery]: derived tables 214 """ 215 self._ensure_collected() 216 return self._derived_tables 217 218 @property 219 def udtfs(self): 220 """ 221 List of "User Defined Tabular Functions" in this scope. 222 223 Returns: 224 list[exp.UDTF]: UDTFs 225 """ 226 self._ensure_collected() 227 return self._udtfs 228 229 @property 230 def subqueries(self): 231 """ 232 List of subqueries in this scope. 233 234 For example: 235 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 236 237 Returns: 238 list[exp.Select | exp.SetOperation]: subqueries 239 """ 240 self._ensure_collected() 241 return self._subqueries 242 243 @property 244 def stars(self) -> t.List[exp.Column | exp.Dot]: 245 """ 246 List of star expressions (columns or dots) in this scope. 247 """ 248 self._ensure_collected() 249 return self._stars 250 251 @property 252 def columns(self): 253 """ 254 List of columns in this scope. 255 256 Returns: 257 list[exp.Column]: Column instances in this scope, plus any 258 Columns that reference this scope from correlated subqueries. 259 """ 260 if self._columns is None: 261 self._ensure_collected() 262 columns = self._raw_columns 263 264 external_columns = [ 265 column 266 for scope in itertools.chain( 267 self.subquery_scopes, 268 self.udtf_scopes, 269 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 270 ) 271 for column in scope.external_columns 272 ] 273 274 named_selects = set(self.expression.named_selects) 275 276 self._columns = [] 277 for column in columns + external_columns: 278 ancestor = column.find_ancestor( 279 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 280 ) 281 if ( 282 not ancestor 283 or column.table 284 or isinstance(ancestor, exp.Select) 285 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 286 or ( 287 isinstance(ancestor, exp.Order) 288 and ( 289 isinstance(ancestor.parent, exp.Window) 290 or column.name not in named_selects 291 ) 292 ) 293 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 294 ): 295 self._columns.append(column) 296 297 return self._columns 298 299 @property 300 def selected_sources(self): 301 """ 302 Mapping of nodes and sources that are actually selected from in this scope. 303 304 That is, all tables in a schema are selectable at any point. But a 305 table only becomes a selected source if it's included in a FROM or JOIN clause. 306 307 Returns: 308 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 309 """ 310 if self._selected_sources is None: 311 result = {} 312 313 for name, node in self.references: 314 if name in result: 315 raise OptimizeError(f"Alias already used: {name}") 316 if name in self.sources: 317 result[name] = (node, self.sources[name]) 318 319 self._selected_sources = result 320 return self._selected_sources 321 322 @property 323 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 324 if self._references is None: 325 self._references = [] 326 327 for table in self.tables: 328 self._references.append((table.alias_or_name, table)) 329 for expression in itertools.chain(self.derived_tables, self.udtfs): 330 self._references.append( 331 ( 332 expression.alias, 333 expression if expression.args.get("pivots") else expression.unnest(), 334 ) 335 ) 336 337 return self._references 338 339 @property 340 def external_columns(self): 341 """ 342 Columns that appear to reference sources in outer scopes. 343 344 Returns: 345 list[exp.Column]: Column instances that don't reference 346 sources in the current scope. 347 """ 348 if self._external_columns is None: 349 if isinstance(self.expression, exp.SetOperation): 350 left, right = self.union_scopes 351 self._external_columns = left.external_columns + right.external_columns 352 else: 353 self._external_columns = [ 354 c for c in self.columns if c.table not in self.selected_sources 355 ] 356 357 return self._external_columns 358 359 @property 360 def unqualified_columns(self): 361 """ 362 Unqualified columns in the current scope. 363 364 Returns: 365 list[exp.Column]: Unqualified columns 366 """ 367 return [c for c in self.columns if not c.table] 368 369 @property 370 def join_hints(self): 371 """ 372 Hints that exist in the scope that reference tables 373 374 Returns: 375 list[exp.JoinHint]: Join hints that are referenced within the scope 376 """ 377 if self._join_hints is None: 378 return [] 379 return self._join_hints 380 381 @property 382 def pivots(self): 383 if not self._pivots: 384 self._pivots = [ 385 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 386 ] 387 388 return self._pivots 389 390 def source_columns(self, source_name): 391 """ 392 Get all columns in the current scope for a particular source. 393 394 Args: 395 source_name (str): Name of the source 396 Returns: 397 list[exp.Column]: Column instances that reference `source_name` 398 """ 399 return [column for column in self.columns if column.table == source_name] 400 401 @property 402 def is_subquery(self): 403 """Determine if this scope is a subquery""" 404 return self.scope_type == ScopeType.SUBQUERY 405 406 @property 407 def is_derived_table(self): 408 """Determine if this scope is a derived table""" 409 return self.scope_type == ScopeType.DERIVED_TABLE 410 411 @property 412 def is_union(self): 413 """Determine if this scope is a union""" 414 return self.scope_type == ScopeType.UNION 415 416 @property 417 def is_cte(self): 418 """Determine if this scope is a common table expression""" 419 return self.scope_type == ScopeType.CTE 420 421 @property 422 def is_root(self): 423 """Determine if this is the root scope""" 424 return self.scope_type == ScopeType.ROOT 425 426 @property 427 def is_udtf(self): 428 """Determine if this scope is a UDTF (User Defined Table Function)""" 429 return self.scope_type == ScopeType.UDTF 430 431 @property 432 def is_correlated_subquery(self): 433 """Determine if this scope is a correlated subquery""" 434 return bool(self.can_be_correlated and self.external_columns) 435 436 def rename_source(self, old_name, new_name): 437 """Rename a source in this scope""" 438 columns = self.sources.pop(old_name or "", []) 439 self.sources[new_name] = columns 440 441 def add_source(self, name, source): 442 """Add a source to this scope""" 443 self.sources[name] = source 444 self.clear_cache() 445 446 def remove_source(self, name): 447 """Remove a source from this scope""" 448 self.sources.pop(name, None) 449 self.clear_cache() 450 451 def __repr__(self): 452 return f"Scope<{self.expression.sql()}>" 453 454 def traverse(self): 455 """ 456 Traverse the scope tree from this node. 457 458 Yields: 459 Scope: scope instances in depth-first-search post-order 460 """ 461 stack = [self] 462 result = [] 463 while stack: 464 scope = stack.pop() 465 result.append(scope) 466 stack.extend( 467 itertools.chain( 468 scope.cte_scopes, 469 scope.union_scopes, 470 scope.table_scopes, 471 scope.subquery_scopes, 472 ) 473 ) 474 475 yield from reversed(result) 476 477 def ref_count(self): 478 """ 479 Count the number of times each scope in this tree is referenced. 480 481 Returns: 482 dict[int, int]: Mapping of Scope instance ID to reference count 483 """ 484 scope_ref_count = defaultdict(lambda: 0) 485 486 for scope in self.traverse(): 487 for _, source in scope.selected_sources.values(): 488 scope_ref_count[id(source)] += 1 489 490 return scope_ref_count 491 492 493def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 494 """ 495 Traverse an expression by its "scopes". 496 497 "Scope" represents the current context of a Select statement. 498 499 This is helpful for optimizing queries, where we need more information than 500 the expression tree itself. For example, we might care about the source 501 names within a subquery. Returns a list because a generator could result in 502 incomplete properties which is confusing. 503 504 Examples: 505 >>> import sqlglot 506 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 507 >>> scopes = traverse_scope(expression) 508 >>> scopes[0].expression.sql(), list(scopes[0].sources) 509 ('SELECT a FROM x', ['x']) 510 >>> scopes[1].expression.sql(), list(scopes[1].sources) 511 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 512 513 Args: 514 expression: Expression to traverse 515 516 Returns: 517 A list of the created scope instances 518 """ 519 if isinstance(expression, TRAVERSABLES): 520 return list(_traverse_scope(Scope(expression))) 521 return [] 522 523 524def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 525 """ 526 Build a scope tree. 527 528 Args: 529 expression: Expression to build the scope tree for. 530 531 Returns: 532 The root scope 533 """ 534 return seq_get(traverse_scope(expression), -1) 535 536 537def _traverse_scope(scope): 538 expression = scope.expression 539 540 if isinstance(expression, exp.Select): 541 yield from _traverse_select(scope) 542 elif isinstance(expression, exp.SetOperation): 543 yield from _traverse_ctes(scope) 544 yield from _traverse_union(scope) 545 return 546 elif isinstance(expression, exp.Subquery): 547 if scope.is_root: 548 yield from _traverse_select(scope) 549 else: 550 yield from _traverse_subqueries(scope) 551 elif isinstance(expression, exp.Table): 552 yield from _traverse_tables(scope) 553 elif isinstance(expression, exp.UDTF): 554 yield from _traverse_udtfs(scope) 555 elif isinstance(expression, exp.DDL): 556 if isinstance(expression.expression, exp.Query): 557 yield from _traverse_ctes(scope) 558 yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources)) 559 return 560 elif isinstance(expression, exp.DML): 561 yield from _traverse_ctes(scope) 562 for query in find_all_in_scope(expression, exp.Query): 563 # This check ensures we don't yield the CTE/nested queries twice 564 if not isinstance(query.parent, (exp.CTE, exp.Subquery)): 565 yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) 566 return 567 else: 568 logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression)) 569 return 570 571 yield scope 572 573 574def _traverse_select(scope): 575 yield from _traverse_ctes(scope) 576 yield from _traverse_tables(scope) 577 yield from _traverse_subqueries(scope) 578 579 580def _traverse_union(scope): 581 prev_scope = None 582 union_scope_stack = [scope] 583 expression_stack = [scope.expression.right, scope.expression.left] 584 585 while expression_stack: 586 expression = expression_stack.pop() 587 union_scope = union_scope_stack[-1] 588 589 new_scope = union_scope.branch( 590 expression, 591 outer_columns=union_scope.outer_columns, 592 scope_type=ScopeType.UNION, 593 ) 594 595 if isinstance(expression, exp.SetOperation): 596 yield from _traverse_ctes(new_scope) 597 598 union_scope_stack.append(new_scope) 599 expression_stack.extend([expression.right, expression.left]) 600 continue 601 602 for scope in _traverse_scope(new_scope): 603 yield scope 604 605 if prev_scope: 606 union_scope_stack.pop() 607 union_scope.union_scopes = [prev_scope, scope] 608 prev_scope = union_scope 609 610 yield union_scope 611 else: 612 prev_scope = scope 613 614 615def _traverse_ctes(scope): 616 sources = {} 617 618 for cte in scope.ctes: 619 cte_name = cte.alias 620 621 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 622 # thus the recursive scope is the first section of the union. 623 with_ = scope.expression.args.get("with") 624 if with_ and with_.recursive: 625 union = cte.this 626 627 if isinstance(union, exp.SetOperation): 628 sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) 629 630 child_scope = None 631 632 for child_scope in _traverse_scope( 633 scope.branch( 634 cte.this, 635 cte_sources=sources, 636 outer_columns=cte.alias_column_names, 637 scope_type=ScopeType.CTE, 638 ) 639 ): 640 yield child_scope 641 642 # append the final child_scope yielded 643 if child_scope: 644 sources[cte_name] = child_scope 645 scope.cte_scopes.append(child_scope) 646 647 scope.sources.update(sources) 648 scope.cte_sources.update(sources) 649 650 651def _is_derived_table(expression: exp.Subquery) -> bool: 652 """ 653 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 654 as it doesn't introduce a new scope. If an alias is present, it shadows all names 655 under the Subquery, so that's one exception to this rule. 656 """ 657 return isinstance(expression, exp.Subquery) and bool( 658 expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) 659 ) 660 661 662def _is_from_or_join(expression: exp.Expression) -> bool: 663 """ 664 Determine if `expression` is the FROM or JOIN clause of a SELECT statement. 665 """ 666 parent = expression.parent 667 668 # Subqueries can be arbitrarily nested 669 while isinstance(parent, exp.Subquery): 670 parent = parent.parent 671 672 return isinstance(parent, (exp.From, exp.Join)) 673 674 675def _traverse_tables(scope): 676 sources = {} 677 678 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 679 expressions = [] 680 from_ = scope.expression.args.get("from") 681 if from_: 682 expressions.append(from_.this) 683 684 for join in scope.expression.args.get("joins") or []: 685 expressions.append(join.this) 686 687 if isinstance(scope.expression, exp.Table): 688 expressions.append(scope.expression) 689 690 expressions.extend(scope.expression.args.get("laterals") or []) 691 692 for expression in expressions: 693 if isinstance(expression, exp.Final): 694 expression = expression.this 695 if isinstance(expression, exp.Table): 696 table_name = expression.name 697 source_name = expression.alias_or_name 698 699 if table_name in scope.sources and not expression.db: 700 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 701 # it is pivoted, because then we get back a new table and hence a new source. 702 pivots = expression.args.get("pivots") 703 if pivots: 704 sources[pivots[0].alias] = expression 705 else: 706 sources[source_name] = scope.sources[table_name] 707 elif source_name in sources: 708 sources[find_new_name(sources, table_name)] = expression 709 else: 710 sources[source_name] = expression 711 712 # Make sure to not include the joins twice 713 if expression is not scope.expression: 714 expressions.extend(join.this for join in expression.args.get("joins") or []) 715 716 continue 717 718 if not isinstance(expression, exp.DerivedTable): 719 continue 720 721 if isinstance(expression, exp.UDTF): 722 lateral_sources = sources 723 scope_type = ScopeType.UDTF 724 scopes = scope.udtf_scopes 725 elif _is_derived_table(expression): 726 lateral_sources = None 727 scope_type = ScopeType.DERIVED_TABLE 728 scopes = scope.derived_table_scopes 729 expressions.extend(join.this for join in expression.args.get("joins") or []) 730 else: 731 # Makes sure we check for possible sources in nested table constructs 732 expressions.append(expression.this) 733 expressions.extend(join.this for join in expression.args.get("joins") or []) 734 continue 735 736 for child_scope in _traverse_scope( 737 scope.branch( 738 expression, 739 lateral_sources=lateral_sources, 740 outer_columns=expression.alias_column_names, 741 scope_type=scope_type, 742 ) 743 ): 744 yield child_scope 745 746 # Tables without aliases will be set as "" 747 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 748 # Until then, this means that only a single, unaliased derived table is allowed (rather, 749 # the latest one wins. 750 sources[expression.alias] = child_scope 751 752 # append the final child_scope yielded 753 scopes.append(child_scope) 754 scope.table_scopes.append(child_scope) 755 756 scope.sources.update(sources) 757 758 759def _traverse_subqueries(scope): 760 for subquery in scope.subqueries: 761 top = None 762 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 763 yield child_scope 764 top = child_scope 765 scope.subquery_scopes.append(top) 766 767 768def _traverse_udtfs(scope): 769 if isinstance(scope.expression, exp.Unnest): 770 expressions = scope.expression.expressions 771 elif isinstance(scope.expression, exp.Lateral): 772 expressions = [scope.expression.this] 773 else: 774 expressions = [] 775 776 sources = {} 777 for expression in expressions: 778 if _is_derived_table(expression): 779 top = None 780 for child_scope in _traverse_scope( 781 scope.branch( 782 expression, 783 scope_type=ScopeType.SUBQUERY, 784 outer_columns=expression.alias_column_names, 785 ) 786 ): 787 yield child_scope 788 top = child_scope 789 sources[expression.alias] = child_scope 790 791 scope.subquery_scopes.append(top) 792 793 scope.sources.update(sources) 794 795 796def walk_in_scope(expression, bfs=True, prune=None): 797 """ 798 Returns a generator object which visits all nodes in the syntrax tree, stopping at 799 nodes that start child scopes. 800 801 Args: 802 expression (exp.Expression): 803 bfs (bool): if set to True the BFS traversal order will be applied, 804 otherwise the DFS traversal will be used instead. 805 prune ((node, parent, arg_key) -> bool): callable that returns True if 806 the generator should stop traversing this branch of the tree. 807 808 Yields: 809 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 810 """ 811 # We'll use this variable to pass state into the dfs generator. 812 # Whenever we set it to True, we exclude a subtree from traversal. 813 crossed_scope_boundary = False 814 815 for node in expression.walk( 816 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 817 ): 818 crossed_scope_boundary = False 819 820 yield node 821 822 if node is expression: 823 continue 824 if ( 825 isinstance(node, exp.CTE) 826 or ( 827 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 828 and (_is_derived_table(node) or isinstance(node, exp.UDTF)) 829 ) 830 or isinstance(node, exp.UNWRAPPED_QUERIES) 831 ): 832 crossed_scope_boundary = True 833 834 if isinstance(node, (exp.Subquery, exp.UDTF)): 835 # The following args are not actually in the inner scope, so we should visit them 836 for key in ("joins", "laterals", "pivots"): 837 for arg in node.args.get(key) or []: 838 yield from walk_in_scope(arg, bfs=bfs) 839 840 841def find_all_in_scope(expression, expression_types, bfs=True): 842 """ 843 Returns a generator object which visits all nodes in this scope and only yields those that 844 match at least one of the specified expression types. 845 846 This does NOT traverse into subscopes. 847 848 Args: 849 expression (exp.Expression): 850 expression_types (tuple[type]|type): the expression type(s) to match. 851 bfs (bool): True to use breadth-first search, False to use depth-first. 852 853 Yields: 854 exp.Expression: nodes 855 """ 856 for expression in walk_in_scope(expression, bfs=bfs): 857 if isinstance(expression, tuple(ensure_collection(expression_types))): 858 yield expression 859 860 861def find_in_scope(expression, expression_types, bfs=True): 862 """ 863 Returns the first node in this scope which matches at least one of the specified types. 864 865 This does NOT traverse into subscopes. 866 867 Args: 868 expression (exp.Expression): 869 expression_types (tuple[type]|type): the expression type(s) to match. 870 bfs (bool): True to use breadth-first search, False to use depth-first. 871 872 Returns: 873 exp.Expression: the node which matches the criteria or None if no node matching 874 the criteria was found. 875 """ 876 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
19class ScopeType(Enum): 20 ROOT = auto() 21 SUBQUERY = auto() 22 DERIVED_TABLE = auto() 23 CTE = auto() 24 UNION = auto() 25 UDTF = auto()
An enumeration.
28class Scope: 29 """ 30 Selection scope. 31 32 Attributes: 33 expression (exp.Select|exp.SetOperation): Root expression of this scope 34 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 35 a Table expression or another Scope instance. For example: 36 SELECT * FROM x {"x": Table(this="x")} 37 SELECT * FROM x AS y {"y": Table(this="x")} 38 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 39 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 40 For example: 41 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 42 The LATERAL VIEW EXPLODE gets x as a source. 43 cte_sources (dict[str, Scope]): Sources from CTES 44 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 45 defines a column list for the alias of this scope, this is that list of columns. 46 For example: 47 SELECT * FROM (SELECT ...) AS y(col1, col2) 48 The inner query would have `["col1", "col2"]` for its `outer_columns` 49 parent (Scope): Parent scope 50 scope_type (ScopeType): Type of this scope, relative to it's parent 51 subquery_scopes (list[Scope]): List of all child scopes for subqueries 52 cte_scopes (list[Scope]): List of all child scopes for CTEs 53 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 54 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 55 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 56 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 57 a list of the left and right child scopes. 58 """ 59 60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache() 88 89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._stars = None 93 self._derived_tables = None 94 self._udtfs = None 95 self._tables = None 96 self._ctes = None 97 self._subqueries = None 98 self._selected_sources = None 99 self._columns = None 100 self._external_columns = None 101 self._join_hints = None 102 self._pivots = None 103 self._references = None 104 105 def branch( 106 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 107 ): 108 """Branch from the current scope to a new, inner scope""" 109 return Scope( 110 expression=expression.unnest(), 111 sources=sources.copy() if sources else None, 112 parent=self, 113 scope_type=scope_type, 114 cte_sources={**self.cte_sources, **(cte_sources or {})}, 115 lateral_sources=lateral_sources.copy() if lateral_sources else None, 116 can_be_correlated=self.can_be_correlated 117 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 118 **kwargs, 119 ) 120 121 def _collect(self): 122 self._tables = [] 123 self._ctes = [] 124 self._subqueries = [] 125 self._derived_tables = [] 126 self._udtfs = [] 127 self._raw_columns = [] 128 self._stars = [] 129 self._join_hints = [] 130 131 for node in self.walk(bfs=False): 132 if node is self.expression: 133 continue 134 135 if isinstance(node, exp.Dot) and node.is_star: 136 self._stars.append(node) 137 elif isinstance(node, exp.Column): 138 if isinstance(node.this, exp.Star): 139 self._stars.append(node) 140 else: 141 self._raw_columns.append(node) 142 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 143 self._tables.append(node) 144 elif isinstance(node, exp.JoinHint): 145 self._join_hints.append(node) 146 elif isinstance(node, exp.UDTF): 147 self._udtfs.append(node) 148 elif isinstance(node, exp.CTE): 149 self._ctes.append(node) 150 elif _is_derived_table(node) and _is_from_or_join(node): 151 self._derived_tables.append(node) 152 elif isinstance(node, exp.UNWRAPPED_QUERIES): 153 self._subqueries.append(node) 154 155 self._collected = True 156 157 def _ensure_collected(self): 158 if not self._collected: 159 self._collect() 160 161 def walk(self, bfs=True, prune=None): 162 return walk_in_scope(self.expression, bfs=bfs, prune=None) 163 164 def find(self, *expression_types, bfs=True): 165 return find_in_scope(self.expression, expression_types, bfs=bfs) 166 167 def find_all(self, *expression_types, bfs=True): 168 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 169 170 def replace(self, old, new): 171 """ 172 Replace `old` with `new`. 173 174 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 175 176 Args: 177 old (exp.Expression): old node 178 new (exp.Expression): new node 179 """ 180 old.replace(new) 181 self.clear_cache() 182 183 @property 184 def tables(self): 185 """ 186 List of tables in this scope. 187 188 Returns: 189 list[exp.Table]: tables 190 """ 191 self._ensure_collected() 192 return self._tables 193 194 @property 195 def ctes(self): 196 """ 197 List of CTEs in this scope. 198 199 Returns: 200 list[exp.CTE]: ctes 201 """ 202 self._ensure_collected() 203 return self._ctes 204 205 @property 206 def derived_tables(self): 207 """ 208 List of derived tables in this scope. 209 210 For example: 211 SELECT * FROM (SELECT ...) <- that's a derived table 212 213 Returns: 214 list[exp.Subquery]: derived tables 215 """ 216 self._ensure_collected() 217 return self._derived_tables 218 219 @property 220 def udtfs(self): 221 """ 222 List of "User Defined Tabular Functions" in this scope. 223 224 Returns: 225 list[exp.UDTF]: UDTFs 226 """ 227 self._ensure_collected() 228 return self._udtfs 229 230 @property 231 def subqueries(self): 232 """ 233 List of subqueries in this scope. 234 235 For example: 236 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 237 238 Returns: 239 list[exp.Select | exp.SetOperation]: subqueries 240 """ 241 self._ensure_collected() 242 return self._subqueries 243 244 @property 245 def stars(self) -> t.List[exp.Column | exp.Dot]: 246 """ 247 List of star expressions (columns or dots) in this scope. 248 """ 249 self._ensure_collected() 250 return self._stars 251 252 @property 253 def columns(self): 254 """ 255 List of columns in this scope. 256 257 Returns: 258 list[exp.Column]: Column instances in this scope, plus any 259 Columns that reference this scope from correlated subqueries. 260 """ 261 if self._columns is None: 262 self._ensure_collected() 263 columns = self._raw_columns 264 265 external_columns = [ 266 column 267 for scope in itertools.chain( 268 self.subquery_scopes, 269 self.udtf_scopes, 270 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 271 ) 272 for column in scope.external_columns 273 ] 274 275 named_selects = set(self.expression.named_selects) 276 277 self._columns = [] 278 for column in columns + external_columns: 279 ancestor = column.find_ancestor( 280 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 281 ) 282 if ( 283 not ancestor 284 or column.table 285 or isinstance(ancestor, exp.Select) 286 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 287 or ( 288 isinstance(ancestor, exp.Order) 289 and ( 290 isinstance(ancestor.parent, exp.Window) 291 or column.name not in named_selects 292 ) 293 ) 294 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 295 ): 296 self._columns.append(column) 297 298 return self._columns 299 300 @property 301 def selected_sources(self): 302 """ 303 Mapping of nodes and sources that are actually selected from in this scope. 304 305 That is, all tables in a schema are selectable at any point. But a 306 table only becomes a selected source if it's included in a FROM or JOIN clause. 307 308 Returns: 309 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 310 """ 311 if self._selected_sources is None: 312 result = {} 313 314 for name, node in self.references: 315 if name in result: 316 raise OptimizeError(f"Alias already used: {name}") 317 if name in self.sources: 318 result[name] = (node, self.sources[name]) 319 320 self._selected_sources = result 321 return self._selected_sources 322 323 @property 324 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 325 if self._references is None: 326 self._references = [] 327 328 for table in self.tables: 329 self._references.append((table.alias_or_name, table)) 330 for expression in itertools.chain(self.derived_tables, self.udtfs): 331 self._references.append( 332 ( 333 expression.alias, 334 expression if expression.args.get("pivots") else expression.unnest(), 335 ) 336 ) 337 338 return self._references 339 340 @property 341 def external_columns(self): 342 """ 343 Columns that appear to reference sources in outer scopes. 344 345 Returns: 346 list[exp.Column]: Column instances that don't reference 347 sources in the current scope. 348 """ 349 if self._external_columns is None: 350 if isinstance(self.expression, exp.SetOperation): 351 left, right = self.union_scopes 352 self._external_columns = left.external_columns + right.external_columns 353 else: 354 self._external_columns = [ 355 c for c in self.columns if c.table not in self.selected_sources 356 ] 357 358 return self._external_columns 359 360 @property 361 def unqualified_columns(self): 362 """ 363 Unqualified columns in the current scope. 364 365 Returns: 366 list[exp.Column]: Unqualified columns 367 """ 368 return [c for c in self.columns if not c.table] 369 370 @property 371 def join_hints(self): 372 """ 373 Hints that exist in the scope that reference tables 374 375 Returns: 376 list[exp.JoinHint]: Join hints that are referenced within the scope 377 """ 378 if self._join_hints is None: 379 return [] 380 return self._join_hints 381 382 @property 383 def pivots(self): 384 if not self._pivots: 385 self._pivots = [ 386 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 387 ] 388 389 return self._pivots 390 391 def source_columns(self, source_name): 392 """ 393 Get all columns in the current scope for a particular source. 394 395 Args: 396 source_name (str): Name of the source 397 Returns: 398 list[exp.Column]: Column instances that reference `source_name` 399 """ 400 return [column for column in self.columns if column.table == source_name] 401 402 @property 403 def is_subquery(self): 404 """Determine if this scope is a subquery""" 405 return self.scope_type == ScopeType.SUBQUERY 406 407 @property 408 def is_derived_table(self): 409 """Determine if this scope is a derived table""" 410 return self.scope_type == ScopeType.DERIVED_TABLE 411 412 @property 413 def is_union(self): 414 """Determine if this scope is a union""" 415 return self.scope_type == ScopeType.UNION 416 417 @property 418 def is_cte(self): 419 """Determine if this scope is a common table expression""" 420 return self.scope_type == ScopeType.CTE 421 422 @property 423 def is_root(self): 424 """Determine if this is the root scope""" 425 return self.scope_type == ScopeType.ROOT 426 427 @property 428 def is_udtf(self): 429 """Determine if this scope is a UDTF (User Defined Table Function)""" 430 return self.scope_type == ScopeType.UDTF 431 432 @property 433 def is_correlated_subquery(self): 434 """Determine if this scope is a correlated subquery""" 435 return bool(self.can_be_correlated and self.external_columns) 436 437 def rename_source(self, old_name, new_name): 438 """Rename a source in this scope""" 439 columns = self.sources.pop(old_name or "", []) 440 self.sources[new_name] = columns 441 442 def add_source(self, name, source): 443 """Add a source to this scope""" 444 self.sources[name] = source 445 self.clear_cache() 446 447 def remove_source(self, name): 448 """Remove a source from this scope""" 449 self.sources.pop(name, None) 450 self.clear_cache() 451 452 def __repr__(self): 453 return f"Scope<{self.expression.sql()}>" 454 455 def traverse(self): 456 """ 457 Traverse the scope tree from this node. 458 459 Yields: 460 Scope: scope instances in depth-first-search post-order 461 """ 462 stack = [self] 463 result = [] 464 while stack: 465 scope = stack.pop() 466 result.append(scope) 467 stack.extend( 468 itertools.chain( 469 scope.cte_scopes, 470 scope.union_scopes, 471 scope.table_scopes, 472 scope.subquery_scopes, 473 ) 474 ) 475 476 yield from reversed(result) 477 478 def ref_count(self): 479 """ 480 Count the number of times each scope in this tree is referenced. 481 482 Returns: 483 dict[int, int]: Mapping of Scope instance ID to reference count 484 """ 485 scope_ref_count = defaultdict(lambda: 0) 486 487 for scope in self.traverse(): 488 for _, source in scope.selected_sources.values(): 489 scope_ref_count[id(source)] += 1 490 491 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.SetOperation): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_columns
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache()
89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._stars = None 93 self._derived_tables = None 94 self._udtfs = None 95 self._tables = None 96 self._ctes = None 97 self._subqueries = None 98 self._selected_sources = None 99 self._columns = None 100 self._external_columns = None 101 self._join_hints = None 102 self._pivots = None 103 self._references = None
105 def branch( 106 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 107 ): 108 """Branch from the current scope to a new, inner scope""" 109 return Scope( 110 expression=expression.unnest(), 111 sources=sources.copy() if sources else None, 112 parent=self, 113 scope_type=scope_type, 114 cte_sources={**self.cte_sources, **(cte_sources or {})}, 115 lateral_sources=lateral_sources.copy() if lateral_sources else None, 116 can_be_correlated=self.can_be_correlated 117 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 118 **kwargs, 119 )
Branch from the current scope to a new, inner scope
170 def replace(self, old, new): 171 """ 172 Replace `old` with `new`. 173 174 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 175 176 Args: 177 old (exp.Expression): old node 178 new (exp.Expression): new node 179 """ 180 old.replace(new) 181 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
183 @property 184 def tables(self): 185 """ 186 List of tables in this scope. 187 188 Returns: 189 list[exp.Table]: tables 190 """ 191 self._ensure_collected() 192 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
194 @property 195 def ctes(self): 196 """ 197 List of CTEs in this scope. 198 199 Returns: 200 list[exp.CTE]: ctes 201 """ 202 self._ensure_collected() 203 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
205 @property 206 def derived_tables(self): 207 """ 208 List of derived tables in this scope. 209 210 For example: 211 SELECT * FROM (SELECT ...) <- that's a derived table 212 213 Returns: 214 list[exp.Subquery]: derived tables 215 """ 216 self._ensure_collected() 217 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
219 @property 220 def udtfs(self): 221 """ 222 List of "User Defined Tabular Functions" in this scope. 223 224 Returns: 225 list[exp.UDTF]: UDTFs 226 """ 227 self._ensure_collected() 228 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
230 @property 231 def subqueries(self): 232 """ 233 List of subqueries in this scope. 234 235 For example: 236 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 237 238 Returns: 239 list[exp.Select | exp.SetOperation]: subqueries 240 """ 241 self._ensure_collected() 242 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.SetOperation]: subqueries
244 @property 245 def stars(self) -> t.List[exp.Column | exp.Dot]: 246 """ 247 List of star expressions (columns or dots) in this scope. 248 """ 249 self._ensure_collected() 250 return self._stars
List of star expressions (columns or dots) in this scope.
252 @property 253 def columns(self): 254 """ 255 List of columns in this scope. 256 257 Returns: 258 list[exp.Column]: Column instances in this scope, plus any 259 Columns that reference this scope from correlated subqueries. 260 """ 261 if self._columns is None: 262 self._ensure_collected() 263 columns = self._raw_columns 264 265 external_columns = [ 266 column 267 for scope in itertools.chain( 268 self.subquery_scopes, 269 self.udtf_scopes, 270 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 271 ) 272 for column in scope.external_columns 273 ] 274 275 named_selects = set(self.expression.named_selects) 276 277 self._columns = [] 278 for column in columns + external_columns: 279 ancestor = column.find_ancestor( 280 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 281 ) 282 if ( 283 not ancestor 284 or column.table 285 or isinstance(ancestor, exp.Select) 286 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 287 or ( 288 isinstance(ancestor, exp.Order) 289 and ( 290 isinstance(ancestor.parent, exp.Window) 291 or column.name not in named_selects 292 ) 293 ) 294 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 295 ): 296 self._columns.append(column) 297 298 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
300 @property 301 def selected_sources(self): 302 """ 303 Mapping of nodes and sources that are actually selected from in this scope. 304 305 That is, all tables in a schema are selectable at any point. But a 306 table only becomes a selected source if it's included in a FROM or JOIN clause. 307 308 Returns: 309 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 310 """ 311 if self._selected_sources is None: 312 result = {} 313 314 for name, node in self.references: 315 if name in result: 316 raise OptimizeError(f"Alias already used: {name}") 317 if name in self.sources: 318 result[name] = (node, self.sources[name]) 319 320 self._selected_sources = result 321 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
323 @property 324 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 325 if self._references is None: 326 self._references = [] 327 328 for table in self.tables: 329 self._references.append((table.alias_or_name, table)) 330 for expression in itertools.chain(self.derived_tables, self.udtfs): 331 self._references.append( 332 ( 333 expression.alias, 334 expression if expression.args.get("pivots") else expression.unnest(), 335 ) 336 ) 337 338 return self._references
340 @property 341 def external_columns(self): 342 """ 343 Columns that appear to reference sources in outer scopes. 344 345 Returns: 346 list[exp.Column]: Column instances that don't reference 347 sources in the current scope. 348 """ 349 if self._external_columns is None: 350 if isinstance(self.expression, exp.SetOperation): 351 left, right = self.union_scopes 352 self._external_columns = left.external_columns + right.external_columns 353 else: 354 self._external_columns = [ 355 c for c in self.columns if c.table not in self.selected_sources 356 ] 357 358 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
360 @property 361 def unqualified_columns(self): 362 """ 363 Unqualified columns in the current scope. 364 365 Returns: 366 list[exp.Column]: Unqualified columns 367 """ 368 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
370 @property 371 def join_hints(self): 372 """ 373 Hints that exist in the scope that reference tables 374 375 Returns: 376 list[exp.JoinHint]: Join hints that are referenced within the scope 377 """ 378 if self._join_hints is None: 379 return [] 380 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
391 def source_columns(self, source_name): 392 """ 393 Get all columns in the current scope for a particular source. 394 395 Args: 396 source_name (str): Name of the source 397 Returns: 398 list[exp.Column]: Column instances that reference `source_name` 399 """ 400 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
402 @property 403 def is_subquery(self): 404 """Determine if this scope is a subquery""" 405 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
407 @property 408 def is_derived_table(self): 409 """Determine if this scope is a derived table""" 410 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
412 @property 413 def is_union(self): 414 """Determine if this scope is a union""" 415 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
417 @property 418 def is_cte(self): 419 """Determine if this scope is a common table expression""" 420 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
422 @property 423 def is_root(self): 424 """Determine if this is the root scope""" 425 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
427 @property 428 def is_udtf(self): 429 """Determine if this scope is a UDTF (User Defined Table Function)""" 430 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
437 def rename_source(self, old_name, new_name): 438 """Rename a source in this scope""" 439 columns = self.sources.pop(old_name or "", []) 440 self.sources[new_name] = columns
Rename a source in this scope
442 def add_source(self, name, source): 443 """Add a source to this scope""" 444 self.sources[name] = source 445 self.clear_cache()
Add a source to this scope
447 def remove_source(self, name): 448 """Remove a source from this scope""" 449 self.sources.pop(name, None) 450 self.clear_cache()
Remove a source from this scope
455 def traverse(self): 456 """ 457 Traverse the scope tree from this node. 458 459 Yields: 460 Scope: scope instances in depth-first-search post-order 461 """ 462 stack = [self] 463 result = [] 464 while stack: 465 scope = stack.pop() 466 result.append(scope) 467 stack.extend( 468 itertools.chain( 469 scope.cte_scopes, 470 scope.union_scopes, 471 scope.table_scopes, 472 scope.subquery_scopes, 473 ) 474 ) 475 476 yield from reversed(result)
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
478 def ref_count(self): 479 """ 480 Count the number of times each scope in this tree is referenced. 481 482 Returns: 483 dict[int, int]: Mapping of Scope instance ID to reference count 484 """ 485 scope_ref_count = defaultdict(lambda: 0) 486 487 for scope in self.traverse(): 488 for _, source in scope.selected_sources.values(): 489 scope_ref_count[id(source)] += 1 490 491 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
494def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 495 """ 496 Traverse an expression by its "scopes". 497 498 "Scope" represents the current context of a Select statement. 499 500 This is helpful for optimizing queries, where we need more information than 501 the expression tree itself. For example, we might care about the source 502 names within a subquery. Returns a list because a generator could result in 503 incomplete properties which is confusing. 504 505 Examples: 506 >>> import sqlglot 507 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 508 >>> scopes = traverse_scope(expression) 509 >>> scopes[0].expression.sql(), list(scopes[0].sources) 510 ('SELECT a FROM x', ['x']) 511 >>> scopes[1].expression.sql(), list(scopes[1].sources) 512 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 513 514 Args: 515 expression: Expression to traverse 516 517 Returns: 518 A list of the created scope instances 519 """ 520 if isinstance(expression, TRAVERSABLES): 521 return list(_traverse_scope(Scope(expression))) 522 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression: Expression to traverse
Returns:
A list of the created scope instances
525def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 526 """ 527 Build a scope tree. 528 529 Args: 530 expression: Expression to build the scope tree for. 531 532 Returns: 533 The root scope 534 """ 535 return seq_get(traverse_scope(expression), -1)
Build a scope tree.
Arguments:
- expression: Expression to build the scope tree for.
Returns:
The root scope
797def walk_in_scope(expression, bfs=True, prune=None): 798 """ 799 Returns a generator object which visits all nodes in the syntrax tree, stopping at 800 nodes that start child scopes. 801 802 Args: 803 expression (exp.Expression): 804 bfs (bool): if set to True the BFS traversal order will be applied, 805 otherwise the DFS traversal will be used instead. 806 prune ((node, parent, arg_key) -> bool): callable that returns True if 807 the generator should stop traversing this branch of the tree. 808 809 Yields: 810 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 811 """ 812 # We'll use this variable to pass state into the dfs generator. 813 # Whenever we set it to True, we exclude a subtree from traversal. 814 crossed_scope_boundary = False 815 816 for node in expression.walk( 817 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 818 ): 819 crossed_scope_boundary = False 820 821 yield node 822 823 if node is expression: 824 continue 825 if ( 826 isinstance(node, exp.CTE) 827 or ( 828 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 829 and (_is_derived_table(node) or isinstance(node, exp.UDTF)) 830 ) 831 or isinstance(node, exp.UNWRAPPED_QUERIES) 832 ): 833 crossed_scope_boundary = True 834 835 if isinstance(node, (exp.Subquery, exp.UDTF)): 836 # The following args are not actually in the inner scope, so we should visit them 837 for key in ("joins", "laterals", "pivots"): 838 for arg in node.args.get(key) or []: 839 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
842def find_all_in_scope(expression, expression_types, bfs=True): 843 """ 844 Returns a generator object which visits all nodes in this scope and only yields those that 845 match at least one of the specified expression types. 846 847 This does NOT traverse into subscopes. 848 849 Args: 850 expression (exp.Expression): 851 expression_types (tuple[type]|type): the expression type(s) to match. 852 bfs (bool): True to use breadth-first search, False to use depth-first. 853 854 Yields: 855 exp.Expression: nodes 856 """ 857 for expression in walk_in_scope(expression, bfs=bfs): 858 if isinstance(expression, tuple(ensure_collection(expression_types))): 859 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
862def find_in_scope(expression, expression_types, bfs=True): 863 """ 864 Returns the first node in this scope which matches at least one of the specified types. 865 866 This does NOT traverse into subscopes. 867 868 Args: 869 expression (exp.Expression): 870 expression_types (tuple[type]|type): the expression type(s) to match. 871 bfs (bool): True to use breadth-first search, False to use depth-first. 872 873 Returns: 874 exp.Expression: the node which matches the criteria or None if no node matching 875 the criteria was found. 876 """ 877 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.