Skip to content

Commit

Permalink
Refactor: move group by finalizer to simplify because that is who cares
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Jun 30, 2023
1 parent 998969e commit d8eeda2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
19 changes: 0 additions & 19 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,25 +220,6 @@ def _expand_group_by(scope: Scope):
group.set("expressions", _expand_positional_references(scope, group.expressions))
expression.set("group", group)

# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
groups = set(group.expressions)
group.meta["final"] = True

for e in expression.selects:
for node, *_ in e.walk():
if node in groups:
e.meta["final"] = True
break

having = expression.args.get("having")
if having:
for node, *_ in having.walk():
if node in groups:
having.meta["final"] = True
break


def _expand_order_by(scope: Scope, resolver: Resolver):
order = scope.expression.args.get("order")
Expand Down
28 changes: 26 additions & 2 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from sqlglot.generator import cached_generator
from sqlglot.helper import first, while_changing

# Final means that an expression should not be simplified
FINAL = "final"


def simplify(expression):
"""
Expand All @@ -27,8 +30,29 @@ def simplify(expression):

generate = cached_generator()

# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
for group in expression.find_all(exp.Group):
select = group.parent
groups = set(group.expressions)
group.meta[FINAL] = True

for e in select.selects:
for node, *_ in e.walk():
if node in groups:
e.meta[FINAL] = True
break

having = select.args.get("having")
if having:
for node, *_ in having.walk():
if node in groups:
having.meta[FINAL] = True
break

def _simplify(expression, root=True):
if expression.meta.get("final"):
if expression.meta.get(FINAL):
return expression
node = expression
node = rewrite_between(node)
Expand Down Expand Up @@ -58,7 +82,7 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
"""
if isinstance(expression, exp.Between):
return exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.GTE(this=expression.this.copy(), expression=expression.args[FINAL]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
copy=False,
)
Expand Down

0 comments on commit d8eeda2

Please sign in to comment.