Edit on GitHub

sqlglot.dataframe.sql

 1from sqlglot.dataframe.sql.column import Column
 2from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions
 3from sqlglot.dataframe.sql.group import GroupedData
 4from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter
 5from sqlglot.dataframe.sql.session import SparkSession
 6from sqlglot.dataframe.sql.window import Window, WindowSpec
 7
 8__all__ = [
 9    "SparkSession",
10    "DataFrame",
11    "GroupedData",
12    "Column",
13    "DataFrameNaFunctions",
14    "Window",
15    "WindowSpec",
16    "DataFrameReader",
17    "DataFrameWriter",
18]
class SparkSession:
 23class SparkSession:
 24    DEFAULT_DIALECT = "spark"
 25    _instance = None
 26
 27    def __init__(self):
 28        if not hasattr(self, "known_ids"):
 29            self.known_ids = set()
 30            self.known_branch_ids = set()
 31            self.known_sequence_ids = set()
 32            self.name_to_sequence_id_mapping = defaultdict(list)
 33            self.incrementing_id = 1
 34            self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)
 35
 36    def __new__(cls, *args, **kwargs) -> SparkSession:
 37        if cls._instance is None:
 38            cls._instance = super().__new__(cls)
 39        return cls._instance
 40
 41    @property
 42    def read(self) -> DataFrameReader:
 43        return DataFrameReader(self)
 44
 45    def table(self, tableName: str) -> DataFrame:
 46        return self.read.table(tableName)
 47
 48    def createDataFrame(
 49        self,
 50        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
 51        schema: t.Optional[SchemaInput] = None,
 52        samplingRatio: t.Optional[float] = None,
 53        verifySchema: bool = False,
 54    ) -> DataFrame:
 55        from sqlglot.dataframe.sql.dataframe import DataFrame
 56
 57        if samplingRatio is not None or verifySchema:
 58            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
 59        if schema is not None and (
 60            not isinstance(schema, (StructType, str, list))
 61            or (isinstance(schema, list) and not isinstance(schema[0], str))
 62        ):
 63            raise NotImplementedError("Only schema of either list or string of list supported")
 64        if not data:
 65            raise ValueError("Must provide data to create into a DataFrame")
 66
 67        column_mapping: t.Dict[str, t.Optional[str]]
 68        if schema is not None:
 69            column_mapping = get_column_mapping_from_schema_input(schema)
 70        elif isinstance(data[0], dict):
 71            column_mapping = {col_name.strip(): None for col_name in data[0]}
 72        else:
 73            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
 74
 75        data_expressions = [
 76            exp.tuple_(
 77                *map(
 78                    lambda x: F.lit(x).expression,
 79                    row if not isinstance(row, dict) else row.values(),
 80                )
 81            )
 82            for row in data
 83        ]
 84
 85        sel_columns = [
 86            (
 87                F.col(name).cast(data_type).alias(name).expression
 88                if data_type is not None
 89                else F.col(name).expression
 90            )
 91            for name, data_type in column_mapping.items()
 92        ]
 93
 94        select_kwargs = {
 95            "expressions": sel_columns,
 96            "from": exp.From(
 97                this=exp.Values(
 98                    expressions=data_expressions,
 99                    alias=exp.TableAlias(
100                        this=exp.to_identifier(self._auto_incrementing_name),
101                        columns=[exp.to_identifier(col_name) for col_name in column_mapping],
102                    ),
103                ),
104            ),
105        }
106
107        sel_expression = exp.Select(**select_kwargs)
108        return DataFrame(self, sel_expression)
109
110    def _optimize(
111        self, expression: exp.Expression, dialect: t.Optional[Dialect] = None
112    ) -> exp.Expression:
113        dialect = dialect or self.dialect
114        quote_identifiers(expression, dialect=dialect)
115        return optimize(expression, dialect=dialect)
116
117    def sql(self, sqlQuery: str) -> DataFrame:
118        expression = self._optimize(sqlglot.parse_one(sqlQuery, read=self.dialect))
119        if isinstance(expression, exp.Select):
120            df = DataFrame(self, expression)
121            df = df._convert_leaf_to_cte()
122        elif isinstance(expression, (exp.Create, exp.Insert)):
123            select_expression = expression.expression.copy()
124            if isinstance(expression, exp.Insert):
125                select_expression.set("with", expression.args.get("with"))
126                expression.set("with", None)
127            del expression.args["expression"]
128            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
129            df = df._convert_leaf_to_cte()
130        else:
131            raise ValueError(
132                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
133            )
134        return df
135
136    @property
137    def _auto_incrementing_name(self) -> str:
138        name = f"a{self.incrementing_id}"
139        self.incrementing_id += 1
140        return name
141
142    @property
143    def _random_branch_id(self) -> str:
144        id = self._random_id
145        self.known_branch_ids.add(id)
146        return id
147
148    @property
149    def _random_sequence_id(self):
150        id = self._random_id
151        self.known_sequence_ids.add(id)
152        return id
153
154    @property
155    def _random_id(self) -> str:
156        id = "r" + uuid.uuid4().hex
157        self.known_ids.add(id)
158        return id
159
160    @property
161    def _join_hint_names(self) -> t.Set[str]:
162        return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
163
164    def _add_alias_to_mapping(self, name: str, sequence_id: str):
165        self.name_to_sequence_id_mapping[name].append(sequence_id)
166
167    class Builder:
168        SQLFRAME_DIALECT_KEY = "sqlframe.dialect"
169
170        def __init__(self):
171            self.dialect = "spark"
172
173        def __getattr__(self, item) -> SparkSession.Builder:
174            return self
175
176        def __call__(self, *args, **kwargs):
177            return self
178
179        def config(
180            self,
181            key: t.Optional[str] = None,
182            value: t.Optional[t.Any] = None,
183            *,
184            map: t.Optional[t.Dict[str, t.Any]] = None,
185            **kwargs: t.Any,
186        ) -> SparkSession.Builder:
187            if key == self.SQLFRAME_DIALECT_KEY:
188                self.dialect = value
189            elif map and self.SQLFRAME_DIALECT_KEY in map:
190                self.dialect = map[self.SQLFRAME_DIALECT_KEY]
191            return self
192
193        def getOrCreate(self) -> SparkSession:
194            spark = SparkSession()
195            spark.dialect = Dialect.get_or_raise(self.dialect)
196            return spark
197
198    @classproperty
199    def builder(cls) -> Builder:
200        return cls.Builder()
DEFAULT_DIALECT = 'spark'
read: DataFrameReader
41    @property
42    def read(self) -> DataFrameReader:
43        return DataFrameReader(self)
def table(self, tableName: str) -> DataFrame:
45    def table(self, tableName: str) -> DataFrame:
46        return self.read.table(tableName)
def createDataFrame( self, data: Sequence[Union[Dict[str, <MagicMock id='140281896367056'>], List[<MagicMock id='140281896367056'>], Tuple]], schema: Optional[<MagicMock id='140281896388256'>] = None, samplingRatio: Optional[float] = None, verifySchema: bool = False) -> DataFrame:
 48    def createDataFrame(
 49        self,
 50        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
 51        schema: t.Optional[SchemaInput] = None,
 52        samplingRatio: t.Optional[float] = None,
 53        verifySchema: bool = False,
 54    ) -> DataFrame:
 55        from sqlglot.dataframe.sql.dataframe import DataFrame
 56
 57        if samplingRatio is not None or verifySchema:
 58            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
 59        if schema is not None and (
 60            not isinstance(schema, (StructType, str, list))
 61            or (isinstance(schema, list) and not isinstance(schema[0], str))
 62        ):
 63            raise NotImplementedError("Only schema of either list or string of list supported")
 64        if not data:
 65            raise ValueError("Must provide data to create into a DataFrame")
 66
 67        column_mapping: t.Dict[str, t.Optional[str]]
 68        if schema is not None:
 69            column_mapping = get_column_mapping_from_schema_input(schema)
 70        elif isinstance(data[0], dict):
 71            column_mapping = {col_name.strip(): None for col_name in data[0]}
 72        else:
 73            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
 74
 75        data_expressions = [
 76            exp.tuple_(
 77                *map(
 78                    lambda x: F.lit(x).expression,
 79                    row if not isinstance(row, dict) else row.values(),
 80                )
 81            )
 82            for row in data
 83        ]
 84
 85        sel_columns = [
 86            (
 87                F.col(name).cast(data_type).alias(name).expression
 88                if data_type is not None
 89                else F.col(name).expression
 90            )
 91            for name, data_type in column_mapping.items()
 92        ]
 93
 94        select_kwargs = {
 95            "expressions": sel_columns,
 96            "from": exp.From(
 97                this=exp.Values(
 98                    expressions=data_expressions,
 99                    alias=exp.TableAlias(
100                        this=exp.to_identifier(self._auto_incrementing_name),
101                        columns=[exp.to_identifier(col_name) for col_name in column_mapping],
102                    ),
103                ),
104            ),
105        }
106
107        sel_expression = exp.Select(**select_kwargs)
108        return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> DataFrame:
117    def sql(self, sqlQuery: str) -> DataFrame:
118        expression = self._optimize(sqlglot.parse_one(sqlQuery, read=self.dialect))
119        if isinstance(expression, exp.Select):
120            df = DataFrame(self, expression)
121            df = df._convert_leaf_to_cte()
122        elif isinstance(expression, (exp.Create, exp.Insert)):
123            select_expression = expression.expression.copy()
124            if isinstance(expression, exp.Insert):
125                select_expression.set("with", expression.args.get("with"))
126                expression.set("with", None)
127            del expression.args["expression"]
128            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
129            df = df._convert_leaf_to_cte()
130        else:
131            raise ValueError(
132                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
133            )
134        return df
builder: SparkSession.Builder
198    @classproperty
199    def builder(cls) -> Builder:
200        return cls.Builder()
class SparkSession.Builder:
167    class Builder:
168        SQLFRAME_DIALECT_KEY = "sqlframe.dialect"
169
170        def __init__(self):
171            self.dialect = "spark"
172
173        def __getattr__(self, item) -> SparkSession.Builder:
174            return self
175
176        def __call__(self, *args, **kwargs):
177            return self
178
179        def config(
180            self,
181            key: t.Optional[str] = None,
182            value: t.Optional[t.Any] = None,
183            *,
184            map: t.Optional[t.Dict[str, t.Any]] = None,
185            **kwargs: t.Any,
186        ) -> SparkSession.Builder:
187            if key == self.SQLFRAME_DIALECT_KEY:
188                self.dialect = value
189            elif map and self.SQLFRAME_DIALECT_KEY in map:
190                self.dialect = map[self.SQLFRAME_DIALECT_KEY]
191            return self
192
193        def getOrCreate(self) -> SparkSession:
194            spark = SparkSession()
195            spark.dialect = Dialect.get_or_raise(self.dialect)
196            return spark
SQLFRAME_DIALECT_KEY = 'sqlframe.dialect'
dialect
def config( self, key: Optional[str] = None, value: Optional[Any] = None, *, map: Optional[Dict[str, Any]] = None, **kwargs: Any) -> SparkSession.Builder:
179        def config(
180            self,
181            key: t.Optional[str] = None,
182            value: t.Optional[t.Any] = None,
183            *,
184            map: t.Optional[t.Dict[str, t.Any]] = None,
185            **kwargs: t.Any,
186        ) -> SparkSession.Builder:
187            if key == self.SQLFRAME_DIALECT_KEY:
188                self.dialect = value
189            elif map and self.SQLFRAME_DIALECT_KEY in map:
190                self.dialect = map[self.SQLFRAME_DIALECT_KEY]
191            return self
def getOrCreate(self) -> SparkSession:
193        def getOrCreate(self) -> SparkSession:
194            spark = SparkSession()
195            spark.dialect = Dialect.get_or_raise(self.dialect)
196            return spark
class DataFrame:
 47class DataFrame:
 48    def __init__(
 49        self,
 50        spark: SparkSession,
 51        expression: exp.Select,
 52        branch_id: t.Optional[str] = None,
 53        sequence_id: t.Optional[str] = None,
 54        last_op: Operation = Operation.INIT,
 55        pending_hints: t.Optional[t.List[exp.Expression]] = None,
 56        output_expression_container: t.Optional[OutputExpressionContainer] = None,
 57        **kwargs,
 58    ):
 59        self.spark = spark
 60        self.expression = expression
 61        self.branch_id = branch_id or self.spark._random_branch_id
 62        self.sequence_id = sequence_id or self.spark._random_sequence_id
 63        self.last_op = last_op
 64        self.pending_hints = pending_hints or []
 65        self.output_expression_container = output_expression_container or exp.Select()
 66
 67    def __getattr__(self, column_name: str) -> Column:
 68        return self[column_name]
 69
 70    def __getitem__(self, column_name: str) -> Column:
 71        column_name = f"{self.branch_id}.{column_name}"
 72        return Column(column_name)
 73
 74    def __copy__(self):
 75        return self.copy()
 76
 77    @property
 78    def sparkSession(self):
 79        return self.spark
 80
 81    @property
 82    def write(self):
 83        return DataFrameWriter(self)
 84
 85    @property
 86    def latest_cte_name(self) -> str:
 87        if not self.expression.ctes:
 88            from_exp = self.expression.args["from"]
 89            if from_exp.alias_or_name:
 90                return from_exp.alias_or_name
 91            table_alias = from_exp.find(exp.TableAlias)
 92            if not table_alias:
 93                raise RuntimeError(
 94                    f"Could not find an alias name for this expression: {self.expression}"
 95                )
 96            return table_alias.alias_or_name
 97        return self.expression.ctes[-1].alias
 98
 99    @property
