Edit on GitHub

sqlglot.schema

  1from __future__ import annotations
  2
  3import abc
  4import typing as t
  5
  6import sqlglot
  7from sqlglot import expressions as exp
  8from sqlglot._typing import T
  9from sqlglot.dialects.dialect import Dialect
 10from sqlglot.errors import ParseError, SchemaError
 11from sqlglot.helper import dict_depth
 12from sqlglot.trie import TrieResult, in_trie, new_trie
 13
 14if t.TYPE_CHECKING:
 15    from sqlglot.dataframe.sql.types import StructType
 16    from sqlglot.dialects.dialect import DialectType
 17
 18    ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
 19
 20TABLE_ARGS = ("this", "db", "catalog")
 21
 22
 23class Schema(abc.ABC):
 24    """Abstract base class for database schemas"""
 25
 26    dialect: DialectType
 27
 28    @abc.abstractmethod
 29    def add_table(
 30        self,
 31        table: exp.Table | str,
 32        column_mapping: t.Optional[ColumnMapping] = None,
 33        dialect: DialectType = None,
 34        normalize: t.Optional[bool] = None,
 35    ) -> None:
 36        """
 37        Register or update a table. Some implementing classes may require column information to also be provided.
 38
 39        Args:
 40            table: the `Table` expression instance or string representing the table.
 41            column_mapping: a column mapping that describes the structure of the table.
 42            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 43            normalize: whether to normalize identifiers according to the dialect of interest.
 44        """
 45
 46    @abc.abstractmethod
 47    def column_names(
 48        self,
 49        table: exp.Table | str,
 50        only_visible: bool = False,
 51        dialect: DialectType = None,
 52        normalize: t.Optional[bool] = None,
 53    ) -> t.List[str]:
 54        """
 55        Get the column names for a table.
 56
 57        Args:
 58            table: the `Table` expression instance.
 59            only_visible: whether to include invisible columns.
 60            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 61            normalize: whether to normalize identifiers according to the dialect of interest.
 62
 63        Returns:
 64            The list of column names.
 65        """
 66
 67    @abc.abstractmethod
 68    def get_column_type(
 69        self,
 70        table: exp.Table | str,
 71        column: exp.Column,
 72        dialect: DialectType = None,
 73        normalize: t.Optional[bool] = None,
 74    ) -> exp.DataType:
 75        """
 76        Get the `sqlglot.exp.DataType` type of a column in the schema.
 77
 78        Args:
 79            table: the source table.
 80            column: the target column.
 81            dialect: the SQL dialect that will be used to parse `table` if it's a string.
 82            normalize: whether to normalize identifiers according to the dialect of interest.
 83
 84        Returns:
 85            The resulting column type.
 86        """
 87
 88    @property
 89    @abc.abstractmethod
 90    def supported_table_args(self) -> t.Tuple[str, ...]:
 91        """
 92        Table arguments this schema support, e.g. `("this", "db", "catalog")`
 93        """
 94
 95    @property
 96    def empty(self) -> bool:
 97        """Returns whether or not the schema is empty."""
 98        return True
 99
