sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.errors import UnsupportedError 7from sqlglot.helper import find_new_name, name_sequence 8 9 10if t.TYPE_CHECKING: 11 from sqlglot._typing import E 12 from sqlglot.generator import Generator 13 14 15def preprocess( 16 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 17) -> t.Callable[[Generator, exp.Expression], str]: 18 """ 19 Creates a new transform by chaining a sequence of transformations and converts the resulting 20 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 21 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 22 23 Args: 24 transforms: sequence of transform functions. These will be called in order. 25 26 Returns: 27 Function that can be used as a generator transform. 28 """ 29 30 def _to_sql(self, expression: exp.Expression) -> str: 31 expression_type = type(expression) 32 33 try: 34 expression = transforms[0](expression) 35 for transform in transforms[1:]: 36 expression = transform(expression) 37 except UnsupportedError as unsupported_error: 38 self.unsupported(str(unsupported_error)) 39 40 _sql_handler = getattr(self, expression.key + "_sql", None) 41 if _sql_handler: 42 return _sql_handler(expression) 43 44 transforms_handler = self.TRANSFORMS.get(type(expression)) 45 if transforms_handler: 46 if expression_type is type(expression): 47 if isinstance(expression, exp.Func): 48 return self.function_fallback_sql(expression) 49 50 # Ensures we don't enter an infinite loop. This can happen when the original expression 51 # has the same type as the final expression and there's no _sql method available for it, 52 # because then it'd re-enter _to_sql. 53 raise ValueError( 54 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 55 ) 56 57 return transforms_handler(self, expression) 58 59 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 60 61 return _to_sql 62 63 64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 65 if isinstance(expression, exp.Select): 66 count = 0 67 recursive_ctes = [] 68 69 for unnest in expression.find_all(exp.Unnest): 70 if ( 71 not isinstance(unnest.parent, (exp.From, exp.Join)) 72 or len(unnest.expressions) != 1 73 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 74 ): 75 continue 76 77 generate_date_array = unnest.expressions[0] 78 start = generate_date_array.args.get("start") 79 end = generate_date_array.args.get("end") 80 step = generate_date_array.args.get("step") 81 82 if not start or not end or not isinstance(step, exp.Interval): 83 continue 84 85 alias = unnest.args.get("alias") 86 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 87 88 start = exp.cast(start, "date") 89 date_add = exp.func( 90 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 91 ) 92 cast_date_add = exp.cast(date_add, "date") 93 94 cte_name = "_generated_dates" + (f"_{count}" if count else "") 95 96 base_query = exp.select(start.as_(column_name)) 97 recursive_query = ( 98 exp.select(cast_date_add) 99 .from_(cte_name) 100 .where(cast_date_add <= exp.cast(end, "date")) 101 ) 102 cte_query = base_query.union(recursive_query, distinct=False) 103 104 generate_dates_query = exp.select(column_name).from_(cte_name) 105 unnest.replace(generate_dates_query.subquery(cte_name)) 106 107 recursive_ctes.append( 108 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 109 ) 110 count += 1 111 112 if recursive_ctes: 113 with_expression = expression.args.get("with") or exp.With() 114 with_expression.set("recursive", True) 115 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 116 expression.set("with", with_expression) 117 118 return expression 119 120 121def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 122 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 123 this = expression.this 124 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 125 unnest = exp.Unnest(expressions=[this]) 126 if expression.alias: 127 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 128 129 return unnest 130 131 return expression 132 133 134def unalias_group(expression: exp.Expression) -> exp.Expression: 135 """ 136 Replace references to select aliases in GROUP BY clauses. 137 138 Example: 139 >>> import sqlglot 140 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 141 'SELECT a AS b FROM x GROUP BY 1' 142 143 Args: 144 expression: the expression that will be transformed. 145 146 Returns: 147 The transformed expression. 148 """ 149 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 150 aliased_selects = { 151 e.alias: i 152 for i, e in enumerate(expression.parent.expressions, start=1) 153 if isinstance(e, exp.Alias) 154 } 155 156 for group_by in expression.expressions: 157 if ( 158 isinstance(group_by, exp.Column) 159 and not group_by.table 160 and group_by.name in aliased_selects 161 ): 162 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 163 164 return expression 165 166 167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 168 """ 169 Convert SELECT DISTINCT ON statements to a subquery with a window function. 170 171 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 172 173 Args: 174 expression: the expression that will be transformed. 175 176 Returns: 177 The transformed expression. 178 """ 179 if ( 180 isinstance(expression, exp.Select) 181 and expression.args.get("distinct") 182 and expression.args["distinct"].args.get("on") 183 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 184 ): 185 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 186 outer_selects = expression.selects 187 row_number = find_new_name(expression.named_selects, "_row_number") 188 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 189 order = expression.args.get("order") 190 191 if order: 192 window.set("order", order.pop()) 193 else: 194 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 195 196 window = exp.alias_(window, row_number) 197 expression.select(window, copy=False) 198 199 return ( 200 exp.select(*outer_selects, copy=False) 201 .from_(expression.subquery("_t", copy=False), copy=False) 202 .where(exp.column(row_number).eq(1), copy=False) 203 ) 204 205 return expression 206 207 208def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 209 """ 210 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 211 212 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 213 https://docs.snowflake.com/en/sql-reference/constructs/qualify 214 215 Some dialects don't support window functions in the WHERE clause, so we need to include them as 216 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 217 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 218 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 219 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 220 corresponding expression to avoid creating invalid column references. 221 """ 222 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 223 taken = set(expression.named_selects) 224 for select in expression.selects: 225 if not select.alias_or_name: 226 alias = find_new_name(taken, "_c") 227 select.replace(exp.alias_(select, alias)) 228 taken.add(alias) 229 230 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 231 alias_or_name = select.alias_or_name 232 identifier = select.args.get("alias") or select.this 233 if isinstance(identifier, exp.Identifier): 234 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 235 return alias_or_name 236 237 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 238 qualify_filters = expression.args["qualify"].pop().this 239 expression_by_alias = { 240 select.alias: select.this 241 for select in expression.selects 242 if isinstance(select, exp.Alias) 243 } 244 245 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 246 for select_candidate in qualify_filters.find_all(select_candidates): 247 if isinstance(select_candidate, exp.Window): 248 if expression_by_alias: 249 for column in select_candidate.find_all(exp.Column): 250 expr = expression_by_alias.get(column.name) 251 if expr: 252 column.replace(expr) 253 254 alias = find_new_name(expression.named_selects, "_w") 255 expression.select(exp.alias_(select_candidate, alias), copy=False) 256 column = exp.column(alias) 257 258 if isinstance(select_candidate.parent, exp.Qualify): 259 qualify_filters = column 260 else: 261 select_candidate.replace(column) 262 elif select_candidate.name not in expression.named_selects: 263 expression.select(select_candidate.copy(), copy=False) 264 265 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 266 qualify_filters, copy=False 267 ) 268 269 return expression 270 271 272def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 273 """ 274 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 275 other expressions. This transforms removes the precision from parameterized types in expressions. 276 """ 277 for node in expression.find_all(exp.DataType): 278 node.set( 279 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 280 ) 281 282 return expression 283 284 285def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 286 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 287 from sqlglot.optimizer.scope import find_all_in_scope 288 289 if isinstance(expression, exp.Select): 290 unnest_aliases = { 291 unnest.alias 292 for unnest in find_all_in_scope(expression, exp.Unnest) 293 if isinstance(unnest.parent, (exp.From, exp.Join)) 294 } 295 if unnest_aliases: 296 for column in expression.find_all(exp.Column): 297 if column.table in unnest_aliases: 298 column.set("table", None) 299 elif column.db in unnest_aliases: 300 column.set("db", None) 301 302 return expression 303 304 305def unnest_to_explode( 306 expression: exp.Expression, 307 unnest_using_arrays_zip: bool = True, 308) -> exp.Expression: 309 """Convert cross join unnest into lateral view explode.""" 310 311 def _unnest_zip_exprs( 312 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 313 ) -> t.List[exp.Expression]: 314 if has_multi_expr: 315 if not unnest_using_arrays_zip: 316 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 317 318 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 319 zip_exprs: t.List[exp.Expression] = [ 320 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 321 ] 322 u.set("expressions", zip_exprs) 323 return zip_exprs 324 return unnest_exprs 325 326 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 327 if u.args.get("offset"): 328 return exp.Posexplode 329 return exp.Inline if has_multi_expr else exp.Explode 330 331 if isinstance(expression, exp.Select): 332 from_ = expression.args.get("from") 333 334 if from_ and isinstance(from_.this, exp.Unnest): 335 unnest = from_.this 336 alias = unnest.args.get("alias") 337 exprs = unnest.expressions 338 has_multi_expr = len(exprs) > 1 339 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 340 341 unnest.replace( 342 exp.Table( 343 this=_udtf_type(unnest, has_multi_expr)( 344 this=this, 345 expressions=expressions, 346 ), 347 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 348 ) 349 ) 350 351 joins = expression.args.get("joins") or [] 352 for join in list(joins): 353 join_expr = join.this 354 355 is_lateral = isinstance(join_expr, exp.Lateral) 356 357 unnest = join_expr.this if is_lateral else join_expr 358 359 if isinstance(unnest, exp.Unnest): 360 if is_lateral: 361 alias = join_expr.args.get("alias") 362 else: 363 alias = unnest.args.get("alias") 364 exprs = unnest.expressions 365 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 366 has_multi_expr = len(exprs) > 1 367 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 368 369 joins.remove(join) 370 371 alias_cols = alias.columns if alias else [] 372 373 # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases 374 # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. 375 # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html 376 377 if not has_multi_expr and len(alias_cols) not in (1, 2): 378 raise UnsupportedError( 379 "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases" 380 ) 381 382 for e, column in zip(exprs, alias_cols): 383 expression.append( 384 "laterals", 385 exp.Lateral( 386 this=_udtf_type(unnest, has_multi_expr)(this=e), 387 view=True, 388 alias=exp.TableAlias( 389 this=alias.this, # type: ignore 390 columns=alias_cols, 391 ), 392 ), 393 ) 394 395 return expression 396 397 398def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 399 """Convert explode/posexplode into unnest.""" 400 401 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 402 if isinstance(expression, exp.Select): 403 from sqlglot.optimizer.scope import Scope 404 405 taken_select_names = set(expression.named_selects) 406 taken_source_names = {name for name, _ in Scope(expression).references} 407 408 def new_name(names: t.Set[str], name: str) -> str: 409 name = find_new_name(names, name) 410 names.add(name) 411 return name 412 413 arrays: t.List[exp.Condition] = [] 414 series_alias = new_name(taken_select_names, "pos") 415 series = exp.alias_( 416 exp.Unnest( 417 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 418 ), 419 new_name(taken_source_names, "_u"), 420 table=[series_alias], 421 ) 422 423 # we use list here because expression.selects is mutated inside the loop 424 for select in list(expression.selects): 425 explode = select.find(exp.Explode) 426 427 if explode: 428 pos_alias = "" 429 explode_alias = "" 430 431 if isinstance(select, exp.Alias): 432 explode_alias = select.args["alias"] 433 alias = select 434 elif isinstance(select, exp.Aliases): 435 pos_alias = select.aliases[0] 436 explode_alias = select.aliases[1] 437 alias = select.replace(exp.alias_(select.this, "", copy=False)) 438 else: 439 alias = select.replace(exp.alias_(select, "")) 440 explode = alias.find(exp.Explode) 441 assert explode 442 443 is_posexplode = isinstance(explode, exp.Posexplode) 444 explode_arg = explode.this 445 446 if isinstance(explode, exp.ExplodeOuter): 447 bracket = explode_arg[0] 448 bracket.set("safe", True) 449 bracket.set("offset", True) 450 explode_arg = exp.func( 451 "IF", 452 exp.func( 453 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 454 ).eq(0), 455 exp.array(bracket, copy=False), 456 explode_arg, 457 ) 458 459 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 460 if isinstance(explode_arg, exp.Column): 461 taken_select_names.add(explode_arg.output_name) 462 463 unnest_source_alias = new_name(taken_source_names, "_u") 464 465 if not explode_alias: 466 explode_alias = new_name(taken_select_names, "col") 467 468 if is_posexplode: 469 pos_alias = new_name(taken_select_names, "pos") 470 471 if not pos_alias: 472 pos_alias = new_name(taken_select_names, "pos") 473 474 alias.set("alias", exp.to_identifier(explode_alias)) 475 476 series_table_alias = series.args["alias"].this 477 column = exp.If( 478 this=exp.column(series_alias, table=series_table_alias).eq( 479 exp.column(pos_alias, table=unnest_source_alias) 480 ), 481 true=exp.column(explode_alias, table=unnest_source_alias), 482 ) 483 484 explode.replace(column) 485 486 if is_posexplode: 487 expressions = expression.expressions 488 expressions.insert( 489 expressions.index(alias) + 1, 490 exp.If( 491 this=exp.column(series_alias, table=series_table_alias).eq( 492 exp.column(pos_alias, table=unnest_source_alias) 493 ), 494 true=exp.column(pos_alias, table=unnest_source_alias), 495 ).as_(pos_alias), 496 ) 497 expression.set("expressions", expressions) 498 499 if not arrays: 500 if expression.args.get("from"): 501 expression.join(series, copy=False, join_type="CROSS") 502 else: 503 expression.from_(series, copy=False) 504 505 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 506 arrays.append(size) 507 508 # trino doesn't support left join unnest with on conditions 509 # if it did, this would be much simpler 510 expression.join( 511 exp.alias_( 512 exp.Unnest( 513 expressions=[explode_arg.copy()], 514 offset=exp.to_identifier(pos_alias), 515 ), 516 unnest_source_alias, 517 table=[explode_alias], 518 ), 519 join_type="CROSS", 520 copy=False, 521 ) 522 523 if index_offset != 1: 524 size = size - 1 525 526 expression.where( 527 exp.column(series_alias, table=series_table_alias) 528 .eq(exp.column(pos_alias, table=unnest_source_alias)) 529 .or_( 530 (exp.column(series_alias, table=series_table_alias) > size).and_( 531 exp.column(pos_alias, table=unnest_source_alias).eq(size) 532 ) 533 ), 534 copy=False, 535 ) 536 537 if arrays: 538 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 539 540 if index_offset != 1: 541 end = end - (1 - index_offset) 542 series.expressions[0].set("end", end) 543 544 return expression 545 546 return _explode_to_unnest 547 548 549def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 550 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 551 if ( 552 isinstance(expression, exp.PERCENTILES) 553 and not isinstance(expression.parent, exp.WithinGroup) 554 and expression.expression 555 ): 556 column = expression.this.pop() 557 expression.set("this", expression.expression.pop()) 558 order = exp.Order(expressions=[exp.Ordered(this=column)]) 559 expression = exp.WithinGroup(this=expression, expression=order) 560 561 return expression 562 563 564def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 565 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 566 if ( 567 isinstance(expression, exp.WithinGroup) 568 and isinstance(expression.this, exp.PERCENTILES) 569 and isinstance(expression.expression, exp.Order) 570 ): 571 quantile = expression.this.this 572 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 573 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 574 575 return expression 576 577 578def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 579 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 580 if isinstance(expression, exp.With) and expression.recursive: 581 next_name = name_sequence("_c_") 582 583 for cte in expression.expressions: 584 if not cte.args["alias"].columns: 585 query = cte.this 586 if isinstance(query, exp.SetOperation): 587 query = query.this 588 589 cte.args["alias"].set( 590 "columns", 591 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 592 ) 593 594 return expression 595 596 597def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 598 """Replace 'epoch' in casts by the equivalent date literal.""" 599 if ( 600 isinstance(expression, (exp.Cast, exp.TryCast)) 601 and expression.name.lower() == "epoch" 602 and expression.to.this in exp.DataType.TEMPORAL_TYPES 603 ): 604 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 605 606 return expression 607 608 609def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 610 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 611 if isinstance(expression, exp.Select): 612 for join in expression.args.get("joins") or []: 613 on = join.args.get("on") 614 if on and join.kind in ("SEMI", "ANTI"): 615 subquery = exp.select("1").from_(join.this).where(on) 616 exists = exp.Exists(this=subquery) 617 if join.kind == "ANTI": 618 exists = exists.not_(copy=False) 619 620 join.pop() 621 expression.where(exists, copy=False) 622 623 return expression 624 625 626def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 627 """ 628 Converts a query with a FULL OUTER join to a union of identical queries that 629 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 630 for queries that have a single FULL OUTER join. 631 """ 632 if isinstance(expression, exp.Select): 633 full_outer_joins = [ 634 (index, join) 635 for index, join in enumerate(expression.args.get("joins") or []) 636 if join.side == "FULL" 637 ] 638 639 if len(full_outer_joins) == 1: 640 expression_copy = expression.copy() 641 expression.set("limit", None) 642 index, full_outer_join = full_outer_joins[0] 643 644 tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) 645 join_conditions = full_outer_join.args.get("on") or exp.and_( 646 *[ 647 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 648 for col in full_outer_join.args.get("using") 649 ] 650 ) 651 652 full_outer_join.set("side", "left") 653 anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) 654 expression_copy.args["joins"][index].set("side", "right") 655 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 656 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 657 expression.args.pop("order", None) # remove order by from LEFT side 658 659 return exp.union(expression, expression_copy, copy=False, distinct=False) 660 661 return expression 662 663 664def move_ctes_to_top_level(expression: E) -> E: 665 """ 666 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 667 defined at the top-level, so for example queries like: 668 669 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 670 671 are invalid in those dialects. This transformation can be used to ensure all CTEs are 672 moved to the top level so that the final SQL code is valid from a syntax standpoint. 673 674 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 675 """ 676 top_level_with = expression.args.get("with") 677 for inner_with in expression.find_all(exp.With): 678 if inner_with.parent is expression: 679 continue 680 681 if not top_level_with: 682 top_level_with = inner_with.pop() 683 expression.set("with", top_level_with) 684 else: 685 if inner_with.recursive: 686 top_level_with.set("recursive", True) 687 688 parent_cte = inner_with.find_ancestor(exp.CTE) 689 inner_with.pop() 690 691 if parent_cte: 692 i = top_level_with.expressions.index(parent_cte) 693 top_level_with.expressions[i:i] = inner_with.expressions 694 top_level_with.set("expressions", top_level_with.expressions) 695 else: 696 top_level_with.set( 697 "expressions", top_level_with.expressions + inner_with.expressions 698 ) 699 700 return expression 701 702 703def ensure_bools(expression: exp.Expression) -> exp.Expression: 704 """Converts numeric values used in conditions into explicit boolean expressions.""" 705 from sqlglot.optimizer.canonicalize import ensure_bools 706 707 def _ensure_bool(node: exp.Expression) -> None: 708 if ( 709 node.is_number 710 or ( 711 not isinstance(node, exp.SubqueryPredicate) 712 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 713 ) 714 or (isinstance(node, exp.Column) and not node.type) 715 ): 716 node.replace(node.neq(0)) 717 718 for node in expression.walk(): 719 ensure_bools(node, _ensure_bool) 720 721 return expression 722 723 724def unqualify_columns(expression: exp.Expression) -> exp.Expression: 725 for column in expression.find_all(exp.Column): 726 # We only wanna pop off the table, db, catalog args 727 for part in column.parts[:-1]: 728 part.pop() 729 730 return expression 731 732 733def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 734 assert isinstance(expression, exp.Create) 735 for constraint in expression.find_all(exp.UniqueColumnConstraint): 736 if constraint.parent: 737 constraint.parent.pop() 738 739 return expression 740 741 742def ctas_with_tmp_tables_to_create_tmp_view( 743 expression: exp.Expression, 744 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 745) -> exp.Expression: 746 assert isinstance(expression, exp.Create) 747 properties = expression.args.get("properties") 748 temporary = any( 749 isinstance(prop, exp.TemporaryProperty) 750 for prop in (properties.expressions if properties else []) 751 ) 752 753 # CTAS with temp tables map to CREATE TEMPORARY VIEW 754 if expression.kind == "TABLE" and temporary: 755 if expression.expression: 756 return exp.Create( 757 kind="TEMPORARY VIEW", 758 this=expression.this, 759 expression=expression.expression, 760 ) 761 return tmp_storage_provider(expression) 762 763 return expression 764 765 766def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 767 """ 768 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 769 PARTITIONED BY value is an array of column names, they are transformed into a schema. 770 The corresponding columns are removed from the create statement. 771 """ 772 assert isinstance(expression, exp.Create) 773 has_schema = isinstance(expression.this, exp.Schema) 774 is_partitionable = expression.kind in {"TABLE", "VIEW"} 775 776 if has_schema and is_partitionable: 777 prop = expression.find(exp.PartitionedByProperty) 778 if prop and prop.this and not isinstance(prop.this, exp.Schema): 779 schema = expression.this 780 columns = {v.name.upper() for v in prop.this.expressions} 781 partitions = [col for col in schema.expressions if col.name.upper() in columns] 782 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 783 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 784 expression.set("this", schema) 785 786 return expression 787 788 789def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 790 """ 791 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 792 793 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 794 """ 795 assert isinstance(expression, exp.Create) 796 prop = expression.find(exp.PartitionedByProperty) 797 if ( 798 prop 799 and prop.this 800 and isinstance(prop.this, exp.Schema) 801 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 802 ): 803 prop_this = exp.Tuple( 804 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 805 ) 806 schema = expression.this 807 for e in prop.this.expressions: 808 schema.append("expressions", e) 809 prop.set("this", prop_this) 810 811 return expression 812 813 814def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 815 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 816 if isinstance(expression, exp.Struct): 817 expression.set( 818 "expressions", 819 [ 820 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 821 for e in expression.expressions 822 ], 823 ) 824 825 return expression 826 827 828def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 829 """ 830 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 831 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 832 833 For example, 834 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 835 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 836 837 Args: 838 expression: The AST to remove join marks from. 839 840 Returns: 841 The AST with join marks removed. 842 """ 843 from sqlglot.optimizer.scope import traverse_scope 844 845 for scope in traverse_scope(expression): 846 query = scope.expression 847 848 where = query.args.get("where") 849 joins = query.args.get("joins") 850 851 if not where or not joins: 852 continue 853 854 query_from = query.args["from"] 855 856 # These keep track of the joins to be replaced 857 new_joins: t.Dict[str, exp.Join] = {} 858 old_joins = {join.alias_or_name: join for join in joins} 859 860 for column in scope.columns: 861 if not column.args.get("join_mark"): 862 continue 863 864 predicate = column.find_ancestor(exp.Predicate, exp.Select) 865 assert isinstance( 866 predicate, exp.Binary 867 ), "Columns can only be marked with (+) when involved in a binary operation" 868 869 predicate_parent = predicate.parent 870 join_predicate = predicate.pop() 871 872 left_columns = [ 873 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 874 ] 875 right_columns = [ 876 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 877 ] 878 879 assert not ( 880 left_columns and right_columns 881 ), "The (+) marker cannot appear in both sides of a binary predicate" 882 883 marked_column_tables = set() 884 for col in left_columns or right_columns: 885 table = col.table 886 assert table, f"Column {col} needs to be qualified with a table" 887 888 col.set("join_mark", False) 889 marked_column_tables.add(table) 890 891 assert ( 892 len(marked_column_tables) == 1 893 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 894 895 join_this = old_joins.get(col.table, query_from).this 896 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 897 898 # Upsert new_join into new_joins dictionary 899 new_join_alias_or_name = new_join.alias_or_name 900 existing_join = new_joins.get(new_join_alias_or_name) 901 if existing_join: 902 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 903 else: 904 new_joins[new_join_alias_or_name] = new_join 905 906 # If the parent of the target predicate is a binary node, then it now has only one child 907 if isinstance(predicate_parent, exp.Binary): 908 if predicate_parent.left is None: 909 predicate_parent.replace(predicate_parent.right) 910 else: 911 predicate_parent.replace(predicate_parent.left) 912 913 if query_from.alias_or_name in new_joins: 914 only_old_joins = old_joins.keys() - new_joins.keys() 915 assert ( 916 len(only_old_joins) >= 1 917 ), "Cannot determine which table to use in the new FROM clause" 918 919 new_from_name = list(only_old_joins)[0] 920 query.set("from", exp.From(this=old_joins[new_from_name].this)) 921 922 query.set("joins", list(new_joins.values())) 923 924 if not where.this: 925 where.pop() 926 927 return expression 928 929 930def any_to_exists(expression: exp.Expression) -> exp.Expression: 931 """ 932 Transform ANY operator to Spark's EXISTS 933 934 For example, 935 - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) 936 - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) 937 938 Both ANY and EXISTS accept queries but currently only array expressions are supported for this 939 transformation 940 """ 941 if isinstance(expression, exp.Select): 942 for any in expression.find_all(exp.Any): 943 this = any.this 944 if isinstance(this, exp.Query): 945 continue 946 947 binop = any.parent 948 if isinstance(binop, exp.Binary): 949 lambda_arg = exp.to_identifier("x") 950 any.replace(lambda_arg) 951 lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) 952 binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) 953 954 return expression
16def preprocess( 17 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 18) -> t.Callable[[Generator, exp.Expression], str]: 19 """ 20 Creates a new transform by chaining a sequence of transformations and converts the resulting 21 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 22 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 23 24 Args: 25 transforms: sequence of transform functions. These will be called in order. 26 27 Returns: 28 Function that can be used as a generator transform. 29 """ 30 31 def _to_sql(self, expression: exp.Expression) -> str: 32 expression_type = type(expression) 33 34 try: 35 expression = transforms[0](expression) 36 for transform in transforms[1:]: 37 expression = transform(expression) 38 except UnsupportedError as unsupported_error: 39 self.unsupported(str(unsupported_error)) 40 41 _sql_handler = getattr(self, expression.key + "_sql", None) 42 if _sql_handler: 43 return _sql_handler(expression) 44 45 transforms_handler = self.TRANSFORMS.get(type(expression)) 46 if transforms_handler: 47 if expression_type is type(expression): 48 if isinstance(expression, exp.Func): 49 return self.function_fallback_sql(expression) 50 51 # Ensures we don't enter an infinite loop. This can happen when the original expression 52 # has the same type as the final expression and there's no _sql method available for it, 53 # because then it'd re-enter _to_sql. 54 raise ValueError( 55 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 56 ) 57 58 return transforms_handler(self, expression) 59 60 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 61 62 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
65def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 66 if isinstance(expression, exp.Select): 67 count = 0 68 recursive_ctes = [] 69 70 for unnest in expression.find_all(exp.Unnest): 71 if ( 72 not isinstance(unnest.parent, (exp.From, exp.Join)) 73 or len(unnest.expressions) != 1 74 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 75 ): 76 continue 77 78 generate_date_array = unnest.expressions[0] 79 start = generate_date_array.args.get("start") 80 end = generate_date_array.args.get("end") 81 step = generate_date_array.args.get("step") 82 83 if not start or not end or not isinstance(step, exp.Interval): 84 continue 85 86 alias = unnest.args.get("alias") 87 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 88 89 start = exp.cast(start, "date") 90 date_add = exp.func( 91 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 92 ) 93 cast_date_add = exp.cast(date_add, "date") 94 95 cte_name = "_generated_dates" + (f"_{count}" if count else "") 96 97 base_query = exp.select(start.as_(column_name)) 98 recursive_query = ( 99 exp.select(cast_date_add) 100 .from_(cte_name) 101 .where(cast_date_add <= exp.cast(end, "date")) 102 ) 103 cte_query = base_query.union(recursive_query, distinct=False) 104 105 generate_dates_query = exp.select(column_name).from_(cte_name) 106 unnest.replace(generate_dates_query.subquery(cte_name)) 107 108 recursive_ctes.append( 109 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 110 ) 111 count += 1 112 113 if recursive_ctes: 114 with_expression = expression.args.get("with") or exp.With() 115 with_expression.set("recursive", True) 116 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 117 expression.set("with", with_expression) 118 119 return expression
122def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 123 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 124 this = expression.this 125 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 126 unnest = exp.Unnest(expressions=[this]) 127 if expression.alias: 128 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 129 130 return unnest 131 132 return expression
Unnests GENERATE_SERIES or SEQUENCE table references.
135def unalias_group(expression: exp.Expression) -> exp.Expression: 136 """ 137 Replace references to select aliases in GROUP BY clauses. 138 139 Example: 140 >>> import sqlglot 141 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 142 'SELECT a AS b FROM x GROUP BY 1' 143 144 Args: 145 expression: the expression that will be transformed. 146 147 Returns: 148 The transformed expression. 149 """ 150 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 151 aliased_selects = { 152 e.alias: i 153 for i, e in enumerate(expression.parent.expressions, start=1) 154 if isinstance(e, exp.Alias) 155 } 156 157 for group_by in expression.expressions: 158 if ( 159 isinstance(group_by, exp.Column) 160 and not group_by.table 161 and group_by.name in aliased_selects 162 ): 163 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 164 165 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
168def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 169 """ 170 Convert SELECT DISTINCT ON statements to a subquery with a window function. 171 172 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 173 174 Args: 175 expression: the expression that will be transformed. 176 177 Returns: 178 The transformed expression. 179 """ 180 if ( 181 isinstance(expression, exp.Select) 182 and expression.args.get("distinct") 183 and expression.args["distinct"].args.get("on") 184 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 185 ): 186 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 187 outer_selects = expression.selects 188 row_number = find_new_name(expression.named_selects, "_row_number") 189 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 190 order = expression.args.get("order") 191 192 if order: 193 window.set("order", order.pop()) 194 else: 195 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 196 197 window = exp.alias_(window, row_number) 198 expression.select(window, copy=False) 199 200 return ( 201 exp.select(*outer_selects, copy=False) 202 .from_(expression.subquery("_t", copy=False), copy=False) 203 .where(exp.column(row_number).eq(1), copy=False) 204 ) 205 206 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
209def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 210 """ 211 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 212 213 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 214 https://docs.snowflake.com/en/sql-reference/constructs/qualify 215 216 Some dialects don't support window functions in the WHERE clause, so we need to include them as 217 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 218 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 219 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 220 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 221 corresponding expression to avoid creating invalid column references. 222 """ 223 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 224 taken = set(expression.named_selects) 225 for select in expression.selects: 226 if not select.alias_or_name: 227 alias = find_new_name(taken, "_c") 228 select.replace(exp.alias_(select, alias)) 229 taken.add(alias) 230 231 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 232 alias_or_name = select.alias_or_name 233 identifier = select.args.get("alias") or select.this 234 if isinstance(identifier, exp.Identifier): 235 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 236 return alias_or_name 237 238 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 239 qualify_filters = expression.args["qualify"].pop().this 240 expression_by_alias = { 241 select.alias: select.this 242 for select in expression.selects 243 if isinstance(select, exp.Alias) 244 } 245 246 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 247 for select_candidate in qualify_filters.find_all(select_candidates): 248 if isinstance(select_candidate, exp.Window): 249 if expression_by_alias: 250 for column in select_candidate.find_all(exp.Column): 251 expr = expression_by_alias.get(column.name) 252 if expr: 253 column.replace(expr) 254 255 alias = find_new_name(expression.named_selects, "_w") 256 expression.select(exp.alias_(select_candidate, alias), copy=False) 257 column = exp.column(alias) 258 259 if isinstance(select_candidate.parent, exp.Qualify): 260 qualify_filters = column 261 else: 262 select_candidate.replace(column) 263 elif select_candidate.name not in expression.named_selects: 264 expression.select(select_candidate.copy(), copy=False) 265 266 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 267 qualify_filters, copy=False 268 ) 269 270 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
273def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 274 """ 275 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 276 other expressions. This transforms removes the precision from parameterized types in expressions. 277 """ 278 for node in expression.find_all(exp.DataType): 279 node.set( 280 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 281 ) 282 283 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
286def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 287 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 288 from sqlglot.optimizer.scope import find_all_in_scope 289 290 if isinstance(expression, exp.Select): 291 unnest_aliases = { 292 unnest.alias 293 for unnest in find_all_in_scope(expression, exp.Unnest) 294 if isinstance(unnest.parent, (exp.From, exp.Join)) 295 } 296 if unnest_aliases: 297 for column in expression.find_all(exp.Column): 298 if column.table in unnest_aliases: 299 column.set("table", None) 300 elif column.db in unnest_aliases: 301 column.set("db", None) 302 303 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
306def unnest_to_explode( 307 expression: exp.Expression, 308 unnest_using_arrays_zip: bool = True, 309) -> exp.Expression: 310 """Convert cross join unnest into lateral view explode.""" 311 312 def _unnest_zip_exprs( 313 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 314 ) -> t.List[exp.Expression]: 315 if has_multi_expr: 316 if not unnest_using_arrays_zip: 317 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 318 319 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 320 zip_exprs: t.List[exp.Expression] = [ 321 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 322 ] 323 u.set("expressions", zip_exprs) 324 return zip_exprs 325 return unnest_exprs 326 327 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 328 if u.args.get("offset"): 329 return exp.Posexplode 330 return exp.Inline if has_multi_expr else exp.Explode 331 332 if isinstance(expression, exp.Select): 333 from_ = expression.args.get("from") 334 335 if from_ and isinstance(from_.this, exp.Unnest): 336 unnest = from_.this 337 alias = unnest.args.get("alias") 338 exprs = unnest.expressions 339 has_multi_expr = len(exprs) > 1 340 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 341 342 unnest.replace( 343 exp.Table( 344 this=_udtf_type(unnest, has_multi_expr)( 345 this=this, 346 expressions=expressions, 347 ), 348 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 349 ) 350 ) 351 352 joins = expression.args.get("joins") or [] 353 for join in list(joins): 354 join_expr = join.this 355 356 is_lateral = isinstance(join_expr, exp.Lateral) 357 358 unnest = join_expr.this if is_lateral else join_expr 359 360 if isinstance(unnest, exp.Unnest): 361 if is_lateral: 362 alias = join_expr.args.get("alias") 363 else: 364 alias = unnest.args.get("alias") 365 exprs = unnest.expressions 366 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 367 has_multi_expr = len(exprs) > 1 368 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 369 370 joins.remove(join) 371 372 alias_cols = alias.columns if alias else [] 373 374 # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases 375 # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. 376 # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html 377 378 if not has_multi_expr and len(alias_cols) not in (1, 2): 379 raise UnsupportedError( 380 "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases" 381 ) 382 383 for e, column in zip(exprs, alias_cols): 384 expression.append( 385 "laterals", 386 exp.Lateral( 387 this=_udtf_type(unnest, has_multi_expr)(this=e), 388 view=True, 389 alias=exp.TableAlias( 390 this=alias.this, # type: ignore 391 columns=alias_cols, 392 ), 393 ), 394 ) 395 396 return expression
Convert cross join unnest into lateral view explode.
399def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 400 """Convert explode/posexplode into unnest.""" 401 402 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 403 if isinstance(expression, exp.Select): 404 from sqlglot.optimizer.scope import Scope 405 406 taken_select_names = set(expression.named_selects) 407 taken_source_names = {name for name, _ in Scope(expression).references} 408 409 def new_name(names: t.Set[str], name: str) -> str: 410 name = find_new_name(names, name) 411 names.add(name) 412 return name 413 414 arrays: t.List[exp.Condition] = [] 415 series_alias = new_name(taken_select_names, "pos") 416 series = exp.alias_( 417 exp.Unnest( 418 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 419 ), 420 new_name(taken_source_names, "_u"), 421 table=[series_alias], 422 ) 423 424 # we use list here because expression.selects is mutated inside the loop 425 for select in list(expression.selects): 426 explode = select.find(exp.Explode) 427 428 if explode: 429 pos_alias = "" 430 explode_alias = "" 431 432 if isinstance(select, exp.Alias): 433 explode_alias = select.args["alias"] 434 alias = select 435 elif isinstance(select, exp.Aliases): 436 pos_alias = select.aliases[0] 437 explode_alias = select.aliases[1] 438 alias = select.replace(exp.alias_(select.this, "", copy=False)) 439 else: 440 alias = select.replace(exp.alias_(select, "")) 441 explode = alias.find(exp.Explode) 442 assert explode 443 444 is_posexplode = isinstance(explode, exp.Posexplode) 445 explode_arg = explode.this 446 447 if isinstance(explode, exp.ExplodeOuter): 448 bracket = explode_arg[0] 449 bracket.set("safe", True) 450 bracket.set("offset", True) 451 explode_arg = exp.func( 452 "IF", 453 exp.func( 454 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 455 ).eq(0), 456 exp.array(bracket, copy=False), 457 explode_arg, 458 ) 459 460 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 461 if isinstance(explode_arg, exp.Column): 462 taken_select_names.add(explode_arg.output_name) 463 464 unnest_source_alias = new_name(taken_source_names, "_u") 465 466 if not explode_alias: 467 explode_alias = new_name(taken_select_names, "col") 468 469 if is_posexplode: 470 pos_alias = new_name(taken_select_names, "pos") 471 472 if not pos_alias: 473 pos_alias = new_name(taken_select_names, "pos") 474 475 alias.set("alias", exp.to_identifier(explode_alias)) 476 477 series_table_alias = series.args["alias"].this 478 column = exp.If( 479 this=exp.column(series_alias, table=series_table_alias).eq( 480 exp.column(pos_alias, table=unnest_source_alias) 481 ), 482 true=exp.column(explode_alias, table=unnest_source_alias), 483 ) 484 485 explode.replace(column) 486 487 if is_posexplode: 488 expressions = expression.expressions 489 expressions.insert( 490 expressions.index(alias) + 1, 491 exp.If( 492 this=exp.column(series_alias, table=series_table_alias).eq( 493 exp.column(pos_alias, table=unnest_source_alias) 494 ), 495 true=exp.column(pos_alias, table=unnest_source_alias), 496 ).as_(pos_alias), 497 ) 498 expression.set("expressions", expressions) 499 500 if not arrays: 501 if expression.args.get("from"): 502 expression.join(series, copy=False, join_type="CROSS") 503 else: 504 expression.from_(series, copy=False) 505 506 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 507 arrays.append(size) 508 509 # trino doesn't support left join unnest with on conditions 510 # if it did, this would be much simpler 511 expression.join( 512 exp.alias_( 513 exp.Unnest( 514 expressions=[explode_arg.copy()], 515 offset=exp.to_identifier(pos_alias), 516 ), 517 unnest_source_alias, 518 table=[explode_alias], 519 ), 520 join_type="CROSS", 521 copy=False, 522 ) 523 524 if index_offset != 1: 525 size = size - 1 526 527 expression.where( 528 exp.column(series_alias, table=series_table_alias) 529 .eq(exp.column(pos_alias, table=unnest_source_alias)) 530 .or_( 531 (exp.column(series_alias, table=series_table_alias) > size).and_( 532 exp.column(pos_alias, table=unnest_source_alias).eq(size) 533 ) 534 ), 535 copy=False, 536 ) 537 538 if arrays: 539 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 540 541 if index_offset != 1: 542 end = end - (1 - index_offset) 543 series.expressions[0].set("end", end) 544 545 return expression 546 547 return _explode_to_unnest
Convert explode/posexplode into unnest.
550def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 551 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 552 if ( 553 isinstance(expression, exp.PERCENTILES) 554 and not isinstance(expression.parent, exp.WithinGroup) 555 and expression.expression 556 ): 557 column = expression.this.pop() 558 expression.set("this", expression.expression.pop()) 559 order = exp.Order(expressions=[exp.Ordered(this=column)]) 560 expression = exp.WithinGroup(this=expression, expression=order) 561 562 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
565def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 566 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 567 if ( 568 isinstance(expression, exp.WithinGroup) 569 and isinstance(expression.this, exp.PERCENTILES) 570 and isinstance(expression.expression, exp.Order) 571 ): 572 quantile = expression.this.this 573 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 574 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 575 576 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
579def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 580 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 581 if isinstance(expression, exp.With) and expression.recursive: 582 next_name = name_sequence("_c_") 583 584 for cte in expression.expressions: 585 if not cte.args["alias"].columns: 586 query = cte.this 587 if isinstance(query, exp.SetOperation): 588 query = query.this 589 590 cte.args["alias"].set( 591 "columns", 592 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 593 ) 594 595 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
598def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 599 """Replace 'epoch' in casts by the equivalent date literal.""" 600 if ( 601 isinstance(expression, (exp.Cast, exp.TryCast)) 602 and expression.name.lower() == "epoch" 603 and expression.to.this in exp.DataType.TEMPORAL_TYPES 604 ): 605 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 606 607 return expression
Replace 'epoch' in casts by the equivalent date literal.
610def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 611 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 612 if isinstance(expression, exp.Select): 613 for join in expression.args.get("joins") or []: 614 on = join.args.get("on") 615 if on and join.kind in ("SEMI", "ANTI"): 616 subquery = exp.select("1").from_(join.this).where(on) 617 exists = exp.Exists(this=subquery) 618 if join.kind == "ANTI": 619 exists = exists.not_(copy=False) 620 621 join.pop() 622 expression.where(exists, copy=False) 623 624 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
627def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 628 """ 629 Converts a query with a FULL OUTER join to a union of identical queries that 630 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 631 for queries that have a single FULL OUTER join. 632 """ 633 if isinstance(expression, exp.Select): 634 full_outer_joins = [ 635 (index, join) 636 for index, join in enumerate(expression.args.get("joins") or []) 637 if join.side == "FULL" 638 ] 639 640 if len(full_outer_joins) == 1: 641 expression_copy = expression.copy() 642 expression.set("limit", None) 643 index, full_outer_join = full_outer_joins[0] 644 645 tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) 646 join_conditions = full_outer_join.args.get("on") or exp.and_( 647 *[ 648 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 649 for col in full_outer_join.args.get("using") 650 ] 651 ) 652 653 full_outer_join.set("side", "left") 654 anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) 655 expression_copy.args["joins"][index].set("side", "right") 656 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 657 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 658 expression.args.pop("order", None) # remove order by from LEFT side 659 660 return exp.union(expression, expression_copy, copy=False, distinct=False) 661 662 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
665def move_ctes_to_top_level(expression: E) -> E: 666 """ 667 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 668 defined at the top-level, so for example queries like: 669 670 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 671 672 are invalid in those dialects. This transformation can be used to ensure all CTEs are 673 moved to the top level so that the final SQL code is valid from a syntax standpoint. 674 675 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 676 """ 677 top_level_with = expression.args.get("with") 678 for inner_with in expression.find_all(exp.With): 679 if inner_with.parent is expression: 680 continue 681 682 if not top_level_with: 683 top_level_with = inner_with.pop() 684 expression.set("with", top_level_with) 685 else: 686 if inner_with.recursive: 687 top_level_with.set("recursive", True) 688 689 parent_cte = inner_with.find_ancestor(exp.CTE) 690 inner_with.pop() 691 692 if parent_cte: 693 i = top_level_with.expressions.index(parent_cte) 694 top_level_with.expressions[i:i] = inner_with.expressions 695 top_level_with.set("expressions", top_level_with.expressions) 696 else: 697 top_level_with.set( 698 "expressions", top_level_with.expressions + inner_with.expressions 699 ) 700 701 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
704def ensure_bools(expression: exp.Expression) -> exp.Expression: 705 """Converts numeric values used in conditions into explicit boolean expressions.""" 706 from sqlglot.optimizer.canonicalize import ensure_bools 707 708 def _ensure_bool(node: exp.Expression) -> None: 709 if ( 710 node.is_number 711 or ( 712 not isinstance(node, exp.SubqueryPredicate) 713 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 714 ) 715 or (isinstance(node, exp.Column) and not node.type) 716 ): 717 node.replace(node.neq(0)) 718 719 for node in expression.walk(): 720 ensure_bools(node, _ensure_bool) 721 722 return expression
Converts numeric values used in conditions into explicit boolean expressions.
743def ctas_with_tmp_tables_to_create_tmp_view( 744 expression: exp.Expression, 745 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 746) -> exp.Expression: 747 assert isinstance(expression, exp.Create) 748 properties = expression.args.get("properties") 749 temporary = any( 750 isinstance(prop, exp.TemporaryProperty) 751 for prop in (properties.expressions if properties else []) 752 ) 753 754 # CTAS with temp tables map to CREATE TEMPORARY VIEW 755 if expression.kind == "TABLE" and temporary: 756 if expression.expression: 757 return exp.Create( 758 kind="TEMPORARY VIEW", 759 this=expression.this, 760 expression=expression.expression, 761 ) 762 return tmp_storage_provider(expression) 763 764 return expression
767def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 768 """ 769 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 770 PARTITIONED BY value is an array of column names, they are transformed into a schema. 771 The corresponding columns are removed from the create statement. 772 """ 773 assert isinstance(expression, exp.Create) 774 has_schema = isinstance(expression.this, exp.Schema) 775 is_partitionable = expression.kind in {"TABLE", "VIEW"} 776 777 if has_schema and is_partitionable: 778 prop = expression.find(exp.PartitionedByProperty) 779 if prop and prop.this and not isinstance(prop.this, exp.Schema): 780 schema = expression.this 781 columns = {v.name.upper() for v in prop.this.expressions} 782 partitions = [col for col in schema.expressions if col.name.upper() in columns] 783 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 784 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 785 expression.set("this", schema) 786 787 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
790def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 791 """ 792 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 793 794 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 795 """ 796 assert isinstance(expression, exp.Create) 797 prop = expression.find(exp.PartitionedByProperty) 798 if ( 799 prop 800 and prop.this 801 and isinstance(prop.this, exp.Schema) 802 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 803 ): 804 prop_this = exp.Tuple( 805 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 806 ) 807 schema = expression.this 808 for e in prop.this.expressions: 809 schema.append("expressions", e) 810 prop.set("this", prop_this) 811 812 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
815def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 816 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 817 if isinstance(expression, exp.Struct): 818 expression.set( 819 "expressions", 820 [ 821 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 822 for e in expression.expressions 823 ], 824 ) 825 826 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
829def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 830 """ 831 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 832 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 833 834 For example, 835 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 836 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 837 838 Args: 839 expression: The AST to remove join marks from. 840 841 Returns: 842 The AST with join marks removed. 843 """ 844 from sqlglot.optimizer.scope import traverse_scope 845 846 for scope in traverse_scope(expression): 847 query = scope.expression 848 849 where = query.args.get("where") 850 joins = query.args.get("joins") 851 852 if not where or not joins: 853 continue 854 855 query_from = query.args["from"] 856 857 # These keep track of the joins to be replaced 858 new_joins: t.Dict[str, exp.Join] = {} 859 old_joins = {join.alias_or_name: join for join in joins} 860 861 for column in scope.columns: 862 if not column.args.get("join_mark"): 863 continue 864 865 predicate = column.find_ancestor(exp.Predicate, exp.Select) 866 assert isinstance( 867 predicate, exp.Binary 868 ), "Columns can only be marked with (+) when involved in a binary operation" 869 870 predicate_parent = predicate.parent 871 join_predicate = predicate.pop() 872 873 left_columns = [ 874 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 875 ] 876 right_columns = [ 877 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 878 ] 879 880 assert not ( 881 left_columns and right_columns 882 ), "The (+) marker cannot appear in both sides of a binary predicate" 883 884 marked_column_tables = set() 885 for col in left_columns or right_columns: 886 table = col.table 887 assert table, f"Column {col} needs to be qualified with a table" 888 889 col.set("join_mark", False) 890 marked_column_tables.add(table) 891 892 assert ( 893 len(marked_column_tables) == 1 894 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 895 896 join_this = old_joins.get(col.table, query_from).this 897 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 898 899 # Upsert new_join into new_joins dictionary 900 new_join_alias_or_name = new_join.alias_or_name 901 existing_join = new_joins.get(new_join_alias_or_name) 902 if existing_join: 903 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 904 else: 905 new_joins[new_join_alias_or_name] = new_join 906 907 # If the parent of the target predicate is a binary node, then it now has only one child 908 if isinstance(predicate_parent, exp.Binary): 909 if predicate_parent.left is None: 910 predicate_parent.replace(predicate_parent.right) 911 else: 912 predicate_parent.replace(predicate_parent.left) 913 914 if query_from.alias_or_name in new_joins: 915 only_old_joins = old_joins.keys() - new_joins.keys() 916 assert ( 917 len(only_old_joins) >= 1 918 ), "Cannot determine which table to use in the new FROM clause" 919 920 new_from_name = list(only_old_joins)[0] 921 query.set("from", exp.From(this=old_joins[new_from_name].this)) 922 923 query.set("joins", list(new_joins.values())) 924 925 if not where.this: 926 where.pop() 927 928 return expression
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running sqlglot.optimizer.qualify
first.
For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Arguments:
- expression: The AST to remove join marks from.
Returns:
The AST with join marks removed.
931def any_to_exists(expression: exp.Expression) -> exp.Expression: 932 """ 933 Transform ANY operator to Spark's EXISTS 934 935 For example, 936 - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) 937 - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) 938 939 Both ANY and EXISTS accept queries but currently only array expressions are supported for this 940 transformation 941 """ 942 if isinstance(expression, exp.Select): 943 for any in expression.find_all(exp.Any): 944 this = any.this 945 if isinstance(this, exp.Query): 946 continue 947 948 binop = any.parent 949 if isinstance(binop, exp.Binary): 950 lambda_arg = exp.to_identifier("x") 951 any.replace(lambda_arg) 952 lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) 953 binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) 954 955 return expression
Transform ANY operator to Spark's EXISTS
For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation