Skip to content

Commit

Permalink
Fix(optimizer): ensure TableAlias column names shadow source columns (#…
Browse files Browse the repository at this point in the history
…2002)

* Fix(optimizer): TableAlias column names must shadow source columns

* Naming cleanup

* Refactor

* Fix bug

* More coverage

* Cleanup

* Improve coverage

* Formatting

* Add invalid qualify columns test

* Comment improvement
  • Loading branch information
georgesittas authored Aug 8, 2023
1 parent 900bec3 commit c73790d
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 43 deletions.
14 changes: 7 additions & 7 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,13 @@ def alias(self) -> str:
return self.args["alias"].name
return self.text("alias")

@property
def alias_column_names(self) -> t.List[str]:
table_alias = self.args.get("alias")
if not table_alias:
return []
return [c.name for c in table_alias.args.get("columns") or []]

@property
def name(self) -> str:
return self.text("this")
Expand Down Expand Up @@ -883,13 +890,6 @@ class Predicate(Condition):


class DerivedTable(Expression):
@property
def alias_column_names(self) -> t.List[str]:
table_alias = self.args.get("alias")
if not table_alias:
return []
return [c.name for c in table_alias.args.get("columns") or []]

@property
def selects(self) -> t.List[Expression]:
return self.this.selects if isinstance(self.this, Subqueryable) else []
Expand Down
19 changes: 14 additions & 5 deletions sqlglot/optimizer/pushdown_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
"""
# Map of Scope to all columns being selected by outer queries.
schema = ensure_schema(schema)
source_column_alias_count = {}
referenced_columns = defaultdict(set)

# We build the scope tree (which is traversed in DFS postorder), then iterate
# over the result in reverse order. This should ensure that the set of selected
# columns for a particular scope are completely build by the time we get to it.
for scope in reversed(traverse_scope(expression)):
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
alias_count = source_column_alias_count.get(scope, 0)

if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots:
if scope.expression.args.get("distinct") or (scope.parent and scope.parent.pivots):
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
# we select from a pivoted source in the parent scope.
parent_selections = {SELECT_ALL}
Expand All @@ -59,7 +61,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)

if isinstance(scope.expression, exp.Select):
if remove_unused_selections:
_remove_unused_selections(scope, parent_selections, schema)
_remove_unused_selections(scope, parent_selections, schema, alias_count)

if scope.expression.is_star:
continue
Expand All @@ -72,15 +74,19 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
selects[table_name].add(col_name)

# Push the selected columns down to the next scope
for name, (_, source) in scope.selected_sources.items():
for name, (node, source) in scope.selected_sources.items():
if isinstance(source, Scope):
columns = selects.get(name) or set()
referenced_columns[source].update(columns)

column_aliases = node.alias_column_names
if column_aliases:
source_column_alias_count[source] = len(column_aliases)

return expression


def _remove_unused_selections(scope, parent_selections, schema):
def _remove_unused_selections(scope, parent_selections, schema, alias_count):
order = scope.expression.args.get("order")

if order:
Expand All @@ -93,11 +99,14 @@ def _remove_unused_selections(scope, parent_selections, schema):
removed = False
star = False

select_all = SELECT_ALL in parent_selections

for selection in scope.expression.selects:
name = selection.alias_or_name

if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
if select_all or name in parent_selections or name in order_refs or alias_count > 0:
new_selections.append(selection)
alias_count -= 1
else:
if selection.is_star:
star = True
Expand Down
77 changes: 46 additions & 31 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def qualify_columns(
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
_qualify_outputs(scope)

_expand_group_by(scope)
_expand_order_by(scope, resolver)

Expand Down Expand Up @@ -86,7 +87,7 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) ->
"""
Remove table column aliases.
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
table_alias = derived_table.args.get("alias")
Expand All @@ -112,11 +113,11 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:

columns = {}

for k in scope.selected_sources:
if k in ordered:
for column in resolver.get_source_columns(k):
if column not in columns:
columns[column] = k
for source_name in scope.selected_sources:
if source_name in ordered:
for column_name in resolver.get_source_columns(source_name):
if column_name not in columns:
columns[column_name] = source_name

