Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def preprocess(
 13    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
 14) -> t.Callable[[Generator, exp.Expression], str]:
 15    """
 16    Creates a new transform by chaining a sequence of transformations and converts the resulting
 17    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
 18    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
 19
 20    Args:
 21        transforms: sequence of transform functions. These will be called in order.
 22
 23    Returns:
 24        Function that can be used as a generator transform.
 25    """
 26
 27    def _to_sql(self, expression: exp.Expression) -> str:
 28        expression_type = type(expression)
 29
 30        expression = transforms[0](expression)
 31        for transform in transforms[1:]:
 32            expression = transform(expression)
 33
 34        _sql_handler = getattr(self, expression.key + "_sql", None)
 35        if _sql_handler:
 36            return _sql_handler(expression)
 37
 38        transforms_handler = self.TRANSFORMS.get(type(expression))
 39        if transforms_handler:
 40            if expression_type is type(expression):
 41                if isinstance(expression, exp.Func):
 42                    return self.function_fallback_sql(expression)
 43
 44                # Ensures we don't enter an infinite loop. This can happen when the original expression
 45                # has the same type as the final expression and there's no _sql method available for it,
 46                # because then it'd re-enter _to_sql.
 47                raise ValueError(
 48                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
 49                )
 50
 51            return transforms_handler(self, expression)
 52
 53        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
 54
 55    return _to_sql
 56
 57
 58def unnest_generate_date_array_using_recursive_cte(
 59    bubble_up_recursive_cte: bool = False,
 60) -> t.Callable[[exp.Expression], exp.Expression]:
 61    def _unnest_generate_date_array_using_recursive_cte(
 62        expression: exp.Expression,
 63    ) -> exp.Expression:
 64        if (
 65            isinstance(expression, exp.Unnest)
 66            and isinstance(expression.parent, (exp.From, exp.Join))
 67            and len(expression.expressions) == 1
 68            and isinstance(expression.expressions[0], exp.GenerateDateArray)
 69        ):
 70            generate_date_array = expression.expressions[0]
 71            start = generate_date_array.args.get("start")
 72            end = generate_date_array.args.get("end")
 73            step = generate_date_array.args.get("step")
 74
 75            if not start or not end or not isinstance(step, exp.Interval):
 76                return expression
 77
 78            start = exp.cast(start, "date")
 79            date_add = exp.func(
 80                "date_add", "date_value", exp.Literal.number(step.name), step.args.get("unit")
 81            )
 82            cast_date_add = exp.cast(date_add, "date")
 83
 84            base_query = exp.select(start.as_("date_value"))
 85            recursive_query = (
 86                exp.select(cast_date_add.as_("date_value"))
 87                .from_("_generated_dates")
 88                .where(cast_date_add < exp.cast(end, "date"))
 89            )
 90            cte_query = base_query.union(recursive_query, distinct=False)
 91            generate_dates_query = exp.select("date_value").from_("_generated_dates")
 92
 93            query_to_add_cte: exp.Query = generate_dates_query
 94
 95            if bubble_up_recursive_cte:
 96                parent: t.Optional[exp.Expression] = expression.parent
 97
 98                while parent:
 99                    if isinstance(parent, exp.Query):