100    def pending_join_hints(self):
101        return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
102
103    @property
104    def pending_partition_hints(self):
105        return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
106
107    @property
108    def columns(self) -> t.List[str]:
109        return self.expression.named_selects
110
111    @property
112    def na(self) -> DataFrameNaFunctions:
113        return DataFrameNaFunctions(self)
114
115    def _replace_cte_names_with_hashes(self, expression: exp.Select):
116        replacement_mapping = {}
117        for cte in expression.ctes:
118            old_name_id = cte.args["alias"].this
119            new_hashed_id = exp.to_identifier(
120                self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
121            )
122            replacement_mapping[old_name_id] = new_hashed_id
123            expression = expression.transform(replace_id_value, replacement_mapping)
124        return expression
125
126    def _create_cte_from_expression(
127        self,
128        expression: exp.Expression,
129        branch_id: t.Optional[str] = None,
130        sequence_id: t.Optional[str] = None,
131        **kwargs,
132    ) -> t.Tuple[exp.CTE, str]:
133        name = self._create_hash_from_expression(expression)
134        expression_to_cte = expression.copy()
135        expression_to_cte.set("with", None)
136        cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
137        cte.set("branch_id", branch_id or self.branch_id)
138        cte.set("sequence_id", sequence_id or self.sequence_id)
139        return cte, name
140
141    @t.overload
142    def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
143
144    @t.overload
145    def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
146
147    def _ensure_list_of_columns(self, cols):
148        return Column.ensure_cols(ensure_list(cols))
149
150    def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None):
151        cols = self._ensure_list_of_columns(cols)
152        normalize(self.spark, expression or self.expression, cols)
153        return cols
154
155    def _ensure_and_normalize_col(self, col):
156        col = Column.ensure_col(col)
157        normalize(self.spark, self.expression, col)
158        return col
159
160    def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
161        df = self._resolve_pending_hints()
162        sequence_id = sequence_id or df.sequence_id
163        expression = df.expression.copy()
164        cte_expression, cte_name = df._create_cte_from_expression(
165            expression=expression, sequence_id=sequence_id
166        )
167        new_expression = df._add_ctes_to_expression(
168            exp.Select(), expression.ctes + [cte_expression]
169        )
170        sel_columns = df._get_outer_select_columns(cte_expression)
171        new_expression = new_expression.from_(cte_name).select(
172            *[x.alias_or_name for x in sel_columns]
173        )
174        return df.copy(expression=new_expression, sequence_id=sequence_id)
175
176    def _resolve_pending_hints(self) -> DataFrame:
177        df = self.copy()
178        if not self.pending_hints:
179            return df
180        expression = df.expression
181        hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
182        for hint in df.pending_partition_hints:
183            hint_expression.append("expressions", hint)
184            df.pending_hints.remove(hint)
185
186        join_aliases = {
187            join_table.alias_or_name
188            for join_table in get_tables_from_expression_with_join(expression)
189        }
190        if join_aliases:
191            for hint in df.pending_join_hints:
192                for sequence_id_expression in hint.expressions:
193                    sequence_id_or_name = sequence_id_expression.alias_or_name
194                    sequence_ids_to_match = [sequence_id_or_name]
195                    if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
196                        sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
197                            sequence_id_or_name
198                        ]
199                    matching_ctes = [
200                        cte
201                        for cte in reversed(expression.ctes)
202                        if cte.args["sequence_id"] in sequence_ids_to_match
203                    ]
204                    for matching_cte in matching_ctes:
205                        if matching_cte.alias_or_name in join_aliases:
206                            sequence_id_expression.set("this", matching_cte.args["alias"].this)
207                            df.pending_hints.remove(hint)
208                            break
209                hint_expression.append("expressions", hint)
210        if hint_expression.expressions:
211            expression.set("hint", hint_expression)
212        return df
213
214    def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
215        hint_name = hint_name.upper()
216        hint_expression = (
217            exp.JoinHint(
218                this=hint_name,
219                expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
220            )
221            if hint_name in JOIN_HINTS
222            else exp.Anonymous(
223                this=hint_name, expressions=[parameter.expression for parameter in args]
224            )
225        )
226        new_df = self.copy()
227        new_df.pending_hints.append(hint_expression)
228        return new_df
229
230    def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
231        other_df = other._convert_leaf_to_cte()
232        base_expression = self.expression.copy()
233        base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
234        all_ctes = base_expression.ctes
235        other_df.expression.set("with", None)
236        base_expression.set("with", None)
237        operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
238        operation.set("with", exp.With(expressions=all_ctes))
239        return self.copy(expression=operation)._convert_leaf_to_cte()
240
241    def _cache(self, storage_level: str):
242        df = self._convert_leaf_to_cte()
243        df.expression.ctes[-1].set("cache_storage_level", storage_level)
244        return df
245
246    @classmethod
247    def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
248        expression = expression.copy()
249        with_expression = expression.args.get("with")
250        if with_expression:
251            existing_ctes = with_expression.expressions
252            existsing_cte_names = {x.alias_or_name for x in existing_ctes}
253            for cte in ctes:
254                if cte.alias_or_name not in existsing_cte_names:
255                    existing_ctes.append(cte)
256        else:
257            existing_ctes = ctes
258        expression.set("with", exp.With(expressions=existing_ctes))
259        return expression
260
261    @classmethod
262    def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
263        expression = item.expression if isinstance(item, DataFrame) else item
264        return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
265
266    @classmethod
267    def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
268        from sqlglot.dataframe.sql.session import SparkSession
269
270        value = expression.sql(dialect=SparkSession().dialect).encode("utf-8")
271        return f"t{zlib.crc32(value)}"[:6]
272
273    def _get_select_expressions(
274        self,
275    ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
276        select_expressions: t.List[
277            t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
278        ] = []
279        main_select_ctes: t.List[exp.CTE] = []
280        for cte in self.expression.ctes:
281            cache_storage_level = cte.args.get("cache_storage_level")
282            if cache_storage_level:
283                select_expression = cte.this.copy()
284                select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
285                select_expression.set("cte_alias_name", cte.alias_or_name)
286                select_expression.set("cache_storage_level", cache_storage_level)
287                select_expressions.append((exp.Cache, select_expression))
288            else:
289                main_select_ctes.append(cte)
290        main_select = self.expression.copy()
291        if main_select_ctes:
292            main_select.set("with", exp.With(expressions=main_select_ctes))
293        expression_select_pair = (type(self.output_expression_container), main_select)
294        select_expressions.append(expression_select_pair)  # type: ignore
295        return select_expressions
296
297    def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]:
298        from sqlglot.dataframe.sql.session import SparkSession
299
300        dialect = Dialect.get_or_raise(dialect or SparkSession().dialect)
301
302        df = self._resolve_pending_hints()
303        select_expressions = df._get_select_expressions()
304        output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
305        replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
306
307        for expression_type, select_expression in select_expressions:
308            select_expression = select_expression.transform(replace_id_value, replacement_mapping)
309            if optimize:
310                select_expression = t.cast(
311                    exp.Select, self.spark._optimize(select_expression, dialect=dialect)
312                )
313
314            select_expression = df._replace_cte_names_with_hashes(select_expression)
315
316            expression: t.Union[exp.Select, exp.Cache, exp.Drop]
317            if expression_type == exp.Cache:
318                cache_table_name = df._create_hash_from_expression(select_expression)
319                cache_table = exp.to_table(cache_table_name)
320                original_alias_name = select_expression.args["cte_alias_name"]
321
322                replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(  # type: ignore
323                    cache_table_name
324                )
325                sqlglot.schema.add_table(
326                    cache_table_name,
327                    {
328                        expression.alias_or_name: expression.type.sql(dialect=dialect)
329                        for expression in select_expression.expressions
330                    },
331                    dialect=dialect,
332                )
333
334                cache_storage_level = select_expression.args["cache_storage_level"]
335                options = [
336                    exp.Literal.string("storageLevel"),
337                    exp.Literal.string(cache_storage_level),
338                ]
339                expression = exp.Cache(
340                    this=cache_table, expression=select_expression, lazy=True, options=options
341                )
342
343                # We will drop the "view" if it exists before running the cache table
344                output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
345            elif expression_type == exp.Create:
346                expression = df.output_expression_container.copy()
347                expression.set("expression", select_expression)
348            elif expression_type == exp.Insert:
349                expression = df.output_expression_container.copy()
350                select_without_ctes = select_expression.copy()
351                select_without_ctes.set("with", None)
352                expression.set("expression", select_without_ctes)
353
354                if select_expression.ctes:
355                    expression.set("with", exp.With(expressions=select_expression.ctes))
356            elif expression_type == exp.Select:
357                expression = select_expression
358            else:
359                raise ValueError(f"Invalid expression type: {expression_type}")
360
361            output_expressions.append(expression)
362
363        return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions]
364
365    def copy(self, **kwargs) -> DataFrame:
366        return DataFrame(**object_to_dict(self, **kwargs))
367
368    @operation(Operation.SELECT)
369    def select(self, *cols, **kwargs) -> DataFrame:
370        cols = self._ensure_and_normalize_cols(cols)
371        kwargs["append"] = kwargs.get("append", False)
372        if self.expression.args.get("joins"):
373            ambiguous_cols = [
374                col
375                for col in cols
376                if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
377            ]
378            if ambiguous_cols:
379                join_table_identifiers = [
380                    x.this for x in get_tables_from_expression_with_join(self.expression)
381                ]
382                cte_names_in_join = [x.this for x in join_table_identifiers]
383                # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
384                # and therefore we allow multiple columns with the same name in the result. This matches the behavior
385                # of Spark.
386                resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
387                for ambiguous_col in ambiguous_cols:
388                    ctes_with_column = [
389                        cte
390                        for cte in self.expression.ctes
391                        if cte.alias_or_name in cte_names_in_join
392                        and ambiguous_col.alias_or_name in cte.this.named_selects
393                    ]
394                    # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
395                    # use the same CTE we used before
396                    cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
397                    if cte:
398                        resolved_column_position[ambiguous_col] += 1
399                    else:
400                        cte = ctes_with_column[resolved_column_position[ambiguous_col]]
401                    ambiguous_col.expression.set("table", cte.alias_or_name)
402        return self.copy(
403            expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
404        )
405
406    @operation(Operation.NO_OP)
407    def alias(self, name: str, **kwargs) -> DataFrame:
408        new_sequence_id = self.spark._random_sequence_id
409        df = self.copy()
410        for join_hint in df.pending_join_hints:
411            for expression in join_hint.expressions:
412                if expression.alias_or_name == self.sequence_id:
413                    expression.set("this", Column.ensure_col(new_sequence_id).expression)
414        df.spark._add_alias_to_mapping(name, new_sequence_id)
415        return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
416
417    @operation(Operation.WHERE)
418    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
419        col = self._ensure_and_normalize_col(column)
420        return self.copy(expression=self.expression.where(col.expression))
421
422    filter = where
423
424    @operation(Operation.GROUP_BY)
425    def groupBy(self, *cols, **kwargs) -> GroupedData:
426        columns = self._ensure_and_normalize_cols(cols)
427        return GroupedData(self, columns, self.last_op)
428
429    @operation(Operation.SELECT)
430    def agg(self, *exprs, **kwargs) -> DataFrame:
431        cols = self._ensure_and_normalize_cols(exprs)
432        return self.groupBy().agg(*cols)
433
434    @operation(Operation.FROM)
435    def join(
436        self,
437        other_df: DataFrame,
438        on: t.Union[str, t.List[str], Column, t.List[Column]],
439        how: str = "inner",
440        **kwargs,
441    ) -> DataFrame:
442        other_df = other_df._convert_leaf_to_cte()
443        join_columns = self._ensure_list_of_columns(on)
444        # We will determine actual "join on" expression later so we don't provide it at first
445        join_expression = self.expression.join(
446            other_df.latest_cte_name, join_type=how.replace("_", " ")
447        )
448        join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
449        self_columns = self._get_outer_select_columns(join_expression)
450        other_columns = self._get_outer_select_columns(other_df)
451        # Determines the join clause and select columns to be used passed on what type of columns were provided for
452        # the join. The columns returned changes based on how the on expression is provided.
453        if isinstance(join_columns[0].expression, exp.Column):
454            """
455            Unique characteristics of join on column names only:
456            * The column names are put at the front of the select list
457            * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
458            """
459            table_names = [
460                table.alias_or_name
461                for table in get_tables_from_expression_with_join(join_expression)
462            ]
463            potential_ctes = [
464                cte
465                for cte in join_expression.ctes
466                if cte.alias_or_name in table_names
467                and cte.alias_or_name != other_df.latest_cte_name
468            ]
469            # Determine the table to reference for the left side of the join by checking each of the left side
470            # tables and see if they have the column being referenced.
471            join_column_pairs = []
472            for join_column in join_columns:
473                num_matching_ctes = 0
474                for cte in potential_ctes:
475                    if join_column.alias_or_name in cte.this.named_selects:
476                        left_column = join_column.copy().set_table_name(cte.alias_or_name)
477                        right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
478                        join_column_pairs.append((left_column, right_column))
479                        num_matching_ctes += 1
480                if num_matching_ctes > 1:
481                    raise ValueError(
482                        f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
483                    )
484                elif num_matching_ctes == 0:
485                    raise ValueError(
486                        f"Column {join_column.alias_or_name} does not exist in any of the tables."
487                    )
488            join_clause = functools.reduce(
489                lambda x, y: x & y,
490                [left_column == right_column for left_column, right_column in join_column_pairs],
491            )
492            join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
493            # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
494            select_column_names = [
495                (
496                    column.alias_or_name
497                    if not isinstance(column.expression.this, exp.Star)
498                    else column.sql()
499                )
500                for column in self_columns + other_columns
501            ]
502            select_column_names = [
503                column_name
504                for column_name in select_column_names
505                if column_name not in join_column_names
506            ]
507            select_column_names = join_column_names + select_column_names
508        else:
509            """
510            Unique characteristics of join on expressions:
511            * There is no deduplication of the results.
512            * The left join dataframe columns go first and right come after. No sort preference is given to join columns
513            """
514            join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
515            if len(join_columns) > 1:
516                join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
517            join_clause = join_columns[0]
518            select_column_names = [column.alias_or_name for column in self_columns + other_columns]
519
520        # Update the on expression with the actual join clause to replace the dummy one from before
521        join_expression.args["joins"][-1].set("on", join_clause.expression)
522        new_df = self.copy(expression=join_expression)
523        new_df.pending_join_hints.extend(self.pending_join_hints)
524        new_df.pending_hints.extend(other_df.pending_hints)
525        new_df = new_df.select.__wrapped__(new_df, *select_column_names)
526        return new_df
527
528    @operation(Operation.ORDER_BY)
529    def orderBy(
530        self,
531        *cols: t.Union[str, Column],
532        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
533    ) -> DataFrame:
534        """
535        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
536        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
537        is unlikely to come up.
538        """
539        columns = self._ensure_and_normalize_cols(cols)
540        pre_ordered_col_indexes = [
541            i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered)
542        ]
543        if ascending is None:
544            ascending = [True] * len(columns)
545        elif not isinstance(ascending, list):
546            ascending = [ascending] * len(columns)
547        ascending = [bool(x) for i, x in enumerate(ascending)]
548        assert len(columns) == len(
549            ascending
550        ), "The length of items in ascending must equal the number of columns provided"
551        col_and_ascending = list(zip(columns, ascending))
552        order_by_columns = [
553            (
554                exp.Ordered(this=col.expression, desc=not asc)
555                if i not in pre_ordered_col_indexes
556                else columns[i].column_expression
557            )
558            for i, (col, asc) in enumerate(col_and_ascending)
559        ]
560        return self.copy(expression=self.expression.order_by(*order_by_columns))
561
562    sort = orderBy
563
564    @operation(Operation.FROM)
565    def union(self, other: DataFrame) -> DataFrame:
566        return self._set_operation(exp.Union, other, False)
567
568    unionAll = union
569
570    @operation(Operation.FROM)
571    def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
572        l_columns = self.columns
573        r_columns = other.columns
574        if not allowMissingColumns:
575            l_expressions = l_columns
576            r_expressions = l_columns
577        else:
578            l_expressions = []
579            r_expressions = []
580            r_columns_unused = copy(r_columns)
581            for l_column in l_columns:
582                l_expressions.append(l_column)
583                if l_column in r_columns:
584                    r_expressions.append(l_column)
585                    r_columns_unused.remove(l_column)
586                else:
587                    r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
588            for r_column in r_columns_unused:
589                l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
590                r_expressions.append(r_column)
591        r_df = (
592            other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
593        )
594        l_df = self.copy()
595        if allowMissingColumns:
596            l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
597        return l_df._set_operation(exp.Union, r_df, False)
598
599    @operation(Operation.FROM)
600    def intersect(self, other: DataFrame) -> DataFrame:
601        return self._set_operation(exp.Intersect, other, True)
602
603    @operation(Operation.FROM)
604    def intersectAll(self, other: DataFrame) -> DataFrame:
605        return self._set_operation(exp.Intersect, other, False)
606
607    @operation(Operation.FROM)
608    def exceptAll(self, other: DataFrame) -> DataFrame:
609        return self._set_operation(exp.Except, other, False)
610
611    @operation(Operation.SELECT)
612    def distinct(self) -> DataFrame:
613        return self.copy(expression=self.expression.distinct())
614
615    @operation(Operation.SELECT)
616    def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
617        if not subset:
618            return self.distinct()
619        column_names = ensure_list(subset)
620        window = Window.partitionBy(*column_names).orderBy(*column_names)
621        return (
622            self.copy()
623            .withColumn("row_num", F.row_number().over(window))
624            .where(F.col("row_num") == F.lit(1))
625            .drop("row_num")
626        )
627
628    @operation(Operation.FROM)
629    def dropna(
630        self,
631        how: str = "any",
632        thresh: t.Optional[int] = None,
633        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
634    ) -> DataFrame:
635        minimum_non_null = thresh or 0  # will be determined later if thresh is null
636        new_df = self.copy()
637        all_columns = self._get_outer_select_columns(new_df.expression)
638        if subset:
639            null_check_columns = self._ensure_and_normalize_cols(subset)
640        else:
641            null_check_columns = all_columns
642        if thresh is None:
643            minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
644        else:
645            minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
646        if minimum_num_nulls > len(null_check_columns):
647            raise RuntimeError(
648                f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
649                f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
650            )
651        if_null_checks = [
652            F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
653        ]
654        nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
655        num_nulls = nulls_added_together.alias("num_nulls")
656        new_df = new_df.select(num_nulls, append=True)
657        filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
658        final_df = filtered_df.select(*all_columns)
659        return final_df
660
661    @operation(Operation.FROM)
662    def fillna(
663        self,
664        value: t.Union[ColumnLiterals],
665        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
666    ) -> DataFrame:
667        """
668        Functionality Difference: If you provide a value to replace a null and that type conflicts
669        with the type of the column then PySpark will just ignore your replacement.
670        This will try to cast them to be the same in some cases. So they won't always match.
671        Best to not mix types so make sure replacement is the same type as the column
672
673        Possibility for improvement: Use `typeof` function to get the type of the column
674        and check if it matches the type of the value provided. If not then make it null.
675        """
676        from sqlglot.dataframe.sql.functions import lit
677
678        values = None
679        columns = None
680        new_df = self.copy()
681        all_columns = self._get_outer_select_columns(new_df.expression)
682        all_column_mapping = {column.alias_or_name: column for column in all_columns}
683        if isinstance(value, dict):
684            values = list(value.values())
685            columns = self._ensure_and_normalize_cols(list(value))
686        if not columns:
687            columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
688        if not values:
689            values = [value] * len(columns)
690        value_columns = [lit(value) for value in values]
691
692        null_replacement_mapping = {
693            column.alias_or_name: (
694                F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
695            )
696            for column, value in zip(columns, value_columns)
697        }
698        null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
699        null_replacement_columns = [
700            null_replacement_mapping[column.alias_or_name] for column in all_columns
701        ]
702        new_df = new_df.select(*null_replacement_columns)
703        return new_df
704
705    @operation(Operation.FROM)
706    def replace(
707        self,
708        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
709        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
710        subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
711    ) -> DataFrame:
712        from sqlglot.dataframe.sql.functions import lit
713
714        old_values = None
715        new_df = self.copy()
716        all_columns = self._get_outer_select_columns(new_df.expression)
717        all_column_mapping = {column.alias_or_name: column for column in all_columns}
718
719        columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
720        if isinstance(to_replace, dict):
721            old_values = list(to_replace)
722            new_values = list(to_replace.values())
723        elif not old_values and isinstance(to_replace, list):
724            assert isinstance(value, list), "value must be a list since the replacements are a list"
725            assert len(to_replace) == len(
726                value
727            ), "the replacements and values must be the same length"
728            old_values = to_replace
729            new_values = value
730        else:
731            old_values = [to_replace] * len(columns)
732            new_values = [value] * len(columns)
733        old_values = [lit(value) for value in old_values]
734        new_values = [lit(value) for value in new_values]
735
736        replacement_mapping = {}
737        for column in columns:
738            expression = Column(None)
739            for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
740                if i == 0:
741                    expression = F.when(column == old_value, new_value)
742                else:
743                    expression = expression.when(column == old_value, new_value)  # type: ignore
744            replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
745                column.expression.alias_or_name
746            )
747
748        replacement_mapping = {**all_column_mapping, **replacement_mapping}
749        replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
750        new_df = new_df.select(*replacement_columns)
751        return new_df
752
753    @operation(Operation.SELECT)
754    def withColumn(self, colName: str, col: Column) -> DataFrame:
755        col = self._ensure_and_normalize_col(col)
756        existing_col_names = self.expression.named_selects
757        existing_col_index = (
758            existing_col_names.index(colName) if colName in existing_col_names else None
759        )
760        if existing_col_index:
761            expression = self.expression.copy()
762            expression.expressions[existing_col_index] = col.expression
763            return self.copy(expression=expression)
764        return self.copy().select(col.alias(colName), append=True)
765
766    @operation(Operation.SELECT)
767    def withColumnRenamed(self, existing: str, new: str):
768        expression = self.expression.copy()
769        existing_columns = [
770            expression
771            for expression in expression.expressions
772            if expression.alias_or_name == existing
773        ]
774        if not existing_columns:
775            raise ValueError("Tried to rename a column that doesn't exist")
776        for existing_column in existing_columns:
777            if isinstance(existing_column, exp.Column):
778                existing_column.replace(exp.alias_(existing_column, new))
779            else:
780                existing_column.set("alias", exp.to_identifier(new))
781        return self.copy(expression=expression)
782
783    @operation(Operation.SELECT)
784    def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
785        all_columns = self._get_outer_select_columns(self.expression)
786        drop_cols = self._ensure_and_normalize_cols(cols)
787        new_columns = [
788            col
789            for col in all_columns
790            if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
791        ]
792        return self.copy().select(*new_columns, append=False)
793
794    @operation(Operation.LIMIT)
795    def limit(self, num: int) -> DataFrame:
796        return self.copy(expression=self.expression.limit(num))
797
798    @operation(Operation.NO_OP)
799    def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
800        parameter_list = ensure_list(parameters)
801        parameter_columns = (
802            self._ensure_list_of_columns(parameter_list)
803            if parameters
804            else Column.ensure_cols([self.sequence_id])
805        )
806        return self._hint(name, parameter_columns)
807
808    @operation(Operation.NO_OP)
809    def repartition(
810        self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
811    ) -> DataFrame:
812        num_partition_cols = self._ensure_list_of_columns(numPartitions)
813        columns = self._ensure_and_normalize_cols(cols)
814        args = num_partition_cols + columns
815        return self._hint("repartition", args)
816
817    @operation(Operation.NO_OP)
818    def coalesce(self, numPartitions: int) -> DataFrame:
819        num_partitions = Column.ensure_cols([numPartitions])
820        return self._hint("coalesce", num_partitions)
821
822    @operation(Operation.NO_OP)
823    def cache(self) -> DataFrame:
824        return self._cache(storage_level="MEMORY_AND_DISK")
825
826    @operation(Operation.NO_OP)
827    def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
828        """
829        Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
830        """
831        return self._cache(storageLevel)
DataFrame( spark: <MagicMock id='140281901234880'>, expression: sqlglot.expressions.Select, branch_id: Optional[str] = None, sequence_id: Optional[str] = None, last_op: sqlglot.dataframe.sql.operations.Operation = <Operation.INIT: -1>, pending_hints: Optional[List[sqlglot.expressions.Expression]] = None, output_expression_container: Optional[<MagicMock id='140281901148688'>] = None, **kwargs)
48    def __init__(
49        self,
50        spark: SparkSession,
51        expression: exp.Select,
52        branch_id: t.Optional[str] = None,
53        sequence_id: t.Optional[str] = None,
54        last_op: Operation = Operation.INIT,
55        pending_hints: t.Optional[t.List[exp.Expression]] = None,
56        output_expression_container: t.Optional[OutputExpressionContainer] = None,
57        **kwargs,
58    ):
59        self.spark = spark
60        self.expression = expression
61        self.branch_id = branch_id or self.spark._random_branch_id
62        self.sequence_id = sequence_id or self.spark._random_sequence_id
63        self.last_op = last_op
64        self.pending_hints = pending_hints or []
65        self.output_expression_container = output_expression_container or exp.Select()
spark
expression
branch_id
sequence_id
last_op
pending_hints
output_expression_container
sparkSession
77    @property
78    def sparkSession(self):
79        return self.spark
write
81    @property
82    def write(self):
83        return DataFrameWriter(self)
latest_cte_name: str
85    @property
86    def latest_cte_name(self) -> str:
87        if not self.expression.ctes:
88            from_exp = self.expression.args["from"]
89            if from_exp.alias_or_name:
90                return from_exp.alias_or_name
91            table_alias = from_exp.find(exp.TableAlias)
92            if not table_alias:
93                raise RuntimeError(
94                    f"Could not find an alias name for this expression: {self.expression}"
95                )
96            return table_alias.alias_or_name
97        return self.expression.ctes[-1].alias
pending_join_hints
 99    @property
