Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(optimizer): propagate constants #2386

Merged
merged 16 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def flatten(self, unnest=True):
"""
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
if not type(node) is self.__class__:
yield node.unnest() if unnest else node
yield node.unnest() if unnest and not isinstance(node, Subquery) else node
georgesittas marked this conversation as resolved.
Show resolved Hide resolved

def __str__(self) -> str:
return self.sql()
Expand Down
8 changes: 8 additions & 0 deletions sqlglot/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,14 @@ def first(it: t.Iterable[T]) -> T:


def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
"""
Merges a sequence of ranges, represented as tuples (low, high) whose values
belong to some totally-ordered set.

Example:
>>> merge_ranges([(1, 3), (2, 6)])
[(1, 6)]
"""
if not ranges:
return []

Expand Down
37 changes: 29 additions & 8 deletions sqlglot/optimizer/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqlglot.errors import OptimizeError
from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
from sqlglot.optimizer.scope import find_all_in_scope
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort

logger = logging.getLogger("sqlglot")
Expand Down Expand Up @@ -63,15 +64,33 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
return expression


def normalized(expression, dnf=False):
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
"""
Checks whether a given expression is in a normal form of interest.

return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
Example:
>>> from sqlglot import parse_one
>>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
True
>>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default
True
>>> normalized(parse_one("a AND (b OR c)"), dnf=True)
False

Args:
expression: The expression to check if it's normalized.
dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
"""
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
return not any(
connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
)

def normalization_distance(expression, dnf=False):

def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
"""
The difference in the number of predicates between the current expression and the normalized form.
The difference in the number of predicates between a given expression and its normalized form.

This is used as an estimate of the cost of the conversion which is exponential in complexity.

Expand All @@ -82,10 +101,12 @@ def normalization_distance(expression, dnf=False):
4

Args:
expression (sqlglot.Expression): expression to compute distance
dnf (bool): compute to dnf distance instead
expression: The expression to compute the normalization distance for.
dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).

