Skip to content

Commit

Permalink
Fix: expand alias refs was buggy and did the samething expand lateral… (
Browse files Browse the repository at this point in the history
#1599)

* Fix[BREAKING]: expand alias refs was buggy and did the samething expand laterals did

* mend
  • Loading branch information
tobymao authored May 12, 2023
1 parent f585eef commit 4dd413b
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 180 deletions.
11 changes: 7 additions & 4 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4835,21 +4835,24 @@ def paren(expression, copy=True) -> Paren:


@t.overload
def to_identifier(name: None, quoted: t.Optional[bool] = None) -> None:
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None:
...


@t.overload
def to_identifier(name: str | Identifier, quoted: t.Optional[bool] = None) -> Identifier:
def to_identifier(
name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True
) -> Identifier:
...


def to_identifier(name, quoted=None):
def to_identifier(name, quoted=None, copy=True):
"""Builds an identifier.
Args:
name: The name to turn into an identifier.
quoted: Whether or not force quote the identifier.
copy: Whether or not to copy a passed in Identefier node.
Returns:
The identifier ast node.
Expand All @@ -4859,7 +4862,7 @@ def to_identifier(name, quoted=None):
return None

if isinstance(name, Identifier):
identifier = name
identifier = _maybe_copy(name, copy)
elif isinstance(name, str):
identifier = Identifier(
this=name,
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from sqlglot import Schema, exp, maybe_parse
from sqlglot.optimizer import Scope, build_scope, optimize
from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables

Expand Down Expand Up @@ -40,7 +40,7 @@ def lineage(
sql: str | exp.Expression,
schema: t.Optional[t.Dict | Schema] = None,
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals),
rules: t.Sequence[t.Callable] = (lower_identities, qualify_tables, qualify_columns),
dialect: DialectType = None,
) -> Node:
"""Build the lineage graph for a column of a SQL query.
Expand Down
34 changes: 0 additions & 34 deletions sqlglot/optimizer/expand_laterals.py

This file was deleted.

123 changes: 44 additions & 79 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@

from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
from sqlglot.schema import Schema, ensure_schema


def qualify_columns(
expression: exp.Expression,
schema: dict | Schema,
expand_laterals: bool = True,
infer_schema: bool = True,
expand_alias_refs: bool = True,
infer_schema: t.Optional[bool] = None,
) -> exp.Expression:
"""
Rewrite sqlglot AST to have fully qualified columns.
Expand All @@ -29,32 +28,34 @@ def qualify_columns(
Args:
expression: expression to qualify
schema: Database schema
expand_laterals: whether or not to expand laterals
expand_alias_refs: whether or not to expand references to aliases
infer_schema: whether or not to infer the schema if missing
Returns:
sqlglot.Expression: qualified expression
"""
schema = ensure_schema(schema)

if schema.empty and expand_laterals:
expression = _expand_laterals(expression)
infer_schema = schema.empty if infer_schema is None else infer_schema

for scope in traverse_scope(expression):
resolver = Resolver(scope, schema, infer_schema=infer_schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
using_column_tables = _expand_using(scope, resolver)

if schema.empty and expand_alias_refs:
_expand_alias_refs(scope, resolver)

_qualify_columns(scope, resolver)

if not schema.empty and expand_alias_refs:
_expand_alias_refs(scope, resolver)

if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables)
_qualify_outputs(scope)
_expand_alias_refs(scope, resolver)
_expand_group_by(scope, resolver)
_expand_order_by(scope)

if not schema.empty and expand_laterals:
expression = _expand_laterals(expression)

return expression


Expand All @@ -66,7 +67,9 @@ def validate_qualify_columns(expression):
unqualified_columns.extend(scope.unqualified_columns)
if scope.external_columns and not scope.is_correlated_subquery:
column = scope.external_columns[0]
raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
raise OptimizeError(
f"""Column '{column}' could not be resolved{" for table: '{column.table}'" if column.table else ''}"""
)

if unqualified_columns:
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
Expand Down Expand Up @@ -158,45 +161,40 @@ def _expand_using(scope, resolver):
return column_tables


def _expand_alias_refs(scope, resolver):
selects = {}
def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
expression = scope.expression

# Replace references to select aliases
def transform(node, source_first=True):
if isinstance(node, exp.Column) and not node.table:
table = resolver.get_table(node.name)
if not isinstance(expression, exp.Select):
return

# Source columns get priority over select aliases
if source_first and table:
node.set("table", table)
return node
alias_to_expression: t.Dict[str, exp.Expression] = {}

if not selects:
for s in scope.selects:
selects[s.alias_or_name] = s
select = selects.get(node.name)
def replace_columns(
node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False
):
if not node:
return

if select:
scope.clear_cache()
if isinstance(select, exp.Alias):
select = select.this
return select.copy()
for column, *_ in walk_in_scope(node):
if not isinstance(column, exp.Column):
continue
table = resolver.get_table(column.name) if resolve_agg and not column.table else None
if table and column.find_ancestor(exp.AggFunc):
column.set("table", table)
elif expand and not column.table and column.name in alias_to_expression:
column.replace(alias_to_expression[column.name].copy())

node.set("table", table)
elif isinstance(node, exp.Expression) and not isinstance(node, exp.Subqueryable):
exp.replace_children(node, transform, source_first)
for projection in scope.selects:
replace_columns(projection)

return node
if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = projection.this

for select in scope.expression.selects:
transform(select)

for modifier, source_first in (
("where", True),
("group", True),
("having", False),
):
transform(scope.expression.args.get(modifier), source_first=source_first)
replace_columns(expression.args.get("where"))
replace_columns(expression.args.get("group"))
replace_columns(expression.args.get("having"), resolve_agg=True)
replace_columns(expression.args.get("order"), expand=False, resolve_agg=True)
scope.clear_cache()


def _expand_group_by(scope, resolver):
Expand Down Expand Up @@ -274,39 +272,6 @@ def _qualify_columns(scope, resolver):
if column_table:
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))