100    def pending_join_hints(self):
101        return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
pending_partition_hints
103    @property
104    def pending_partition_hints(self):
105        return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
columns: List[str]
107    @property
108    def columns(self) -> t.List[str]:
109        return self.expression.named_selects
na: DataFrameNaFunctions
111    @property
112    def na(self) -> DataFrameNaFunctions:
113        return DataFrameNaFunctions(self)
def sql( self, dialect: <MagicMock id='140281897159552'> = None, optimize: bool = True, **kwargs) -> List[str]:
297    def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]:
298        from sqlglot.dataframe.sql.session import SparkSession
299
300        dialect = Dialect.get_or_raise(dialect or SparkSession().dialect)
301
302        df = self._resolve_pending_hints()
303        select_expressions = df._get_select_expressions()
304        output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
305        replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
306
307        for expression_type, select_expression in select_expressions:
308            select_expression = select_expression.transform(replace_id_value, replacement_mapping)
309            if optimize:
310                select_expression = t.cast(
311                    exp.Select, self.spark._optimize(select_expression, dialect=dialect)
312                )
313
314            select_expression = df._replace_cte_names_with_hashes(select_expression)
315
316            expression: t.Union[exp.Select, exp.Cache, exp.Drop]
317            if expression_type == exp.Cache:
318                cache_table_name = df._create_hash_from_expression(select_expression)
319                cache_table = exp.to_table(cache_table_name)
320                original_alias_name = select_expression.args["cte_alias_name"]
321
322                replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(  # type: ignore
323                    cache_table_name
324                )
325                sqlglot.schema.add_table(
326                    cache_table_name,
327                    {
328                        expression.alias_or_name: expression.type.sql(dialect=dialect)
329                        for expression in select_expression.expressions
330                    },
331                    dialect=dialect,
332                )
333
334                cache_storage_level = select_expression.args["cache_storage_level"]
335                options = [
336                    exp.Literal.string("storageLevel"),
337                    exp.Literal.string(cache_storage_level),
338                ]
339                expression = exp.Cache(
340                    this=cache_table, expression=select_expression, lazy=True, options=options
341                )
342
343                # We will drop the "view" if it exists before running the cache table
344                output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
345            elif expression_type == exp.Create:
346                expression = df.output_expression_container.copy()
347                expression.set("expression", select_expression)
348            elif expression_type == exp.Insert:
349                expression = df.output_expression_container.copy()
350                select_without_ctes = select_expression.copy()
351                select_without_ctes.set("with", None)
352                expression.set("expression", select_without_ctes)
353
354                if select_expression.ctes:
355                    expression.set("with", exp.With(expressions=select_expression.ctes))
356            elif expression_type == exp.Select:
357                expression = select_expression
358            else:
359                raise ValueError(f"Invalid expression type: {expression_type}")
360
361            output_expressions.append(expression)
362
363        return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions]
def copy(self, **kwargs) -> DataFrame:
365    def copy(self, **kwargs) -> DataFrame:
366        return DataFrame(**object_to_dict(self, **kwargs))
@operation(Operation.SELECT)
def select(self, *cols, **kwargs) -> DataFrame:
368    @operation(Operation.SELECT)
369    def select(self, *cols, **kwargs) -> DataFrame:
370        cols = self._ensure_and_normalize_cols(cols)
371        kwargs["append"] = kwargs.get("append", False)
372        if self.expression.args.get("joins"):
373            ambiguous_cols = [
374                col
375                for col in cols
376                if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
377            ]
378            if ambiguous_cols:
379                join_table_identifiers = [
380                    x.this for x in get_tables_from_expression_with_join(self.expression)
381                ]
382                cte_names_in_join = [x.this for x in join_table_identifiers]
383                # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
384                # and therefore we allow multiple columns with the same name in the result. This matches the behavior
385                # of Spark.
386                resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
387                for ambiguous_col in ambiguous_cols:
388                    ctes_with_column = [
389                        cte
390                        for cte in self.expression.ctes
391                        if cte.alias_or_name in cte_names_in_join
392                        and ambiguous_col.alias_or_name in cte.this.named_selects
393                    ]
394                    # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
395                    # use the same CTE we used before
396                    cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
397                    if cte:
398                        resolved_column_position[ambiguous_col] += 1
399                    else:
400                        cte = ctes_with_column[resolved_column_position[ambiguous_col]]
401                    ambiguous_col.expression.set("table", cte.alias_or_name)
402        return self.copy(
403            expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
404        )
@operation(Operation.NO_OP)
def alias(self, name: str, **kwargs) -> DataFrame:
406    @operation(Operation.NO_OP)
407    def alias(self, name: str, **kwargs) -> DataFrame:
408        new_sequence_id = self.spark._random_sequence_id
409        df = self.copy()
410        for join_hint in df.pending_join_hints:
411            for expression in join_hint.expressions:
412                if expression.alias_or_name == self.sequence_id:
413                    expression.set("this", Column.ensure_col(new_sequence_id).expression)
414        df.spark._add_alias_to_mapping(name, new_sequence_id)
415        return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
@operation(Operation.WHERE)
def where( self, column: Union[Column, bool], **kwargs) -> DataFrame:
417    @operation(Operation.WHERE)
418    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
419        col = self._ensure_and_normalize_col(column)
420        return self.copy(expression=self.expression.where(col.expression))
@operation(Operation.WHERE)
def filter( self, column: Union[Column, bool], **kwargs) -> DataFrame:
417    @operation(Operation.WHERE)
418    def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
419        col = self._ensure_and_normalize_col(column)
420        return self.copy(expression=self.expression.where(col.expression))
@operation(Operation.GROUP_BY)
def groupBy(self, *cols, **kwargs) -> GroupedData:
424    @operation(Operation.GROUP_BY)
425    def groupBy(self, *cols, **kwargs) -> GroupedData:
426        columns = self._ensure_and_normalize_cols(cols)
427        return GroupedData(self, columns, self.last_op)
@operation(Operation.SELECT)
def agg(self, *exprs, **kwargs) -> DataFrame:
429    @operation(Operation.SELECT)
430    def agg(self, *exprs, **kwargs) -> DataFrame:
431        cols = self._ensure_and_normalize_cols(exprs)
432        return self.groupBy().agg(*cols)
@operation(Operation.FROM)
def join( self, other_df: DataFrame, on: Union[str, List[str], Column, List[Column]], how: str = 'inner', **kwargs) -> DataFrame:
434    @operation(Operation.FROM)
435    def join(
436        self,
437        other_df: DataFrame,
438        on: t.Union[str, t.List[str], Column, t.List[Column]],
439        how: str = "inner",
440        **kwargs,
441    ) -> DataFrame:
442        other_df = other_df._convert_leaf_to_cte()
443        join_columns = self._ensure_list_of_columns(on)
444        # We will determine actual "join on" expression later so we don't provide it at first
445        join_expression = self.expression.join(
446            other_df.latest_cte_name, join_type=how.replace("_", " ")
447        )
448        join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
449        self_columns = self._get_outer_select_columns(join_expression)
450        other_columns = self._get_outer_select_columns(other_df)
451        # Determines the join clause and select columns to be used passed on what type of columns were provided for
452        # the join. The columns returned changes based on how the on expression is provided.
453        if isinstance(join_columns[0].expression, exp.Column):
454            """
455            Unique characteristics of join on column names only:
456            * The column names are put at the front of the select list
457            * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
458            """
459            table_names = [
460                table.alias_or_name
461                for table in get_tables_from_expression_with_join(join_expression)
462            ]
463            potential_ctes = [
464                cte
465                for cte in join_expression.ctes
466                if cte.alias_or_name in table_names
467                and cte.alias_or_name != other_df.latest_cte_name
468            ]
469            # Determine the table to reference for the left side of the join by checking each of the left side
470            # tables and see if they have the column being referenced.
471            join_column_pairs = []
472            for join_column in join_columns:
473                num_matching_ctes = 0
474                for cte in potential_ctes:
475                    if join_column.alias_or_name in cte.this.named_selects:
476                        left_column = join_column.copy().set_table_name(cte.alias_or_name)
477                        right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
478                        join_column_pairs.append((left_column, right_column))
479                        num_matching_ctes += 1
480                if num_matching_ctes > 1:
481                    raise ValueError(
482                        f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
483                    )
484                elif num_matching_ctes == 0:
485                    raise ValueError(
486                        f"Column {join_column.alias_or_name} does not exist in any of the tables."
487                    )
488            join_clause = functools.reduce(
489                lambda x, y: x & y,
490                [left_column == right_column for left_column, right_column in join_column_pairs],
491            )
492            join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
493            # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
494            select_column_names = [
495                (
496                    column.alias_or_name
497                    if not isinstance(column.expression.this, exp.Star)
498                    else column.sql()
499                )
500                for column in self_columns + other_columns
501            ]
502            select_column_names = [
503                column_name
504                for column_name in select_column_names
505                if column_name not in join_column_names
506            ]
507            select_column_names = join_column_names + select_column_names
508        else:
509            """
510            Unique characteristics of join on expressions:
511            * There is no deduplication of the results.
512            * The left join dataframe columns go first and right come after. No sort preference is given to join columns
513            """
514            join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
515            if len(join_columns) > 1:
516                join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
517            join_clause = join_columns[0]
518            select_column_names = [column.alias_or_name for column in self_columns + other_columns]
519
520        # Update the on expression with the actual join clause to replace the dummy one from before
521        join_expression.args["joins"][-1].set("on", join_clause.expression)
522        new_df = self.copy(expression=join_expression)
523        new_df.pending_join_hints.extend(self.pending_join_hints)
524        new_df.pending_hints.extend(other_df.pending_hints)
525        new_df = new_df.select.__wrapped__(new_df, *select_column_names)
526        return new_df
@operation(Operation.ORDER_BY)
def orderBy( self, *cols: Union[str, Column], ascending: Union[Any, List[Any], NoneType] = None) -> DataFrame:
528    @operation(Operation.ORDER_BY)
529    def orderBy(
530        self,
531        *cols: t.Union[str, Column],
532        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
533    ) -> DataFrame:
534        """
535        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
536        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
537        is unlikely to come up.
538        """
539        columns = self._ensure_and_normalize_cols(cols)
540        pre_ordered_col_indexes = [
541            i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered)
542        ]
543        if ascending is None:
544            ascending = [True] * len(columns)
545        elif not isinstance(ascending, list):
546            ascending = [ascending] * len(columns)
547        ascending = [bool(x) for i, x in enumerate(ascending)]
548        assert len(columns) == len(
549            ascending
550        ), "The length of items in ascending must equal the number of columns provided"
551        col_and_ascending = list(zip(columns, ascending))
552        order_by_columns = [
553            (
554                exp.Ordered(this=col.expression, desc=not asc)
555                if i not in pre_ordered_col_indexes
556                else columns[i].column_expression
557            )
558            for i, (col, asc) in enumerate(col_and_ascending)
559        ]
560        return self.copy(expression=self.expression.order_by(*order_by_columns))

