Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.errors import UnsupportedError
  7from sqlglot.helper import find_new_name, name_sequence
  8
  9
 10if t.TYPE_CHECKING:
 11    from sqlglot._typing import E
 12    from sqlglot.generator import Generator
 13
 14
 15def preprocess(
 16    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
 17) -> t.Callable[[Generator, exp.Expression], str]:
 18    """
 19    Creates a new transform by chaining a sequence of transformations and converts the resulting
 20    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
 21    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
 22
 23    Args:
 24        transforms: sequence of transform functions. These will be called in order.
 25
 26    Returns:
 27        Function that can be used as a generator transform.
 28    """
 29
 30    def _to_sql(self, expression: exp.Expression) -> str:
 31        expression_type = type(expression)
 32
 33        try:
 34            expression = transforms[0](expression)
 35            for transform in transforms[1:]:
 36                expression = transform(expression)
 37        except UnsupportedError as unsupported_error:
 38            self.unsupported(str(unsupported_error))
 39
 40        _sql_handler = getattr(self, expression.key + "_sql", None)
 41        if _sql_handler:
 42            return _sql_handler(expression)
 43
 44        transforms_handler = self.TRANSFORMS.get(type(expression))
 45        if transforms_handler:
 46            if expression_type is type(expression):
 47                if isinstance(expression, exp.Func):
 48                    return self.function_fallback_sql(expression)
 49
 50                # Ensures we don't enter an infinite loop. This can happen when the original expression
 51                # has the same type as the final expression and there's no _sql method available for it,
 52                # because then it'd re-enter _to_sql.
 53                raise ValueError(
 54                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
 55                )
 56
 57            return transforms_handler(self, expression)
 58
 59        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
 60
 61    return _to_sql
 62
 63
 64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 65    if isinstance(expression, exp.Select):
 66        count = 0
 67        recursive_ctes = []
 68
 69        for unnest in expression.find_all(exp.Unnest):
 70            if (
 71                not isinstance(unnest.parent, (exp.From, exp.Join))
 72                or len(unnest.expressions) != 1
 73                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 74            ):
 75                continue
 76
 77            generate_date_array = unnest.expressions[0]
 78            start = generate_date_array.args.get("start")
 79            end = generate_date_array.args.get("end")
 80            step = generate_date_array.args.get("step")
 81
 82            if not start or not end or not isinstance(step, exp.Interval):
 83                continue
 84
 85            alias = unnest.args.get("alias")
 86            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 87
 88            start = exp.cast(start, "date")
 89            date_add = exp.func(
 90                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 91            )
 92            cast_date_add = exp.cast(date_add, "date")
 93
 94            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 95
 96            base_query = exp.select(start.as_(column_name))
 97            recursive_query = (
 98                exp.select(cast_date_add)
 99                .from_(cte_name)
100                .where(cast_date_add <= exp.cast(end, "date"))
101            )
102            cte_query = base_query.union(recursive_query, distinct=False)
103
104            generate_dates_query = exp.select(column_name).from_(cte_name)
105            unnest.replace(generate_dates_query.subquery(cte_name))
106
107            recursive_ctes.append(
108                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
109            )
110            count += 1
111
112        if recursive_ctes:
113            with_expression = expression.args.get("with") or exp.With()
114            with_expression.set("recursive", True)
115            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
116            expression.set("with", with_expression)
117
118    return expression
119
120
121def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
122    """Unnests GENERATE_SERIES or SEQUENCE table references."""
123    this = expression.this
124    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
125        unnest = exp.Unnest(expressions=[this])
126        if expression.alias:
127            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
128
129        return unnest
130
131    return expression
132
133
134def unalias_group(expression: exp.Expression) -> exp.Expression:
135    """
136    Replace references to select aliases in GROUP BY clauses.
137
138    Example:
139        >>> import sqlglot
140        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
141        'SELECT a AS b FROM x GROUP BY 1'
142
143    Args:
144        expression: the expression that will be transformed.
145
146    Returns:
147        The transformed expression.
148    """
149    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
150        aliased_selects = {
151            e.alias: i
152            for i, e in enumerate(expression.parent.expressions, start=1)
153            if isinstance(e, exp.Alias)
154        }
155
156        for group_by in expression.expressions:
157            if (
158                isinstance(group_by, exp.Column)
159                and not group_by.table
160                and group_by.name in aliased_selects
161            ):
162                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
163
164    return expression
165
166
167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
168    """
169    Convert SELECT DISTINCT ON statements to a subquery with a window function.
170
171    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
172
173    Args:
174        expression: the expression that will be transformed.
175
176    Returns:
177        The transformed expression.
178    """
179    if (
180        isinstance(expression, exp.Select)
181        and expression.args.get("distinct")
182        and expression.args["distinct"].args.get("on")
183        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
184    ):
185        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
186        outer_selects = expression.selects
187        row_number = find_new_name(expression.named_selects, "_row_number")
188        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
189        order = expression.args.get("order")
190
191        if order:
192            window.set("order", order.pop())
193        else:
194            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
195
196        window = exp.alias_(window, row_number)
197        expression.select(window, copy=False)
198
199        return (
200            exp.select(*outer_selects, copy=False)
201            .from_(expression.subquery("_t", copy=False), copy=False)
202            .where(exp.column(row_number).eq(1), copy=False)
203        )
204
205    return expression
206
207
208def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
209    """
210    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
211
212    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
213    https://docs.snowflake.com/en/sql-reference/constructs/qualify
214
215    Some dialects don't support window functions in the WHERE clause, so we need to include them as
216    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
217    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
218    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
219    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
220    corresponding expression to avoid creating invalid column references.
221    """
222    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
223        taken = set(expression.named_selects)
224        for select in expression.selects:
225            if not select.alias_or_name:
226                alias = find_new_name(taken, "_c")
227                select.replace(exp.alias_(select, alias))
228                taken.add(alias)
229
230        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
231            alias_or_name = select.alias_or_name
232            identifier = select.args.get("alias") or select.this
233            if isinstance(identifier, exp.Identifier):
234                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
235            return alias_or_name
236
237        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
238        qualify_filters = expression.args["qualify"].pop().this
239        expression_by_alias = {
240            select.alias: select.this
241            for select in expression.selects
242            if isinstance(select, exp.Alias)
243        }
244
245        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
246        for select_candidate in qualify_filters.find_all(select_candidates):
247            if isinstance(select_candidate, exp.Window):
248                if expression_by_alias:
249                    for column in select_candidate.find_all(exp.Column):
250                        expr = expression_by_alias.get(column.name)
251                        if expr:
252                            column.replace(expr)
253
254                alias = find_new_name(expression.named_selects, "_w")
255                expression.select(exp.alias_(select_candidate, alias), copy=False)
256                column = exp.column(alias)
257
258                if isinstance(select_candidate.parent, exp.Qualify):
259                    qualify_filters = column
260                else:
261                    select_candidate.replace(column)
262            elif select_candidate.name not in expression.named_selects:
263                expression.select(select_candidate.copy(), copy=False)
264
265        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
266            qualify_filters, copy=False
267        )
268
269    return expression
270
271
272def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
273    """
274    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
275    other expressions. This transforms removes the precision from parameterized types in expressions.
276    """
277    for node in expression.find_all(exp.DataType):
278        node.set(
279            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
280        )
281
282    return expression
283
284
285def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
286    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
287    from sqlglot.optimizer.scope import find_all_in_scope
288
289    if isinstance(expression, exp.Select):
290        unnest_aliases = {
291            unnest.alias
292            for unnest in find_all_in_scope(expression, exp.Unnest)
293            if isinstance(unnest.parent, (exp.From, exp.Join))
294        }
295        if unnest_aliases:
296            for column in expression.find_all(exp.Column):
297                if column.table in unnest_aliases:
298                    column.set("table", None)
299                elif column.db in unnest_aliases:
300                    column.set("db", None)
301
302    return expression
303
304
305def unnest_to_explode(
306    expression: exp.Expression,
307    unnest_using_arrays_zip: bool = True,
308) -> exp.Expression:
309    """Convert cross join unnest into lateral view explode."""
310
311    def _unnest_zip_exprs(
312        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
313    ) -> t.List[exp.Expression]:
314        if has_multi_expr:
315            if not unnest_using_arrays_zip:
316                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
317
318            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
319            zip_exprs: t.List[exp.Expression] = [
320                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
321            ]
322            u.set("expressions", zip_exprs)
323            return zip_exprs
324        return unnest_exprs
325
326    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
327        if u.args.get("offset"):
328            return exp.Posexplode
329        return exp.Inline if has_multi_expr else exp.Explode
330
331    if isinstance(expression, exp.Select):
332        from_ = expression.args.get("from")
333
334        if from_ and isinstance(from_.this, exp.Unnest):
335            unnest = from_.this
336            alias = unnest.args.get("alias")
337            exprs = unnest.expressions
338            has_multi_expr = len(exprs) > 1
339            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
340
341            unnest.replace(
342                exp.Table(
343                    this=_udtf_type(unnest, has_multi_expr)(
344                        this=this,
345                        expressions=expressions,
346                    ),
347                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
348                )
349            )
350
351        joins = expression.args.get("joins") or []
352        for join in list(joins):
353            join_expr = join.this
354
355            is_lateral = isinstance(join_expr, exp.Lateral)
356
357            unnest = join_expr.this if is_lateral else join_expr
358
359            if isinstance(unnest, exp.Unnest):
360                if is_lateral:
361                    alias = join_expr.args.get("alias")
362                else:
363                    alias = unnest.args.get("alias")
364                exprs = unnest.expressions
365                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
366                has_multi_expr = len(exprs) > 1
367                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
368
369                joins.remove(join)
370
371                alias_cols = alias.columns if alias else []
372
373                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
374                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
375                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
376
377                if not has_multi_expr and len(alias_cols) not in (1, 2):
378                    raise UnsupportedError(
379                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
380                    )
381
382                for e, column in zip(exprs, alias_cols):
383                    expression.append(
384                        "laterals",
385                        exp.Lateral(
386                            this=_udtf_type(unnest, has_multi_expr)(this=e),
387                            view=True,
388                            alias=exp.TableAlias(
389                                this=alias.this,  # type: ignore
390                                columns=alias_cols,
391                            ),
392                        ),
393                    )
394
395    return expression
396
397
398def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
399    """Convert explode/posexplode into unnest."""
400
401    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
402        if isinstance(expression, exp.Select):
403            from sqlglot.optimizer.scope import Scope
404
405            taken_select_names = set(expression.named_selects)
406            taken_source_names = {name for name, _ in Scope(expression).references}
407
408            def new_name(names: t.Set[str], name: str) -> str:
409                name = find_new_name(names, name)
410                names.add(name)
411                return name
412
413            arrays: t.List[exp.Condition] = []
414            series_alias = new_name(taken_select_names, "pos")
415            series = exp.alias_(
416                exp.Unnest(
417                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
418                ),
419                new_name(taken_source_names, "_u"),
420                table=[series_alias],
421            )
422
423            # we use list here because expression.selects is mutated inside the loop
424            for select in list(expression.selects):
425                explode = select.find(exp.Explode)
426
427                if explode:
428                    pos_alias = ""
429                    explode_alias = ""
430
431                    if isinstance(select, exp.Alias):
432                        explode_alias = select.args["alias"]
433                        alias = select
434                    elif isinstance(select, exp.Aliases):
435                        pos_alias = select.aliases[0]
436                        explode_alias = select.aliases[1]
437                        alias = select.replace(exp.alias_(select.this, "", copy=False))
438                    else:
439                        alias = select.replace(exp.alias_(select, ""))
440                        explode = alias.find(exp.Explode)
441                        assert explode
442
443                    is_posexplode = isinstance(explode, exp.Posexplode)
444                    explode_arg = explode.this
445
446                    if isinstance(explode, exp.ExplodeOuter):
447                        bracket = explode_arg[0]
448                        bracket.set("safe", True)
449                        bracket.set("offset", True)
450                        explode_arg = exp.func(
451                            "IF",
452                            exp.func(
453                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
454                            ).eq(0),
455                            exp.array(bracket, copy=False),
456                            explode_arg,
457                        )
458
459                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
460                    if isinstance(explode_arg, exp.Column):
461                        taken_select_names.add(explode_arg.output_name)
462
463                    unnest_source_alias = new_name(taken_source_names, "_u")
464
465                    if not explode_alias:
466                        explode_alias = new_name(taken_select_names, "col")
467
468                        if is_posexplode:
469                            pos_alias = new_name(taken_select_names, "pos")
470
471                    if not pos_alias:
472                        pos_alias = new_name(taken_select_names, "pos")
473
474                    alias.set("alias", exp.to_identifier(explode_alias))
475
476                    series_table_alias = series.args["alias"].this
477                    column = exp.If(
478                        this=exp.column(series_alias, table=series_table_alias).eq(
479                            exp.column(pos_alias, table=unnest_source_alias)
480                        ),
481                        true=exp.column(explode_alias, table=unnest_source_alias),
482                    )
483
484                    explode.replace(column)
485
486                    if is_posexplode:
487                        expressions = expression.expressions
488                        expressions.insert(
489                            expressions.index(alias) + 1,
490                            exp.If(
491                                this=exp.column(series_alias, table=series_table_alias).eq(
492                                    exp.column(pos_alias, table=unnest_source_alias)
493                                ),
494                                true=exp.column(pos_alias, table=unnest_source_alias),
495                            ).as_(pos_alias),
496                        )
497                        expression.set("expressions", expressions)
498
499                    if not arrays:
500                        if expression.args.get("from"):
501                            expression.join(series, copy=False, join_type="CROSS")
502                        else:
503                            expression.from_(series, copy=False)
504
505                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
506                    arrays.append(size)
507
508                    # trino doesn't support left join unnest with on conditions
509                    # if it did, this would be much simpler
510                    expression.join(
511                        exp.alias_(
512                            exp.Unnest(
513                                expressions=[explode_arg.copy()],
514                                offset=exp.to_identifier(pos_alias),
515                            ),
516                            unnest_source_alias,
517                            table=[explode_alias],
518                        ),
519                        join_type="CROSS",
520                        copy=False,
521                    )
522
523                    if index_offset != 1:
524                        size = size - 1
525
526                    expression.where(
527                        exp.column(series_alias, table=series_table_alias)
528                        .eq(exp.column(pos_alias, table=unnest_source_alias))
529                        .or_(
530                            (exp.column(series_alias, table=series_table_alias) > size).and_(
531                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
532                            )
533                        ),
534                        copy=False,
535                    )
536
537            if arrays:
538                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
539
540                if index_offset != 1:
541                    end = end - (1 - index_offset)
542                series.expressions[0].set("end", end)
543
544        return expression
545
546    return _explode_to_unnest
547
548
549def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
550    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
551    if (
552        isinstance(expression, exp.PERCENTILES)
553        and not isinstance(expression.parent, exp.WithinGroup)
554        and expression.expression
555    ):
556        column = expression.this.pop()
557        expression.set("this", expression.expression.pop())
558        order = exp.Order(expressions=[exp.Ordered(this=column)])
559        expression = exp.WithinGroup(this=expression, expression=order)
560
561    return expression
562
563
564def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
565    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
566    if (
567        isinstance(expression, exp.WithinGroup)
568        and isinstance(expression.this, exp.PERCENTILES)
569        and isinstance(expression.expression, exp.Order)
570    ):
571        quantile = expression.this.this
572        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
573        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
574
575    return expression
576
577
578def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
579    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
580    if isinstance(expression, exp.With) and expression.recursive:
581        next_name = name_sequence("_c_")
582
583        for cte in expression.expressions:
584            if not cte.args["alias"].columns:
585                query = cte.this
586                if isinstance(query, exp.SetOperation):
587                    query = query.this
588
589                cte.args["alias"].set(
590                    "columns",
591                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
592                )
593
594    return expression
595
596
597def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
598    """Replace 'epoch' in casts by the equivalent date literal."""
599    if (
600        isinstance(expression, (exp.Cast, exp.TryCast))
601        and expression.name.lower() == "epoch"
602        and expression.to.this in exp.DataType.TEMPORAL_TYPES
603    ):
604        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
605
606    return expression
607
608
609def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
610    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
611    if isinstance(expression, exp.Select):
612        for join in expression.args.get("joins") or []:
613            on = join.args.get("on")
614            if on and join.kind in ("SEMI", "ANTI"):
615                subquery = exp.select("1").from_(join.this).where(on)
616                exists = exp.Exists(this=subquery)
617                if join.kind == "ANTI":
618                    exists = exists.not_(copy=False)
619
620                join.pop()
621                expression.where(exists, copy=False)
622
623    return expression
624
625
626def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
627    """
628    Converts a query with a FULL OUTER join to a union of identical queries that
629    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
630    for queries that have a single FULL OUTER join.
631    """
632    if isinstance(expression, exp.Select):
633        full_outer_joins = [
634            (index, join)
635            for index, join in enumerate(expression.args.get("joins") or [])
636            if join.side == "FULL"
637        ]
638
639        if len(full_outer_joins) == 1:
640            expression_copy = expression.copy()
641            expression.set("limit", None)
642            index, full_outer_join = full_outer_joins[0]
643
644            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
645            join_conditions = full_outer_join.args.get("on") or exp.and_(
646                *[
647                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
648                    for col in full_outer_join.args.get("using")
649                ]
650            )
651
652            full_outer_join.set("side", "left")
653            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
654            expression_copy.args["joins"][index].set("side", "right")
655            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
656            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
657            expression.args.pop("order", None)  # remove order by from LEFT side
658
659            return exp.union(expression, expression_copy, copy=False, distinct=False)
660
661    return expression
662
663
664def move_ctes_to_top_level(expression: E) -> E:
665    """
666    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
667    defined at the top-level, so for example queries like:
668
669        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
670
671    are invalid in those dialects. This transformation can be used to ensure all CTEs are
672    moved to the top level so that the final SQL code is valid from a syntax standpoint.
673
674    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
675    """
676    top_level_with = expression.args.get("with")
677    for inner_with in expression.find_all(exp.With):
678        if inner_with.parent is expression:
679            continue
680
681        if not top_level_with:
682            top_level_with = inner_with.pop()
683            expression.set("with", top_level_with)
684        else:
685            if inner_with.recursive:
686                top_level_with.set("recursive", True)
687
688            parent_cte = inner_with.find_ancestor(exp.CTE)
689            inner_with.pop()
690
691            if parent_cte:
692                i = top_level_with.expressions.index(parent_cte)
693                top_level_with.expressions[i:i] = inner_with.expressions
694                top_level_with.set("expressions", top_level_with.expressions)
695            else:
696                top_level_with.set(
697                    "expressions", top_level_with.expressions + inner_with.expressions
698                )
699
700    return expression
701
702
703def ensure_bools(expression: exp.Expression) -> exp.Expression:
704    """Converts numeric values used in conditions into explicit boolean expressions."""
705    from sqlglot.optimizer.canonicalize import ensure_bools
706
707    def _ensure_bool(node: exp.Expression) -> None:
708        if (
709            node.is_number
710            or (
711                not isinstance(node, exp.SubqueryPredicate)
712                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
713            )
714            or (isinstance(node, exp.Column) and not node.type)
715        ):
716            node.replace(node.neq(0))
717
718    for node in expression.walk():
719        ensure_bools(node, _ensure_bool)
720
721    return expression
722
723
724def unqualify_columns(expression: exp.Expression) -> exp.Expression:
725    for column in expression.find_all(exp.Column):
726        # We only wanna pop off the table, db, catalog args
727        for part in column.parts[:-1]:
728            part.pop()
729
730    return expression
731
732
733def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
734    assert isinstance(expression, exp.Create)
735    for constraint in expression.find_all(exp.UniqueColumnConstraint):
736        if constraint.parent:
737            constraint.parent.pop()
738
739    return expression
740
741
742def ctas_with_tmp_tables_to_create_tmp_view(
743    expression: exp.Expression,
744    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
745) -> exp.Expression:
746    assert isinstance(expression, exp.Create)
747    properties = expression.args.get("properties")
748    temporary = any(
749        isinstance(prop, exp.TemporaryProperty)
750        for prop in (properties.expressions if properties else [])
751    )
752
753    # CTAS with temp tables map to CREATE TEMPORARY VIEW
754    if expression.kind == "TABLE" and temporary:
755        if expression.expression:
756            return exp.Create(
757                kind="TEMPORARY VIEW",
758                this=expression.this,
759                expression=expression.expression,
760            )
761        return tmp_storage_provider(expression)
762
763    return expression
764
765
766def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
767    """
768    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
769    PARTITIONED BY value is an array of column names, they are transformed into a schema.
770    The corresponding columns are removed from the create statement.
771    """
772    assert isinstance(expression, exp.Create)
773    has_schema = isinstance(expression.this, exp.Schema)
774    is_partitionable = expression.kind in {"TABLE", "VIEW"}
775
776    if has_schema and is_partitionable:
777        prop = expression.find(exp.PartitionedByProperty)
778        if prop and prop.this and not isinstance(prop.this, exp.Schema):
779            schema = expression.this
780            columns = {v.name.upper() for v in prop.this.expressions}
781            partitions = [col for col in schema.expressions if col.name.upper() in columns]
782            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
783            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
784            expression.set("this", schema)
785
786    return expression
787
788
789def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
790    """
791    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
792
793    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
794    """
795    assert isinstance(expression, exp.Create)
796    prop = expression.find(exp.PartitionedByProperty)
797    if (
798        prop
799        and prop.this
800        and isinstance(prop.this, exp.Schema)
801        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
802    ):
803        prop_this = exp.Tuple(
804            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
805        )
806        schema = expression.this
807        for e in prop.this.expressions:
808            schema.append("expressions", e)
809        prop.set("this", prop_this)
810
811    return expression
812
813
814def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
815    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
816    if isinstance(expression, exp.Struct):
817        expression.set(
818            "expressions",
819            [
820                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
821                for e in expression.expressions
822            ],
823        )
824
825    return expression
826
827
828def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
829    """
830    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
831    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
832
833    For example,
834        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
835        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
836
837    Args:
838        expression: The AST to remove join marks from.
839
840    Returns:
841       The AST with join marks removed.
842    """
843    from sqlglot.optimizer.scope import traverse_scope
844
845    for scope in traverse_scope(expression):
846        query = scope.expression
847
848        where = query.args.get("where")
849        joins = query.args.get("joins")
850
851        if not where or not joins:
852            continue
853
854        query_from = query.args["from"]
855
856        # These keep track of the joins to be replaced
857        new_joins: t.Dict[str, exp.Join] = {}
858        old_joins = {join.alias_or_name: join for join in joins}
859
860        for column in scope.columns:
861            if not column.args.get("join_mark"):
862                continue
863
864            predicate = column.find_ancestor(exp.Predicate, exp.Select)
865            assert isinstance(
866                predicate, exp.Binary
867            ), "Columns can only be marked with (+) when involved in a binary operation"
868
869            predicate_parent = predicate.parent
870            join_predicate = predicate.pop()
871
872            left_columns = [
873                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
874            ]
875            right_columns = [
876                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
877            ]
878
879            assert not (
880                left_columns and right_columns
881            ), "The (+) marker cannot appear in both sides of a binary predicate"
882
883            marked_column_tables = set()
884            for col in left_columns or right_columns:
885                table = col.table
886                assert table, f"Column {col} needs to be qualified with a table"
887
888                col.set("join_mark", False)
889                marked_column_tables.add(table)
890
891            assert (
892                len(marked_column_tables) == 1
893            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
894
895            join_this = old_joins.get(col.table, query_from).this
896            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
897
898            # Upsert new_join into new_joins dictionary
899            new_join_alias_or_name = new_join.alias_or_name
900            existing_join = new_joins.get(new_join_alias_or_name)
901            if existing_join:
902                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
903            else:
904                new_joins[new_join_alias_or_name] = new_join
905
906            # If the parent of the target predicate is a binary node, then it now has only one child
907            if isinstance(predicate_parent, exp.Binary):
908                if predicate_parent.left is None:
909                    predicate_parent.replace(predicate_parent.right)
910                else:
911                    predicate_parent.replace(predicate_parent.left)
912
913        if query_from.alias_or_name in new_joins:
914            only_old_joins = old_joins.keys() - new_joins.keys()
915            assert (
916                len(only_old_joins) >= 1
917            ), "Cannot determine which table to use in the new FROM clause"
918
919            new_from_name = list(only_old_joins)[0]
920            query.set("from", exp.From(this=old_joins[new_from_name].this))
921
922        query.set("joins", list(new_joins.values()))
923
924        if not where.this:
925            where.pop()
926
927    return expression
928
929
930def any_to_exists(expression: exp.Expression) -> exp.Expression:
931    """
932    Transform ANY operator to Spark's EXISTS
933
934    For example,
935        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
936        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
937
938    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
939    transformation
940    """
941    if isinstance(expression, exp.Select):
942        for any in expression.find_all(exp.Any):
943            this = any.this
944            if isinstance(this, exp.Query):
945                continue
946
947            binop = any.parent
948            if isinstance(binop, exp.Binary):
949                lambda_arg = exp.to_identifier("x")
950                any.replace(lambda_arg)
951                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
952                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
953
954    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
16def preprocess(
17    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
18) -> t.Callable[[Generator, exp.Expression], str]:
19    """
20    Creates a new transform by chaining a sequence of transformations and converts the resulting
21    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
22    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
23
24    Args:
25        transforms: sequence of transform functions. These will be called in order.
26
27    Returns:
28        Function that can be used as a generator transform.
29    """
30
31    def _to_sql(self, expression: exp.Expression) -> str:
32        expression_type = type(expression)
33
34        try:
35            expression = transforms[0](expression)
36            for transform in transforms[1:]:
37                expression = transform(expression)
38        except UnsupportedError as unsupported_error:
39            self.unsupported(str(unsupported_error))
40
41        _sql_handler = getattr(self, expression.key + "_sql", None)
42        if _sql_handler:
43            return _sql_handler(expression)
44
45        transforms_handler = self.TRANSFORMS.get(type(expression))
46        if transforms_handler:
47            if expression_type is type(expression):
48                if isinstance(expression, exp.Func):
49                    return self.function_fallback_sql(expression)
50
51                # Ensures we don't enter an infinite loop. This can happen when the original expression
52                # has the same type as the final expression and there's no _sql method available for it,
53                # because then it'd re-enter _to_sql.
54                raise ValueError(
55                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
56                )
57
58            return transforms_handler(self, expression)
59
60        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
61
62    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 66    if isinstance(expression, exp.Select):
 67        count = 0
 68        recursive_ctes = []
 69
 70        for unnest in expression.find_all(exp.Unnest):
 71            if (
 72                not isinstance(unnest.parent, (exp.From, exp.Join))
 73                or len(unnest.expressions) != 1
 74                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 75            ):
 76                continue
 77
 78            generate_date_array = unnest.expressions[0]
 79            start = generate_date_array.args.get("start")
 80            end = generate_date_array.args.get("end")
 81            step = generate_date_array.args.get("step")
 82
 83            if not start or not end or not isinstance(step, exp.Interval):
 84                continue
 85
 86            alias = unnest.args.get("alias")
 87            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 88
 89            start = exp.cast(start, "date")
 90            date_add = exp.func(
 91                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 92            )
 93            cast_date_add = exp.cast(date_add, "date")
 94
 95            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 96
 97            base_query = exp.select(start.as_(column_name))
 98            recursive_query = (
 99                exp.select(cast_date_add)
100                .from_(cte_name)
101                .where(cast_date_add <= exp.cast(end, "date"))
102            )
103            cte_query = base_query.union(recursive_query, distinct=False)
104
105            generate_dates_query = exp.select(column_name).from_(cte_name)
106            unnest.replace(generate_dates_query.subquery(cte_name))
107
108            recursive_ctes.append(
109                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
110            )
111            count += 1
112
113        if recursive_ctes:
114            with_expression = expression.args.get("with") or exp.With()
115            with_expression.set("recursive", True)
116            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
117            expression.set("with", with_expression)
118
119    return expression
def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
123    """Unnests GENERATE_SERIES or SEQUENCE table references."""
124    this = expression.this
125    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
126        unnest = exp.Unnest(expressions=[this])
127        if expression.alias:
128            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
129
130        return unnest
131
132    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
135def unalias_group(expression: exp.Expression) -> exp.Expression:
136    """
137    Replace references to select aliases in GROUP BY clauses.
138
139    Example:
140        >>> import sqlglot
141        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
142        'SELECT a AS b FROM x GROUP BY 1'
143
144    Args:
145        expression: the expression that will be transformed.
146
147    Returns:
148        The transformed expression.
149    """
150    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
151        aliased_selects = {
152            e.alias: i
153            for i, e in enumerate(expression.parent.expressions, start=1)
154            if isinstance(e, exp.Alias)
155        }
156
157        for group_by in expression.expressions:
158            if (
159                isinstance(group_by, exp.Column)
160                and not group_by.table
161                and group_by.name in aliased_selects
162            ):
163                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
164
165    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
168def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
169    """
170    Convert SELECT DISTINCT ON statements to a subquery with a window function.
171
172    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
173
174    Args:
175        expression: the expression that will be transformed.
176
177    Returns:
178        The transformed expression.
179    """
180    if (
181        isinstance(expression, exp.Select)
182        and expression.args.get("distinct")
183        and expression.args["distinct"].args.get("on")
184        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
185    ):
186        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
187        outer_selects = expression.selects
188        row_number = find_new_name(expression.named_selects, "_row_number")
189        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
190        order = expression.args.get("order")
191
192        if order:
193            window.set("order", order.pop())
194        else:
195            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
196
197        window = exp.alias_(window, row_number)
198        expression.select(window, copy=False)
199
200        return (
201            exp.select(*outer_selects, copy=False)
202            .from_(expression.subquery("_t", copy=False), copy=False)
203            .where(exp.column(row_number).eq(1), copy=False)
204        )
205
206    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
209def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
210    """
211    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
212
213    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
214    https://docs.snowflake.com/en/sql-reference/constructs/qualify
215
216    Some dialects don't support window functions in the WHERE clause, so we need to include them as
217    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
218    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
219    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
220    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
221    corresponding expression to avoid creating invalid column references.
222    """
223    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
224        taken = set(expression.named_selects)
225        for select in expression.selects:
226            if not select.alias_or_name:
227                alias = find_new_name(taken, "_c")
228                select.replace(exp.alias_(select, alias))
229                taken.add(alias)
230
231        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
232            alias_or_name = select.alias_or_name
233            identifier = select.args.get("alias") or select.this
234            if isinstance(identifier, exp.Identifier):
235                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
236            return alias_or_name
237
238        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
239        qualify_filters = expression.args["qualify"].pop().this
240        expression_by_alias = {
241            select.alias: select.this
242            for select in expression.selects
243            if isinstance(select, exp.Alias)
244        }
245
246        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
247        for select_candidate in qualify_filters.find_all(select_candidates):
248            if isinstance(select_candidate, exp.Window):
249                if expression_by_alias:
250                    for column in select_candidate.find_all(exp.Column):
251                        expr = expression_by_alias.get(column.name)
252                        if expr:
253                            column.replace(expr)
254
255                alias = find_new_name(expression.named_selects, "_w")
256                expression.select(exp.alias_(select_candidate, alias), copy=False)
257                column = exp.column(alias)
258
259                if isinstance(select_candidate.parent, exp.Qualify):
260                    qualify_filters = column
261                else:
262                    select_candidate.replace(column)
263            elif select_candidate.name not in expression.named_selects:
264                expression.select(select_candidate.copy(), copy=False)
265
266        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
267            qualify_filters, copy=False
268        )
269
270    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
273def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
274    """
275    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
276    other expressions. This transforms removes the precision from parameterized types in expressions.
277    """
278    for node in expression.find_all(exp.DataType):
279        node.set(
280            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
281        )
282
283    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
286def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
287    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
288    from sqlglot.optimizer.scope import find_all_in_scope
289
290    if isinstance(expression, exp.Select):
291        unnest_aliases = {
292            unnest.alias
293            for unnest in find_all_in_scope(expression, exp.Unnest)
294            if isinstance(unnest.parent, (exp.From, exp.Join))
295        }
296        if unnest_aliases:
297            for column in expression.find_all(exp.Column):
298                if column.table in unnest_aliases:
299                    column.set("table", None)
300                elif column.db in unnest_aliases:
301                    column.set("db", None)
302
303    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression, unnest_using_arrays_zip: bool = True) -> sqlglot.expressions.Expression:
306def unnest_to_explode(
307    expression: exp.Expression,
308    unnest_using_arrays_zip: bool = True,
309) -> exp.Expression:
310    """Convert cross join unnest into lateral view explode."""
311
312    def _unnest_zip_exprs(
313        u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
314    ) -> t.List[exp.Expression]:
315        if has_multi_expr:
316            if not unnest_using_arrays_zip:
317                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
318
319            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
320            zip_exprs: t.List[exp.Expression] = [
321                exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
322            ]
323            u.set("expressions", zip_exprs)
324            return zip_exprs
325        return unnest_exprs
326
327    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
328        if u.args.get("offset"):
329            return exp.Posexplode
330        return exp.Inline if has_multi_expr else exp.Explode
331
332    if isinstance(expression, exp.Select):
333        from_ = expression.args.get("from")
334
335        if from_ and isinstance(from_.this, exp.Unnest):
336            unnest = from_.this
337            alias = unnest.args.get("alias")
338            exprs = unnest.expressions
339            has_multi_expr = len(exprs) > 1
340            this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
341
342            unnest.replace(
343                exp.Table(
344                    this=_udtf_type(unnest, has_multi_expr)(
345                        this=this,
346                        expressions=expressions,
347                    ),
348                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
349                )
350            )
351
352        joins = expression.args.get("joins") or []
353        for join in list(joins):
354            join_expr = join.this
355
356            is_lateral = isinstance(join_expr, exp.Lateral)
357
358            unnest = join_expr.this if is_lateral else join_expr
359
360            if isinstance(unnest, exp.Unnest):
361                if is_lateral:
362                    alias = join_expr.args.get("alias")
363                else:
364                    alias = unnest.args.get("alias")
365                exprs = unnest.expressions
366                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
367                has_multi_expr = len(exprs) > 1
368                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
369
370                joins.remove(join)
371
372                alias_cols = alias.columns if alias else []
373
374                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
375                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
376                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
377
378                if not has_multi_expr and len(alias_cols) not in (1, 2):
379                    raise UnsupportedError(
380                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
381                    )
382
383                for e, column in zip(exprs, alias_cols):
384                    expression.append(
385                        "laterals",
386                        exp.Lateral(
387                            this=_udtf_type(unnest, has_multi_expr)(this=e),
388                            view=True,
389                            alias=exp.TableAlias(
390                                this=alias.this,  # type: ignore
391                                columns=alias_cols,
392                            ),
393                        ),
394                    )
395
396    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
399def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
400    """Convert explode/posexplode into unnest."""
401
402    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
403        if isinstance(expression, exp.Select):
404            from sqlglot.optimizer.scope import Scope
405
406            taken_select_names = set(expression.named_selects)
407            taken_source_names = {name for name, _ in Scope(expression).references}
408
409            def new_name(names: t.Set[str], name: str) -> str:
410                name = find_new_name(names, name)
411                names.add(name)
412                return name
413
414            arrays: t.List[exp.Condition] = []
415            series_alias = new_name(taken_select_names, "pos")
416            series = exp.alias_(
417                exp.Unnest(
418                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
419                ),
420                new_name(taken_source_names, "_u"),
421                table=[series_alias],
422            )
423
424            # we use list here because expression.selects is mutated inside the loop
425            for select in list(expression.selects):
426                explode = select.find(exp.Explode)
427
428                if explode:
429                    pos_alias = ""
430                    explode_alias = ""
431
432                    if isinstance(select, exp.Alias):
433                        explode_alias = select.args["alias"]
434                        alias = select
435                    elif isinstance(select, exp.Aliases):
436                        pos_alias = select.aliases[0]
437                        explode_alias = select.aliases[1]
438                        alias = select.replace(exp.alias_(select.this, "", copy=False))
439                    else:
440                        alias = select.replace(exp.alias_(select, ""))
441                        explode = alias.find(exp.Explode)
442                        assert explode
443
444                    is_posexplode = isinstance(explode, exp.Posexplode)
445                    explode_arg = explode.this
446
447                    if isinstance(explode, exp.ExplodeOuter):
448                        bracket = explode_arg[0]
449                        bracket.set("safe", True)
450                        bracket.set("offset", True)
451                        explode_arg = exp.func(
452                            "IF",
453                            exp.func(
454                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
455                            ).eq(0),
456                            exp.array(bracket, copy=False),
457                            explode_arg,
458                        )
459
460                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
461                    if isinstance(explode_arg, exp.Column):
462                        taken_select_names.add(explode_arg.output_name)
463
464                    unnest_source_alias = new_name(taken_source_names, "_u")
465
466                    if not explode_alias:
467                        explode_alias = new_name(taken_select_names, "col")
468
469                        if is_posexplode:
470                            pos_alias = new_name(taken_select_names, "pos")
471
472                    if not pos_alias:
473                        pos_alias = new_name(taken_select_names, "pos")
474
475                    alias.set("alias", exp.to_identifier(explode_alias))
476
477                    series_table_alias = series.args["alias"].this
478                    column = exp.If(
479                        this=exp.column(series_alias, table=series_table_alias).eq(
480                            exp.column(pos_alias, table=unnest_source_alias)
481                        ),
482                        true=exp.column(explode_alias, table=unnest_source_alias),
483                    )
484
485                    explode.replace(column)
486
487                    if is_posexplode:
488                        expressions = expression.expressions
489                        expressions.insert(
490                            expressions.index(alias) + 1,
491                            exp.If(
492                                this=exp.column(series_alias, table=series_table_alias).eq(
493                                    exp.column(pos_alias, table=unnest_source_alias)
494                                ),
495                                true=exp.column(pos_alias, table=unnest_source_alias),
496                            ).as_(pos_alias),
497                        )
498                        expression.set("expressions", expressions)
499
500                    if not arrays:
501                        if expression.args.get("from"):
502                            expression.join(series, copy=False, join_type="CROSS")
503                        else:
504                            expression.from_(series, copy=False)
505
506                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
507                    arrays.append(size)
508
509                    # trino doesn't support left join unnest with on conditions
510                    # if it did, this would be much simpler
511                    expression.join(
512                        exp.alias_(
513                            exp.Unnest(
514                                expressions=[explode_arg.copy()],
515                                offset=exp.to_identifier(pos_alias),
516                            ),
517                            unnest_source_alias,
518                            table=[explode_alias],
519                        ),
520                        join_type="CROSS",
521                        copy=False,
522                    )
523
524                    if index_offset != 1:
525                        size = size - 1
526
527                    expression.where(
528                        exp.column(series_alias, table=series_table_alias)
529                        .eq(exp.column(pos_alias, table=unnest_source_alias))
530                        .or_(
531                            (exp.column(series_alias, table=series_table_alias) > size).and_(
532                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
533                            )
534                        ),
535                        copy=False,
536                    )
537
538            if arrays:
539                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
540
541                if index_offset != 1:
542                    end = end - (1 - index_offset)
543                series.expressions[0].set("end", end)
544
545        return expression
546
547    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
550def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
551    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
552    if (
553        isinstance(expression, exp.PERCENTILES)
554        and not isinstance(expression.parent, exp.WithinGroup)
555        and expression.expression
556    ):
557        column = expression.this.pop()
558        expression.set("this", expression.expression.pop())
559        order = exp.Order(expressions=[exp.Ordered(this=column)])
560        expression = exp.WithinGroup(this=expression, expression=order)
561
562    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
565def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
566    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
567    if (
568        isinstance(expression, exp.WithinGroup)
569        and isinstance(expression.this, exp.PERCENTILES)
570        and isinstance(expression.expression, exp.Order)
571    ):
572        quantile = expression.this.this
573        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
574        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
575
576    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
579def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
580    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
581    if isinstance(expression, exp.With) and expression.recursive:
582        next_name = name_sequence("_c_")
583
584        for cte in expression.expressions:
585            if not cte.args["alias"].columns:
586                query = cte.this
587                if isinstance(query, exp.SetOperation):
588                    query = query.this
589
590                cte.args["alias"].set(
591                    "columns",
592                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
593                )
594
595    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
598def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
599    """Replace 'epoch' in casts by the equivalent date literal."""
600    if (
601        isinstance(expression, (exp.Cast, exp.TryCast))
602        and expression.name.lower() == "epoch"
603        and expression.to.this in exp.DataType.TEMPORAL_TYPES
604    ):
605        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
606
607    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
610def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
611    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
612    if isinstance(expression, exp.Select):
613        for join in expression.args.get("joins") or []:
614            on = join.args.get("on")
615            if on and join.kind in ("SEMI", "ANTI"):
616                subquery = exp.select("1").from_(join.this).where(on)
617                exists = exp.Exists(this=subquery)
618                if join.kind == "ANTI":
619                    exists = exists.not_(copy=False)
620
621                join.pop()
622                expression.where(exists, copy=False)
623
624    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
627def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
628    """
629    Converts a query with a FULL OUTER join to a union of identical queries that
630    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
631    for queries that have a single FULL OUTER join.
632    """
633    if isinstance(expression, exp.Select):
634        full_outer_joins = [
635            (index, join)
636            for index, join in enumerate(expression.args.get("joins") or [])
637            if join.side == "FULL"
638        ]
639
640        if len(full_outer_joins) == 1:
641            expression_copy = expression.copy()
642            expression.set("limit", None)
643            index, full_outer_join = full_outer_joins[0]
644
645            tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name)
646            join_conditions = full_outer_join.args.get("on") or exp.and_(
647                *[
648                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
649                    for col in full_outer_join.args.get("using")
650                ]
651            )
652
653            full_outer_join.set("side", "left")
654            anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions)
655            expression_copy.args["joins"][index].set("side", "right")
656            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
657            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
658            expression.args.pop("order", None)  # remove order by from LEFT side
659
660            return exp.union(expression, expression_copy, copy=False, distinct=False)
661
662    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level(expression: ~E) -> ~E:
665def move_ctes_to_top_level(expression: E) -> E:
666    """
667    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
668    defined at the top-level, so for example queries like:
669
670        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
671
672    are invalid in those dialects. This transformation can be used to ensure all CTEs are
673    moved to the top level so that the final SQL code is valid from a syntax standpoint.
674
675    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
676    """
677    top_level_with = expression.args.get("with")
678    for inner_with in expression.find_all(exp.With):
679        if inner_with.parent is expression:
680            continue
681
682        if not top_level_with:
683            top_level_with = inner_with.pop()
684            expression.set("with", top_level_with)
685        else:
686            if inner_with.recursive:
687                top_level_with.set("recursive", True)
688
689            parent_cte = inner_with.find_ancestor(exp.CTE)
690            inner_with.pop()
691
692            if parent_cte:
693                i = top_level_with.expressions.index(parent_cte)
694                top_level_with.expressions[i:i] = inner_with.expressions
695                top_level_with.set("expressions", top_level_with.expressions)
696            else:
697                top_level_with.set(
698                    "expressions", top_level_with.expressions + inner_with.expressions
699                )
700
701    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
704def ensure_bools(expression: exp.Expression) -> exp.Expression:
705    """Converts numeric values used in conditions into explicit boolean expressions."""
706    from sqlglot.optimizer.canonicalize import ensure_bools
707
708    def _ensure_bool(node: exp.Expression) -> None:
709        if (
710            node.is_number
711            or (
712                not isinstance(node, exp.SubqueryPredicate)
713                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
714            )
715            or (isinstance(node, exp.Column) and not node.type)
716        ):
717            node.replace(node.neq(0))
718
719    for node in expression.walk():
720        ensure_bools(node, _ensure_bool)
721
722    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
725def unqualify_columns(expression: exp.Expression) -> exp.Expression:
726    for column in expression.find_all(exp.Column):
727        # We only wanna pop off the table, db, catalog args
728        for part in column.parts[:-1]:
729            part.pop()
730
731    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
734def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
735    assert isinstance(expression, exp.Create)
736    for constraint in expression.find_all(exp.UniqueColumnConstraint):
737        if constraint.parent:
738            constraint.parent.pop()
739
740    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
743def ctas_with_tmp_tables_to_create_tmp_view(
744    expression: exp.Expression,
745    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
746) -> exp.Expression:
747    assert isinstance(expression, exp.Create)
748    properties = expression.args.get("properties")
749    temporary = any(
750        isinstance(prop, exp.TemporaryProperty)
751        for prop in (properties.expressions if properties else [])
752    )
753
754    # CTAS with temp tables map to CREATE TEMPORARY VIEW
755    if expression.kind == "TABLE" and temporary:
756        if expression.expression:
757            return exp.Create(
758                kind="TEMPORARY VIEW",
759                this=expression.this,
760                expression=expression.expression,
761            )
762        return tmp_storage_provider(expression)
763
764    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
767def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
768    """
769    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
770    PARTITIONED BY value is an array of column names, they are transformed into a schema.
771    The corresponding columns are removed from the create statement.
772    """
773    assert isinstance(expression, exp.Create)
774    has_schema = isinstance(expression.this, exp.Schema)
775    is_partitionable = expression.kind in {"TABLE", "VIEW"}
776
777    if has_schema and is_partitionable:
778        prop = expression.find(exp.PartitionedByProperty)
779        if prop and prop.this and not isinstance(prop.this, exp.Schema):
780            schema = expression.this
781            columns = {v.name.upper() for v in prop.this.expressions}
782            partitions = [col for col in schema.expressions if col.name.upper() in columns]
783            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
784            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
785            expression.set("this", schema)
786
787    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
790def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
791    """
792    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
793
794    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
795    """
796    assert isinstance(expression, exp.Create)
797    prop = expression.find(exp.PartitionedByProperty)
798    if (
799        prop
800        and prop.this
801        and isinstance(prop.this, exp.Schema)
802        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
803    ):
804        prop_this = exp.Tuple(
805            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
806        )
807        schema = expression.this
808        for e in prop.this.expressions:
809            schema.append("expressions", e)
810        prop.set("this", prop_this)
811
812    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
815def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
816    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
817    if isinstance(expression, exp.Struct):
818        expression.set(
819            "expressions",
820            [
821                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
822                for e in expression.expressions
823            ],
824        )
825
826    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
829def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
830    """
831    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
832    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
833
834    For example,
835        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
836        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
837
838    Args:
839        expression: The AST to remove join marks from.
840
841    Returns:
842       The AST with join marks removed.
843    """
844    from sqlglot.optimizer.scope import traverse_scope
845
846    for scope in traverse_scope(expression):
847        query = scope.expression
848
849        where = query.args.get("where")
850        joins = query.args.get("joins")
851
852        if not where or not joins:
853            continue
854
855        query_from = query.args["from"]
856
857        # These keep track of the joins to be replaced
858        new_joins: t.Dict[str, exp.Join] = {}
859        old_joins = {join.alias_or_name: join for join in joins}
860
861        for column in scope.columns:
862            if not column.args.get("join_mark"):
863                continue
864
865            predicate = column.find_ancestor(exp.Predicate, exp.Select)
866            assert isinstance(
867                predicate, exp.Binary
868            ), "Columns can only be marked with (+) when involved in a binary operation"
869
870            predicate_parent = predicate.parent
871            join_predicate = predicate.pop()
872
873            left_columns = [
874                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
875            ]
876            right_columns = [
877                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
878            ]
879
880            assert not (
881                left_columns and right_columns
882            ), "The (+) marker cannot appear in both sides of a binary predicate"
883
884            marked_column_tables = set()
885            for col in left_columns or right_columns:
886                table = col.table
887                assert table, f"Column {col} needs to be qualified with a table"
888
889                col.set("join_mark", False)
890                marked_column_tables.add(table)
891
892            assert (
893                len(marked_column_tables) == 1
894            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
895
896            join_this = old_joins.get(col.table, query_from).this
897            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
898
899            # Upsert new_join into new_joins dictionary
900            new_join_alias_or_name = new_join.alias_or_name
901            existing_join = new_joins.get(new_join_alias_or_name)
902            if existing_join:
903                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
904            else:
905                new_joins[new_join_alias_or_name] = new_join
906
907            # If the parent of the target predicate is a binary node, then it now has only one child
908            if isinstance(predicate_parent, exp.Binary):
909                if predicate_parent.left is None:
910                    predicate_parent.replace(predicate_parent.right)
911                else:
912                    predicate_parent.replace(predicate_parent.left)
913
914        if query_from.alias_or_name in new_joins:
915            only_old_joins = old_joins.keys() - new_joins.keys()
916            assert (
917                len(only_old_joins) >= 1
918            ), "Cannot determine which table to use in the new FROM clause"
919
920            new_from_name = list(only_old_joins)[0]
921            query.set("from", exp.From(this=old_joins[new_from_name].this))
922
923        query.set("joins", list(new_joins.values()))
924
925        if not where.this:
926            where.pop()
927
928    return expression

Remove join marks from an AST. This rule assumes that all marked columns are qualified. If this does not hold for a query, consider running sqlglot.optimizer.qualify first.

For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this

Arguments:
  • expression: The AST to remove join marks from.
Returns:

The AST with join marks removed.

def any_to_exists( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
931def any_to_exists(expression: exp.Expression) -> exp.Expression:
932    """
933    Transform ANY operator to Spark's EXISTS
934
935    For example,
936        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
937        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
938
939    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
940    transformation
941    """
942    if isinstance(expression, exp.Select):
943        for any in expression.find_all(exp.Any):
944            this = any.this
945            if isinstance(this, exp.Query):
946                continue
947
948            binop = any.parent
949            if isinstance(binop, exp.Binary):
950                lambda_arg = exp.to_identifier("x")
951                any.replace(lambda_arg)
952                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
953                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
954
955    return expression

Transform ANY operator to Spark's EXISTS

For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)

Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation