sqlglot.dialects.dialect
1from __future__ import annotations 2 3import typing as t 4from enum import Enum 5 6from sqlglot import exp 7from sqlglot.generator import Generator 8from sqlglot.helper import flatten, seq_get 9from sqlglot.parser import Parser 10from sqlglot.time import format_time 11from sqlglot.tokens import Token, Tokenizer 12from sqlglot.trie import new_trie 13 14E = t.TypeVar("E", bound=exp.Expression) 15 16 17class Dialects(str, Enum): 18 DIALECT = "" 19 20 BIGQUERY = "bigquery" 21 CLICKHOUSE = "clickhouse" 22 DUCKDB = "duckdb" 23 HIVE = "hive" 24 MYSQL = "mysql" 25 ORACLE = "oracle" 26 POSTGRES = "postgres" 27 PRESTO = "presto" 28 REDSHIFT = "redshift" 29 SNOWFLAKE = "snowflake" 30 SPARK = "spark" 31 SPARK2 = "spark2" 32 SQLITE = "sqlite" 33 STARROCKS = "starrocks" 34 TABLEAU = "tableau" 35 TRINO = "trino" 36 TSQL = "tsql" 37 DATABRICKS = "databricks" 38 DRILL = "drill" 39 TERADATA = "teradata" 40 41 42class _Dialect(type): 43 classes: t.Dict[str, t.Type[Dialect]] = {} 44 45 @classmethod 46 def __getitem__(cls, key: str) -> t.Type[Dialect]: 47 return cls.classes[key] 48 49 @classmethod 50 def get( 51 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 52 ) -> t.Optional[t.Type[Dialect]]: 53 return cls.classes.get(key, default) 54 55 def __new__(cls, clsname, bases, attrs): 56 klass = super().__new__(cls, clsname, bases, attrs) 57 enum = Dialects.__members__.get(clsname.upper()) 58 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 59 60 klass.time_trie = new_trie(klass.time_mapping) 61 klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()} 62 klass.inverse_time_trie = new_trie(klass.inverse_time_mapping) 63 64 klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) 65 klass.parser_class = getattr(klass, "Parser", Parser) 66 klass.generator_class = getattr(klass, "Generator", Generator) 67 68 klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] 69 klass.identifier_start, klass.identifier_end = list( 70 klass.tokenizer_class._IDENTIFIERS.items() 71 )[0] 72 73 klass.bit_start, klass.bit_end = seq_get( 74 list(klass.tokenizer_class._BIT_STRINGS.items()), 0 75 ) or (None, None) 76 77 klass.hex_start, klass.hex_end = seq_get( 78 list(klass.tokenizer_class._HEX_STRINGS.items()), 0 79 ) or (None, None) 80 81 klass.byte_start, klass.byte_end = seq_get( 82 list(klass.tokenizer_class._BYTE_STRINGS.items()), 0 83 ) or (None, None) 84 85 return klass 86 87 88class Dialect(metaclass=_Dialect): 89 index_offset = 0 90 unnest_column_only = False 91 alias_post_tablesample = False 92 normalize_functions: t.Optional[str] = "upper" 93 null_ordering = "nulls_are_small" 94 95 date_format = "'%Y-%m-%d'" 96 dateint_format = "'%Y%m%d'" 97 time_format = "'%Y-%m-%d %H:%M:%S'" 98 time_mapping: t.Dict[str, str] = {} 99 100 # autofilled 101 quote_start = None 102 quote_end = None 103 identifier_start = None 104 identifier_end = None 105 106 time_trie = None 107 inverse_time_mapping = None 108 inverse_time_trie = None 109 tokenizer_class = None 110 parser_class = None 111 generator_class = None 112 113 @classmethod 114 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 115 if not dialect: 116 return cls 117 if isinstance(dialect, _Dialect): 118 return dialect 119 if isinstance(dialect, Dialect): 120 return dialect.__class__ 121 122 result = cls.get(dialect) 123 if not result: 124 raise ValueError(f"Unknown dialect '{dialect}'") 125 126 return result 127 128 @classmethod 129 def format_time( 130 cls, expression: t.Optional[str | exp.Expression] 131 ) -> t.Optional[exp.Expression]: 132 if isinstance(expression, str): 133 return exp.Literal.string( 134 format_time( 135 expression[1:-1], # the time formats are quoted 136 cls.time_mapping, 137 cls.time_trie, 138 ) 139 ) 140 if expression and expression.is_string: 141 return exp.Literal.string( 142 format_time( 143 expression.this, 144 cls.time_mapping, 145 cls.time_trie, 146 ) 147 ) 148 return expression 149 150 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 151 return self.parser(**opts).parse(self.tokenize(sql), sql) 152 153 def parse_into( 154 self, expression_type: exp.IntoType, sql: str, **opts 155 ) -> t.List[t.Optional[exp.Expression]]: 156 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 157 158 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 159 return self.generator(**opts).generate(expression) 160 161 def transpile(self, sql: str, **opts) -> t.List[str]: 162 return [self.generate(expression, **opts) for expression in self.parse(sql)] 163 164 def tokenize(self, sql: str) -> t.List[Token]: 165 return self.tokenizer.tokenize(sql) 166 167 @property 168 def tokenizer(self) -> Tokenizer: 169 if not hasattr(self, "_tokenizer"): 170 self._tokenizer = self.tokenizer_class() # type: ignore 171 return self._tokenizer 172 173 def parser(self, **opts) -> Parser: 174 return self.parser_class( # type: ignore 175 **{ 176 "index_offset": self.index_offset, 177 "unnest_column_only": self.unnest_column_only, 178 "alias_post_tablesample": self.alias_post_tablesample, 179 "null_ordering": self.null_ordering, 180 **opts, 181 }, 182 ) 183 184 def generator(self, **opts) -> Generator: 185 return self.generator_class( # type: ignore 186 **{ 187 "quote_start": self.quote_start, 188 "quote_end": self.quote_end, 189 "bit_start": self.bit_start, 190 "bit_end": self.bit_end, 191 "hex_start": self.hex_start, 192 "hex_end": self.hex_end, 193 "byte_start": self.byte_start, 194 "byte_end": self.byte_end, 195 "identifier_start": self.identifier_start, 196 "identifier_end": self.identifier_end, 197 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 198 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 199 "index_offset": self.index_offset, 200 "time_mapping": self.inverse_time_mapping, 201 "time_trie": self.inverse_time_trie, 202 "unnest_column_only": self.unnest_column_only, 203 "alias_post_tablesample": self.alias_post_tablesample, 204 "normalize_functions": self.normalize_functions, 205 "null_ordering": self.null_ordering, 206 **opts, 207 } 208 ) 209 210 211DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 212 213 214def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 215 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 216 217 218def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 219 if expression.args.get("accuracy"): 220 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 221 return self.func("APPROX_COUNT_DISTINCT", expression.this) 222 223 224def if_sql(self: Generator, expression: exp.If) -> str: 225 return self.func( 226 "IF", expression.this, expression.args.get("true"), expression.args.get("false") 227 ) 228 229 230def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: 231 return self.binary(expression, "->") 232 233 234def arrow_json_extract_scalar_sql( 235 self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar 236) -> str: 237 return self.binary(expression, "->>") 238 239 240def inline_array_sql(self: Generator, expression: exp.Array) -> str: 241 return f"[{self.expressions(expression)}]" 242 243 244def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 245 return self.like_sql( 246 exp.Like( 247 this=exp.Lower(this=expression.this), 248 expression=expression.args["expression"], 249 ) 250 ) 251 252 253def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 254 zone = self.sql(expression, "this") 255 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 256 257 258def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 259 if expression.args.get("recursive"): 260 self.unsupported("Recursive CTEs are unsupported") 261 expression.args["recursive"] = False 262 return self.with_sql(expression) 263 264 265def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 266 n = self.sql(expression, "this") 267 d = self.sql(expression, "expression") 268 return f"IF({d} <> 0, {n} / {d}, NULL)" 269 270 271def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 272 self.unsupported("TABLESAMPLE unsupported") 273 return self.sql(expression.this) 274 275 276def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 277 self.unsupported("PIVOT unsupported") 278 return "" 279 280 281def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 282 return self.cast_sql(expression) 283 284 285def no_properties_sql(self: Generator, expression: exp.Properties) -> str: 286 self.unsupported("Properties unsupported") 287 return "" 288 289 290def no_comment_column_constraint_sql( 291 self: Generator, expression: exp.CommentColumnConstraint 292) -> str: 293 self.unsupported("CommentColumnConstraint unsupported") 294 return "" 295 296 297def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 298 this = self.sql(expression, "this") 299 substr = self.sql(expression, "substr") 300 position = self.sql(expression, "position") 301 if position: 302 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 303 return f"STRPOS({this}, {substr})" 304 305 306def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 307 this = self.sql(expression, "this") 308 struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) 309 return f"{this}.{struct_key}" 310 311 312def var_map_sql( 313 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 314) -> str: 315 keys = expression.args["keys"] 316 values = expression.args["values"] 317 318 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 319 self.unsupported("Cannot convert array columns into map.") 320 return self.func(map_func_name, keys, values) 321 322 args = [] 323 for key, value in zip(keys.expressions, values.expressions): 324 args.append(self.sql(key)) 325 args.append(self.sql(value)) 326 return self.func(map_func_name, *args) 327 328 329def format_time_lambda( 330 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 331) -> t.Callable[[t.Sequence], E]: 332 """Helper used for time expressions. 333 334 Args: 335 exp_class: the expression class to instantiate. 336 dialect: target sql dialect. 337 default: the default format, True being time. 338 339 Returns: 340 A callable that can be used to return the appropriately formatted time expression. 341 """ 342 343 def _format_time(args: t.Sequence): 344 return exp_class( 345 this=seq_get(args, 0), 346 format=Dialect[dialect].format_time( 347 seq_get(args, 1) 348 or (Dialect[dialect].time_format if default is True else default or None) 349 ), 350 ) 351 352 return _format_time 353 354 355def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 356 """ 357 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 358 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 359 columns are removed from the create statement. 360 """ 361 has_schema = isinstance(expression.this, exp.Schema) 362 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 363 364 if has_schema and is_partitionable: 365 expression = expression.copy() 366 prop = expression.find(exp.PartitionedByProperty) 367 if prop and prop.this and not isinstance(prop.this, exp.Schema): 368 schema = expression.this 369 columns = {v.name.upper() for v in prop.this.expressions} 370 partitions = [col for col in schema.expressions if col.name.upper() in columns] 371 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 372 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 373 expression.set("this", schema) 374 375 return self.create_sql(expression) 376 377 378def parse_date_delta( 379 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 380) -> t.Callable[[t.Sequence], E]: 381 def inner_func(args: t.Sequence) -> E: 382 unit_based = len(args) == 3 383 this = args[2] if unit_based else seq_get(args, 0) 384 unit = args[0] if unit_based else exp.Literal.string("DAY") 385 unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit 386 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 387 388 return inner_func 389 390 391def parse_date_delta_with_interval( 392 expression_class: t.Type[E], 393) -> t.Callable[[t.Sequence], t.Optional[E]]: 394 def func(args: t.Sequence) -> t.Optional[E]: 395 if len(args) < 2: 396 return None 397 398 interval = args[1] 399 expression = interval.this 400 if expression and expression.is_string: 401 expression = exp.Literal.number(expression.this) 402 403 return expression_class( 404 this=args[0], 405 expression=expression, 406 unit=exp.Literal.string(interval.text("unit")), 407 ) 408 409 return func 410 411 412def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc: 413 unit = seq_get(args, 0) 414 this = seq_get(args, 1) 415 416 if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE): 417 return exp.DateTrunc(unit=unit, this=this) 418 return exp.TimestampTrunc(this=this, unit=unit) 419 420 421def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 422 return self.func( 423 "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this 424 ) 425 426 427def locate_to_strposition(args: t.Sequence) -> exp.Expression: 428 return exp.StrPosition( 429 this=seq_get(args, 1), 430 substr=seq_get(args, 0), 431 position=seq_get(args, 2), 432 ) 433 434 435def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 436 return self.func( 437 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 438 ) 439 440 441def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 442 return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" 443 444 445def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 446 return f"CAST({self.sql(expression, 'this')} AS DATE)" 447 448 449def min_or_least(self: Generator, expression: exp.Min) -> str: 450 name = "LEAST" if expression.expressions else "MIN" 451 return rename_func(name)(self, expression) 452 453 454def max_or_greatest(self: Generator, expression: exp.Max) -> str: 455 name = "GREATEST" if expression.expressions else "MAX" 456 return rename_func(name)(self, expression) 457 458 459def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 460 cond = expression.this 461 462 if isinstance(expression.this, exp.Distinct): 463 cond = expression.this.expressions[0] 464 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 465 466 return self.func("sum", exp.func("if", cond, 1, 0)) 467 468 469def trim_sql(self: Generator, expression: exp.Trim) -> str: 470 target = self.sql(expression, "this") 471 trim_type = self.sql(expression, "position") 472 remove_chars = self.sql(expression, "expression") 473 collation = self.sql(expression, "collation") 474 475 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 476 if not remove_chars and not collation: 477 return self.trim_sql(expression) 478 479 trim_type = f"{trim_type} " if trim_type else "" 480 remove_chars = f"{remove_chars} " if remove_chars else "" 481 from_part = "FROM " if trim_type or remove_chars else "" 482 collation = f" COLLATE {collation}" if collation else "" 483 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 484 485 486def str_to_time_sql(self, expression: exp.Expression) -> str: 487 return self.func("STRPTIME", expression.this, self.format_time(expression)) 488 489 490def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 491 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 492 _dialect = Dialect.get_or_raise(dialect) 493 time_format = self.format_time(expression) 494 if time_format and time_format not in (_dialect.time_format, _dialect.date_format): 495 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 496 return f"CAST({self.sql(expression, 'this')} AS DATE)" 497 498 return _ts_or_ds_to_date_sql
class
Dialects(builtins.str, enum.Enum):
18class Dialects(str, Enum): 19 DIALECT = "" 20 21 BIGQUERY = "bigquery" 22 CLICKHOUSE = "clickhouse" 23 DUCKDB = "duckdb" 24 HIVE = "hive" 25 MYSQL = "mysql" 26 ORACLE = "oracle" 27 POSTGRES = "postgres" 28 PRESTO = "presto" 29 REDSHIFT = "redshift" 30 SNOWFLAKE = "snowflake" 31 SPARK = "spark" 32 SPARK2 = "spark2" 33 SQLITE = "sqlite" 34 STARROCKS = "starrocks" 35 TABLEAU = "tableau" 36 TRINO = "trino" 37 TSQL = "tsql" 38 DATABRICKS = "databricks" 39 DRILL = "drill" 40 TERADATA = "teradata"
An enumeration.
DIALECT =
<Dialects.DIALECT: ''>
BIGQUERY =
<Dialects.BIGQUERY: 'bigquery'>
CLICKHOUSE =
<Dialects.CLICKHOUSE: 'clickhouse'>
DUCKDB =
<Dialects.DUCKDB: 'duckdb'>
HIVE =
<Dialects.HIVE: 'hive'>
MYSQL =
<Dialects.MYSQL: 'mysql'>
ORACLE =
<Dialects.ORACLE: 'oracle'>
POSTGRES =
<Dialects.POSTGRES: 'postgres'>
PRESTO =
<Dialects.PRESTO: 'presto'>
REDSHIFT =
<Dialects.REDSHIFT: 'redshift'>
SNOWFLAKE =
<Dialects.SNOWFLAKE: 'snowflake'>
SPARK =
<Dialects.SPARK: 'spark'>
SPARK2 =
<Dialects.SPARK2: 'spark2'>
SQLITE =
<Dialects.SQLITE: 'sqlite'>
STARROCKS =
<Dialects.STARROCKS: 'starrocks'>
TABLEAU =
<Dialects.TABLEAU: 'tableau'>
TRINO =
<Dialects.TRINO: 'trino'>
TSQL =
<Dialects.TSQL: 'tsql'>
DATABRICKS =
<Dialects.DATABRICKS: 'databricks'>
DRILL =
<Dialects.DRILL: 'drill'>
TERADATA =
<Dialects.TERADATA: 'teradata'>
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
class
Dialect:
89class Dialect(metaclass=_Dialect): 90 index_offset = 0 91 unnest_column_only = False 92 alias_post_tablesample = False 93 normalize_functions: t.Optional[str] = "upper" 94 null_ordering = "nulls_are_small" 95 96 date_format = "'%Y-%m-%d'" 97 dateint_format = "'%Y%m%d'" 98 time_format = "'%Y-%m-%d %H:%M:%S'" 99 time_mapping: t.Dict[str, str] = {} 100 101 # autofilled 102 quote_start = None 103 quote_end = None 104 identifier_start = None 105 identifier_end = None 106 107 time_trie = None 108 inverse_time_mapping = None 109 inverse_time_trie = None 110 tokenizer_class = None 111 parser_class = None 112 generator_class = None 113 114 @classmethod 115 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 116 if not dialect: 117 return cls 118 if isinstance(dialect, _Dialect): 119 return dialect 120 if isinstance(dialect, Dialect): 121 return dialect.__class__ 122 123 result = cls.get(dialect) 124 if not result: 125 raise ValueError(f"Unknown dialect '{dialect}'") 126 127 return result 128 129 @classmethod 130 def format_time( 131 cls, expression: t.Optional[str | exp.Expression] 132 ) -> t.Optional[exp.Expression]: 133 if isinstance(expression, str): 134 return exp.Literal.string( 135 format_time( 136 expression[1:-1], # the time formats are quoted 137 cls.time_mapping, 138 cls.time_trie, 139 ) 140 ) 141 if expression and expression.is_string: 142 return exp.Literal.string( 143 format_time( 144 expression.this, 145 cls.time_mapping, 146 cls.time_trie, 147 ) 148 ) 149 return expression 150 151 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 152 return self.parser(**opts).parse(self.tokenize(sql), sql) 153 154 def parse_into( 155 self, expression_type: exp.IntoType, sql: str, **opts 156 ) -> t.List[t.Optional[exp.Expression]]: 157 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 158 159 def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: 160 return self.generator(**opts).generate(expression) 161 162 def transpile(self, sql: str, **opts) -> t.List[str]: 163 return [self.generate(expression, **opts) for expression in self.parse(sql)] 164 165 def tokenize(self, sql: str) -> t.List[Token]: 166 return self.tokenizer.tokenize(sql) 167 168 @property 169 def tokenizer(self) -> Tokenizer: 170 if not hasattr(self, "_tokenizer"): 171 self._tokenizer = self.tokenizer_class() # type: ignore 172 return self._tokenizer 173 174 def parser(self, **opts) -> Parser: 175 return self.parser_class( # type: ignore 176 **{ 177 "index_offset": self.index_offset, 178 "unnest_column_only": self.unnest_column_only, 179 "alias_post_tablesample": self.alias_post_tablesample, 180 "null_ordering": self.null_ordering, 181 **opts, 182 }, 183 ) 184 185 def generator(self, **opts) -> Generator: 186 return self.generator_class( # type: ignore 187 **{ 188 "quote_start": self.quote_start, 189 "quote_end": self.quote_end, 190 "bit_start": self.bit_start, 191 "bit_end": self.bit_end, 192 "hex_start": self.hex_start, 193 "hex_end": self.hex_end, 194 "byte_start": self.byte_start, 195 "byte_end": self.byte_end, 196 "identifier_start": self.identifier_start, 197 "identifier_end": self.identifier_end, 198 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 199 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 200 "index_offset": self.index_offset, 201 "time_mapping": self.inverse_time_mapping, 202 "time_trie": self.inverse_time_trie, 203 "unnest_column_only": self.unnest_column_only, 204 "alias_post_tablesample": self.alias_post_tablesample, 205 "normalize_functions": self.normalize_functions, 206 "null_ordering": self.null_ordering, 207 **opts, 208 } 209 )
@classmethod
def
get_or_raise( cls, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType]) -> Type[sqlglot.dialects.dialect.Dialect]:
114 @classmethod 115 def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: 116 if not dialect: 117 return cls 118 if isinstance(dialect, _Dialect): 119 return dialect 120 if isinstance(dialect, Dialect): 121 return dialect.__class__ 122 123 result = cls.get(dialect) 124 if not result: 125 raise ValueError(f"Unknown dialect '{dialect}'") 126 127 return result
@classmethod
def
format_time( cls, expression: Union[str, sqlglot.expressions.Expression, NoneType]) -> Optional[sqlglot.expressions.Expression]:
129 @classmethod 130 def format_time( 131 cls, expression: t.Optional[str | exp.Expression] 132 ) -> t.Optional[exp.Expression]: 133 if isinstance(expression, str): 134 return exp.Literal.string( 135 format_time( 136 expression[1:-1], # the time formats are quoted 137 cls.time_mapping, 138 cls.time_trie, 139 ) 140 ) 141 if expression and expression.is_string: 142 return exp.Literal.string( 143 format_time( 144 expression.this, 145 cls.time_mapping, 146 cls.time_trie, 147 ) 148 ) 149 return expression
def
parse_into( self, expression_type: Union[str, Type[sqlglot.expressions.Expression], Collection[Union[str, Type[sqlglot.expressions.Expression]]]], sql: str, **opts) -> List[Optional[sqlglot.expressions.Expression]]:
174 def parser(self, **opts) -> Parser: 175 return self.parser_class( # type: ignore 176 **{ 177 "index_offset": self.index_offset, 178 "unnest_column_only": self.unnest_column_only, 179 "alias_post_tablesample": self.alias_post_tablesample, 180 "null_ordering": self.null_ordering, 181 **opts, 182 }, 183 )
185 def generator(self, **opts) -> Generator: 186 return self.generator_class( # type: ignore 187 **{ 188 "quote_start": self.quote_start, 189 "quote_end": self.quote_end, 190 "bit_start": self.bit_start, 191 "bit_end": self.bit_end, 192 "hex_start": self.hex_start, 193 "hex_end": self.hex_end, 194 "byte_start": self.byte_start, 195 "byte_end": self.byte_end, 196 "identifier_start": self.identifier_start, 197 "identifier_end": self.identifier_end, 198 "string_escape": self.tokenizer_class.STRING_ESCAPES[0], 199 "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], 200 "index_offset": self.index_offset, 201 "time_mapping": self.inverse_time_mapping, 202 "time_trie": self.inverse_time_trie, 203 "unnest_column_only": self.unnest_column_only, 204 "alias_post_tablesample": self.alias_post_tablesample, 205 "normalize_functions": self.normalize_functions, 206 "null_ordering": self.null_ordering, 207 **opts, 208 } 209 )
def
rename_func( name: str) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
def
approx_count_distinct_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.ApproxDistinct) -> str:
def
arrow_json_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtract | sqlglot.expressions.JSONBExtract) -> str:
def
arrow_json_extract_scalar_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.JSONExtractScalar | sqlglot.expressions.JSONBExtractScalar) -> str:
def
inline_array_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Array) -> str:
def
no_paren_current_date_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CurrentDate) -> str:
def
no_recursive_cte_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.With) -> str:
def
no_safe_divide_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.SafeDivide) -> str:
def
no_tablesample_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TableSample) -> str:
def
no_trycast_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TryCast) -> str:
def
no_properties_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Properties) -> str:
def
no_comment_column_constraint_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CommentColumnConstraint) -> str:
def
str_position_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
298def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: 299 this = self.sql(expression, "this") 300 substr = self.sql(expression, "substr") 301 position = self.sql(expression, "position") 302 if position: 303 return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" 304 return f"STRPOS({this}, {substr})"
def
struct_extract_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StructExtract) -> str:
def
var_map_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Map | sqlglot.expressions.VarMap, map_func_name: str = 'MAP') -> str:
313def var_map_sql( 314 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 315) -> str: 316 keys = expression.args["keys"] 317 values = expression.args["values"] 318 319 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 320 self.unsupported("Cannot convert array columns into map.") 321 return self.func(map_func_name, keys, values) 322 323 args = [] 324 for key, value in zip(keys.expressions, values.expressions): 325 args.append(self.sql(key)) 326 args.append(self.sql(value)) 327 return self.func(map_func_name, *args)
def
format_time_lambda( exp_class: Type[~E], dialect: str, default: Union[bool, str, NoneType] = None) -> Callable[[Sequence], ~E]:
330def format_time_lambda( 331 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 332) -> t.Callable[[t.Sequence], E]: 333 """Helper used for time expressions. 334 335 Args: 336 exp_class: the expression class to instantiate. 337 dialect: target sql dialect. 338 default: the default format, True being time. 339 340 Returns: 341 A callable that can be used to return the appropriately formatted time expression. 342 """ 343 344 def _format_time(args: t.Sequence): 345 return exp_class( 346 this=seq_get(args, 0), 347 format=Dialect[dialect].format_time( 348 seq_get(args, 1) 349 or (Dialect[dialect].time_format if default is True else default or None) 350 ), 351 ) 352 353 return _format_time
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
def
create_with_partitions_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Create) -> str:
356def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: 357 """ 358 In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the 359 PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding 360 columns are removed from the create statement. 361 """ 362 has_schema = isinstance(expression.this, exp.Schema) 363 is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") 364 365 if has_schema and is_partitionable: 366 expression = expression.copy() 367 prop = expression.find(exp.PartitionedByProperty) 368 if prop and prop.this and not isinstance(prop.this, exp.Schema): 369 schema = expression.this 370 columns = {v.name.upper() for v in prop.this.expressions} 371 partitions = [col for col in schema.expressions if col.name.upper() in columns] 372 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 373 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 374 expression.set("this", schema) 375 376 return self.create_sql(expression)
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
def
parse_date_delta( exp_class: Type[~E], unit_mapping: Optional[Dict[str, str]] = None) -> Callable[[Sequence], ~E]:
379def parse_date_delta( 380 exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None 381) -> t.Callable[[t.Sequence], E]: 382 def inner_func(args: t.Sequence) -> E: 383 unit_based = len(args) == 3 384 this = args[2] if unit_based else seq_get(args, 0) 385 unit = args[0] if unit_based else exp.Literal.string("DAY") 386 unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit 387 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 388 389 return inner_func
def
parse_date_delta_with_interval(expression_class: Type[~E]) -> Callable[[Sequence], Optional[~E]]:
392def parse_date_delta_with_interval( 393 expression_class: t.Type[E], 394) -> t.Callable[[t.Sequence], t.Optional[E]]: 395 def func(args: t.Sequence) -> t.Optional[E]: 396 if len(args) < 2: 397 return None 398 399 interval = args[1] 400 expression = interval.this 401 if expression and expression.is_string: 402 expression = exp.Literal.number(expression.this) 403 404 return expression_class( 405 this=args[0], 406 expression=expression, 407 unit=exp.Literal.string(interval.text("unit")), 408 ) 409 410 return func
def
date_trunc_to_time( args: Sequence) -> sqlglot.expressions.DateTrunc | sqlglot.expressions.TimestampTrunc:
413def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc: 414 unit = seq_get(args, 0) 415 this = seq_get(args, 1) 416 417 if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE): 418 return exp.DateTrunc(unit=unit, this=this) 419 return exp.TimestampTrunc(this=this, unit=unit)
def
timestamptrunc_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimestampTrunc) -> str:
def
strposition_to_locate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.StrPosition) -> str:
def
timestrtotime_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.TimeStrToTime) -> str:
def
datestrtodate_sql( self: sqlglot.generator.Generator, expression: sqlglot.expressions.DateStrToDate) -> str:
def
max_or_greatest( self: sqlglot.generator.Generator, expression: sqlglot.expressions.Max) -> str:
def
count_if_to_sum( self: sqlglot.generator.Generator, expression: sqlglot.expressions.CountIf) -> str:
460def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 461 cond = expression.this 462 463 if isinstance(expression.this, exp.Distinct): 464 cond = expression.this.expressions[0] 465 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 466 467 return self.func("sum", exp.func("if", cond, 1, 0))
470def trim_sql(self: Generator, expression: exp.Trim) -> str: 471 target = self.sql(expression, "this") 472 trim_type = self.sql(expression, "position") 473 remove_chars = self.sql(expression, "expression") 474 collation = self.sql(expression, "collation") 475 476 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 477 if not remove_chars and not collation: 478 return self.trim_sql(expression) 479 480 trim_type = f"{trim_type} " if trim_type else "" 481 remove_chars = f"{remove_chars} " if remove_chars else "" 482 from_part = "FROM " if trim_type or remove_chars else "" 483 collation = f" COLLATE {collation}" if collation else "" 484 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
def
ts_or_ds_to_date_sql(dialect: str) -> Callable:
491def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: 492 def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: 493 _dialect = Dialect.get_or_raise(dialect) 494 time_format = self.format_time(expression) 495 if time_format and time_format not in (_dialect.time_format, _dialect.date_format): 496 return f"CAST({str_to_time_sql(self, expression)} AS DATE)" 497 return f"CAST({self.sql(expression, 'this')} AS DATE)" 498 499 return _ts_or_ds_to_date_sql