This implementation lets any ordered columns take priority over whatever is provided in ascending. Spark has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this is unlikely to come up.

@operation(Operation.ORDER_BY)
def sort( self, *cols: Union[str, Column], ascending: Union[Any, List[Any], NoneType] = None) -> DataFrame:
528    @operation(Operation.ORDER_BY)
529    def orderBy(
530        self,
531        *cols: t.Union[str, Column],
532        ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
533    ) -> DataFrame:
534        """
535        This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
536        has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
537        is unlikely to come up.
538        """
539        columns = self._ensure_and_normalize_cols(cols)
540        pre_ordered_col_indexes = [
541            i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered)
542        ]
543        if ascending is None:
544            ascending = [True] * len(columns)
545        elif not isinstance(ascending, list):
546            ascending = [ascending] * len(columns)
547        ascending = [bool(x) for i, x in enumerate(ascending)]
548        assert len(columns) == len(
549            ascending
550        ), "The length of items in ascending must equal the number of columns provided"
551        col_and_ascending = list(zip(columns, ascending))
552        order_by_columns = [
553            (
554                exp.Ordered(this=col.expression, desc=not asc)
555                if i not in pre_ordered_col_indexes
556                else columns[i].column_expression
557            )
558            for i, (col, asc) in enumerate(col_and_ascending)
559        ]
560        return self.copy(expression=self.expression.order_by(*order_by_columns))

