sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5from functools import reduce 6 7from sqlglot import exp 8from sqlglot._typing import E 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import flatten, seq_get 12from sqlglot.parser import Parser 13from sqlglot.time import TIMEZONES, format_time 14from sqlglot.tokens import Token, Tokenizer, TokenType 15from sqlglot.trie import new_trie 16 17B = t.TypeVar("B", bound=exp.Binary) 18 19 20class Dialects(str, Enum): 21 DIALECT = "" 22 23 BIGQUERY = "bigquery" 24 CLICKHOUSE = "clickhouse" 25 DATABRICKS = "databricks" 26 DRILL = "drill" 27 DUCKDB = "duckdb" 28 HIVE = "hive" 29 MYSQL = "mysql" 30 ORACLE = "oracle" 31 POSTGRES = "postgres" 32 PRESTO = "presto" 33 REDSHIFT = "redshift" 34 SNOWFLAKE = "snowflake" 35 SPARK = "spark" 36 SPARK2 = "spark2" 37 SQLITE = "sqlite" 38 STARROCKS = "starrocks" 39 TABLEAU = "tableau" 40 TERADATA = "teradata" 41 TRINO = "trino" 42 TSQL = "tsql" 43 Doris = "doris" 44 45 46class _Dialect(type): 47 classes: t.Dict[str, t.Type[Dialect]] = {} 48 49 def __eq__(cls, other: t.Any) -> bool: 50 if cls is other: 51 return True 52 if isinstance(other, str): 53 return cls is cls.get(other) 54 if isinstance(other, Dialect): 55 return cls is type(other) 56 57 return False 58 59 def __hash__(cls) -> int: 60 return hash(cls.__name__.lower()) 61 62 @classmethod 63 def __getitem__(cls, key: str) -> t.Type[Dialect]: 64 return cls.classes[key] 65 66 @classmethod 67 def get( 68 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 69 ) -> t.Optional[t.Type[Dialect]]: 70 return cls.classes.get(key, default) 71 72 def __new__(cls, clsname, bases, attrs): 73 klass = super().__new__(cls, clsname, bases, attrs) 74 enum = Dialects.__members__.get(clsname.upper()) 75 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 76 77 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 78 klass.FORMAT_TRIE = ( 79 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 80 ) 81 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 82 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 83 84 klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} 85 86 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 87 klass.parser_class = getattr(klass, "Parser", Parser) 88 klass.generator_class = getattr(klass, "Generator", Generator) 89 90 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 91 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 92 klass.tokenizer_class._IDENTIFIERS.items() 93 )[0] 94 95 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 96 return next( 97 ( 98 (s, e) 99 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 100 if t == token_type 101 ), 102 (None, None), 103 ) 104 105 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 106 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 107 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 108 109 dialect_properties = { 110 **{ 111 k: v 112 for k, v in vars(klass).items() 113 if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") 114 }, 115 "TOKENIZER_CLASS": klass.tokenizer_class, 116 } 117 118 if enum not in ("", "bigquery"): 119 dialect_properties["SELECT_KINDS"] = () 120 121 # Pass required dialect properties to the tokenizer, parser and generator classes 122 for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): 123 for name, value in dialect_properties.items(): 124 if hasattr(subclass, name): 125 setattr(subclass, name, value) 126 127 if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT: 128 klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe 129 130 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 131 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 132 TokenType.ANTI, 133 TokenType.SEMI, 134 } 135 136 klass.generator_class.can_identify = klass.can_identify 137 138 return klass 139 140 141class Dialect(metaclass=_Dialect): 142 # Determines the base index offset for arrays 143 INDEX_OFFSET = 0 144 145 # If true unnest table aliases are considered only as column aliases 146 UNNEST_COLUMN_ONLY = False 147 148 # Determines whether or not the table alias comes after tablesample 149 ALIAS_POST_TABLESAMPLE = False 150 151 # Determines whether or not unquoted identifiers are resolved as uppercase 152 # When set to None, it means that the dialect treats all identifiers as case-insensitive 153 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 154 155 # Determines whether or not an unquoted identifier can start with a digit 156 IDENTIFIERS_CAN_START_WITH_DIGIT = False 157 158 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 159 DPIPE_IS_STRING_CONCAT = True 160 161 # Determines whether or not CONCAT's arguments must be strings 162 STRICT_STRING_CONCAT = False 163 164 # Determines whether or not user-defined data types are supported 165 SUPPORTS_USER_DEFINED_TYPES = True 166 167 # Determines whether or not SEMI/ANTI JOINs are supported 168 SUPPORTS_SEMI_ANTI_JOIN = True 169 170 # Determines how function names are going to be normalized 171 NORMALIZE_FUNCTIONS: bool | str = "upper" 172 173 # Determines whether the base comes first in the LOG function 174 LOG_BASE_FIRST = True 175 176 # Indicates the default null ordering method to use if not explicitly set 177 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 178 NULL_ORDERING = "nulls_are_small" 179 180 # Whether the behavior of a / b depends on the types of a and b. 181 # False means a / b is always float division. 182 # True means a / b is integer division if both a and b are integers. 183 TYPED_DIVISION = False 184 185 # False means 1 / 0 throws an error. 186 # True means 1 / 0 returns null. 187 SAFE_DIVISION = False 188 189 DATE_FORMAT = "'%Y-%m-%d'" 190 DATEINT_FORMAT = "'%Y%m%d'" 191 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 192 193 # Custom time mappings in which the key represents dialect time format 194 # and the value represents a python time format 195 TIME_MAPPING: t.Dict[str, str] = {} 196 197 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 198 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 199 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 200 FORMAT_MAPPING: t.Dict[str, str] = {} 201 202 # Mapping of an unescaped escape sequence to the corresponding character 203 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 204 205 # Columns that are auto-generated by the engine corresponding to this dialect 206 # Such columns may be excluded from SELECT * queries, for example 207 PSEUDOCOLUMNS: t.Set[str] = set() 208 209 # Autofilled 210 tokenizer_class = Tokenizer 211 parser_class = Parser 212 generator_class = Generator 213 214 # A trie of the time_mapping keys 215 TIME_TRIE: t.Dict = {} 216 FORMAT_TRIE: t.Dict = {} 217 218 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 219 INVERSE_TIME_TRIE: t.Dict = {} 220 221 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 222 223 def __eq__(self, other: t.Any) -> bool: 224 return type(self) == other 225 226 def __hash__(self) -> int: 227 return hash(type(self)) 228 229 @classmethod 230 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 231 if not dialect: 232 return cls 233 if isinstance(dialect, _Dialect): 234 return dialect 235 if isinstance(dialect, Dialect): 236 return dialect.__class__ 237 238 result = cls.get(dialect) 239 if not result: 240 raise ValueError(f"Unknown dialect '{dialect}'") 241 242 return result 243 244 @classmethod 245 def format_time( 246 cls, expression: t.Optional[str | exp.Expression] 247 ) -> t.Optional[exp.Expression]: 248 if isinstance(expression, str): 249 return exp.Literal.string( 250 # the time formats are quoted 251 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 252 ) 253 254 if expression and expression.is_string: 255 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 256 257 return expression 258 259 @classmethod 260 def normalize_identifier(cls, expression: E) -> E: 261 """ 262 Normalizes an unquoted identifier to either lower or upper case, thus essentially 263 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 264 they will be normalized to lowercase regardless of being quoted or not. 265 """ 266 if isinstance(expression, exp.Identifier) and ( 267 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 268 ): 269 expression.set( 270 "this", 271 expression.this.upper() 272 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 273 else expression.this.lower(), 274 ) 275 276 return expression 277 278 @classmethod 279 def case_sensitive(cls, text: str) -> bool: 280 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 281 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 282 return False 283 284 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 285 return any(unsafe(char) for char in text) 286 287 @classmethod 288 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 289 """Checks if text can be identified given an identify option. 290 291 Args: 292 text: The text to check. 293 identify: 294 "always" or `True`: Always returns true. 295 "safe": True if the identifier is case-insensitive. 296 297 Returns: 298 Whether or not the given text can be identified. 299 """ 300 if identify is True or identify == "always": 301 return True 302 303 if identify == "safe": 304 return not cls.case_sensitive(text) 305 306 return False 307 308 @classmethod 309 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 310 if isinstance(expression, exp.Identifier): 311 name = expression.this 312 expression.set( 313 "quoted", 314 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 315 ) 316 317 return expression 318 319 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 320 return self.parser(**opts).parse(self.tokenize(sql), sql) 321 322 def parse_into( 323 self, expression_type: exp.IntoType, sql: str, **opts 324 ) -> t.List[t.Optional[exp.Expression]]: 325 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 326 327 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 328 return self.generator(**opts).generate(expression, copy=copy) 329 330 def transpile(self, sql: str, **opts) -> t.List[str]: 331 return [ 332 self.generate(expression, copy=False, **opts) if expression else "" 333 for expression in self.parse(sql) 334 ] 335 336 def tokenize(self, sql: str) -> t.List[Token]: 337 return self.tokenizer.tokenize(sql) 338 339 @property 340 def tokenizer(self) -> Tokenizer: 341 if not hasattr(self, "_tokenizer"): 342 self._tokenizer = self.tokenizer_class() 343 return self._tokenizer 344 345 def parser(self, **opts) -> Parser: 346 return self.parser_class(**opts) 347 348 def generator(self, **opts) -> Generator: 349 return self.generator_class(**opts) 350 351 352DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 353 354 355def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 356 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 357 358 359def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 360 if expression.args.get("accuracy"): 361 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 362 return self.func("APPROX_COUNT_DISTINCT", expression.this) 363 364 365def if_sql( 366 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 367) -> t.Callable[[Generator, exp.If], str]: 368 def _if_sql(self: Generator, expression: exp.If) -> str: 369 return self.func( 370 name, 371 expression.this, 372 expression.args.get("true"), 373 expression.args.get("false") or false_value, 374 ) 375 376 return _if_sql 377 378 379def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 380 return self.binary(expression, "->") 381 382 383def arrow_json_extract_scalar_sql( 384 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 385) -> str: 386 return self.binary(expression, "->>") 387 388 389def inline_array_sql(self: Generator, expression: exp.Array) -> str: 390 return f"[{self.expressions(expression, flat=True)}]" 391 392 393def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 394 return self.like_sql( 395 exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) 396 ) 397 398 399def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 400 zone = self.sql(expression, "this") 401 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 402 403 404def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 405 if expression.args.get("recursive"): 406 self.unsupported("Recursive CTEs are unsupported") 407 expression.args["recursive"] = False 408 return self.with_sql(expression) 409 410 411def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 412 n = self.sql(expression, "this") 413 d = self.sql(expression, "expression") 414 return f"IF({d} <> 0, {n} / {d}, NULL)" 415 416 417def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 418 self.unsupported("TABLESAMPLE unsupported") 419 return self.sql(expression.this) 420 421 422def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 423 self.unsupported("PIVOT unsupported") 424 return "" 425 426 427def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 428 return self.cast_sql(expression) 429 430 431def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 432 self.unsupported("Properties unsupported") 433 return "" 434 435 436def no_comment_column_constraint_sql( 437 self: Generator, expression: exp.CommentColumnConstraint 438) -> str: 439 self.unsupported("CommentColumnConstraint unsupported") 440 return "" 441 442 443def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 444 self.unsupported("MAP_FROM_ENTRIES unsupported") 445 return "" 446 447 448def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 449 this = self.sql(expression, "this") 450 substr = self.sql(expression, "substr") 451 position = self.sql(expression, "position") 452 if position: 453 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 454 return f"STRPOS({this}, {substr})" 455 456 457def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 458 return ( 459 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 460 ) 461 462 463def var_map_sql( 464 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 465) -> str: 466 keys = expression.args["keys"] 467 values = expression.args["values"] 468 469 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 470 self.unsupported("Cannot convert array columns into map.") 471 return self.func(map_func_name, keys, values) 472 473 args = [] 474 for key, value in zip(keys.expressions, values.expressions): 475 args.append(self.sql(key)) 476 args.append(self.sql(value)) 477 478 return self.func(map_func_name, *args) 479 480 481def format_time_lambda( 482 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 483) -> t.Callable[[t.List], E]: 484 """Helper used for time expressions. 485 486 Args: 487 exp_class: the expression class to instantiate. 488 dialect: target sql dialect. 489 default: the default format, True being time. 490 491 Returns: 492 A callable that can be used to return the appropriately formatted time expression. 493 """ 494 495 def _format_time(args: t.List): 496 return exp_class( 497 this=seq_get(args, 0), 498 format=Dialect[dialect].format_time( 499 seq_get(args, 1) 500 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 501 ), 502 ) 503 504 return _format_time 505 506 507def time_format( 508 dialect: DialectType = None, 509) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 510 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 511 """ 512 Returns the time format for a given expression, unless it's equivalent 513 to the default time format of the dialect of interest. 514 """ 515 time_format = self.format_time(expression) 516 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 517 518 return _time_format 519 520 521def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 522 """ 523 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 524 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 525 columns are removed from the create statement. 526 """ 527 has_schema = isinstance(expression.this, exp.Schema) 528 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 529 530 if has_schema and is_partitionable: 531 prop = expression.find(exp.PartitionedByProperty) 532 if prop and prop.this and not isinstance(prop.this, exp.Schema): 533 schema = expression.this 534 columns = {v.name.upper() for v in prop.this.expressions} 535 partitions = [col for col in schema.expressions if col.name.upper() in columns] 536 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 537 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 538 expression.set("this", schema) 539 540 return self.create_sql(expression) 541 542 543def parse_date_delta( 544 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 545) -> t.Callable[[t.List], E]: 546 def inner_func(args: t.List) -> E: 547 unit_based = len(args) == 3 548 this = args[2] if unit_based else seq_get(args, 0) 549 unit = args[0] if unit_based else exp.Literal.string("DAY") 550 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 551 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 552 553 return inner_func 554 555 556def parse_date_delta_with_interval( 557 expression_class: t.Type[E], 558) -> t.Callable[[t.List], t.Optional[E]]: 559 def func(args: t.List) -> t.Optional[E]: 560 if len(args) < 2: 561 return None 562 563 interval = args[1] 564 565 if not isinstance(interval, exp.Interval): 566 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 567 568 expression = interval.this 569 if expression and expression.is_string: 570 expression = exp.Literal.number(expression.this) 571 572 return expression_class( 573 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 574 ) 575 576 return func 577 578 579def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 580 unit = seq_get(args, 0) 581 this = seq_get(args, 1) 582 583 if isinstance(this, exp.Cast) and this.is_type("date"): 584 return exp.DateTrunc(unit=unit, this=this) 585 return exp.TimestampTrunc(this=this, unit=unit) 586 587 588def date_add_interval_sql( 589 data_type: str, kind: str 590) -> t.Callable[[Generator, exp.Expression], str]: 591 def func(self: Generator, expression: exp.Expression) -> str: 592 this = self.sql(expression, "this") 593 unit = expression.args.get("unit") 594 unit = exp.var(unit.name.upper() if unit else "DAY") 595 interval = exp.Interval(this=expression.expression, unit=unit) 596 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 597 598 return func 599 600 601def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 602 return self.func( 603 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 604 ) 605 606 607def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 608 if not expression.expression: 609 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 610 if expression.text("expression").lower() in TIMEZONES: 611 return self.sql( 612 exp.AtTimeZone( 613 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 614 zone=expression.expression, 615 ) 616 ) 617 return self.function_fallback_sql(expression) 618 619 620def locate_to_strposition(args: t.List) -> exp.Expression: 621 return exp.StrPosition( 622 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 623 ) 624 625 626def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 627 return self.func( 628 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 629 ) 630 631 632def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 633 return self.sql( 634 exp.Substring( 635 this=expression.this, start=exp.Literal.number(1), length=expression.expression 636 ) 637 ) 638 639 640def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 641 return self.sql( 642 exp.Substring( 643 this=expression.this, 644 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 645 ) 646 ) 647 648 649def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 650 return self.sql(exp.cast(expression.this, "timestamp")) 651 652 653def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 654 return self.sql(exp.cast(expression.this, "date")) 655 656 657# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 658def encode_decode_sql( 659 self: Generator, expression: exp.Expression, name: str, replace: bool = True 660) -> str: 661 charset = expression.args.get("charset") 662 if charset and charset.name.lower() != "utf-8": 663 self.unsupported(f"Expected utf-8 character set, got {charset}.") 664 665 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 666 667 668def min_or_least(self: Generator, expression: exp.Min) -> str: 669 name = "LEAST" if expression.expressions else "MIN" 670 return rename_func(name)(self, expression) 671 672 673def max_or_greatest(self: Generator, expression: exp.Max) -> str: 674 name = "GREATEST" if expression.expressions else "MAX" 675 return rename_func(name)(self, expression) 676 677 678def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 679 cond = expression.this 680 681 if isinstance(expression.this, exp.Distinct): 682 cond = expression.this.expressions[0] 683 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 684 685 return self.func("sum", exp.func("if", cond, 1, 0)) 686 687 688def trim_sql(self: Generator, expression: exp.Trim) -> str: 689 target = self.sql(expression, "this") 690 trim_type = self.sql(expression, "position") 691 remove_chars = self.sql(expression, "expression") 692 collation = self.sql(expression, "collation") 693 694 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 695 if not remove_chars and not collation: 696 return self.trim_sql(expression) 697 698 trim_type = f"{trim_type} " if trim_type else "" 699 remove_chars = f"{remove_chars} " if remove_chars else "" 700 from_part = "FROM " if trim_type or remove_chars else "" 701 collation = f" COLLATE {collation}" if collation else "" 702 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 703 704 705def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 706 return self.func("STRPTIME", expression.this, self.format_time(expression)) 707 708 709def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 710 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 711 _dialect = Dialect.get_or_raise(dialect) 712 time_format = self.format_time(expression) 713 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 714 return self.sql( 715 exp.cast( 716 exp.StrToTime(this=expression.this, format=expression.args["format"]), 717 "date", 718 ) 719 ) 720 return self.sql(exp.cast(expression.this, "date")) 721 722 return _ts_or_ds_to_date_sql 723 724 725def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: 726 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 727 728 729def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 730 delim, *rest_args = expression.expressions 731 return self.sql( 732 reduce( 733 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 734 rest_args, 735 ) 736 ) 737 738 739def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 740 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 741 if bad_args: 742 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 743 744 return self.func( 745 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 746 ) 747 748 749def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 750 bad_args = list( 751 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 752 ) 753 if bad_args: 754 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 755 756 return self.func( 757 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 758 ) 759 760 761def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 762 names = [] 763 for agg in aggregations: 764 if isinstance(agg, exp.Alias): 765 names.append(agg.alias) 766 else: 767 """ 768 This case corresponds to aggregations without aliases being used as suffixes 769 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 770 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 771 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 772 """ 773 agg_all_unquoted = agg.transform( 774 lambda node: exp.Identifier(this=node.name, quoted=False) 775 if isinstance(node, exp.Identifier) 776 else node 777 ) 778 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 779 780 return names 781 782 783def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 784 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 785 786 787# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 788def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 789 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 790 791 792def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 793 return self.func("MAX", expression.this) 794 795 796def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 797 a = self.sql(expression.left) 798 b = self.sql(expression.right) 799 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 800 801 802# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon 803def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: 804 return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" 805 806 807def is_parse_json(expression: exp.Expression) -> bool: 808 return isinstance(expression, exp.ParseJSON) or ( 809 isinstance(expression, exp.Cast) and expression.is_type("json") 810 ) 811 812 813def isnull_to_is_null(args: t.List) -> exp.Expression: 814 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 815 816 817def generatedasidentitycolumnconstraint_sql( 818 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 819) -> str: 820 start = self.sql(expression, "start") or "1" 821 increment = self.sql(expression, "increment") or "1" 822 return f"IDENTITY({start}, {increment})" 823 824 825def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 826 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 827 if expression.args.get("count"): 828 self.unsupported(f"Only two arguments are supported in function {name}.") 829 830 return self.func(name, expression.this, expression.expression) 831 832 return _arg_max_or_min_sql
21class Dialects(str, Enum): 22 DIALECT = "" 23 24 BIGQUERY = "bigquery" 25 CLICKHOUSE = "clickhouse" 26 DATABRICKS = "databricks" 27 DRILL = "drill" 28 DUCKDB = "duckdb" 29 HIVE = "hive" 30 MYSQL = "mysql" 31 ORACLE = "oracle" 32 POSTGRES = "postgres" 33 PRESTO = "presto" 34 REDSHIFT = "redshift" 35 SNOWFLAKE = "snowflake" 36 SPARK = "spark" 37 SPARK2 = "spark2" 38 SQLITE = "sqlite" 39 STARROCKS = "starrocks" 40 TABLEAU = "tableau" 41 TERADATA = "teradata" 42 TRINO = "trino" 43 TSQL = "tsql" 44 Doris = "doris"
An enumeration.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
142class Dialect(metaclass=_Dialect): 143 # Determines the base index offset for arrays 144 INDEX_OFFSET = 0 145 146 # If true unnest table aliases are considered only as column aliases 147 UNNEST_COLUMN_ONLY = False 148 149 # Determines whether or not the table alias comes after tablesample 150 ALIAS_POST_TABLESAMPLE = False 151 152 # Determines whether or not unquoted identifiers are resolved as uppercase 153 # When set to None, it means that the dialect treats all identifiers as case-insensitive 154 RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False 155 156 # Determines whether or not an unquoted identifier can start with a digit 157 IDENTIFIERS_CAN_START_WITH_DIGIT = False 158 159 # Determines whether or not the DPIPE token ('||') is a string concatenation operator 160 DPIPE_IS_STRING_CONCAT = True 161 162 # Determines whether or not CONCAT's arguments must be strings 163 STRICT_STRING_CONCAT = False 164 165 # Determines whether or not user-defined data types are supported 166 SUPPORTS_USER_DEFINED_TYPES = True 167 168 # Determines whether or not SEMI/ANTI JOINs are supported 169 SUPPORTS_SEMI_ANTI_JOIN = True 170 171 # Determines how function names are going to be normalized 172 NORMALIZE_FUNCTIONS: bool | str = "upper" 173 174 # Determines whether the base comes first in the LOG function 175 LOG_BASE_FIRST = True 176 177 # Indicates the default null ordering method to use if not explicitly set 178 # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" 179 NULL_ORDERING = "nulls_are_small" 180 181 # Whether the behavior of a / b depends on the types of a and b. 182 # False means a / b is always float division. 183 # True means a / b is integer division if both a and b are integers. 184 TYPED_DIVISION = False 185 186 # False means 1 / 0 throws an error. 187 # True means 1 / 0 returns null. 188 SAFE_DIVISION = False 189 190 DATE_FORMAT = "'%Y-%m-%d'" 191 DATEINT_FORMAT = "'%Y%m%d'" 192 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 193 194 # Custom time mappings in which the key represents dialect time format 195 # and the value represents a python time format 196 TIME_MAPPING: t.Dict[str, str] = {} 197 198 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 199 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 200 # special syntax cast(x as date format 'yyyy') defaults to time_mapping 201 FORMAT_MAPPING: t.Dict[str, str] = {} 202 203 # Mapping of an unescaped escape sequence to the corresponding character 204 ESCAPE_SEQUENCES: t.Dict[str, str] = {} 205 206 # Columns that are auto-generated by the engine corresponding to this dialect 207 # Such columns may be excluded from SELECT * queries, for example 208 PSEUDOCOLUMNS: t.Set[str] = set() 209 210 # Autofilled 211 tokenizer_class = Tokenizer 212 parser_class = Parser 213 generator_class = Generator 214 215 # A trie of the time_mapping keys 216 TIME_TRIE: t.Dict = {} 217 FORMAT_TRIE: t.Dict = {} 218 219 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 220 INVERSE_TIME_TRIE: t.Dict = {} 221 222 INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} 223 224 def __eq__(self, other: t.Any) -> bool: 225 return type(self) == other 226 227 def __hash__(self) -> int: 228 return hash(type(self)) 229 230 @classmethod 231 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 232 if not dialect: 233 return cls 234 if isinstance(dialect, _Dialect): 235 return dialect 236 if isinstance(dialect, Dialect): 237 return dialect.__class__ 238 239 result = cls.get(dialect) 240 if not result: 241 raise ValueError(f"Unknown dialect '{dialect}'") 242 243 return result 244 245 @classmethod 246 def format_time( 247 cls, expression: t.Optional[str | exp.Expression] 248 ) -> t.Optional[exp.Expression]: 249 if isinstance(expression, str): 250 return exp.Literal.string( 251 # the time formats are quoted 252 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 253 ) 254 255 if expression and expression.is_string: 256 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 257 258 return expression 259 260 @classmethod 261 def normalize_identifier(cls, expression: E) -> E: 262 """ 263 Normalizes an unquoted identifier to either lower or upper case, thus essentially 264 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 265 they will be normalized to lowercase regardless of being quoted or not. 266 """ 267 if isinstance(expression, exp.Identifier) and ( 268 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 269 ): 270 expression.set( 271 "this", 272 expression.this.upper() 273 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 274 else expression.this.lower(), 275 ) 276 277 return expression 278 279 @classmethod 280 def case_sensitive(cls, text: str) -> bool: 281 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 282 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 283 return False 284 285 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 286 return any(unsafe(char) for char in text) 287 288 @classmethod 289 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 290 """Checks if text can be identified given an identify option. 291 292 Args: 293 text: The text to check. 294 identify: 295 "always" or `True`: Always returns true. 296 "safe": True if the identifier is case-insensitive. 297 298 Returns: 299 Whether or not the given text can be identified. 300 """ 301 if identify is True or identify == "always": 302 return True 303 304 if identify == "safe": 305 return not cls.case_sensitive(text) 306 307 return False 308 309 @classmethod 310 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 311 if isinstance(expression, exp.Identifier): 312 name = expression.this 313 expression.set( 314 "quoted", 315 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 316 ) 317 318 return expression 319 320 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 321 return self.parser(**opts).parse(self.tokenize(sql), sql) 322 323 def parse_into( 324 self, expression_type: exp.IntoType, sql: str, **opts 325 ) -> t.List[t.Optional[exp.Expression]]: 326 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 327 328 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 329 return self.generator(**opts).generate(expression, copy=copy) 330 331 def transpile(self, sql: str, **opts) -> t.List[str]: 332 return [ 333 self.generate(expression, copy=False, **opts) if expression else "" 334 for expression in self.parse(sql) 335 ] 336 337 def tokenize(self, sql: str) -> t.List[Token]: 338 return self.tokenizer.tokenize(sql) 339 340 @property 341 def tokenizer(self) -> Tokenizer: 342 if not hasattr(self, "_tokenizer"): 343 self._tokenizer = self.tokenizer_class() 344 return self._tokenizer 345 346 def parser(self, **opts) -> Parser: 347 return self.parser_class(**opts) 348 349 def generator(self, **opts) -> Generator: 350 return self.generator_class(**opts)
230 @classmethod 231 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 232 if not dialect: 233 return cls 234 if isinstance(dialect, _Dialect): 235 return dialect 236 if isinstance(dialect, Dialect): 237 return dialect.__class__ 238 239 result = cls.get(dialect) 240 if not result: 241 raise ValueError(f"Unknown dialect '{dialect}'") 242 243 return result
245 @classmethod 246 def format_time( 247 cls, expression: t.Optional[str | exp.Expression] 248 ) -> t.Optional[exp.Expression]: 249 if isinstance(expression, str): 250 return exp.Literal.string( 251 # the time formats are quoted 252 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 253 ) 254 255 if expression and expression.is_string: 256 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 257 258 return expression
260 @classmethod 261 def normalize_identifier(cls, expression: E) -> E: 262 """ 263 Normalizes an unquoted identifier to either lower or upper case, thus essentially 264 making it case-insensitive. If a dialect treats all identifiers as case-insensitive, 265 they will be normalized to lowercase regardless of being quoted or not. 266 """ 267 if isinstance(expression, exp.Identifier) and ( 268 not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None 269 ): 270 expression.set( 271 "this", 272 expression.this.upper() 273 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE 274 else expression.this.lower(), 275 ) 276 277 return expression
Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, they will be normalized to lowercase regardless of being quoted or not.
279 @classmethod 280 def case_sensitive(cls, text: str) -> bool: 281 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 282 if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: 283 return False 284 285 unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper 286 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
288 @classmethod 289 def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: 290 """Checks if text can be identified given an identify option. 291 292 Args: 293 text: The text to check. 294 identify: 295 "always" or `True`: Always returns true. 296 "safe": True if the identifier is case-insensitive. 297 298 Returns: 299 Whether or not the given text can be identified. 300 """ 301 if identify is True or identify == "always": 302 return True 303 304 if identify == "safe": 305 return not cls.case_sensitive(text) 306 307 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify: "always" or
True
: Always returns true. "safe": True if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
309 @classmethod 310 def quote_identifier(cls, expression: E, identify: bool = True) -> E: 311 if isinstance(expression, exp.Identifier): 312 name = expression.this 313 expression.set( 314 "quoted", 315 identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 316 ) 317 318 return expression
366def if_sql( 367 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 368) -> t.Callable[[Generator, exp.If], str]: 369 def _if_sql(self: Generator, expression: exp.If) -> str: 370 return self.func( 371 name, 372 expression.this, 373 expression.args.get("true"), 374 expression.args.get("false") or false_value, 375 ) 376 377 return _if_sql
449def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 450 this = self.sql(expression, "this") 451 substr = self.sql(expression, "substr") 452 position = self.sql(expression, "position") 453 if position: 454 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 455 return f"STRPOS({this}, {substr})"
464def var_map_sql( 465 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 466) -> str: 467 keys = expression.args["keys"] 468 values = expression.args["values"] 469 470 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 471 self.unsupported("Cannot convert array columns into map.") 472 return self.func(map_func_name, keys, values) 473 474 args = [] 475 for key, value in zip(keys.expressions, values.expressions): 476 args.append(self.sql(key)) 477 args.append(self.sql(value)) 478 479 return self.func(map_func_name, *args)
482def format_time_lambda( 483 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 484) -> t.Callable[[t.List], E]: 485 """Helper used for time expressions. 486 487 Args: 488 exp_class: the expression class to instantiate. 489 dialect: target sql dialect. 490 default: the default format, True being time. 491 492 Returns: 493 A callable that can be used to return the appropriately formatted time expression. 494 """ 495 496 def _format_time(args: t.List): 497 return exp_class( 498 this=seq_get(args, 0), 499 format=Dialect[dialect].format_time( 500 seq_get(args, 1) 501 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 502 ), 503 ) 504 505 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
508def time_format( 509 dialect: DialectType = None, 510) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 511 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 512 """ 513 Returns the time format for a given expression, unless it's equivalent 514 to the default time format of the dialect of interest. 515 """ 516 time_format = self.format_time(expression) 517 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 518 519 return _time_format
522def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 523 """ 524 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 525 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 526 columns are removed from the create statement. 527 """ 528 has_schema = isinstance(expression.this, exp.Schema) 529 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 530 531 if has_schema and is_partitionable: 532 prop = expression.find(exp.PartitionedByProperty) 533 if prop and prop.this and not isinstance(prop.this, exp.Schema): 534 schema = expression.this 535 columns = {v.name.upper() for v in prop.this.expressions} 536 partitions = [col for col in schema.expressions if col.name.upper() in columns] 537 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 538 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 539 expression.set("this", schema) 540 541 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
544def parse_date_delta( 545 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 546) -> t.Callable[[t.List], E]: 547 def inner_func(args: t.List) -> E: 548 unit_based = len(args) == 3 549 this = args[2] if unit_based else seq_get(args, 0) 550 unit = args[0] if unit_based else exp.Literal.string("DAY") 551 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 552 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 553 554 return inner_func
557def parse_date_delta_with_interval( 558 expression_class: t.Type[E], 559) -> t.Callable[[t.List], t.Optional[E]]: 560 def func(args: t.List) -> t.Optional[E]: 561 if len(args) < 2: 562 return None 563 564 interval = args[1] 565 566 if not isinstance(interval, exp.Interval): 567 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 568 569 expression = interval.this 570 if expression and expression.is_string: 571 expression = exp.Literal.number(expression.this) 572 573 return expression_class( 574 this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) 575 ) 576 577 return func
589def date_add_interval_sql( 590 data_type: str, kind: str 591) -> t.Callable[[Generator, exp.Expression], str]: 592 def func(self: Generator, expression: exp.Expression) -> str: 593 this = self.sql(expression, "this") 594 unit = expression.args.get("unit") 595 unit = exp.var(unit.name.upper() if unit else "DAY") 596 interval = exp.Interval(this=expression.expression, unit=unit) 597 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 598 599 return func
608def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 609 if not expression.expression: 610 return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) 611 if expression.text("expression").lower() in TIMEZONES: 612 return self.sql( 613 exp.AtTimeZone( 614 this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), 615 zone=expression.expression, 616 ) 617 ) 618 return self.function_fallback_sql(expression)
659def encode_decode_sql( 660 self: Generator, expression: exp.Expression, name: str, replace: bool = True 661) -> str: 662 charset = expression.args.get("charset") 663 if charset and charset.name.lower() != "utf-8": 664 self.unsupported(f"Expected utf-8 character set, got {charset}.") 665 666 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
679def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 680 cond = expression.this 681 682 if isinstance(expression.this, exp.Distinct): 683 cond = expression.this.expressions[0] 684 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 685 686 return self.func("sum", exp.func("if", cond, 1, 0))
689def trim_sql(self: Generator, expression: exp.Trim) -> str: 690 target = self.sql(expression, "this") 691 trim_type = self.sql(expression, "position") 692 remove_chars = self.sql(expression, "expression") 693 collation = self.sql(expression, "collation") 694 695 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 696 if not remove_chars and not collation: 697 return self.trim_sql(expression) 698 699 trim_type = f"{trim_type} " if trim_type else "" 700 remove_chars = f"{remove_chars} " if remove_chars else "" 701 from_part = "FROM " if trim_type or remove_chars else "" 702 collation = f" COLLATE {collation}" if collation else "" 703 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
710def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 711 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 712 _dialect = Dialect.get_or_raise(dialect) 713 time_format = self.format_time(expression) 714 if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): 715 return self.sql( 716 exp.cast( 717 exp.StrToTime(this=expression.this, format=expression.args["format"]), 718 "date", 719 ) 720 ) 721 return self.sql(exp.cast(expression.this, "date")) 722 723 return _ts_or_ds_to_date_sql
740def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 741 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 742 if bad_args: 743 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 744 745 return self.func( 746 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 747 )
750def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 751 bad_args = list( 752 filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) 753 ) 754 if bad_args: 755 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 756 757 return self.func( 758 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 759 )
762def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 763 names = [] 764 for agg in aggregations: 765 if isinstance(agg, exp.Alias): 766 names.append(agg.alias) 767 else: 768 """ 769 This case corresponds to aggregations without aliases being used as suffixes 770 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 771 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 772 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 773 """ 774 agg_all_unquoted = agg.transform( 775 lambda node: exp.Identifier(this=node.name, quoted=False) 776 if isinstance(node, exp.Identifier) 777 else node 778 ) 779 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 780 781 return names
826def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 827 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 828 if expression.args.get("count"): 829 self.unsupported(f"Only two arguments are supported in function {name}.") 830 831 return self.func(name, expression.this, expression.expression) 832 833 return _arg_max_or_min_sql