100
101class AbstractMappingSchema(t.Generic[T]):
102    def __init__(
103        self,
104        mapping: t.Optional[t.Dict] = None,
105    ) -> None:
106        self.mapping = mapping or {}
107        self.mapping_trie = new_trie(
108            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
109        )
110        self._supported_table_args: t.Tuple[str, ...] = tuple()
111
112    @property
113    def empty(self) -> bool:
114        return not self.mapping
115
116    def _depth(self) -> int:
117        return dict_depth(self.mapping)
118
119    @property
120    def supported_table_args(self) -> t.Tuple[str, ...]:
121        if not self._supported_table_args and self.mapping:
122            depth = self._depth()
123
124            if not depth:  # None
125                self._supported_table_args = tuple()
126            elif 1 <= depth <= 3:
127                self._supported_table_args = TABLE_ARGS[:depth]
128            else:
129                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
130
131        return self._supported_table_args
132
133    def table_parts(self, table: exp.Table) -> t.List[str]:
134        if isinstance(table.this, exp.ReadCSV):
135            return [table.this.name]
136        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
137
138    def find(
139        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
140    ) -> t.Optional[T]:
141        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
142        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
143
144        if value == TrieResult.FAILED:
145            return None
146
147        if value == TrieResult.PREFIX:
148            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
149
150            if len(possibilities) == 1:
151                parts.extend(possibilities[0])
152            else:
153                message = ", ".join(".".join(parts) for parts in possibilities)
154                if raise_on_missing:
155                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
156                return None
157
158        return self.nested_get(parts, raise_on_missing=raise_on_missing)
159
160    def nested_get(
161        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
162    ) -> t.Optional[t.Any]:
163        return nested_get(
164            d or self.mapping,
165            *zip(self.supported_table_args, reversed(parts)),
166            raise_on_missing=raise_on_missing,
167        )
168
169
170class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
171    """
172    Schema based on a nested mapping.
173
174    Args:
175        schema: Mapping in one of the following forms:
176            1. {table: {col: type}}
177            2. {db: {table: {col: type}}}
178            3. {catalog: {db: {table: {col: type}}}}
179            4. None - Tables will be added later
180        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
181            are assumed to be visible. The nesting should mirror that of the schema:
182            1. {table: set(*cols)}}
183            2. {db: {table: set(*cols)}}}
184            3. {catalog: {db: {table: set(*cols)}}}}
185        dialect: The dialect to be used for custom type mappings & parsing string arguments.
186        normalize: Whether to normalize identifier names according to the given dialect or not.
187    """
188
189    def __init__(
190        self,
191        schema: t.Optional[t.Dict] = None,
192        visible: t.Optional[t.Dict] = None,
193        dialect: DialectType = None,
194        normalize: bool = True,
195    ) -> None:
196        self.dialect = dialect
197        self.visible = visible or {}
198        self.normalize = normalize
199        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
200
201        super().__init__(self._normalize(schema or {}))
202
203    @classmethod
204    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
205        return MappingSchema(
206            schema=mapping_schema.mapping,
207            visible=mapping_schema.visible,
208            dialect=mapping_schema.dialect,
209            normalize=mapping_schema.normalize,
210        )
211
212    def copy(self, **kwargs) -> MappingSchema:
213        return MappingSchema(
214            **{  # type: ignore
215                "schema": self.mapping.copy(),
216                "visible": self.visible.copy(),
217                "dialect": self.dialect,
218                "normalize": self.normalize,
219                **kwargs,
220            }
221        )
222
223    def add_table(
224        self,
225        table: exp.Table | str,
226        column_mapping: t.Optional[ColumnMapping] = None,
227        dialect: DialectType = None,
228        normalize: t.Optional[bool] = None,
229    ) -> None:
230        """
231        Register or update a table. Updates are only performed if a new column mapping is provided.
232
233        Args:
234            table: the `Table` expression instance or string representing the table.
235            column_mapping: a column mapping that describes the structure of the table.
236            dialect: the SQL dialect that will be used to parse `table` if it's a string.
237            normalize: whether to normalize identifiers according to the dialect of interest.
238        """
239        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
240
241        normalized_column_mapping = {
242            self._normalize_name(key, dialect=dialect, normalize=normalize): value
243            for key, value in ensure_column_mapping(column_mapping).items()
244        }
245
246        schema = self.find(normalized_table, raise_on_missing=False)
247        if schema and not normalized_column_mapping:
248            return
249
250        parts = self.table_parts(normalized_table)
251
252        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
253        new_trie([parts], self.mapping_trie)
254
255    def column_names(
256        self,
257        table: exp.Table | str,
258        only_visible: bool = False,
259        dialect: DialectType = None,
260        normalize: t.Optional[bool] = None,
261    ) -> t.List[str]:
262        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
263
264        schema = self.find(normalized_table)
265        if schema is None:
266            return []
267
268        if not only_visible or not self.visible:
269            return list(schema)
270
271        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
272        return [col for col in schema if col in visible]
273
274    def get_column_type(
275        self,
276        table: exp.Table | str,
277        column: exp.Column,
278        dialect: DialectType = None,
279        normalize: t.Optional[bool] = None,
280    ) -> exp.DataType:
281        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
282
283        normalized_column_name = self._normalize_name(
284            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
285        )
286
287        table_schema = self.find(normalized_table, raise_on_missing=False)
288        if table_schema:
289            column_type = table_schema.get(normalized_column_name)
290
291            if isinstance(column_type, exp.DataType):
292                return column_type
293            elif isinstance(column_type, str):
294                return self._to_data_type(column_type.upper(), dialect=dialect)
295
296        return exp.DataType.build("unknown")
297
298    def _normalize(self, schema: t.Dict) -> t.Dict:
299        """
300        Normalizes all identifiers in the schema.
301
302        Args:
303            schema: the schema to normalize.
304
305        Returns:
306            The normalized schema mapping.
307        """
308        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
309
310        normalized_mapping: t.Dict = {}
311        for keys in flattened_schema:
312            columns = nested_get(schema, *zip(keys, keys))
313            assert columns is not None
314
315            normalized_keys = [
316                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
317            ]
318            for column_name, column_type in columns.items():
319                nested_set(
320                    normalized_mapping,
321                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
322                    column_type,
323                )
324
325        return normalized_mapping
326
327    def _normalize_table(
328        self,
329        table: exp.Table | str,
330        dialect: DialectType = None,
331        normalize: t.Optional[bool] = None,
332    ) -> exp.Table:
333        normalized_table = exp.maybe_parse(
334            table, into=exp.Table, dialect=dialect or self.dialect, copy=True
335        )
336
337        for arg in TABLE_ARGS:
338            value = normalized_table.args.get(arg)
339            if isinstance(value, (str, exp.Identifier)):
340                normalized_table.set(
341                    arg,
342                    exp.to_identifier(
343                        self._normalize_name(
344                            value, dialect=dialect, is_table=True, normalize=normalize
345                        )
346                    ),
347                )
348
349        return normalized_table
350
351    def _normalize_name(
352        self,
353        name: str | exp.Identifier,
354        dialect: DialectType = None,
355        is_table: bool = False,
356        normalize: t.Optional[bool] = None,
357    ) -> str:
358        dialect = dialect or self.dialect
359        normalize = self.normalize if normalize is None else normalize
360
361        try:
362            identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
363        except ParseError:
364            return name if isinstance(name, str) else name.name
365
366        name = identifier.name
367        if not normalize:
368            return name
369
370        # This can be useful for normalize_identifier
371        identifier.meta["is_table"] = is_table
372        return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
373
374    def _depth(self) -> int:
375        # The columns themselves are a mapping, but we don't want to include those
376        return super()._depth() - 1
377
378    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
379        """
380        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
381
382        Args:
383            schema_type: the type we want to convert.
384            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
385
386        Returns:
387            The resulting expression type.
388        """
389        if schema_type not in self._type_mapping_cache:
390            dialect = dialect or self.dialect
391
392            try:
393                expression = exp.DataType.build(schema_type, dialect=dialect)
394                self._type_mapping_cache[schema_type] = expression
395            except AttributeError:
396                in_dialect = f" in dialect {dialect}" if dialect else ""
397                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
398
399        return self._type_mapping_cache[schema_type]
400
401
402def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
403    if isinstance(schema, Schema):
404        return schema
405
406    return MappingSchema(schema, **kwargs)
407
408
409def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
410    if mapping is None:
411        return {}
412    elif isinstance(mapping, dict):
413        return mapping
414    elif isinstance(mapping, str):
415        col_name_type_strs = [x.strip() for x in mapping.split(",")]
416        return {
417            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
418            for name_type_str in col_name_type_strs
419        }
420    # Check if mapping looks like a DataFrame StructType
421    elif hasattr(mapping, "simpleString"):
422        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
423    elif isinstance(mapping, list):
424        return {x.strip(): None for x in mapping}
425
426    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
427
428
429def flatten_schema(
430    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
431) -> t.List[t.List[str]]:
432    tables = []
433    keys = keys or []
434
435    for k, v in schema.items():
436        if depth >= 2:
437            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
438        elif depth == 1:
439            tables.append(keys + [k])
440
441    return tables
442
443
444def nested_get(
445    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
446) -> t.Optional[t.Any]:
447    """
448    Get a value for a nested dictionary.
449
450    Args:
451        d: the dictionary to search.
452        *path: tuples of (name, key), where:
453            `key` is the key in the dictionary to get.
454            `name` is a string to use in the error if `key` isn't found.
455
456    Returns:
457        The value or None if it doesn't exist.
458    """
459    for name, key in path:
460        d = d.get(key)  # type: ignore
461        if d is None:
462            if raise_on_missing:
463                name = "table" if name == "this" else name
464                raise ValueError(f"Unknown {name}: {key}")
465            return None
466
467    return d
468
469
470def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
471    """
472    In-place set a value for a nested dictionary
473
474    Example:
475        >>> nested_set({}, ["top_key", "second_key"], "value")
476        {'top_key': {'second_key': 'value'}}
477
478        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
479        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
480
481    Args:
482        d: dictionary to update.
483        keys: the keys that makeup the path to `value`.
484        value: the value to set in the dictionary for the given key path.
485
486    Returns:
487        The (possibly) updated dictionary.
488    """
489    if not keys:
490        return d
491
492    if len(keys) == 1:
493        d[keys[0]] = value
494        return d
495
496    subd = d
497    for key in keys[:-1]:
498        if key not in subd:
499            subd = subd.setdefault(key, {})
500        else:
501            subd = subd[key]
502
503    subd[keys[-1]] = value
504    return d
TABLE_ARGS = ('this', 'db', 'catalog')
class Schema(abc.ABC):
24class Schema(abc.ABC):
25    """Abstract base class for database schemas"""
26
27    dialect: DialectType
28
29    @abc.abstractmethod
30    def add_table(
31        self,
32        table: exp.Table | str,
33        column_mapping: t.Optional[ColumnMapping] = None,
34        dialect: DialectType = None,
35        normalize: t.Optional[bool] = None,
36    ) -> None:
37        """
38        Register or update a table. Some implementing classes may require column information to also be provided.
39
40        Args:
41            table: the `Table` expression instance or string representing the table.
42            column_mapping: a column mapping that describes the structure of the table.
43            dialect: the SQL dialect that will be used to parse `table` if it's a string.
44            normalize: whether to normalize identifiers according to the dialect of interest.
45        """
46
47    @abc.abstractmethod
48    def column_names(
49        self,
50        table: exp.Table | str,
51        only_visible: bool = False,
52        dialect: DialectType = None,
53        normalize: t.Optional[bool] = None,
54    ) -> t.List[str]:
55        """
56        Get the column names for a table.
57
58        Args:
59            table: the `Table` expression instance.
60            only_visible: whether to include invisible columns.
61            dialect: the SQL dialect that will be used to parse `table` if it's a string.
62            normalize: whether to normalize identifiers according to the dialect of interest.
63
64        Returns:
65            The list of column names.
66        """
67
68    @abc.abstractmethod
69    def get_column_type(
70        self,
71        table: exp.Table | str,
72        column: exp.Column,
73        dialect: DialectType = None,
74        normalize: t.Optional[bool] = None,
75    ) -> exp.DataType:
76        """
77        Get the `sqlglot.exp.DataType` type of a column in the schema.
78
79        Args:
80            table: the source table.
81            column: the target column.
82            dialect: the SQL dialect that will be used to parse `table` if it's a string.
83            normalize: whether to normalize identifiers according to the dialect of interest.
84
85        Returns:
86            The resulting column type.
87        """
88
89    @property
90    @abc.abstractmethod
91    def supported_table_args(self) -> t.Tuple[str, ...]:
92        """
93        Table arguments this schema support, e.g. `("this", "db", "catalog")`
94        """
95
96    @property
97    def empty(self) -> bool:
98        """Returns whether or not the schema is empty."""
99        return True