This implementation lets any ordered columns take priority over whatever is provided in ascending. Spark has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this is unlikely to come up.

@operation(Operation.FROM)
def union( self, other: DataFrame) -> DataFrame:
564    @operation(Operation.FROM)
565    def union(self, other: DataFrame) -> DataFrame:
566        return self._set_operation(exp.Union, other, False)
@operation(Operation.FROM)
def unionAll( self, other: DataFrame) -> DataFrame:
564    @operation(Operation.FROM)
565    def union(self, other: DataFrame) -> DataFrame:
566        return self._set_operation(exp.Union, other, False)
@operation(Operation.FROM)
def unionByName( self, other: DataFrame, allowMissingColumns: bool = False):
570    @operation(Operation.FROM)
571    def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
572        l_columns = self.columns
573        r_columns = other.columns
574        if not allowMissingColumns:
575            l_expressions = l_columns
576            r_expressions = l_columns
577        else:
578            l_expressions = []
579            r_expressions = []
580            r_columns_unused = copy(r_columns)
581            for l_column in l_columns:
582                l_expressions.append(l_column)
583                if l_column in r_columns:
584                    r_expressions.append(l_column)
585                    r_columns_unused.remove(l_column)
586                else:
587                    r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
588            for r_column in r_columns_unused:
589                l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
590                r_expressions.append(r_column)
591        r_df = (
592            other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
593        )
594        l_df = self.copy()
595        if allowMissingColumns:
596            l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
597        return l_df._set_operation(exp.Union, r_df, False)
@operation(Operation.FROM)
def intersect( self, other: DataFrame) -> DataFrame:
599    @operation(Operation.FROM)
600    def intersect(self, other: DataFrame) -> DataFrame:
601        return self._set_operation(exp.Intersect, other, True)
@operation(Operation.FROM)
def intersectAll( self, other: DataFrame) -> DataFrame:
603    @operation(Operation.FROM)
604    def intersectAll(self, other: DataFrame) -> DataFrame:
605        return self._set_operation(exp.Intersect, other, False)
@operation(Operation.FROM)
def exceptAll( self, other: DataFrame) -> DataFrame:
607    @operation(Operation.FROM)
608    def exceptAll(self, other: DataFrame) -> DataFrame:
609        return self._set_operation(exp.Except, other, False)
@operation(Operation.SELECT)
def distinct(self) -> DataFrame:
611    @operation(Operation.SELECT)
612    def distinct(self) -> DataFrame:
613        return self.copy(expression=self.expression.distinct())
@operation(Operation.SELECT)
def dropDuplicates(self, subset: Optional[List[str]] = None):
615    @operation(Operation.SELECT)
616    def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
617        if not subset:
618            return self.distinct()
619        column_names = ensure_list(subset)
620        window = Window.partitionBy(*column_names).orderBy(*column_names)
621        return (
622            self.copy()
623            .withColumn("row_num", F.row_number().over(window))
624            .where(F.col("row_num") == F.lit(1))
625            .drop("row_num")
626        )
@operation(Operation.FROM)
def dropna( self, how: str = 'any', thresh: Optional[int] = None, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> DataFrame:
628    @operation(Operation.FROM)
629    def dropna(
630        self,
631        how: str = "any",
632        thresh: t.Optional[int] = None,
633        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
634    ) -> DataFrame:
635        minimum_non_null = thresh or 0  # will be determined later if thresh is null
636        new_df = self.copy()
637        all_columns = self._get_outer_select_columns(new_df.expression)
638        if subset:
639            null_check_columns = self._ensure_and_normalize_cols(subset)
640        else:
641            null_check_columns = all_columns
642        if thresh is None:
643            minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
644        else:
645            minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
646        if minimum_num_nulls > len(null_check_columns):
647            raise RuntimeError(
648                f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
649                f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
650            )
651        if_null_checks = [
652            F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
653        ]
654        nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
655        num_nulls = nulls_added_together.alias("num_nulls")
656        new_df = new_df.select(num_nulls, append=True)
657        filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
658        final_df = filtered_df.select(*all_columns)
659        return final_df
@operation(Operation.FROM)
def fillna( self, value: <MagicMock id='140281896801504'>, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> DataFrame:
661    @operation(Operation.FROM)
662    def fillna(
663        self,
664        value: t.Union[ColumnLiterals],
665        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
666    ) -> DataFrame:
667        """
668        Functionality Difference: If you provide a value to replace a null and that type conflicts
669        with the type of the column then PySpark will just ignore your replacement.
670        This will try to cast them to be the same in some cases. So they won't always match.
671        Best to not mix types so make sure replacement is the same type as the column
672
673        Possibility for improvement: Use `typeof` function to get the type of the column
674        and check if it matches the type of the value provided. If not then make it null.
675        """
676        from sqlglot.dataframe.sql.functions import lit
677
678        values = None
679        columns = None
680        new_df = self.copy()
681        all_columns = self._get_outer_select_columns(new_df.expression)
682        all_column_mapping = {column.alias_or_name: column for column in all_columns}
683        if isinstance(value, dict):
684            values = list(value.values())
685            columns = self._ensure_and_normalize_cols(list(value))
686        if not columns:
687            columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
688        if not values:
689            values = [value] * len(columns)
690        value_columns = [lit(value) for value in values]
691
692        null_replacement_mapping = {
693            column.alias_or_name: (
694                F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
695            )
696            for column, value in zip(columns, value_columns)
697        }
698        null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
699        null_replacement_columns = [
700            null_replacement_mapping[column.alias_or_name] for column in all_columns
701        ]
702        new_df = new_df.select(*null_replacement_columns)
703        return new_df

Functionality Difference: If you provide a value to replace a null and that type conflicts with the type of the column then PySpark will just ignore your replacement. This will try to cast them to be the same in some cases. So they won't always match. Best to not mix types so make sure replacement is the same type as the column

Possibility for improvement: Use typeof function to get the type of the column and check if it matches the type of the value provided. If not then make it null.

@operation(Operation.FROM)
def replace( self, to_replace: Union[bool, int, float, str, List, Dict], value: Union[bool, int, float, str, List, NoneType] = None, subset: Union[Collection[<MagicMock id='140281895230800'>], <MagicMock id='140281895230800'>, NoneType] = None) -> DataFrame:
705    @operation(Operation.FROM)
706    def replace(
707        self,
708        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
709        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
710        subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
711    ) -> DataFrame:
712        from sqlglot.dataframe.sql.functions import lit
713
714        old_values = None
715        new_df = self.copy()
716        all_columns = self._get_outer_select_columns(new_df.expression)
717        all_column_mapping = {column.alias_or_name: column for column in all_columns}
718
719        columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
720        if isinstance(to_replace, dict):
721            old_values = list(to_replace)
722            new_values = list(to_replace.values())
723        elif not old_values and isinstance(to_replace, list):
724            assert isinstance(value, list), "value must be a list since the replacements are a list"
725            assert len(to_replace) == len(
726                value
727            ), "the replacements and values must be the same length"
728            old_values = to_replace
729            new_values = value
730        else:
731            old_values = [to_replace] * len(columns)
732            new_values = [value] * len(columns)
733        old_values = [lit(value) for value in old_values]
734        new_values = [lit(value) for value in new_values]
735
736        replacement_mapping = {}
737        for column in columns:
738            expression = Column(None)
739            for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
740                if i == 0:
741                    expression = F.when(column == old_value, new_value)
742                else:
743                    expression = expression.when(column == old_value, new_value)  # type: ignore
744            replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
745                column.expression.alias_or_name
746            )
747
748        replacement_mapping = {**all_column_mapping, **replacement_mapping}
749        replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
750        new_df = new_df.select(*replacement_columns)
751        return new_df
@operation(Operation.SELECT)
def withColumn( self, colName: str, col: Column) -> DataFrame:
753    @operation(Operation.SELECT)
754    def withColumn(self, colName: str, col: Column) -> DataFrame:
755        col = self._ensure_and_normalize_col(col)
756        existing_col_names = self.expression.named_selects
757        existing_col_index = (
758            existing_col_names.index(colName) if colName in existing_col_names else None
759        )
760        if existing_col_index:
761            expression = self.expression.copy()
762            expression.expressions[existing_col_index] = col.expression
763            return self.copy(expression=expression)
764        return self.copy().select(col.alias(colName), append=True)
@operation(Operation.SELECT)
def withColumnRenamed(self, existing: str, new: str):
766    @operation(Operation.SELECT)
767    def withColumnRenamed(self, existing: str, new: str):
768        expression = self.expression.copy()
769        existing_columns = [
770            expression
771            for expression in expression.expressions
772            if expression.alias_or_name == existing
773        ]
774        if not existing_columns:
775            raise ValueError("Tried to rename a column that doesn't exist")
776        for existing_column in existing_columns:
777            if isinstance(existing_column, exp.Column):
778                existing_column.replace(exp.alias_(existing_column, new))
779            else:
780                existing_column.set("alias", exp.to_identifier(new))
781        return self.copy(expression=expression)
@operation(Operation.SELECT)
def drop( self, *cols: Union[str, Column]) -> DataFrame:
783    @operation(Operation.SELECT)
784    def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
785        all_columns = self._get_outer_select_columns(self.expression)
786        drop_cols = self._ensure_and_normalize_cols(cols)
787        new_columns = [
788            col
789            for col in all_columns
790            if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
791        ]
792        return self.copy().select(*new_columns, append=False)
@operation(Operation.LIMIT)
def limit(self, num: int) -> DataFrame:
794    @operation(Operation.LIMIT)
795    def limit(self, num: int) -> DataFrame:
796        return self.copy(expression=self.expression.limit(num))
@operation(Operation.NO_OP)
def hint( self, name: str, *parameters: Union[str, int, NoneType]) -> DataFrame:
798    @operation(Operation.NO_OP)
799    def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
800        parameter_list = ensure_list(parameters)
801        parameter_columns = (
802            self._ensure_list_of_columns(parameter_list)
803            if parameters
804            else Column.ensure_cols([self.sequence_id])
805        )
806        return self._hint(name, parameter_columns)
@operation(Operation.NO_OP)
def repartition( self, numPartitions: Union[int, <MagicMock id='140281895230800'>], *cols: <MagicMock id='140281895230800'>) -> DataFrame:
808    @operation(Operation.NO_OP)
809    def repartition(
810        self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
811    ) -> DataFrame:
812        num_partition_cols = self._ensure_list_of_columns(numPartitions)
813        columns = self._ensure_and_normalize_cols(cols)
814        args = num_partition_cols + columns
815        return self._hint("repartition", args)
@operation(Operation.NO_OP)
def coalesce(self, numPartitions: int) -> DataFrame:
817    @operation(Operation.NO_OP)
818    def coalesce(self, numPartitions: int) -> DataFrame:
819        num_partitions = Column.ensure_cols([numPartitions])
820        return self._hint("coalesce", num_partitions)
@operation(Operation.NO_OP)
def cache(self) -> DataFrame:
822    @operation(Operation.NO_OP)
823    def cache(self) -> DataFrame:
824        return self._cache(storage_level="MEMORY_AND_DISK")
@operation(Operation.NO_OP)
def persist( self, storageLevel: str = 'MEMORY_AND_DISK_SER') -> DataFrame:
826    @operation(Operation.NO_OP)
827    def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
828        """
829        Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
830        """
831        return self._cache(storageLevel)
class GroupedData:
14class GroupedData:
15    def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
16        self._df = df.copy()
17        self.spark = df.spark
18        self.last_op = last_op
19        self.group_by_cols = group_by_cols
20
21    def _get_function_applied_columns(
22        self, func_name: str, cols: t.Tuple[str, ...]
23    ) -> t.List[Column]:
24        func_name = func_name.lower()
25        return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
26
27    @operation(Operation.SELECT)
28    def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
29        columns = (
30            [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
31            if isinstance(exprs[0], dict)
32            else exprs
33        )
34        cols = self._df._ensure_and_normalize_cols(columns)
35
36        expression = self._df.expression.group_by(
37            *[x.expression for x in self.group_by_cols]
38        ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
39        return self._df.copy(expression=expression)
40
41    def count(self) -> DataFrame:
42        return self.agg(F.count("*").alias("count"))
43
44    def mean(self, *cols: str) -> DataFrame:
45        return self.avg(*cols)
46
47    def avg(self, *cols: str) -> DataFrame:
48        return self.agg(*self._get_function_applied_columns("avg", cols))
49
50    def max(self, *cols: str) -> DataFrame:
51        return self.agg(*self._get_function_applied_columns("max", cols))
52
53    def min(self, *cols: str) -> DataFrame:
54        return self.agg(*self._get_function_applied_columns("min", cols))
55
56    def sum(self, *cols: str) -> DataFrame:
57        return self.agg(*self._get_function_applied_columns("sum", cols))
58
59    def pivot(self, *cols: str) -> DataFrame:
60        raise NotImplementedError("Sum distinct is not currently implemented")
GroupedData( df: DataFrame, group_by_cols: List[Column], last_op: sqlglot.dataframe.sql.operations.Operation)
15    def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
16        self._df = df.copy()
17        self.spark = df.spark
18        self.last_op = last_op
19        self.group_by_cols = group_by_cols
spark
last_op
group_by_cols
@operation(Operation.SELECT)
def agg( self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame:
27    @operation(Operation.SELECT)
28    def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
29        columns = (
30            [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
31            if isinstance(exprs[0], dict)
32            else exprs
33        )
34        cols = self._df._ensure_and_normalize_cols(columns)
35
36        expression = self._df.expression.group_by(
37            *[x.expression for x in self.group_by_cols]
38        ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
39        return self._df.copy(expression=expression)
def count(self) -> DataFrame:
41    def count(self) -> DataFrame:
42        return self.agg(F.count("*").alias("count"))
def mean(self, *cols: str) -> DataFrame:
44    def mean(self, *cols: str) -> DataFrame:
45        return self.avg(*cols)
def avg(self, *cols: str) -> DataFrame:
47    def avg(self, *cols: str) -> DataFrame:
48        return self.agg(*self._get_function_applied_columns("avg", cols))
def max(self, *cols: str) -> DataFrame:
50    def max(self, *cols: str) -> DataFrame:
51        return self.agg(*self._get_function_applied_columns("max", cols))
def min(self, *cols: str) -> DataFrame:
53    def min(self, *cols: str) -> DataFrame:
54        return self.agg(*self._get_function_applied_columns("min", cols))
def sum(self, *cols: str) -> DataFrame:
56    def sum(self, *cols: str) -> DataFrame:
57        return self.agg(*self._get_function_applied_columns("sum", cols))
def pivot(self, *cols: str) -> DataFrame:
59    def pivot(self, *cols: str) -> DataFrame:
60        raise NotImplementedError("Sum distinct is not currently implemented")
class Column:
 16class Column:
 17    def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
 18        from sqlglot.dataframe.sql.session import SparkSession
 19
 20        if isinstance(expression, Column):
 21            expression = expression.expression  # type: ignore
 22        elif expression is None or not isinstance(expression, (str, exp.Expression)):
 23            expression = self._lit(expression).expression  # type: ignore
 24        elif not isinstance(expression, exp.Column):
 25            expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform(
 26                SparkSession().dialect.normalize_identifier, copy=False
 27            )
 28        if expression is None:
 29            raise ValueError(f"Could not parse {expression}")
 30
 31        self.expression: exp.Expression = expression  # type: ignore
 32
 33    def __repr__(self):
 34        return repr(self.expression)
 35
 36    def __hash__(self):
 37        return hash(self.expression)
 38
 39    def __eq__(self, other: ColumnOrLiteral) -> Column:  # type: ignore
 40        return self.binary_op(exp.EQ, other)
 41
 42    def __ne__(self, other: ColumnOrLiteral) -> Column:  # type: ignore
 43        return self.binary_op(exp.NEQ, other)
 44
 45    def __gt__(self, other: ColumnOrLiteral) -> Column:
 46        return self.binary_op(exp.GT, other)
 47
 48    def __ge__(self, other: ColumnOrLiteral) -> Column:
 49        return self.binary_op(exp.GTE, other)
 50
 51    def __lt__(self, other: ColumnOrLiteral) -> Column:
 52        return self.binary_op(exp.LT, other)
 53
 54    def __le__(self, other: ColumnOrLiteral) -> Column:
 55        return self.binary_op(exp.LTE, other)
 56
 57    def __and__(self, other: ColumnOrLiteral) -> Column:
 58        return self.binary_op(exp.And, other)
 59
 60    def __or__(self, other: ColumnOrLiteral) -> Column:
 61        return self.binary_op(exp.Or, other)
 62
 63    def __mod__(self, other: ColumnOrLiteral) -> Column:
 64        return self.binary_op(exp.Mod, other)
 65
 66    def __add__(self, other: ColumnOrLiteral) -> Column:
 67        return self.binary_op(exp.Add, other)
 68
 69    def __sub__(self, other: ColumnOrLiteral) -> Column:
 70        return self.binary_op(exp.Sub, other)
 71
 72    def __mul__(self, other: ColumnOrLiteral) -> Column:
 73        return self.binary_op(exp.Mul, other)
 74
 75    def __truediv__(self, other: ColumnOrLiteral) -> Column:
 76        return self.binary_op(exp.Div, other)
 77
 78    def __div__(self, other: ColumnOrLiteral) -> Column:
 79        return self.binary_op(exp.Div, other)
 80
 81    def __neg__(self) -> Column:
 82        return self.unary_op(exp.Neg)
 83
 84    def __radd__(self, other: ColumnOrLiteral) -> Column:
 85        return self.inverse_binary_op(exp.Add, other)
 86
 87    def __rsub__(self, other: ColumnOrLiteral) -> Column:
 88        return self.inverse_binary_op(exp.Sub, other)
 89
 90    def __rmul__(self, other: ColumnOrLiteral) -> Column:
 91        return self.inverse_binary_op(exp.Mul, other)
 92
 93    def __rdiv__(self, other: ColumnOrLiteral) -> Column:
 94        return self.inverse_binary_op(exp.Div, other)
 95
 96    def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
 97        return self.inverse_binary_op(exp.Div, other)
 98
 99    def __rmod__(self, other: ColumnOrLiteral) -> Column:
100        return self.inverse_binary_op(exp.Mod, other)
101
102    def __pow__(self, power: ColumnOrLiteral, modulo=None):
103        return Column(exp.Pow(this=self.expression, expression=Column(power).expression))
104
105    def __rpow__(self, power: ColumnOrLiteral):
106        return Column(exp.Pow(this=Column(power).expression, expression=self.expression))
107
108    def __invert__(self):
109        return self.unary_op(exp.Not)
110
111    def __rand__(self, other: ColumnOrLiteral) -> Column:
112        return self.inverse_binary_op(exp.And, other)
113
114    def __ror__(self, other: ColumnOrLiteral) -> Column:
115        return self.inverse_binary_op(exp.Or, other)
116
117    @classmethod
118    def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column:
119        return cls(value)
120
121    @classmethod
122    def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
123        return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
124
125    @classmethod
126    def _lit(cls, value: ColumnOrLiteral) -> Column:
127        if isinstance(value, dict):
128            columns = [cls._lit(v).alias(k).expression for k, v in value.items()]
129            return cls(exp.Struct(expressions=columns))
130        return cls(exp.convert(value))
131
132    @classmethod
133    def invoke_anonymous_function(
134        cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
135    ) -> Column:
136        columns = [] if column is None else [cls.ensure_col(column)]
137        column_args = [cls.ensure_col(arg) for arg in args]
138        expressions = [x.expression for x in columns + column_args]
139        new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
140        return Column(new_expression)
141
142    @classmethod
143    def invoke_expression_over_column(
144        cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
145    ) -> Column:
146        ensured_column = None if column is None else cls.ensure_col(column)
147        ensure_expression_values = {
148            k: (
149                [Column.ensure_col(x).expression for x in v]
150                if is_iterable(v)
151                else Column.ensure_col(v).expression
152            )
153            for k, v in kwargs.items()
154            if v is not None
155        }
156        new_expression = (
157            callable_expression(**ensure_expression_values)
158            if ensured_column is None
159            else callable_expression(
160                this=ensured_column.column_expression, **ensure_expression_values
161            )
162        )
163        return Column(new_expression)
164
165    def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
166        return Column(
167            klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
168        )
169
170    def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
171        return Column(
172            klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
173        )
174
175    def unary_op(self, klass: t.Callable, **kwargs) -> Column:
176        return Column(klass(this=self.column_expression, **kwargs))
177
178    @property
179    def is_alias(self):
180        return isinstance(self.expression, exp.Alias)
181
182    @property
183    def is_column(self):
184        return isinstance(self.expression, exp.Column)
185
186    @property
187    def column_expression(self) -> t.Union[exp.Column, exp.Literal]:
188        return self.expression.unalias()
189
190    @property
191    def alias_or_name(self) -> str:
192        return self.expression.alias_or_name
193
194    @classmethod
195    def ensure_literal(cls, value) -> Column:
196        from sqlglot.dataframe.sql.functions import lit
197
198        if isinstance(value, cls):
199            value = value.expression
200        if not isinstance(value, exp.Literal):
201            return lit(value)
202        return Column(value)
203
204    def copy(self) -> Column:
205        return Column(self.expression.copy())
206
207    def set_table_name(self, table_name: str, copy=False) -> Column:
208        expression = self.expression.copy() if copy else self.expression
209        expression.set("table", exp.to_identifier(table_name))
210        return Column(expression)
211
212    def sql(self, **kwargs) -> str:
213        from sqlglot.dataframe.sql.session import SparkSession
214
215        return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
216
217    def alias(self, name: str) -> Column:
218        from sqlglot.dataframe.sql.session import SparkSession
219
220        dialect = SparkSession().dialect
221        alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect)
222        new_expression = exp.alias_(
223            self.column_expression,
224            alias.this if isinstance(alias, exp.Column) else name,
225            dialect=dialect,
226        )
227        return Column(new_expression)
228
229    def asc(self) -> Column:
230        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
231        return Column(new_expression)
232
233    def desc(self) -> Column:
234        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
235        return Column(new_expression)
236
237    asc_nulls_first = asc
238
239    def asc_nulls_last(self) -> Column:
240        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
241        return Column(new_expression)
242
243    def desc_nulls_first(self) -> Column:
244        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
245        return Column(new_expression)
246
247    desc_nulls_last = desc
248
249    def when(self, condition: Column, value: t.Any) -> Column:
250        from sqlglot.dataframe.sql.functions import when
251
252        column_with_if = when(condition, value)
253        if not isinstance(self.expression, exp.Case):
254            return column_with_if
255        new_column = self.copy()
256        new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
257        return new_column
258
259    def otherwise(self, value: t.Any) -> Column:
260        from sqlglot.dataframe.sql.functions import lit
261
262        true_value = value if isinstance(value, Column) else lit(value)
263        new_column = self.copy()
264        new_column.expression.set("default", true_value.column_expression)
265        return new_column
266
267    def isNull(self) -> Column:
268        new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
269        return Column(new_expression)
270
271    def isNotNull(self) -> Column:
272        new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
273        return Column(new_expression)
274
275    def cast(self, dataType: t.Union[str, DataType]) -> Column:
276        """
277        Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
278        Sqlglot doesn't currently replicate this class so it only accepts a string
279        """
280        from sqlglot.dataframe.sql.session import SparkSession
281
282        if isinstance(dataType, DataType):
283            dataType = dataType.simpleString()
284        return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))
285
286    def startswith(self, value: t.Union[str, Column]) -> Column:
287        value = self._lit(value) if not isinstance(value, Column) else value
288        return self.invoke_anonymous_function(self, "STARTSWITH", value)
289
290    def endswith(self, value: t.Union[str, Column]) -> Column:
291        value = self._lit(value) if not isinstance(value, Column) else value
292        return self.invoke_anonymous_function(self, "ENDSWITH", value)
293
294    def rlike(self, regexp: str) -> Column:
295        return self.invoke_expression_over_column(
296            column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
297        )
298
299    def like(self, other: str):
300        return self.invoke_expression_over_column(
301            self, exp.Like, expression=self._lit(other).expression
302        )
303
304    def ilike(self, other: str):
305        return self.invoke_expression_over_column(
306            self, exp.ILike, expression=self._lit(other).expression
307        )
308
309    def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
310        startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
311        length = self._lit(length) if not isinstance(length, Column) else length
312        return Column.invoke_expression_over_column(
313            self, exp.Substring, start=startPos.expression, length=length.expression
314        )
315
316    def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
317        columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
318        expressions = [self._lit(x).expression for x in columns]
319        return Column.invoke_expression_over_column(self, exp.In, expressions=expressions)  # type: ignore
320
321    def between(
322        self,
323        lowerBound: t.Union[ColumnOrLiteral],
324        upperBound: t.Union[ColumnOrLiteral],
325    ) -> Column:
326        lower_bound_exp = (
327            self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
328        )
329        upper_bound_exp = (
330            self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
331        )
332        return Column(
333            exp.Between(
334                this=self.column_expression,
335                low=lower_bound_exp.expression,
336                high=upper_bound_exp.expression,
337            )
338        )
339
340    def over(self, window: WindowSpec) -> Column:
341        window_expression = window.expression.copy()
342        window_expression.set("this", self.column_expression)
343        return Column(window_expression)
Column( expression: Union[<MagicMock id='140281899179968'>, sqlglot.expressions.Expression, NoneType])
17    def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
18        from sqlglot.dataframe.sql.session import SparkSession
19
20        if isinstance(expression, Column):
21            expression = expression.expression  # type: ignore
22        elif expression is None or not isinstance(expression, (str, exp.Expression)):
23            expression = self._lit(expression).expression  # type: ignore
24        elif not isinstance(expression, exp.Column):
25            expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform(
26                SparkSession().dialect.normalize_identifier, copy=False
27            )
28        if expression is None:
29            raise ValueError(f"Could not parse {expression}")
30
31        self.expression: exp.Expression = expression  # type: ignore
@classmethod
def ensure_col( cls, value: Union[<MagicMock id='140281899179968'>, sqlglot.expressions.Expression, NoneType]) -> Column:
117    @classmethod
118    def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column:
119        return cls(value)
@classmethod
def ensure_cols( cls, args: List[Union[<MagicMock id='140281899179968'>, sqlglot.expressions.Expression]]) -> List[Column]:
121    @classmethod
122    def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
123        return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
@classmethod
def invoke_anonymous_function( cls, column: Optional[<MagicMock id='140281899179968'>], func_name: str, *args: Optional[<MagicMock id='140281899179968'>]) -> Column:
132    @classmethod
133    def invoke_anonymous_function(
134        cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
135    ) -> Column:
136        columns = [] if column is None else [cls.ensure_col(column)]
137        column_args = [cls.ensure_col(arg) for arg in args]
138        expressions = [x.expression for x in columns + column_args]
139        new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
140        return Column(new_expression)
@classmethod
def invoke_expression_over_column( cls, column: Optional[<MagicMock id='140281899179968'>], callable_expression: Callable, **kwargs) -> Column:
142    @classmethod
143    def invoke_expression_over_column(
144        cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
145    ) -> Column:
146        ensured_column = None if column is None else cls.ensure_col(column)
147        ensure_expression_values = {
148            k: (
149                [Column.ensure_col(x).expression for x in v]
150                if is_iterable(v)
151                else Column.ensure_col(v).expression
152            )
153            for k, v in kwargs.items()
154            if v is not None
155        }
156        new_expression = (
157            callable_expression(**ensure_expression_values)
158            if ensured_column is None
159            else callable_expression(
160                this=ensured_column.column_expression, **ensure_expression_values
161            )
162        )
163        return Column(new_expression)
def binary_op( self, klass: Callable, other: <MagicMock id='140281899179968'>, **kwargs) -> Column:
165    def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
166        return Column(
167            klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
168        )
def inverse_binary_op( self, klass: Callable, other: <MagicMock id='140281899179968'>, **kwargs) -> Column:
170    def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
171        return Column(
172            klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
173        )
def unary_op(self, klass: Callable, **kwargs) -> Column:
175    def unary_op(self, klass: t.Callable, **kwargs) -> Column:
176        return Column(klass(this=self.column_expression, **kwargs))
is_alias
178    @property
179    def is_alias(self):
180        return isinstance(self.expression, exp.Alias)
is_column
182    @property
183    def is_column(self):
184        return isinstance(self.expression, exp.Column)
column_expression: Union[sqlglot.expressions.Column, sqlglot.expressions.Literal]
186    @property
187    def column_expression(self) -> t.Union[exp.Column, exp.Literal]:
188        return self.expression.unalias()
alias_or_name: str
190    @property
191    def alias_or_name(self) -> str:
192        return self.expression.alias_or_name
@classmethod
def ensure_literal(cls, value) -> Column:
194    @classmethod
195    def ensure_literal(cls, value) -> Column:
196        from sqlglot.dataframe.sql.functions import lit
197
198        if isinstance(value, cls):
199            value = value.expression
200        if not isinstance(value, exp.Literal):
201            return lit(value)
202        return Column(value)
def copy(self) -> Column:
204    def copy(self) -> Column:
205        return Column(self.expression.copy())
def set_table_name(self, table_name: str, copy=False) -> Column:
207    def set_table_name(self, table_name: str, copy=False) -> Column:
208        expression = self.expression.copy() if copy else self.expression
209        expression.set("table", exp.to_identifier(table_name))
210        return Column(expression)
def sql(self, **kwargs) -> str:
212    def sql(self, **kwargs) -> str:
213        from sqlglot.dataframe.sql.session import SparkSession
214
215        return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
def alias(self, name: str) -> Column:
217    def alias(self, name: str) -> Column:
218        from sqlglot.dataframe.sql.session import SparkSession
219
220        dialect = SparkSession().dialect
221        alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect)
222        new_expression = exp.alias_(
223            self.column_expression,
224            alias.this if isinstance(alias, exp.Column) else name,
225            dialect=dialect,
226        )
227        return Column(new_expression)
def asc(self) -> Column:
229    def asc(self) -> Column:
230        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
231        return Column(new_expression)
def desc(self) -> Column:
233    def desc(self) -> Column:
234        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
235        return Column(new_expression)
def asc_nulls_first(self) -> Column:
229    def asc(self) -> Column:
230        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
231        return Column(new_expression)
def asc_nulls_last(self) -> Column:
239    def asc_nulls_last(self) -> Column:
240        new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
241        return Column(new_expression)
def desc_nulls_first(self) -> Column:
243    def desc_nulls_first(self) -> Column:
244        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
245        return Column(new_expression)
def desc_nulls_last(self) -> Column:
233    def desc(self) -> Column:
234        new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
235        return Column(new_expression)
def when( self, condition: Column, value: Any) -> Column:
249    def when(self, condition: Column, value: t.Any) -> Column:
250        from sqlglot.dataframe.sql.functions import when
251
252        column_with_if = when(condition, value)
253        if not isinstance(self.expression, exp.Case):
254            return column_with_if
255        new_column = self.copy()
256        new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
257        return new_column
def otherwise(self, value: Any) -> Column:
259    def otherwise(self, value: t.Any) -> Column:
260        from sqlglot.dataframe.sql.functions import lit
261
262        true_value = value if isinstance(value, Column) else lit(value)
263        new_column = self.copy()
264        new_column.expression.set("default", true_value.column_expression)
265        return new_column
def isNull(self) -> Column:
267    def isNull(self) -> Column:
268        new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
269        return Column(new_expression)
def isNotNull(self) -> Column:
271    def isNotNull(self) -> Column:
272        new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
273        return Column(new_expression)
def cast( self, dataType: Union[str, sqlglot.dataframe.sql.types.DataType]) -> Column:
275    def cast(self, dataType: t.Union[str, DataType]) -> Column:
276        """
277        Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
278        Sqlglot doesn't currently replicate this class so it only accepts a string
279        """
280        from sqlglot.dataframe.sql.session import SparkSession
281
282        if isinstance(dataType, DataType):
283            dataType = dataType.simpleString()
284        return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))

Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string

