Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.errors import OptimizeError
  8from sqlglot.helper import seq_get
  9from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
 10from sqlglot.schema import Schema, ensure_schema
 11
 12
 13def qualify_columns(
 14    expression: exp.Expression,
 15    schema: dict | Schema,
 16    expand_alias_refs: bool = True,
 17    infer_schema: t.Optional[bool] = None,
 18) -> exp.Expression:
 19    """
 20    Rewrite sqlglot AST to have fully qualified columns.
 21
 22    Example:
 23        >>> import sqlglot
 24        >>> schema = {"tbl": {"col": "INT"}}
 25        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 26        >>> qualify_columns(expression, schema).sql()
 27        'SELECT tbl.col AS col FROM tbl'
 28
 29    Args:
 30        expression: expression to qualify
 31        schema: Database schema
 32        expand_alias_refs: whether or not to expand references to aliases
 33        infer_schema: whether or not to infer the schema if missing
 34    Returns:
 35        sqlglot.Expression: qualified expression
 36    """
 37    schema = ensure_schema(schema)
 38    infer_schema = schema.empty if infer_schema is None else infer_schema
 39
 40    for scope in traverse_scope(expression):
 41        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 42        _pop_table_column_aliases(scope.ctes)
 43        _pop_table_column_aliases(scope.derived_tables)
 44        using_column_tables = _expand_using(scope, resolver)
 45
 46        if schema.empty and expand_alias_refs:
 47            _expand_alias_refs(scope, resolver)
 48
 49        _qualify_columns(scope, resolver)
 50
 51        if not schema.empty and expand_alias_refs:
 52            _expand_alias_refs(scope, resolver)
 53
 54        if not isinstance(scope.expression, exp.UDTF):
 55            _expand_stars(scope, resolver, using_column_tables)
 56            _qualify_outputs(scope)
 57        _expand_group_by(scope, resolver)
 58        _expand_order_by(scope)
 59
 60    return expression
 61
 62
 63def validate_qualify_columns(expression):
 64    """Raise an `OptimizeError` if any columns aren't qualified"""
 65    unqualified_columns = []
 66    for scope in traverse_scope(expression):
 67        if isinstance(scope.expression, exp.Select):
 68            unqualified_columns.extend(scope.unqualified_columns)
 69            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 70                column = scope.external_columns[0]
 71                raise OptimizeError(
 72                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
 73                )
 74
 75    if unqualified_columns:
 76        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
 77    return expression
 78
 79
 80def _pop_table_column_aliases(derived_tables):
 81    """
 82    Remove table column aliases.
 83
 84    (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 85    """
 86    for derived_table in derived_tables:
 87        table_alias = derived_table.args.get("alias")
 88        if table_alias:
 89            table_alias.args.pop("columns", None)
 90
 91
 92def _expand_using(scope, resolver):
 93    joins = list(scope.find_all(exp.Join))
 94    names = {join.this.alias for join in joins}
 95    ordered = [key for key in scope.selected_sources if key not in names]
 96
 97    # Mapping of automatically joined column names to an ordered set of source names (dict).
 98    column_tables = {}
 99
