Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import logging
  5import typing as t
  6from dataclasses import dataclass, field
  7
  8from sqlglot import Schema, exp, maybe_parse
  9from sqlglot.errors import SqlglotError
 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dialects.dialect import DialectType
 14
 15logger = logging.getLogger("sqlglot")
 16
 17
 18@dataclass(frozen=True)
 19class Node:
 20    name: str
 21    expression: exp.Expression
 22    source: exp.Expression
 23    downstream: t.List[Node] = field(default_factory=list)
 24    source_name: str = ""
 25    reference_node_name: str = ""
 26
 27    def walk(self) -> t.Iterator[Node]:
 28        yield self
 29
 30        for d in self.downstream:
 31            yield from d.walk()
 32
 33    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
 34        nodes = {}
 35        edges = []
 36
 37        for node in self.walk():
 38            if isinstance(node.expression, exp.Table):
 39                label = f"FROM {node.expression.this}"
 40                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
 41                group = 1
 42            else:
 43                label = node.expression.sql(pretty=True, dialect=dialect)
 44                source = node.source.transform(
 45                    lambda n: (
 46                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
 47                    ),
 48                    copy=False,
 49                ).sql(pretty=True, dialect=dialect)
 50                title = f"<pre>{source}</pre>"
 51                group = 0
 52
 53            node_id = id(node)
 54
 55            nodes[node_id] = {
 56                "id": node_id,
 57                "label": label,
 58                "title": title,
 59                "group": group,
 60            }
 61
 62            for d in node.downstream:
 63                edges.append({"from": node_id, "to": id(d)})
 64        return GraphHTML(nodes, edges, **opts)
 65
 66
 67def lineage(
 68    column: str | exp.Column,
 69    sql: str | exp.Expression,
 70    schema: t.Optional[t.Dict | Schema] = None,
 71    sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
 72    dialect: DialectType = None,
 73    **kwargs,
 74) -> Node:
 75    """Build the lineage graph for a column of a SQL query.
 76
 77    Args:
 78        column: The column to build the lineage for.
 79        sql: The SQL string or expression.
 80        schema: The schema of tables.
 81        sources: A mapping of queries which will be used to continue building lineage.
 82        dialect: The dialect of input SQL.
 83        **kwargs: Qualification optimizer kwargs.
 84
 85    Returns:
 86        A lineage node.
 87    """
 88
 89    expression = maybe_parse(sql, dialect=dialect)
 90    column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
 91
 92    if sources:
 93        expression = exp.expand(
 94            expression,
 95            {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
 96            dialect=dialect,
 97        )
 98
 99    qualified = qualify.qualify(
100        expression,
101        dialect=dialect,
102        schema=schema,
103        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
104    )
105
106    scope = build_scope(qualified)
107
108    if not scope:
109        raise SqlglotError("Cannot build lineage, sql must be SELECT")
110
111    if not any(select.alias_or_name == column for select in scope.expression.selects):
112        raise SqlglotError(f"Cannot find column '{column}' in query.")
113
114    return to_node(column, scope, dialect)
115
116
117def to_node(
118    column: str | int,
119    scope: Scope,
120    dialect: DialectType,
121    scope_name: t.Optional[str] = None,
122    upstream: t.Optional[Node] = None,
123    source_name: t.Optional[str] = None,
124    reference_node_name: t.Optional[str] = None,
125) -> Node:
126    source_names = {
127        dt.alias: dt.comments[0].split()[1]
128        for dt in scope.derived_tables
129        if dt.comments and dt.comments[0].startswith("source: ")
130    }
131
132    # Find the specific select clause that is the source of the column we want.
133    # This can either be a specific, named select or a generic `*` clause.
134    select = (
135        scope.expression.selects[column]
136        if isinstance(column, int)
137        else next(
138            (select for select in scope.expression.selects if select.alias_or_name == column),
139            exp.Star() if scope.expression.is_star else scope.expression,
140        )
141    )
142
143    if isinstance(scope.expression, exp.Subquery):
144        for source in scope.subquery_scopes:
145            return to_node(
146                column,
147                scope=source,
148                dialect=dialect,
149                upstream=upstream,
150                source_name=source_name,
151                reference_node_name=reference_node_name,
152            )
153    if isinstance(scope.expression, exp.Union):
154        upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
155
156        index = (
157            column
158            if isinstance(column, int)
159            else next(
160                (
161                    i
162                    for i, select in enumerate(scope.expression.selects)
163                    if select.alias_or_name == column or select.is_star
164                ),
165                -1,  # mypy will not allow a None here, but a negative index should never be returned
166            )
167        )
168
169        if index == -1:
170            raise ValueError(f"Could not find {column} in {scope.expression}")
171
172        for s in scope.union_scopes:
173            to_node(
174                index,
175                scope=s,
176                dialect=dialect,
177                upstream=upstream,
178                source_name=source_name,
179                reference_node_name=reference_node_name,
180            )
181
182        return upstream
183
184    if isinstance(scope.expression, exp.Select):
185        # For better ergonomics in our node labels, replace the full select with
186        # a version that has only the column we care about.
187        #   "x", SELECT x, y FROM foo
188        #     => "x", SELECT x FROM foo
189        source = t.cast(exp.Expression, scope.expression.select(select, append=False))
190    else:
191        source = scope.expression
192
193    # Create the node for this step in the lineage chain, and attach it to the previous one.
194    node = Node(
195        name=f"{scope_name}.{column}" if scope_name else str(column),
196        source=source,
197        expression=select,
198        source_name=source_name or "",
199        reference_node_name=reference_node_name or "",
200    )
201
202    if upstream:
203        upstream.downstream.append(node)
204
205    subquery_scopes = {
206        id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
207    }
208
209    for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
210        subquery_scope = subquery_scopes.get(id(subquery))
211        if not subquery_scope:
212            logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
213            continue
214
215        for name in subquery.named_selects:
216            to_node(name, scope=subquery_scope, dialect=dialect, upstream=node)
217
218    # if the select is a star add all scope sources as downstreams
219    if select.is_star:
220        for source in scope.sources.values():
221            if isinstance(source, Scope):
222                source = source.expression
223            node.downstream.append(Node(name=select.sql(), source=source, expression=source))
224
225    # Find all columns that went into creating this one to list their lineage nodes.
226    source_columns = set(find_all_in_scope(select, exp.Column))
227
228    # If the source is a UDTF find columns used in the UTDF to generate the table
229    if isinstance(source, exp.UDTF):
230        source_columns |= set(source.find_all(exp.Column))
231
232    for c in source_columns:
233        table = c.table
234        source = scope.sources.get(table)
235
236        if isinstance(source, Scope):
237            selected_node, _ = scope.selected_sources.get(table, (None, None))
238            # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
239            to_node(
240                c.name,
241                scope=source,
242                dialect=dialect,
243                scope_name=table,
244                upstream=node,
245                source_name=source_names.get(table) or source_name,
246                reference_node_name=selected_node.name if selected_node else None,
247            )
248        else:
249            # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
250            # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
251            # is not passed into the `sources` map.
252            source = source or exp.Placeholder()
253            node.downstream.append(Node(name=c.sql(), source=source, expression=source))
254
255    return node
256
257
258class GraphHTML:
259    """Node to HTML generator using vis.js.
260
261    https://visjs.github.io/vis-network/docs/network/
262    """
263
264    def __init__(
265        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
266    ):
267        self.imports = imports
268
269        self.options = {
270            "height": "500px",
271            "width": "100%",
272            "layout": {
273                "hierarchical": {
274                    "enabled": True,
275                    "nodeSpacing": 200,
276                    "sortMethod": "directed",
277                },
278            },
279            "interaction": {
280                "dragNodes": False,
281                "selectable": False,
282            },
283            "physics": {
284                "enabled": False,
285            },
286            "edges": {
287                "arrows": "to",
288            },
289            "nodes": {
290                "font": "20px monaco",
291                "shape": "box",
292                "widthConstraint": {
293                    "maximum": 300,
294                },
295            },
296            **(options or {}),
297        }
298
299        self.nodes = nodes
300        self.edges = edges
301
302    def __str__(self):
303        nodes = json.dumps(list(self.nodes.values()))
304        edges = json.dumps(self.edges)
305        options = json.dumps(self.options)
306        imports = (
307            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
308  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
309  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
310            if self.imports
311            else ""
312        )
313
314        return f"""<div>
315  <div id="sqlglot-lineage"></div>
316  {imports}
317  <script type="text/javascript">
318    var nodes = new vis.DataSet({nodes})
319    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
320
321    new vis.Network(
322        document.getElementById("sqlglot-lineage"),
323        {{
324            nodes: nodes,
325            edges: new vis.DataSet({edges})
326        }},
327        {options},
328    )
329  </script>
330</div>"""
331
332    def _repr_html_(self) -> str:
333        return self.__str__()
logger = <Logger sqlglot (WARNING)>
@dataclass(frozen=True)
class Node:
19@dataclass(frozen=True)
20class Node:
21    name: str
22    expression: exp.Expression
23    source: exp.Expression
24    downstream: t.List[Node] = field(default_factory=list)
25    source_name: str = ""
26    reference_node_name: str = ""
27
28    def walk(self) -> t.Iterator[Node]:
29        yield self
30
31        for d in self.downstream:
32            yield from d.walk()
33
34    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
35        nodes = {}
36        edges = []
37
38        for node in self.walk():
39            if isinstance(node.expression, exp.Table):
40                label = f"FROM {node.expression.this}"
41                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
42                group = 1
43            else:
44                label = node.expression.sql(pretty=True, dialect=dialect)
45                source = node.source.transform(
46                    lambda n: (
47                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
48                    ),
49                    copy=False,
50                ).sql(pretty=True, dialect=dialect)
51                title = f"<pre>{source}</pre>"
52                group = 0
53
54            node_id = id(node)
55
56            nodes[node_id] = {
57                "id": node_id,
58                "label": label,
59                "title": title,
60                "group": group,
61            }
62
63            for d in node.downstream:
64                edges.append({"from": node_id, "to": id(d)})
65        return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[Node] = <factory>, source_name: str = '', reference_node_name: str = '')
name: str
downstream: List[Node]
source_name: str = ''
reference_node_name: str = ''
def walk(self) -> Iterator[Node]:
28    def walk(self) -> t.Iterator[Node]:
29        yield self
30
31        for d in self.downstream:
32            yield from d.walk()
def to_html( self, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **opts) -> GraphHTML:
34    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
35        nodes = {}
36        edges = []
37
38        for node in self.walk():
39            if isinstance(node.expression, exp.Table):
40                label = f"FROM {node.expression.this}"
41                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
42                group = 1
43            else:
44                label = node.expression.sql(pretty=True, dialect=dialect)
45                source = node.source.transform(
46                    lambda n: (
47                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
48                    ),
49                    copy=False,
50                ).sql(pretty=True, dialect=dialect)
51                title = f"<pre>{source}</pre>"
52                group = 0
53
54            node_id = id(node)
55
56            nodes[node_id] = {
57                "id": node_id,
58                "label": label,
59                "title": title,
60                "group": group,
61            }
62
63            for d in node.downstream:
64                edges.append({"from": node_id, "to": id(d)})
65        return GraphHTML(nodes, edges, **opts)
def lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Query]] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **kwargs) -> Node:
 68def lineage(
 69    column: str | exp.Column,
 70    sql: str | exp.Expression,
 71    schema: t.Optional[t.Dict | Schema] = None,
 72    sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
 73    dialect: DialectType = None,
 74    **kwargs,
 75) -> Node:
 76    """Build the lineage graph for a column of a SQL query.
 77
 78    Args:
 79        column: The column to build the lineage for.
 80        sql: The SQL string or expression.
 81        schema: The schema of tables.
 82        sources: A mapping of queries which will be used to continue building lineage.
 83        dialect: The dialect of input SQL.
 84        **kwargs: Qualification optimizer kwargs.
 85
 86    Returns:
 87        A lineage node.
 88    """
 89
 90    expression = maybe_parse(sql, dialect=dialect)
 91    column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
 92
 93    if sources:
 94        expression = exp.expand(
 95            expression,
 96            {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
 97            dialect=dialect,
 98        )
 99
100    qualified = qualify.qualify(
101        expression,
102        dialect=dialect,
103        schema=schema,
104        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
105    )
106
107    scope = build_scope(qualified)
108
109    if not scope:
110        raise SqlglotError("Cannot build lineage, sql must be SELECT")
111
112    if not any(select.alias_or_name == column for select in scope.expression.selects):
113        raise SqlglotError(f"Cannot find column '{column}' in query.")
114
115    return to_node(column, scope, dialect)

Build the lineage graph for a column of a SQL query.

Arguments:
  • column: The column to build the lineage for.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • dialect: The dialect of input SQL.
  • **kwargs: Qualification optimizer kwargs.
Returns:

A lineage node.

def to_node( column: str | int, scope: sqlglot.optimizer.scope.Scope, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType], scope_name: Optional[str] = None, upstream: Optional[Node] = None, source_name: Optional[str] = None, reference_node_name: Optional[str] = None) -> Node:
118def to_node(
119    column: str | int,
120    scope: Scope,
121    dialect: DialectType,
122    scope_name: t.Optional[str] = None,
123    upstream: t.Optional[Node] = None,
124    source_name: t.Optional[str] = None,
125    reference_node_name: t.Optional[str] = None,
126) -> Node:
127    source_names = {
128        dt.alias: dt.comments[0].split()[1]
129        for dt in scope.derived_tables
130        if dt.comments and dt.comments[0].startswith("source: ")
131    }
132
133    # Find the specific select clause that is the source of the column we want.
134    # This can either be a specific, named select or a generic `*` clause.
135    select = (
136        scope.expression.selects[column]
137        if isinstance(column, int)
138        else next(
139            (select for select in scope.expression.selects if select.alias_or_name == column),
140            exp.Star() if scope.expression.is_star else scope.expression,
141        )
142    )
143
144    if isinstance(scope.expression, exp.Subquery):
145        for source in scope.subquery_scopes:
146            return to_node(
147                column,
148                scope=source,
149                dialect=dialect,
150                upstream=upstream,
151                source_name=source_name,
152                reference_node_name=reference_node_name,
153            )
154    if isinstance(scope.expression, exp.Union):
155        upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
156
157        index = (
158            column
159            if isinstance(column, int)
160            else next(
161                (
162                    i
163                    for i, select in enumerate(scope.expression.selects)
164                    if select.alias_or_name == column or select.is_star
165                ),
166                -1,  # mypy will not allow a None here, but a negative index should never be returned
167            )
168        )
169
170        if index == -1:
171            raise ValueError(f"Could not find {column} in {scope.expression}")
172
173        for s in scope.union_scopes:
174            to_node(
175                index,
176                scope=s,
177                dialect=dialect,
178                upstream=upstream,
179                source_name=source_name,
180                reference_node_name=reference_node_name,
181            )
182
183        return upstream
184
185    if isinstance(scope.expression, exp.Select):
186        # For better ergonomics in our node labels, replace the full select with
187        # a version that has only the column we care about.
188        #   "x", SELECT x, y FROM foo
189        #     => "x", SELECT x FROM foo
190        source = t.cast(exp.Expression, scope.expression.select(select, append=False))
191    else:
192        source = scope.expression
193
194    # Create the node for this step in the lineage chain, and attach it to the previous one.
195    node = Node(
196        name=f"{scope_name}.{column}" if scope_name else str(column),
197        source=source,
198        expression=select,
199        source_name=source_name or "",
200        reference_node_name=reference_node_name or "",
201    )
202
203    if upstream:
204        upstream.downstream.append(node)
205
206    subquery_scopes = {
207        id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
208    }
209
210    for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
211        subquery_scope = subquery_scopes.get(id(subquery))
212        if not subquery_scope:
213            logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
214            continue
215
216        for name in subquery.named_selects:
217            to_node(name, scope=subquery_scope, dialect=dialect, upstream=node)
218
219    # if the select is a star add all scope sources as downstreams
220    if select.is_star:
221        for source in scope.sources.values():
222            if isinstance(source, Scope):
223                source = source.expression
224            node.downstream.append(Node(name=select.sql(), source=source, expression=source))
225
226    # Find all columns that went into creating this one to list their lineage nodes.
227    source_columns = set(find_all_in_scope(select, exp.Column))
228
229    # If the source is a UDTF find columns used in the UTDF to generate the table
230    if isinstance(source, exp.UDTF):
231        source_columns |= set(source.find_all(exp.Column))
232
233    for c in source_columns:
234        table = c.table
235        source = scope.sources.get(table)
236
237        if isinstance(source, Scope):
238            selected_node, _ = scope.selected_sources.get(table, (None, None))
239            # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
240            to_node(
241                c.name,
242                scope=source,
243                dialect=dialect,
244                scope_name=table,
245                upstream=node,
246                source_name=source_names.get(table) or source_name,
247                reference_node_name=selected_node.name if selected_node else None,
248            )
249        else:
250            # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
251            # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
252            # is not passed into the `sources` map.
253            source = source or exp.Placeholder()
254            node.downstream.append(Node(name=c.sql(), source=source, expression=source))
255
256    return node
class GraphHTML:
259class GraphHTML:
260    """Node to HTML generator using vis.js.
261
262    https://visjs.github.io/vis-network/docs/network/
263    """
264
265    def __init__(
266        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
267    ):
268        self.imports = imports
269
270        self.options = {
271            "height": "500px",
272            "width": "100%",
273            "layout": {
274                "hierarchical": {
275                    "enabled": True,
276                    "nodeSpacing": 200,
277                    "sortMethod": "directed",
278                },
279            },
280            "interaction": {
281                "dragNodes": False,
282                "selectable": False,
283            },
284            "physics": {
285                "enabled": False,
286            },
287            "edges": {
288                "arrows": "to",
289            },
290            "nodes": {
291                "font": "20px monaco",
292                "shape": "box",
293                "widthConstraint": {
294                    "maximum": 300,
295                },
296            },
297            **(options or {}),
298        }
299
300        self.nodes = nodes
301        self.edges = edges
302
303    def __str__(self):
304        nodes = json.dumps(list(self.nodes.values()))
305        edges = json.dumps(self.edges)
306        options = json.dumps(self.options)
307        imports = (
308            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
309  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
310  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
311            if self.imports
312            else ""
313        )
314
315        return f"""<div>
316  <div id="sqlglot-lineage"></div>
317  {imports}
318  <script type="text/javascript">
319    var nodes = new vis.DataSet({nodes})
320    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
321
322    new vis.Network(
323        document.getElementById("sqlglot-lineage"),
324        {{
325            nodes: nodes,
326            edges: new vis.DataSet({edges})
327        }},
328        {options},
329    )
330  </script>
331</div>"""
332
333    def _repr_html_(self) -> str:
334        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

GraphHTML( nodes: Dict, edges: List, imports: bool = True, options: Optional[Dict] = None)
265    def __init__(
266        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
267    ):
268        self.imports = imports
269
270        self.options = {
271            "height": "500px",
272            "width": "100%",
273            "layout": {
274                "hierarchical": {
275                    "enabled": True,
276                    "nodeSpacing": 200,
277                    "sortMethod": "directed",
278                },
279            },
280            "interaction": {
281                "dragNodes": False,
282                "selectable": False,
283            },
284            "physics": {
285                "enabled": False,
286            },
287            "edges": {
288                "arrows": "to",
289            },
290            "nodes": {
291                "font": "20px monaco",
292                "shape": "box",
293                "widthConstraint": {
294                    "maximum": 300,
295                },
296            },
297            **(options or {}),
298        }
299
300        self.nodes = nodes
301        self.edges = edges
imports
options
nodes
edges