Skip to content

Commit

Permalink
Fix(optimizer): don't propagate equality constraints from IF/CASE out…
Browse files Browse the repository at this point in the history
…wards (#2396)
  • Loading branch information
georgesittas authored Oct 11, 2023
1 parent 0d3b77d commit d7021d1
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 23 deletions.
18 changes: 11 additions & 7 deletions sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def _ensure_collected(self):
if not self._collected:
self._collect()

def walk(self, bfs=True):
return walk_in_scope(self.expression, bfs=bfs)
def walk(self, bfs=True, prune=None):
return walk_in_scope(self.expression, bfs=bfs, prune=None)

def find(self, *expression_types, bfs=True):
return find_in_scope(self.expression, expression_types, bfs=bfs)
Expand Down Expand Up @@ -731,7 +731,7 @@ def _traverse_ddl(scope):
yield from _traverse_scope(query_scope)


def walk_in_scope(expression, bfs=True):
def walk_in_scope(expression, bfs=True, prune=None):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
nodes that start child scopes.
Expand All @@ -740,16 +740,20 @@ def walk_in_scope(expression, bfs=True):
expression (exp.Expression):
bfs (bool): if set to True the BFS traversal order will be applied,
otherwise the DFS traversal will be used instead.
prune ((node, parent, arg_key) -> bool): callable that returns True if
the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
"""
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
prune = False
crossed_scope_boundary = False

for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
prune = False
for node, parent, key in expression.walk(
bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
):
crossed_scope_boundary = False

yield node, parent, key

Expand All @@ -765,7 +769,7 @@ def walk_in_scope(expression, bfs=True):
or isinstance(node, exp.UDTF)
or isinstance(node, exp.Subqueryable)
):
prune = True
crossed_scope_boundary = True

if isinstance(node, (exp.Subquery, exp.UDTF)):
# The following args are not actually in the inner scope, so we should visit them
Expand Down
29 changes: 15 additions & 14 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sqlglot import exp
from sqlglot.generator import cached_generator
from sqlglot.helper import first, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope

# Final means that an expression should not be simplified
FINAL = "final"
Expand Down Expand Up @@ -392,19 +392,20 @@ def propagate_constants(expression, root=True):
and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
):
constant_mapping = {}
for eq in find_all_in_scope(expression, exp.EQ):
l, r = eq.left, eq.right

# TODO: create a helper that can be used to detect nested literal expressions such
# as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
pass
elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
l, r = r, l
else:
continue

constant_mapping[l] = (id(l), r)
for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
if isinstance(expr, exp.EQ):
l, r = expr.left, expr.right

# TODO: create a helper that can be used to detect nested literal expressions such
# as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
pass
elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
l, r = r, l
else:
continue

constant_mapping[l] = (id(l), r)

if constant_mapping:
for column in find_all_in_scope(expression, exp.Column):
Expand Down
16 changes: 14 additions & 2 deletions tests/fixtures/optimizer/simplify.sql
Original file line number Diff line number Diff line change
Expand Up @@ -907,5 +907,17 @@ x = 1 AND y > 0 AND (SELECT z = 5 FROM t WHERE y = 1);
x = 1 AND x = y AND (SELECT z FROM t WHERE a AND (b OR c));
x = 1 AND (SELECT z FROM t WHERE a AND (b OR c)) AND 1 = y;

SELECT * FROM t1, t2, t3 WHERE t1.a = 39 AND t2.b = t1.a AND t3.c = t2.b;
SELECT * FROM t1, t2, t3 WHERE t1.a = 39 AND t2.b = 39 AND t3.c = 39;
t1.a = 39 AND t2.b = t1.a AND t3.c = t2.b;
t1.a = 39 AND t2.b = 39 AND t3.c = 39;

x = 1 AND CASE WHEN x = 5 THEN FALSE ELSE TRUE END;
x = 1 AND CASE WHEN FALSE THEN FALSE ELSE TRUE END;

x = 1 AND IF(x = 5, FALSE, TRUE);
x = 1 AND CASE WHEN FALSE THEN FALSE ELSE TRUE END;

x = y AND CASE WHEN x = 5 THEN FALSE ELSE TRUE END;
x = y AND CASE WHEN x = 5 THEN FALSE ELSE TRUE END;

x = 1 AND CASE WHEN y = 5 THEN x = z END;
x = 1 AND CASE WHEN y = 5 THEN 1 = z END;

0 comments on commit d7021d1

Please sign in to comment.