def startswith( self, value: Union[str, Column]) -> Column:
286    def startswith(self, value: t.Union[str, Column]) -> Column:
287        value = self._lit(value) if not isinstance(value, Column) else value
288        return self.invoke_anonymous_function(self, "STARTSWITH", value)
def endswith( self, value: Union[str, Column]) -> Column:
290    def endswith(self, value: t.Union[str, Column]) -> Column:
291        value = self._lit(value) if not isinstance(value, Column) else value
292        return self.invoke_anonymous_function(self, "ENDSWITH", value)
def rlike(self, regexp: str) -> Column:
294    def rlike(self, regexp: str) -> Column:
295        return self.invoke_expression_over_column(
296            column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
297        )
def like(self, other: str):
299    def like(self, other: str):
300        return self.invoke_expression_over_column(
301            self, exp.Like, expression=self._lit(other).expression
302        )
def ilike(self, other: str):
304    def ilike(self, other: str):
305        return self.invoke_expression_over_column(
306            self, exp.ILike, expression=self._lit(other).expression
307        )
def substr( self, startPos: Union[int, Column], length: Union[int, Column]) -> Column:
309    def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
310        startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
311        length = self._lit(length) if not isinstance(length, Column) else length
312        return Column.invoke_expression_over_column(
313            self, exp.Substring, start=startPos.expression, length=length.expression
314        )
def isin( self, *cols: Union[<MagicMock id='140281899179968'>, Iterable[<MagicMock id='140281899179968'>]]):
316    def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
317        columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
318        expressions = [self._lit(x).expression for x in columns]
319        return Column.invoke_expression_over_column(self, exp.In, expressions=expressions)  # type: ignore
def between( self, lowerBound: <MagicMock id='140281899179968'>, upperBound: <MagicMock id='140281899179968'>) -> Column:
321    def between(
322        self,
323        lowerBound: t.Union[ColumnOrLiteral],
324        upperBound: t.Union[ColumnOrLiteral],
325    ) -> Column:
326        lower_bound_exp = (
327            self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
328        )
329        upper_bound_exp = (
330            self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
331        )
332        return Column(
333            exp.Between(
334                this=self.column_expression,
335                low=lower_bound_exp.expression,
336                high=upper_bound_exp.expression,
337            )
338        )
def over( self, window: <MagicMock id='140281895717616'>) -> Column:
340    def over(self, window: WindowSpec) -> Column:
341        window_expression = window.expression.copy()
342        window_expression.set("this", self.column_expression)
343        return Column(window_expression)
class DataFrameNaFunctions:
834class DataFrameNaFunctions:
835    def __init__(self, df: DataFrame):
836        self.df = df
837
838    def drop(
839        self,
840        how: str = "any",
841        thresh: t.Optional[int] = None,
842        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
843    ) -> DataFrame:
844        return self.df.dropna(how=how, thresh=thresh, subset=subset)
845
846    def fill(
847        self,
848        value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
849        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
850    ) -> DataFrame:
851        return self.df.fillna(value=value, subset=subset)
852
853    def replace(
854        self,
855        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
856        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
857        subset: t.Optional[t.Union[str, t.List[str]]] = None,
858    ) -> DataFrame:
859        return self.df.replace(to_replace=to_replace, value=value, subset=subset)
DataFrameNaFunctions(df: DataFrame)
835    def __init__(self, df: DataFrame):
836        self.df = df
df
def drop( self, how: str = 'any', thresh: Optional[int] = None, subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> DataFrame:
838    def drop(
839        self,
840        how: str = "any",
841        thresh: t.Optional[int] = None,
842        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
843    ) -> DataFrame:
844        return self.df.dropna(how=how, thresh=thresh, subset=subset)
def fill( self, value: Union[int, bool, float, str, Dict[str, Any]], subset: Union[str, Tuple[str, ...], List[str], NoneType] = None) -> DataFrame:
846    def fill(
847        self,
848        value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
849        subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
850    ) -> DataFrame:
851        return self.df.fillna(value=value, subset=subset)
def replace( self, to_replace: Union[bool, int, float, str, List, Dict], value: Union[bool, int, float, str, List, NoneType] = None, subset: Union[str, List[str], NoneType] = None) -> DataFrame:
853    def replace(
854        self,
855        to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
856        value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
857        subset: t.Optional[t.Union[str, t.List[str]]] = None,
858    ) -> DataFrame:
859        return self.df.replace(to_replace=to_replace, value=value, subset=subset)
class Window:
15class Window:
16    _JAVA_MIN_LONG = -(1 << 63)  # -9223372036854775808
17    _JAVA_MAX_LONG = (1 << 63) - 1  # 9223372036854775807
18    _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
19    _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
20
21    unboundedPreceding: int = _JAVA_MIN_LONG
22
23    unboundedFollowing: int = _JAVA_MAX_LONG
24
25    currentRow: int = 0
26
27    @classmethod
28    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
29        return WindowSpec().partitionBy(*cols)
30
31    @classmethod
32    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
33        return WindowSpec().orderBy(*cols)
34
35    @classmethod
36    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
37        return WindowSpec().rowsBetween(start, end)
38
39    @classmethod
40    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
41        return WindowSpec().rangeBetween(start, end)
unboundedPreceding: int = -9223372036854775808
unboundedFollowing: int = 9223372036854775807
currentRow: int = 0
@classmethod
def partitionBy( cls, *cols: Union[<MagicMock id='140281896098960'>, List[<MagicMock id='140281896098960'>]]) -> WindowSpec:
27    @classmethod
28    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
29        return WindowSpec().partitionBy(*cols)
@classmethod
def orderBy( cls, *cols: Union[<MagicMock id='140281896098960'>, List[<MagicMock id='140281896098960'>]]) -> WindowSpec:
31    @classmethod
32    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
33        return WindowSpec().orderBy(*cols)
@classmethod
def rowsBetween(cls, start: int, end: int) -> WindowSpec:
35    @classmethod
36    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
37        return WindowSpec().rowsBetween(start, end)
@classmethod
def rangeBetween(cls, start: int, end: int) -> WindowSpec:
39    @classmethod
40    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
41        return WindowSpec().rangeBetween(start, end)
class WindowSpec:
 44class WindowSpec:
 45    def __init__(self, expression: exp.Expression = exp.Window()):
 46        self.expression = expression
 47
 48    def copy(self):
 49        return WindowSpec(self.expression.copy())
 50
 51    def sql(self, **kwargs) -> str:
 52        from sqlglot.dataframe.sql.session import SparkSession
 53
 54        return self.expression.sql(dialect=SparkSession().dialect, **kwargs)
 55
 56    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
 57        from sqlglot.dataframe.sql.column import Column
 58
 59        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
 60        expressions = [Column.ensure_col(x).expression for x in cols]
 61        window_spec = self.copy()
 62        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
 63        partition_by_expressions.extend(expressions)
 64        window_spec.expression.set("partition_by", partition_by_expressions)
 65        return window_spec
 66
 67    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
 68        from sqlglot.dataframe.sql.column import Column
 69
 70        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
 71        expressions = [Column.ensure_col(x).expression for x in cols]
 72        window_spec = self.copy()
 73        if window_spec.expression.args.get("order") is None:
 74            window_spec.expression.set("order", exp.Order(expressions=[]))
 75        order_by = window_spec.expression.args["order"].expressions
 76        order_by.extend(expressions)
 77        window_spec.expression.args["order"].set("expressions", order_by)
 78        return window_spec
 79
 80    def _calc_start_end(
 81        self, start: int, end: int
 82    ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
 83        kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
 84            "start_side": None,
 85            "end_side": None,
 86        }
 87        if start == Window.currentRow:
 88            kwargs["start"] = "CURRENT ROW"
 89        else:
 90            kwargs = {
 91                **kwargs,
 92                **{
 93                    "start_side": "PRECEDING",
 94                    "start": (
 95                        "UNBOUNDED"
 96                        if start <= Window.unboundedPreceding
 97                        else F.lit(start).expression
 98                    ),
 99                },
100            }
101        if end == Window.currentRow:
102            kwargs["end"] = "CURRENT ROW"
103        else:
104            kwargs = {
105                **kwargs,
106                **{
107                    "end_side": "FOLLOWING",
108                    "end": (
109                        "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression
110                    ),
111                },
112            }
113        return kwargs
114
115    def rowsBetween(self, start: int, end: int) -> WindowSpec:
116        window_spec = self.copy()
117        spec = self._calc_start_end(start, end)
118        spec["kind"] = "ROWS"
119        window_spec.expression.set(
120            "spec",
121            exp.WindowSpec(
122                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
123            ),
124        )
125        return window_spec
126
127    def rangeBetween(self, start: int, end: int) -> WindowSpec:
128        window_spec = self.copy()
129        spec = self._calc_start_end(start, end)
130        spec["kind"] = "RANGE"
131        window_spec.expression.set(
132            "spec",
133            exp.WindowSpec(
134                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
135            ),
136        )
137        return window_spec
WindowSpec(expression: sqlglot.expressions.Expression = Window())
45    def __init__(self, expression: exp.Expression = exp.Window()):
46        self.expression = expression
expression
def copy(self):
48    def copy(self):
49        return WindowSpec(self.expression.copy())
def sql(self, **kwargs) -> str:
51    def sql(self, **kwargs) -> str:
52        from sqlglot.dataframe.sql.session import SparkSession
53
54        return self.expression.sql(dialect=SparkSession().dialect, **kwargs)
def partitionBy( self, *cols: Union[<MagicMock id='140281896098960'>, List[<MagicMock id='140281896098960'>]]) -> WindowSpec:
56    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
57        from sqlglot.dataframe.sql.column import Column
58
59        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
60        expressions = [Column.ensure_col(x).expression for x in cols]
61        window_spec = self.copy()
62        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
63        partition_by_expressions.extend(expressions)
64        window_spec.expression.set("partition_by", partition_by_expressions)
65        return window_spec
def orderBy( self, *cols: Union[<MagicMock id='140281896098960'>, List[<MagicMock id='140281896098960'>]]) -> WindowSpec:
67    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
68        from sqlglot.dataframe.sql.column import Column
69
70        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
71        expressions = [Column.ensure_col(x).expression for x in cols]
72        window_spec = self.copy()
73        if window_spec.expression.args.get("order") is None:
74            window_spec.expression.set("order", exp.Order(expressions=[]))
75        order_by = window_spec.expression.args["order"].expressions
76        order_by.extend(expressions)
77        window_spec.expression.args["order"].set("expressions", order_by)
78        return window_spec
def rowsBetween(self, start: int, end: int) -> WindowSpec:
115    def rowsBetween(self, start: int, end: int) -> WindowSpec:
116        window_spec = self.copy()
117        spec = self._calc_start_end(start, end)
118        spec["kind"] = "ROWS"
119        window_spec.expression.set(
120            "spec",
121            exp.WindowSpec(
122                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
123            ),
124        )
125        return window_spec
def rangeBetween(self, start: int, end: int) -> WindowSpec:
127    def rangeBetween(self, start: int, end: int) -> WindowSpec:
128        window_spec = self.copy()
129        spec = self._calc_start_end(start, end)
130        spec["kind"] = "RANGE"
131        window_spec.expression.set(
132            "spec",
133            exp.WindowSpec(
134                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
135            ),
136        )
137        return window_spec
class DataFrameReader:
15class DataFrameReader:
16    def __init__(self, spark: SparkSession):
17        self.spark = spark
18
19    def table(self, tableName: str) -> DataFrame:
20        from sqlglot.dataframe.sql.dataframe import DataFrame
21        from sqlglot.dataframe.sql.session import SparkSession
22
23        sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect)
24
25        return DataFrame(
26            self.spark,
27            exp.Select()
28            .from_(
29                exp.to_table(tableName, dialect=SparkSession().dialect).transform(
30                    SparkSession().dialect.normalize_identifier
31                )
32            )
33            .select(
34                *(
35                    column
36                    for column in sqlglot.schema.column_names(
37                        tableName, dialect=SparkSession().dialect
38                    )
39                )
40            ),
41        )
DataFrameReader(spark: SparkSession)
16    def __init__(self, spark: SparkSession):
17        self.spark = spark
spark
def table(self, tableName: str) -> DataFrame:
19    def table(self, tableName: str) -> DataFrame:
20        from sqlglot.dataframe.sql.dataframe import DataFrame
21        from sqlglot.dataframe.sql.session import SparkSession
22
23        sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect)
24
25        return DataFrame(
26            self.spark,
27            exp.Select()
28            .from_(
29                exp.to_table(tableName, dialect=SparkSession().dialect).transform(
30                    SparkSession().dialect.normalize_identifier
31                )
32            )
33            .select(
34                *(
35                    column
36                    for column in sqlglot.schema.column_names(
37                        tableName, dialect=SparkSession().dialect
38                    )
39                )
40            ),
41        )
class DataFrameWriter:
 44class DataFrameWriter:
 45    def __init__(
 46        self,
 47        df: DataFrame,
 48        spark: t.Optional[SparkSession] = None,
 49        mode: t.Optional[str] = None,
 50        by_name: bool = False,
 51    ):
 52        self._df = df
 53        self._spark = spark or df.spark
 54        self._mode = mode
 55        self._by_name = by_name
 56
 57    def copy(self, **kwargs) -> DataFrameWriter:
 58        return DataFrameWriter(
 59            **{
 60                k[1:] if k.startswith("_") else k: v
 61                for k, v in object_to_dict(self, **kwargs).items()
 62            }
 63        )
 64
 65    def sql(self, **kwargs) -> t.List[str]:
 66        return self._df.sql(**kwargs)
 67
 68    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
 69        return self.copy(_mode=saveMode)
 70
 71    @property
 72    def byName(self):
 73        return self.copy(by_name=True)
 74
 75    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
 76        from sqlglot.dataframe.sql.session import SparkSession
 77
 78        output_expression_container = exp.Insert(
 79            **{
 80                "this": exp.to_table(tableName),
 81                "overwrite": overwrite,
 82            }
 83        )
 84        df = self._df.copy(output_expression_container=output_expression_container)
 85        if self._by_name:
 86            columns = sqlglot.schema.column_names(
 87                tableName, only_visible=True, dialect=SparkSession().dialect
 88            )
 89            df = df._convert_leaf_to_cte().select(*columns)
 90
 91        return self.copy(_df=df)
 92
 93    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
 94        if format is not None:
 95            raise NotImplementedError("Providing Format in the save as table is not supported")
 96        exists, replace, mode = None, None, mode or str(self._mode)
 97        if mode == "append":
 98            return self.insertInto(name)
 99        if mode == "ignore":
