sqlglot.dialects.spark2
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import exp, parser, transforms 6from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql 7from sqlglot.dialects.hive import Hive 8from sqlglot.helper import seq_get 9 10 11def _create_sql(self: Hive.Generator, e: exp.Create) -> str: 12 kind = e.args["kind"] 13 properties = e.args.get("properties") 14 15 if kind.upper() == "TABLE" and any( 16 isinstance(prop, exp.TemporaryProperty) 17 for prop in (properties.expressions if properties else []) 18 ): 19 return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" 20 return create_with_partitions_sql(self, e) 21 22 23def _map_sql(self: Hive.Generator, expression: exp.Map) -> str: 24 keys = self.sql(expression.args["keys"]) 25 values = self.sql(expression.args["values"]) 26 return f"MAP_FROM_ARRAYS({keys}, {values})" 27 28 29def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]: 30 return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type)) 31 32 33def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str: 34 this = self.sql(expression, "this") 35 time_format = self.format_time(expression) 36 if time_format == Hive.date_format: 37 return f"TO_DATE({this})" 38 return f"TO_DATE({this}, {time_format})" 39 40 41def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str: 42 scale = expression.args.get("scale") 43 timestamp = self.sql(expression, "this") 44 if scale is None: 45 return f"CAST(FROM_UNIXTIME({timestamp}) AS TIMESTAMP)" 46 if scale == exp.UnixToTime.SECONDS: 47 return f"TIMESTAMP_SECONDS({timestamp})" 48 if scale == exp.UnixToTime.MILLIS: 49 return f"TIMESTAMP_MILLIS({timestamp})" 50 if scale == exp.UnixToTime.MICROS: 51 return f"TIMESTAMP_MICROS({timestamp})" 52 53 raise ValueError("Improper scale for timestamp") 54 55 56def _unalias_pivot(expression: exp.Expression) -> exp.Expression: 57 """ 58 Spark doesn't allow PIVOT aliases, so we need to remove them and possibly wrap a 59 pivoted source in a subquery with the same alias to preserve the query's semantics. 60 61 Example: 62 >>> from sqlglot import parse_one 63 >>> expr = parse_one("SELECT piv.x FROM tbl PIVOT (SUM(a) FOR b IN ('x')) piv") 64 >>> print(_unalias_pivot(expr).sql(dialect="spark")) 65 SELECT piv.x FROM (SELECT * FROM tbl PIVOT(SUM(a) FOR b IN ('x'))) AS piv 66 """ 67 if isinstance(expression, exp.From) and expression.this.args.get("pivots"): 68 pivot = expression.this.args["pivots"][0] 69 if pivot.alias: 70 alias = pivot.args["alias"].pop() 71 return exp.From( 72 this=expression.this.replace( 73 exp.select("*").from_(expression.this.copy()).subquery(alias=alias) 74 ) 75 ) 76 77 return expression 78 79 80def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: 81 """ 82 Spark doesn't allow the column referenced in the PIVOT's field to be qualified, 83 so we need to unqualify it. 84 85 Example: 86 >>> from sqlglot import parse_one 87 >>> expr = parse_one("SELECT * FROM tbl PIVOT (SUM(tbl.sales) FOR tbl.quarter IN ('Q1', 'Q2'))") 88 >>> print(_unqualify_pivot_columns(expr).sql(dialect="spark")) 89 SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1')) 90 """ 91 if isinstance(expression, exp.Pivot): 92 expression.args["field"].transform( 93 lambda node: exp.column(node.output_name, quoted=node.this.quoted) 94 if isinstance(node, exp.Column) 95 else node, 96 copy=False, 97 ) 98 99 return expression 100 101 102class Spark2(Hive): 103 class Parser(Hive.Parser): 104 FUNCTIONS = { 105 **Hive.Parser.FUNCTIONS, # type: ignore 106 "MAP_FROM_ARRAYS": exp.Map.from_arg_list, 107 "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, 108 "LEFT": lambda args: exp.Substring( 109 this=seq_get(args, 0), 110 start=exp.Literal.number(1), 111 length=seq_get(args, 1), 112 ), 113 "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( 114 this=seq_get(args, 0), 115 expression=seq_get(args, 1), 116 ), 117 "SHIFTRIGHT": lambda args: exp.BitwiseRightShift( 118 this=seq_get(args, 0), 119 expression=seq_get(args, 1), 120 ), 121 "RIGHT": lambda args: exp.Substring( 122 this=seq_get(args, 0), 123 start=exp.Sub( 124 this=exp.Length(this=seq_get(args, 0)), 125 expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)), 126 ), 127 length=seq_get(args, 1), 128 ), 129 "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, 130 "IIF": exp.If.from_arg_list, 131 "AGGREGATE": exp.Reduce.from_arg_list, 132 "DAYOFWEEK": lambda args: exp.DayOfWeek( 133 this=exp.TsOrDsToDate(this=seq_get(args, 0)), 134 ), 135 "DAYOFMONTH": lambda args: exp.DayOfMonth( 136 this=exp.TsOrDsToDate(this=seq_get(args, 0)), 137 ), 138 "DAYOFYEAR": lambda args: exp.DayOfYear( 139 this=exp.TsOrDsToDate(this=seq_get(args, 0)), 140 ), 141 "WEEKOFYEAR": lambda args: exp.WeekOfYear( 142 this=exp.TsOrDsToDate(this=seq_get(args, 0)), 143 ), 144 "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")), 145 "DATE_TRUNC": lambda args: exp.TimestampTrunc( 146 this=seq_get(args, 1), 147 unit=exp.var(seq_get(args, 0)), 148 ), 149 "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), 150 "BOOLEAN": _parse_as_cast("boolean"), 151 "DOUBLE": _parse_as_cast("double"), 152 "FLOAT": _parse_as_cast("float"), 153 "INT": _parse_as_cast("int"), 154 "STRING": _parse_as_cast("string"), 155 "TIMESTAMP": _parse_as_cast("timestamp"), 156 } 157 158 FUNCTION_PARSERS = { 159 **parser.Parser.FUNCTION_PARSERS, # type: ignore 160 "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), 161 "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), 162 "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), 163 "MERGE": lambda self: self._parse_join_hint("MERGE"), 164 "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"), 165 "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"), 166 "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"), 167 "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), 168 } 169 170 def _parse_add_column(self) -> t.Optional[exp.Expression]: 171 return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema() 172 173 def _parse_drop_column(self) -> t.Optional[exp.Expression]: 174 return self._match_text_seq("DROP", "COLUMNS") and self.expression( 175 exp.Drop, 176 this=self._parse_schema(), 177 kind="COLUMNS", 178 ) 179 180 def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: 181 # Spark doesn't add a suffix to the pivot columns when there's a single aggregation 182 if len(pivot_columns) == 1: 183 return [""] 184 185 names = [] 186 for agg in pivot_columns: 187 if isinstance(agg, exp.Alias): 188 names.append(agg.alias) 189 else: 190 """ 191 This case corresponds to aggregations without aliases being used as suffixes 192 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 193 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 194 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 195 196 Moreover, function names are lowercased in order to mimic Spark's naming scheme. 197 """ 198 agg_all_unquoted = agg.transform( 199 lambda node: exp.Identifier(this=node.name, quoted=False) 200 if isinstance(node, exp.Identifier) 201 else node 202 ) 203 names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower")) 204 205 return names 206 207 class Generator(Hive.Generator): 208 TYPE_MAPPING = { 209 **Hive.Generator.TYPE_MAPPING, # type: ignore 210 exp.DataType.Type.TINYINT: "BYTE", 211 exp.DataType.Type.SMALLINT: "SHORT", 212 exp.DataType.Type.BIGINT: "LONG", 213 } 214 215 PROPERTIES_LOCATION = { 216 **Hive.Generator.PROPERTIES_LOCATION, # type: ignore 217 exp.EngineProperty: exp.Properties.Location.UNSUPPORTED, 218 exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED, 219 exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED, 220 exp.CollateProperty: exp.Properties.Location.UNSUPPORTED, 221 } 222 223 TRANSFORMS = { 224 **Hive.Generator.TRANSFORMS, # type: ignore 225 exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), 226 exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", 227 exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", 228 exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), 229 exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), 230 exp.Create: _create_sql, 231 exp.DateFromParts: rename_func("MAKE_DATE"), 232 exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), 233 exp.DayOfMonth: rename_func("DAYOFMONTH"), 234 exp.DayOfWeek: rename_func("DAYOFWEEK"), 235 exp.DayOfYear: rename_func("DAYOFYEAR"), 236 exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", 237 exp.From: transforms.preprocess([_unalias_pivot]), 238 exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", 239 exp.LogicalAnd: rename_func("BOOL_AND"), 240 exp.LogicalOr: rename_func("BOOL_OR"), 241 exp.Map: _map_sql, 242 exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]), 243 exp.Reduce: rename_func("AGGREGATE"), 244 exp.StrToDate: _str_to_date, 245 exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", 246 exp.TimestampTrunc: lambda self, e: self.func( 247 "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this 248 ), 249 exp.Trim: trim_sql, 250 exp.UnixToTime: _unix_to_time_sql, 251 exp.VariancePop: rename_func("VAR_POP"), 252 exp.WeekOfYear: rename_func("WEEKOFYEAR"), 253 exp.WithinGroup: transforms.preprocess( 254 [transforms.remove_within_group_for_percentiles] 255 ), 256 } 257 TRANSFORMS.pop(exp.ArrayJoin) 258 TRANSFORMS.pop(exp.ArraySort) 259 TRANSFORMS.pop(exp.ILike) 260 261 WRAP_DERIVED_VALUES = False 262 CREATE_FUNCTION_RETURN_AS = False 263 264 def cast_sql(self, expression: exp.Cast) -> str: 265 if isinstance(expression.this, exp.Cast) and expression.this.is_type( 266 exp.DataType.Type.JSON 267 ): 268 schema = f"'{self.sql(expression, 'to')}'" 269 return self.func("FROM_JSON", expression.this.this, schema) 270 if expression.to.is_type(exp.DataType.Type.JSON): 271 return self.func("TO_JSON", expression.this) 272 273 return super(Hive.Generator, self).cast_sql(expression) 274 275 def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: 276 return super().columndef_sql( 277 expression, 278 sep=": " 279 if isinstance(expression.parent, exp.DataType) 280 and expression.parent.is_type(exp.DataType.Type.STRUCT) 281 else sep, 282 ) 283 284 class Tokenizer(Hive.Tokenizer): 285 HEX_STRINGS = [("X'", "'")]
103class Spark2(Hive): 104 class Parser(Hive.Parser): 105 FUNCTIONS = { 106 **Hive.Parser.FUNCTIONS, # type: ignore 107 "MAP_FROM_ARRAYS": exp.Map.from_arg_list, 108 "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, 109 "LEFT": lambda args: exp.Substring( 110 this=seq_get(args, 0), 111 start=exp.Literal.number(1), 112 length=seq_get(args, 1), 113 ), 114 "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( 115 this=seq_get(args, 0), 116 expression=seq_get(args, 1), 117 ), 118 "SHIFTRIGHT": lambda args: exp.BitwiseRightShift( 119 this=seq_get(args, 0), 120 expression=seq_get(args, 1), 121 ), 122 "RIGHT": lambda args: exp.Substring( 123 this=seq_get(args, 0), 124 start=exp.Sub( 125 this=exp.Length(this=seq_get(args, 0)), 126 expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)), 127 ), 128 length=seq_get(args, 1), 129 ), 130 "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, 131 "IIF": exp.If.from_arg_list, 132 "AGGREGATE": exp.Reduce.from_arg_list, 133 "DAYOFWEEK": lambda args: exp.DayOfWeek( 134 this=exp.TsOrDsToDate(this=seq_get(args, 0)), 135 ), 136 "DAYOFMONTH": lambda args: exp.DayOfMonth( 137 this=exp.TsOrDsToDate(this=seq_get(args, 0)), 138 ), 139 "DAYOFYEAR": lambda args: exp.DayOfYear( 140 this=exp.TsOrDsToDate(this=seq_get(args, 0)), 141 ), 142 "WEEKOFYEAR": lambda args: exp.WeekOfYear( 143 this=exp.TsOrDsToDate(this=seq_get(args, 0)), 144 ), 145 "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")), 146 "DATE_TRUNC": lambda args: exp.TimestampTrunc( 147 this=seq_get(args, 1), 148 unit=exp.var(seq_get(args, 0)), 149 ), 150 "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), 151 "BOOLEAN": _parse_as_cast("boolean"), 152 "DOUBLE": _parse_as_cast("double"), 153 "FLOAT": _parse_as_cast("float"), 154 "INT": _parse_as_cast("int"), 155 "STRING": _parse_as_cast("string"), 156 "TIMESTAMP": _parse_as_cast("timestamp"), 157 } 158 159 FUNCTION_PARSERS = { 160 **parser.Parser.FUNCTION_PARSERS, # type: ignore 161 "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), 162 "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), 163 "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), 164 "MERGE": lambda self: self._parse_join_hint("MERGE"), 165 "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"), 166 "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"), 167 "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"), 168 "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), 169 } 170 171 def _parse_add_column(self) -> t.Optional[exp.Expression]: 172 return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema() 173 174 def _parse_drop_column(self) -> t.Optional[exp.Expression]: 175 return self._match_text_seq("DROP", "COLUMNS") and self.expression( 176 exp.Drop, 177 this=self._parse_schema(), 178 kind="COLUMNS", 179 ) 180 181 def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: 182 # Spark doesn't add a suffix to the pivot columns when there's a single aggregation 183 if len(pivot_columns) == 1: 184 return [""] 185 186 names = [] 187 for agg in pivot_columns: 188 if isinstance(agg, exp.Alias): 189 names.append(agg.alias) 190 else: 191 """ 192 This case corresponds to aggregations without aliases being used as suffixes 193 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 194 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 195 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 196 197 Moreover, function names are lowercased in order to mimic Spark's naming scheme. 198 """ 199 agg_all_unquoted = agg.transform( 200 lambda node: exp.Identifier(this=node.name, quoted=False) 201 if isinstance(node, exp.Identifier) 202 else node 203 ) 204 names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower")) 205 206 return names 207 208 class Generator(Hive.Generator): 209 TYPE_MAPPING = { 210 **Hive.Generator.TYPE_MAPPING, # type: ignore 211 exp.DataType.Type.TINYINT: "BYTE", 212 exp.DataType.Type.SMALLINT: "SHORT", 213 exp.DataType.Type.BIGINT: "LONG", 214 } 215 216 PROPERTIES_LOCATION = { 217 **Hive.Generator.PROPERTIES_LOCATION, # type: ignore 218 exp.EngineProperty: exp.Properties.Location.UNSUPPORTED, 219 exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED, 220 exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED, 221 exp.CollateProperty: exp.Properties.Location.UNSUPPORTED, 222 } 223 224 TRANSFORMS = { 225 **Hive.Generator.TRANSFORMS, # type: ignore 226 exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), 227 exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", 228 exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", 229 exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), 230 exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), 231 exp.Create: _create_sql, 232 exp.DateFromParts: rename_func("MAKE_DATE"), 233 exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), 234 exp.DayOfMonth: rename_func("DAYOFMONTH"), 235 exp.DayOfWeek: rename_func("DAYOFWEEK"), 236 exp.DayOfYear: rename_func("DAYOFYEAR"), 237 exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", 238 exp.From: transforms.preprocess([_unalias_pivot]), 239 exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", 240 exp.LogicalAnd: rename_func("BOOL_AND"), 241 exp.LogicalOr: rename_func("BOOL_OR"), 242 exp.Map: _map_sql, 243 exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]), 244 exp.Reduce: rename_func("AGGREGATE"), 245 exp.StrToDate: _str_to_date, 246 exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", 247 exp.TimestampTrunc: lambda self, e: self.func( 248 "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this 249 ), 250 exp.Trim: trim_sql, 251 exp.UnixToTime: _unix_to_time_sql, 252 exp.VariancePop: rename_func("VAR_POP"), 253 exp.WeekOfYear: rename_func("WEEKOFYEAR"), 254 exp.WithinGroup: transforms.preprocess( 255 [transforms.remove_within_group_for_percentiles] 256 ), 257 } 258 TRANSFORMS.pop(exp.ArrayJoin) 259 TRANSFORMS.pop(exp.ArraySort) 260 TRANSFORMS.pop(exp.ILike) 261 262 WRAP_DERIVED_VALUES = False 263 CREATE_FUNCTION_RETURN_AS = False 264 265 def cast_sql(self, expression: exp.Cast) -> str: 266 if isinstance(expression.this, exp.Cast) and expression.this.is_type( 267 exp.DataType.Type.JSON 268 ): 269 schema = f"'{self.sql(expression, 'to')}'" 270 return self.func("FROM_JSON", expression.this.this, schema) 271 if expression.to.is_type(exp.DataType.Type.JSON): 272 return self.func("TO_JSON", expression.this) 273 274 return super(Hive.Generator, self).cast_sql(expression) 275 276 def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: 277 return super().columndef_sql( 278 expression, 279 sep=": " 280 if isinstance(expression.parent, exp.DataType) 281 and expression.parent.is_type(exp.DataType.Type.STRUCT) 282 else sep, 283 ) 284 285 class Tokenizer(Hive.Tokenizer): 286 HEX_STRINGS = [("X'", "'")]
104 class Parser(Hive.Parser): 105 FUNCTIONS = { 106 **Hive.Parser.FUNCTIONS, # type: ignore 107 "MAP_FROM_ARRAYS": exp.Map.from_arg_list, 108 "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, 109 "LEFT": lambda args: exp.Substring( 110 this=seq_get(args, 0), 111 start=exp.Literal.number(1), 112 length=seq_get(args, 1), 113 ), 114 "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( 115 this=seq_get(args, 0), 116 expression=seq_get(args, 1), 117 ), 118 "SHIFTRIGHT": lambda args: exp.BitwiseRightShift( 119 this=seq_get(args, 0), 120 expression=seq_get(args, 1), 121 ), 122 "RIGHT": lambda args: exp.Substring( 123 this=seq_get(args, 0), 124 start=exp.Sub( 125 this=exp.Length(this=seq_get(args, 0)), 126 expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)), 127 ), 128 length=seq_get(args, 1), 129 ), 130 "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, 131 "IIF": exp.If.from_arg_list, 132 "AGGREGATE": exp.Reduce.from_arg_list, 133 "DAYOFWEEK": lambda args: exp.DayOfWeek( 134 this=exp.TsOrDsToDate(this=seq_get(args, 0)), 135 ), 136 "DAYOFMONTH": lambda args: exp.DayOfMonth( 137 this=exp.TsOrDsToDate(this=seq_get(args, 0)), 138 ), 139 "DAYOFYEAR": lambda args: exp.DayOfYear( 140 this=exp.TsOrDsToDate(this=seq_get(args, 0)), 141 ), 142 "WEEKOFYEAR": lambda args: exp.WeekOfYear( 143 this=exp.TsOrDsToDate(this=seq_get(args, 0)), 144 ), 145 "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")), 146 "DATE_TRUNC": lambda args: exp.TimestampTrunc( 147 this=seq_get(args, 1), 148 unit=exp.var(seq_get(args, 0)), 149 ), 150 "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), 151 "BOOLEAN": _parse_as_cast("boolean"), 152 "DOUBLE": _parse_as_cast("double"), 153 "FLOAT": _parse_as_cast("float"), 154 "INT": _parse_as_cast("int"), 155 "STRING": _parse_as_cast("string"), 156 "TIMESTAMP": _parse_as_cast("timestamp"), 157 } 158 159 FUNCTION_PARSERS = { 160 **parser.Parser.FUNCTION_PARSERS, # type: ignore 161 "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), 162 "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), 163 "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), 164 "MERGE": lambda self: self._parse_join_hint("MERGE"), 165 "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"), 166 "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"), 167 "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"), 168 "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), 169 } 170 171 def _parse_add_column(self) -> t.Optional[exp.Expression]: 172 return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema() 173 174 def _parse_drop_column(self) -> t.Optional[exp.Expression]: 175 return self._match_text_seq("DROP", "COLUMNS") and self.expression( 176 exp.Drop, 177 this=self._parse_schema(), 178 kind="COLUMNS", 179 ) 180 181 def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: 182 # Spark doesn't add a suffix to the pivot columns when there's a single aggregation 183 if len(pivot_columns) == 1: 184 return [""] 185 186 names = [] 187 for agg in pivot_columns: 188 if isinstance(agg, exp.Alias): 189 names.append(agg.alias) 190 else: 191 """ 192 This case corresponds to aggregations without aliases being used as suffixes 193 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 194 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 195 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 196 197 Moreover, function names are lowercased in order to mimic Spark's naming scheme. 198 """ 199 agg_all_unquoted = agg.transform( 200 lambda node: exp.Identifier(this=node.name, quoted=False) 201 if isinstance(node, exp.Identifier) 202 else node 203 ) 204 names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower")) 205 206 return names
Parser consumes a list of tokens produced by the sqlglot.tokens.Tokenizer
and produces
a parsed syntax tree.
Arguments:
- error_level: the desired error level. Default: ErrorLevel.RAISE
- error_message_context: determines the amount of context to capture from a query string when displaying the error message (in number of characters). Default: 50.
- index_offset: Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list. Default: 0
- alias_post_tablesample: If the table alias comes after tablesample. Default: False
- max_errors: Maximum number of error messages to include in a raised ParseError. This is only relevant if error_level is ErrorLevel.RAISE. Default: 3
- null_ordering: Indicates the default null ordering method to use if not explicitly set. Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". Default: "nulls_are_small"
Inherited Members
208 class Generator(Hive.Generator): 209 TYPE_MAPPING = { 210 **Hive.Generator.TYPE_MAPPING, # type: ignore 211 exp.DataType.Type.TINYINT: "BYTE", 212 exp.DataType.Type.SMALLINT: "SHORT", 213 exp.DataType.Type.BIGINT: "LONG", 214 } 215 216 PROPERTIES_LOCATION = { 217 **Hive.Generator.PROPERTIES_LOCATION, # type: ignore 218 exp.EngineProperty: exp.Properties.Location.UNSUPPORTED, 219 exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED, 220 exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED, 221 exp.CollateProperty: exp.Properties.Location.UNSUPPORTED, 222 } 223 224 TRANSFORMS = { 225 **Hive.Generator.TRANSFORMS, # type: ignore 226 exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), 227 exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", 228 exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", 229 exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), 230 exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), 231 exp.Create: _create_sql, 232 exp.DateFromParts: rename_func("MAKE_DATE"), 233 exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), 234 exp.DayOfMonth: rename_func("DAYOFMONTH"), 235 exp.DayOfWeek: rename_func("DAYOFWEEK"), 236 exp.DayOfYear: rename_func("DAYOFYEAR"), 237 exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", 238 exp.From: transforms.preprocess([_unalias_pivot]), 239 exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", 240 exp.LogicalAnd: rename_func("BOOL_AND"), 241 exp.LogicalOr: rename_func("BOOL_OR"), 242 exp.Map: _map_sql, 243 exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]), 244 exp.Reduce: rename_func("AGGREGATE"), 245 exp.StrToDate: _str_to_date, 246 exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", 247 exp.TimestampTrunc: lambda self, e: self.func( 248 "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this 249 ), 250 exp.Trim: trim_sql, 251 exp.UnixToTime: _unix_to_time_sql, 252 exp.VariancePop: rename_func("VAR_POP"), 253 exp.WeekOfYear: rename_func("WEEKOFYEAR"), 254 exp.WithinGroup: transforms.preprocess( 255 [transforms.remove_within_group_for_percentiles] 256 ), 257 } 258 TRANSFORMS.pop(exp.ArrayJoin) 259 TRANSFORMS.pop(exp.ArraySort) 260 TRANSFORMS.pop(exp.ILike) 261 262 WRAP_DERIVED_VALUES = False 263 CREATE_FUNCTION_RETURN_AS = False 264 265 def cast_sql(self, expression: exp.Cast) -> str: 266 if isinstance(expression.this, exp.Cast) and expression.this.is_type( 267 exp.DataType.Type.JSON 268 ): 269 schema = f"'{self.sql(expression, 'to')}'" 270 return self.func("FROM_JSON", expression.this.this, schema) 271 if expression.to.is_type(exp.DataType.Type.JSON): 272 return self.func("TO_JSON", expression.this) 273 274 return super(Hive.Generator, self).cast_sql(expression) 275 276 def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: 277 return super().columndef_sql( 278 expression, 279 sep=": " 280 if isinstance(expression.parent, exp.DataType) 281 and expression.parent.is_type(exp.DataType.Type.STRUCT) 282 else sep, 283 )
Generator interprets the given syntax tree and produces a SQL string as an output.
Arguments:
- time_mapping (dict): the dictionary of custom time mappings in which the key represents a python time format and the output the target time format
- time_trie (trie): a trie of the time_mapping keys
- pretty (bool): if set to True the returned string will be formatted. Default: False.
- quote_start (str): specifies which starting character to use to delimit quotes. Default: '.
- quote_end (str): specifies which ending character to use to delimit quotes. Default: '.
- identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ".
- identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ".
- bit_start (str): specifies which starting character to use to delimit bit literals. Default: None.
- bit_end (str): specifies which ending character to use to delimit bit literals. Default: None.
- hex_start (str): specifies which starting character to use to delimit hex literals. Default: None.
- hex_end (str): specifies which ending character to use to delimit hex literals. Default: None.
- byte_start (str): specifies which starting character to use to delimit byte literals. Default: None.
- byte_end (str): specifies which ending character to use to delimit byte literals. Default: None.
- identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
- normalize (bool): if set to True all identifiers will lower cased
- string_escape (str): specifies a string escape character. Default: '.
- identifier_escape (str): specifies an identifier escape character. Default: ".
- pad (int): determines padding in a formatted string. Default: 2.
- indent (int): determines the size of indentation in a formatted string. Default: 4.
- unnest_column_only (bool): if true unnest table aliases are considered only as column aliases
- normalize_functions (str): normalize function names, "upper", "lower", or None Default: "upper"
- alias_post_tablesample (bool): if the table alias comes after tablesample Default: False
- unsupported_level (ErrorLevel): determines the generator's behavior when it encounters unsupported expressions. Default ErrorLevel.WARN.
- null_ordering (str): Indicates the default null ordering method to use if not explicitly set. Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". Default: "nulls_are_small"
- max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError. This is only relevant if unsupported_level is ErrorLevel.RAISE. Default: 3
- leading_comma (bool): if the the comma is leading or trailing in select statements Default: False
- max_text_width: The max number of characters in a segment before creating new lines in pretty mode. The default is on the smaller end because the length only represents a segment and not the true line length. Default: 80
- comments: Whether or not to preserve comments in the output SQL code. Default: True
265 def cast_sql(self, expression: exp.Cast) -> str: 266 if isinstance(expression.this, exp.Cast) and expression.this.is_type( 267 exp.DataType.Type.JSON 268 ): 269 schema = f"'{self.sql(expression, 'to')}'" 270 return self.func("FROM_JSON", expression.this.this, schema) 271 if expression.to.is_type(exp.DataType.Type.JSON): 272 return self.func("TO_JSON", expression.this) 273 274 return super(Hive.Generator, self).cast_sql(expression)
Inherited Members
- sqlglot.generator.Generator
- Generator
- generate
- unsupported
- sep
- seg
- pad_comment
- maybe_comment
- wrap
- no_identify
- normalize_func
- indent
- sql
- uncache_sql
- cache_sql
- characterset_sql
- column_sql
- columnposition_sql
- columnconstraint_sql
- autoincrementcolumnconstraint_sql
- compresscolumnconstraint_sql
- generatedasidentitycolumnconstraint_sql
- notnullcolumnconstraint_sql
- primarykeycolumnconstraint_sql
- uniquecolumnconstraint_sql
- create_sql
- clone_sql
- describe_sql
- prepend_ctes
- with_sql
- cte_sql
- tablealias_sql
- bitstring_sql
- hexstring_sql
- bytestring_sql
- datatypesize_sql
- directory_sql
- delete_sql
- drop_sql
- except_sql
- except_op
- fetch_sql
- filter_sql
- hint_sql
- index_sql
- identifier_sql
- inputoutputformat_sql
- national_sql
- partition_sql
- properties_sql
- root_properties
- properties
- locate_properties
- property_sql
- likeproperty_sql
- fallbackproperty_sql
- journalproperty_sql
- freespaceproperty_sql
- afterjournalproperty_sql
- checksumproperty_sql
- mergeblockratioproperty_sql
- datablocksizeproperty_sql
- blockcompressionproperty_sql
- isolatedloadingproperty_sql
- lockingproperty_sql
- withdataproperty_sql
- insert_sql
- intersect_sql
- intersect_op
- introducer_sql
- pseudotype_sql
- onconflict_sql
- returning_sql
- rowformatdelimitedproperty_sql
- table_sql
- tablesample_sql
- pivot_sql
- tuple_sql
- update_sql
- values_sql
- var_sql
- into_sql
- from_sql
- group_sql
- having_sql
- join_sql
- lambda_sql
- lateral_sql
- limit_sql
- offset_sql
- setitem_sql
- set_sql
- pragma_sql
- lock_sql
- literal_sql
- loaddata_sql
- null_sql
- boolean_sql
- order_sql
- cluster_sql
- distribute_sql
- sort_sql
- ordered_sql
- matchrecognize_sql
- query_modifiers
- after_limit_modifiers
- select_sql
- schema_sql
- star_sql
- parameter_sql
- sessionparameter_sql
- placeholder_sql
- subquery_sql
- qualify_sql
- union_sql
- union_op
- unnest_sql
- where_sql
- window_sql
- partition_by_sql
- windowspec_sql
- withingroup_sql
- between_sql
- bracket_sql
- all_sql
- any_sql
- exists_sql
- case_sql
- constraint_sql
- nextvaluefor_sql
- extract_sql
- trim_sql
- concat_sql
- check_sql
- foreignkey_sql
- primarykey_sql
- unique_sql
- if_sql
- matchagainst_sql
- jsonkeyvalue_sql
- jsonobject_sql
- openjsoncolumndef_sql
- openjson_sql
- in_sql
- in_unnest_op
- interval_sql
- return_sql
- reference_sql
- anonymous_sql
- paren_sql
- neg_sql
- not_sql
- alias_sql
- aliases_sql
- attimezone_sql
- add_sql
- and_sql
- connector_sql
- bitwiseand_sql
- bitwiseleftshift_sql
- bitwisenot_sql
- bitwiseor_sql
- bitwiserightshift_sql
- bitwisexor_sql
- currentdate_sql
- collate_sql
- command_sql
- comment_sql
- mergetreettlaction_sql
- mergetreettl_sql
- transaction_sql
- commit_sql
- rollback_sql
- altercolumn_sql
- renametable_sql
- altertable_sql
- droppartition_sql
- addconstraint_sql
- distinct_sql
- ignorenulls_sql
- respectnulls_sql
- intdiv_sql
- dpipe_sql
- div_sql
- overlaps_sql
- distance_sql
- dot_sql
- eq_sql
- escape_sql
- glob_sql
- gt_sql
- gte_sql
- ilike_sql
- ilikeany_sql
- is_sql
- like_sql
- likeany_sql
- similarto_sql
- lt_sql
- lte_sql
- mod_sql
- mul_sql
- neq_sql
- nullsafeeq_sql
- nullsafeneq_sql
- or_sql
- slice_sql
- sub_sql
- trycast_sql
- use_sql
- binary
- function_fallback_sql
- func
- format_args
- text_width
- format_time
- expressions
- op_expressions
- naked_property
- set_operation
- tag_sql
- token_sql
- userdefinedfunction_sql
- joinhint_sql
- kwarg_sql
- when_sql
- merge_sql
- tochar_sql