From 2f8bb1390d444af1828c0c73945e53cbcb659e41 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Sun, 8 Oct 2023 20:45:04 +0300 Subject: [PATCH 01/16] Feat!(optimizer): propagate constants --- sqlglot/helper.py | 8 ++++ sqlglot/optimizer/simplify.py | 39 +++++++++++++++++ tests/fixtures/optimizer/optimizer.sql | 2 +- .../optimizer/pushdown_predicates.sql | 8 ++-- tests/fixtures/optimizer/simplify.sql | 43 ++++++++++++++++--- tests/fixtures/optimizer/tpc-ds/tpc-ds.sql | 31 +++++-------- 6 files changed, 100 insertions(+), 31 deletions(-) diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 00d49ae389..74b61e39cf 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -441,6 +441,14 @@ def first(it: t.Iterable[T]) -> T: def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]: + """ + Merges a sequence of ranges, represented as tuples (low, high) whose values + belong to some totally-ordered set. + + Example: + >>> merge_ranges([(1, 3), (2, 6)]) + [(1, 6)] + """ if not ranges: return [] diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index cd266eee56..418fcb4120 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -65,6 +65,7 @@ def _simplify(expression, root=True): node = rewrite_between(node) node = uniq_sort(node, generate, root) node = absorb_and_eliminate(node, root) + node = propagate_constants(node, root) node = simplify_concat(node) exp.replace_children(node, lambda e: _simplify(e, False)) @@ -369,6 +370,44 @@ def absorb_and_eliminate(expression, root=True): return expression +def propagate_constants(expression, root=True): + """ + Propagate constants for conjunctions normalized into DNF: + + SELECT * FROM t WHERE a = b AND b = 5 becomes + SELECT * FROM t WHERE a = 5 AND b = 5 + + Reference: https://www.sqlite.org/optoverview.html + """ + from sqlglot.optimizer.normalize import normalized + + if ( + isinstance(expression, exp.And) + and (root or not expression.same_parent) + and normalized(expression, dnf=True) + ): + constant_mapping: t.Dict[exp.Column, [int, exp.Literal]] = {} + for eq in expression.find_all(exp.EQ): + l, r = eq.left, eq.right + + if isinstance(l, exp.Column) and isinstance(r, exp.Literal): + pass + elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): + l, r = r, l + else: + continue + + constant_mapping[l] = (id(l), r) + + if constant_mapping: + for column in expression.find_all(exp.Column): + id_and_constant = constant_mapping.get(column) + if id_and_constant and id(column) != id_and_constant[0]: + column.replace(id_and_constant[1].copy()) + + return expression + + INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { exp.DateAdd: exp.Sub, exp.DateSub: exp.Add, diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index 4cc62c9b1f..70c68114c0 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -369,7 +369,7 @@ SELECT "y"."b" AS "b" FROM "x" AS "x" RIGHT JOIN "y_2" AS "y" - ON "x"."a" = "y"."b"; + ON "x"."a" = 1; # title: lateral column alias reference diff --git a/tests/fixtures/optimizer/pushdown_predicates.sql b/tests/fixtures/optimizer/pushdown_predicates.sql index cfa69fbda4..61c1ee2207 100644 --- a/tests/fixtures/optimizer/pushdown_predicates.sql +++ b/tests/fixtures/optimizer/pushdown_predicates.sql @@ -11,7 +11,7 @@ SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = y.a AND x.a = 1 AND x SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.a WHERE TRUE; SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.b; -SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.a OR x.a = y.b WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.b; +SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.b OR 1 = y.a WHERE x.a = y.b OR (x.a = 1 AND x.b = 1 AND 1 = y.a); SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x) AS x WHERE x.c = 1; SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x WHERE x.b * 1 = 1) AS x WHERE TRUE; @@ -23,13 +23,13 @@ SELECT x.a AS a FROM (SELECT x.a FROM x AS x) AS x JOIN y WHERE x.a = 1 AND x.b SELECT x.a AS a FROM (SELECT x.a FROM x AS x WHERE x.a = 1 AND x.b = 1) AS x JOIN y ON x.c = 1 OR y.c = 1 WHERE TRUE AND TRUE AND (TRUE); SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y) AS y ON y.a = 1 AND x.a = y.a; -SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y WHERE y.a = 1) AS y ON x.a = y.a AND TRUE; +SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y WHERE y.a = 1) AS y ON x.a = 1 AND TRUE; SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y) AS y ON y.a = 1 WHERE x.a = 1 AND x.b = 1 AND y.a = x.a; -SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.a = x.a AND TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE; +SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE; SELECT x.a AS a FROM x AS x CROSS JOIN (SELECT * FROM y AS y) AS y WHERE x.a = 1 AND x.b = 1 AND y.a = x.a AND y.a = 1; -SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.a = x.a AND TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE AND TRUE; +SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE; with t1 as (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) as row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1; WITH t1 AS (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 7a6c8871b2..927c80292d 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -625,22 +625,22 @@ t0.x = t1.x AND t0.y < t1.y AND t0.y <= t1.y; t0.x = t1.x AND t0.y < t1.y AND t0.y <= t1.y; -------------------------------------- --- Coalesce +-- COALESCE -------------------------------------- COALESCE(x); x; COALESCE(x, 1) = 2; -x = 2 AND NOT x IS NULL; +x = 2; 2 = COALESCE(x, 1); -2 = x AND NOT x IS NULL; +2 = x; COALESCE(x, 1, 1) = 1 + 1; -x = 2 AND NOT x IS NULL; +x = 2; COALESCE(x, 1, 2) = 2; -x = 2 AND NOT x IS NULL; +x = 2; COALESCE(x, 3) <= 2; x <= 2 AND NOT x IS NULL; @@ -864,3 +864,36 @@ x < CAST('2020-01-07' AS DATE); x - INTERVAL '1' day = CAST(y AS DATE); x - INTERVAL '1' day = CAST(y AS DATE); + +-------------------------------------- +-- Constant Propagation +-------------------------------------- +x = 5 AND y = x; +x = 5 AND y = 5; + +x = 5 OR y = x; +x = 5 OR y = x; + +(x = 5 AND y = x) OR y = 1; +(x = 5 AND y = 5) OR y = 1; + +t.x = 5 AND y = x; +t.x = 5 AND y = x; + +t.x = 'a' AND y = CONCAT_WS('-', t.x, 'b'); +t.x = 'a' AND y = 'a-b'; + +x = 5 AND y = x AND y + 1 < 5; +FALSE; + +x = 5 AND x = 6; +FALSE; + +x = 5 AND (y = x OR z = 1); +x = 5 AND (y = x OR z = 1); + +x = 5 AND x + 3 = 8; +x = 5; + +SELECT * FROM t1 LEFT JOIN t2 ON t1.x = t2.y AND t2.y > 5 AND t1.x = 5; +SELECT * FROM t1 LEFT JOIN t2 ON FALSE; diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index 22181821db..d89db19f1e 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -2029,18 +2029,7 @@ JOIN "date_dim" AS "date_dim" ON "date_dim"."d_year" = 2001 AND "store_sales"."ss_sold_date_sk" = "date_dim"."d_date_sk" JOIN "household_demographics" AS "household_demographics" - ON "customer_demographics"."cd_demo_sk" = "store_sales"."ss_cdemo_sk" - AND "customer_demographics"."cd_education_status" = 'Advanced Degree' - AND "customer_demographics"."cd_education_status" = 'Primary' - AND "customer_demographics"."cd_education_status" = 'Secondary' - AND "customer_demographics"."cd_marital_status" = 'D' - AND "customer_demographics"."cd_marital_status" = 'M' - AND "customer_demographics"."cd_marital_status" = 'U' - AND "household_demographics"."hd_dep_count" = 1 - AND "household_demographics"."hd_dep_count" = 3 - AND "store_sales"."ss_hdemo_sk" = "household_demographics"."hd_demo_sk" - AND "store_sales"."ss_sales_price" <= 100.00 - AND "store_sales"."ss_sales_price" >= 150.00 + ON FALSE JOIN "store" AS "store" ON "store"."s_store_sk" = "store_sales"."ss_store_sk"; @@ -11687,28 +11676,28 @@ JOIN "customer_demographics" AS "cd1" ON "cd1"."cd_demo_sk" = "web_returns"."wr_refunded_cdemo_sk" AND ( ( - "cd1"."cd_education_status" = "cd2"."cd_education_status" - AND "cd1"."cd_education_status" = 'Advanced Degree' - AND "cd1"."cd_marital_status" = "cd2"."cd_marital_status" + "cd1"."cd_education_status" = 'Advanced Degree' AND "cd1"."cd_marital_status" = 'M' AND "web_sales"."ws_sales_price" <= 200.00 AND "web_sales"."ws_sales_price" >= 150.00 + AND 'Advanced Degree' = "cd2"."cd_education_status" + AND 'M' = "cd2"."cd_marital_status" ) OR ( - "cd1"."cd_education_status" = "cd2"."cd_education_status" - AND "cd1"."cd_education_status" = 'Primary' - AND "cd1"."cd_marital_status" = "cd2"."cd_marital_status" + "cd1"."cd_education_status" = 'Primary' AND "cd1"."cd_marital_status" = 'W' AND "web_sales"."ws_sales_price" <= 150.00 AND "web_sales"."ws_sales_price" >= 100.00 + AND 'Primary' = "cd2"."cd_education_status" + AND 'W' = "cd2"."cd_marital_status" ) OR ( - "cd1"."cd_education_status" = "cd2"."cd_education_status" - AND "cd1"."cd_education_status" = 'Secondary' - AND "cd1"."cd_marital_status" = "cd2"."cd_marital_status" + "cd1"."cd_education_status" = 'Secondary' AND "cd1"."cd_marital_status" = 'D' AND "web_sales"."ws_sales_price" <= 100.00 AND "web_sales"."ws_sales_price" >= 50.00 + AND 'D' = "cd2"."cd_marital_status" + AND 'Secondary' = "cd2"."cd_education_status" ) ) GROUP BY From 2f6b484da36f7c636743b98945417ed4a21c4c93 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Sun, 8 Oct 2023 21:17:13 +0300 Subject: [PATCH 02/16] Rephrase docstring --- sqlglot/optimizer/simplify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 418fcb4120..f9696f9c33 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -372,7 +372,7 @@ def absorb_and_eliminate(expression, root=True): def propagate_constants(expression, root=True): """ - Propagate constants for conjunctions normalized into DNF: + Propagate constants for conjunctions in DNF: SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5 From 7405ed931517538a182abd15e6d7b8dada158c43 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Sun, 8 Oct 2023 21:21:04 +0300 Subject: [PATCH 03/16] Type hint fix --- sqlglot/optimizer/simplify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index f9696f9c33..dc0f6f088a 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -386,7 +386,7 @@ def propagate_constants(expression, root=True): and (root or not expression.same_parent) and normalized(expression, dnf=True) ): - constant_mapping: t.Dict[exp.Column, [int, exp.Literal]] = {} + constant_mapping: t.Dict[exp.Column, t.Tuple[int, exp.Literal]] = {} for eq in expression.find_all(exp.EQ): l, r = eq.left, eq.right From b5d0c7f5e0fcaf6fece342c7cdbd201563184fb1 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Sun, 8 Oct 2023 21:53:27 +0300 Subject: [PATCH 04/16] Don't replace variables that are compared to NULL --- sqlglot/optimizer/simplify.py | 4 +++- tests/fixtures/optimizer/simplify.sql | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index dc0f6f088a..3467e2e66d 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -402,7 +402,9 @@ def propagate_constants(expression, root=True): if constant_mapping: for column in expression.find_all(exp.Column): id_and_constant = constant_mapping.get(column) - if id_and_constant and id(column) != id_and_constant[0]: + if id_and_constant and id(column) != id_and_constant[0] and not ( + isinstance(column.parent, exp.Is) and type(column.parent.expression) is exp.Null + ): column.replace(id_and_constant[1].copy()) return expression diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 927c80292d..9ae6e42175 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -631,16 +631,16 @@ COALESCE(x); x; COALESCE(x, 1) = 2; -x = 2; +x = 2 AND NOT x IS NULL; 2 = COALESCE(x, 1); -2 = x; +2 = x AND NOT x IS NULL; COALESCE(x, 1, 1) = 1 + 1; -x = 2; +x = 2 AND NOT x IS NULL; COALESCE(x, 1, 2) = 2; -x = 2; +x = 2 AND NOT x IS NULL; COALESCE(x, 3) <= 2; x <= 2 AND NOT x IS NULL; From 92452323640359bc2c4c56ac61fa3f09e1c64a8a Mon Sep 17 00:00:00 2001 From: George Sittas Date: Sun, 8 Oct 2023 21:57:52 +0300 Subject: [PATCH 05/16] Formatting --- sqlglot/optimizer/simplify.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 3467e2e66d..f657219f49 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -402,8 +402,13 @@ def propagate_constants(expression, root=True): if constant_mapping: for column in expression.find_all(exp.Column): id_and_constant = constant_mapping.get(column) - if id_and_constant and id(column) != id_and_constant[0] and not ( - isinstance(column.parent, exp.Is) and type(column.parent.expression) is exp.Null + if ( + id_and_constant + and id(column) != id_and_constant[0] + and not ( + isinstance(column.parent, exp.Is) + and type(column.parent.expression) is exp.Null + ) ): column.replace(id_and_constant[1].copy()) From ab54273ff5684bee39f6840d7607e625702136f0 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Sun, 8 Oct 2023 22:17:46 +0300 Subject: [PATCH 06/16] Cleanup --- sqlglot/optimizer/simplify.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index f657219f49..920e8cf398 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -401,16 +401,14 @@ def propagate_constants(expression, root=True): if constant_mapping: for column in expression.find_all(exp.Column): - id_and_constant = constant_mapping.get(column) + parent = column.parent + column_id, constant = constant_mapping.get(column) or (None, None) if ( - id_and_constant - and id(column) != id_and_constant[0] - and not ( - isinstance(column.parent, exp.Is) - and type(column.parent.expression) is exp.Null - ) + column_id is not None + and id(column) != column_id + and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) ): - column.replace(id_and_constant[1].copy()) + column.replace(constant.copy()) return expression From 83dd49759941fa390766c09e1c0eaf2188f691d9 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Mon, 9 Oct 2023 02:34:21 +0300 Subject: [PATCH 07/16] Use find_all_in_scope instead of find_all --- sqlglot/optimizer/simplify.py | 3 ++- tests/fixtures/optimizer/simplify.sql | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 920e8cf398..7c6a522570 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -8,6 +8,7 @@ from sqlglot import exp from sqlglot.generator import cached_generator from sqlglot.helper import first, merge_ranges, while_changing +from sqlglot.optimizer.scope import find_all_in_scope # Final means that an expression should not be simplified FINAL = "final" @@ -400,7 +401,7 @@ def propagate_constants(expression, root=True): constant_mapping[l] = (id(l), r) if constant_mapping: - for column in expression.find_all(exp.Column): + for column in find_all_in_scope(expression, exp.Column): parent = column.parent column_id, constant = constant_mapping.get(column) or (None, None) if ( diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 9ae6e42175..2273496fc7 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -895,5 +895,5 @@ x = 5 AND (y = x OR z = 1); x = 5 AND x + 3 = 8; x = 5; -SELECT * FROM t1 LEFT JOIN t2 ON t1.x = t2.y AND t2.y > 5 AND t1.x = 5; -SELECT * FROM t1 LEFT JOIN t2 ON FALSE; +x = 5 AND (SELECT x FROM t WHERE y = 1); +x = 5 AND (SELECT x FROM t WHERE y = 1); From 64377a0d80d653a68380b05fe67d830e3818a6dc Mon Sep 17 00:00:00 2001 From: George Sittas Date: Mon, 9 Oct 2023 02:50:18 +0300 Subject: [PATCH 08/16] Leave a TODO comment to add helper that detects literals --- sqlglot/optimizer/simplify.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 7c6a522570..f822aadbf6 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -391,6 +391,8 @@ def propagate_constants(expression, root=True): for eq in expression.find_all(exp.EQ): l, r = eq.left, eq.right + # TODO: create a helper that can be used to detect nested literal expressions such + # as CAST('2012-01-01' AS DATE), since we usually want to treat those as literals too if isinstance(l, exp.Column) and isinstance(r, exp.Literal): pass elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): From 1adeb163899d9d10f6e3dcd1704383d16dcd360f Mon Sep 17 00:00:00 2001 From: George Sittas Date: Mon, 9 Oct 2023 02:57:05 +0300 Subject: [PATCH 09/16] Fix another bug involving find_all vs find_all_in_scope --- sqlglot/optimizer/simplify.py | 2 +- tests/fixtures/optimizer/simplify.sql | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index f822aadbf6..35291f491e 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -388,7 +388,7 @@ def propagate_constants(expression, root=True): and normalized(expression, dnf=True) ): constant_mapping: t.Dict[exp.Column, t.Tuple[int, exp.Literal]] = {} - for eq in expression.find_all(exp.EQ): + for eq in find_all_in_scope(expression, exp.EQ): l, r = eq.left, eq.right # TODO: create a helper that can be used to detect nested literal expressions such diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 2273496fc7..123baba64a 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -897,3 +897,6 @@ x = 5; x = 5 AND (SELECT x FROM t WHERE y = 1); x = 5 AND (SELECT x FROM t WHERE y = 1); + +x = 1 AND y > 0 AND (SELECT z = 5 FROM t WHERE y = 1); +x = 1 AND y > 0 AND (SELECT z = 5 FROM t WHERE y = 1); From d7e1064d381f92e110e07fe8cb350e672fb4a74a Mon Sep 17 00:00:00 2001 From: George Sittas Date: Mon, 9 Oct 2023 03:05:47 +0300 Subject: [PATCH 10/16] Move normalized helper in expressions.py --- sqlglot/expressions.py | 21 +++++++++++++++++++++ sqlglot/optimizer/eliminate_joins.py | 5 ++--- sqlglot/optimizer/normalize.py | 10 ++-------- sqlglot/optimizer/pushdown_predicates.py | 3 +-- sqlglot/optimizer/simplify.py | 3 +-- 5 files changed, 27 insertions(+), 15 deletions(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 80f1c0faf0..b581150b57 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -6390,6 +6390,27 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: return function +def normalized(expression: Expression, dnf: bool = False) -> bool: + """ + Checks whether a given expression is in a normal form of interest. + + Example: + >>> normalized(maybe_parse("(a AND b) OR c OR (d AND e)"), dnf=True) + True + >>> normalized(maybe_parse("(a OR b) AND c")) + True + >>> normalized(maybe_parse("a AND (b OR c)"), dnf=True) + False + + Args: + expression: The target expression. + dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). + Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). + """ + ancestor, root = (And, Or) if dnf else (Or, And) + return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root)) + + def true() -> Boolean: """ Returns a true Boolean expression. diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index 3134e65986..68e84440f0 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -1,5 +1,4 @@ from sqlglot import expressions as exp -from sqlglot.optimizer.normalize import normalized from sqlglot.optimizer.scope import Scope, traverse_scope @@ -152,13 +151,13 @@ def extract_condition(condition): # ON x.a = y.b AND y.b > 1 # # should pull y.b as the join key and x.a as the source key - if normalized(on): + if exp.normalized(on): on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False) for condition in on.flatten(): if isinstance(condition, exp.EQ): extract_condition(condition) - elif normalized(on, dnf=True): + elif exp.normalized(on, dnf=True): conditions = None for condition in on.flatten(): diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index 1db094ecec..bce432a7a7 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -32,7 +32,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): if isinstance(node, exp.Connector): - if normalized(node, dnf=dnf): + if exp.normalized(node, dnf=dnf): continue root = node is expression original = node.copy() @@ -63,12 +63,6 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = return expression -def normalized(expression, dnf=False): - ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) - - return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root)) - - def normalization_distance(expression, dnf=False): """ The difference in the number of predicates between the current expression and the normalized form. @@ -117,7 +111,7 @@ def distributive_law(expression, dnf, max_distance, generate): x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) """ - if normalized(expression, dnf=dnf): + if exp.normalized(expression, dnf=dnf): return expression distance = normalization_distance(expression, dnf=dnf) diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index f7348b58e3..bdef5c8f39 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -1,5 +1,4 @@ from sqlglot import exp -from sqlglot.optimizer.normalize import normalized from sqlglot.optimizer.scope import build_scope, find_in_scope from sqlglot.optimizer.simplify import simplify @@ -55,7 +54,7 @@ def pushdown(condition, sources, scope_ref_count): return condition = condition.replace(simplify(condition)) - cnf_like = normalized(condition) or not normalized(condition, dnf=True) + cnf_like = exp.normalized(condition) or not exp.normalized(condition, dnf=True) predicates = list( condition.flatten() diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 35291f491e..c8240260da 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -380,12 +380,11 @@ def propagate_constants(expression, root=True): Reference: https://www.sqlite.org/optoverview.html """ - from sqlglot.optimizer.normalize import normalized if ( isinstance(expression, exp.And) and (root or not expression.same_parent) - and normalized(expression, dnf=True) + and exp.normalized(expression, dnf=True) ): constant_mapping: t.Dict[exp.Column, t.Tuple[int, exp.Literal]] = {} for eq in find_all_in_scope(expression, exp.EQ): From 81dcfcc16f910df8bea99c30738ba7946861bea4 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Mon, 9 Oct 2023 04:48:02 +0300 Subject: [PATCH 11/16] Move normalized back to normalize.py, refactor.. --- sqlglot/expressions.py | 23 +------------ sqlglot/optimizer/eliminate_joins.py | 5 +-- sqlglot/optimizer/normalize.py | 41 ++++++++++++++++++++---- sqlglot/optimizer/pushdown_predicates.py | 3 +- sqlglot/optimizer/simplify.py | 5 +-- tests/fixtures/optimizer/simplify.sql | 3 ++ 6 files changed, 46 insertions(+), 34 deletions(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index b581150b57..f61e6e14aa 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -487,7 +487,7 @@ def flatten(self, unnest=True): """ for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__): if not type(node) is self.__class__: - yield node.unnest() if unnest else node + yield node.unnest() if unnest and not isinstance(node, Subquery) else node def __str__(self) -> str: return self.sql() @@ -6390,27 +6390,6 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: return function -def normalized(expression: Expression, dnf: bool = False) -> bool: - """ - Checks whether a given expression is in a normal form of interest. - - Example: - >>> normalized(maybe_parse("(a AND b) OR c OR (d AND e)"), dnf=True) - True - >>> normalized(maybe_parse("(a OR b) AND c")) - True - >>> normalized(maybe_parse("a AND (b OR c)"), dnf=True) - False - - Args: - expression: The target expression. - dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). - Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). - """ - ancestor, root = (And, Or) if dnf else (Or, And) - return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root)) - - def true() -> Boolean: """ Returns a true Boolean expression. diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index 68e84440f0..3134e65986 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -1,4 +1,5 @@ from sqlglot import expressions as exp +from sqlglot.optimizer.normalize import normalized from sqlglot.optimizer.scope import Scope, traverse_scope @@ -151,13 +152,13 @@ def extract_condition(condition): # ON x.a = y.b AND y.b > 1 # # should pull y.b as the join key and x.a as the source key - if exp.normalized(on): + if normalized(on): on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False) for condition in on.flatten(): if isinstance(condition, exp.EQ): extract_condition(condition) - elif exp.normalized(on, dnf=True): + elif normalized(on, dnf=True): conditions = None for condition in on.flatten(): diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index bce432a7a7..8d82b2da79 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -6,6 +6,7 @@ from sqlglot.errors import OptimizeError from sqlglot.generator import cached_generator from sqlglot.helper import while_changing +from sqlglot.optimizer.scope import find_all_in_scope from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort logger = logging.getLogger("sqlglot") @@ -32,7 +33,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): if isinstance(node, exp.Connector): - if exp.normalized(node, dnf=dnf): + if normalized(node, dnf=dnf): continue root = node is expression original = node.copy() @@ -63,9 +64,33 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = return expression -def normalization_distance(expression, dnf=False): +def normalized(expression: exp.Expression, dnf: bool = False) -> bool: """ - The difference in the number of predicates between the current expression and the normalized form. + Checks whether a given expression is in a normal form of interest. + + Example: + >>> from sqlglot import parse_one + >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) + True + >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default + True + >>> normalized(parse_one("a AND (b OR c)"), dnf=True) + False + + Args: + expression: The expression to check if it's normalized. + dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). + Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). + """ + ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) + return not any( + connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) + ) + + +def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int: + """ + The difference in the number of predicates between a given expression and its normalized form. This is used as an estimate of the cost of the conversion which is exponential in complexity. @@ -76,10 +101,12 @@ def normalization_distance(expression, dnf=False): 4 Args: - expression (sqlglot.Expression): expression to compute distance - dnf (bool): compute to dnf distance instead + expression: The expression to compute the normalization distance for. + dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). + Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). + Returns: - int: difference + The normalization distance. """ return sum(_predicate_lengths(expression, dnf)) - ( sum(1 for _ in expression.find_all(exp.Connector)) + 1 @@ -111,7 +138,7 @@ def distributive_law(expression, dnf, max_distance, generate): x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) """ - if exp.normalized(expression, dnf=dnf): + if normalized(expression, dnf=dnf): return expression distance = normalization_distance(expression, dnf=dnf) diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index bdef5c8f39..f7348b58e3 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -1,4 +1,5 @@ from sqlglot import exp +from sqlglot.optimizer.normalize import normalized from sqlglot.optimizer.scope import build_scope, find_in_scope from sqlglot.optimizer.simplify import simplify @@ -54,7 +55,7 @@ def pushdown(condition, sources, scope_ref_count): return condition = condition.replace(simplify(condition)) - cnf_like = exp.normalized(condition) or not exp.normalized(condition, dnf=True) + cnf_like = normalized(condition) or not normalized(condition, dnf=True) predicates = list( condition.flatten() diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index c8240260da..ca6999a606 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -5,6 +5,7 @@ from collections import deque from decimal import Decimal +import sqlglot from sqlglot import exp from sqlglot.generator import cached_generator from sqlglot.helper import first, merge_ranges, while_changing @@ -384,9 +385,9 @@ def propagate_constants(expression, root=True): if ( isinstance(expression, exp.And) and (root or not expression.same_parent) - and exp.normalized(expression, dnf=True) + and sqlglot.optimizer.normalize.normalized(expression, dnf=True) ): - constant_mapping: t.Dict[exp.Column, t.Tuple[int, exp.Literal]] = {} + constant_mapping = {} for eq in find_all_in_scope(expression, exp.EQ): l, r = eq.left, eq.right diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 123baba64a..3d0bed0ec2 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -900,3 +900,6 @@ x = 5 AND (SELECT x FROM t WHERE y = 1); x = 1 AND y > 0 AND (SELECT z = 5 FROM t WHERE y = 1); x = 1 AND y > 0 AND (SELECT z = 5 FROM t WHERE y = 1); + +x = 1 AND x = y AND (SELECT z FROM t WHERE a AND (b OR c)); +x = 1 AND (SELECT z FROM t WHERE a AND (b OR c)) AND 1 = y; From 370a4d542b099cc6f7bf0b15ebff6c8a640bc03a Mon Sep 17 00:00:00 2001 From: George Sittas Date: Mon, 9 Oct 2023 05:04:06 +0300 Subject: [PATCH 12/16] Update comment --- sqlglot/optimizer/simplify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index ca6999a606..df153018ae 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -392,7 +392,7 @@ def propagate_constants(expression, root=True): l, r = eq.left, eq.right # TODO: create a helper that can be used to detect nested literal expressions such - # as CAST('2012-01-01' AS DATE), since we usually want to treat those as literals too + # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too if isinstance(l, exp.Column) and isinstance(r, exp.Literal): pass elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): From 866e3d0fadb73f34023faf77c9fd4f55fee4cea4 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Mon, 9 Oct 2023 05:23:19 +0300 Subject: [PATCH 13/16] Increase test coverage --- tests/fixtures/optimizer/simplify.sql | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 3d0bed0ec2..f420295825 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -871,6 +871,9 @@ x - INTERVAL '1' day = CAST(y AS DATE); x = 5 AND y = x; x = 5 AND y = 5; +5 = x AND y = x; +y = 5 AND 5 = x; + x = 5 OR y = x; x = 5 OR y = x; From 4943eb7ff371c147d9c29a4d9f4b6912f13adf9e Mon Sep 17 00:00:00 2001 From: George Sittas Date: Mon, 9 Oct 2023 05:30:55 +0300 Subject: [PATCH 14/16] Increase test coverage --- tests/fixtures/optimizer/simplify.sql | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index f420295825..91dda41f17 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -906,3 +906,6 @@ x = 1 AND y > 0 AND (SELECT z = 5 FROM t WHERE y = 1); x = 1 AND x = y AND (SELECT z FROM t WHERE a AND (b OR c)); x = 1 AND (SELECT z FROM t WHERE a AND (b OR c)) AND 1 = y; + +SELECT * FROM t1, t2, t3 WHERE t1.a = 39 AND t2.b = t1.a AND t3.c = t2.b; +SELECT * FROM t1, t2, t3 WHERE t1.a = 39 AND t2.b = 39 AND t3.c = 39; From 170f5ef74256002be58585b9df36a84cd4e3d199 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Mon, 9 Oct 2023 18:44:32 +0300 Subject: [PATCH 15/16] Make the constant propagation rule opt-in --- sqlglot/optimizer/simplify.py | 8 +++-- tests/fixtures/optimizer/optimizer.sql | 2 +- .../optimizer/pushdown_predicates.sql | 8 ++--- tests/fixtures/optimizer/tpc-ds/tpc-ds.sql | 31 +++++++++++++------ tests/test_optimizer.py | 6 +++- 5 files changed, 37 insertions(+), 18 deletions(-) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index df153018ae..92b4770bcb 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -19,7 +19,7 @@ class UnsupportedUnit(Exception): pass -def simplify(expression): +def simplify(expression, constant_propagation=False): """ Rewrite sqlglot AST to simplify expressions. @@ -31,6 +31,8 @@ def simplify(expression): Args: expression (sqlglot.Expression): expression to simplify + constant_propagation: whether or not the constant propagation rule should be used + Returns: sqlglot.Expression: simplified expression """ @@ -67,9 +69,11 @@ def _simplify(expression, root=True): node = rewrite_between(node) node = uniq_sort(node, generate, root) node = absorb_and_eliminate(node, root) - node = propagate_constants(node, root) node = simplify_concat(node) + if constant_propagation: + node = propagate_constants(node, root) + exp.replace_children(node, lambda e: _simplify(e, False)) # Post-order transformations diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index 70c68114c0..4cc62c9b1f 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -369,7 +369,7 @@ SELECT "y"."b" AS "b" FROM "x" AS "x" RIGHT JOIN "y_2" AS "y" - ON "x"."a" = 1; + ON "x"."a" = "y"."b"; # title: lateral column alias reference diff --git a/tests/fixtures/optimizer/pushdown_predicates.sql b/tests/fixtures/optimizer/pushdown_predicates.sql index 61c1ee2207..cfa69fbda4 100644 --- a/tests/fixtures/optimizer/pushdown_predicates.sql +++ b/tests/fixtures/optimizer/pushdown_predicates.sql @@ -11,7 +11,7 @@ SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = y.a AND x.a = 1 AND x SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.a WHERE TRUE; SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.b; -SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.b OR 1 = y.a WHERE x.a = y.b OR (x.a = 1 AND x.b = 1 AND 1 = y.a); +SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.a OR x.a = y.b WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.b; SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x) AS x WHERE x.c = 1; SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x WHERE x.b * 1 = 1) AS x WHERE TRUE; @@ -23,13 +23,13 @@ SELECT x.a AS a FROM (SELECT x.a FROM x AS x) AS x JOIN y WHERE x.a = 1 AND x.b SELECT x.a AS a FROM (SELECT x.a FROM x AS x WHERE x.a = 1 AND x.b = 1) AS x JOIN y ON x.c = 1 OR y.c = 1 WHERE TRUE AND TRUE AND (TRUE); SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y) AS y ON y.a = 1 AND x.a = y.a; -SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y WHERE y.a = 1) AS y ON x.a = 1 AND TRUE; +SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y WHERE y.a = 1) AS y ON x.a = y.a AND TRUE; SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y) AS y ON y.a = 1 WHERE x.a = 1 AND x.b = 1 AND y.a = x.a; -SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE; +SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.a = x.a AND TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE; SELECT x.a AS a FROM x AS x CROSS JOIN (SELECT * FROM y AS y) AS y WHERE x.a = 1 AND x.b = 1 AND y.a = x.a AND y.a = 1; -SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE; +SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.a = x.a AND TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE AND TRUE; with t1 as (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) as row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1; WITH t1 AS (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1; diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index d89db19f1e..22181821db 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -2029,7 +2029,18 @@ JOIN "date_dim" AS "date_dim" ON "date_dim"."d_year" = 2001 AND "store_sales"."ss_sold_date_sk" = "date_dim"."d_date_sk" JOIN "household_demographics" AS "household_demographics" - ON FALSE + ON "customer_demographics"."cd_demo_sk" = "store_sales"."ss_cdemo_sk" + AND "customer_demographics"."cd_education_status" = 'Advanced Degree' + AND "customer_demographics"."cd_education_status" = 'Primary' + AND "customer_demographics"."cd_education_status" = 'Secondary' + AND "customer_demographics"."cd_marital_status" = 'D' + AND "customer_demographics"."cd_marital_status" = 'M' + AND "customer_demographics"."cd_marital_status" = 'U' + AND "household_demographics"."hd_dep_count" = 1 + AND "household_demographics"."hd_dep_count" = 3 + AND "store_sales"."ss_hdemo_sk" = "household_demographics"."hd_demo_sk" + AND "store_sales"."ss_sales_price" <= 100.00 + AND "store_sales"."ss_sales_price" >= 150.00 JOIN "store" AS "store" ON "store"."s_store_sk" = "store_sales"."ss_store_sk"; @@ -11676,28 +11687,28 @@ JOIN "customer_demographics" AS "cd1" ON "cd1"."cd_demo_sk" = "web_returns"."wr_refunded_cdemo_sk" AND ( ( - "cd1"."cd_education_status" = 'Advanced Degree' + "cd1"."cd_education_status" = "cd2"."cd_education_status" + AND "cd1"."cd_education_status" = 'Advanced Degree' + AND "cd1"."cd_marital_status" = "cd2"."cd_marital_status" AND "cd1"."cd_marital_status" = 'M' AND "web_sales"."ws_sales_price" <= 200.00 AND "web_sales"."ws_sales_price" >= 150.00 - AND 'Advanced Degree' = "cd2"."cd_education_status" - AND 'M' = "cd2"."cd_marital_status" ) OR ( - "cd1"."cd_education_status" = 'Primary' + "cd1"."cd_education_status" = "cd2"."cd_education_status" + AND "cd1"."cd_education_status" = 'Primary' + AND "cd1"."cd_marital_status" = "cd2"."cd_marital_status" AND "cd1"."cd_marital_status" = 'W' AND "web_sales"."ws_sales_price" <= 150.00 AND "web_sales"."ws_sales_price" >= 100.00 - AND 'Primary' = "cd2"."cd_education_status" - AND 'W' = "cd2"."cd_marital_status" ) OR ( - "cd1"."cd_education_status" = 'Secondary' + "cd1"."cd_education_status" = "cd2"."cd_education_status" + AND "cd1"."cd_education_status" = 'Secondary' + AND "cd1"."cd_marital_status" = "cd2"."cd_marital_status" AND "cd1"."cd_marital_status" = 'D' AND "web_sales"."ws_sales_price" <= 100.00 AND "web_sales"."ws_sales_price" >= 50.00 - AND 'D' = "cd2"."cd_marital_status" - AND 'Secondary' = "cd2"."cd_education_status" ) ) GROUP BY diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 8fc3273155..ef57d49e79 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -45,6 +45,10 @@ def normalize(expression, **kwargs): return optimizer.simplify.simplify(expression) +def simplify(expression, **kwargs): + return optimizer.simplify.simplify(expression, constant_propagation=True, **kwargs) + + class TestOptimizer(unittest.TestCase): maxDiff = None @@ -271,7 +275,7 @@ def test_pushdown_projection(self): self.check_file("pushdown_projections", pushdown_projections, schema=self.schema) def test_simplify(self): - self.check_file("simplify", optimizer.simplify.simplify) + self.check_file("simplify", simplify) expression = parse_one("TRUE AND TRUE AND TRUE") self.assertEqual(exp.true(), optimizer.simplify.simplify(expression)) From d40fc163f593ff47a883428f9769a6e045776734 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Mon, 9 Oct 2023 18:59:49 +0300 Subject: [PATCH 16/16] Fix optimize_joins bug --- sqlglot/optimizer/optimize_joins.py | 6 +++- tests/fixtures/optimizer/tpc-ds/tpc-ds.sql | 39 +++++++++++++++------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 9d401fc73c..15304561a6 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -39,10 +39,14 @@ def optimize_joins(expression): if len(other_table_names(dep)) < 2: continue + operator = type(on) for predicate in on.flatten(): if name in exp.column_table_names(predicate): predicate.replace(exp.true()) - join.on(predicate, copy=False) + predicate = exp._combine( + [join.args.get("on"), predicate], operator, copy=False + ) + join.on(predicate, append=False, copy=False) expression = reorder_joins(expression) expression = normalize(expression) diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index 22181821db..91b553edca 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -2029,18 +2029,33 @@ JOIN "date_dim" AS "date_dim" ON "date_dim"."d_year" = 2001 AND "store_sales"."ss_sold_date_sk" = "date_dim"."d_date_sk" JOIN "household_demographics" AS "household_demographics" - ON "customer_demographics"."cd_demo_sk" = "store_sales"."ss_cdemo_sk" - AND "customer_demographics"."cd_education_status" = 'Advanced Degree' - AND "customer_demographics"."cd_education_status" = 'Primary' - AND "customer_demographics"."cd_education_status" = 'Secondary' - AND "customer_demographics"."cd_marital_status" = 'D' - AND "customer_demographics"."cd_marital_status" = 'M' - AND "customer_demographics"."cd_marital_status" = 'U' - AND "household_demographics"."hd_dep_count" = 1 - AND "household_demographics"."hd_dep_count" = 3 - AND "store_sales"."ss_hdemo_sk" = "household_demographics"."hd_demo_sk" - AND "store_sales"."ss_sales_price" <= 100.00 - AND "store_sales"."ss_sales_price" >= 150.00 + ON ( + "customer_demographics"."cd_demo_sk" = "store_sales"."ss_cdemo_sk" + AND "customer_demographics"."cd_education_status" = 'Advanced Degree' + AND "customer_demographics"."cd_marital_status" = 'U' + AND "household_demographics"."hd_dep_count" = 3 + AND "store_sales"."ss_hdemo_sk" = "household_demographics"."hd_demo_sk" + AND "store_sales"."ss_sales_price" <= 150.00 + AND "store_sales"."ss_sales_price" >= 100.00 + ) + OR ( + "customer_demographics"."cd_demo_sk" = "store_sales"."ss_cdemo_sk" + AND "customer_demographics"."cd_education_status" = 'Primary' + AND "customer_demographics"."cd_marital_status" = 'M' + AND "household_demographics"."hd_dep_count" = 1 + AND "store_sales"."ss_hdemo_sk" = "household_demographics"."hd_demo_sk" + AND "store_sales"."ss_sales_price" <= 100.00 + AND "store_sales"."ss_sales_price" >= 50.00 + ) + OR ( + "customer_demographics"."cd_demo_sk" = "store_sales"."ss_cdemo_sk" + AND "customer_demographics"."cd_education_status" = 'Secondary' + AND "customer_demographics"."cd_marital_status" = 'D' + AND "household_demographics"."hd_dep_count" = 1 + AND "store_sales"."ss_hdemo_sk" = "household_demographics"."hd_demo_sk" + AND "store_sales"."ss_sales_price" <= 200.00 + AND "store_sales"."ss_sales_price" >= 150.00 + ) JOIN "store" AS "store" ON "store"."s_store_sk" = "store_sales"."ss_store_sk";