Skip to content

Commit

Permalink
Feat(optimizer): propagate constants (#2386)
Browse files Browse the repository at this point in the history
* Feat!(optimizer): propagate constants

* Rephrase docstring

* Type hint fix

* Don't replace variables that are compared to NULL

* Formatting

* Cleanup

* Use find_all_in_scope instead of find_all

* Leave a TODO comment to add helper that detects literals

* Fix another bug involving find_all vs find_all_in_scope

* Move normalized helper in expressions.py

* Move normalized back to normalize.py, refactor..

* Update comment

* Increase test coverage

* Increase test coverage

* Make the constant propagation rule opt-in

* Fix optimize_joins bug
  • Loading branch information
georgesittas authored Oct 9, 2023
1 parent a849794 commit cca58dd
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 25 deletions.
2 changes: 1 addition & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def flatten(self, unnest=True):
"""
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
if not type(node) is self.__class__:
yield node.unnest() if unnest else node
yield node.unnest() if unnest and not isinstance(node, Subquery) else node

def __str__(self) -> str:
return self.sql()
Expand Down
8 changes: 8 additions & 0 deletions sqlglot/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,14 @@ def first(it: t.Iterable[T]) -> T:


def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
"""
Merges a sequence of ranges, represented as tuples (low, high) whose values
belong to some totally-ordered set.
Example:
>>> merge_ranges([(1, 3), (2, 6)])
[(1, 6)]
"""
if not ranges:
return []

Expand Down
37 changes: 29 additions & 8 deletions sqlglot/optimizer/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqlglot.errors import OptimizeError
from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
from sqlglot.optimizer.scope import find_all_in_scope
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort

logger = logging.getLogger("sqlglot")
Expand Down Expand Up @@ -63,15 +64,33 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
return expression


def normalized(expression, dnf=False):
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
"""
Checks whether a given expression is in a normal form of interest.
return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
Example:
>>> from sqlglot import parse_one
>>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
True
>>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default
True
>>> normalized(parse_one("a AND (b OR c)"), dnf=True)
False
Args:
expression: The expression to check if it's normalized.
dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
"""
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
return not any(
connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
)

def normalization_distance(expression, dnf=False):

def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
"""
The difference in the number of predicates between the current expression and the normalized form.
The difference in the number of predicates between a given expression and its normalized form.
This is used as an estimate of the cost of the conversion which is exponential in complexity.
Expand All @@ -82,10 +101,12 @@ def normalization_distance(expression, dnf=False):
4
Args:
expression (sqlglot.Expression): expression to compute distance
dnf (bool): compute to dnf distance instead
expression: The expression to compute the normalization distance for.
dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
Returns:
int: difference
The normalization distance.
"""
return sum(_predicate_lengths(expression, dnf)) - (
sum(1 for _ in expression.find_all(exp.Connector)) + 1
Expand Down
6 changes: 5 additions & 1 deletion sqlglot/optimizer/optimize_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ def optimize_joins(expression):
if len(other_table_names(dep)) < 2:
continue

operator = type(on)
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
predicate.replace(exp.true())
join.on(predicate, copy=False)
predicate = exp._combine(
[join.args.get("on"), predicate], operator, copy=False
)
join.on(predicate, append=False, copy=False)

expression = reorder_joins(expression)
expression = normalize(expression)
Expand Down
53 changes: 52 additions & 1 deletion sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from collections import deque
from decimal import Decimal

import sqlglot
from sqlglot import exp
from sqlglot.generator import cached_generator
from sqlglot.helper import first, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope

# Final means that an expression should not be simplified
FINAL = "final"
Expand All @@ -17,7 +19,7 @@ class UnsupportedUnit(Exception):
pass


def simplify(expression):
def simplify(expression, constant_propagation=False):
"""
Rewrite sqlglot AST to simplify expressions.
Expand All @@ -29,6 +31,8 @@ def simplify(expression):
Args:
expression (sqlglot.Expression): expression to simplify
constant_propagation: whether or not the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
"""
Expand Down Expand Up @@ -67,6 +71,9 @@ def _simplify(expression, root=True):
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)

if constant_propagation:
node = propagate_constants(node, root)

exp.replace_children(node, lambda e: _simplify(e, False))

# Post-order transformations
Expand Down Expand Up @@ -369,6 +376,50 @@ def absorb_and_eliminate(expression, root=True):
return expression


def propagate_constants(expression, root=True):
"""
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes
SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
"""

if (
isinstance(expression, exp.And)
and (root or not expression.same_parent)
and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
):
constant_mapping = {}
for eq in find_all_in_scope(expression, exp.EQ):
l, r = eq.left, eq.right

# TODO: create a helper that can be used to detect nested literal expressions such
# as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
pass
elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
l, r = r, l
else:
continue

constant_mapping[l] = (id(l), r)

if constant_mapping:
for column in find_all_in_scope(expression, exp.Column):
parent = column.parent
column_id, constant = constant_mapping.get(column) or (None, None)
if (
column_id is not None
and id(column) != column_id
and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
):
column.replace(constant.copy())

return expression


INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
exp.DateAdd: exp.Sub,
exp.DateSub: exp.Add,
Expand Down
47 changes: 46 additions & 1 deletion tests/fixtures/optimizer/simplify.sql
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ t0.x = t1.x AND t0.y < t1.y AND t0.y <= t1.y;
t0.x = t1.x AND t0.y < t1.y AND t0.y <= t1.y;

--------------------------------------
-- Coalesce
-- COALESCE
--------------------------------------
COALESCE(x);
x;
Expand Down Expand Up @@ -864,3 +864,48 @@ x < CAST('2020-01-07' AS DATE);

x - INTERVAL '1' day = CAST(y AS DATE);
x - INTERVAL '1' day = CAST(y AS DATE);

--------------------------------------
-- Constant Propagation
--------------------------------------
x = 5 AND y = x;
x = 5 AND y = 5;

5 = x AND y = x;
y = 5 AND 5 = x;

x = 5 OR y = x;
x = 5 OR y = x;

(x = 5 AND y = x) OR y = 1;
(x = 5 AND y = 5) OR y = 1;

t.x = 5 AND y = x;
t.x = 5 AND y = x;

t.x = 'a' AND y = CONCAT_WS('-', t.x, 'b');
t.x = 'a' AND y = 'a-b';

x = 5 AND y = x AND y + 1 < 5;
FALSE;

x = 5 AND x = 6;
FALSE;

x = 5 AND (y = x OR z = 1);
x = 5 AND (y = x OR z = 1);

x = 5 AND x + 3 = 8;
x = 5;

x = 5 AND (SELECT x FROM t WHERE y = 1);
x = 5 AND (SELECT x FROM t WHERE y = 1);

x = 1 AND y > 0 AND (SELECT z = 5 FROM t WHERE y = 1);
x = 1 AND y > 0 AND (SELECT z = 5 FROM t WHERE y = 1);

x = 1 AND x = y AND (SELECT z FROM t WHERE a AND (b OR c));
x = 1 AND (SELECT z FROM t WHERE a AND (b OR c)) AND 1 = y;

SELECT * FROM t1, t2, t3 WHERE t1.a = 39 AND t2.b = t1.a AND t3.c = t2.b;
SELECT * FROM t1, t2, t3 WHERE t1.a = 39 AND t2.b = 39 AND t3.c = 39;
39 changes: 27 additions & 12 deletions tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2029,18 +2029,33 @@ JOIN "date_dim" AS "date_dim"
ON "date_dim"."d_year" = 2001
AND "store_sales"."ss_sold_date_sk" = "date_dim"."d_date_sk"
JOIN "household_demographics" AS "household_demographics"
ON "customer_demographics"."cd_demo_sk" = "store_sales"."ss_cdemo_sk"
AND "customer_demographics"."cd_education_status" = 'Advanced Degree'
AND "customer_demographics"."cd_education_status" = 'Primary'
AND "customer_demographics"."cd_education_status" = 'Secondary'
AND "customer_demographics"."cd_marital_status" = 'D'
AND "customer_demographics"."cd_marital_status" = 'M'
AND "customer_demographics"."cd_marital_status" = 'U'
AND "household_demographics"."hd_dep_count" = 1
AND "household_demographics"."hd_dep_count" = 3
AND "store_sales"."ss_hdemo_sk" = "household_demographics"."hd_demo_sk"
AND "store_sales"."ss_sales_price" <= 100.00
AND "store_sales"."ss_sales_price" >= 150.00
ON (
"customer_demographics"."cd_demo_sk" = "store_sales"."ss_cdemo_sk"
AND "customer_demographics"."cd_education_status" = 'Advanced Degree'
AND "customer_demographics"."cd_marital_status" = 'U'
AND "household_demographics"."hd_dep_count" = 3
AND "store_sales"."ss_hdemo_sk" = "household_demographics"."hd_demo_sk"
AND "store_sales"."ss_sales_price" <= 150.00
AND "store_sales"."ss_sales_price" >= 100.00
)
OR (
"customer_demographics"."cd_demo_sk" = "store_sales"."ss_cdemo_sk"
AND "customer_demographics"."cd_education_status" = 'Primary'
AND "customer_demographics"."cd_marital_status" = 'M'
AND "household_demographics"."hd_dep_count" = 1
AND "store_sales"."ss_hdemo_sk" = "household_demographics"."hd_demo_sk"
AND "store_sales"."ss_sales_price" <= 100.00
AND "store_sales"."ss_sales_price" >= 50.00
)
OR (
"customer_demographics"."cd_demo_sk" = "store_sales"."ss_cdemo_sk"
AND "customer_demographics"."cd_education_status" = 'Secondary'
AND "customer_demographics"."cd_marital_status" = 'D'
AND "household_demographics"."hd_dep_count" = 1
AND "store_sales"."ss_hdemo_sk" = "household_demographics"."hd_demo_sk"
AND "store_sales"."ss_sales_price" <= 200.00
AND "store_sales"."ss_sales_price" >= 150.00
)
JOIN "store" AS "store"
ON "store"."s_store_sk" = "store_sales"."ss_store_sk";

Expand Down
6 changes: 5 additions & 1 deletion tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def normalize(expression, **kwargs):
return optimizer.simplify.simplify(expression)


def simplify(expression, **kwargs):
return optimizer.simplify.simplify(expression, constant_propagation=True, **kwargs)


class TestOptimizer(unittest.TestCase):
maxDiff = None

Expand Down Expand Up @@ -271,7 +275,7 @@ def test_pushdown_projection(self):
self.check_file("pushdown_projections", pushdown_projections, schema=self.schema)

def test_simplify(self):
self.check_file("simplify", optimizer.simplify.simplify)
self.check_file("simplify", simplify)

expression = parse_one("TRUE AND TRUE AND TRUE")
self.assertEqual(exp.true(), optimizer.simplify.simplify(expression))
Expand Down

0 comments on commit cca58dd

Please sign in to comment.