Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import seq_get, SingleValuedMapping
 10from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
 11from sqlglot.optimizer.simplify import simplify_parens
 12from sqlglot.schema import Schema, ensure_schema
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot._typing import E
 16
 17
 18def qualify_columns(
 19    expression: exp.Expression,
 20    schema: t.Dict | Schema,
 21    expand_alias_refs: bool = True,
 22    expand_stars: bool = True,
 23    infer_schema: t.Optional[bool] = None,
 24) -> exp.Expression:
 25    """
 26    Rewrite sqlglot AST to have fully qualified columns.
 27
 28    Example:
 29        >>> import sqlglot
 30        >>> schema = {"tbl": {"col": "INT"}}
 31        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 32        >>> qualify_columns(expression, schema).sql()
 33        'SELECT tbl.col AS col FROM tbl'
 34
 35    Args:
 36        expression: Expression to qualify.
 37        schema: Database schema.
 38        expand_alias_refs: Whether to expand references to aliases.
 39        expand_stars: Whether to expand star queries. This is a necessary step
 40            for most of the optimizer's rules to work; do not set to False unless you
 41            know what you're doing!
 42        infer_schema: Whether to infer the schema if missing.
 43
 44    Returns:
 45        The qualified expression.
 46
 47    Notes:
 48        - Currently only handles a single PIVOT or UNPIVOT operator
 49    """
 50    schema = ensure_schema(schema)
 51    infer_schema = schema.empty if infer_schema is None else infer_schema
 52    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
 53
 54    for scope in traverse_scope(expression):
 55        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 56        _pop_table_column_aliases(scope.ctes)
 57        _pop_table_column_aliases(scope.derived_tables)
 58        using_column_tables = _expand_using(scope, resolver)
 59
 60        if schema.empty and expand_alias_refs:
 61            _expand_alias_refs(scope, resolver)
 62
 63        _qualify_columns(scope, resolver)
 64
 65        if not schema.empty and expand_alias_refs:
 66            _expand_alias_refs(scope, resolver)
 67
 68        if not isinstance(scope.expression, exp.UDTF):
 69            if expand_stars:
 70                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 71            qualify_outputs(scope)
 72
 73        _expand_group_by(scope)
 74        _expand_order_by(scope, resolver)
 75
 76    return expression
 77
 78
 79def validate_qualify_columns(expression: E) -> E:
 80    """Raise an `OptimizeError` if any columns aren't qualified"""
 81    all_unqualified_columns = []
 82    for scope in traverse_scope(expression):
 83        if isinstance(scope.expression, exp.Select):
 84            unqualified_columns = scope.unqualified_columns
 85
 86            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 87                column = scope.external_columns[0]
 88                for_table = f" for table: '{column.table}'" if column.table else ""
 89                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 90
 91            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 92                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 93                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 94                # this list here to ensure those in the former category will be excluded.
 95                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 96                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 97
 98            all_unqualified_columns.extend(unqualified_columns)
 99
