Skip to content

Commit

Permalink
feat(optimizer): simplify date_trunc (#2271)
Browse files Browse the repository at this point in the history
* feat(optimizer): simplify date_trunc

* Refactor

* extract merge_ranges

* python 3.7 compatibility

* fixup

* fixup

* address pr comments, add timestamp_trunc

* move dunder methods
  • Loading branch information
barakalon authored Sep 21, 2023
1 parent ed8714f commit d1cfa01
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 29 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"python-dateutil",
"pdoc",
"pre-commit",
"types-python-dateutil",
],
},
classifiers=[
Expand Down
5 changes: 5 additions & 0 deletions sqlglot/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,10 @@

import sqlglot

# A little hack for backwards compatibility with Python 3.7.
# For example, we might want a TypeVar for objects that support comparison e.g. SupportsRichComparisonT from typeshed.
# But Python 3.7 doesn't support Protocols, so we'd also need typing_extensions, which we don't want as a dependency.
A = t.TypeVar("A", bound=t.Any)

E = t.TypeVar("E", bound="sqlglot.exp.Expression")
T = t.TypeVar("T")
30 changes: 20 additions & 10 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,16 +664,6 @@ def load(cls, obj):

return load(obj)


IntoType = t.Union[
str,
t.Type[Expression],
t.Collection[t.Union[str, t.Type[Expression]]],
]
ExpOrStr = t.Union[str, Expression]


class Condition(Expression):
def and_(
self,
*expressions: t.Optional[ExpOrStr],
Expand Down Expand Up @@ -886,6 +876,18 @@ def __invert__(self) -> Not:
return not_(self.copy())


IntoType = t.Union[
str,
t.Type[Expression],
t.Collection[t.Union[str, t.Type[Expression]]],
]
ExpOrStr = t.Union[str, Expression]


class Condition(Expression):
"""Logical conditions like x AND y, or simply x"""


class Predicate(Condition):
"""Relationships like x = y, x > 1, x >= y."""

Expand Down Expand Up @@ -4328,6 +4330,10 @@ class DateDiff(Func, TimeUnit):
class DateTrunc(Func):
arg_types = {"unit": True, "this": True, "zone": False}

@property
def unit(self) -> Expression:
return self.args["unit"]


class DatetimeAdd(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False}
Expand Down Expand Up @@ -4392,6 +4398,10 @@ class TimestampDiff(Func, TimeUnit):
class TimestampTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False}

@property
def unit(self) -> Expression:
return self.args["unit"]


class TimeAdd(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False}
Expand Down
22 changes: 21 additions & 1 deletion sqlglot/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

if t.TYPE_CHECKING:
from sqlglot import exp
from sqlglot._typing import E, T
from sqlglot._typing import A, E, T
from sqlglot.expressions import Expression


CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
PYTHON_VERSION = sys.version_info[:2]
logger = logging.getLogger("sqlglot")
Expand Down Expand Up @@ -435,3 +436,22 @@ def dict_depth(d: t.Dict) -> int:
def first(it: t.Iterable[T]) -> T:
"""Returns the first element from an iterable (useful for sets)."""
return next(i for i in it)


def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
if not ranges:
return []

ranges = sorted(ranges)

merged = [ranges[0]]

for start, end in ranges[1:]:
last_start, last_end = merged[-1]

if start <= last_end:
merged[-1] = (last_start, max(last_end, end))
else:
merged.append((start, end))

return merged
211 changes: 194 additions & 17 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import datetime
import functools
import itertools
import typing as t
from collections import deque
from decimal import Decimal

from sqlglot import exp
from sqlglot.generator import cached_generator
from sqlglot.helper import first, while_changing
from sqlglot.helper import first, merge_ranges, while_changing

# Final means that an expression should not be simplified
FINAL = "final"


class UnsupportedUnit(Exception):
pass


def simplify(expression):
"""
Rewrite sqlglot AST to simplify expressions.
Expand Down Expand Up @@ -73,6 +78,7 @@ def _simplify(expression, root=True):
node.parent = expression.parent
node = simplify_literals(node, root)
node = simplify_parens(node)
node = simplify_datetrunc_predicate(node)

if root:
expression.replace(node)
Expand All @@ -84,6 +90,21 @@ def _simplify(expression, root=True):
return expression


def catch(*exceptions):
"""Decorator that ignores a simplification function if any of `exceptions` are raised"""

def decorator(func):
def wrapped(expression, *args, **kwargs):
try:
return func(expression, *args, **kwargs)
except exceptions:
return expression

return wrapped

return decorator


def rewrite_between(expression: exp.Expression) -> exp.Expression:
"""Rewrite x between y and z to x >= y AND x <= z.
Expand Down Expand Up @@ -196,7 +217,7 @@ def _simplify_connectors(expression, left, right):
exp.Is,
)

INVERSE_COMPARISONS = {
INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
exp.LT: exp.GT,
exp.GT: exp.LT,
exp.LTE: exp.GTE,
Expand Down Expand Up @@ -530,6 +551,123 @@ def simplify_concat(expression):
return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)