columns_missing_from_scope = []

# Determine whether each reference in the order by clause is to a column or an alias.
order = scope.expression.args.get("order")

if order:
for ordered in order.expressions:
for column in ordered.find_all(exp.Column):
if (
not column.table
and column.parent is not ordered
and column.name in resolver.all_columns
):
columns_missing_from_scope.append(column)

# Determine whether each reference in the having clause is to a column or an alias.
having = scope.expression.args.get("having")

if having:
for column in having.find_all(exp.Column):
if (
not column.table
and column.find_ancestor(exp.AggFunc)
and column.name in resolver.all_columns
):
columns_missing_from_scope.append(column)

for column in columns_missing_from_scope:
column_table = resolver.get_table(column.name)

if column_table:
column.set("table", column_table)


def _expand_stars(scope, resolver, using_column_tables):
"""Expand stars to lists of column selections"""
Expand Down Expand Up @@ -460,7 +425,7 @@ def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:

node_alias = node.args.get("alias")
if node_alias:
return node_alias.this
return exp.to_identifier(node_alias.this)

return exp.to_identifier(
table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
Expand Down
8 changes: 5 additions & 3 deletions sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,14 @@ def columns(self):

self._columns = []
for column in columns + external_columns:
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
ancestor = column.find_ancestor(
exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint
)
if (
not ancestor
# Window functions can have an ORDER BY clause
or not isinstance(ancestor.parent, exp.Select)
or column.table
or isinstance(ancestor, exp.Select)
or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window))
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
):
self._columns.append(column)
Expand Down
40 changes: 0 additions & 40 deletions tests/fixtures/optimizer/expand_laterals.sql

This file was deleted.

28 changes: 28 additions & 0 deletions tests/fixtures/optimizer/qualify_columns.sql
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,31 @@ SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x
# dialect: bigquery
SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1;
SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1;


--------------------------------------
-- Expand laterals
--------------------------------------
# title: expand alias reference
SELECT
x.a + 1 AS i,
i + 1 AS j,
j + 1 AS k
FROM x;
SELECT x.a + 1 AS i, x.a + 1 + 1 AS j, x.a + 1 + 1 + 1 AS k FROM x AS x;

# title: noop - reference comes before alias
# execute: false
SELECT i + 1 AS j, x.a + 1 AS i FROM x;
SELECT i + 1 AS j, x.a + 1 AS i FROM x AS x;

# title: subquery
SELECT
*
FROM (
SELECT
x.a + 1 AS i,
i + 1 AS j
FROM x
);
SELECT _q_0.i AS i, _q_0.j AS j FROM (SELECT x.a + 1 AS i, x.a + 1 + 1 AS j FROM x AS x) AS _q_0;
Loading

0 comments on commit 4dd413b

Please sign in to comment.