source_table = ordered[-1]
ordered.append(join_table)
Expand Down Expand Up @@ -218,7 +219,7 @@ def replace_columns(
scope.clear_cache()


def _expand_group_by(scope: Scope):
def _expand_group_by(scope: Scope) -> None:
expression = scope.expression
group = expression.args.get("group")
if not group:
Expand All @@ -228,7 +229,7 @@ def _expand_group_by(scope: Scope):
expression.set("group", group)


def _expand_order_by(scope: Scope, resolver: Resolver):
def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
order = scope.expression.args.get("order")
if not order:
return
Expand Down Expand Up @@ -447,7 +448,7 @@ def _add_replace_columns(
replace_columns[id(table)] = columns


def _qualify_outputs(scope: Scope):
def _qualify_outputs(scope: Scope) -> None:
"""Ensure all output columns are aliased"""
new_selections = []

Expand Down Expand Up @@ -487,9 +488,9 @@ class Resolver:
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
self.scope = scope
self.schema = schema
self._source_columns = None
self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
self._all_columns = None
self._all_columns: t.Optional[t.Set[str]] = None
self._infer_schema = infer_schema

def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
Expand Down Expand Up @@ -533,61 +534,75 @@ def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
return exp.to_identifier(table_name)

@property
def all_columns(self):
def all_columns(self) -> t.Set[str]:
"""All available columns of all sources in this scope"""
if self._all_columns is None:
self._all_columns = {
column for columns in self._get_all_source_columns().values() for column in columns
}
return self._all_columns

def get_source_columns(self, name, only_visible=False):
"""Resolve the source columns for a given source `name`"""
def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
"""Resolve the source columns for a given source `name`."""
if name not in self.scope.sources:
raise OptimizeError(f"Unknown table: {name}")

source = self.scope.sources[name]

# If referencing a table, return the columns from the schema
if isinstance(source, exp.Table):
return self.schema.column_names(source, only_visible)
columns = self.schema.column_names(source, only_visible)
elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
columns = source.expression.alias_column_names
else:
columns = source.expression.named_selects

if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
return source.expression.alias_column_names
node, _ = self.scope.selected_sources.get(name) or (None, None)
if isinstance(node, Scope):
column_aliases = node.expression.alias_column_names
elif isinstance(node, exp.Expression):
column_aliases = node.alias_column_names
else:
column_aliases = []

# Otherwise, if referencing another scope, return that scope's named selects
return source.expression.named_selects
# If the source's columns are aliased, their aliases shadow the corresponding column names
return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]

def _get_all_source_columns(self):
def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
if self._source_columns is None:
self._source_columns = {
k: self.get_source_columns(k)
for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
source_name: self.get_source_columns(source_name)
for source_name, source in itertools.chain(
self.scope.selected_sources.items(), self.scope.lateral_sources.items()
)
}
return self._source_columns

def _get_unambiguous_columns(self, source_columns):
def _get_unambiguous_columns(
self, source_columns: t.Dict[str, t.List[str]]
) -> t.Dict[str, str]:
"""
Find all the unambiguous columns in sources.
Args:
source_columns (dict): Mapping of names to source columns
source_columns: Mapping of names to source columns.
Returns:
dict: Mapping of column name to source name
Mapping of column name to source name.
"""
if not source_columns:
return {}

source_columns = list(source_columns.items())
source_columns_pairs = list(source_columns.items())

first_table, first_columns = source_columns[0]
first_table, first_columns = source_columns_pairs[0]
unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
all_columns = set(unambiguous_columns)

for table, columns in source_columns[1:]:
for table, columns in source_columns_pairs[1:]:
unique = self._find_unique_columns(columns)
ambiguous = set(all_columns).intersection(unique)
all_columns.update(columns)

for column in ambiguous:
unambiguous_columns.pop(column, None)
for column in unique.difference(ambiguous):
Expand All @@ -596,7 +611,7 @@ def _get_unambiguous_columns(self, source_columns):
return unambiguous_columns

@staticmethod
def _find_unique_columns(columns):
def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
"""
Find the unique columns in a list of columns.
Expand All @@ -606,7 +621,7 @@ def _find_unique_columns(columns):
This is necessary because duplicate column names are ambiguous.
"""
counts = {}
counts: t.Dict[str, int] = {}
for column in columns:
counts[column] = counts.get(column, 0) + 1
return {column for column, count in counts.items() if count == 1}
37 changes: 37 additions & 0 deletions tests/fixtures/optimizer/optimizer.sql
Original file line number Diff line number Diff line change
Expand Up @@ -907,3 +907,40 @@ JOIN "x" AS "x"
ON "y"."b" = "x"."b"
GROUP BY
"x"."a";

# title: select * from a cte, which had one of its two columns aliased
WITH cte(x, y) AS (SELECT 1, 2) SELECT * FROM cte AS cte(a);
WITH "cte" AS (
SELECT
1 AS "x",
2 AS "y"
)
SELECT
"cte"."a" AS "a",
"cte"."y" AS "y"
FROM "cte" AS "cte"("a");

# title: select single column from a cte using its alias
WITH cte(x) AS (SELECT 1) SELECT a FROM cte AS cte(a);
WITH "cte" AS (
SELECT
1 AS "x"
)
SELECT
"cte"."a" AS "a"
FROM "cte" AS "cte"("a");

# title: joined ctes with a "using" clause, one of which has had its column aliased
WITH m(a) AS (SELECT 1), n(b) AS (SELECT 1) SELECT * FROM m JOIN n AS foo(a) USING (a);
WITH "m" AS (
SELECT
1 AS "a"
), "n" AS (
SELECT
1 AS "b"
)
SELECT
COALESCE("m"."a", "foo"."a") AS "a"
FROM "m"
JOIN "n" AS "foo"("a")
ON "m"."a" = "foo"."a";
9 changes: 9 additions & 0 deletions tests/fixtures/optimizer/pushdown_projections.sql
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ SELECT i.a AS a FROM x AS i LEFT JOIN (SELECT _q_0.a AS a FROM (SELECT x.a AS a
WITH cte AS (SELECT source.a AS a, ROW_NUMBER() OVER (PARTITION BY source.id, source.timestamp ORDER BY source.a DESC) AS index FROM source AS source QUALIFY index) SELECT cte.a AS a FROM cte;
WITH cte AS (SELECT source.a AS a FROM source AS source QUALIFY ROW_NUMBER() OVER (PARTITION BY source.id, source.timestamp ORDER BY source.a DESC)) SELECT cte.a AS a FROM cte;

WITH cte AS (SELECT 1 AS x, 2 AS y, 3 AS z) SELECT cte.a FROM cte AS cte(a);
WITH cte AS (SELECT 1 AS x) SELECT cte.a AS a FROM cte AS cte(a);

WITH cte(x, y, z) AS (SELECT 1, 2, 3) SELECT a, z FROM cte AS cte(a);
WITH cte AS (SELECT 1 AS x, 3 AS z) SELECT cte.a AS a, cte.z AS z FROM cte AS cte(a);

WITH cte(x, y, z) AS (SELECT 1, 2, 3) SELECT a, z FROM (SELECT * FROM cte AS cte(b)) AS cte(a);
WITH cte AS (SELECT 1 AS x, 3 AS z) SELECT cte.a AS a, cte.z AS z FROM (SELECT cte.b AS a, cte.z AS z FROM cte AS cte(b)) AS cte;

--------------------------------------
-- Unknown Star Expansion
--------------------------------------
Expand Down
9 changes: 9 additions & 0 deletions tests/fixtures/optimizer/qualify_columns.sql
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,12 @@ WITH z AS (SELECT x.a AS a, x.b AS b FROM x AS x), q AS (SELECT z.b AS b FROM z)
WITH z AS ((SELECT b FROM x UNION ALL SELECT b FROM y) ORDER BY b) SELECT * FROM z;
WITH z AS ((SELECT x.b AS b FROM x AS x UNION ALL SELECT y.b AS b FROM y AS y) ORDER BY b) SELECT z.b AS b FROM z;

WITH cte(x) AS (SELECT 1) SELECT * FROM cte AS cte(a);
WITH cte AS (SELECT 1 AS x) SELECT cte.a AS a FROM cte AS cte(a);

WITH cte(x, y) AS (SELECT 1, 2) SELECT cte.* FROM cte AS cte(a);
WITH cte AS (SELECT 1 AS x, 2 AS y) SELECT cte.a AS a, cte.y AS y FROM cte AS cte(a);

--------------------------------------
-- Except and Replace
--------------------------------------
Expand Down Expand Up @@ -383,6 +389,9 @@ SELECT x.b AS b FROM t AS t JOIN x AS x ON t.a = x.a;
SELECT a FROM t1 JOIN t2 USING(a);
SELECT COALESCE(t1.a, t2.a) AS a FROM t1 AS t1 JOIN t2 AS t2 ON t1.a = t2.a;

WITH m(a) AS (SELECT 1), n(b) AS (SELECT 1) SELECT * FROM m JOIN n AS foo(a) USING (a);
WITH m AS (SELECT 1 AS a), n AS (SELECT 1 AS b) SELECT COALESCE(m.a, foo.a) AS a FROM m JOIN n AS foo(a) ON m.a = foo.a;

--------------------------------------
-- Hint with table reference
--------------------------------------
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/optimizer/qualify_columns__invalid.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ SELECT x.a FROM x JOIN y USING (a);
SELECT a, SUM(b) FROM x GROUP BY 3;
SELECT p FROM (SELECT x from xx) y CROSS JOIN yy CROSS JOIN zz
SELECT a FROM (SELECT * FROM x CROSS JOIN y);
SELECT x FROM tbl AS tbl(a);
4 changes: 4 additions & 0 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@ def test_alias_column_names(self):
cte = expression.find(exp.CTE)
self.assertEqual(cte.alias_column_names, ["a", "b"])

expression = parse_one("SELECT * FROM tbl AS tbl(a, b)")
table = expression.find(exp.Table)
self.assertEqual(table.alias_column_names, ["a", "b"])

def test_ctes(self):
expression = parse_one("SELECT a FROM x")
self.assertEqual(expression.ctes, [])
Expand Down

0 comments on commit c73790d

Please sign in to comment.