Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(optimizer): simplify_equality #2281

Merged
merged 2 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4010,6 +4010,10 @@ def __init__(self, **args):

super().__init__(**args)

@property
def unit(self) -> t.Optional[Var]:
return self.args.get("unit")


# https://www.oracletutorial.com/oracle-basics/oracle-interval/
# https://trino.io/docs/current/language/types.html#interval-day-to-second
Expand All @@ -4021,10 +4025,6 @@ class IntervalSpan(Expression):
class Interval(TimeUnit):
arg_types = {"this": False, "unit": False}

@property
def unit(self) -> t.Optional[Var]:
return self.args.get("unit")


class IgnoreNulls(Expression):
pass
Expand Down Expand Up @@ -4398,10 +4398,6 @@ class TimestampDiff(Func, TimeUnit):
class TimestampTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False}

@property
def unit(self) -> Expression:
return self.args["unit"]


class TimeAdd(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False}
Expand Down
88 changes: 88 additions & 0 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _simplify(expression, root=True):
node = simplify_coalesce(node)
node.parent = expression.parent
node = simplify_literals(node, root)
node = simplify_equality(node)
node = simplify_parens(node)
node = simplify_datetrunc_predicate(node)

Expand Down Expand Up @@ -368,6 +369,87 @@ def absorb_and_eliminate(expression, root=True):
return expression


INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
exp.DateAdd: exp.Sub,
exp.DateSub: exp.Add,
exp.DatetimeAdd: exp.Sub,
exp.DatetimeSub: exp.Add,
}

INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
**INVERSE_DATE_OPS,
exp.Add: exp.Sub,
exp.Sub: exp.Add,
}


def _is_number(expression: exp.Expression) -> bool:
return expression.is_number


def _is_date(expression: exp.Expression) -> bool:
return isinstance(expression, exp.Cast) and extract_date(expression) is not None


def _is_interval(expression: exp.Expression) -> bool:
return isinstance(expression, exp.Interval) and extract_interval(expression) is not None


@catch(ModuleNotFoundError, UnsupportedUnit)
def simplify_equality(expression: exp.Expression) -> exp.Expression:
"""
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and =
Here's how we reference all the operands in the code below:

l r
x + 1 = 3
a b
"""
if isinstance(expression, COMPARISONS):
l, r = expression.left, expression.right

if l.__class__ in INVERSE_OPS:
pass
elif r.__class__ in INVERSE_OPS:
l, r = r, l
else:
return expression

if r.is_number:
a_predicate = _is_number
b_predicate = _is_number
elif _is_date(r):
a_predicate = _is_date
b_predicate = _is_interval
else:
return expression

if l.__class__ in INVERSE_DATE_OPS:
a = l.this
b = exp.Interval(
this=l.expression.copy(),
unit=l.unit.copy(),
)
else:
a, b = l.left, l.right

if not a_predicate(a) and b_predicate(b):
pass
elif not a_predicate(b) and b_predicate(a):
a, b = b, a
else:
return expression

return expression.__class__(
this=a.copy(), expression=INVERSE_OPS[l.__class__](this=r.copy(), expression=b.copy())
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these copies needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah shit sorry i hit merge.
I'm not sure.
In general, I copy when creating a new expression using child nodes from an existing tree. When do you?

Copy link
Owner

@tobymao tobymao Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the optimizer, i only copy when i have to, because it's mutation heavy and expensive,

in generator, where the contract is idempotency and no mutations, i copy

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@georgesittas georgesittas Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree on copy being a bit redundant here: It seems like we're dropping the expression arg in this return statement anyway & a and b are not associated with r's AST, so this completely replaces the AST corresponding to the input expression.

)
return expression


def simplify_literals(expression, root=True):
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_binary, root)
Expand Down Expand Up @@ -771,6 +853,12 @@ def interval(unit: str, n: int = 1):
return relativedelta(weeks=1 * n)
if unit == "day":
return relativedelta(days=1 * n)
if unit == "hour":
return relativedelta(hours=1 * n)
if unit == "minute":
return relativedelta(minutes=1 * n)
if unit == "second":
return relativedelta(seconds=1 * n)

raise UnsupportedUnit(f"Unsupported unit: {unit}")

Expand Down
57 changes: 57 additions & 0 deletions tests/fixtures/optimizer/simplify.sql
Original file line number Diff line number Diff line change
Expand Up @@ -757,3 +757,60 @@ x < CAST('2022-01-01' AS DATE) AND x >= CAST('2021-01-01' AS DATE);

TIMESTAMP_TRUNC(x, YEAR) = CAST('2021-01-01' AS DATETIME);
x < CAST('2022-01-01 00:00:00' AS DATETIME) AND x >= CAST('2021-01-01 00:00:00' AS DATETIME);

--------------------------------------
-- EQUALITY
--------------------------------------
x + 1 = 3;
x = 2;

1 + x = 3;
x = 2;

3 = x + 1;
x = 2;

x - 1 = 3;
x = 4;

x + 1 > 3;
x > 2;

x + 1 >= 3;
x >= 2;

x + 1 <= 3;
x <= 2;

x + 1 <= 3;
x <= 2;

x + 1 <> 3;
x <> 2;

1 + x + 1 = 3 + 1;
x = 2;

x - INTERVAL 1 DAY = CAST('2021-01-01' AS DATE);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would happen if it's

x - interval 1 day = cast(y as date)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing. The right side must be a date literal.

x = CAST('2021-01-02' AS DATE);

x - INTERVAL 1 HOUR > CAST('2021-01-01' AS DATETIME);
x > CAST('2021-01-01 01:00:00' AS DATETIME);

DATETIME_ADD(x, 1, HOUR) < CAST('2021-01-01' AS DATETIME);
x < CAST('2020-12-31 23:00:00' AS DATETIME);

DATETIME_SUB(x, 1, DAY) >= CAST('2021-01-01' AS DATETIME);
x >= CAST('2021-01-02 00:00:00' AS DATETIME);

DATE_ADD(x, 1, DAY) <= CAST('2021-01-01' AS DATE);
x <= CAST('2020-12-31' AS DATE);

DATE_SUB(x, 1, DAY) <> CAST('2021-01-01' AS DATE);
x <> CAST('2021-01-02' AS DATE);

DATE_ADD(DATE_ADD(DATE_TRUNC('week', DATE_SUB(x, 1, DAY)), 1, DAY), 1, YEAR) < CAST('2021-01-08' AS DATE);
x < CAST('2020-01-07' AS DATE);

x - INTERVAL '1' day = CAST(y AS DATE);
x - INTERVAL '1' day = CAST(y AS DATE);