From 72aa018379f39e9cf937910fc71d31f5bc2dbf43 Mon Sep 17 00:00:00 2001 From: Jo <46752250+GeorgeSittas@users.noreply.github.com> Date: Thu, 29 Jun 2023 20:01:41 +0300 Subject: [PATCH] Fix!(optimizer): preserve predecence when merging derived tables (#1857) * Fix!(optimizer): preserve predecence when merging derived tables * Rephrase comment * Add another test in simplify * Formatting * Fix typo * Add another test --- sqlglot/optimizer/merge_subqueries.py | 22 ++++++++++++++++++- sqlglot/optimizer/simplify.py | 1 + tests/fixtures/optimizer/merge_subqueries.sql | 22 +++++++++++++++++++ tests/fixtures/optimizer/optimizer.sql | 8 +++++++ tests/fixtures/optimizer/simplify.sql | 6 +++++ 5 files changed, 58 insertions(+), 1 deletion(-) diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index fefe96e71a..e156d5e4fc 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -47,6 +47,17 @@ def merge_subqueries(expression, leave_tables_isolated=False): } +# Projections in the outer query that are instances of these types can be replaced +# without getting wrapped in parentheses, because the precedence won't be altered. +SAFE_TO_REPLACE_UNWRAPPED = ( + exp.Column, + exp.EQ, + exp.Func, + exp.NEQ, + exp.Paren, +) + + def merge_ctes(expression, leave_tables_isolated=False): scopes = traverse_scope(expression) @@ -293,8 +304,17 @@ def _merge_expressions(outer_scope, inner_scope, alias): if not projection_name: continue columns_to_replace = outer_columns.get(projection_name, []) + + expression = expression.unalias() + must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED) + for column in columns_to_replace: - column.replace(expression.unalias().copy()) + # Ensures we don't alter the intended operator precedence if there's additional + # context surrounding the outer expression (i.e. it's not a simple projection). + if isinstance(column.parent, (exp.Unary, exp.Binary)) and must_wrap_expression: + expression = exp.paren(expression, copy=False) + + column.replace(expression.copy()) def _merge_where(outer_scope, inner_scope, from_or_join): diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 34005d961a..1a2d82ceeb 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -400,6 +400,7 @@ def simplify_parens(expression): or not isinstance(this, exp.Binary) or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) + or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) ): return expression.this return expression diff --git a/tests/fixtures/optimizer/merge_subqueries.sql b/tests/fixtures/optimizer/merge_subqueries.sql index bd56e07dc8..fb69ea7bbf 100644 --- a/tests/fixtures/optimizer/merge_subqueries.sql +++ b/tests/fixtures/optimizer/merge_subqueries.sql @@ -2,6 +2,28 @@ SELECT a, b FROM (SELECT a, b FROM x); SELECT x.a AS a, x.b AS b FROM x AS x; +# title: Wrap addition in a multiplication +SELECT c * 2 AS d FROM (SELECT a + b AS c FROM x); +SELECT (x.a + x.b) * 2 AS d FROM x AS x; + +# title: Wrap addition in an addition +# note: The "simplify" rule will unwrap this +SELECT c + d AS e FROM (SELECT a + b AS c, a AS d FROM x); +SELECT (x.a + x.b) + x.a AS e FROM x AS x; + +# title: Wrap multiplication in an addition +# note: The "simplify" rule will unwrap this +WITH cte AS (SELECT a * b AS c, a AS d FROM x) SELECT c + d AS e FROM cte; +SELECT (x.a * x.b) + x.a AS e FROM x AS x; + +# title: Don't wrap function +SELECT 2 * foo AS bar FROM (SELECT CAST(b AS DOUBLE) AS foo FROM x); +SELECT 2 * CAST(x.b AS DOUBLE) AS bar FROM x AS x; + +# title: Don't wrap a wrapped expression +SELECT foo * 2 AS bar FROM (SELECT (1 + 2 + 3) AS foo FROM x); +SELECT (1 + 2 + 3) * 2 AS bar FROM x AS x; + # title: Inner table alias is merged SELECT a, b FROM (SELECT a, b FROM x AS q) AS r; SELECT q.a AS a, q.b AS b FROM x AS q; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index f71ddde102..38e64d7d17 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -693,3 +693,11 @@ GROUP BY "x"."a" + 1 + 1 HAVING "x"."a" + 1 + 1 + 1 + 1 > 1; + +# title: replace alias with mult expression without wrapping it +WITH cte AS (SELECT a * b AS c, a AS d, b as e FROM x) SELECT c + d - (c - e) AS f FROM cte; +SELECT + "x"."a" * "x"."b" + "x"."a" - ( + "x"."a" * "x"."b" - "x"."b" + ) AS "f" +FROM "x" AS "x"; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index e0aded4e97..f821575d75 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -354,6 +354,12 @@ a + 4; a + (1 + 1) + (10); a + 12; +a + (1 * 1) + (1 - (1 * 1)); +a + 1; + +a + (b * c) + (d - (e * f)); +a + b * c + (d - e * f); + 5 + 4 * 3; 17;