100            exists = True
101        if mode == "overwrite":
102            replace = True
103        output_expression_container = exp.Create(
104            this=exp.to_table(name),
105            kind="TABLE",
106            exists=exists,
107            replace=replace,
108        )
109        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
DataFrameWriter( df: DataFrame, spark: Optional[SparkSession] = None, mode: Optional[str] = None, by_name: bool = False)
45    def __init__(
46        self,
47        df: DataFrame,
48        spark: t.Optional[SparkSession] = None,
49        mode: t.Optional[str] = None,
50        by_name: bool = False,
51    ):
52        self._df = df
53        self._spark = spark or df.spark
54        self._mode = mode
55        self._by_name = by_name
def copy(self, **kwargs) -> DataFrameWriter:
57    def copy(self, **kwargs) -> DataFrameWriter:
58        return DataFrameWriter(
59            **{
60                k[1:] if k.startswith("_") else k: v
61                for k, v in object_to_dict(self, **kwargs).items()
62            }
63        )
def sql(self, **kwargs) -> List[str]:
65    def sql(self, **kwargs) -> t.List[str]:
66        return self._df.sql(**kwargs)
def mode( self, saveMode: Optional[str]) -> DataFrameWriter:
68    def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
69        return self.copy(_mode=saveMode)
byName
71    @property
72    def byName(self):
73        return self.copy(by_name=True)
def insertInto( self, tableName: str, overwrite: Optional[bool] = None) -> DataFrameWriter:
75    def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
76        from sqlglot.dataframe.sql.session import SparkSession
77
78        output_expression_container = exp.Insert(
79            **{
80                "this": exp.to_table(tableName),
81                "overwrite": overwrite,
82            }
83        )
84        df = self._df.copy(output_expression_container=output_expression_container)
85        if self._by_name:
86            columns = sqlglot.schema.column_names(
87                tableName, only_visible=True, dialect=SparkSession().dialect
88            )
89            df = df._convert_leaf_to_cte().select(*columns)
90
91        return self.copy(_df=df)
def saveAsTable( self, name: str, format: Optional[str] = None, mode: Optional[str] = None):
 93    def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
 94        if format is not None:
 95            raise NotImplementedError("Providing Format in the save as table is not supported")
 96        exists, replace, mode = None, None, mode or str(self._mode)
 97        if mode == "append":
 98            return self.insertInto(name)
 99        if mode == "ignore":
100            exists = True
101        if mode == "overwrite":
102            replace = True
103        output_expression_container = exp.Create(
104            this=exp.to_table(name),
105            kind="TABLE",
106            exists=exists,
107            replace=replace,
108        )
109        return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))