Abstract base class for database schemas

@abc.abstractmethod
def add_table( self, table: sqlglot.expressions.Table | str, column_mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> None:
29    @abc.abstractmethod
30    def add_table(
31        self,
32        table: exp.Table | str,
33        column_mapping: t.Optional[ColumnMapping] = None,
34        dialect: DialectType = None,
35        normalize: t.Optional[bool] = None,
36    ) -> None:
37        """
38        Register or update a table. Some implementing classes may require column information to also be provided.
39
40        Args:
41            table: the `Table` expression instance or string representing the table.
42            column_mapping: a column mapping that describes the structure of the table.
43            dialect: the SQL dialect that will be used to parse `table` if it's a string.
44            normalize: whether to normalize identifiers according to the dialect of interest.
45        """

Register or update a table. Some implementing classes may require column information to also be provided.

Arguments:
  • table: the Table expression instance or string representing the table.
  • column_mapping: a column mapping that describes the structure of the table.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
@abc.abstractmethod
def column_names( self, table: sqlglot.expressions.Table | str, only_visible: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> List[str]:
47    @abc.abstractmethod
48    def column_names(
49        self,
50        table: exp.Table | str,
51        only_visible: bool = False,
52        dialect: DialectType = None,
53        normalize: t.Optional[bool] = None,
54    ) -> t.List[str]:
55        """
56        Get the column names for a table.
57
58        Args:
59            table: the `Table` expression instance.
60            only_visible: whether to include invisible columns.
61            dialect: the SQL dialect that will be used to parse `table` if it's a string.
62            normalize: whether to normalize identifiers according to the dialect of interest.
63
64        Returns:
65            The list of column names.
66        """

Get the column names for a table.

Arguments:
  • table: the Table expression instance.
  • only_visible: whether to include invisible columns.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The list of column names.

@abc.abstractmethod
def get_column_type( self, table: sqlglot.expressions.Table | str, column: sqlglot.expressions.Column, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: Optional[bool] = None) -> sqlglot.expressions.DataType:
68    @abc.abstractmethod
69    def get_column_type(
70        self,
71        table: exp.Table | str,
72        column: exp.Column,
73        dialect: DialectType = None,
74        normalize: t.Optional[bool] = None,
75    ) -> exp.DataType:
76        """
77        Get the `sqlglot.exp.DataType` type of a column in the schema.
78
79        Args:
80            table: the source table.
81            column: the target column.
82            dialect: the SQL dialect that will be used to parse `table` if it's a string.
83            normalize: whether to normalize identifiers according to the dialect of interest.
84
85        Returns:
86            The resulting column type.
87        """

Get the sqlglot.exp.DataType type of a column in the schema.

Arguments:
  • table: the source table.
  • column: the target column.
  • dialect: the SQL dialect that will be used to parse table if it's a string.
  • normalize: whether to normalize identifiers according to the dialect of interest.
Returns:

The resulting column type.

supported_table_args: Tuple[str, ...]

Table arguments this schema support, e.g. ("this", "db", "catalog")

empty: bool

Returns whether or not the schema is empty.

class AbstractMappingSchema(typing.Generic[~T]):
102class AbstractMappingSchema(t.Generic[T]):
103    def __init__(
104        self,
105        mapping: t.Optional[t.Dict] = None,
106    ) -> None:
107        self.mapping = mapping or {}
108        self.mapping_trie = new_trie(
109            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
110        )
111        self._supported_table_args: t.Tuple[str, ...] = tuple()
112
113    @property
114    def empty(self) -> bool:
115        return not self.mapping
116
117    def _depth(self) -> int:
118        return dict_depth(self.mapping)
119
120    @property
121    def supported_table_args(self) -> t.Tuple[str, ...]:
122        if not self._supported_table_args and self.mapping:
123            depth = self._depth()
124
125            if not depth:  # None
126                self._supported_table_args = tuple()
127            elif 1 <= depth <= 3:
128                self._supported_table_args = TABLE_ARGS[:depth]
129            else:
130                raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
131
132        return self._supported_table_args
133
134    def table_parts(self, table: exp.Table) -> t.List[str]:
135        if isinstance(table.this, exp.ReadCSV):
136            return [table.this.name]
137        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
138
139    def find(
140        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
141    ) -> t.Optional[T]:
142        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
143        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
144
145        if value == TrieResult.FAILED:
146            return None
147
148        if value == TrieResult.PREFIX:
149            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
150
151            if len(possibilities) == 1:
152                parts.extend(possibilities[0])
153            else:
154                message = ", ".join(".".join(parts) for parts in possibilities)
155                if raise_on_missing:
156                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
157                return None
158
159        return self.nested_get(parts, raise_on_missing=raise_on_missing)
160
161    def nested_get(
162        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
163    ) -> t.Optional[t.Any]:
164        return nested_get(
165            d or self.mapping,
166            *zip(self.supported_table_args, reversed(parts)),
167            raise_on_missing=raise_on_missing,
168        )

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

AbstractMappingSchema(mapping: Optional[Dict] = None)
103    def __init__(
104        self,
105        mapping: t.Optional[t.Dict] = None,
106    ) -> None:
107        self.mapping = mapping or {}
108        self.mapping_trie = new_trie(
109            tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
110        )
111        self._supported_table_args: t.Tuple[str, ...] = tuple()
mapping
mapping_trie
empty: bool
supported_table_args: Tuple[str, ...]
def table_parts(self, table: sqlglot.expressions.Table) -> List[str]:
134    def table_parts(self, table: exp.Table) -> t.List[str]:
135        if isinstance(table.this, exp.ReadCSV):
136            return [table.this.name]
137        return [table.text(part) for part in TABLE_ARGS if table.text(part)]
def find( self, table: sqlglot.expressions.Table, trie: Optional[Dict] = None, raise_on_missing: bool = True) -> Optional[~T]:
139    def find(
140        self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
141    ) -> t.Optional[T]:
142        parts = self.table_parts(table)[0 : len(self.supported_table_args)]
143        value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
144
145        if value == TrieResult.FAILED:
146            return None
147
148        if value == TrieResult.PREFIX:
149            possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
150
151            if len(possibilities) == 1:
152                parts.extend(possibilities[0])
153            else:
154                message = ", ".join(".".join(parts) for parts in possibilities)
155                if raise_on_missing:
156                    raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
157                return None
158
159        return self.nested_get(parts, raise_on_missing=raise_on_missing)
def nested_get( self, parts: Sequence[str], d: Optional[Dict] = None, raise_on_missing=True) -> Optional[Any]:
161    def nested_get(
162        self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
163    ) -> t.Optional[t.Any]:
164        return nested_get(
165            d or self.mapping,
166            *zip(self.supported_table_args, reversed(parts)),
167            raise_on_missing=raise_on_missing,
168        )
171class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
172    """
173    Schema based on a nested mapping.
174
175    Args:
176        schema: Mapping in one of the following forms:
177            1. {table: {col: type}}
178            2. {db: {table: {col: type}}}
179            3. {catalog: {db: {table: {col: type}}}}
180            4. None - Tables will be added later
181        visible: Optional mapping of which columns in the schema are visible. If not provided, all columns
182            are assumed to be visible. The nesting should mirror that of the schema:
183            1. {table: set(*cols)}}
184            2. {db: {table: set(*cols)}}}
185            3. {catalog: {db: {table: set(*cols)}}}}
186        dialect: The dialect to be used for custom type mappings & parsing string arguments.
187        normalize: Whether to normalize identifier names according to the given dialect or not.
188    """
189
190    def __init__(
191        self,
192        schema: t.Optional[t.Dict] = None,
193        visible: t.Optional[t.Dict] = None,
194        dialect: DialectType = None,
195        normalize: bool = True,
196    ) -> None:
197        self.dialect = dialect
198        self.visible = visible or {}
199        self.normalize = normalize
200        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
201
202        super().__init__(self._normalize(schema or {}))
203
204    @classmethod
205    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
206        return MappingSchema(
207            schema=mapping_schema.mapping,
208            visible=mapping_schema.visible,
209            dialect=mapping_schema.dialect,
210            normalize=mapping_schema.normalize,
211        )
212
213    def copy(self, **kwargs) -> MappingSchema:
214        return MappingSchema(
215            **{  # type: ignore
216                "schema": self.mapping.copy(),
217                "visible": self.visible.copy(),
218                "dialect": self.dialect,
219                "normalize": self.normalize,
220                **kwargs,
221            }
222        )
223
224    def add_table(
225        self,
226        table: exp.Table | str,
227        column_mapping: t.Optional[ColumnMapping] = None,
228        dialect: DialectType = None,
229        normalize: t.Optional[bool] = None,
230    ) -> None:
231        """
232        Register or update a table. Updates are only performed if a new column mapping is provided.
233
234        Args:
235            table: the `Table` expression instance or string representing the table.
236            column_mapping: a column mapping that describes the structure of the table.
237            dialect: the SQL dialect that will be used to parse `table` if it's a string.
238            normalize: whether to normalize identifiers according to the dialect of interest.
239        """
240        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
241
242        normalized_column_mapping = {
243            self._normalize_name(key, dialect=dialect, normalize=normalize): value
244            for key, value in ensure_column_mapping(column_mapping).items()
245        }
246
247        schema = self.find(normalized_table, raise_on_missing=False)
248        if schema and not normalized_column_mapping:
249            return
250
251        parts = self.table_parts(normalized_table)
252
253        nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping)
254        new_trie([parts], self.mapping_trie)
255
256    def column_names(
257        self,
258        table: exp.Table | str,
259        only_visible: bool = False,
260        dialect: DialectType = None,
261        normalize: t.Optional[bool] = None,
262    ) -> t.List[str]:
263        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
264
265        schema = self.find(normalized_table)
266        if schema is None:
267            return []
268
269        if not only_visible or not self.visible:
270            return list(schema)
271
272        visible = self.nested_get(self.table_parts(normalized_table), self.visible) or []
273        return [col for col in schema if col in visible]
274
275    def get_column_type(
276        self,
277        table: exp.Table | str,
278        column: exp.Column,
279        dialect: DialectType = None,
280        normalize: t.Optional[bool] = None,
281    ) -> exp.DataType:
282        normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
283
284        normalized_column_name = self._normalize_name(
285            column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
286        )
287
288        table_schema = self.find(normalized_table, raise_on_missing=False)
289        if table_schema:
290            column_type = table_schema.get(normalized_column_name)
291
292            if isinstance(column_type, exp.DataType):
293                return column_type
294            elif isinstance(column_type, str):
295                return self._to_data_type(column_type.upper(), dialect=dialect)
296
297        return exp.DataType.build("unknown")
298
299    def _normalize(self, schema: t.Dict) -> t.Dict:
300        """
301        Normalizes all identifiers in the schema.
302
303        Args:
304            schema: the schema to normalize.
305
306        Returns:
307            The normalized schema mapping.
308        """
309        flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
310
311        normalized_mapping: t.Dict = {}
312        for keys in flattened_schema:
313            columns = nested_get(schema, *zip(keys, keys))
314            assert columns is not None
315
316            normalized_keys = [
317                self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
318            ]
319            for column_name, column_type in columns.items():
320                nested_set(
321                    normalized_mapping,
322                    normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)],
323                    column_type,
324                )
325
326        return normalized_mapping
327
328    def _normalize_table(
329        self,
330        table: exp.Table | str,
331        dialect: DialectType = None,
332        normalize: t.Optional[bool] = None,
333    ) -> exp.Table:
334        normalized_table = exp.maybe_parse(
335            table, into=exp.Table, dialect=dialect or self.dialect, copy=True
336        )
337
338        for arg in TABLE_ARGS:
339            value = normalized_table.args.get(arg)
340            if isinstance(value, (str, exp.Identifier)):
341                normalized_table.set(
342                    arg,
343                    exp.to_identifier(
344                        self._normalize_name(
345                            value, dialect=dialect, is_table=True, normalize=normalize
346                        )
347                    ),
348                )
349
350        return normalized_table
351
352    def _normalize_name(
353        self,
354        name: str | exp.Identifier,
355        dialect: DialectType = None,
356        is_table: bool = False,
357        normalize: t.Optional[bool] = None,
358    ) -> str:
359        dialect = dialect or self.dialect
360        normalize = self.normalize if normalize is None else normalize
361
362        try:
363            identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
364        except ParseError:
365            return name if isinstance(name, str) else name.name
366
367        name = identifier.name
368        if not normalize:
369            return name
370
371        # This can be useful for normalize_identifier
372        identifier.meta["is_table"] = is_table
373        return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
374
375    def _depth(self) -> int:
376        # The columns themselves are a mapping, but we don't want to include those
377        return super()._depth() - 1
378
379    def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
380        """
381        Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
382
383        Args:
384            schema_type: the type we want to convert.
385            dialect: the SQL dialect that will be used to parse `schema_type`, if needed.
386
387        Returns:
388            The resulting expression type.
389        """
390        if schema_type not in self._type_mapping_cache:
391            dialect = dialect or self.dialect
392
393            try:
394                expression = exp.DataType.build(schema_type, dialect=dialect)
395                self._type_mapping_cache[schema_type] = expression
396            except AttributeError:
397                in_dialect = f" in dialect {dialect}" if dialect else ""
398                raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.")
399
400        return self._type_mapping_cache[schema_type]

