sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def preprocess( 13 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 14) -> t.Callable[[Generator, exp.Expression], str]: 15 """ 16 Creates a new transform by chaining a sequence of transformations and converts the resulting 17 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 18 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 19 20 Args: 21 transforms: sequence of transform functions. These will be called in order. 22 23 Returns: 24 Function that can be used as a generator transform. 25 """ 26 27 def _to_sql(self, expression: exp.Expression) -> str: 28 expression_type = type(expression) 29 30 expression = transforms[0](expression) 31 for transform in transforms[1:]: 32 expression = transform(expression) 33 34 _sql_handler = getattr(self, expression.key + "_sql", None) 35 if _sql_handler: 36 return _sql_handler(expression) 37 38 transforms_handler = self.TRANSFORMS.get(type(expression)) 39 if transforms_handler: 40 if expression_type is type(expression): 41 if isinstance(expression, exp.Func): 42 return self.function_fallback_sql(expression) 43 44 # Ensures we don't enter an infinite loop. This can happen when the original expression 45 # has the same type as the final expression and there's no _sql method available for it, 46 # because then it'd re-enter _to_sql. 47 raise ValueError( 48 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 49 ) 50 51 return transforms_handler(self, expression) 52 53 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 54 55 return _to_sql 56 57 58def unnest_generate_date_array_using_recursive_cte( 59 bubble_up_recursive_cte: bool = False, 60) -> t.Callable[[exp.Expression], exp.Expression]: 61 def _unnest_generate_date_array_using_recursive_cte( 62 expression: exp.Expression, 63 ) -> exp.Expression: 64 if ( 65 isinstance(expression, exp.Unnest) 66 and isinstance(expression.parent, (exp.From, exp.Join)) 67 and len(expression.expressions) == 1 68 and isinstance(expression.expressions[0], exp.GenerateDateArray) 69 ): 70 generate_date_array = expression.expressions[0] 71 start = generate_date_array.args.get("start") 72 end = generate_date_array.args.get("end") 73 step = generate_date_array.args.get("step") 74 75 if not start or not end or not isinstance(step, exp.Interval): 76 return expression 77 78 start = exp.cast(start, "date") 79 date_add = exp.func( 80 "date_add", "date_value", exp.Literal.number(step.name), step.args.get("unit") 81 ) 82 cast_date_add = exp.cast(date_add, "date") 83 84 base_query = exp.select(start.as_("date_value")) 85 recursive_query = ( 86 exp.select(cast_date_add.as_("date_value")) 87 .from_("_generated_dates") 88 .where(cast_date_add < exp.cast(end, "date")) 89 ) 90 cte_query = base_query.union(recursive_query, distinct=False) 91 generate_dates_query = exp.select("date_value").from_("_generated_dates") 92 93 query_to_add_cte: exp.Query = generate_dates_query 94 95 if bubble_up_recursive_cte: 96 parent: t.Optional[exp.Expression] = expression.parent 97 98 while parent: 99 if isinstance(parent, exp.Query): 100 query_to_add_cte = parent 101 parent = parent.parent 102 103 query_to_add_cte.with_( 104 "_generated_dates(date_value)", as_=cte_query, recursive=True, copy=False 105 ) 106 return generate_dates_query.subquery("_generated_dates") 107 108 return expression 109 110 return _unnest_generate_date_array_using_recursive_cte 111 112 113def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 114 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 115 this = expression.this 116 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 117 unnest = exp.Unnest(expressions=[this]) 118 if expression.alias: 119 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 120 121 return unnest 122 123 return expression 124 125 126def unalias_group(expression: exp.Expression) -> exp.Expression: 127 """ 128 Replace references to select aliases in GROUP BY clauses. 129 130 Example: 131 >>> import sqlglot 132 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 133 'SELECT a AS b FROM x GROUP BY 1' 134 135 Args: 136 expression: the expression that will be transformed. 137 138 Returns: 139 The transformed expression. 140 """ 141 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 142 aliased_selects = { 143 e.alias: i 144 for i, e in enumerate(expression.parent.expressions, start=1) 145 if isinstance(e, exp.Alias) 146 } 147 148 for group_by in expression.expressions: 149 if ( 150 isinstance(group_by, exp.Column) 151 and not group_by.table 152 and group_by.name in aliased_selects 153 ): 154 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 155 156 return expression 157 158 159def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 160 """ 161 Convert SELECT DISTINCT ON statements to a subquery with a window function. 162 163 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 164 165 Args: 166 expression: the expression that will be transformed. 167 168 Returns: 169 The transformed expression. 170 """ 171 if ( 172 isinstance(expression, exp.Select) 173 and expression.args.get("distinct") 174 and expression.args["distinct"].args.get("on") 175 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 176 ): 177 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 178 outer_selects = expression.selects 179 row_number = find_new_name(expression.named_selects, "_row_number") 180 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 181 order = expression.args.get("order") 182 183 if order: 184 window.set("order", order.pop()) 185 else: 186 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 187 188 window = exp.alias_(window, row_number) 189 expression.select(window, copy=False) 190 191 return ( 192 exp.select(*outer_selects, copy=False) 193 .from_(expression.subquery("_t", copy=False), copy=False) 194 .where(exp.column(row_number).eq(1), copy=False) 195 ) 196 197 return expression 198 199 200def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 201 """ 202 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 203 204 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 205 https://docs.snowflake.com/en/sql-reference/constructs/qualify 206 207 Some dialects don't support window functions in the WHERE clause, so we need to include them as 208 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 209 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 210 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 211 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 212 corresponding expression to avoid creating invalid column references. 213 """ 214 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 215 taken = set(expression.named_selects) 216 for select in expression.selects: 217 if not select.alias_or_name: 218 alias = find_new_name(taken, "_c") 219 select.replace(exp.alias_(select, alias)) 220 taken.add(alias) 221 222 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 223 alias_or_name = select.alias_or_name 224 identifier = select.args.get("alias") or select.this 225 if isinstance(identifier, exp.Identifier): 226 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 227 return alias_or_name 228 229 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 230 qualify_filters = expression.args["qualify"].pop().this 231 expression_by_alias = { 232 select.alias: select.this 233 for select in expression.selects 234 if isinstance(select, exp.Alias) 235 } 236 237 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 238 for select_candidate in qualify_filters.find_all(select_candidates): 239 if isinstance(select_candidate, exp.Window): 240 if expression_by_alias: 241 for column in select_candidate.find_all(exp.Column): 242 expr = expression_by_alias.get(column.name) 243 if expr: 244 column.replace(expr) 245 246 alias = find_new_name(expression.named_selects, "_w") 247 expression.select(exp.alias_(select_candidate, alias), copy=False) 248 column = exp.column(alias) 249 250 if isinstance(select_candidate.parent, exp.Qualify): 251 qualify_filters = column 252 else: 253 select_candidate.replace(column) 254 elif select_candidate.name not in expression.named_selects: 255 expression.select(select_candidate.copy(), copy=False) 256 257 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 258 qualify_filters, copy=False 259 ) 260 261 return expression 262 263 264def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 265 """ 266 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 267 other expressions. This transforms removes the precision from parameterized types in expressions. 268 """ 269 for node in expression.find_all(exp.DataType): 270 node.set( 271 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 272 ) 273 274 return expression 275 276 277def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 278 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 279 from sqlglot.optimizer.scope import find_all_in_scope 280 281 if isinstance(expression, exp.Select): 282 unnest_aliases = { 283 unnest.alias 284 for unnest in find_all_in_scope(expression, exp.Unnest) 285 if isinstance(unnest.parent, (exp.From, exp.Join)) 286 } 287 if unnest_aliases: 288 for column in expression.find_all(exp.Column): 289 if column.table in unnest_aliases: 290 column.set("table", None) 291 elif column.db in unnest_aliases: 292 column.set("db", None) 293 294 return expression 295 296 297def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 298 """Convert cross join unnest into lateral view explode.""" 299 if isinstance(expression, exp.Select): 300 from_ = expression.args.get("from") 301 302 if from_ and isinstance(from_.this, exp.Unnest): 303 unnest = from_.this 304 alias = unnest.args.get("alias") 305 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 306 this, *expressions = unnest.expressions 307 unnest.replace( 308 exp.Table( 309 this=udtf( 310 this=this, 311 expressions=expressions, 312 ), 313 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 314 ) 315 ) 316 317 for join in expression.args.get("joins") or []: 318 unnest = join.this 319 320 if isinstance(unnest, exp.Unnest): 321 alias = unnest.args.get("alias") 322 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 323 324 expression.args["joins"].remove(join) 325 326 for e, column in zip(unnest.expressions, alias.columns if alias else []): 327 expression.append( 328 "laterals", 329 exp.Lateral( 330 this=udtf(this=e), 331 view=True, 332 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 333 ), 334 ) 335 336 return expression 337 338 339def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 340 """Convert explode/posexplode into unnest.""" 341 342 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 343 if isinstance(expression, exp.Select): 344 from sqlglot.optimizer.scope import Scope 345 346 taken_select_names = set(expression.named_selects) 347 taken_source_names = {name for name, _ in Scope(expression).references} 348 349 def new_name(names: t.Set[str], name: str) -> str: 350 name = find_new_name(names, name) 351 names.add(name) 352 return name 353 354 arrays: t.List[exp.Condition] = [] 355 series_alias = new_name(taken_select_names, "pos") 356 series = exp.alias_( 357 exp.Unnest( 358 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 359 ), 360 new_name(taken_source_names, "_u"), 361 table=[series_alias], 362 ) 363 364 # we use list here because expression.selects is mutated inside the loop 365 for select in list(expression.selects): 366 explode = select.find(exp.Explode) 367 368 if explode: 369 pos_alias = "" 370 explode_alias = "" 371 372 if isinstance(select, exp.Alias): 373 explode_alias = select.args["alias"] 374 alias = select 375 elif isinstance(select, exp.Aliases): 376 pos_alias = select.aliases[0] 377 explode_alias = select.aliases[1] 378 alias = select.replace(exp.alias_(select.this, "", copy=False)) 379 else: 380 alias = select.replace(exp.alias_(select, "")) 381 explode = alias.find(exp.Explode) 382 assert explode 383 384 is_posexplode = isinstance(explode, exp.Posexplode) 385 explode_arg = explode.this 386 387 if isinstance(explode, exp.ExplodeOuter): 388 bracket = explode_arg[0] 389 bracket.set("safe", True) 390 bracket.set("offset", True) 391 explode_arg = exp.func( 392 "IF", 393 exp.func( 394 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 395 ).eq(0), 396 exp.array(bracket, copy=False), 397 explode_arg, 398 ) 399 400 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 401 if isinstance(explode_arg, exp.Column): 402 taken_select_names.add(explode_arg.output_name) 403 404 unnest_source_alias = new_name(taken_source_names, "_u") 405 406 if not explode_alias: 407 explode_alias = new_name(taken_select_names, "col") 408 409 if is_posexplode: 410 pos_alias = new_name(taken_select_names, "pos") 411 412 if not pos_alias: 413 pos_alias = new_name(taken_select_names, "pos") 414 415 alias.set("alias", exp.to_identifier(explode_alias)) 416 417 series_table_alias = series.args["alias"].this 418 column = exp.If( 419 this=exp.column(series_alias, table=series_table_alias).eq( 420 exp.column(pos_alias, table=unnest_source_alias) 421 ), 422 true=exp.column(explode_alias, table=unnest_source_alias), 423 ) 424 425 explode.replace(column) 426 427 if is_posexplode: 428 expressions = expression.expressions 429 expressions.insert( 430 expressions.index(alias) + 1, 431 exp.If( 432 this=exp.column(series_alias, table=series_table_alias).eq( 433 exp.column(pos_alias, table=unnest_source_alias) 434 ), 435 true=exp.column(pos_alias, table=unnest_source_alias), 436 ).as_(pos_alias), 437 ) 438 expression.set("expressions", expressions) 439 440 if not arrays: 441 if expression.args.get("from"): 442 expression.join(series, copy=False, join_type="CROSS") 443 else: 444 expression.from_(series, copy=False) 445 446 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 447 arrays.append(size) 448 449 # trino doesn't support left join unnest with on conditions 450 # if it did, this would be much simpler 451 expression.join( 452 exp.alias_( 453 exp.Unnest( 454 expressions=[explode_arg.copy()], 455 offset=exp.to_identifier(pos_alias), 456 ), 457 unnest_source_alias, 458 table=[explode_alias], 459 ), 460 join_type="CROSS", 461 copy=False, 462 ) 463 464 if index_offset != 1: 465 size = size - 1 466 467 expression.where( 468 exp.column(series_alias, table=series_table_alias) 469 .eq(exp.column(pos_alias, table=unnest_source_alias)) 470 .or_( 471 (exp.column(series_alias, table=series_table_alias) > size).and_( 472 exp.column(pos_alias, table=unnest_source_alias).eq(size) 473 ) 474 ), 475 copy=False, 476 ) 477 478 if arrays: 479 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 480 481 if index_offset != 1: 482 end = end - (1 - index_offset) 483 series.expressions[0].set("end", end) 484 485 return expression 486 487 return _explode_to_unnest 488 489 490def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 491 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 492 if ( 493 isinstance(expression, exp.PERCENTILES) 494 and not isinstance(expression.parent, exp.WithinGroup) 495 and expression.expression 496 ): 497 column = expression.this.pop() 498 expression.set("this", expression.expression.pop()) 499 order = exp.Order(expressions=[exp.Ordered(this=column)]) 500 expression = exp.WithinGroup(this=expression, expression=order) 501 502 return expression 503 504 505def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 506 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 507 if ( 508 isinstance(expression, exp.WithinGroup) 509 and isinstance(expression.this, exp.PERCENTILES) 510 and isinstance(expression.expression, exp.Order) 511 ): 512 quantile = expression.this.this 513 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 514 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 515 516 return expression 517 518 519def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 520 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 521 if isinstance(expression, exp.With) and expression.recursive: 522 next_name = name_sequence("_c_") 523 524 for cte in expression.expressions: 525 if not cte.args["alias"].columns: 526 query = cte.this 527 if isinstance(query, exp.SetOperation): 528 query = query.this 529 530 cte.args["alias"].set( 531 "columns", 532 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 533 ) 534 535 return expression 536 537 538def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 539 """Replace 'epoch' in casts by the equivalent date literal.""" 540 if ( 541 isinstance(expression, (exp.Cast, exp.TryCast)) 542 and expression.name.lower() == "epoch" 543 and expression.to.this in exp.DataType.TEMPORAL_TYPES 544 ): 545 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 546 547 return expression 548 549 550def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 551 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 552 if isinstance(expression, exp.Select): 553 for join in expression.args.get("joins") or []: 554 on = join.args.get("on") 555 if on and join.kind in ("SEMI", "ANTI"): 556 subquery = exp.select("1").from_(join.this).where(on) 557 exists = exp.Exists(this=subquery) 558 if join.kind == "ANTI": 559 exists = exists.not_(copy=False) 560 561 join.pop() 562 expression.where(exists, copy=False) 563 564 return expression 565 566 567def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 568 """ 569 Converts a query with a FULL OUTER join to a union of identical queries that 570 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 571 for queries that have a single FULL OUTER join. 572 """ 573 if isinstance(expression, exp.Select): 574 full_outer_joins = [ 575 (index, join) 576 for index, join in enumerate(expression.args.get("joins") or []) 577 if join.side == "FULL" 578 ] 579 580 if len(full_outer_joins) == 1: 581 expression_copy = expression.copy() 582 expression.set("limit", None) 583 index, full_outer_join = full_outer_joins[0] 584 full_outer_join.set("side", "left") 585 expression_copy.args["joins"][index].set("side", "right") 586 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 587 588 return exp.union(expression, expression_copy, copy=False) 589 590 return expression 591 592 593def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 594 """ 595 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 596 defined at the top-level, so for example queries like: 597 598 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 599 600 are invalid in those dialects. This transformation can be used to ensure all CTEs are 601 moved to the top level so that the final SQL code is valid from a syntax standpoint. 602 603 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 604 """ 605 top_level_with = expression.args.get("with") 606 for inner_with in expression.find_all(exp.With): 607 if inner_with.parent is expression: 608 continue 609 610 if not top_level_with: 611 top_level_with = inner_with.pop() 612 expression.set("with", top_level_with) 613 else: 614 if inner_with.recursive: 615 top_level_with.set("recursive", True) 616 617 parent_cte = inner_with.find_ancestor(exp.CTE) 618 inner_with.pop() 619 620 if parent_cte: 621 i = top_level_with.expressions.index(parent_cte) 622 top_level_with.expressions[i:i] = inner_with.expressions 623 top_level_with.set("expressions", top_level_with.expressions) 624 else: 625 top_level_with.set( 626 "expressions", top_level_with.expressions + inner_with.expressions 627 ) 628 629 return expression 630 631 632def ensure_bools(expression: exp.Expression) -> exp.Expression: 633 """Converts numeric values used in conditions into explicit boolean expressions.""" 634 from sqlglot.optimizer.canonicalize import ensure_bools 635 636 def _ensure_bool(node: exp.Expression) -> None: 637 if ( 638 node.is_number 639 or ( 640 not isinstance(node, exp.SubqueryPredicate) 641 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 642 ) 643 or (isinstance(node, exp.Column) and not node.type) 644 ): 645 node.replace(node.neq(0)) 646 647 for node in expression.walk(): 648 ensure_bools(node, _ensure_bool) 649 650 return expression 651 652 653def unqualify_columns(expression: exp.Expression) -> exp.Expression: 654 for column in expression.find_all(exp.Column): 655 # We only wanna pop off the table, db, catalog args 656 for part in column.parts[:-1]: 657 part.pop() 658 659 return expression 660 661 662def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 663 assert isinstance(expression, exp.Create) 664 for constraint in expression.find_all(exp.UniqueColumnConstraint): 665 if constraint.parent: 666 constraint.parent.pop() 667 668 return expression 669 670 671def ctas_with_tmp_tables_to_create_tmp_view( 672 expression: exp.Expression, 673 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 674) -> exp.Expression: 675 assert isinstance(expression, exp.Create) 676 properties = expression.args.get("properties") 677 temporary = any( 678 isinstance(prop, exp.TemporaryProperty) 679 for prop in (properties.expressions if properties else []) 680 ) 681 682 # CTAS with temp tables map to CREATE TEMPORARY VIEW 683 if expression.kind == "TABLE" and temporary: 684 if expression.expression: 685 return exp.Create( 686 kind="TEMPORARY VIEW", 687 this=expression.this, 688 expression=expression.expression, 689 ) 690 return tmp_storage_provider(expression) 691 692 return expression 693 694 695def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 696 """ 697 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 698 PARTITIONED BY value is an array of column names, they are transformed into a schema. 699 The corresponding columns are removed from the create statement. 700 """ 701 assert isinstance(expression, exp.Create) 702 has_schema = isinstance(expression.this, exp.Schema) 703 is_partitionable = expression.kind in {"TABLE", "VIEW"} 704 705 if has_schema and is_partitionable: 706 prop = expression.find(exp.PartitionedByProperty) 707 if prop and prop.this and not isinstance(prop.this, exp.Schema): 708 schema = expression.this 709 columns = {v.name.upper() for v in prop.this.expressions} 710 partitions = [col for col in schema.expressions if col.name.upper() in columns] 711 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 712 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 713 expression.set("this", schema) 714 715 return expression 716 717 718def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 719 """ 720 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 721 722 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 723 """ 724 assert isinstance(expression, exp.Create) 725 prop = expression.find(exp.PartitionedByProperty) 726 if ( 727 prop 728 and prop.this 729 and isinstance(prop.this, exp.Schema) 730 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 731 ): 732 prop_this = exp.Tuple( 733 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 734 ) 735 schema = expression.this 736 for e in prop.this.expressions: 737 schema.append("expressions", e) 738 prop.set("this", prop_this) 739 740 return expression 741 742 743def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 744 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 745 if isinstance(expression, exp.Struct): 746 expression.set( 747 "expressions", 748 [ 749 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 750 for e in expression.expressions 751 ], 752 ) 753 754 return expression 755 756 757def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 758 """ 759 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 760 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 761 762 For example, 763 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 764 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 765 766 Args: 767 expression: The AST to remove join marks from. 768 769 Returns: 770 The AST with join marks removed. 771 """ 772 from sqlglot.optimizer.scope import traverse_scope 773 774 for scope in traverse_scope(expression): 775 query = scope.expression 776 777 where = query.args.get("where") 778 joins = query.args.get("joins") 779 780 if not where or not joins: 781 continue 782 783 query_from = query.args["from"] 784 785 # These keep track of the joins to be replaced 786 new_joins: t.Dict[str, exp.Join] = {} 787 old_joins = {join.alias_or_name: join for join in joins} 788 789 for column in scope.columns: 790 if not column.args.get("join_mark"): 791 continue 792 793 predicate = column.find_ancestor(exp.Predicate, exp.Select) 794 assert isinstance( 795 predicate, exp.Binary 796 ), "Columns can only be marked with (+) when involved in a binary operation" 797 798 predicate_parent = predicate.parent 799 join_predicate = predicate.pop() 800 801 left_columns = [ 802 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 803 ] 804 right_columns = [ 805 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 806 ] 807 808 assert not ( 809 left_columns and right_columns 810 ), "The (+) marker cannot appear in both sides of a binary predicate" 811 812 marked_column_tables = set() 813 for col in left_columns or right_columns: 814 table = col.table 815 assert table, f"Column {col} needs to be qualified with a table" 816 817 col.set("join_mark", False) 818 marked_column_tables.add(table) 819 820 assert ( 821 len(marked_column_tables) == 1 822 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 823 824 join_this = old_joins.get(col.table, query_from).this 825 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 826 827 # Upsert new_join into new_joins dictionary 828 new_join_alias_or_name = new_join.alias_or_name 829 existing_join = new_joins.get(new_join_alias_or_name) 830 if existing_join: 831 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 832 else: 833 new_joins[new_join_alias_or_name] = new_join 834 835 # If the parent of the target predicate is a binary node, then it now has only one child 836 if isinstance(predicate_parent, exp.Binary): 837 if predicate_parent.left is None: 838 predicate_parent.replace(predicate_parent.right) 839 else: 840 predicate_parent.replace(predicate_parent.left) 841 842 if query_from.alias_or_name in new_joins: 843 only_old_joins = old_joins.keys() - new_joins.keys() 844 assert ( 845 len(only_old_joins) >= 1 846 ), "Cannot determine which table to use in the new FROM clause" 847 848 new_from_name = list(only_old_joins)[0] 849 query.set("from", exp.From(this=old_joins[new_from_name].this)) 850 851 query.set("joins", list(new_joins.values())) 852 853 if not where.this: 854 where.pop() 855 856 return expression
13def preprocess( 14 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 15) -> t.Callable[[Generator, exp.Expression], str]: 16 """ 17 Creates a new transform by chaining a sequence of transformations and converts the resulting 18 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 19 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 20 21 Args: 22 transforms: sequence of transform functions. These will be called in order. 23 24 Returns: 25 Function that can be used as a generator transform. 26 """ 27 28 def _to_sql(self, expression: exp.Expression) -> str: 29 expression_type = type(expression) 30 31 expression = transforms[0](expression) 32 for transform in transforms[1:]: 33 expression = transform(expression) 34 35 _sql_handler = getattr(self, expression.key + "_sql", None) 36 if _sql_handler: 37 return _sql_handler(expression) 38 39 transforms_handler = self.TRANSFORMS.get(type(expression)) 40 if transforms_handler: 41 if expression_type is type(expression): 42 if isinstance(expression, exp.Func): 43 return self.function_fallback_sql(expression) 44 45 # Ensures we don't enter an infinite loop. This can happen when the original expression 46 # has the same type as the final expression and there's no _sql method available for it, 47 # because then it'd re-enter _to_sql. 48 raise ValueError( 49 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 50 ) 51 52 return transforms_handler(self, expression) 53 54 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 55 56 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
59def unnest_generate_date_array_using_recursive_cte( 60 bubble_up_recursive_cte: bool = False, 61) -> t.Callable[[exp.Expression], exp.Expression]: 62 def _unnest_generate_date_array_using_recursive_cte( 63 expression: exp.Expression, 64 ) -> exp.Expression: 65 if ( 66 isinstance(expression, exp.Unnest) 67 and isinstance(expression.parent, (exp.From, exp.Join)) 68 and len(expression.expressions) == 1 69 and isinstance(expression.expressions[0], exp.GenerateDateArray) 70 ): 71 generate_date_array = expression.expressions[0] 72 start = generate_date_array.args.get("start") 73 end = generate_date_array.args.get("end") 74 step = generate_date_array.args.get("step") 75 76 if not start or not end or not isinstance(step, exp.Interval): 77 return expression 78 79 start = exp.cast(start, "date") 80 date_add = exp.func( 81 "date_add", "date_value", exp.Literal.number(step.name), step.args.get("unit") 82 ) 83 cast_date_add = exp.cast(date_add, "date") 84 85 base_query = exp.select(start.as_("date_value")) 86 recursive_query = ( 87 exp.select(cast_date_add.as_("date_value")) 88 .from_("_generated_dates") 89 .where(cast_date_add < exp.cast(end, "date")) 90 ) 91 cte_query = base_query.union(recursive_query, distinct=False) 92 generate_dates_query = exp.select("date_value").from_("_generated_dates") 93 94 query_to_add_cte: exp.Query = generate_dates_query 95 96 if bubble_up_recursive_cte: 97 parent: t.Optional[exp.Expression] = expression.parent 98 99 while parent: 100 if isinstance(parent, exp.Query): 101 query_to_add_cte = parent 102 parent = parent.parent 103 104 query_to_add_cte.with_( 105 "_generated_dates(date_value)", as_=cte_query, recursive=True, copy=False 106 ) 107 return generate_dates_query.subquery("_generated_dates") 108 109 return expression 110 111 return _unnest_generate_date_array_using_recursive_cte
114def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 115 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 116 this = expression.this 117 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 118 unnest = exp.Unnest(expressions=[this]) 119 if expression.alias: 120 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 121 122 return unnest 123 124 return expression
Unnests GENERATE_SERIES or SEQUENCE table references.
127def unalias_group(expression: exp.Expression) -> exp.Expression: 128 """ 129 Replace references to select aliases in GROUP BY clauses. 130 131 Example: 132 >>> import sqlglot 133 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 134 'SELECT a AS b FROM x GROUP BY 1' 135 136 Args: 137 expression: the expression that will be transformed. 138 139 Returns: 140 The transformed expression. 141 """ 142 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 143 aliased_selects = { 144 e.alias: i 145 for i, e in enumerate(expression.parent.expressions, start=1) 146 if isinstance(e, exp.Alias) 147 } 148 149 for group_by in expression.expressions: 150 if ( 151 isinstance(group_by, exp.Column) 152 and not group_by.table 153 and group_by.name in aliased_selects 154 ): 155 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 156 157 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
160def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 161 """ 162 Convert SELECT DISTINCT ON statements to a subquery with a window function. 163 164 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 165 166 Args: 167 expression: the expression that will be transformed. 168 169 Returns: 170 The transformed expression. 171 """ 172 if ( 173 isinstance(expression, exp.Select) 174 and expression.args.get("distinct") 175 and expression.args["distinct"].args.get("on") 176 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 177 ): 178 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 179 outer_selects = expression.selects 180 row_number = find_new_name(expression.named_selects, "_row_number") 181 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 182 order = expression.args.get("order") 183 184 if order: 185 window.set("order", order.pop()) 186 else: 187 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 188 189 window = exp.alias_(window, row_number) 190 expression.select(window, copy=False) 191 192 return ( 193 exp.select(*outer_selects, copy=False) 194 .from_(expression.subquery("_t", copy=False), copy=False) 195 .where(exp.column(row_number).eq(1), copy=False) 196 ) 197 198 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
201def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 202 """ 203 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 204 205 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 206 https://docs.snowflake.com/en/sql-reference/constructs/qualify 207 208 Some dialects don't support window functions in the WHERE clause, so we need to include them as 209 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 210 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 211 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 212 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 213 corresponding expression to avoid creating invalid column references. 214 """ 215 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 216 taken = set(expression.named_selects) 217 for select in expression.selects: 218 if not select.alias_or_name: 219 alias = find_new_name(taken, "_c") 220 select.replace(exp.alias_(select, alias)) 221 taken.add(alias) 222 223 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 224 alias_or_name = select.alias_or_name 225 identifier = select.args.get("alias") or select.this 226 if isinstance(identifier, exp.Identifier): 227 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 228 return alias_or_name 229 230 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 231 qualify_filters = expression.args["qualify"].pop().this 232 expression_by_alias = { 233 select.alias: select.this 234 for select in expression.selects 235 if isinstance(select, exp.Alias) 236 } 237 238 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 239 for select_candidate in qualify_filters.find_all(select_candidates): 240 if isinstance(select_candidate, exp.Window): 241 if expression_by_alias: 242 for column in select_candidate.find_all(exp.Column): 243 expr = expression_by_alias.get(column.name) 244 if expr: 245 column.replace(expr) 246 247 alias = find_new_name(expression.named_selects, "_w") 248 expression.select(exp.alias_(select_candidate, alias), copy=False) 249 column = exp.column(alias) 250 251 if isinstance(select_candidate.parent, exp.Qualify): 252 qualify_filters = column 253 else: 254 select_candidate.replace(column) 255 elif select_candidate.name not in expression.named_selects: 256 expression.select(select_candidate.copy(), copy=False) 257 258 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 259 qualify_filters, copy=False 260 ) 261 262 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
265def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 266 """ 267 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 268 other expressions. This transforms removes the precision from parameterized types in expressions. 269 """ 270 for node in expression.find_all(exp.DataType): 271 node.set( 272 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 273 ) 274 275 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
278def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 279 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 280 from sqlglot.optimizer.scope import find_all_in_scope 281 282 if isinstance(expression, exp.Select): 283 unnest_aliases = { 284 unnest.alias 285 for unnest in find_all_in_scope(expression, exp.Unnest) 286 if isinstance(unnest.parent, (exp.From, exp.Join)) 287 } 288 if unnest_aliases: 289 for column in expression.find_all(exp.Column): 290 if column.table in unnest_aliases: 291 column.set("table", None) 292 elif column.db in unnest_aliases: 293 column.set("db", None) 294 295 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
298def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 299 """Convert cross join unnest into lateral view explode.""" 300 if isinstance(expression, exp.Select): 301 from_ = expression.args.get("from") 302 303 if from_ and isinstance(from_.this, exp.Unnest): 304 unnest = from_.this 305 alias = unnest.args.get("alias") 306 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 307 this, *expressions = unnest.expressions 308 unnest.replace( 309 exp.Table( 310 this=udtf( 311 this=this, 312 expressions=expressions, 313 ), 314 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 315 ) 316 ) 317 318 for join in expression.args.get("joins") or []: 319 unnest = join.this 320 321 if isinstance(unnest, exp.Unnest): 322 alias = unnest.args.get("alias") 323 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 324 325 expression.args["joins"].remove(join) 326 327 for e, column in zip(unnest.expressions, alias.columns if alias else []): 328 expression.append( 329 "laterals", 330 exp.Lateral( 331 this=udtf(this=e), 332 view=True, 333 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 334 ), 335 ) 336 337 return expression
Convert cross join unnest into lateral view explode.
340def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 341 """Convert explode/posexplode into unnest.""" 342 343 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 344 if isinstance(expression, exp.Select): 345 from sqlglot.optimizer.scope import Scope 346 347 taken_select_names = set(expression.named_selects) 348 taken_source_names = {name for name, _ in Scope(expression).references} 349 350 def new_name(names: t.Set[str], name: str) -> str: 351 name = find_new_name(names, name) 352 names.add(name) 353 return name 354 355 arrays: t.List[exp.Condition] = [] 356 series_alias = new_name(taken_select_names, "pos") 357 series = exp.alias_( 358 exp.Unnest( 359 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 360 ), 361 new_name(taken_source_names, "_u"), 362 table=[series_alias], 363 ) 364 365 # we use list here because expression.selects is mutated inside the loop 366 for select in list(expression.selects): 367 explode = select.find(exp.Explode) 368 369 if explode: 370 pos_alias = "" 371 explode_alias = "" 372 373 if isinstance(select, exp.Alias): 374 explode_alias = select.args["alias"] 375 alias = select 376 elif isinstance(select, exp.Aliases): 377 pos_alias = select.aliases[0] 378 explode_alias = select.aliases[1] 379 alias = select.replace(exp.alias_(select.this, "", copy=False)) 380 else: 381 alias = select.replace(exp.alias_(select, "")) 382 explode = alias.find(exp.Explode) 383 assert explode 384 385 is_posexplode = isinstance(explode, exp.Posexplode) 386 explode_arg = explode.this 387 388 if isinstance(explode, exp.ExplodeOuter): 389 bracket = explode_arg[0] 390 bracket.set("safe", True) 391 bracket.set("offset", True) 392 explode_arg = exp.func( 393 "IF", 394 exp.func( 395 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 396 ).eq(0), 397 exp.array(bracket, copy=False), 398 explode_arg, 399 ) 400 401 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 402 if isinstance(explode_arg, exp.Column): 403 taken_select_names.add(explode_arg.output_name) 404 405 unnest_source_alias = new_name(taken_source_names, "_u") 406 407 if not explode_alias: 408 explode_alias = new_name(taken_select_names, "col") 409 410 if is_posexplode: 411 pos_alias = new_name(taken_select_names, "pos") 412 413 if not pos_alias: 414 pos_alias = new_name(taken_select_names, "pos") 415 416 alias.set("alias", exp.to_identifier(explode_alias)) 417 418 series_table_alias = series.args["alias"].this 419 column = exp.If( 420 this=exp.column(series_alias, table=series_table_alias).eq( 421 exp.column(pos_alias, table=unnest_source_alias) 422 ), 423 true=exp.column(explode_alias, table=unnest_source_alias), 424 ) 425 426 explode.replace(column) 427 428 if is_posexplode: 429 expressions = expression.expressions 430 expressions.insert( 431 expressions.index(alias) + 1, 432 exp.If( 433 this=exp.column(series_alias, table=series_table_alias).eq( 434 exp.column(pos_alias, table=unnest_source_alias) 435 ), 436 true=exp.column(pos_alias, table=unnest_source_alias), 437 ).as_(pos_alias), 438 ) 439 expression.set("expressions", expressions) 440 441 if not arrays: 442 if expression.args.get("from"): 443 expression.join(series, copy=False, join_type="CROSS") 444 else: 445 expression.from_(series, copy=False) 446 447 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 448 arrays.append(size) 449 450 # trino doesn't support left join unnest with on conditions 451 # if it did, this would be much simpler 452 expression.join( 453 exp.alias_( 454 exp.Unnest( 455 expressions=[explode_arg.copy()], 456 offset=exp.to_identifier(pos_alias), 457 ), 458 unnest_source_alias, 459 table=[explode_alias], 460 ), 461 join_type="CROSS", 462 copy=False, 463 ) 464 465 if index_offset != 1: 466 size = size - 1 467 468 expression.where( 469 exp.column(series_alias, table=series_table_alias) 470 .eq(exp.column(pos_alias, table=unnest_source_alias)) 471 .or_( 472 (exp.column(series_alias, table=series_table_alias) > size).and_( 473 exp.column(pos_alias, table=unnest_source_alias).eq(size) 474 ) 475 ), 476 copy=False, 477 ) 478 479 if arrays: 480 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 481 482 if index_offset != 1: 483 end = end - (1 - index_offset) 484 series.expressions[0].set("end", end) 485 486 return expression 487 488 return _explode_to_unnest
Convert explode/posexplode into unnest.
491def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 492 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 493 if ( 494 isinstance(expression, exp.PERCENTILES) 495 and not isinstance(expression.parent, exp.WithinGroup) 496 and expression.expression 497 ): 498 column = expression.this.pop() 499 expression.set("this", expression.expression.pop()) 500 order = exp.Order(expressions=[exp.Ordered(this=column)]) 501 expression = exp.WithinGroup(this=expression, expression=order) 502 503 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
506def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 507 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 508 if ( 509 isinstance(expression, exp.WithinGroup) 510 and isinstance(expression.this, exp.PERCENTILES) 511 and isinstance(expression.expression, exp.Order) 512 ): 513 quantile = expression.this.this 514 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 515 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 516 517 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
520def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 521 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 522 if isinstance(expression, exp.With) and expression.recursive: 523 next_name = name_sequence("_c_") 524 525 for cte in expression.expressions: 526 if not cte.args["alias"].columns: 527 query = cte.this 528 if isinstance(query, exp.SetOperation): 529 query = query.this 530 531 cte.args["alias"].set( 532 "columns", 533 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 534 ) 535 536 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
539def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 540 """Replace 'epoch' in casts by the equivalent date literal.""" 541 if ( 542 isinstance(expression, (exp.Cast, exp.TryCast)) 543 and expression.name.lower() == "epoch" 544 and expression.to.this in exp.DataType.TEMPORAL_TYPES 545 ): 546 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 547 548 return expression
Replace 'epoch' in casts by the equivalent date literal.
551def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 552 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 553 if isinstance(expression, exp.Select): 554 for join in expression.args.get("joins") or []: 555 on = join.args.get("on") 556 if on and join.kind in ("SEMI", "ANTI"): 557 subquery = exp.select("1").from_(join.this).where(on) 558 exists = exp.Exists(this=subquery) 559 if join.kind == "ANTI": 560 exists = exists.not_(copy=False) 561 562 join.pop() 563 expression.where(exists, copy=False) 564 565 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
568def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 569 """ 570 Converts a query with a FULL OUTER join to a union of identical queries that 571 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 572 for queries that have a single FULL OUTER join. 573 """ 574 if isinstance(expression, exp.Select): 575 full_outer_joins = [ 576 (index, join) 577 for index, join in enumerate(expression.args.get("joins") or []) 578 if join.side == "FULL" 579 ] 580 581 if len(full_outer_joins) == 1: 582 expression_copy = expression.copy() 583 expression.set("limit", None) 584 index, full_outer_join = full_outer_joins[0] 585 full_outer_join.set("side", "left") 586 expression_copy.args["joins"][index].set("side", "right") 587 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 588 589 return exp.union(expression, expression_copy, copy=False) 590 591 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
594def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 595 """ 596 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 597 defined at the top-level, so for example queries like: 598 599 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 600 601 are invalid in those dialects. This transformation can be used to ensure all CTEs are 602 moved to the top level so that the final SQL code is valid from a syntax standpoint. 603 604 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 605 """ 606 top_level_with = expression.args.get("with") 607 for inner_with in expression.find_all(exp.With): 608 if inner_with.parent is expression: 609 continue 610 611 if not top_level_with: 612 top_level_with = inner_with.pop() 613 expression.set("with", top_level_with) 614 else: 615 if inner_with.recursive: 616 top_level_with.set("recursive", True) 617 618 parent_cte = inner_with.find_ancestor(exp.CTE) 619 inner_with.pop() 620 621 if parent_cte: 622 i = top_level_with.expressions.index(parent_cte) 623 top_level_with.expressions[i:i] = inner_with.expressions 624 top_level_with.set("expressions", top_level_with.expressions) 625 else: 626 top_level_with.set( 627 "expressions", top_level_with.expressions + inner_with.expressions 628 ) 629 630 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
633def ensure_bools(expression: exp.Expression) -> exp.Expression: 634 """Converts numeric values used in conditions into explicit boolean expressions.""" 635 from sqlglot.optimizer.canonicalize import ensure_bools 636 637 def _ensure_bool(node: exp.Expression) -> None: 638 if ( 639 node.is_number 640 or ( 641 not isinstance(node, exp.SubqueryPredicate) 642 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 643 ) 644 or (isinstance(node, exp.Column) and not node.type) 645 ): 646 node.replace(node.neq(0)) 647 648 for node in expression.walk(): 649 ensure_bools(node, _ensure_bool) 650 651 return expression
Converts numeric values used in conditions into explicit boolean expressions.
672def ctas_with_tmp_tables_to_create_tmp_view( 673 expression: exp.Expression, 674 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 675) -> exp.Expression: 676 assert isinstance(expression, exp.Create) 677 properties = expression.args.get("properties") 678 temporary = any( 679 isinstance(prop, exp.TemporaryProperty) 680 for prop in (properties.expressions if properties else []) 681 ) 682 683 # CTAS with temp tables map to CREATE TEMPORARY VIEW 684 if expression.kind == "TABLE" and temporary: 685 if expression.expression: 686 return exp.Create( 687 kind="TEMPORARY VIEW", 688 this=expression.this, 689 expression=expression.expression, 690 ) 691 return tmp_storage_provider(expression) 692 693 return expression
696def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 697 """ 698 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 699 PARTITIONED BY value is an array of column names, they are transformed into a schema. 700 The corresponding columns are removed from the create statement. 701 """ 702 assert isinstance(expression, exp.Create) 703 has_schema = isinstance(expression.this, exp.Schema) 704 is_partitionable = expression.kind in {"TABLE", "VIEW"} 705 706 if has_schema and is_partitionable: 707 prop = expression.find(exp.PartitionedByProperty) 708 if prop and prop.this and not isinstance(prop.this, exp.Schema): 709 schema = expression.this 710 columns = {v.name.upper() for v in prop.this.expressions} 711 partitions = [col for col in schema.expressions if col.name.upper() in columns] 712 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 713 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 714 expression.set("this", schema) 715 716 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
719def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 720 """ 721 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 722 723 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 724 """ 725 assert isinstance(expression, exp.Create) 726 prop = expression.find(exp.PartitionedByProperty) 727 if ( 728 prop 729 and prop.this 730 and isinstance(prop.this, exp.Schema) 731 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 732 ): 733 prop_this = exp.Tuple( 734 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 735 ) 736 schema = expression.this 737 for e in prop.this.expressions: 738 schema.append("expressions", e) 739 prop.set("this", prop_this) 740 741 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
744def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 745 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 746 if isinstance(expression, exp.Struct): 747 expression.set( 748 "expressions", 749 [ 750 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 751 for e in expression.expressions 752 ], 753 ) 754 755 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
758def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 759 """ 760 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 761 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 762 763 For example, 764 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 765 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 766 767 Args: 768 expression: The AST to remove join marks from. 769 770 Returns: 771 The AST with join marks removed. 772 """ 773 from sqlglot.optimizer.scope import traverse_scope 774 775 for scope in traverse_scope(expression): 776 query = scope.expression 777 778 where = query.args.get("where") 779 joins = query.args.get("joins") 780 781 if not where or not joins: 782 continue 783 784 query_from = query.args["from"] 785 786 # These keep track of the joins to be replaced 787 new_joins: t.Dict[str, exp.Join] = {} 788 old_joins = {join.alias_or_name: join for join in joins} 789 790 for column in scope.columns: 791 if not column.args.get("join_mark"): 792 continue 793 794 predicate = column.find_ancestor(exp.Predicate, exp.Select) 795 assert isinstance( 796 predicate, exp.Binary 797 ), "Columns can only be marked with (+) when involved in a binary operation" 798 799 predicate_parent = predicate.parent 800 join_predicate = predicate.pop() 801 802 left_columns = [ 803 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 804 ] 805 right_columns = [ 806 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 807 ] 808 809 assert not ( 810 left_columns and right_columns 811 ), "The (+) marker cannot appear in both sides of a binary predicate" 812 813 marked_column_tables = set() 814 for col in left_columns or right_columns: 815 table = col.table 816 assert table, f"Column {col} needs to be qualified with a table" 817 818 col.set("join_mark", False) 819 marked_column_tables.add(table) 820 821 assert ( 822 len(marked_column_tables) == 1 823 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 824 825 join_this = old_joins.get(col.table, query_from).this 826 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 827 828 # Upsert new_join into new_joins dictionary 829 new_join_alias_or_name = new_join.alias_or_name 830 existing_join = new_joins.get(new_join_alias_or_name) 831 if existing_join: 832 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 833 else: 834 new_joins[new_join_alias_or_name] = new_join 835 836 # If the parent of the target predicate is a binary node, then it now has only one child 837 if isinstance(predicate_parent, exp.Binary): 838 if predicate_parent.left is None: 839 predicate_parent.replace(predicate_parent.right) 840 else: 841 predicate_parent.replace(predicate_parent.left) 842 843 if query_from.alias_or_name in new_joins: 844 only_old_joins = old_joins.keys() - new_joins.keys() 845 assert ( 846 len(only_old_joins) >= 1 847 ), "Cannot determine which table to use in the new FROM clause" 848 849 new_from_name = list(only_old_joins)[0] 850 query.set("from", exp.From(this=old_joins[new_from_name].this)) 851 852 query.set("joins", list(new_joins.values())) 853 854 if not where.this: 855 where.pop() 856 857 return expression
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running sqlglot.optimizer.qualify
first.
For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Arguments:
- expression: The AST to remove join marks from.
Returns:
The AST with join marks removed.