sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def unalias_group(expression: exp.Expression) -> exp.Expression: 13 """ 14 Replace references to select aliases in GROUP BY clauses. 15 16 Example: 17 >>> import sqlglot 18 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 19 'SELECT a AS b FROM x GROUP BY 1' 20 21 Args: 22 expression: the expression that will be transformed. 23 24 Returns: 25 The transformed expression. 26 """ 27 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 28 aliased_selects = { 29 e.alias: i 30 for i, e in enumerate(expression.parent.expressions, start=1) 31 if isinstance(e, exp.Alias) 32 } 33 34 for group_by in expression.expressions: 35 if ( 36 isinstance(group_by, exp.Column) 37 and not group_by.table 38 and group_by.name in aliased_selects 39 ): 40 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 41 42 return expression 43 44 45def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 46 """ 47 Convert SELECT DISTINCT ON statements to a subquery with a window function. 48 49 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 50 51 Args: 52 expression: the expression that will be transformed. 53 54 Returns: 55 The transformed expression. 56 """ 57 if ( 58 isinstance(expression, exp.Select) 59 and expression.args.get("distinct") 60 and expression.args["distinct"].args.get("on") 61 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 62 ): 63 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 64 outer_selects = expression.selects 65 row_number = find_new_name(expression.named_selects, "_row_number") 66 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 67 order = expression.args.get("order") 68 69 if order: 70 window.set("order", order.pop()) 71 else: 72 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 73 74 window = exp.alias_(window, row_number) 75 expression.select(window, copy=False) 76 77 return ( 78 exp.select(*outer_selects, copy=False) 79 .from_(expression.subquery("_t", copy=False), copy=False) 80 .where(exp.column(row_number).eq(1), copy=False) 81 ) 82 83 return expression 84 85 86def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 87 """ 88 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 89 90 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 91 https://docs.snowflake.com/en/sql-reference/constructs/qualify 92 93 Some dialects don't support window functions in the WHERE clause, so we need to include them as 94 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 95 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 96 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 97 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 98 corresponding expression to avoid creating invalid column references. 99 """ 100 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 101 taken = set(expression.named_selects) 102 for select in expression.selects: 103 if not select.alias_or_name: 104 alias = find_new_name(taken, "_c") 105 select.replace(exp.alias_(select, alias)) 106 taken.add(alias) 107 108 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 109 qualify_filters = expression.args["qualify"].pop().this 110 expression_by_alias = { 111 select.alias: select.this 112 for select in expression.selects 113 if isinstance(select, exp.Alias) 114 } 115 116 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 117 for select_candidate in qualify_filters.find_all(select_candidates): 118 if isinstance(select_candidate, exp.Window): 119 if expression_by_alias: 120 for column in select_candidate.find_all(exp.Column): 121 expr = expression_by_alias.get(column.name) 122 if expr: 123 column.replace(expr) 124 125 alias = find_new_name(expression.named_selects, "_w") 126 expression.select(exp.alias_(select_candidate, alias), copy=False) 127 column = exp.column(alias) 128 129 if isinstance(select_candidate.parent, exp.Qualify): 130 qualify_filters = column 131 else: 132 select_candidate.replace(column) 133 elif select_candidate.name not in expression.named_selects: 134 expression.select(select_candidate.copy(), copy=False) 135 136 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 137 qualify_filters, copy=False 138 ) 139 140 return expression 141 142 143def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 144 """ 145 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 146 other expressions. This transforms removes the precision from parameterized types in expressions. 147 """ 148 for node in expression.find_all(exp.DataType): 149 node.set( 150 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 151 ) 152 153 return expression 154 155 156def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 157 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 158 from sqlglot.optimizer.scope import find_all_in_scope 159 160 if isinstance(expression, exp.Select): 161 unnest_aliases = { 162 unnest.alias 163 for unnest in find_all_in_scope(expression, exp.Unnest) 164 if isinstance(unnest.parent, (exp.From, exp.Join)) 165 } 166 if unnest_aliases: 167 for column in expression.find_all(exp.Column): 168 if column.table in unnest_aliases: 169 column.set("table", None) 170 elif column.db in unnest_aliases: 171 column.set("db", None) 172 173 return expression 174 175 176def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 177 """Convert cross join unnest into lateral view explode.""" 178 if isinstance(expression, exp.Select): 179 for join in expression.args.get("joins") or []: 180 unnest = join.this 181 182 if isinstance(unnest, exp.Unnest): 183 alias = unnest.args.get("alias") 184 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 185 186 expression.args["joins"].remove(join) 187 188 for e, column in zip(unnest.expressions, alias.columns if alias else []): 189 expression.append( 190 "laterals", 191 exp.Lateral( 192 this=udtf(this=e), 193 view=True, 194 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 195 ), 196 ) 197 198 return expression 199 200 201def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 202 """Convert explode/posexplode into unnest.""" 203 204 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 205 if isinstance(expression, exp.Select): 206 from sqlglot.optimizer.scope import Scope 207 208 taken_select_names = set(expression.named_selects) 209 taken_source_names = {name for name, _ in Scope(expression).references} 210 211 def new_name(names: t.Set[str], name: str) -> str: 212 name = find_new_name(names, name) 213 names.add(name) 214 return name 215 216 arrays: t.List[exp.Condition] = [] 217 series_alias = new_name(taken_select_names, "pos") 218 series = exp.alias_( 219 exp.Unnest( 220 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 221 ), 222 new_name(taken_source_names, "_u"), 223 table=[series_alias], 224 ) 225 226 # we use list here because expression.selects is mutated inside the loop 227 for select in list(expression.selects): 228 explode = select.find(exp.Explode) 229 230 if explode: 231 pos_alias = "" 232 explode_alias = "" 233 234 if isinstance(select, exp.Alias): 235 explode_alias = select.args["alias"] 236 alias = select 237 elif isinstance(select, exp.Aliases): 238 pos_alias = select.aliases[0] 239 explode_alias = select.aliases[1] 240 alias = select.replace(exp.alias_(select.this, "", copy=False)) 241 else: 242 alias = select.replace(exp.alias_(select, "")) 243 explode = alias.find(exp.Explode) 244 assert explode 245 246 is_posexplode = isinstance(explode, exp.Posexplode) 247 explode_arg = explode.this 248 249 if isinstance(explode, exp.ExplodeOuter): 250 bracket = explode_arg[0] 251 bracket.set("safe", True) 252 bracket.set("offset", True) 253 explode_arg = exp.func( 254 "IF", 255 exp.func( 256 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 257 ).eq(0), 258 exp.array(bracket, copy=False), 259 explode_arg, 260 ) 261 262 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 263 if isinstance(explode_arg, exp.Column): 264 taken_select_names.add(explode_arg.output_name) 265 266 unnest_source_alias = new_name(taken_source_names, "_u") 267 268 if not explode_alias: 269 explode_alias = new_name(taken_select_names, "col") 270 271 if is_posexplode: 272 pos_alias = new_name(taken_select_names, "pos") 273 274 if not pos_alias: 275 pos_alias = new_name(taken_select_names, "pos") 276 277 alias.set("alias", exp.to_identifier(explode_alias)) 278 279 series_table_alias = series.args["alias"].this 280 column = exp.If( 281 this=exp.column(series_alias, table=series_table_alias).eq( 282 exp.column(pos_alias, table=unnest_source_alias) 283 ), 284 true=exp.column(explode_alias, table=unnest_source_alias), 285 ) 286 287 explode.replace(column) 288 289 if is_posexplode: 290 expressions = expression.expressions 291 expressions.insert( 292 expressions.index(alias) + 1, 293 exp.If( 294 this=exp.column(series_alias, table=series_table_alias).eq( 295 exp.column(pos_alias, table=unnest_source_alias) 296 ), 297 true=exp.column(pos_alias, table=unnest_source_alias), 298 ).as_(pos_alias), 299 ) 300 expression.set("expressions", expressions) 301 302 if not arrays: 303 if expression.args.get("from"): 304 expression.join(series, copy=False, join_type="CROSS") 305 else: 306 expression.from_(series, copy=False) 307 308 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 309 arrays.append(size) 310 311 # trino doesn't support left join unnest with on conditions 312 # if it did, this would be much simpler 313 expression.join( 314 exp.alias_( 315 exp.Unnest( 316 expressions=[explode_arg.copy()], 317 offset=exp.to_identifier(pos_alias), 318 ), 319 unnest_source_alias, 320 table=[explode_alias], 321 ), 322 join_type="CROSS", 323 copy=False, 324 ) 325 326 if index_offset != 1: 327 size = size - 1 328 329 expression.where( 330 exp.column(series_alias, table=series_table_alias) 331 .eq(exp.column(pos_alias, table=unnest_source_alias)) 332 .or_( 333 (exp.column(series_alias, table=series_table_alias) > size).and_( 334 exp.column(pos_alias, table=unnest_source_alias).eq(size) 335 ) 336 ), 337 copy=False, 338 ) 339 340 if arrays: 341 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 342 343 if index_offset != 1: 344 end = end - (1 - index_offset) 345 series.expressions[0].set("end", end) 346 347 return expression 348 349 return _explode_to_unnest 350 351 352PERCENTILES = (exp.PercentileCont, exp.PercentileDisc) 353 354 355def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 356 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 357 if ( 358 isinstance(expression, PERCENTILES) 359 and not isinstance(expression.parent, exp.WithinGroup) 360 and expression.expression 361 ): 362 column = expression.this.pop() 363 expression.set("this", expression.expression.pop()) 364 order = exp.Order(expressions=[exp.Ordered(this=column)]) 365 expression = exp.WithinGroup(this=expression, expression=order) 366 367 return expression 368 369 370def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 371 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 372 if ( 373 isinstance(expression, exp.WithinGroup) 374 and isinstance(expression.this, PERCENTILES) 375 and isinstance(expression.expression, exp.Order) 376 ): 377 quantile = expression.this.this 378 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 379 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 380 381 return expression 382 383 384def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 385 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 386 if isinstance(expression, exp.With) and expression.recursive: 387 next_name = name_sequence("_c_") 388 389 for cte in expression.expressions: 390 if not cte.args["alias"].columns: 391 query = cte.this 392 if isinstance(query, exp.Union): 393 query = query.this 394 395 cte.args["alias"].set( 396 "columns", 397 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 398 ) 399 400 return expression 401 402 403def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 404 """Replace 'epoch' in casts by the equivalent date literal.""" 405 if ( 406 isinstance(expression, (exp.Cast, exp.TryCast)) 407 and expression.name.lower() == "epoch" 408 and expression.to.this in exp.DataType.TEMPORAL_TYPES 409 ): 410 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 411 412 return expression 413 414 415def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 416 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 417 if isinstance(expression, exp.Select): 418 for join in expression.args.get("joins") or []: 419 on = join.args.get("on") 420 if on and join.kind in ("SEMI", "ANTI"): 421 subquery = exp.select("1").from_(join.this).where(on) 422 exists = exp.Exists(this=subquery) 423 if join.kind == "ANTI": 424 exists = exists.not_(copy=False) 425 426 join.pop() 427 expression.where(exists, copy=False) 428 429 return expression 430 431 432def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 433 """ 434 Converts a query with a FULL OUTER join to a union of identical queries that 435 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 436 for queries that have a single FULL OUTER join. 437 """ 438 if isinstance(expression, exp.Select): 439 full_outer_joins = [ 440 (index, join) 441 for index, join in enumerate(expression.args.get("joins") or []) 442 if join.side == "FULL" 443 ] 444 445 if len(full_outer_joins) == 1: 446 expression_copy = expression.copy() 447 expression.set("limit", None) 448 index, full_outer_join = full_outer_joins[0] 449 full_outer_join.set("side", "left") 450 expression_copy.args["joins"][index].set("side", "right") 451 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 452 453 return exp.union(expression, expression_copy, copy=False) 454 455 return expression 456 457 458def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 459 """ 460 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 461 defined at the top-level, so for example queries like: 462 463 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 464 465 are invalid in those dialects. This transformation can be used to ensure all CTEs are 466 moved to the top level so that the final SQL code is valid from a syntax standpoint. 467 468 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 469 """ 470 top_level_with = expression.args.get("with") 471 for node in expression.find_all(exp.With): 472 if node.parent is expression: 473 continue 474 475 inner_with = node.pop() 476 if not top_level_with: 477 top_level_with = inner_with 478 expression.set("with", top_level_with) 479 else: 480 if inner_with.recursive: 481 top_level_with.set("recursive", True) 482 483 top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions) 484 485 return expression 486 487 488def ensure_bools(expression: exp.Expression) -> exp.Expression: 489 """Converts numeric values used in conditions into explicit boolean expressions.""" 490 from sqlglot.optimizer.canonicalize import ensure_bools 491 492 def _ensure_bool(node: exp.Expression) -> None: 493 if ( 494 node.is_number 495 or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 496 or (isinstance(node, exp.Column) and not node.type) 497 ): 498 node.replace(node.neq(0)) 499 500 for node in expression.walk(): 501 ensure_bools(node, _ensure_bool) 502 503 return expression 504 505 506def unqualify_columns(expression: exp.Expression) -> exp.Expression: 507 for column in expression.find_all(exp.Column): 508 # We only wanna pop off the table, db, catalog args 509 for part in column.parts[:-1]: 510 part.pop() 511 512 return expression 513 514 515def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 516 assert isinstance(expression, exp.Create) 517 for constraint in expression.find_all(exp.UniqueColumnConstraint): 518 if constraint.parent: 519 constraint.parent.pop() 520 521 return expression 522 523 524def ctas_with_tmp_tables_to_create_tmp_view( 525 expression: exp.Expression, 526 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 527) -> exp.Expression: 528 assert isinstance(expression, exp.Create) 529 properties = expression.args.get("properties") 530 temporary = any( 531 isinstance(prop, exp.TemporaryProperty) 532 for prop in (properties.expressions if properties else []) 533 ) 534 535 # CTAS with temp tables map to CREATE TEMPORARY VIEW 536 if expression.kind == "TABLE" and temporary: 537 if expression.expression: 538 return exp.Create( 539 kind="TEMPORARY VIEW", 540 this=expression.this, 541 expression=expression.expression, 542 ) 543 return tmp_storage_provider(expression) 544 545 return expression 546 547 548def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 549 """ 550 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 551 PARTITIONED BY value is an array of column names, they are transformed into a schema. 552 The corresponding columns are removed from the create statement. 553 """ 554 assert isinstance(expression, exp.Create) 555 has_schema = isinstance(expression.this, exp.Schema) 556 is_partitionable = expression.kind in {"TABLE", "VIEW"} 557 558 if has_schema and is_partitionable: 559 prop = expression.find(exp.PartitionedByProperty) 560 if prop and prop.this and not isinstance(prop.this, exp.Schema): 561 schema = expression.this 562 columns = {v.name.upper() for v in prop.this.expressions} 563 partitions = [col for col in schema.expressions if col.name.upper() in columns] 564 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 565 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 566 expression.set("this", schema) 567 568 return expression 569 570 571def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 572 """ 573 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 574 575 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 576 """ 577 assert isinstance(expression, exp.Create) 578 prop = expression.find(exp.PartitionedByProperty) 579 if ( 580 prop 581 and prop.this 582 and isinstance(prop.this, exp.Schema) 583 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 584 ): 585 prop_this = exp.Tuple( 586 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 587 ) 588 schema = expression.this 589 for e in prop.this.expressions: 590 schema.append("expressions", e) 591 prop.set("this", prop_this) 592 593 return expression 594 595 596def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 597 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 598 if isinstance(expression, exp.Struct): 599 expression.set( 600 "expressions", 601 [ 602 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 603 for e in expression.expressions 604 ], 605 ) 606 607 return expression 608 609 610def preprocess( 611 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 612) -> t.Callable[[Generator, exp.Expression], str]: 613 """ 614 Creates a new transform by chaining a sequence of transformations and converts the resulting 615 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 616 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 617 618 Args: 619 transforms: sequence of transform functions. These will be called in order. 620 621 Returns: 622 Function that can be used as a generator transform. 623 """ 624 625 def _to_sql(self, expression: exp.Expression) -> str: 626 expression_type = type(expression) 627 628 expression = transforms[0](expression) 629 for transform in transforms[1:]: 630 expression = transform(expression) 631 632 _sql_handler = getattr(self, expression.key + "_sql", None) 633 if _sql_handler: 634 return _sql_handler(expression) 635 636 transforms_handler = self.TRANSFORMS.get(type(expression)) 637 if transforms_handler: 638 if expression_type is type(expression): 639 if isinstance(expression, exp.Func): 640 return self.function_fallback_sql(expression) 641 642 # Ensures we don't enter an infinite loop. This can happen when the original expression 643 # has the same type as the final expression and there's no _sql method available for it, 644 # because then it'd re-enter _to_sql. 645 raise ValueError( 646 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 647 ) 648 649 return transforms_handler(self, expression) 650 651 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 652 653 return _to_sql
13def unalias_group(expression: exp.Expression) -> exp.Expression: 14 """ 15 Replace references to select aliases in GROUP BY clauses. 16 17 Example: 18 >>> import sqlglot 19 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 20 'SELECT a AS b FROM x GROUP BY 1' 21 22 Args: 23 expression: the expression that will be transformed. 24 25 Returns: 26 The transformed expression. 27 """ 28 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 29 aliased_selects = { 30 e.alias: i 31 for i, e in enumerate(expression.parent.expressions, start=1) 32 if isinstance(e, exp.Alias) 33 } 34 35 for group_by in expression.expressions: 36 if ( 37 isinstance(group_by, exp.Column) 38 and not group_by.table 39 and group_by.name in aliased_selects 40 ): 41 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 42 43 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
46def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 47 """ 48 Convert SELECT DISTINCT ON statements to a subquery with a window function. 49 50 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 51 52 Args: 53 expression: the expression that will be transformed. 54 55 Returns: 56 The transformed expression. 57 """ 58 if ( 59 isinstance(expression, exp.Select) 60 and expression.args.get("distinct") 61 and expression.args["distinct"].args.get("on") 62 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 63 ): 64 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 65 outer_selects = expression.selects 66 row_number = find_new_name(expression.named_selects, "_row_number") 67 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 68 order = expression.args.get("order") 69 70 if order: 71 window.set("order", order.pop()) 72 else: 73 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 74 75 window = exp.alias_(window, row_number) 76 expression.select(window, copy=False) 77 78 return ( 79 exp.select(*outer_selects, copy=False) 80 .from_(expression.subquery("_t", copy=False), copy=False) 81 .where(exp.column(row_number).eq(1), copy=False) 82 ) 83 84 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
87def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 88 """ 89 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 90 91 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 92 https://docs.snowflake.com/en/sql-reference/constructs/qualify 93 94 Some dialects don't support window functions in the WHERE clause, so we need to include them as 95 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 96 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 97 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 98 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 99 corresponding expression to avoid creating invalid column references. 100 """ 101 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 102 taken = set(expression.named_selects) 103 for select in expression.selects: 104 if not select.alias_or_name: 105 alias = find_new_name(taken, "_c") 106 select.replace(exp.alias_(select, alias)) 107 taken.add(alias) 108 109 outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) 110 qualify_filters = expression.args["qualify"].pop().this 111 expression_by_alias = { 112 select.alias: select.this 113 for select in expression.selects 114 if isinstance(select, exp.Alias) 115 } 116 117 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 118 for select_candidate in qualify_filters.find_all(select_candidates): 119 if isinstance(select_candidate, exp.Window): 120 if expression_by_alias: 121 for column in select_candidate.find_all(exp.Column): 122 expr = expression_by_alias.get(column.name) 123 if expr: 124 column.replace(expr) 125 126 alias = find_new_name(expression.named_selects, "_w") 127 expression.select(exp.alias_(select_candidate, alias), copy=False) 128 column = exp.column(alias) 129 130 if isinstance(select_candidate.parent, exp.Qualify): 131 qualify_filters = column 132 else: 133 select_candidate.replace(column) 134 elif select_candidate.name not in expression.named_selects: 135 expression.select(select_candidate.copy(), copy=False) 136 137 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 138 qualify_filters, copy=False 139 ) 140 141 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
144def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 145 """ 146 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 147 other expressions. This transforms removes the precision from parameterized types in expressions. 148 """ 149 for node in expression.find_all(exp.DataType): 150 node.set( 151 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 152 ) 153 154 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
157def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 158 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 159 from sqlglot.optimizer.scope import find_all_in_scope 160 161 if isinstance(expression, exp.Select): 162 unnest_aliases = { 163 unnest.alias 164 for unnest in find_all_in_scope(expression, exp.Unnest) 165 if isinstance(unnest.parent, (exp.From, exp.Join)) 166 } 167 if unnest_aliases: 168 for column in expression.find_all(exp.Column): 169 if column.table in unnest_aliases: 170 column.set("table", None) 171 elif column.db in unnest_aliases: 172 column.set("db", None) 173 174 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
177def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 178 """Convert cross join unnest into lateral view explode.""" 179 if isinstance(expression, exp.Select): 180 for join in expression.args.get("joins") or []: 181 unnest = join.this 182 183 if isinstance(unnest, exp.Unnest): 184 alias = unnest.args.get("alias") 185 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 186 187 expression.args["joins"].remove(join) 188 189 for e, column in zip(unnest.expressions, alias.columns if alias else []): 190 expression.append( 191 "laterals", 192 exp.Lateral( 193 this=udtf(this=e), 194 view=True, 195 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 196 ), 197 ) 198 199 return expression
Convert cross join unnest into lateral view explode.
202def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 203 """Convert explode/posexplode into unnest.""" 204 205 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 206 if isinstance(expression, exp.Select): 207 from sqlglot.optimizer.scope import Scope 208 209 taken_select_names = set(expression.named_selects) 210 taken_source_names = {name for name, _ in Scope(expression).references} 211 212 def new_name(names: t.Set[str], name: str) -> str: 213 name = find_new_name(names, name) 214 names.add(name) 215 return name 216 217 arrays: t.List[exp.Condition] = [] 218 series_alias = new_name(taken_select_names, "pos") 219 series = exp.alias_( 220 exp.Unnest( 221 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 222 ), 223 new_name(taken_source_names, "_u"), 224 table=[series_alias], 225 ) 226 227 # we use list here because expression.selects is mutated inside the loop 228 for select in list(expression.selects): 229 explode = select.find(exp.Explode) 230 231 if explode: 232 pos_alias = "" 233 explode_alias = "" 234 235 if isinstance(select, exp.Alias): 236 explode_alias = select.args["alias"] 237 alias = select 238 elif isinstance(select, exp.Aliases): 239 pos_alias = select.aliases[0] 240 explode_alias = select.aliases[1] 241 alias = select.replace(exp.alias_(select.this, "", copy=False)) 242 else: 243 alias = select.replace(exp.alias_(select, "")) 244 explode = alias.find(exp.Explode) 245 assert explode 246 247 is_posexplode = isinstance(explode, exp.Posexplode) 248 explode_arg = explode.this 249 250 if isinstance(explode, exp.ExplodeOuter): 251 bracket = explode_arg[0] 252 bracket.set("safe", True) 253 bracket.set("offset", True) 254 explode_arg = exp.func( 255 "IF", 256 exp.func( 257 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 258 ).eq(0), 259 exp.array(bracket, copy=False), 260 explode_arg, 261 ) 262 263 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 264 if isinstance(explode_arg, exp.Column): 265 taken_select_names.add(explode_arg.output_name) 266 267 unnest_source_alias = new_name(taken_source_names, "_u") 268 269 if not explode_alias: 270 explode_alias = new_name(taken_select_names, "col") 271 272 if is_posexplode: 273 pos_alias = new_name(taken_select_names, "pos") 274 275 if not pos_alias: 276 pos_alias = new_name(taken_select_names, "pos") 277 278 alias.set("alias", exp.to_identifier(explode_alias)) 279 280 series_table_alias = series.args["alias"].this 281 column = exp.If( 282 this=exp.column(series_alias, table=series_table_alias).eq( 283 exp.column(pos_alias, table=unnest_source_alias) 284 ), 285 true=exp.column(explode_alias, table=unnest_source_alias), 286 ) 287 288 explode.replace(column) 289 290 if is_posexplode: 291 expressions = expression.expressions 292 expressions.insert( 293 expressions.index(alias) + 1, 294 exp.If( 295 this=exp.column(series_alias, table=series_table_alias).eq( 296 exp.column(pos_alias, table=unnest_source_alias) 297 ), 298 true=exp.column(pos_alias, table=unnest_source_alias), 299 ).as_(pos_alias), 300 ) 301 expression.set("expressions", expressions) 302 303 if not arrays: 304 if expression.args.get("from"): 305 expression.join(series, copy=False, join_type="CROSS") 306 else: 307 expression.from_(series, copy=False) 308 309 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 310 arrays.append(size) 311 312 # trino doesn't support left join unnest with on conditions 313 # if it did, this would be much simpler 314 expression.join( 315 exp.alias_( 316 exp.Unnest( 317 expressions=[explode_arg.copy()], 318 offset=exp.to_identifier(pos_alias), 319 ), 320 unnest_source_alias, 321 table=[explode_alias], 322 ), 323 join_type="CROSS", 324 copy=False, 325 ) 326 327 if index_offset != 1: 328 size = size - 1 329 330 expression.where( 331 exp.column(series_alias, table=series_table_alias) 332 .eq(exp.column(pos_alias, table=unnest_source_alias)) 333 .or_( 334 (exp.column(series_alias, table=series_table_alias) > size).and_( 335 exp.column(pos_alias, table=unnest_source_alias).eq(size) 336 ) 337 ), 338 copy=False, 339 ) 340 341 if arrays: 342 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 343 344 if index_offset != 1: 345 end = end - (1 - index_offset) 346 series.expressions[0].set("end", end) 347 348 return expression 349 350 return _explode_to_unnest
Convert explode/posexplode into unnest.
356def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 357 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 358 if ( 359 isinstance(expression, PERCENTILES) 360 and not isinstance(expression.parent, exp.WithinGroup) 361 and expression.expression 362 ): 363 column = expression.this.pop() 364 expression.set("this", expression.expression.pop()) 365 order = exp.Order(expressions=[exp.Ordered(this=column)]) 366 expression = exp.WithinGroup(this=expression, expression=order) 367 368 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
371def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 372 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 373 if ( 374 isinstance(expression, exp.WithinGroup) 375 and isinstance(expression.this, PERCENTILES) 376 and isinstance(expression.expression, exp.Order) 377 ): 378 quantile = expression.this.this 379 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 380 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 381 382 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
385def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 386 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 387 if isinstance(expression, exp.With) and expression.recursive: 388 next_name = name_sequence("_c_") 389 390 for cte in expression.expressions: 391 if not cte.args["alias"].columns: 392 query = cte.this 393 if isinstance(query, exp.Union): 394 query = query.this 395 396 cte.args["alias"].set( 397 "columns", 398 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 399 ) 400 401 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
404def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 405 """Replace 'epoch' in casts by the equivalent date literal.""" 406 if ( 407 isinstance(expression, (exp.Cast, exp.TryCast)) 408 and expression.name.lower() == "epoch" 409 and expression.to.this in exp.DataType.TEMPORAL_TYPES 410 ): 411 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 412 413 return expression
Replace 'epoch' in casts by the equivalent date literal.
416def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 417 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 418 if isinstance(expression, exp.Select): 419 for join in expression.args.get("joins") or []: 420 on = join.args.get("on") 421 if on and join.kind in ("SEMI", "ANTI"): 422 subquery = exp.select("1").from_(join.this).where(on) 423 exists = exp.Exists(this=subquery) 424 if join.kind == "ANTI": 425 exists = exists.not_(copy=False) 426 427 join.pop() 428 expression.where(exists, copy=False) 429 430 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
433def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 434 """ 435 Converts a query with a FULL OUTER join to a union of identical queries that 436 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 437 for queries that have a single FULL OUTER join. 438 """ 439 if isinstance(expression, exp.Select): 440 full_outer_joins = [ 441 (index, join) 442 for index, join in enumerate(expression.args.get("joins") or []) 443 if join.side == "FULL" 444 ] 445 446 if len(full_outer_joins) == 1: 447 expression_copy = expression.copy() 448 expression.set("limit", None) 449 index, full_outer_join = full_outer_joins[0] 450 full_outer_join.set("side", "left") 451 expression_copy.args["joins"][index].set("side", "right") 452 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 453 454 return exp.union(expression, expression_copy, copy=False) 455 456 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
459def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 460 """ 461 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 462 defined at the top-level, so for example queries like: 463 464 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 465 466 are invalid in those dialects. This transformation can be used to ensure all CTEs are 467 moved to the top level so that the final SQL code is valid from a syntax standpoint. 468 469 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 470 """ 471 top_level_with = expression.args.get("with") 472 for node in expression.find_all(exp.With): 473 if node.parent is expression: 474 continue 475 476 inner_with = node.pop() 477 if not top_level_with: 478 top_level_with = inner_with 479 expression.set("with", top_level_with) 480 else: 481 if inner_with.recursive: 482 top_level_with.set("recursive", True) 483 484 top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions) 485 486 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
489def ensure_bools(expression: exp.Expression) -> exp.Expression: 490 """Converts numeric values used in conditions into explicit boolean expressions.""" 491 from sqlglot.optimizer.canonicalize import ensure_bools 492 493 def _ensure_bool(node: exp.Expression) -> None: 494 if ( 495 node.is_number 496 or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 497 or (isinstance(node, exp.Column) and not node.type) 498 ): 499 node.replace(node.neq(0)) 500 501 for node in expression.walk(): 502 ensure_bools(node, _ensure_bool) 503 504 return expression
Converts numeric values used in conditions into explicit boolean expressions.
525def ctas_with_tmp_tables_to_create_tmp_view( 526 expression: exp.Expression, 527 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 528) -> exp.Expression: 529 assert isinstance(expression, exp.Create) 530 properties = expression.args.get("properties") 531 temporary = any( 532 isinstance(prop, exp.TemporaryProperty) 533 for prop in (properties.expressions if properties else []) 534 ) 535 536 # CTAS with temp tables map to CREATE TEMPORARY VIEW 537 if expression.kind == "TABLE" and temporary: 538 if expression.expression: 539 return exp.Create( 540 kind="TEMPORARY VIEW", 541 this=expression.this, 542 expression=expression.expression, 543 ) 544 return tmp_storage_provider(expression) 545 546 return expression
549def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 550 """ 551 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 552 PARTITIONED BY value is an array of column names, they are transformed into a schema. 553 The corresponding columns are removed from the create statement. 554 """ 555 assert isinstance(expression, exp.Create) 556 has_schema = isinstance(expression.this, exp.Schema) 557 is_partitionable = expression.kind in {"TABLE", "VIEW"} 558 559 if has_schema and is_partitionable: 560 prop = expression.find(exp.PartitionedByProperty) 561 if prop and prop.this and not isinstance(prop.this, exp.Schema): 562 schema = expression.this 563 columns = {v.name.upper() for v in prop.this.expressions} 564 partitions = [col for col in schema.expressions if col.name.upper() in columns] 565 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 566 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 567 expression.set("this", schema) 568 569 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
572def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 573 """ 574 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 575 576 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 577 """ 578 assert isinstance(expression, exp.Create) 579 prop = expression.find(exp.PartitionedByProperty) 580 if ( 581 prop 582 and prop.this 583 and isinstance(prop.this, exp.Schema) 584 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 585 ): 586 prop_this = exp.Tuple( 587 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 588 ) 589 schema = expression.this 590 for e in prop.this.expressions: 591 schema.append("expressions", e) 592 prop.set("this", prop_this) 593 594 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
597def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 598 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 599 if isinstance(expression, exp.Struct): 600 expression.set( 601 "expressions", 602 [ 603 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 604 for e in expression.expressions 605 ], 606 ) 607 608 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
611def preprocess( 612 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 613) -> t.Callable[[Generator, exp.Expression], str]: 614 """ 615 Creates a new transform by chaining a sequence of transformations and converts the resulting 616 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 617 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 618 619 Args: 620 transforms: sequence of transform functions. These will be called in order. 621 622 Returns: 623 Function that can be used as a generator transform. 624 """ 625 626 def _to_sql(self, expression: exp.Expression) -> str: 627 expression_type = type(expression) 628 629 expression = transforms[0](expression) 630 for transform in transforms[1:]: 631 expression = transform(expression) 632 633 _sql_handler = getattr(self, expression.key + "_sql", None) 634 if _sql_handler: 635 return _sql_handler(expression) 636 637 transforms_handler = self.TRANSFORMS.get(type(expression)) 638 if transforms_handler: 639 if expression_type is type(expression): 640 if isinstance(expression, exp.Func): 641 return self.function_fallback_sql(expression) 642 643 # Ensures we don't enter an infinite loop. This can happen when the original expression 644 # has the same type as the final expression and there's no _sql method available for it, 645 # because then it'd re-enter _to_sql. 646 raise ValueError( 647 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 648 ) 649 650 return transforms_handler(self, expression) 651 652 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 653 654 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.