Schema based on a nested mapping.

Arguments:
  • schema: Mapping in one of the following forms:
    1. {table: {col: type}}
    2. {db: {table: {col: type}}}
    3. {catalog: {db: {table: {col: type}}}}
    4. None - Tables will be added later
  • visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema:
    1. {table: set(cols)}}
    2. {db: {table: set(cols)}}}
    3. {catalog: {db: {table: set(*cols)}}}}
  • dialect: The dialect to be used for custom type mappings & parsing string arguments.
  • normalize: Whether to normalize identifier names according to the given dialect or not.
MappingSchema( schema: Optional[Dict] = None, visible: Optional[Dict] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, normalize: bool = True)
190    def __init__(
191        self,
192        schema: t.Optional[t.Dict] = None,
193        visible: t.Optional[t.Dict] = None,
194        dialect: DialectType = None,
195        normalize: bool = True,
196    ) -> None:
197        self.dialect = dialect
198        self.visible = visible or {}
199        self.normalize = normalize
200        self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
201
202        super().__init__(self._normalize(schema or {}))
visible
normalize
@classmethod
def from_mapping_schema( cls, mapping_schema: sqlglot.schema.MappingSchema) -> sqlglot.schema.MappingSchema:
204    @classmethod
205    def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
206        return MappingSchema(
207            schema=mapping_schema.mapping,
208            visible=mapping_schema.visible,
209            dialect=mapping_schema.dialect,
210            normalize=mapping_schema.normalize,
211        )
def copy(self, **kwargs) -> sqlglot.schema.MappingSchema:
213    def copy(self, **kwargs) -> MappingSchema:
214        return MappingSchema(
215            **{  # type: ignore
216                "schema": self.mapping.copy(),
217                "visible": self.visible.copy(),
218                "dialect": self.dialect,
219                "normalize": self.normalize,
220                **kwargs,
221            }
222        )
def ensure_schema( schema: Union[sqlglot.schema.Schema, Dict, NoneType], **kwargs: Any) -> sqlglot.schema.Schema:
403def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
404    if isinstance(schema, Schema):
405        return schema
406
407    return MappingSchema(schema, **kwargs)
def ensure_column_mapping( mapping: Union[Dict, str, sqlglot.dataframe.sql.types.StructType, List, NoneType]) -> Dict:
410def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
411    if mapping is None:
412        return {}
413    elif isinstance(mapping, dict):
414        return mapping
415    elif isinstance(mapping, str):
416        col_name_type_strs = [x.strip() for x in mapping.split(",")]
417        return {
418            name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
419            for name_type_str in col_name_type_strs
420        }
421    # Check if mapping looks like a DataFrame StructType
422    elif hasattr(mapping, "simpleString"):
423        return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
424    elif isinstance(mapping, list):
425        return {x.strip(): None for x in mapping}
426
427    raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def flatten_schema( schema: Dict, depth: int, keys: Optional[List[str]] = None) -> List[List[str]]:
430def flatten_schema(
431    schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
432) -> t.List[t.List[str]]:
433    tables = []
434    keys = keys or []
435
436    for k, v in schema.items():
437        if depth >= 2:
438            tables.extend(flatten_schema(v, depth - 1, keys + [k]))
439        elif depth == 1:
440            tables.append(keys + [k])
441
442    return tables
def nested_get( d: Dict, *path: Tuple[str, str], raise_on_missing: bool = True) -> Optional[Any]:
445def nested_get(
446    d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
447) -> t.Optional[t.Any]:
448    """
449    Get a value for a nested dictionary.
450
451    Args:
452        d: the dictionary to search.
453        *path: tuples of (name, key), where:
454            `key` is the key in the dictionary to get.
455            `name` is a string to use in the error if `key` isn't found.
456
457    Returns:
458        The value or None if it doesn't exist.
459    """
460    for name, key in path:
461        d = d.get(key)  # type: ignore
462        if d is None:
463            if raise_on_missing:
464                name = "table" if name == "this" else name
465                raise ValueError(f"Unknown {name}: {key}")
466            return None
467
468    return d

