Skip to content

Commit

Permalink
Fix: union lineage with > 2 sources closes #1934
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Jul 20, 2023
1 parent 2411bd3 commit b8de650
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
44 changes: 28 additions & 16 deletions sqlglot/lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def lineage(
raise SqlglotError("Cannot build lineage, sql must be SELECT")

def to_node(
column_name: str,
column: str | int,
scope: Scope,
scope_name: t.Optional[str] = None,
upstream: t.Optional[Node] = None,
Expand All @@ -90,26 +90,38 @@ def to_node(
for dt in scope.derived_tables
if dt.comments and dt.comments[0].startswith("source: ")
}
if isinstance(scope.expression, exp.Union):
for scope in scope.union_scopes:
node = to_node(
column_name,
scope=scope,
scope_name=scope_name,
upstream=upstream,
alias=aliases.get(scope_name),
)
return node

# Find the specific select clause that is the source of the column we want.
# This can either be a specific, named select or a generic `*` clause.
select = next(
(select for select in scope.expression.selects if select.alias_or_name == column_name),
exp.Star() if scope.expression.is_star else None,
select = (
scope.expression.selects[column]
if isinstance(column, int)
else next(
(select for select in scope.expression.selects if select.alias_or_name == column),
exp.Star() if scope.expression.is_star else None,
)
)

if not select:
raise ValueError(f"Could not find {column_name} in {scope.expression}")
raise ValueError(f"Could not find {column} in {scope.expression}")

if isinstance(scope.expression, exp.Union):
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)

index = (
column
if isinstance(column, int)
else next(
i
for i, select in enumerate(scope.expression.selects)
if select.alias_or_name == column
)
)

for s in scope.union_scopes:
to_node(index, scope=s, upstream=upstream)

return upstream

if isinstance(scope.expression, exp.Select):
# For better ergonomics in our node labels, replace the full select with
Expand All @@ -122,7 +134,7 @@ def to_node(

# Create the node for this step in the lineage chain, and attach it to the previous one.
node = Node(
name=f"{scope_name}.{column_name}" if scope_name else column_name,
name=f"{scope_name}.{column}" if scope_name else str(column),
source=source,
expression=select,
alias=alias or "",
Expand Down
13 changes: 13 additions & 0 deletions tests/test_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,16 @@ def test_lineage_cte_name_appears_in_schema(self) -> None:
self.assertEqual(downstream.alias, "")

self.assertEqual(downstream.downstream, [])

def test_lineage_union(self) -> None:
node = lineage(
"x",
"SELECT ax AS x FROM a UNION SELECT bx FROM b UNION SELECT cx FROM c",
)
assert len(node.downstream) == 3

node = lineage(
"x",
"SELECT x FROM (SELECT ax AS x FROM a UNION SELECT bx FROM b UNION SELECT cx FROM c)",
)
assert len(node.downstream) == 3

0 comments on commit b8de650

Please sign in to comment.