100    if all_unqualified_columns:
101        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
102
103    return expression
104
105
106def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
107    name_column = []
108    field = unpivot.args.get("field")
109    if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
110        name_column.append(field.this)
111
112    value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
113    return itertools.chain(name_column, value_columns)
114
115
116def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
117    """
118    Remove table column aliases.
119
120    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
121    """
122    for derived_table in derived_tables:
123        table_alias = derived_table.args.get("alias")
124        if table_alias:
125            table_alias.args.pop("columns", None)
126
127
128def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
129    joins = list(scope.find_all(exp.Join))
130    names = {join.alias_or_name for join in joins}
131    ordered = [key for key in scope.selected_sources if key not in names]
132
133    # Mapping of automatically joined column names to an ordered set of source names (dict).
134    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
135
136    for join in joins:
137        using = join.args.get("using")
138
139        if not using:
140            continue
141
142        join_table = join.alias_or_name
143
144        columns = {}
145
146        for source_name in scope.selected_sources:
147            if source_name in ordered:
148                for column_name in resolver.get_source_columns(source_name):
149                    if column_name not in columns:
150                        columns[column_name] = source_name
151
152        source_table = ordered[-1]
153        ordered.append(join_table)
154        join_columns = resolver.get_source_columns(join_table)
155        conditions = []
156
157        for identifier in using:
158            identifier = identifier.name
159            table = columns.get(identifier)
160
161            if not table or identifier not in join_columns:
162                if (columns and "*" not in columns) and join_columns:
163                    raise OptimizeError(f"Cannot automatically join: {identifier}")
164
165            table = table or source_table
166            conditions.append(
167                exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table))
168            )
169
170            # Set all values in the dict to None, because we only care about the key ordering
171            tables = column_tables.setdefault(identifier, {})
172            if table not in tables:
173                tables[table] = None
174            if join_table not in tables:
175                tables[join_table] = None
176
177        join.args.pop("using")
178        join.set("on", exp.and_(*conditions, copy=False))
179
180    if column_tables:
181        for column in scope.columns:
182            if not column.table and column.name in column_tables:
183                tables = column_tables[column.name]
184                coalesce = [exp.column(column.name, table=table) for table in tables]
185                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
186
187                # Ensure selects keep their output name
188                if isinstance(column.parent, exp.Select):
189                    replacement = alias(replacement, alias=column.name, copy=False)
190
191                scope.replace(column, replacement)
192
193    return column_tables
194
195
196def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
197    expression = scope.expression
198
199    if not isinstance(expression, exp.Select):
200        return
201
202    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
203
204    def replace_columns(
205        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
206    ) -> None:
207        if not node:
208            return
209
210        for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star):
211            if not isinstance(column, exp.Column):
212                continue
213
214            table = resolver.get_table(column.name) if resolve_table and not column.table else None
215            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
216            double_agg = (
217                (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
218                if alias_expr
219                else False
220            )
221
222            if table and (not alias_expr or double_agg):
223                column.set("table", table)
224            elif not column.table and alias_expr and not double_agg:
225                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
226                    if literal_index:
227                        column.replace(exp.Literal.number(i))
228                else:
229                    column = column.replace(exp.paren(alias_expr))
230                    simplified = simplify_parens(column)
231                    if simplified is not column:
232                        column.replace(simplified)
233
234    for i, projection in enumerate(scope.expression.selects):
235        replace_columns(projection)
236
237        if isinstance(projection, exp.Alias):
238            alias_to_expression[projection.alias] = (projection.this, i + 1)
239
240    replace_columns(expression.args.get("where"))
241    replace_columns(expression.args.get("group"), literal_index=True)
242    replace_columns(expression.args.get("having"), resolve_table=True)
243    replace_columns(expression.args.get("qualify"), resolve_table=True)
244
245    scope.clear_cache()
246
247
248def _expand_group_by(scope: Scope) -> None:
249    expression = scope.expression
250    group = expression.args.get("group")
251    if not group:
252        return
253
254    group.set("expressions", _expand_positional_references(scope, group.expressions))
255    expression.set("group", group)
256
257
258def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
259    order = scope.expression.args.get("order")
260    if not order:
261        return
262
263    ordereds = order.expressions
264    for ordered, new_expression in zip(
265        ordereds,
266        _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
267    ):
268        for agg in ordered.find_all(exp.AggFunc):
269            for col in agg.find_all(exp.Column):
270                if not col.table:
271                    col.set("table", resolver.get_table(col.name))
272
273        ordered.set("this", new_expression)
274
275    if scope.expression.args.get("group"):
276        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
277
278        for ordered in ordereds:
279            ordered = ordered.this
280
281            ordered.replace(
282                exp.to_identifier(_select_by_pos(scope, ordered).alias)
283                if ordered.is_int
284                else selects.get(ordered, ordered)
285            )
286
287
288def _expand_positional_references(
289    scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
290) -> t.List[exp.Expression]:
291    new_nodes: t.List[exp.Expression] = []
292    for node in expressions:
293        if node.is_int:
294            select = _select_by_pos(scope, t.cast(exp.Literal, node))
295
296            if alias:
297                new_nodes.append(exp.column(select.args["alias"].copy()))
298            else:
299                select = select.this
300
301                if isinstance(select, exp.Literal):
302                    new_nodes.append(node)
303                else:
304                    new_nodes.append(select.copy())
305        else:
306            new_nodes.append(node)
307
308    return new_nodes
309
310
311def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
312    try:
313        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
314    except IndexError:
315        raise OptimizeError(f"Unknown output column: {node.name}")
316
317
318def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
319    """Disambiguate columns, ensuring each column specifies a source"""
320    for column in scope.columns:
321        column_table = column.table
322        column_name = column.name
323
324        if column_table and column_table in scope.sources:
325            source_columns = resolver.get_source_columns(column_table)
326            if source_columns and column_name not in source_columns and "*" not in source_columns:
327                raise OptimizeError(f"Unknown column: {column_name}")
328
329        if not column_table:
330            if scope.pivots and not column.find_ancestor(exp.Pivot):
331                # If the column is under the Pivot expression, we need to qualify it
332                # using the name of the pivoted source instead of the pivot's alias
333                column.set("table", exp.to_identifier(scope.pivots[0].alias))
334                continue
335
336            column_table = resolver.get_table(column_name)
337
338            # column_table can be a '' because bigquery unnest has no table alias
339            if column_table:
340                column.set("table", column_table)
341        elif column_table not in scope.sources and (
342            not scope.parent
343            or column_table not in scope.parent.sources
344            or not scope.is_correlated_subquery
345        ):
346            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
347            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
348
349            root, *parts = column.parts
350
351            if root.name in scope.sources:
352                # struct is already qualified, but we still need to change the AST representation
353                column_table = root
354                root, *parts = parts
355            else:
356                column_table = resolver.get_table(root.name)
357
358            if column_table:
359                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
360
361    for pivot in scope.pivots:
362        for column in pivot.find_all(exp.Column):
363            if not column.table and column.name in resolver.all_columns:
364                column_table = resolver.get_table(column.name)
365                if column_table:
366                    column.set("table", column_table)
367
368
369def _expand_stars(
370    scope: Scope,
371    resolver: Resolver,
372    using_column_tables: t.Dict[str, t.Any],
373    pseudocolumns: t.Set[str],
374) -> None:
375    """Expand stars to lists of column selections"""
376
377    new_selections = []
378    except_columns: t.Dict[int, t.Set[str]] = {}
379    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
380    coalesced_columns = set()
381
382    pivot_output_columns = None
383    pivot_exclude_columns = None
384
385    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
386    if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
387        if pivot.unpivot:
388            pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
389
390            field = pivot.args.get("field")
391            if isinstance(field, exp.In):
392                pivot_exclude_columns = {
393                    c.output_name for e in field.expressions for c in e.find_all(exp.Column)
394                }
395        else:
396            pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
397
398            pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
399            if not pivot_output_columns:
400                pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
401
402    for expression in scope.expression.selects:
403        if isinstance(expression, exp.Star):
404            tables = list(scope.selected_sources)
405            _add_except_columns(expression, tables, except_columns)
406            _add_replace_columns(expression, tables, replace_columns)
407        elif expression.is_star:
408            tables = [expression.table]
409            _add_except_columns(expression.this, tables, except_columns)
410            _add_replace_columns(expression.this, tables, replace_columns)
411        else:
412            new_selections.append(expression)
413            continue
414
415        for table in tables:
416            if table not in scope.sources:
417                raise OptimizeError(f"Unknown table: {table}")
418
419            columns = resolver.get_source_columns(table, only_visible=True)
420            columns = columns or scope.outer_column_list
421
422            if pseudocolumns:
423                columns = [name for name in columns if name.upper() not in pseudocolumns]
424
425            if not columns or "*" in columns:
426                return
427
428            table_id = id(table)
429            columns_to_exclude = except_columns.get(table_id) or set()
430
431            if pivot:
432                if pivot_output_columns and pivot_exclude_columns:
433                    pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
434                    pivot_columns.extend(pivot_output_columns)
435                else:
436                    pivot_columns = pivot.alias_column_names
437
438                if pivot_columns:
439                    new_selections.extend(
440                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
441                        for name in pivot_columns
442                        if name not in columns_to_exclude
443                    )
444                    continue
445
446            for name in columns:
447                if name in columns_to_exclude or name in coalesced_columns:
448                    continue
449                if name in using_column_tables and table in using_column_tables[name]:
450                    coalesced_columns.add(name)
451                    tables = using_column_tables[name]
452                    coalesce = [exp.column(name, table=table) for table in tables]
453
454                    new_selections.append(
455                        alias(
456                            exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
457                            alias=name,
458                            copy=False,
459                        )
460                    )
461                else:
462                    alias_ = replace_columns.get(table_id, {}).get(name, name)
463                    column = exp.column(name, table=table)
464                    new_selections.append(
465                        alias(column, alias_, copy=False) if alias_ != name else column
466                    )
467
468    # Ensures we don't overwrite the initial selections with an empty list
469    if new_selections:
470        scope.expression.set("expressions", new_selections)
471
472
473def _add_except_columns(
474    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
475) -> None:
476    except_ = expression.args.get("except")
477
478    if not except_:
479        return
480
481    columns = {e.name for e in except_}
482
483    for table in tables:
484        except_columns[id(table)] = columns
485
486
487def _add_replace_columns(
488    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
489) -> None:
490    replace = expression.args.get("replace")
491
492    if not replace:
493        return
494
495    columns = {e.this.name: e.alias for e in replace}
496
497    for table in tables:
498        replace_columns[id(table)] = columns
499
500
501def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
502    """Ensure all output columns are aliased"""
503    if isinstance(scope_or_expression, exp.Expression):
504        scope = build_scope(scope_or_expression)
505        if not isinstance(scope, Scope):
506            return
507    else:
508        scope = scope_or_expression
509
510    new_selections = []
511    for i, (selection, aliased_column) in enumerate(
512        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
513    ):
514        if selection is None:
515            break
516
517        if isinstance(selection, exp.Subquery):
518            if not selection.output_name:
519                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
520        elif not isinstance(selection, exp.Alias) and not selection.is_star:
521            selection = alias(
522                selection,
523                alias=selection.output_name or f"_col_{i}",
524                copy=False,
525            )
526        if aliased_column:
527            selection.set("alias", exp.to_identifier(aliased_column))
528
529        new_selections.append(selection)
530
531    scope.expression.set("expressions", new_selections)
532
533
534def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
535    """Makes sure all identifiers that need to be quoted are quoted."""
536    return expression.transform(
537        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
538    )
539
540
541def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
542    """
543    Pushes down the CTE alias columns into the projection,
544
545    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
546
547    Example:
548        >>> import sqlglot
549        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
550        >>> pushdown_cte_alias_columns(expression).sql()
551        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
552
553    Args:
554        expression: Expression to pushdown.
555
556    Returns:
557        The expression with the CTE aliases pushed down into the projection.
558    """
559    for cte in expression.find_all(exp.CTE):
560        if cte.alias_column_names:
561            new_expressions = []
562            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
563                if isinstance(projection, exp.Alias):
564                    projection.set("alias", _alias)
565                else:
566                    projection = alias(projection, alias=_alias)
567                new_expressions.append(projection)
568            cte.this.set("expressions", new_expressions)
569
570    return expression
571
572
573class Resolver:
574    """
575    Helper for resolving columns.
576
577    This is a class so we can lazily load some things and easily share them across functions.
578    """
579
580    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
581        self.scope = scope
582        self.schema = schema
583        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
584        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
585        self._all_columns: t.Optional[t.Set[str]] = None
586        self._infer_schema = infer_schema
587
588    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
589        """
590        Get the table for a column name.
591
592        Args:
593            column_name: The column name to find the table for.
594        Returns:
595            The table name if it can be found/inferred.
596        """
597        if self._unambiguous_columns is None:
598            self._unambiguous_columns = self._get_unambiguous_columns(
599                self._get_all_source_columns()
600            )
601
602        table_name = self._unambiguous_columns.get(column_name)
603
604        if not table_name and self._infer_schema:
605            sources_without_schema = tuple(
606                source
607                for source, columns in self._get_all_source_columns().items()
608                if not columns or "*" in columns
609            )
610            if len(sources_without_schema) == 1:
611                table_name = sources_without_schema[0]
612
613        if table_name not in self.scope.selected_sources:
614            return exp.to_identifier(table_name)
615
616        node, _ = self.scope.selected_sources.get(table_name)
617
618        if isinstance(node, exp.Subqueryable):
619            while node and node.alias != table_name:
620                node = node.parent
621
622        node_alias = node.args.get("alias")
623        if node_alias:
624            return exp.to_identifier(node_alias.this)
625
626        return exp.to_identifier(table_name)
627
628    @property
629    def all_columns(self) -> t.Set[str]:
630        """All available columns of all sources in this scope"""
631        if self._all_columns is None:
632            self._all_columns = {
633                column for columns in self._get_all_source_columns().values() for column in columns
634            }
635        return self._all_columns
636
637    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
638        """Resolve the source columns for a given source `name`."""
639        if name not in self.scope.sources:
640            raise OptimizeError(f"Unknown table: {name}")
641
642        source = self.scope.sources[name]
643
644        if isinstance(source, exp.Table):
645            columns = self.schema.column_names(source, only_visible)
646        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
647            columns = source.expression.alias_column_names
648        else:
649            columns = source.expression.named_selects
650
651        node, _ = self.scope.selected_sources.get(name) or (None, None)
652        if isinstance(node, Scope):
653            column_aliases = node.expression.alias_column_names
654        elif isinstance(node, exp.Expression):
655            column_aliases = node.alias_column_names
656        else:
657            column_aliases = []
658
659        if column_aliases:
660            # If the source's columns are aliased, their aliases shadow the corresponding column names.
661            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
662            return [
663                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
664            ]
665        return columns
666
667    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
668        if self._source_columns is None:
669            self._source_columns = {
670                source_name: self.get_source_columns(source_name)
671                for source_name, source in itertools.chain(
672                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
673                )
674            }
675        return self._source_columns
676
677    def _get_unambiguous_columns(
678        self, source_columns: t.Dict[str, t.Sequence[str]]
679    ) -> t.Mapping[str, str]:
680        """
681        Find all the unambiguous columns in sources.
682
683        Args:
684            source_columns: Mapping of names to source columns.
685
686        Returns:
687            Mapping of column name to source name.
688        """
689        if not source_columns:
690            return {}
691
692        source_columns_pairs = list(source_columns.items())
693
694        first_table, first_columns = source_columns_pairs[0]
695
696        if len(source_columns_pairs) == 1:
697            # Performance optimization - avoid copying first_columns if there is only one table.
698            return SingleValuedMapping(first_columns, first_table)
699
700        unambiguous_columns = {col: first_table for col in first_columns}
701        all_columns = set(unambiguous_columns)
702
703        for table, columns in source_columns_pairs[1:]:
704            unique = set(columns)
705            ambiguous = all_columns.intersection(unique)
706            all_columns.update(columns)
707
708            for column in ambiguous:
709                unambiguous_columns.pop(column, None)
710            for column in unique.difference(ambiguous):
711                unambiguous_columns[column] = table
712
713        return unambiguous_columns
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, expand_stars: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
19def qualify_columns(
20    expression: exp.Expression,
21    schema: t.Dict | Schema,
22    expand_alias_refs: bool = True,
23    expand_stars: bool = True,
24    infer_schema: t.Optional[bool] = None,
25) -> exp.Expression:
26    """
27    Rewrite sqlglot AST to have fully qualified columns.
28
29    Example:
30        >>> import sqlglot
31        >>> schema = {"tbl": {"col": "INT"}}
32        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
33        >>> qualify_columns(expression, schema).sql()
34        'SELECT tbl.col AS col FROM tbl'
35
36    Args:
37        expression: Expression to qualify.
38        schema: Database schema.
39        expand_alias_refs: Whether to expand references to aliases.
40        expand_stars: Whether to expand star queries. This is a necessary step
41            for most of the optimizer's rules to work; do not set to False unless you
42            know what you're doing!
43        infer_schema: Whether to infer the schema if missing.
44
45    Returns:
46        The qualified expression.
47
48    Notes:
49        - Currently only handles a single PIVOT or UNPIVOT operator
50    """
51    schema = ensure_schema(schema)
52    infer_schema = schema.empty if infer_schema is None else infer_schema
53    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
54
55    for scope in traverse_scope(expression):
56        resolver = Resolver(scope, schema, infer_schema=infer_schema)
57        _pop_table_column_aliases(scope.ctes)
58        _pop_table_column_aliases(scope.derived_tables)
59        using_column_tables = _expand_using(scope, resolver)
60
61        if schema.empty and expand_alias_refs:
62            _expand_alias_refs(scope, resolver)
63
64        _qualify_columns(scope, resolver)
65
66        if not schema.empty and expand_alias_refs:
67            _expand_alias_refs(scope, resolver)
68
69        if not isinstance(scope.expression, exp.UDTF):
70            if expand_stars:
71                _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
72            qualify_outputs(scope)
73
74        _expand_group_by(scope)
75        _expand_order_by(scope, resolver)
76
77    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether to expand references to aliases.
  • expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing!
  • infer_schema: Whether to infer the schema if missing.
Returns:

The qualified expression.

Notes:
  • Currently only handles a single PIVOT or UNPIVOT operator
def validate_qualify_columns(expression: ~E) -> ~E:
 80def validate_qualify_columns(expression: E) -> E:
 81    """Raise an `OptimizeError` if any columns aren't qualified"""
 82    all_unqualified_columns = []
 83    for scope in traverse_scope(expression):
 84        if isinstance(scope.expression, exp.Select):
 85            unqualified_columns = scope.unqualified_columns
 86
 87            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 88                column = scope.external_columns[0]
 89                for_table = f" for table: '{column.table}'" if column.table else ""
 90                raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
 91
 92            if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
 93                # New columns produced by the UNPIVOT can't be qualified, but there may be columns
 94                # under the UNPIVOT's IN clause that can and should be qualified. We recompute
 95                # this list here to ensure those in the former category will be excluded.
 96                unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
 97                unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
 98
 99            all_unqualified_columns.extend(unqualified_columns)
100
101    if all_unqualified_columns:
102        raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
103
104    return expression

Raise an OptimizeError if any columns aren't qualified

def qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
502def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
503    """Ensure all output columns are aliased"""
504    if isinstance(scope_or_expression, exp.Expression):
505        scope = build_scope(scope_or_expression)
506        if not isinstance(scope, Scope):
507            return
508    else:
509        scope = scope_or_expression
510
511    new_selections = []
512    for i, (selection, aliased_column) in enumerate(
513        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
514    ):
515        if selection is None:
516            break
517
518        if isinstance(selection, exp.Subquery):
519            if not selection.output_name:
520                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
521        elif not isinstance(selection, exp.Alias) and not selection.is_star:
522            selection = alias(
523                selection,
524                alias=selection.output_name or f"_col_{i}",
525                copy=False,
526            )
527        if aliased_column:
528            selection.set("alias", exp.to_identifier(aliased_column))
529
530        new_selections.append(selection)
531
532    scope.expression.set("expressions", new_selections)

Ensure all output columns are aliased

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
535def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
536    """Makes sure all identifiers that need to be quoted are quoted."""
537    return expression.transform(
538        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
539    )

Makes sure all identifiers that need to be quoted are quoted.

def pushdown_cte_alias_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
542def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
543    """
544    Pushes down the CTE alias columns into the projection,
545
546    This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
547
548    Example:
549        >>> import sqlglot
550        >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
551        >>> pushdown_cte_alias_columns(expression).sql()
552        'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
553
554    Args:
555        expression: Expression to pushdown.
556
557    Returns:
558        The expression with the CTE aliases pushed down into the projection.
559    """
560    for cte in expression.find_all(exp.CTE):
561        if cte.alias_column_names:
562            new_expressions = []
563            for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
564                if isinstance(projection, exp.Alias):
565                    projection.set("alias", _alias)
566                else:
567                    projection = alias(projection, alias=_alias)
568                new_expressions.append(projection)
569            cte.this.set("expressions", new_expressions)
570
571    return expression

Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Arguments:
  • expression: Expression to pushdown.
Returns:

The expression with the CTE aliases pushed down into the projection.

class Resolver:
574class Resolver:
575    """
576    Helper for resolving columns.
577
578    This is a class so we can lazily load some things and easily share them across functions.
579    """
580
581    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
582        self.scope = scope
583        self.schema = schema
584        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
585        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
586        self._all_columns: t.Optional[t.Set[str]] = None
587        self._infer_schema = infer_schema
588
589    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
590        """
591        Get the table for a column name.
592
593        Args:
594            column_name: The column name to find the table for.
595        Returns:
596            The table name if it can be found/inferred.
597        """
598        if self._unambiguous_columns is None:
599            self._unambiguous_columns = self._get_unambiguous_columns(
600                self._get_all_source_columns()
601            )
602
603        table_name = self._unambiguous_columns.get(column_name)
604
605        if not table_name and self._infer_schema:
606            sources_without_schema = tuple(
607                source
608                for source, columns in self._get_all_source_columns().items()
609                if not columns or "*" in columns
610            )
611            if len(sources_without_schema) == 1:
612                table_name = sources_without_schema[0]
613
614        if table_name not in self.scope.selected_sources:
615            return exp.to_identifier(table_name)
616
617        node, _ = self.scope.selected_sources.get(table_name)
618
619        if isinstance(node, exp.Subqueryable):
620            while node and node.alias != table_name:
621                node = node.parent
622
623        node_alias = node.args.get("alias")
624        if node_alias:
625            return exp.to_identifier(node_alias.this)
626
627        return exp.to_identifier(table_name)
628
629    @property
630    def all_columns(self) -> t.Set[str]:
631        """All available columns of all sources in this scope"""
632        if self._all_columns is None:
633            self._all_columns = {
634                column for columns in self._get_all_source_columns().values() for column in columns
635            }
636        return self._all_columns
637
638    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
639        """Resolve the source columns for a given source `name`."""
640        if name not in self.scope.sources:
641            raise OptimizeError(f"Unknown table: {name}")
642
643        source = self.scope.sources[name]
644
645        if isinstance(source, exp.Table):
646            columns = self.schema.column_names(source, only_visible)
647        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
648            columns = source.expression.alias_column_names
649        else:
650            columns = source.expression.named_selects
651
652        node, _ = self.scope.selected_sources.get(name) or (None, None)
653        if isinstance(node, Scope):
654            column_aliases = node.expression.alias_column_names
655        elif isinstance(node, exp.Expression):
656            column_aliases = node.alias_column_names
657        else:
658            column_aliases = []
659
660        if column_aliases:
661            # If the source's columns are aliased, their aliases shadow the corresponding column names.
662            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
663            return [
664                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
665            ]
666        return columns
667
668    def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
669        if self._source_columns is None:
670            self._source_columns = {
671                source_name: self.get_source_columns(source_name)
672                for source_name, source in itertools.chain(
673                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
674                )
675            }
676        return self._source_columns
677
678    def _get_unambiguous_columns(
679        self, source_columns: t.Dict[str, t.Sequence[str]]
680    ) -> t.Mapping[str, str]:
681        """
682        Find all the unambiguous columns in sources.
683
684        Args:
685            source_columns: Mapping of names to source columns.
686
687        Returns:
688            Mapping of column name to source name.
689        """
690        if not source_columns:
691            return {}
692
693        source_columns_pairs = list(source_columns.items())
694
695        first_table, first_columns = source_columns_pairs[0]
696
697        if len(source_columns_pairs) == 1:
698            # Performance optimization - avoid copying first_columns if there is only one table.
699            return SingleValuedMapping(first_columns, first_table)
700
701        unambiguous_columns = {col: first_table for col in first_columns}
702        all_columns = set(unambiguous_columns)
703
704        for table, columns in source_columns_pairs[1:]:
705            unique = set(columns)
706            ambiguous = all_columns.intersection(unique)
707            all_columns.update(columns)
708
709            for column in ambiguous:
710                unambiguous_columns.pop(column, None)
711            for column in unique.difference(ambiguous):
712                unambiguous_columns[column] = table
713
714        return unambiguous_columns

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
581    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
582        self.scope = scope
583        self.schema = schema
584        self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None
585        self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
586        self._all_columns: t.Optional[t.Set[str]] = None
587        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
589    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
590        """
591        Get the table for a column name.
592
593        Args:
594            column_name: The column name to find the table for.
595        Returns:
596            The table name if it can be found/inferred.
597        """
598        if self._unambiguous_columns is None:
599            self._unambiguous_columns = self._get_unambiguous_columns(
600                self._get_all_source_columns()
601            )
602
603        table_name = self._unambiguous_columns.get(column_name)
604
605        if not table_name and self._infer_schema:
606            sources_without_schema = tuple(
607                source
608                for source, columns in self._get_all_source_columns().items()
609                if not columns or "*" in columns
610            )
611            if len(sources_without_schema) == 1:
612                table_name = sources_without_schema[0]
613
614        if table_name not in self.scope.selected_sources:
615            return exp.to_identifier(table_name)
616
617        node, _ = self.scope.selected_sources.get(table_name)
618
619        if isinstance(node, exp.Subqueryable):
620            while node and node.alias != table_name:
621                node = node.parent
622
623        node_alias = node.args.get("alias")
624        if node_alias:
625            return exp.to_identifier(node_alias.this)
626
627        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]
629    @property
630    def all_columns(self) -> t.Set[str]:
631        """All available columns of all sources in this scope"""
632        if self._all_columns is None:
633            self._all_columns = {
634                column for columns in self._get_all_source_columns().values() for column in columns
635            }
636        return self._all_columns

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> Sequence[str]:
638    def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
639        """Resolve the source columns for a given source `name`."""
640        if name not in self.scope.sources:
641            raise OptimizeError(f"Unknown table: {name}")
642
643        source = self.scope.sources[name]
644
645        if isinstance(source, exp.Table):
646            columns = self.schema.column_names(source, only_visible)
647        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
648            columns = source.expression.alias_column_names
649        else:
650            columns = source.expression.named_selects
651
652        node, _ = self.scope.selected_sources.get(name) or (None, None)
653        if isinstance(node, Scope):
654            column_aliases = node.expression.alias_column_names
655        elif isinstance(node, exp.Expression):
656            column_aliases = node.alias_column_names
657        else:
658            column_aliases = []
659
660        if column_aliases:
661            # If the source's columns are aliased, their aliases shadow the corresponding column names.
662            # This can be expensive if there are lots of columns, so only do this if column_aliases exist.
663            return [
664                alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
665            ]
666        return columns

Resolve the source columns for a given source name.