Skip to content

Commit

Permalink
Fix: group and order cannot replace with literals
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Jun 29, 2023
1 parent 0357d63 commit 28e1024
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 17 deletions.
50 changes: 34 additions & 16 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,19 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if not isinstance(expression, exp.Select):
return

alias_to_expression: t.Dict[str, exp.Expression] = {}
alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}

def replace_columns(node: t.Optional[exp.Expression], resolve_table: bool = False) -> None:
def replace_columns(
node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
) -> None:
if not node:
return

for column, *_ in walk_in_scope(node):
if not isinstance(column, exp.Column):
continue
table = resolver.get_table(column.name) if resolve_table and not column.table else None
alias_expr = alias_to_expression.get(column.name)
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
double_agg = (
(alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
if alias_expr
Expand All @@ -190,16 +192,20 @@ def replace_columns(node: t.Optional[exp.Expression], resolve_table: bool = Fals
if table and (not alias_expr or double_agg):
column.set("table", table)
elif not column.table and alias_expr and not double_agg:
column.replace(alias_expr.copy())
if isinstance(alias_expr, exp.Literal):
if literal_index:
column.replace(exp.Literal.number(i))
else:
column.replace(alias_expr.copy())

for projection in scope.selects:
for i, projection in enumerate(scope.selects):
replace_columns(projection)

if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = projection.this
alias_to_expression[projection.alias] = (projection.this, i + 1)

replace_columns(expression.args.get("where"))
replace_columns(expression.args.get("group"))
replace_columns(expression.args.get("group"), literal_index=True)
replace_columns(expression.args.get("having"), resolve_table=True)
replace_columns(expression.args.get("qualify"), resolve_table=True)
scope.clear_cache()
Expand Down Expand Up @@ -255,27 +261,39 @@ def _expand_order_by(scope: Scope, resolver: Resolver):
selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects}

for ordered in ordereds:
ordered.set("this", selects.get(ordered.this, ordered.this))
ordered = ordered.this

ordered.replace(
exp.to_identifier(_select_by_pos(scope, ordered).alias)
if ordered.is_int
else selects.get(ordered, ordered)
)


def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
new_nodes = []
for node in expressions:
if node.is_int:
try:
select = scope.selects[int(node.name) - 1]
except IndexError:
raise OptimizeError(f"Unknown output column: {node.name}")
if isinstance(select, exp.Alias):
select = select.this
new_nodes.append(select.copy())
scope.clear_cache()
select = _select_by_pos(scope, t.cast(exp.Literal, node)).this

if isinstance(select, exp.Literal):
new_nodes.append(node)
else:
new_nodes.append(select.copy())
scope.clear_cache()
else:
new_nodes.append(node)

return new_nodes


def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
try:
return scope.selects[int(node.this) - 1].assert_is(exp.Alias)
except IndexError:
raise OptimizeError(f"Unknown output column: {node.name}")


def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
"""Disambiguate columns, ensuring each column specifies a source"""
for column in scope.columns:
Expand Down
17 changes: 17 additions & 0 deletions tests/fixtures/optimizer/qualify_columns.sql
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,23 @@ SELECT x.a AS a, x.b AS b FROM x AS x GROUP BY x.a, x.b;
SELECT a, b FROM x ORDER BY 1, 2;
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a, x.b;

SELECT 2 FROM x GROUP BY 1;
SELECT 2 AS "2" FROM x AS x GROUP BY 1;

SELECT 'a' AS a FROM x GROUP BY 1;
SELECT 'a' AS a FROM x AS x GROUP BY 1;

# execute: false
-- this query seems to be invalid in postgres and duckdb but valid in bigquery
SELECT 2 a FROM x GROUP BY 1 HAVING a > 1;
SELECT 2 AS a FROM x AS x GROUP BY 1 HAVING a > 1;

SELECT 2 d FROM x GROUP BY d HAVING d > 1;
SELECT 2 AS d FROM x AS x GROUP BY 1 HAVING d > 1;

SELECT 2 d FROM x GROUP BY 1 ORDER BY 1;
SELECT 2 AS d FROM x AS x GROUP BY 1 ORDER BY d;

# execute: false
SELECT DATE(a), DATE(b) AS c FROM x GROUP BY 1, 2;
SELECT DATE(x.a) AS _col_0, DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b);
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5985,7 +5985,7 @@ WITH "date_dim_2" AS (
WHERE
"store"."currency_rank" <= 10 OR "store"."return_rank" <= 10
ORDER BY
'store',
1,
"store"."return_rank",
"store"."currency_rank"
LIMIT 100
Expand Down

0 comments on commit 28e1024

Please sign in to comment.