100    for join in joins:
101        using = join.args.get("using")
102
103        if not using:
104            continue
105
106        join_table = join.this.alias_or_name
107
108        columns = {}
109
110        for k in scope.selected_sources:
111            if k in ordered:
112                for column in resolver.get_source_columns(k):
113                    if column not in columns:
114                        columns[column] = k
115
116        source_table = ordered[-1]
117        ordered.append(join_table)
118        join_columns = resolver.get_source_columns(join_table)
119        conditions = []
120
121        for identifier in using:
122            identifier = identifier.name
123            table = columns.get(identifier)
124
125            if not table or identifier not in join_columns:
126                if columns and join_columns:
127                    raise OptimizeError(f"Cannot automatically join: {identifier}")
128
129            table = table or source_table
130            conditions.append(
131                exp.condition(
132                    exp.EQ(
133                        this=exp.column(identifier, table=table),
134                        expression=exp.column(identifier, table=join_table),
135                    )
136                )
137            )
138
139            # Set all values in the dict to None, because we only care about the key ordering
140            tables = column_tables.setdefault(identifier, {})
141            if table not in tables:
142                tables[table] = None
143            if join_table not in tables:
144                tables[join_table] = None
145
146        join.args.pop("using")
147        join.set("on", exp.and_(*conditions, copy=False))
148
149    if column_tables:
150        for column in scope.columns:
151            if not column.table and column.name in column_tables:
152                tables = column_tables[column.name]
153                coalesce = [exp.column(column.name, table=table) for table in tables]
154                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
155
156                # Ensure selects keep their output name
157                if isinstance(column.parent, exp.Select):
158                    replacement = alias(replacement, alias=column.name, copy=False)
159
160                scope.replace(column, replacement)
161
162    return column_tables
163
164
165def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
166    expression = scope.expression
167
168    if not isinstance(expression, exp.Select):
169        return
170
171    alias_to_expression: t.Dict[str, exp.Expression] = {}
172
173    def replace_columns(
174        node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False
175    ):
176        if not node:
177            return
178
179        for column, *_ in walk_in_scope(node):
180            if not isinstance(column, exp.Column):
181                continue
182            table = resolver.get_table(column.name) if resolve_agg and not column.table else None
183            if table and column.find_ancestor(exp.AggFunc):
184                column.set("table", table)
185            elif expand and not column.table and column.name in alias_to_expression:
186                column.replace(alias_to_expression[column.name].copy())
187
188    for projection in scope.selects:
189        replace_columns(projection)
190
191        if isinstance(projection, exp.Alias):
192            alias_to_expression[projection.alias] = projection.this
193
194    replace_columns(expression.args.get("where"))
195    replace_columns(expression.args.get("group"))
196    replace_columns(expression.args.get("having"), resolve_agg=True)
197    replace_columns(expression.args.get("order"), expand=False, resolve_agg=True)
198    scope.clear_cache()
199
200
201def _expand_group_by(scope, resolver):
202    group = scope.expression.args.get("group")
203    if not group:
204        return
205
206    group.set("expressions", _expand_positional_references(scope, group.expressions))
207    scope.expression.set("group", group)
208
209
210def _expand_order_by(scope):
211    order = scope.expression.args.get("order")
212    if not order:
213        return
214
215    ordereds = order.expressions
216    for ordered, new_expression in zip(
217        ordereds,
218        _expand_positional_references(scope, (o.this for o in ordereds)),
219    ):
220        ordered.set("this", new_expression)
221
222
223def _expand_positional_references(scope, expressions):
224    new_nodes = []
225    for node in expressions:
226        if node.is_int:
227            try:
228                select = scope.selects[int(node.name) - 1]
229            except IndexError:
230                raise OptimizeError(f"Unknown output column: {node.name}")
231            if isinstance(select, exp.Alias):
232                select = select.this
233            new_nodes.append(select.copy())
234            scope.clear_cache()
235        else:
236            new_nodes.append(node)
237
238    return new_nodes
239
240
241def _qualify_columns(scope, resolver):
242    """Disambiguate columns, ensuring each column specifies a source"""
243    for column in scope.columns:
244        column_table = column.table
245        column_name = column.name
246
247        if column_table and column_table in scope.sources:
248            source_columns = resolver.get_source_columns(column_table)
249            if source_columns and column_name not in source_columns and "*" not in source_columns:
250                raise OptimizeError(f"Unknown column: {column_name}")
251
252        if not column_table:
253            if scope.pivots and not column.find_ancestor(exp.Pivot):
254                # If the column is under the Pivot expression, we need to qualify it
255                # using the name of the pivoted source instead of the pivot's alias
256                column.set("table", exp.to_identifier(scope.pivots[0].alias))
257                continue
258
259            column_table = resolver.get_table(column_name)
260
261            # column_table can be a '' because bigquery unnest has no table alias
262            if column_table:
263                column.set("table", column_table)
264        elif column_table not in scope.sources and (
265            not scope.parent or column_table not in scope.parent.sources
266        ):
267            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
268            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
269
270            root, *parts = column.parts
271
272            if root.name in scope.sources:
273                # struct is already qualified, but we still need to change the AST representation
274                column_table = root
275                root, *parts = parts
276            else:
277                column_table = resolver.get_table(root.name)
278
279            if column_table:
280                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
281
282    for pivot in scope.pivots:
283        for column in pivot.find_all(exp.Column):
284            if not column.table and column.name in resolver.all_columns:
285                column_table = resolver.get_table(column.name)
286                if column_table:
287                    column.set("table", column_table)
288
289
290def _expand_stars(scope, resolver, using_column_tables):
291    """Expand stars to lists of column selections"""
292
293    new_selections = []
294    except_columns = {}
295    replace_columns = {}
296    coalesced_columns = set()
297
298    # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
299    pivot_columns = None
300    pivot_output_columns = None
301    pivot = seq_get(scope.pivots, 0)
302
303    has_pivoted_source = pivot and not pivot.args.get("unpivot")
304    if has_pivoted_source:
305        pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
306
307        pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
308        if not pivot_output_columns:
309            pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
310
311    for expression in scope.selects:
312        if isinstance(expression, exp.Star):
313            tables = list(scope.selected_sources)
314            _add_except_columns(expression, tables, except_columns)
315            _add_replace_columns(expression, tables, replace_columns)
316        elif expression.is_star:
317            tables = [expression.table]
318            _add_except_columns(expression.this, tables, except_columns)
319            _add_replace_columns(expression.this, tables, replace_columns)
320        else:
321            new_selections.append(expression)
322            continue
323
324        for table in tables:
325            if table not in scope.sources:
326                raise OptimizeError(f"Unknown table: {table}")
327
328            columns = resolver.get_source_columns(table, only_visible=True)
329
330            if columns and "*" not in columns:
331                if has_pivoted_source:
332                    implicit_columns = [col for col in columns if col not in pivot_columns]
333                    new_selections.extend(
334                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
335                        for name in implicit_columns + pivot_output_columns
336                    )
337                    continue
338
339                table_id = id(table)
340                for name in columns:
341                    if name in using_column_tables and table in using_column_tables[name]:
342                        if name in coalesced_columns:
343                            continue
344
345                        coalesced_columns.add(name)
346                        tables = using_column_tables[name]
347                        coalesce = [exp.column(name, table=table) for table in tables]
348
349                        new_selections.append(
350                            alias(
351                                exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
352                                alias=name,
353                                copy=False,
354                            )
355                        )
356                    elif name not in except_columns.get(table_id, set()):
357                        alias_ = replace_columns.get(table_id, {}).get(name, name)
358                        column = exp.column(name, table=table)
359                        new_selections.append(
360                            alias(column, alias_, copy=False) if alias_ != name else column
361                        )
362            else:
363                return
364
365    scope.expression.set("expressions", new_selections)
366
367
368def _add_except_columns(expression, tables, except_columns):
369    except_ = expression.args.get("except")
370
371    if not except_:
372        return
373
374    columns = {e.name for e in except_}
375
376    for table in tables:
377        except_columns[id(table)] = columns
378
379
380def _add_replace_columns(expression, tables, replace_columns):
381    replace = expression.args.get("replace")
382
383    if not replace:
384        return
385
386    columns = {e.this.name: e.alias for e in replace}
387
388    for table in tables:
389        replace_columns[id(table)] = columns
390
391
392def _qualify_outputs(scope):
393    """Ensure all output columns are aliased"""
394    new_selections = []
395
396    for i, (selection, aliased_column) in enumerate(
397        itertools.zip_longest(scope.selects, scope.outer_column_list)
398    ):
399        if isinstance(selection, exp.Subquery):
400            if not selection.output_name:
401                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
402        elif not isinstance(selection, exp.Alias) and not selection.is_star:
403            selection = alias(
404                selection,
405                alias=selection.output_name or f"_col_{i}",
406                quoted=True
407                if isinstance(selection, exp.Column) and selection.this.quoted
408                else None,
409            )
410        if aliased_column:
411            selection.set("alias", exp.to_identifier(aliased_column))
412
413        new_selections.append(selection)
414
415    scope.expression.set("expressions", new_selections)
416
417
418class Resolver:
419    """
420    Helper for resolving columns.
421
422    This is a class so we can lazily load some things and easily share them across functions.
423    """
424
425    def __init__(self, scope, schema, infer_schema: bool = True):
426        self.scope = scope
427        self.schema = schema
428        self._source_columns = None
429        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
430        self._all_columns = None
431        self._infer_schema = infer_schema
432
433    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
434        """
435        Get the table for a column name.
436
437        Args:
438            column_name: The column name to find the table for.
439        Returns:
440            The table name if it can be found/inferred.
441        """
442        if self._unambiguous_columns is None:
443            self._unambiguous_columns = self._get_unambiguous_columns(
444                self._get_all_source_columns()
445            )
446
447        table_name = self._unambiguous_columns.get(column_name)
448
449        if not table_name and self._infer_schema:
450            sources_without_schema = tuple(
451                source
452                for source, columns in self._get_all_source_columns().items()
453                if not columns or "*" in columns
454            )
455            if len(sources_without_schema) == 1:
456                table_name = sources_without_schema[0]
457
458        if table_name not in self.scope.selected_sources:
459            return exp.to_identifier(table_name)
460
461        node, _ = self.scope.selected_sources.get(table_name)
462
463        if isinstance(node, exp.Subqueryable):
464            while node and node.alias != table_name:
465                node = node.parent
466
467        node_alias = node.args.get("alias")
468        if node_alias:
469            return exp.to_identifier(node_alias.this)
470
471        return exp.to_identifier(
472            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
473        )
474
475    @property
476    def all_columns(self):
477        """All available columns of all sources in this scope"""
478        if self._all_columns is None:
479            self._all_columns = {
480                column for columns in self._get_all_source_columns().values() for column in columns
481            }
482        return self._all_columns
483
484    def get_source_columns(self, name, only_visible=False):
485        """Resolve the source columns for a given source `name`"""
486        if name not in self.scope.sources:
487            raise OptimizeError(f"Unknown table: {name}")
488
489        source = self.scope.sources[name]
490
491        # If referencing a table, return the columns from the schema
492        if isinstance(source, exp.Table):
493            return self.schema.column_names(source, only_visible)
494
495        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
496            return source.expression.alias_column_names
497
498        # Otherwise, if referencing another scope, return that scope's named selects
499        return source.expression.named_selects
500
501    def _get_all_source_columns(self):
502        if self._source_columns is None:
503            self._source_columns = {
504                k: self.get_source_columns(k)
505                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
506            }
507        return self._source_columns
508
509    def _get_unambiguous_columns(self, source_columns):
510        """
511        Find all the unambiguous columns in sources.
512
513        Args:
514            source_columns (dict): Mapping of names to source columns
515        Returns:
516            dict: Mapping of column name to source name
517        """
518        if not source_columns:
519            return {}
520
521        source_columns = list(source_columns.items())
522
523        first_table, first_columns = source_columns[0]
524        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
525        all_columns = set(unambiguous_columns)
526
527        for table, columns in source_columns[1:]:
528            unique = self._find_unique_columns(columns)
529            ambiguous = set(all_columns).intersection(unique)
530            all_columns.update(columns)
531            for column in ambiguous:
532                unambiguous_columns.pop(column, None)
533            for column in unique.difference(ambiguous):
534                unambiguous_columns[column] = table
535
536        return unambiguous_columns
537
538    @staticmethod
539    def _find_unique_columns(columns):
540        """
541        Find the unique columns in a list of columns.
542
543        Example:
544            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
545            ['a', 'c']
546
547        This is necessary because duplicate column names are ambiguous.
548        """
549        counts = {}
550        for column in columns:
551            counts[column] = counts.get(column, 0) + 1
552        return {column for column, count in counts.items() if count == 1}
def qualify_columns( expression: sqlglot.expressions.Expression, schema: dict | sqlglot.schema.Schema, expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
14def qualify_columns(
15    expression: exp.Expression,
16    schema: dict | Schema,
17    expand_alias_refs: bool = True,
18    infer_schema: t.Optional[bool] = None,
19) -> exp.Expression:
20    """
21    Rewrite sqlglot AST to have fully qualified columns.
22
23    Example:
24        >>> import sqlglot
25        >>> schema = {"tbl": {"col": "INT"}}
26        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
27        >>> qualify_columns(expression, schema).sql()
28        'SELECT tbl.col AS col FROM tbl'
29
30    Args:
31        expression: expression to qualify
32        schema: Database schema
33        expand_alias_refs: whether or not to expand references to aliases
34        infer_schema: whether or not to infer the schema if missing
35    Returns:
36        sqlglot.Expression: qualified expression
37    """
38    schema = ensure_schema(schema)
39    infer_schema = schema.empty if infer_schema is None else infer_schema
40
41    for scope in traverse_scope(expression):
42        resolver = Resolver(scope, schema, infer_schema=infer_schema)
43        _pop_table_column_aliases(scope.ctes)
44        _pop_table_column_aliases(scope.derived_tables)
45        using_column_tables = _expand_using(scope, resolver)
46
47        if schema.empty and expand_alias_refs:
48            _expand_alias_refs(scope, resolver)
49
50        _qualify_columns(scope, resolver)
51
52        if not schema.empty and expand_alias_refs:
53            _expand_alias_refs(scope, resolver)
54
55        if not isinstance(scope.expression, exp.UDTF):
56            _expand_stars(scope, resolver, using_column_tables)
57            _qualify_outputs(scope)
58        _expand_group_by(scope, resolver)
59        _expand_order_by(scope)
60
61    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: expression to qualify
  • schema: Database schema
  • expand_alias_refs: whether or not to expand references to aliases
  • infer_schema: whether or not to infer the schema if missing
Returns:

sqlglot.Expression: qualified expression

def validate_qualify_columns(expression):
64def validate_qualify_columns(expression):
65    """Raise an `OptimizeError` if any columns aren't qualified"""
66    unqualified_columns = []
67    for scope in traverse_scope(expression):
68        if isinstance(scope.expression, exp.Select):
69            unqualified_columns.extend(scope.unqualified_columns)
70            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
71                column = scope.external_columns[0]
72                raise OptimizeError(
73                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
74                )
75
76    if unqualified_columns:
77        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
78    return expression

Raise an OptimizeError if any columns aren't qualified

class Resolver:
419class Resolver:
420    """
421    Helper for resolving columns.
422
423    This is a class so we can lazily load some things and easily share them across functions.
424    """
425
426    def __init__(self, scope, schema, infer_schema: bool = True):
427        self.scope = scope
428        self.schema = schema
429        self._source_columns = None
430        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
431        self._all_columns = None
432        self._infer_schema = infer_schema
433
434    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
435        """
436        Get the table for a column name.
437
438        Args:
439            column_name: The column name to find the table for.
440        Returns:
441            The table name if it can be found/inferred.
442        """
443        if self._unambiguous_columns is None:
444            self._unambiguous_columns = self._get_unambiguous_columns(
445                self._get_all_source_columns()
446            )
447
448        table_name = self._unambiguous_columns.get(column_name)
449
450        if not table_name and self._infer_schema:
451            sources_without_schema = tuple(
452                source
453                for source, columns in self._get_all_source_columns().items()
454                if not columns or "*" in columns
455            )
456            if len(sources_without_schema) == 1:
457                table_name = sources_without_schema[0]
458
459        if table_name not in self.scope.selected_sources:
460            return exp.to_identifier(table_name)
461
462        node, _ = self.scope.selected_sources.get(table_name)
463
464        if isinstance(node, exp.Subqueryable):
465            while node and node.alias != table_name:
466                node = node.parent
467
468        node_alias = node.args.get("alias")
469        if node_alias:
470            return exp.to_identifier(node_alias.this)
471
472        return exp.to_identifier(
473            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
474        )
475
476    @property
477    def all_columns(self):
478        """All available columns of all sources in this scope"""
479        if self._all_columns is None:
480            self._all_columns = {
481                column for columns in self._get_all_source_columns().values() for column in columns
482            }
483        return self._all_columns
484
485    def get_source_columns(self, name, only_visible=False):
486        """Resolve the source columns for a given source `name`"""
487        if name not in self.scope.sources:
488            raise OptimizeError(f"Unknown table: {name}")
489
490        source = self.scope.sources[name]
491
492        # If referencing a table, return the columns from the schema
493        if isinstance(source, exp.Table):
494            return self.schema.column_names(source, only_visible)
495
496        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
497            return source.expression.alias_column_names
498
499        # Otherwise, if referencing another scope, return that scope's named selects
500        return source.expression.named_selects
501
502    def _get_all_source_columns(self):
503        if self._source_columns is None:
504            self._source_columns = {
505                k: self.get_source_columns(k)
506                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
507            }
508        return self._source_columns
509
510    def _get_unambiguous_columns(self, source_columns):
511        """
512        Find all the unambiguous columns in sources.
513
514        Args:
515            source_columns (dict): Mapping of names to source columns
516        Returns:
517            dict: Mapping of column name to source name
518        """
519        if not source_columns:
520            return {}
521
522        source_columns = list(source_columns.items())
523
524        first_table, first_columns = source_columns[0]
525        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
526        all_columns = set(unambiguous_columns)
527
528        for table, columns in source_columns[1:]:
529            unique = self._find_unique_columns(columns)
530            ambiguous = set(all_columns).intersection(unique)
531            all_columns.update(columns)
532            for column in ambiguous:
533                unambiguous_columns.pop(column, None)
534            for column in unique.difference(ambiguous):
535                unambiguous_columns[column] = table
536
537        return unambiguous_columns
538
539    @staticmethod
540    def _find_unique_columns(columns):
541        """
542        Find the unique columns in a list of columns.
543
544        Example:
545            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
546            ['a', 'c']
547
548        This is necessary because duplicate column names are ambiguous.
549        """
550        counts = {}
551        for column in columns:
552            counts[column] = counts.get(column, 0) + 1
553        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver(scope, schema, infer_schema: bool = True)
426    def __init__(self, scope, schema, infer_schema: bool = True):
427        self.scope = scope
428        self.schema = schema
429        self._source_columns = None
430        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
431        self._all_columns = None
432        self._infer_schema = infer_schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
434    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
435        """
436        Get the table for a column name.
437
438        Args:
439            column_name: The column name to find the table for.
440        Returns:
441            The table name if it can be found/inferred.
442        """
443        if self._unambiguous_columns is None:
444            self._unambiguous_columns = self._get_unambiguous_columns(
445                self._get_all_source_columns()
446            )
447
448        table_name = self._unambiguous_columns.get(column_name)
449
450        if not table_name and self._infer_schema:
451            sources_without_schema = tuple(
452                source
453                for source, columns in self._get_all_source_columns().items()
454                if not columns or "*" in columns
455            )
456            if len(sources_without_schema) == 1:
457                table_name = sources_without_schema[0]
458
459        if table_name not in self.scope.selected_sources:
460            return exp.to_identifier(table_name)
461
462        node, _ = self.scope.selected_sources.get(table_name)
463
464        if isinstance(node, exp.Subqueryable):
465            while node and node.alias != table_name:
466                node = node.parent
467
468        node_alias = node.args.get("alias")
469        if node_alias:
470            return exp.to_identifier(node_alias.this)
471
472        return exp.to_identifier(
473            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
474        )

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns

All available columns of all sources in this scope

def get_source_columns(self, name, only_visible=False):
485    def get_source_columns(self, name, only_visible=False):
486        """Resolve the source columns for a given source `name`"""
487        if name not in self.scope.sources:
488            raise OptimizeError(f"Unknown table: {name}")
489
490        source = self.scope.sources[name]
491
492        # If referencing a table, return the columns from the schema
493        if isinstance(source, exp.Table):
494            return self.schema.column_names(source, only_visible)
495
496        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
497            return source.expression.alias_column_names
498
499        # Otherwise, if referencing another scope, return that scope's named selects
500        return source.expression.named_selects

Resolve the source columns for a given source name