Get a value for a nested dictionary.

Arguments:
  • d: the dictionary to search.
  • *path: tuples of (name, key), where: key is the key in the dictionary to get. name is a string to use in the error if key isn't found.
Returns:

The value or None if it doesn't exist.

def nested_set(d: Dict, keys: Sequence[str], value: Any) -> Dict:
471def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict:
472    """
473    In-place set a value for a nested dictionary
474
475    Example:
476        >>> nested_set({}, ["top_key", "second_key"], "value")
477        {'top_key': {'second_key': 'value'}}
478
479        >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
480        {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
481
482    Args:
483        d: dictionary to update.
484        keys: the keys that makeup the path to `value`.
485        value: the value to set in the dictionary for the given key path.
486
487    Returns:
488        The (possibly) updated dictionary.
489    """
490    if not keys:
491        return d
492
493    if len(keys) == 1:
494        d[keys[0]] = value
495        return d
496
497    subd = d
498    for key in keys[:-1]:
499        if key not in subd:
500            subd = subd.setdefault(key, {})
501        else:
502            subd = subd[key]
503
504    subd[keys[-1]] = value
505    return d

In-place set a value for a nested dictionary

Example:
>>> nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
  • d: dictionary to update.
  • keys: the keys that makeup the path to value.
  • value: the value to set in the dictionary for the given key path.
Returns:

The (possibly) updated dictionary.