DateRange = t.Tuple[datetime.date, datetime.date]


def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
"""
Get the date range for a DATE_TRUNC equality comparison:
Example:
_datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
Returns:
tuple of [min, max) or None if a value can never be equal to `date` for `unit`
"""
floor = date_floor(date, unit)

if date != floor:
# This will always be False, except for NULL values.
return None

return floor, floor + interval(unit)


def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
"""Get the logical expression for a date range"""
return exp.and_(
left >= date_literal(drange[0]),
left < date_literal(drange[1]),
copy=False,
)


def _datetrunc_eq(
left: exp.Expression, date: datetime.date, unit: str
) -> t.Optional[exp.Expression]:
drange = _datetrunc_range(date, unit)
if not drange:
return None

return _datetrunc_eq_expression(left, drange)


def _datetrunc_neq(
left: exp.Expression, date: datetime.date, unit: str
) -> t.Optional[exp.Expression]:
drange = _datetrunc_range(date, unit)
if not drange:
return None

return exp.and_(
left < date_literal(drange[0]),
left >= date_literal(drange[1]),
copy=False,
)


DateTruncBinaryTransform = t.Callable[
[exp.Expression, datetime.date, str], t.Optional[exp.Expression]
]
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
exp.EQ: _datetrunc_eq,
exp.NEQ: _datetrunc_neq,
}
DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}


def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
return (
isinstance(left, (exp.DateTrunc, exp.TimestampTrunc))
and isinstance(right, exp.Cast)
and right.is_type(*exp.DataType.TEMPORAL_TYPES)
)


@catch(ModuleNotFoundError, UnsupportedUnit)
def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
"""Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
comparison = expression.__class__

if comparison not in DATETRUNC_COMPARISONS:
return expression

if isinstance(expression, exp.Binary):
l, r = expression.left, expression.right

if _is_datetrunc_predicate(l, r):
pass
elif _is_datetrunc_predicate(r, l):
comparison = INVERSE_COMPARISONS.get(comparison, comparison)
l, r = r, l
else:
return expression

unit = l.unit.name.lower()
date = extract_date(r)

return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
elif isinstance(expression, exp.In):
l = expression.this
rs = expression.expressions

if all(_is_datetrunc_predicate(l, r) for r in rs):
unit = l.unit.name.lower()

ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r]
if not ranges:
return expression

ranges = merge_ranges(ranges)

return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)

return expression


# CROSS joins result in an empty table if the right table is empty.
# So we can only simplify certain types of joins to CROSS.
# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
Expand Down Expand Up @@ -603,31 +741,70 @@ def extract_date(cast):
return None


def extract_interval(interval):
def extract_interval(expression):
n = int(expression.name)
unit = expression.text("unit").lower()

try:
from dateutil.relativedelta import relativedelta # type: ignore
except ModuleNotFoundError:
return interval(unit, n)
except (UnsupportedUnit, ModuleNotFoundError):
return None

n = int(interval.name)
unit = interval.text("unit").lower()

def date_literal(date):
return exp.cast(
exp.Literal.string(date),
"DATETIME" if isinstance(date, datetime.datetime) else "DATE",
)


def interval(unit: str, n: int = 1):
from dateutil.relativedelta import relativedelta

if unit == "year":
return relativedelta(years=n)
return relativedelta(years=1 * n)
if unit == "quarter":
return relativedelta(months=3 * n)
if unit == "month":
return relativedelta(months=n)
return relativedelta(months=1 * n)
if unit == "week":
return relativedelta(weeks=n)
return relativedelta(weeks=1 * n)
if unit == "day":
return relativedelta(days=n)
return None
return relativedelta(days=1 * n)

raise UnsupportedUnit(f"Unsupported unit: {unit}")

def date_literal(date):
return exp.cast(
exp.Literal.string(date),
"DATETIME" if isinstance(date, datetime.datetime) else "DATE",
)

def date_floor(d: datetime.date, unit: str) -> datetime.date:
if unit == "year":
return d.replace(month=1, day=1)
if unit == "quarter":
if d.month <= 3:
return d.replace(month=1, day=1)
elif d.month <= 6:
return d.replace(month=4, day=1)
elif d.month <= 9:
return d.replace(month=7, day=1)
else:
return d.replace(month=10, day=1)
if unit == "month":
return d.replace(month=d.month, day=1)
if unit == "week":
# Assuming week starts on Monday (0) and ends on Sunday (6)
return d - datetime.timedelta(days=d.weekday())
if unit == "day":
return d

raise UnsupportedUnit(f"Unsupported unit: {unit}")


def date_ceil(d: datetime.date, unit: str) -> datetime.date:
floor = date_floor(d, unit)

if floor == d:
return d

return floor + interval(unit)


def boolean_literal(condition):
Expand Down
Loading

0 comments on commit d1cfa01

Please sign in to comment.