Returns:
int: difference
The normalization distance.
"""
return sum(_predicate_lengths(expression, dnf)) - (
sum(1 for _ in expression.find_all(exp.Connector)) + 1
Expand Down
47 changes: 47 additions & 0 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from collections import deque
from decimal import Decimal

import sqlglot
from sqlglot import exp
from sqlglot.generator import cached_generator
from sqlglot.helper import first, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope

# Final means that an expression should not be simplified
FINAL = "final"
Expand Down Expand Up @@ -65,6 +67,7 @@ def _simplify(expression, root=True):
node = rewrite_between(node)
node = uniq_sort(node, generate, root)
node = absorb_and_eliminate(node, root)
node = propagate_constants(node, root)
node = simplify_concat(node)

exp.replace_children(node, lambda e: _simplify(e, False))
Expand Down Expand Up @@ -369,6 +372,50 @@ def absorb_and_eliminate(expression, root=True):
return expression


def propagate_constants(expression, root=True):
"""
Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes
SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html
"""

if (
isinstance(expression, exp.And)
and (root or not expression.same_parent)
and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
):
constant_mapping = {}
for eq in find_all_in_scope(expression, exp.EQ):
l, r = eq.left, eq.right

# TODO: create a helper that can be used to detect nested literal expressions such
# as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
pass
elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
l, r = r, l
else:
continue

constant_mapping[l] = (id(l), r)
tobymao marked this conversation as resolved.
Show resolved Hide resolved

if constant_mapping:
for column in find_all_in_scope(expression, exp.Column):
parent = column.parent
column_id, constant = constant_mapping.get(column) or (None, None)
if (
column_id is not None
and id(column) != column_id
and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
):
column.replace(constant.copy())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is replacement right? That takes it a step further than the sqlite rule, which says it just adds the extra predicate.

Do any planners take advantage of column equality in join conditions?

Copy link
Collaborator Author

@georgesittas georgesittas Oct 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought SQLite does the variable replacement as well, but I'm not 100% sure. Will need to check their source to verify this assumption.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example doesn’t show replacement. I also wonder if there is different behavior on join conditions.

Copy link
Collaborator Author

@georgesittas georgesittas Oct 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, will dig deeper. Also a good catch, thanks!

Copy link
Collaborator Author

@georgesittas georgesittas Oct 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw I've discovered that there are some gnarly cases that are tricky to take care of - mostly because of coercions (e.g. int to string - see SQLite documentation linked below).

I also wonder if there is different behavior on join conditions.

Interestingly, SQLite doesn't seem to implement this transformation for ON clauses of LEFT, RIGHT joins:

      ...
      /* Do not propagate constants on any ON clause if there is a
      ** RIGHT JOIN anywhere in the query */
      x.mExcludeOn = EP_InnerON | EP_OuterON;
    }else{
      /* Do not propagate constants through the ON clause of a LEFT JOIN */
      x.mExcludeOn = EP_OuterON;
      ...

I'll need to understand their code a bit better to make sure this is the case though..

Some documentation on SQLite's implementation can be found here (screenshot of their source code).

On the other hand, MariaDB mentions this transformation in the context of WHERE clauses, but towards the end of the page they also have this (note the remark about ON expressions):

Screenshot 2023-10-09 at 5 46 56 AM

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding

The example doesn’t show replacement

It's now clear that SQLite does indeed replace the variable with the value:

Screenshot 2023-10-09 at 6 03 43 AM


return expression


INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
exp.DateAdd: exp.Sub,
exp.DateSub: exp.Add,
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/optimizer/optimizer.sql
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ SELECT
"y"."b" AS "b"
FROM "x" AS "x"
RIGHT JOIN "y_2" AS "y"
ON "x"."a" = "y"."b";
ON "x"."a" = 1;


# title: lateral column alias reference
Expand Down
8 changes: 4 additions & 4 deletions tests/fixtures/optimizer/pushdown_predicates.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = y.a AND x.a = 1 AND x
SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.a WHERE TRUE;

SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.b;
SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.a OR x.a = y.b WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.b;
SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.b OR 1 = y.a WHERE x.a = y.b OR (x.a = 1 AND x.b = 1 AND 1 = y.a);

SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x) AS x WHERE x.c = 1;
SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x WHERE x.b * 1 = 1) AS x WHERE TRUE;
Expand All @@ -23,13 +23,13 @@ SELECT x.a AS a FROM (SELECT x.a FROM x AS x) AS x JOIN y WHERE x.a = 1 AND x.b
SELECT x.a AS a FROM (SELECT x.a FROM x AS x WHERE x.a = 1 AND x.b = 1) AS x JOIN y ON x.c = 1 OR y.c = 1 WHERE TRUE AND TRUE AND (TRUE);

SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y) AS y ON y.a = 1 AND x.a = y.a;
SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y WHERE y.a = 1) AS y ON x.a = y.a AND TRUE;
SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y WHERE y.a = 1) AS y ON x.a = 1 AND TRUE;

SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y) AS y ON y.a = 1 WHERE x.a = 1 AND x.b = 1 AND y.a = x.a;
SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.a = x.a AND TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE;
SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE;

SELECT x.a AS a FROM x AS x CROSS JOIN (SELECT * FROM y AS y) AS y WHERE x.a = 1 AND x.b = 1 AND y.a = x.a AND y.a = 1;
SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.a = x.a AND TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE AND TRUE;
SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE;

with t1 as (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) as row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1;
WITH t1 AS (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1;
Expand Down
47 changes: 46 additions & 1 deletion tests/fixtures/optimizer/simplify.sql
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ t0.x = t1.x AND t0.y < t1.y AND t0.y <= t1.y;
t0.x = t1.x AND t0.y < t1.y AND t0.y <= t1.y;

--------------------------------------
-- Coalesce
-- COALESCE
--------------------------------------
COALESCE(x);
x;
Expand Down Expand Up @@ -864,3 +864,48 @@ x < CAST('2020-01-07' AS DATE);

x - INTERVAL '1' day = CAST(y AS DATE);
x - INTERVAL '1' day = CAST(y AS DATE);

--------------------------------------
-- Constant Propagation
--------------------------------------
x = 5 AND y = x;
x = 5 AND y = 5;

5 = x AND y = x;
y = 5 AND 5 = x;

x = 5 OR y = x;
x = 5 OR y = x;

(x = 5 AND y = x) OR y = 1;
(x = 5 AND y = 5) OR y = 1;

t.x = 5 AND y = x;
t.x = 5 AND y = x;

t.x = 'a' AND y = CONCAT_WS('-', t.x, 'b');
t.x = 'a' AND y = 'a-b';

x = 5 AND y = x AND y + 1 < 5;
FALSE;

x = 5 AND x = 6;
FALSE;

x = 5 AND (y = x OR z = 1);
x = 5 AND (y = x OR z = 1);

x = 5 AND x + 3 = 8;
x = 5;

x = 5 AND (SELECT x FROM t WHERE y = 1);
x = 5 AND (SELECT x FROM t WHERE y = 1);

x = 1 AND y > 0 AND (SELECT z = 5 FROM t WHERE y = 1);
x = 1 AND y > 0 AND (SELECT z = 5 FROM t WHERE y = 1);

x = 1 AND x = y AND (SELECT z FROM t WHERE a AND (b OR c));
x = 1 AND (SELECT z FROM t WHERE a AND (b OR c)) AND 1 = y;

SELECT * FROM t1, t2, t3 WHERE t1.a = 39 AND t2.b = t1.a AND t3.c = t2.b;
SELECT * FROM t1, t2, t3 WHERE t1.a = 39 AND t2.b = 39 AND t3.c = 39;
31 changes: 10 additions & 21 deletions tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2029,18 +2029,7 @@ JOIN "date_dim" AS "date_dim"
ON "date_dim"."d_year" = 2001
AND "store_sales"."ss_sold_date_sk" = "date_dim"."d_date_sk"
JOIN "household_demographics" AS "household_demographics"
ON "customer_demographics"."cd_demo_sk" = "store_sales"."ss_cdemo_sk"
AND "customer_demographics"."cd_education_status" = 'Advanced Degree'
AND "customer_demographics"."cd_education_status" = 'Primary'
AND "customer_demographics"."cd_education_status" = 'Secondary'
AND "customer_demographics"."cd_marital_status" = 'D'
AND "customer_demographics"."cd_marital_status" = 'M'
AND "customer_demographics"."cd_marital_status" = 'U'
AND "household_demographics"."hd_dep_count" = 1
AND "household_demographics"."hd_dep_count" = 3
AND "store_sales"."ss_hdemo_sk" = "household_demographics"."hd_demo_sk"
AND "store_sales"."ss_sales_price" <= 100.00
AND "store_sales"."ss_sales_price" >= 150.00
ON FALSE
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
JOIN "store" AS "store"
ON "store"."s_store_sk" = "store_sales"."ss_store_sk";

Expand Down Expand Up @@ -11687,28 +11676,28 @@ JOIN "customer_demographics" AS "cd1"
ON "cd1"."cd_demo_sk" = "web_returns"."wr_refunded_cdemo_sk"
AND (
(
"cd1"."cd_education_status" = "cd2"."cd_education_status"
AND "cd1"."cd_education_status" = 'Advanced Degree'
AND "cd1"."cd_marital_status" = "cd2"."cd_marital_status"
"cd1"."cd_education_status" = 'Advanced Degree'
AND "cd1"."cd_marital_status" = 'M'
AND "web_sales"."ws_sales_price" <= 200.00
AND "web_sales"."ws_sales_price" >= 150.00
AND 'Advanced Degree' = "cd2"."cd_education_status"
AND 'M' = "cd2"."cd_marital_status"
)
OR (
"cd1"."cd_education_status" = "cd2"."cd_education_status"
AND "cd1"."cd_education_status" = 'Primary'
AND "cd1"."cd_marital_status" = "cd2"."cd_marital_status"
"cd1"."cd_education_status" = 'Primary'
AND "cd1"."cd_marital_status" = 'W'
AND "web_sales"."ws_sales_price" <= 150.00
AND "web_sales"."ws_sales_price" >= 100.00
AND 'Primary' = "cd2"."cd_education_status"
AND 'W' = "cd2"."cd_marital_status"
)
OR (
"cd1"."cd_education_status" = "cd2"."cd_education_status"
AND "cd1"."cd_education_status" = 'Secondary'
AND "cd1"."cd_marital_status" = "cd2"."cd_marital_status"
"cd1"."cd_education_status" = 'Secondary'
AND "cd1"."cd_marital_status" = 'D'
AND "web_sales"."ws_sales_price" <= 100.00
AND "web_sales"."ws_sales_price" >= 50.00
AND 'D' = "cd2"."cd_marital_status"
AND 'Secondary' = "cd2"."cd_education_status"
)
)
GROUP BY
Expand Down