sqlglot.optimizer.scope
1import itertools 2import logging 3import typing as t 4from collections import defaultdict 5from enum import Enum, auto 6 7from sqlglot import exp 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import find_new_name 10 11logger = logging.getLogger("sqlglot") 12 13 14class ScopeType(Enum): 15 ROOT = auto() 16 SUBQUERY = auto() 17 DERIVED_TABLE = auto() 18 CTE = auto() 19 UNION = auto() 20 UDTF = auto() 21 22 23class Scope: 24 """ 25 Selection scope. 26 27 Attributes: 28 expression (exp.Select|exp.Union): Root expression of this scope 29 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 30 a Table expression or another Scope instance. For example: 31 SELECT * FROM x {"x": Table(this="x")} 32 SELECT * FROM x AS y {"y": Table(this="x")} 33 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 34 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 35 For example: 36 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 37 The LATERAL VIEW EXPLODE gets x as a source. 38 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 39 defines a column list of it's alias of this scope, this is that list of columns. 40 For example: 41 SELECT * FROM (SELECT ...) AS y(col1, col2) 42 The inner query would have `["col1", "col2"]` for its `outer_column_list` 43 parent (Scope): Parent scope 44 scope_type (ScopeType): Type of this scope, relative to it's parent 45 subquery_scopes (list[Scope]): List of all child scopes for subqueries 46 cte_scopes (list[Scope]): List of all child scopes for CTEs 47 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 48 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 49 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 50 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 51 a list of the left and right child scopes. 52 """ 53 54 def __init__( 55 self, 56 expression, 57 sources=None, 58 outer_column_list=None, 59 parent=None, 60 scope_type=ScopeType.ROOT, 61 lateral_sources=None, 62 ): 63 self.expression = expression 64 self.sources = sources or {} 65 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 66 self.sources.update(self.lateral_sources) 67 self.outer_column_list = outer_column_list or [] 68 self.parent = parent 69 self.scope_type = scope_type 70 self.subquery_scopes = [] 71 self.derived_table_scopes = [] 72 self.table_scopes = [] 73 self.cte_scopes = [] 74 self.union_scopes = [] 75 self.udtf_scopes = [] 76 self.clear_cache() 77 78 def clear_cache(self): 79 self._collected = False 80 self._raw_columns = None 81 self._derived_tables = None 82 self._udtfs = None 83 self._tables = None 84 self._ctes = None 85 self._subqueries = None 86 self._selected_sources = None 87 self._columns = None 88 self._external_columns = None 89 self._join_hints = None 90 self._pivots = None 91 self._references = None 92 93 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 94 """Branch from the current scope to a new, inner scope""" 95 return Scope( 96 expression=expression.unnest(), 97 sources={**self.cte_sources, **(chain_sources or {})}, 98 parent=self, 99 scope_type=scope_type, 100 **kwargs, 101 ) 102 103 def _collect(self): 104 self._tables = [] 105 self._ctes = [] 106 self._subqueries = [] 107 self._derived_tables = [] 108 self._udtfs = [] 109 self._raw_columns = [] 110 self._join_hints = [] 111 112 for node, parent, _ in self.walk(bfs=False): 113 if node is self.expression: 114 continue 115 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 116 self._raw_columns.append(node) 117 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 118 self._tables.append(node) 119 elif isinstance(node, exp.JoinHint): 120 self._join_hints.append(node) 121 elif isinstance(node, exp.UDTF): 122 self._udtfs.append(node) 123 elif isinstance(node, exp.CTE): 124 self._ctes.append(node) 125 elif ( 126 isinstance(node, exp.Subquery) 127 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 128 and _is_derived_table(node) 129 ): 130 self._derived_tables.append(node) 131 elif isinstance(node, exp.Subqueryable): 132 self._subqueries.append(node) 133 134 self._collected = True 135 136 def _ensure_collected(self): 137 if not self._collected: 138 self._collect() 139 140 def walk(self, bfs=True): 141 return walk_in_scope(self.expression, bfs=bfs) 142 143 def find(self, *expression_types, bfs=True): 144 """ 145 Returns the first node in this scope which matches at least one of the specified types. 146 147 This does NOT traverse into subscopes. 148 149 Args: 150 expression_types (type): the expression type(s) to match. 151 bfs (bool): True to use breadth-first search, False to use depth-first. 152 153 Returns: 154 exp.Expression: the node which matches the criteria or None if no node matching 155 the criteria was found. 156 """ 157 return next(self.find_all(*expression_types, bfs=bfs), None) 158 159 def find_all(self, *expression_types, bfs=True): 160 """ 161 Returns a generator object which visits all nodes in this scope and only yields those that 162 match at least one of the specified expression types. 163 164 This does NOT traverse into subscopes. 165 166 Args: 167 expression_types (type): the expression type(s) to match. 168 bfs (bool): True to use breadth-first search, False to use depth-first. 169 170 Yields: 171 exp.Expression: nodes 172 """ 173 for expression, *_ in self.walk(bfs=bfs): 174 if isinstance(expression, expression_types): 175 yield expression 176 177 def replace(self, old, new): 178 """ 179 Replace `old` with `new`. 180 181 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 182 183 Args: 184 old (exp.Expression): old node 185 new (exp.Expression): new node 186 """ 187 old.replace(new) 188 self.clear_cache() 189 190 @property 191 def tables(self): 192 """ 193 List of tables in this scope. 194 195 Returns: 196 list[exp.Table]: tables 197 """ 198 self._ensure_collected() 199 return self._tables 200 201 @property 202 def ctes(self): 203 """ 204 List of CTEs in this scope. 205 206 Returns: 207 list[exp.CTE]: ctes 208 """ 209 self._ensure_collected() 210 return self._ctes 211 212 @property 213 def derived_tables(self): 214 """ 215 List of derived tables in this scope. 216 217 For example: 218 SELECT * FROM (SELECT ...) <- that's a derived table 219 220 Returns: 221 list[exp.Subquery]: derived tables 222 """ 223 self._ensure_collected() 224 return self._derived_tables 225 226 @property 227 def udtfs(self): 228 """ 229 List of "User Defined Tabular Functions" in this scope. 230 231 Returns: 232 list[exp.UDTF]: UDTFs 233 """ 234 self._ensure_collected() 235 return self._udtfs 236 237 @property 238 def subqueries(self): 239 """ 240 List of subqueries in this scope. 241 242 For example: 243 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 244 245 Returns: 246 list[exp.Subqueryable]: subqueries 247 """ 248 self._ensure_collected() 249 return self._subqueries 250 251 @property 252 def columns(self): 253 """ 254 List of columns in this scope. 255 256 Returns: 257 list[exp.Column]: Column instances in this scope, plus any 258 Columns that reference this scope from correlated subqueries. 259 """ 260 if self._columns is None: 261 self._ensure_collected() 262 columns = self._raw_columns 263 264 external_columns = [ 265 column 266 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 267 for column in scope.external_columns 268 ] 269 270 named_selects = set(self.expression.named_selects) 271 272 self._columns = [] 273 for column in columns + external_columns: 274 ancestor = column.find_ancestor( 275 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 276 ) 277 if ( 278 not ancestor 279 or column.table 280 or isinstance(ancestor, exp.Select) 281 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 282 or ( 283 isinstance(ancestor, exp.Order) 284 and ( 285 isinstance(ancestor.parent, exp.Window) 286 or column.name not in named_selects 287 ) 288 ) 289 ): 290 self._columns.append(column) 291 292 return self._columns 293 294 @property 295 def selected_sources(self): 296 """ 297 Mapping of nodes and sources that are actually selected from in this scope. 298 299 That is, all tables in a schema are selectable at any point. But a 300 table only becomes a selected source if it's included in a FROM or JOIN clause. 301 302 Returns: 303 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 304 """ 305 if self._selected_sources is None: 306 result = {} 307 308 for name, node in self.references: 309 if name in result: 310 raise OptimizeError(f"Alias already used: {name}") 311 if name in self.sources: 312 result[name] = (node, self.sources[name]) 313 314 self._selected_sources = result 315 return self._selected_sources 316 317 @property 318 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 319 if self._references is None: 320 self._references = [] 321 322 for table in self.tables: 323 self._references.append((table.alias_or_name, table)) 324 for expression in itertools.chain(self.derived_tables, self.udtfs): 325 self._references.append( 326 ( 327 expression.alias, 328 expression if expression.args.get("pivots") else expression.unnest(), 329 ) 330 ) 331 332 return self._references 333 334 @property 335 def cte_sources(self): 336 """ 337 Sources that are CTEs. 338 339 Returns: 340 dict[str, Scope]: Mapping of source alias to Scope 341 """ 342 return { 343 alias: scope 344 for alias, scope in self.sources.items() 345 if isinstance(scope, Scope) and scope.is_cte 346 } 347 348 @property 349 def external_columns(self): 350 """ 351 Columns that appear to reference sources in outer scopes. 352 353 Returns: 354 list[exp.Column]: Column instances that don't reference 355 sources in the current scope. 356 """ 357 if self._external_columns is None: 358 self._external_columns = [ 359 c for c in self.columns if c.table not in self.selected_sources 360 ] 361 return self._external_columns 362 363 @property 364 def unqualified_columns(self): 365 """ 366 Unqualified columns in the current scope. 367 368 Returns: 369 list[exp.Column]: Unqualified columns 370 """ 371 return [c for c in self.columns if not c.table] 372 373 @property 374 def join_hints(self): 375 """ 376 Hints that exist in the scope that reference tables 377 378 Returns: 379 list[exp.JoinHint]: Join hints that are referenced within the scope 380 """ 381 if self._join_hints is None: 382 return [] 383 return self._join_hints 384 385 @property 386 def pivots(self): 387 if not self._pivots: 388 self._pivots = [ 389 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 390 ] 391 392 return self._pivots 393 394 def source_columns(self, source_name): 395 """ 396 Get all columns in the current scope for a particular source. 397 398 Args: 399 source_name (str): Name of the source 400 Returns: 401 list[exp.Column]: Column instances that reference `source_name` 402 """ 403 return [column for column in self.columns if column.table == source_name] 404 405 @property 406 def is_subquery(self): 407 """Determine if this scope is a subquery""" 408 return self.scope_type == ScopeType.SUBQUERY 409 410 @property 411 def is_derived_table(self): 412 """Determine if this scope is a derived table""" 413 return self.scope_type == ScopeType.DERIVED_TABLE 414 415 @property 416 def is_union(self): 417 """Determine if this scope is a union""" 418 return self.scope_type == ScopeType.UNION 419 420 @property 421 def is_cte(self): 422 """Determine if this scope is a common table expression""" 423 return self.scope_type == ScopeType.CTE 424 425 @property 426 def is_root(self): 427 """Determine if this is the root scope""" 428 return self.scope_type == ScopeType.ROOT 429 430 @property 431 def is_udtf(self): 432 """Determine if this scope is a UDTF (User Defined Table Function)""" 433 return self.scope_type == ScopeType.UDTF 434 435 @property 436 def is_correlated_subquery(self): 437 """Determine if this scope is a correlated subquery""" 438 return bool( 439 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 440 and self.external_columns 441 ) 442 443 def rename_source(self, old_name, new_name): 444 """Rename a source in this scope""" 445 columns = self.sources.pop(old_name or "", []) 446 self.sources[new_name] = columns 447 448 def add_source(self, name, source): 449 """Add a source to this scope""" 450 self.sources[name] = source 451 self.clear_cache() 452 453 def remove_source(self, name): 454 """Remove a source from this scope""" 455 self.sources.pop(name, None) 456 self.clear_cache() 457 458 def __repr__(self): 459 return f"Scope<{self.expression.sql()}>" 460 461 def traverse(self): 462 """ 463 Traverse the scope tree from this node. 464 465 Yields: 466 Scope: scope instances in depth-first-search post-order 467 """ 468 for child_scope in itertools.chain( 469 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 470 ): 471 yield from child_scope.traverse() 472 yield self 473 474 def ref_count(self): 475 """ 476 Count the number of times each scope in this tree is referenced. 477 478 Returns: 479 dict[int, int]: Mapping of Scope instance ID to reference count 480 """ 481 scope_ref_count = defaultdict(lambda: 0) 482 483 for scope in self.traverse(): 484 for _, source in scope.selected_sources.values(): 485 scope_ref_count[id(source)] += 1 486 487 return scope_ref_count 488 489 490def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 491 """ 492 Traverse an expression by its "scopes". 493 494 "Scope" represents the current context of a Select statement. 495 496 This is helpful for optimizing queries, where we need more information than 497 the expression tree itself. For example, we might care about the source 498 names within a subquery. Returns a list because a generator could result in 499 incomplete properties which is confusing. 500 501 Examples: 502 >>> import sqlglot 503 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 504 >>> scopes = traverse_scope(expression) 505 >>> scopes[0].expression.sql(), list(scopes[0].sources) 506 ('SELECT a FROM x', ['x']) 507 >>> scopes[1].expression.sql(), list(scopes[1].sources) 508 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 509 510 Args: 511 expression (exp.Expression): expression to traverse 512 Returns: 513 list[Scope]: scope instances 514 """ 515 if not isinstance(expression, exp.Unionable): 516 return [] 517 return list(_traverse_scope(Scope(expression))) 518 519 520def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 521 """ 522 Build a scope tree. 523 524 Args: 525 expression (exp.Expression): expression to build the scope tree for 526 Returns: 527 Scope: root scope 528 """ 529 scopes = traverse_scope(expression) 530 if scopes: 531 return scopes[-1] 532 return None 533 534 535def _traverse_scope(scope): 536 if isinstance(scope.expression, exp.Select): 537 yield from _traverse_select(scope) 538 elif isinstance(scope.expression, exp.Union): 539 yield from _traverse_union(scope) 540 elif isinstance(scope.expression, exp.Subquery): 541 yield from _traverse_subqueries(scope) 542 elif isinstance(scope.expression, exp.Table): 543 yield from _traverse_tables(scope) 544 elif isinstance(scope.expression, exp.UDTF): 545 yield from _traverse_udtfs(scope) 546 else: 547 logger.warning( 548 "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) 549 ) 550 return 551 552 yield scope 553 554 555def _traverse_select(scope): 556 yield from _traverse_ctes(scope) 557 yield from _traverse_tables(scope) 558 yield from _traverse_subqueries(scope) 559 560 561def _traverse_union(scope): 562 yield from _traverse_ctes(scope) 563 564 # The last scope to be yield should be the top most scope 565 left = None 566 for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): 567 yield left 568 569 right = None 570 for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): 571 yield right 572 573 scope.union_scopes = [left, right] 574 575 576def _traverse_ctes(scope): 577 sources = {} 578 579 for cte in scope.ctes: 580 recursive_scope = None 581 582 # if the scope is a recursive cte, it must be in the form of 583 # base_case UNION recursive. thus the recursive scope is the first 584 # section of the union. 585 if scope.expression.args["with"].recursive: 586 union = cte.this 587 588 if isinstance(union, exp.Union): 589 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) 590 591 child_scope = None 592 593 for child_scope in _traverse_scope( 594 scope.branch( 595 cte.this, 596 chain_sources=sources, 597 outer_column_list=cte.alias_column_names, 598 scope_type=ScopeType.CTE, 599 ) 600 ): 601 yield child_scope 602 603 alias = cte.alias 604 sources[alias] = child_scope 605 606 if recursive_scope: 607 child_scope.add_source(alias, recursive_scope) 608 609 # append the final child_scope yielded 610 if child_scope: 611 scope.cte_scopes.append(child_scope) 612 613 scope.sources.update(sources) 614 615 616def _is_derived_table(expression: exp.Subquery) -> bool: 617 """ 618 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 619 as it doesn't introduce a new scope. If an alias is present, it shadows all names 620 under the Subquery, so that's one exception to this rule. 621 """ 622 return bool(expression.alias or isinstance(expression.this, exp.Subqueryable)) 623 624 625def _traverse_tables(scope): 626 sources = {} 627 628 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 629 expressions = [] 630 from_ = scope.expression.args.get("from") 631 if from_: 632 expressions.append(from_.this) 633 634 for join in scope.expression.args.get("joins") or []: 635 expressions.append(join.this) 636 637 if isinstance(scope.expression, exp.Table): 638 expressions.append(scope.expression) 639 640 expressions.extend(scope.expression.args.get("laterals") or []) 641 642 for expression in expressions: 643 if isinstance(expression, exp.Table): 644 table_name = expression.name 645 source_name = expression.alias_or_name 646 647 if table_name in scope.sources and not expression.db: 648 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 649 # it is pivoted, because then we get back a new table and hence a new source. 650 pivots = expression.args.get("pivots") 651 if pivots: 652 sources[pivots[0].alias] = expression 653 else: 654 sources[source_name] = scope.sources[table_name] 655 elif source_name in sources: 656 sources[find_new_name(sources, table_name)] = expression 657 else: 658 sources[source_name] = expression 659 660 # Make sure to not include the joins twice 661 if expression is not scope.expression: 662 expressions.extend(join.this for join in expression.args.get("joins") or []) 663 664 continue 665 666 if not isinstance(expression, exp.DerivedTable): 667 continue 668 669 if isinstance(expression, exp.UDTF): 670 lateral_sources = sources 671 scope_type = ScopeType.UDTF 672 scopes = scope.udtf_scopes 673 elif _is_derived_table(expression): 674 lateral_sources = None 675 scope_type = ScopeType.DERIVED_TABLE 676 scopes = scope.derived_table_scopes 677 expressions.extend(join.this for join in expression.args.get("joins") or []) 678 else: 679 # Makes sure we check for possible sources in nested table constructs 680 expressions.append(expression.this) 681 expressions.extend(join.this for join in expression.args.get("joins") or []) 682 continue 683 684 for child_scope in _traverse_scope( 685 scope.branch( 686 expression, 687 lateral_sources=lateral_sources, 688 outer_column_list=expression.alias_column_names, 689 scope_type=scope_type, 690 ) 691 ): 692 yield child_scope 693 694 # Tables without aliases will be set as "" 695 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 696 # Until then, this means that only a single, unaliased derived table is allowed (rather, 697 # the latest one wins. 698 sources[expression.alias] = child_scope 699 700 # append the final child_scope yielded 701 scopes.append(child_scope) 702 scope.table_scopes.append(child_scope) 703 704 scope.sources.update(sources) 705 706 707def _traverse_subqueries(scope): 708 for subquery in scope.subqueries: 709 top = None 710 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 711 yield child_scope 712 top = child_scope 713 scope.subquery_scopes.append(top) 714 715 716def _traverse_udtfs(scope): 717 if isinstance(scope.expression, exp.Unnest): 718 expressions = scope.expression.expressions 719 elif isinstance(scope.expression, exp.Lateral): 720 expressions = [scope.expression.this] 721 else: 722 expressions = [] 723 724 sources = {} 725 for expression in expressions: 726 if isinstance(expression, exp.Subquery) and _is_derived_table(expression): 727 top = None 728 for child_scope in _traverse_scope( 729 scope.branch( 730 expression, 731 scope_type=ScopeType.DERIVED_TABLE, 732 outer_column_list=expression.alias_column_names, 733 ) 734 ): 735 yield child_scope 736 top = child_scope 737 sources[expression.alias] = child_scope 738 739 scope.derived_table_scopes.append(top) 740 scope.table_scopes.append(top) 741 742 scope.sources.update(sources) 743 744 745def walk_in_scope(expression, bfs=True): 746 """ 747 Returns a generator object which visits all nodes in the syntrax tree, stopping at 748 nodes that start child scopes. 749 750 Args: 751 expression (exp.Expression): 752 bfs (bool): if set to True the BFS traversal order will be applied, 753 otherwise the DFS traversal will be used instead. 754 755 Yields: 756 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 757 """ 758 # We'll use this variable to pass state into the dfs generator. 759 # Whenever we set it to True, we exclude a subtree from traversal. 760 prune = False 761 762 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 763 prune = False 764 765 yield node, parent, key 766 767 if node is expression: 768 continue 769 if ( 770 isinstance(node, exp.CTE) 771 or ( 772 isinstance(node, exp.Subquery) 773 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 774 and _is_derived_table(node) 775 ) 776 or isinstance(node, exp.UDTF) 777 or isinstance(node, exp.Subqueryable) 778 ): 779 prune = True 780 781 if isinstance(node, (exp.Subquery, exp.UDTF)): 782 # The following args are not actually in the inner scope, so we should visit them 783 for key in ("joins", "laterals", "pivots"): 784 for arg in node.args.get(key) or []: 785 yield from walk_in_scope(arg, bfs=bfs)
15class ScopeType(Enum): 16 ROOT = auto() 17 SUBQUERY = auto() 18 DERIVED_TABLE = auto() 19 CTE = auto() 20 UNION = auto() 21 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
24class Scope: 25 """ 26 Selection scope. 27 28 Attributes: 29 expression (exp.Select|exp.Union): Root expression of this scope 30 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 31 a Table expression or another Scope instance. For example: 32 SELECT * FROM x {"x": Table(this="x")} 33 SELECT * FROM x AS y {"y": Table(this="x")} 34 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 35 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 36 For example: 37 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 38 The LATERAL VIEW EXPLODE gets x as a source. 39 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 40 defines a column list of it's alias of this scope, this is that list of columns. 41 For example: 42 SELECT * FROM (SELECT ...) AS y(col1, col2) 43 The inner query would have `["col1", "col2"]` for its `outer_column_list` 44 parent (Scope): Parent scope 45 scope_type (ScopeType): Type of this scope, relative to it's parent 46 subquery_scopes (list[Scope]): List of all child scopes for subqueries 47 cte_scopes (list[Scope]): List of all child scopes for CTEs 48 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 49 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 50 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 51 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 52 a list of the left and right child scopes. 53 """ 54 55 def __init__( 56 self, 57 expression, 58 sources=None, 59 outer_column_list=None, 60 parent=None, 61 scope_type=ScopeType.ROOT, 62 lateral_sources=None, 63 ): 64 self.expression = expression 65 self.sources = sources or {} 66 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 67 self.sources.update(self.lateral_sources) 68 self.outer_column_list = outer_column_list or [] 69 self.parent = parent 70 self.scope_type = scope_type 71 self.subquery_scopes = [] 72 self.derived_table_scopes = [] 73 self.table_scopes = [] 74 self.cte_scopes = [] 75 self.union_scopes = [] 76 self.udtf_scopes = [] 77 self.clear_cache() 78 79 def clear_cache(self): 80 self._collected = False 81 self._raw_columns = None 82 self._derived_tables = None 83 self._udtfs = None 84 self._tables = None 85 self._ctes = None 86 self._subqueries = None 87 self._selected_sources = None 88 self._columns = None 89 self._external_columns = None 90 self._join_hints = None 91 self._pivots = None 92 self._references = None 93 94 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 95 """Branch from the current scope to a new, inner scope""" 96 return Scope( 97 expression=expression.unnest(), 98 sources={**self.cte_sources, **(chain_sources or {})}, 99 parent=self, 100 scope_type=scope_type, 101 **kwargs, 102 ) 103 104 def _collect(self): 105 self._tables = [] 106 self._ctes = [] 107 self._subqueries = [] 108 self._derived_tables = [] 109 self._udtfs = [] 110 self._raw_columns = [] 111 self._join_hints = [] 112 113 for node, parent, _ in self.walk(bfs=False): 114 if node is self.expression: 115 continue 116 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 117 self._raw_columns.append(node) 118 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 119 self._tables.append(node) 120 elif isinstance(node, exp.JoinHint): 121 self._join_hints.append(node) 122 elif isinstance(node, exp.UDTF): 123 self._udtfs.append(node) 124 elif isinstance(node, exp.CTE): 125 self._ctes.append(node) 126 elif ( 127 isinstance(node, exp.Subquery) 128 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 129 and _is_derived_table(node) 130 ): 131 self._derived_tables.append(node) 132 elif isinstance(node, exp.Subqueryable): 133 self._subqueries.append(node) 134 135 self._collected = True 136 137 def _ensure_collected(self): 138 if not self._collected: 139 self._collect() 140 141 def walk(self, bfs=True): 142 return walk_in_scope(self.expression, bfs=bfs) 143 144 def find(self, *expression_types, bfs=True): 145 """ 146 Returns the first node in this scope which matches at least one of the specified types. 147 148 This does NOT traverse into subscopes. 149 150 Args: 151 expression_types (type): the expression type(s) to match. 152 bfs (bool): True to use breadth-first search, False to use depth-first. 153 154 Returns: 155 exp.Expression: the node which matches the criteria or None if no node matching 156 the criteria was found. 157 """ 158 return next(self.find_all(*expression_types, bfs=bfs), None) 159 160 def find_all(self, *expression_types, bfs=True): 161 """ 162 Returns a generator object which visits all nodes in this scope and only yields those that 163 match at least one of the specified expression types. 164 165 This does NOT traverse into subscopes. 166 167 Args: 168 expression_types (type): the expression type(s) to match. 169 bfs (bool): True to use breadth-first search, False to use depth-first. 170 171 Yields: 172 exp.Expression: nodes 173 """ 174 for expression, *_ in self.walk(bfs=bfs): 175 if isinstance(expression, expression_types): 176 yield expression 177 178 def replace(self, old, new): 179 """ 180 Replace `old` with `new`. 181 182 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 183 184 Args: 185 old (exp.Expression): old node 186 new (exp.Expression): new node 187 """ 188 old.replace(new) 189 self.clear_cache() 190 191 @property 192 def tables(self): 193 """ 194 List of tables in this scope. 195 196 Returns: 197 list[exp.Table]: tables 198 """ 199 self._ensure_collected() 200 return self._tables 201 202 @property 203 def ctes(self): 204 """ 205 List of CTEs in this scope. 206 207 Returns: 208 list[exp.CTE]: ctes 209 """ 210 self._ensure_collected() 211 return self._ctes 212 213 @property 214 def derived_tables(self): 215 """ 216 List of derived tables in this scope. 217 218 For example: 219 SELECT * FROM (SELECT ...) <- that's a derived table 220 221 Returns: 222 list[exp.Subquery]: derived tables 223 """ 224 self._ensure_collected() 225 return self._derived_tables 226 227 @property 228 def udtfs(self): 229 """ 230 List of "User Defined Tabular Functions" in this scope. 231 232 Returns: 233 list[exp.UDTF]: UDTFs 234 """ 235 self._ensure_collected() 236 return self._udtfs 237 238 @property 239 def subqueries(self): 240 """ 241 List of subqueries in this scope. 242 243 For example: 244 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 245 246 Returns: 247 list[exp.Subqueryable]: subqueries 248 """ 249 self._ensure_collected() 250 return self._subqueries 251 252 @property 253 def columns(self): 254 """ 255 List of columns in this scope. 256 257 Returns: 258 list[exp.Column]: Column instances in this scope, plus any 259 Columns that reference this scope from correlated subqueries. 260 """ 261 if self._columns is None: 262 self._ensure_collected() 263 columns = self._raw_columns 264 265 external_columns = [ 266 column 267 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 268 for column in scope.external_columns 269 ] 270 271 named_selects = set(self.expression.named_selects) 272 273 self._columns = [] 274 for column in columns + external_columns: 275 ancestor = column.find_ancestor( 276 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 277 ) 278 if ( 279 not ancestor 280 or column.table 281 or isinstance(ancestor, exp.Select) 282 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 283 or ( 284 isinstance(ancestor, exp.Order) 285 and ( 286 isinstance(ancestor.parent, exp.Window) 287 or column.name not in named_selects 288 ) 289 ) 290 ): 291 self._columns.append(column) 292 293 return self._columns 294 295 @property 296 def selected_sources(self): 297 """ 298 Mapping of nodes and sources that are actually selected from in this scope. 299 300 That is, all tables in a schema are selectable at any point. But a 301 table only becomes a selected source if it's included in a FROM or JOIN clause. 302 303 Returns: 304 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 305 """ 306 if self._selected_sources is None: 307 result = {} 308 309 for name, node in self.references: 310 if name in result: 311 raise OptimizeError(f"Alias already used: {name}") 312 if name in self.sources: 313 result[name] = (node, self.sources[name]) 314 315 self._selected_sources = result 316 return self._selected_sources 317 318 @property 319 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 320 if self._references is None: 321 self._references = [] 322 323 for table in self.tables: 324 self._references.append((table.alias_or_name, table)) 325 for expression in itertools.chain(self.derived_tables, self.udtfs): 326 self._references.append( 327 ( 328 expression.alias, 329 expression if expression.args.get("pivots") else expression.unnest(), 330 ) 331 ) 332 333 return self._references 334 335 @property 336 def cte_sources(self): 337 """ 338 Sources that are CTEs. 339 340 Returns: 341 dict[str, Scope]: Mapping of source alias to Scope 342 """ 343 return { 344 alias: scope 345 for alias, scope in self.sources.items() 346 if isinstance(scope, Scope) and scope.is_cte 347 } 348 349 @property 350 def external_columns(self): 351 """ 352 Columns that appear to reference sources in outer scopes. 353 354 Returns: 355 list[exp.Column]: Column instances that don't reference 356 sources in the current scope. 357 """ 358 if self._external_columns is None: 359 self._external_columns = [ 360 c for c in self.columns if c.table not in self.selected_sources 361 ] 362 return self._external_columns 363 364 @property 365 def unqualified_columns(self): 366 """ 367 Unqualified columns in the current scope. 368 369 Returns: 370 list[exp.Column]: Unqualified columns 371 """ 372 return [c for c in self.columns if not c.table] 373 374 @property 375 def join_hints(self): 376 """ 377 Hints that exist in the scope that reference tables 378 379 Returns: 380 list[exp.JoinHint]: Join hints that are referenced within the scope 381 """ 382 if self._join_hints is None: 383 return [] 384 return self._join_hints 385 386 @property 387 def pivots(self): 388 if not self._pivots: 389 self._pivots = [ 390 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 391 ] 392 393 return self._pivots 394 395 def source_columns(self, source_name): 396 """ 397 Get all columns in the current scope for a particular source. 398 399 Args: 400 source_name (str): Name of the source 401 Returns: 402 list[exp.Column]: Column instances that reference `source_name` 403 """ 404 return [column for column in self.columns if column.table == source_name] 405 406 @property 407 def is_subquery(self): 408 """Determine if this scope is a subquery""" 409 return self.scope_type == ScopeType.SUBQUERY 410 411 @property 412 def is_derived_table(self): 413 """Determine if this scope is a derived table""" 414 return self.scope_type == ScopeType.DERIVED_TABLE 415 416 @property 417 def is_union(self): 418 """Determine if this scope is a union""" 419 return self.scope_type == ScopeType.UNION 420 421 @property 422 def is_cte(self): 423 """Determine if this scope is a common table expression""" 424 return self.scope_type == ScopeType.CTE 425 426 @property 427 def is_root(self): 428 """Determine if this is the root scope""" 429 return self.scope_type == ScopeType.ROOT 430 431 @property 432 def is_udtf(self): 433 """Determine if this scope is a UDTF (User Defined Table Function)""" 434 return self.scope_type == ScopeType.UDTF 435 436 @property 437 def is_correlated_subquery(self): 438 """Determine if this scope is a correlated subquery""" 439 return bool( 440 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 441 and self.external_columns 442 ) 443 444 def rename_source(self, old_name, new_name): 445 """Rename a source in this scope""" 446 columns = self.sources.pop(old_name or "", []) 447 self.sources[new_name] = columns 448 449 def add_source(self, name, source): 450 """Add a source to this scope""" 451 self.sources[name] = source 452 self.clear_cache() 453 454 def remove_source(self, name): 455 """Remove a source from this scope""" 456 self.sources.pop(name, None) 457 self.clear_cache() 458 459 def __repr__(self): 460 return f"Scope<{self.expression.sql()}>" 461 462 def traverse(self): 463 """ 464 Traverse the scope tree from this node. 465 466 Yields: 467 Scope: scope instances in depth-first-search post-order 468 """ 469 for child_scope in itertools.chain( 470 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 471 ): 472 yield from child_scope.traverse() 473 yield self 474 475 def ref_count(self): 476 """ 477 Count the number of times each scope in this tree is referenced. 478 479 Returns: 480 dict[int, int]: Mapping of Scope instance ID to reference count 481 """ 482 scope_ref_count = defaultdict(lambda: 0) 483 484 for scope in self.traverse(): 485 for _, source in scope.selected_sources.values(): 486 scope_ref_count[id(source)] += 1 487 488 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_column_list
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
55 def __init__( 56 self, 57 expression, 58 sources=None, 59 outer_column_list=None, 60 parent=None, 61 scope_type=ScopeType.ROOT, 62 lateral_sources=None, 63 ): 64 self.expression = expression 65 self.sources = sources or {} 66 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 67 self.sources.update(self.lateral_sources) 68 self.outer_column_list = outer_column_list or [] 69 self.parent = parent 70 self.scope_type = scope_type 71 self.subquery_scopes = [] 72 self.derived_table_scopes = [] 73 self.table_scopes = [] 74 self.cte_scopes = [] 75 self.union_scopes = [] 76 self.udtf_scopes = [] 77 self.clear_cache()
79 def clear_cache(self): 80 self._collected = False 81 self._raw_columns = None 82 self._derived_tables = None 83 self._udtfs = None 84 self._tables = None 85 self._ctes = None 86 self._subqueries = None 87 self._selected_sources = None 88 self._columns = None 89 self._external_columns = None 90 self._join_hints = None 91 self._pivots = None 92 self._references = None
94 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 95 """Branch from the current scope to a new, inner scope""" 96 return Scope( 97 expression=expression.unnest(), 98 sources={**self.cte_sources, **(chain_sources or {})}, 99 parent=self, 100 scope_type=scope_type, 101 **kwargs, 102 )
Branch from the current scope to a new, inner scope
144 def find(self, *expression_types, bfs=True): 145 """ 146 Returns the first node in this scope which matches at least one of the specified types. 147 148 This does NOT traverse into subscopes. 149 150 Args: 151 expression_types (type): the expression type(s) to match. 152 bfs (bool): True to use breadth-first search, False to use depth-first. 153 154 Returns: 155 exp.Expression: the node which matches the criteria or None if no node matching 156 the criteria was found. 157 """ 158 return next(self.find_all(*expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.
160 def find_all(self, *expression_types, bfs=True): 161 """ 162 Returns a generator object which visits all nodes in this scope and only yields those that 163 match at least one of the specified expression types. 164 165 This does NOT traverse into subscopes. 166 167 Args: 168 expression_types (type): the expression type(s) to match. 169 bfs (bool): True to use breadth-first search, False to use depth-first. 170 171 Yields: 172 exp.Expression: nodes 173 """ 174 for expression, *_ in self.walk(bfs=bfs): 175 if isinstance(expression, expression_types): 176 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
178 def replace(self, old, new): 179 """ 180 Replace `old` with `new`. 181 182 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 183 184 Args: 185 old (exp.Expression): old node 186 new (exp.Expression): new node 187 """ 188 old.replace(new) 189 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
395 def source_columns(self, source_name): 396 """ 397 Get all columns in the current scope for a particular source. 398 399 Args: 400 source_name (str): Name of the source 401 Returns: 402 list[exp.Column]: Column instances that reference `source_name` 403 """ 404 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
444 def rename_source(self, old_name, new_name): 445 """Rename a source in this scope""" 446 columns = self.sources.pop(old_name or "", []) 447 self.sources[new_name] = columns
Rename a source in this scope
449 def add_source(self, name, source): 450 """Add a source to this scope""" 451 self.sources[name] = source 452 self.clear_cache()
Add a source to this scope
454 def remove_source(self, name): 455 """Remove a source from this scope""" 456 self.sources.pop(name, None) 457 self.clear_cache()
Remove a source from this scope
462 def traverse(self): 463 """ 464 Traverse the scope tree from this node. 465 466 Yields: 467 Scope: scope instances in depth-first-search post-order 468 """ 469 for child_scope in itertools.chain( 470 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 471 ): 472 yield from child_scope.traverse() 473 yield self
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
475 def ref_count(self): 476 """ 477 Count the number of times each scope in this tree is referenced. 478 479 Returns: 480 dict[int, int]: Mapping of Scope instance ID to reference count 481 """ 482 scope_ref_count = defaultdict(lambda: 0) 483 484 for scope in self.traverse(): 485 for _, source in scope.selected_sources.values(): 486 scope_ref_count[id(source)] += 1 487 488 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
491def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 492 """ 493 Traverse an expression by its "scopes". 494 495 "Scope" represents the current context of a Select statement. 496 497 This is helpful for optimizing queries, where we need more information than 498 the expression tree itself. For example, we might care about the source 499 names within a subquery. Returns a list because a generator could result in 500 incomplete properties which is confusing. 501 502 Examples: 503 >>> import sqlglot 504 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 505 >>> scopes = traverse_scope(expression) 506 >>> scopes[0].expression.sql(), list(scopes[0].sources) 507 ('SELECT a FROM x', ['x']) 508 >>> scopes[1].expression.sql(), list(scopes[1].sources) 509 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 510 511 Args: 512 expression (exp.Expression): expression to traverse 513 Returns: 514 list[Scope]: scope instances 515 """ 516 if not isinstance(expression, exp.Unionable): 517 return [] 518 return list(_traverse_scope(Scope(expression)))
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
521def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 522 """ 523 Build a scope tree. 524 525 Args: 526 expression (exp.Expression): expression to build the scope tree for 527 Returns: 528 Scope: root scope 529 """ 530 scopes = traverse_scope(expression) 531 if scopes: 532 return scopes[-1] 533 return None
Build a scope tree.
Arguments:
- expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
746def walk_in_scope(expression, bfs=True): 747 """ 748 Returns a generator object which visits all nodes in the syntrax tree, stopping at 749 nodes that start child scopes. 750 751 Args: 752 expression (exp.Expression): 753 bfs (bool): if set to True the BFS traversal order will be applied, 754 otherwise the DFS traversal will be used instead. 755 756 Yields: 757 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 758 """ 759 # We'll use this variable to pass state into the dfs generator. 760 # Whenever we set it to True, we exclude a subtree from traversal. 761 prune = False 762 763 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 764 prune = False 765 766 yield node, parent, key 767 768 if node is expression: 769 continue 770 if ( 771 isinstance(node, exp.CTE) 772 or ( 773 isinstance(node, exp.Subquery) 774 and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) 775 and _is_derived_table(node) 776 ) 777 or isinstance(node, exp.UDTF) 778 or isinstance(node, exp.Subqueryable) 779 ): 780 prune = True 781 782 if isinstance(node, (exp.Subquery, exp.UDTF)): 783 # The following args are not actually in the inner scope, so we should visit them 784 for key in ("joins", "laterals", "pivots"): 785 for arg in node.args.get(key) or []: 786 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key