Skip to content

Commit

Permalink
fix: pushdown predicate to HAVING (#2064)
Browse files Browse the repository at this point in the history
  • Loading branch information
barakalon authored Aug 15, 2023
1 parent 56a3d89 commit 7787342
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 34 deletions.
14 changes: 11 additions & 3 deletions sqlglot/optimizer/pushdown_predicates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlglot import exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import build_scope
from sqlglot.optimizer.scope import build_scope, find_in_scope
from sqlglot.optimizer.simplify import simplify


Expand Down Expand Up @@ -81,7 +81,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count):
break
if isinstance(node, exp.Select):
predicate.replace(exp.true())
node.where(replace_aliases(node, predicate), copy=False)
inner_predicate = replace_aliases(node, predicate)
if find_in_scope(inner_predicate, exp.AggFunc):
node.having(inner_predicate, copy=False)
else:
node.where(inner_predicate, copy=False)


def pushdown_dnf(predicates, scope, scope_ref_count):
Expand Down Expand Up @@ -142,7 +146,11 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
if isinstance(node, exp.Join):
node.on(predicate, copy=False)
elif isinstance(node, exp.Select):
node.where(replace_aliases(node, predicate), copy=False)
inner_predicate = replace_aliases(node, predicate)
if find_in_scope(inner_predicate, exp.AggFunc):
node.having(inner_predicate, copy=False)
else:
node.where(inner_predicate, copy=False)


def nodes_for_predicate(predicate, sources, scope_ref_count):
Expand Down
72 changes: 41 additions & 31 deletions sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import find_new_name
from sqlglot.helper import ensure_collection, find_new_name

logger = logging.getLogger("sqlglot")

Expand Down Expand Up @@ -141,38 +141,10 @@ def walk(self, bfs=True):
return walk_in_scope(self.expression, bfs=bfs)

def find(self, *expression_types, bfs=True):
"""
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Args:
expression_types (type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching
the criteria was found.
"""
return next(self.find_all(*expression_types, bfs=bfs), None)
return find_in_scope(self.expression, expression_types, bfs=bfs)

def find_all(self, *expression_types, bfs=True):
"""
Returns a generator object which visits all nodes in this scope and only yields those that
match at least one of the specified expression types.
This does NOT traverse into subscopes.
Args:
expression_types (type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
"""
for expression, *_ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
return find_all_in_scope(self.expression, expression_types, bfs=bfs)

def replace(self, old, new):
"""
Expand Down Expand Up @@ -800,3 +772,41 @@ def walk_in_scope(expression, bfs=True):
for key in ("joins", "laterals", "pivots"):
for arg in node.args.get(key) or []:
yield from walk_in_scope(arg, bfs=bfs)


def find_all_in_scope(expression, expression_types, bfs=True):
"""
Returns a generator object which visits all nodes in this scope and only yields those that
match at least one of the specified expression types.
This does NOT traverse into subscopes.
Args:
expression (exp.Expression):
expression_types (tuple[type]|type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
"""
for expression, *_ in walk_in_scope(expression, bfs=bfs):
if isinstance(expression, tuple(ensure_collection(expression_types))):
yield expression


def find_in_scope(expression, expression_types, bfs=True):
"""
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Args:
expression (exp.Expression):
expression_types (tuple[type]|type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching
the criteria was found.
"""
return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
8 changes: 8 additions & 0 deletions tests/fixtures/optimizer/pushdown_predicates.sql
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,11 @@ WITH t1 AS (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) A

WITH m AS (SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)), n AS (SELECT a, b FROM m WHERE m.a = 1), o AS (SELECT a, b FROM m WHERE m.a = 2) SELECT n.a, n.b, n.a, o.b FROM n FULL OUTER JOIN o ON n.a = o.a;
WITH m AS (SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)), n AS (SELECT a, b FROM m WHERE m.a = 1), o AS (SELECT a, b FROM m WHERE m.a = 2) SELECT n.a, n.b, n.a, o.b FROM n FULL OUTER JOIN o ON n.a = o.a;

-- Pushdown predicate to HAVING (CNF)
SELECT x.cnt AS cnt FROM (SELECT COUNT(1) AS cnt FROM x AS x) AS x WHERE x.cnt > 0;
SELECT x.cnt AS cnt FROM (SELECT COUNT(1) AS cnt FROM x AS x HAVING COUNT(1) > 0) AS x WHERE TRUE;

-- Pushdown predicate to HAVING (DNF)
SELECT x.cnt AS cnt FROM (SELECT COUNT(1) AS cnt, COUNT(x.a) AS cnt_a, COUNT(x.b) AS cnt_b FROM x AS x) AS x WHERE (x.cnt_a > 0 AND x.cnt_b > 0) OR x.cnt > 0;
SELECT x.cnt AS cnt FROM (SELECT COUNT(1) AS cnt, COUNT(x.a) AS cnt_a, COUNT(x.b) AS cnt_b FROM x AS x HAVING COUNT(1) > 0 OR (COUNT(x.a) > 0 AND COUNT(x.b) > 0)) AS x WHERE x.cnt > 0 OR (x.cnt_a > 0 AND x.cnt_b > 0);

0 comments on commit 7787342

Please sign in to comment.