100                        query_to_add_cte = parent
101                    parent = parent.parent
102
103            query_to_add_cte.with_(
104                "_generated_dates(date_value)", as_=cte_query, recursive=True, copy=False
105            )
106            return generate_dates_query.subquery("_generated_dates")
107
108        return expression
109
110    return _unnest_generate_date_array_using_recursive_cte
111
112
113def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
114    """Unnests GENERATE_SERIES or SEQUENCE table references."""
115    this = expression.this
116    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
117        unnest = exp.Unnest(expressions=[this])
118        if expression.alias:
119            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
120
121        return unnest
122
123    return expression
124
125
126def unalias_group(expression: exp.Expression) -> exp.Expression:
127    """
128    Replace references to select aliases in GROUP BY clauses.
129
130    Example:
131        >>> import sqlglot
132        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
133        'SELECT a AS b FROM x GROUP BY 1'
134
135    Args:
136        expression: the expression that will be transformed.
137
138    Returns:
139        The transformed expression.
140    """
141    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
142        aliased_selects = {
143            e.alias: i
144            for i, e in enumerate(expression.parent.expressions, start=1)
145            if isinstance(e, exp.Alias)
146        }
147
148        for group_by in expression.expressions:
149            if (
150                isinstance(group_by, exp.Column)
151                and not group_by.table
152                and group_by.name in aliased_selects
153            ):
154                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
155
156    return expression
157
158
159def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
160    """
161    Convert SELECT DISTINCT ON statements to a subquery with a window function.
162
163    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
164
165    Args:
166        expression: the expression that will be transformed.
167
168    Returns:
169        The transformed expression.
170    """
171    if (
172        isinstance(expression, exp.Select)
173        and expression.args.get("distinct")
174        and expression.args["distinct"].args.get("on")
175        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
176    ):
177        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
178        outer_selects = expression.selects
179        row_number = find_new_name(expression.named_selects, "_row_number")
180        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
181        order = expression.args.get("order")
182
183        if order:
184            window.set("order", order.pop())
185        else:
186            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
187
188        window = exp.alias_(window, row_number)
189        expression.select(window, copy=False)
190
191        return (
192            exp.select(*outer_selects, copy=False)
193            .from_(expression.subquery("_t", copy=False), copy=False)
194            .where(exp.column(row_number).eq(1), copy=False)
195        )
196
197    return expression
198
199
200def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
201    """
202    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
203
204    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
205    https://docs.snowflake.com/en/sql-reference/constructs/qualify
206
207    Some dialects don't support window functions in the WHERE clause, so we need to include them as
208    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
209    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
210    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
211    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
212    corresponding expression to avoid creating invalid column references.
213    """
214    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
215        taken = set(expression.named_selects)
216        for select in expression.selects:
217            if not select.alias_or_name:
218                alias = find_new_name(taken, "_c")
219                select.replace(exp.alias_(select, alias))
220                taken.add(alias)
221
222        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
223            alias_or_name = select.alias_or_name
224            identifier = select.args.get("alias") or select.this
225            if isinstance(identifier, exp.Identifier):
226                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
227            return alias_or_name
228
229        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
230        qualify_filters = expression.args["qualify"].pop().this
231        expression_by_alias = {
232            select.alias: select.this
233            for select in expression.selects
234            if isinstance(select, exp.Alias)
235        }
236
237        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
238        for select_candidate in qualify_filters.find_all(select_candidates):
239            if isinstance(select_candidate, exp.Window):
240                if expression_by_alias:
241                    for column in select_candidate.find_all(exp.Column):
242                        expr = expression_by_alias.get(column.name)
243                        if expr:
244                            column.replace(expr)
245
246                alias = find_new_name(expression.named_selects, "_w")
247                expression.select(exp.alias_(select_candidate, alias), copy=False)
248                column = exp.column(alias)
249
250                if isinstance(select_candidate.parent, exp.Qualify):
251                    qualify_filters = column
252                else:
253                    select_candidate.replace(column)
254            elif select_candidate.name not in expression.named_selects:
255                expression.select(select_candidate.copy(), copy=False)
256
257        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
258            qualify_filters, copy=False
259        )
260
261    return expression
262
263
264def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
265    """
266    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
267    other expressions. This transforms removes the precision from parameterized types in expressions.
268    """
269    for node in expression.find_all(exp.DataType):
270        node.set(
271            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
272        )
273
274    return expression
275
276
277def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
278    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
279    from sqlglot.optimizer.scope import find_all_in_scope
280
281    if isinstance(expression, exp.Select):
282        unnest_aliases = {
283            unnest.alias
284            for unnest in find_all_in_scope(expression, exp.Unnest)
285            if isinstance(unnest.parent, (exp.From, exp.Join))
286        }
287        if unnest_aliases:
288            for column in expression.find_all(exp.Column):
289                if column.table in unnest_aliases:
290                    column.set("table", None)
291                elif column.db in unnest_aliases:
292                    column.set("db", None)
293
294    return expression
295
296
297def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
298    """Convert cross join unnest into lateral view explode."""
299    if isinstance(expression, exp.Select):
300        from_ = expression.args.get("from")
301
302        if from_ and isinstance(from_.this, exp.Unnest):
303            unnest = from_.this
304            alias = unnest.args.get("alias")
305            udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
306            this, *expressions = unnest.expressions
307            unnest.replace(
308                exp.Table(
309                    this=udtf(
310                        this=this,
311                        expressions=expressions,
312                    ),
313                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
314                )
315            )
316
317        for join in expression.args.get("joins") or []:
318            unnest = join.this
319
320            if isinstance(unnest, exp.Unnest):
321                alias = unnest.args.get("alias")
322                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
323
324                expression.args["joins"].remove(join)
325
326                for e, column in zip(unnest.expressions, alias.columns if alias else []):
327                    expression.append(
328                        "laterals",
329                        exp.Lateral(
330                            this=udtf(this=e),
331                            view=True,
332                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
333                        ),
334                    )
335
336    return expression
337
338
339def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
340    """Convert explode/posexplode into unnest."""
341
342    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
343        if isinstance(expression, exp.Select):
344            from sqlglot.optimizer.scope import Scope
345
346            taken_select_names = set(expression.named_selects)
347            taken_source_names = {name for name, _ in Scope(expression).references}
348
349            def new_name(names: t.Set[str], name: str) -> str:
350                name = find_new_name(names, name)
351                names.add(name)
352                return name
353
354            arrays: t.List[exp.Condition] = []
355            series_alias = new_name(taken_select_names, "pos")
356            series = exp.alias_(
357                exp.Unnest(
358                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
359                ),
360                new_name(taken_source_names, "_u"),
361                table=[series_alias],
362            )
363
364            # we use list here because expression.selects is mutated inside the loop
365            for select in list(expression.selects):
366                explode = select.find(exp.Explode)
367
368                if explode:
369                    pos_alias = ""
370                    explode_alias = ""
371
372                    if isinstance(select, exp.Alias):
373                        explode_alias = select.args["alias"]
374                        alias = select
375                    elif isinstance(select, exp.Aliases):
376                        pos_alias = select.aliases[0]
377                        explode_alias = select.aliases[1]
378                        alias = select.replace(exp.alias_(select.this, "", copy=False))
379                    else:
380                        alias = select.replace(exp.alias_(select, ""))
381                        explode = alias.find(exp.Explode)
382                        assert explode
383
384                    is_posexplode = isinstance(explode, exp.Posexplode)
385                    explode_arg = explode.this
386
387                    if isinstance(explode, exp.ExplodeOuter):
388                        bracket = explode_arg[0]
389                        bracket.set("safe", True)
390                        bracket.set("offset", True)
391                        explode_arg = exp.func(
392                            "IF",
393                            exp.func(
394                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
395                            ).eq(0),
396                            exp.array(bracket, copy=False),
397                            explode_arg,
398                        )
399
400                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
401                    if isinstance(explode_arg, exp.Column):
402                        taken_select_names.add(explode_arg.output_name)
403
404                    unnest_source_alias = new_name(taken_source_names, "_u")
405
406                    if not explode_alias:
407                        explode_alias = new_name(taken_select_names, "col")
408
409                        if is_posexplode:
410                            pos_alias = new_name(taken_select_names, "pos")
411
412                    if not pos_alias:
413                        pos_alias = new_name(taken_select_names, "pos")
414
415                    alias.set("alias", exp.to_identifier(explode_alias))
416
417                    series_table_alias = series.args["alias"].this
418                    column = exp.If(
419                        this=exp.column(series_alias, table=series_table_alias).eq(
420                            exp.column(pos_alias, table=unnest_source_alias)
421                        ),
422                        true=exp.column(explode_alias, table=unnest_source_alias),
423                    )
424
425                    explode.replace(column)
426
427                    if is_posexplode:
428                        expressions = expression.expressions
429                        expressions.insert(
430                            expressions.index(alias) + 1,
431                            exp.If(
432                                this=exp.column(series_alias, table=series_table_alias).eq(
433                                    exp.column(pos_alias, table=unnest_source_alias)
434                                ),
435                                true=exp.column(pos_alias, table=unnest_source_alias),
436                            ).as_(pos_alias),
437                        )
438                        expression.set("expressions", expressions)
439
440                    if not arrays:
441                        if expression.args.get("from"):
442                            expression.join(series, copy=False, join_type="CROSS")
443                        else:
444                            expression.from_(series, copy=False)
445
446                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
447                    arrays.append(size)
448
449                    # trino doesn't support left join unnest with on conditions
450                    # if it did, this would be much simpler
451                    expression.join(
452                        exp.alias_(
453                            exp.Unnest(
454                                expressions=[explode_arg.copy()],
455                                offset=exp.to_identifier(pos_alias),
456                            ),
457                            unnest_source_alias,
458                            table=[explode_alias],
459                        ),
460                        join_type="CROSS",
461                        copy=False,
462                    )
463
464                    if index_offset != 1:
465                        size = size - 1
466
467                    expression.where(
468                        exp.column(series_alias, table=series_table_alias)
469                        .eq(exp.column(pos_alias, table=unnest_source_alias))
470                        .or_(
471                            (exp.column(series_alias, table=series_table_alias) > size).and_(
472                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
473                            )
474                        ),
475                        copy=False,
476                    )
477
478            if arrays:
479                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
480
481                if index_offset != 1:
482                    end = end - (1 - index_offset)
483                series.expressions[0].set("end", end)
484
485        return expression
486
487    return _explode_to_unnest
488
489
490def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
491    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
492    if (
493        isinstance(expression, exp.PERCENTILES)
494        and not isinstance(expression.parent, exp.WithinGroup)
495        and expression.expression
496    ):
497        column = expression.this.pop()
498        expression.set("this", expression.expression.pop())
499        order = exp.Order(expressions=[exp.Ordered(this=column)])
500        expression = exp.WithinGroup(this=expression, expression=order)
501
502    return expression
503
504
505def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
506    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
507    if (
508        isinstance(expression, exp.WithinGroup)
509        and isinstance(expression.this, exp.PERCENTILES)
510        and isinstance(expression.expression, exp.Order)
511    ):
512        quantile = expression.this.this
513        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
514        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
515
516    return expression
517
518
519def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
520    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
521    if isinstance(expression, exp.With) and expression.recursive:
522        next_name = name_sequence("_c_")
523
524        for cte in expression.expressions:
525            if not cte.args["alias"].columns:
526                query = cte.this
527                if isinstance(query, exp.SetOperation):
528                    query = query.this
529
530                cte.args["alias"].set(
531                    "columns",
532                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
533                )
534
535    return expression
536
537
538def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
539    """Replace 'epoch' in casts by the equivalent date literal."""
540    if (
541        isinstance(expression, (exp.Cast, exp.TryCast))
542        and expression.name.lower() == "epoch"
543        and expression.to.this in exp.DataType.TEMPORAL_TYPES
544    ):
545        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
546
547    return expression
548
549
550def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
551    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
552    if isinstance(expression, exp.Select):
553        for join in expression.args.get("joins") or []:
554            on = join.args.get("on")
555            if on and join.kind in ("SEMI", "ANTI"):
556                subquery = exp.select("1").from_(join.this).where(on)
557                exists = exp.Exists(this=subquery)
558                if join.kind == "ANTI":
559                    exists = exists.not_(copy=False)
560
561                join.pop()
562                expression.where(exists, copy=False)
563
564    return expression
565
566
567def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
568    """
569    Converts a query with a FULL OUTER join to a union of identical queries that
570    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
571    for queries that have a single FULL OUTER join.
572    """
573    if isinstance(expression, exp.Select):
574        full_outer_joins = [
575            (index, join)
576            for index, join in enumerate(expression.args.get("joins") or [])
577            if join.side == "FULL"
578        ]
579
580        if len(full_outer_joins) == 1:
581            expression_copy = expression.copy()
582            expression.set("limit", None)
583            index, full_outer_join = full_outer_joins[0]
584            full_outer_join.set("side", "left")
585            expression_copy.args["joins"][index].set("side", "right")
586            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
587
588            return exp.union(expression, expression_copy, copy=False)
589
590    return expression
591
592
593def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
594    """
595    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
596    defined at the top-level, so for example queries like:
597
598        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
599
600    are invalid in those dialects. This transformation can be used to ensure all CTEs are
601    moved to the top level so that the final SQL code is valid from a syntax standpoint.
602
603    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
604    """
605    top_level_with = expression.args.get("with")
606    for inner_with in expression.find_all(exp.With):
607        if inner_with.parent is expression:
608            continue
609
610        if not top_level_with:
611            top_level_with = inner_with.pop()
612            expression.set("with", top_level_with)
613        else:
614            if inner_with.recursive:
615                top_level_with.set("recursive", True)
616
617            parent_cte = inner_with.find_ancestor(exp.CTE)
618            inner_with.pop()
619
620            if parent_cte:
621                i = top_level_with.expressions.index(parent_cte)
622                top_level_with.expressions[i:i] = inner_with.expressions
623                top_level_with.set("expressions", top_level_with.expressions)
624            else:
625                top_level_with.set(
626                    "expressions", top_level_with.expressions + inner_with.expressions
627                )
628
629    return expression
630
631
632def ensure_bools(expression: exp.Expression) -> exp.Expression:
633    """Converts numeric values used in conditions into explicit boolean expressions."""
634    from sqlglot.optimizer.canonicalize import ensure_bools
635
636    def _ensure_bool(node: exp.Expression) -> None:
637        if (
638            node.is_number
639            or (
640                not isinstance(node, exp.SubqueryPredicate)
641                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
642            )
643            or (isinstance(node, exp.Column) and not node.type)
644        ):
645            node.replace(node.neq(0))
646
647    for node in expression.walk():
648        ensure_bools(node, _ensure_bool)
649
650    return expression
651
652
653def unqualify_columns(expression: exp.Expression) -> exp.Expression:
654    for column in expression.find_all(exp.Column):
655        # We only wanna pop off the table, db, catalog args
656        for part in column.parts[:-1]:
657            part.pop()
658
659    return expression
660
661
662def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
663    assert isinstance(expression, exp.Create)
664    for constraint in expression.find_all(exp.UniqueColumnConstraint):
665        if constraint.parent:
666            constraint.parent.pop()
667
668    return expression
669
670
671def ctas_with_tmp_tables_to_create_tmp_view(
672    expression: exp.Expression,
673    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
674) -> exp.Expression:
675    assert isinstance(expression, exp.Create)
676    properties = expression.args.get("properties")
677    temporary = any(
678        isinstance(prop, exp.TemporaryProperty)
679        for prop in (properties.expressions if properties else [])
680    )
681
682    # CTAS with temp tables map to CREATE TEMPORARY VIEW
683    if expression.kind == "TABLE" and temporary:
684        if expression.expression:
685            return exp.Create(
686                kind="TEMPORARY VIEW",
687                this=expression.this,
688                expression=expression.expression,
689            )
690        return tmp_storage_provider(expression)
691
692    return expression
693
694
695def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
696    """
697    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
698    PARTITIONED BY value is an array of column names, they are transformed into a schema.
699    The corresponding columns are removed from the create statement.
700    """
701    assert isinstance(expression, exp.Create)
702    has_schema = isinstance(expression.this, exp.Schema)
703    is_partitionable = expression.kind in {"TABLE", "VIEW"}
704
705    if has_schema and is_partitionable:
706        prop = expression.find(exp.PartitionedByProperty)
707        if prop and prop.this and not isinstance(prop.this, exp.Schema):
708            schema = expression.this
709            columns = {v.name.upper() for v in prop.this.expressions}
710            partitions = [col for col in schema.expressions if col.name.upper() in columns]
711            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
712            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
713            expression.set("this", schema)
714
715    return expression
716
717
718def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
719    """
720    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
721
722    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
723    """
724    assert isinstance(expression, exp.Create)
725    prop = expression.find(exp.PartitionedByProperty)
726    if (
727        prop
728        and prop.this
729        and isinstance(prop.this, exp.Schema)
730        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
731    ):
732        prop_this = exp.Tuple(
733            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
734        )
735        schema = expression.this
736        for e in prop.this.expressions:
737            schema.append("expressions", e)
738        prop.set("this", prop_this)
739
740    return expression
741
742
743def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
744    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
745    if isinstance(expression, exp.Struct):
746        expression.set(
747            "expressions",
748            [
749                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
750                for e in expression.expressions
751            ],
752        )
753
754    return expression
755
756
757def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
758    """
759    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
760    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
761
762    For example,
763        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
764        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
765
766    Args:
767        expression: The AST to remove join marks from.
768
769    Returns:
770       The AST with join marks removed.
771    """
772    from sqlglot.optimizer.scope import traverse_scope
773
774    for scope in traverse_scope(expression):
775        query = scope.expression
776
777        where = query.args.get("where")
778        joins = query.args.get("joins")
779
780        if not where or not joins:
781            continue
782
783        query_from = query.args["from"]
784
785        # These keep track of the joins to be replaced
786        new_joins: t.Dict[str, exp.Join] = {}
787        old_joins = {join.alias_or_name: join for join in joins}
788
789        for column in scope.columns:
790            if not column.args.get("join_mark"):
791                continue
792
793            predicate = column.find_ancestor(exp.Predicate, exp.Select)
794            assert isinstance(
795                predicate, exp.Binary
796            ), "Columns can only be marked with (+) when involved in a binary operation"
797
798            predicate_parent = predicate.parent
799            join_predicate = predicate.pop()
800
801            left_columns = [
802                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
803            ]
804            right_columns = [
805                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
806            ]
807
808            assert not (
809                left_columns and right_columns
810            ), "The (+) marker cannot appear in both sides of a binary predicate"
811
812            marked_column_tables = set()
813            for col in left_columns or right_columns:
814                table = col.table
815                assert table, f"Column {col} needs to be qualified with a table"
816
817                col.set("join_mark", False)
818                marked_column_tables.add(table)
819
820            assert (
821                len(marked_column_tables) == 1
822            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
823
824            join_this = old_joins.get(col.table, query_from).this
825            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
826
827            # Upsert new_join into new_joins dictionary
828            new_join_alias_or_name = new_join.alias_or_name
829            existing_join = new_joins.get(new_join_alias_or_name)
830            if existing_join:
831                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
832            else:
833                new_joins[new_join_alias_or_name] = new_join
834
835            # If the parent of the target predicate is a binary node, then it now has only one child
836            if isinstance(predicate_parent, exp.Binary):
837                if predicate_parent.left is None:
838                    predicate_parent.replace(predicate_parent.right)
839                else:
840                    predicate_parent.replace(predicate_parent.left)
841
842        if query_from.alias_or_name in new_joins:
843            only_old_joins = old_joins.keys() - new_joins.keys()
844            assert (
845                len(only_old_joins) >= 1
846            ), "Cannot determine which table to use in the new FROM clause"
847
848            new_from_name = list(only_old_joins)[0]
849            query.set("from", exp.From(this=old_joins[new_from_name].this))
850
851        query.set("joins", list(new_joins.values()))
852
853        if not where.this:
854            where.pop()
855
856    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
13def preprocess(
14    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
15) -> t.Callable[[Generator, exp.Expression], str]:
16    """
17    Creates a new transform by chaining a sequence of transformations and converts the resulting
18    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
19    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
20
21    Args:
22        transforms: sequence of transform functions. These will be called in order.
23
24    Returns:
25        Function that can be used as a generator transform.
26    """
27
28    def _to_sql(self, expression: exp.Expression) -> str:
29        expression_type = type(expression)
30
31        expression = transforms[0](expression)
32        for transform in transforms[1:]:
33            expression = transform(expression)
34
35        _sql_handler = getattr(self, expression.key + "_sql", None)
36        if _sql_handler:
37            return _sql_handler(expression)
38
39        transforms_handler = self.TRANSFORMS.get(type(expression))
40        if transforms_handler:
41            if expression_type is type(expression):
42                if isinstance(expression, exp.Func):
43                    return self.function_fallback_sql(expression)
44
45                # Ensures we don't enter an infinite loop. This can happen when the original expression
46                # has the same type as the final expression and there's no _sql method available for it,
47                # because then it'd re-enter _to_sql.
48                raise ValueError(
49                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
50                )
51
52            return transforms_handler(self, expression)
53
54        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
55
56    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( bubble_up_recursive_cte: bool = False) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
 59def unnest_generate_date_array_using_recursive_cte(
 60    bubble_up_recursive_cte: bool = False,
 61) -> t.Callable[[exp.Expression], exp.Expression]:
 62    def _unnest_generate_date_array_using_recursive_cte(
 63        expression: exp.Expression,
 64    ) -> exp.Expression:
 65        if (
 66            isinstance(expression, exp.Unnest)
 67            and isinstance(expression.parent, (exp.From, exp.Join))
 68            and len(expression.expressions) == 1
 69            and isinstance(expression.expressions[0], exp.GenerateDateArray)
 70        ):
 71            generate_date_array = expression.expressions[0]
 72            start = generate_date_array.args.get("start")
 73            end = generate_date_array.args.get("end")
 74            step = generate_date_array.args.get("step")
 75
 76            if not start or not end or not isinstance(step, exp.Interval):
 77                return expression
 78
 79            start = exp.cast(start, "date")
 80            date_add = exp.func(
 81                "date_add", "date_value", exp.Literal.number(step.name), step.args.get("unit")
 82            )
 83            cast_date_add = exp.cast(date_add, "date")
 84
 85            base_query = exp.select(start.as_("date_value"))
 86            recursive_query = (
 87                exp.select(cast_date_add.as_("date_value"))
 88                .from_("_generated_dates")
 89                .where(cast_date_add < exp.cast(end, "date"))
 90            )
 91            cte_query = base_query.union(recursive_query, distinct=False)
 92            generate_dates_query = exp.select("date_value").from_("_generated_dates")
 93
 94            query_to_add_cte: exp.Query = generate_dates_query
 95
 96            if bubble_up_recursive_cte:
 97                parent: t.Optional[exp.Expression] = expression.parent
 98
 99                while parent:
100                    if isinstance(parent, exp.Query):
101                        query_to_add_cte = parent
102                    parent = parent.parent
103
104            query_to_add_cte.with_(
105                "_generated_dates(date_value)", as_=cte_query, recursive=True, copy=False
106            )
107            return generate_dates_query.subquery("_generated_dates")
108
109        return expression
110
111    return _unnest_generate_date_array_using_recursive_cte
def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
114def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
115    """Unnests GENERATE_SERIES or SEQUENCE table references."""
116    this = expression.this
117    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
118        unnest = exp.Unnest(expressions=[this])
119        if expression.alias:
120            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
121
122        return unnest
123
124    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
127def unalias_group(expression: exp.Expression) -> exp.Expression:
128    """
129    Replace references to select aliases in GROUP BY clauses.
130
131    Example:
132        >>> import sqlglot
133        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
134        'SELECT a AS b FROM x GROUP BY 1'
135
136    Args:
137        expression: the expression that will be transformed.
138
139    Returns:
140        The transformed expression.
141    """
142    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
143        aliased_selects = {
144            e.alias: i
145            for i, e in enumerate(expression.parent.expressions, start=1)
146            if isinstance(e, exp.Alias)
147        }
148
149        for group_by in expression.expressions:
150            if (
151                isinstance(group_by, exp.Column)
152                and not group_by.table
153                and group_by.name in aliased_selects
154            ):
155                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
156
157    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
160def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
161    """
162    Convert SELECT DISTINCT ON statements to a subquery with a window function.
163
164    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
165
166    Args:
167        expression: the expression that will be transformed.
168
169    Returns:
170        The transformed expression.
171    """
172    if (
173        isinstance(expression, exp.Select)
174        and expression.args.get("distinct")
175        and expression.args["distinct"].args.get("on")
176        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
177    ):
178        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
179        outer_selects = expression.selects
180        row_number = find_new_name(expression.named_selects, "_row_number")
181        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
182        order = expression.args.get("order")
183
184        if order:
185            window.set("order", order.pop())
186        else:
187            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
188
189        window = exp.alias_(window, row_number)
190        expression.select(window, copy=False)
191
192        return (
193            exp.select(*outer_selects, copy=False)
194            .from_(expression.subquery("_t", copy=False), copy=False)
195            .where(exp.column(row_number).eq(1), copy=False)
196        )
197
198    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
201def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
202    """
203    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
204
205    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
206    https://docs.snowflake.com/en/sql-reference/constructs/qualify
207
208    Some dialects don't support window functions in the WHERE clause, so we need to include them as
209    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
210    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
211    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
212    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
213    corresponding expression to avoid creating invalid column references.
214    """
215    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
216        taken = set(expression.named_selects)
217        for select in expression.selects:
218            if not select.alias_or_name:
219                alias = find_new_name(taken, "_c")
220                select.replace(exp.alias_(select, alias))
221                taken.add(alias)
222
223        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
224            alias_or_name = select.alias_or_name
225            identifier = select.args.get("alias") or select.this
226            if isinstance(identifier, exp.Identifier):
227                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
228            return alias_or_name
229
230        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
231        qualify_filters = expression.args["qualify"].pop().this
232        expression_by_alias = {
233            select.alias: select.this
234            for select in expression.selects
235            if isinstance(select, exp.Alias)
236        }
237
238        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
239        for select_candidate in qualify_filters.find_all(select_candidates):
240            if isinstance(select_candidate, exp.Window):
241                if expression_by_alias:
242                    for column in select_candidate.find_all(exp.Column):
243                        expr = expression_by_alias.get(column.name)
244                        if expr:
245                            column.replace(expr)
246
247                alias = find_new_name(expression.named_selects, "_w")
248                expression.select(exp.alias_(select_candidate, alias), copy=False)
249                column = exp.column(alias)
250
251                if isinstance(select_candidate.parent, exp.Qualify):
252                    qualify_filters = column
253                else:
254                    select_candidate.replace(column)
255            elif select_candidate.name not in expression.named_selects:
256                expression.select(select_candidate.copy(), copy=False)
257
258        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
259            qualify_filters, copy=False
260        )
261
262    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
265def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
266    """
267    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
268    other expressions. This transforms removes the precision from parameterized types in expressions.
269    """
270    for node in expression.find_all(exp.DataType):
271        node.set(
272            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
273        )
274
275    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
278def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
279    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
280    from sqlglot.optimizer.scope import find_all_in_scope
281
282    if isinstance(expression, exp.Select):
283        unnest_aliases = {
284            unnest.alias
285            for unnest in find_all_in_scope(expression, exp.Unnest)
286            if isinstance(unnest.parent, (exp.From, exp.Join))
287        }
288        if unnest_aliases:
289            for column in expression.find_all(exp.Column):
290                if column.table in unnest_aliases:
291                    column.set("table", None)
292                elif column.db in unnest_aliases:
293                    column.set("db", None)
294
295    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
298def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
299    """Convert cross join unnest into lateral view explode."""
300    if isinstance(expression, exp.Select):
301        from_ = expression.args.get("from")
302
303        if from_ and isinstance(from_.this, exp.Unnest):
304            unnest = from_.this
305            alias = unnest.args.get("alias")
306            udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
307            this, *expressions = unnest.expressions
308            unnest.replace(
309                exp.Table(
310                    this=udtf(
311                        this=this,
312                        expressions=expressions,
313                    ),
314                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
315                )
316            )
317
318        for join in expression.args.get("joins") or []:
319            unnest = join.this
320
321            if isinstance(unnest, exp.Unnest):
322                alias = unnest.args.get("alias")
323                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
324
325                expression.args["joins"].remove(join)
326
327                for e, column in zip(unnest.expressions, alias.columns if alias else []):
328                    expression.append(
329                        "laterals",
330                        exp.Lateral(
331                            this=udtf(this=e),
332                            view=True,
333                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
334                        ),
335                    )
336
337    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
340def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
341    """Convert explode/posexplode into unnest."""
342
343    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
344        if isinstance(expression, exp.Select):
345            from sqlglot.optimizer.scope import Scope
346
347            taken_select_names = set(expression.named_selects)
348            taken_source_names = {name for name, _ in Scope(expression).references}
349
350            def new_name(names: t.Set[str], name: str) -> str:
351                name = find_new_name(names, name)
352                names.add(name)
353                return name
354
355            arrays: t.List[exp.Condition] = []
356            series_alias = new_name(taken_select_names, "pos")
357            series = exp.alias_(
358                exp.Unnest(
359                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
360                ),
361                new_name(taken_source_names, "_u"),
362                table=[series_alias],
363            )
364
365            # we use list here because expression.selects is mutated inside the loop
366            for select in list(expression.selects):
367                explode = select.find(exp.Explode)
368
369                if explode:
370                    pos_alias = ""
371                    explode_alias = ""
372
373                    if isinstance(select, exp.Alias):
374                        explode_alias = select.args["alias"]
375                        alias = select
376                    elif isinstance(select, exp.Aliases):
377                        pos_alias = select.aliases[0]
378                        explode_alias = select.aliases[1]
379                        alias = select.replace(exp.alias_(select.this, "", copy=False))
380                    else:
381                        alias = select.replace(exp.alias_(select, ""))
382                        explode = alias.find(exp.Explode)
383                        assert explode
384
385                    is_posexplode = isinstance(explode, exp.Posexplode)
386                    explode_arg = explode.this
387
388                    if isinstance(explode, exp.ExplodeOuter):
389                        bracket = explode_arg[0]
390                        bracket.set("safe", True)
391                        bracket.set("offset", True)
392                        explode_arg = exp.func(
393                            "IF",
394                            exp.func(
395                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
396                            ).eq(0),
397                            exp.array(bracket, copy=False),
398                            explode_arg,
399                        )
400
401                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
402                    if isinstance(explode_arg, exp.Column):
403                        taken_select_names.add(explode_arg.output_name)
404
405                    unnest_source_alias = new_name(taken_source_names, "_u")
406
407                    if not explode_alias:
408                        explode_alias = new_name(taken_select_names, "col")
409
410                        if is_posexplode:
411                            pos_alias = new_name(taken_select_names, "pos")
412
413                    if not pos_alias:
414                        pos_alias = new_name(taken_select_names, "pos")
415
416                    alias.set("alias", exp.to_identifier(explode_alias))
417
418                    series_table_alias = series.args["alias"].this
419                    column = exp.If(
420                        this=exp.column(series_alias, table=series_table_alias).eq(
421                            exp.column(pos_alias, table=unnest_source_alias)
422                        ),
423                        true=exp.column(explode_alias, table=unnest_source_alias),
424                    )
425
426                    explode.replace(column)
427
428                    if is_posexplode:
429                        expressions = expression.expressions
430                        expressions.insert(
431                            expressions.index(alias) + 1,
432                            exp.If(
433                                this=exp.column(series_alias, table=series_table_alias).eq(
434                                    exp.column(pos_alias, table=unnest_source_alias)
435                                ),
436                                true=exp.column(pos_alias, table=unnest_source_alias),
437                            ).as_(pos_alias),
438                        )
439                        expression.set("expressions", expressions)
440
441                    if not arrays:
442                        if expression.args.get("from"):
443                            expression.join(series, copy=False, join_type="CROSS")
444                        else:
445                            expression.from_(series, copy=False)
446
447                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
448                    arrays.append(size)
449
450                    # trino doesn't support left join unnest with on conditions
451                    # if it did, this would be much simpler
452                    expression.join(
453                        exp.alias_(
454                            exp.Unnest(
455                                expressions=[explode_arg.copy()],
456                                offset=exp.to_identifier(pos_alias),
457                            ),
458                            unnest_source_alias,
459                            table=[explode_alias],
460                        ),
461                        join_type="CROSS",
462                        copy=False,
463                    )
464
465                    if index_offset != 1:
466                        size = size - 1
467
468                    expression.where(
469                        exp.column(series_alias, table=series_table_alias)
470                        .eq(exp.column(pos_alias, table=unnest_source_alias))
471                        .or_(
472                            (exp.column(series_alias, table=series_table_alias) > size).and_(
473                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
474                            )
475                        ),
476                        copy=False,
477                    )
478
479            if arrays:
480                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
481
482                if index_offset != 1:
483                    end = end - (1 - index_offset)
484                series.expressions[0].set("end", end)
485
486        return expression
487
488    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
491def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
492    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
493    if (
494        isinstance(expression, exp.PERCENTILES)
495        and not isinstance(expression.parent, exp.WithinGroup)
496        and expression.expression
497    ):
498        column = expression.this.pop()
499        expression.set("this", expression.expression.pop())
500        order = exp.Order(expressions=[exp.Ordered(this=column)])
501        expression = exp.WithinGroup(this=expression, expression=order)
502
503    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
506def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
507    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
508    if (
509        isinstance(expression, exp.WithinGroup)
510        and isinstance(expression.this, exp.PERCENTILES)
511        and isinstance(expression.expression, exp.Order)
512    ):
513        quantile = expression.this.this
514        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
515        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
516
517    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
520def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
521    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
522    if isinstance(expression, exp.With) and expression.recursive:
523        next_name = name_sequence("_c_")
524
525        for cte in expression.expressions:
526            if not cte.args["alias"].columns:
527                query = cte.this
528                if isinstance(query, exp.SetOperation):
529                    query = query.this
530
531                cte.args["alias"].set(
532                    "columns",
533                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
534                )
535
536    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
539def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
540    """Replace 'epoch' in casts by the equivalent date literal."""
541    if (
542        isinstance(expression, (exp.Cast, exp.TryCast))
543        and expression.name.lower() == "epoch"
544        and expression.to.this in exp.DataType.TEMPORAL_TYPES
545    ):
546        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
547
548    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
551def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
552    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
553    if isinstance(expression, exp.Select):
554        for join in expression.args.get("joins") or []:
555            on = join.args.get("on")
556            if on and join.kind in ("SEMI", "ANTI"):
557                subquery = exp.select("1").from_(join.this).where(on)
558                exists = exp.Exists(this=subquery)
559                if join.kind == "ANTI":
560                    exists = exists.not_(copy=False)
561
562                join.pop()
563                expression.where(exists, copy=False)
564
565    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
568def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
569    """
570    Converts a query with a FULL OUTER join to a union of identical queries that
571    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
572    for queries that have a single FULL OUTER join.
573    """
574    if isinstance(expression, exp.Select):
575        full_outer_joins = [
576            (index, join)
577            for index, join in enumerate(expression.args.get("joins") or [])
578            if join.side == "FULL"
579        ]
580
581        if len(full_outer_joins) == 1:
582            expression_copy = expression.copy()
583            expression.set("limit", None)
584            index, full_outer_join = full_outer_joins[0]
585            full_outer_join.set("side", "left")
586            expression_copy.args["joins"][index].set("side", "right")
587            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
588
589            return exp.union(expression, expression_copy, copy=False)
590
591    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
594def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
595    """
596    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
597    defined at the top-level, so for example queries like:
598
599        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
600
601    are invalid in those dialects. This transformation can be used to ensure all CTEs are
602    moved to the top level so that the final SQL code is valid from a syntax standpoint.
603
604    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
605    """
606    top_level_with = expression.args.get("with")
607    for inner_with in expression.find_all(exp.With):
608        if inner_with.parent is expression:
609            continue
610
611        if not top_level_with:
612            top_level_with = inner_with.pop()
613            expression.set("with", top_level_with)
614        else:
615            if inner_with.recursive:
616                top_level_with.set("recursive", True)
617
618            parent_cte = inner_with.find_ancestor(exp.CTE)
619            inner_with.pop()
620
621            if parent_cte:
622                i = top_level_with.expressions.index(parent_cte)
623                top_level_with.expressions[i:i] = inner_with.expressions
624                top_level_with.set("expressions", top_level_with.expressions)
625            else:
626                top_level_with.set(
627                    "expressions", top_level_with.expressions + inner_with.expressions
628                )
629
630    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
633def ensure_bools(expression: exp.Expression) -> exp.Expression:
634    """Converts numeric values used in conditions into explicit boolean expressions."""
635    from sqlglot.optimizer.canonicalize import ensure_bools
636
637    def _ensure_bool(node: exp.Expression) -> None:
638        if (
639            node.is_number
640            or (
641                not isinstance(node, exp.SubqueryPredicate)
642                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
643            )
644            or (isinstance(node, exp.Column) and not node.type)
645        ):
646            node.replace(node.neq(0))
647
648    for node in expression.walk():
649        ensure_bools(node, _ensure_bool)
650
651    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
654def unqualify_columns(expression: exp.Expression) -> exp.Expression:
655    for column in expression.find_all(exp.Column):
656        # We only wanna pop off the table, db, catalog args
657        for part in column.parts[:-1]:
658            part.pop()
659
660    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
663def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
664    assert isinstance(expression, exp.Create)
665    for constraint in expression.find_all(exp.UniqueColumnConstraint):
666        if constraint.parent:
667            constraint.parent.pop()
668
669    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
672def ctas_with_tmp_tables_to_create_tmp_view(
673    expression: exp.Expression,
674    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
675) -> exp.Expression:
676    assert isinstance(expression, exp.Create)
677    properties = expression.args.get("properties")
678    temporary = any(
679        isinstance(prop, exp.TemporaryProperty)
680        for prop in (properties.expressions if properties else [])
681    )
682
683    # CTAS with temp tables map to CREATE TEMPORARY VIEW
684    if expression.kind == "TABLE" and temporary:
685        if expression.expression:
686            return exp.Create(
687                kind="TEMPORARY VIEW",
688                this=expression.this,
689                expression=expression.expression,
690            )
691        return tmp_storage_provider(expression)
692
693    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
696def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
697    """
698    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
699    PARTITIONED BY value is an array of column names, they are transformed into a schema.
700    The corresponding columns are removed from the create statement.
701    """
702    assert isinstance(expression, exp.Create)
703    has_schema = isinstance(expression.this, exp.Schema)
704    is_partitionable = expression.kind in {"TABLE", "VIEW"}
705
706    if has_schema and is_partitionable:
707        prop = expression.find(exp.PartitionedByProperty)
708        if prop and prop.this and not isinstance(prop.this, exp.Schema):
709            schema = expression.this
710            columns = {v.name.upper() for v in prop.this.expressions}
711            partitions = [col for col in schema.expressions if col.name.upper() in columns]
712            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
713            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
714            expression.set("this", schema)
715
716    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
719def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
720    """
721    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
722
723    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
724    """
725    assert isinstance(expression, exp.Create)
726    prop = expression.find(exp.PartitionedByProperty)
727    if (
728        prop
729        and prop.this
730        and isinstance(prop.this, exp.Schema)
731        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
732    ):
733        prop_this = exp.Tuple(
734            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
735        )
736        schema = expression.this
737        for e in prop.this.expressions:
738            schema.append("expressions", e)
739        prop.set("this", prop_this)
740
741    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
744def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
745    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
746    if isinstance(expression, exp.Struct):
747        expression.set(
748            "expressions",
749            [
750                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
751                for e in expression.expressions
752            ],
753        )
754
755    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
758def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
759    """
760    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
761    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
762
763    For example,
764        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
765        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
766
767    Args:
768        expression: The AST to remove join marks from.
769
770    Returns:
771       The AST with join marks removed.
772    """
773    from sqlglot.optimizer.scope import traverse_scope
774
775    for scope in traverse_scope(expression):
776        query = scope.expression
777
778        where = query.args.get("where")
779        joins = query.args.get("joins")
780
781        if not where or not joins:
782            continue
783
784        query_from = query.args["from"]
785
786        # These keep track of the joins to be replaced
787        new_joins: t.Dict[str, exp.Join] = {}
788        old_joins = {join.alias_or_name: join for join in joins}
789
790        for column in scope.columns:
791            if not column.args.get("join_mark"):
792                continue
793
794            predicate = column.find_ancestor(exp.Predicate, exp.Select)
795            assert isinstance(
796                predicate, exp.Binary
797            ), "Columns can only be marked with (+) when involved in a binary operation"
798
799            predicate_parent = predicate.parent
800            join_predicate = predicate.pop()
801
802            left_columns = [
803                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
804            ]
805            right_columns = [
806                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
807            ]
808
809            assert not (
810                left_columns and right_columns
811            ), "The (+) marker cannot appear in both sides of a binary predicate"
812
813            marked_column_tables = set()
814            for col in left_columns or right_columns:
815                table = col.table
816                assert table, f"Column {col} needs to be qualified with a table"
817
818                col.set("join_mark", False)
819                marked_column_tables.add(table)
820
821            assert (
822                len(marked_column_tables) == 1
823            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
824
825            join_this = old_joins.get(col.table, query_from).this
826            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
827
828            # Upsert new_join into new_joins dictionary
829            new_join_alias_or_name = new_join.alias_or_name
830            existing_join = new_joins.get(new_join_alias_or_name)
831            if existing_join:
832                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
833            else:
834                new_joins[new_join_alias_or_name] = new_join
835
836            # If the parent of the target predicate is a binary node, then it now has only one child
837            if isinstance(predicate_parent, exp.Binary):
838                if predicate_parent.left is None:
839                    predicate_parent.replace(predicate_parent.right)
840                else:
841                    predicate_parent.replace(predicate_parent.left)
842
843        if query_from.alias_or_name in new_joins:
844            only_old_joins = old_joins.keys() - new_joins.keys()
845            assert (
846                len(only_old_joins) >= 1
847            ), "Cannot determine which table to use in the new FROM clause"
848
849            new_from_name = list(only_old_joins)[0]
850            query.set("from", exp.From(this=old_joins[new_from_name].this))
851
852        query.set("joins", list(new_joins.values()))
853
854        if not where.this:
855            where.pop()
856
857    return expression

Remove join marks from an AST. This rule assumes that all marked columns are qualified. If this does not hold for a query, consider running sqlglot.optimizer.qualify first.

For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this

Arguments:
  • expression: The AST to remove join marks from.
Returns:

The AST with join marks removed.