Skip to content

Commit

Permalink
fix(mysql): DATE_ADD for datetimes (#2360)
Browse files Browse the repository at this point in the history
  • Loading branch information
barakalon authored Oct 3, 2023
1 parent 55e2d15 commit 2bc30a5
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 49 deletions.
5 changes: 1 addition & 4 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def inner_func(args: t.List) -> E:


def parse_date_delta_with_interval(
expression_class: t.Type[E], invert: bool = False
expression_class: t.Type[E],
) -> t.Callable[[t.List], t.Optional[E]]:
def func(args: t.List) -> t.Optional[E]:
if len(args) < 2:
Expand All @@ -553,9 +553,6 @@ def func(args: t.List) -> t.Optional[E]:
if expression and expression.is_string:
expression = exp.Literal.number(expression.this)

if expression and invert:
expression = expression * -1

return expression_class(
this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
)
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)),
"DATE_ADD": parse_date_delta_with_interval(exp.TsOrDsAdd),
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.TsOrDsAdd, invert=True),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
"ISNULL": isnull_to_is_null,
"LOCATE": locate_to_strposition,
Expand Down
18 changes: 14 additions & 4 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4053,6 +4053,16 @@ def unit(self) -> t.Optional[Var]:
return self.args.get("unit")


class IntervalOp(TimeUnit):
arg_types = {"unit": True, "expression": True}

def interval(self):
return Interval(
this=self.expression.copy(),
unit=self.unit.copy(),
)


# https://www.oracletutorial.com/oracle-basics/oracle-interval/
# https://trino.io/docs/current/language/types.html#interval-day-to-second
# https://docs.databricks.com/en/sql/language-manual/data-types/interval-type.html
Expand Down Expand Up @@ -4358,11 +4368,11 @@ class CurrentUser(Func):
arg_types = {"this": False}


class DateAdd(Func, TimeUnit):
class DateAdd(Func, IntervalOp):
arg_types = {"this": True, "expression": True, "unit": False}


class DateSub(Func, TimeUnit):
class DateSub(Func, IntervalOp):
arg_types = {"this": True, "expression": True, "unit": False}


Expand All @@ -4379,11 +4389,11 @@ def unit(self) -> Expression:
return self.args["unit"]


class DatetimeAdd(Func, TimeUnit):
class DatetimeAdd(Func, IntervalOp):
arg_types = {"this": True, "expression": True, "unit": False}


class DatetimeSub(Func, TimeUnit):
class DatetimeSub(Func, IntervalOp):
arg_types = {"this": True, "expression": True, "unit": False}


Expand Down
115 changes: 109 additions & 6 deletions sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import datetime
import functools
import typing as t

from sqlglot import exp
Expand All @@ -11,6 +13,16 @@
if t.TYPE_CHECKING:
B = t.TypeVar("B", bound=exp.Binary)

BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
BinaryCoercions = t.Dict[
t.Tuple[exp.DataType.Type, exp.DataType.Type],
BinaryCoercionFunc,
]


# Interval units that operate on date components
DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}


def annotate_types(
expression: E,
Expand Down Expand Up @@ -48,6 +60,59 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type
return lambda self, e: self._annotate_with_type(e, data_type)


def _is_iso_date(text: str) -> bool:
try:
datetime.date.fromisoformat(text)
return True
except ValueError:
return False


def _is_iso_datetime(text: str) -> bool:
try:
datetime.datetime.fromisoformat(text)
return True
except ValueError:
return False


def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
date_text = l.name
unit = r.text("unit").lower()

is_iso_date = _is_iso_date(date_text)

if is_iso_date and unit in DATE_UNITS:
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE))
return exp.DataType.Type.DATE

# An ISO date is also an ISO datetime, but not vice versa
if is_iso_date or _is_iso_datetime(date_text):
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME))
return exp.DataType.Type.DATETIME

return exp.DataType.Type.UNKNOWN


def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
unit = r.text("unit").lower()
if unit not in DATE_UNITS:
return exp.DataType.Type.DATETIME
return l.type.this if l.type else exp.DataType.Type.UNKNOWN


def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
@functools.wraps(func)
def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
return func(r, l)

return _swapped


def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}


class _TypeAnnotator(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
Expand Down Expand Up @@ -104,10 +169,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DataType.Type.DATE: {
exp.CurrentDate,
exp.Date,
exp.DateAdd,
exp.DateFromParts,
exp.DateStrToDate,
exp.DateSub,
exp.DateTrunc,
exp.DiToDate,
exp.StrToDate,
Expand Down Expand Up @@ -212,6 +275,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
exp.DateSub: lambda self, e: self._annotate_dateadd(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
Expand All @@ -234,21 +299,41 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
# Specifies what types a given type can be coerced into (autofilled)
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}

# Coercion functions for binary operations.
# Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
BINARY_COERCIONS: BinaryCoercions = {
**swap_all(
{
(t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
for t in exp.DataType.TEXT_TYPES
}
),
**swap_all(
{
(exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
}
),
}

def __init__(
self,
schema: Schema,
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
binary_coercions: t.Optional[BinaryCoercions] = None,
) -> None:
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
self.binary_coercions = binary_coercions or self.BINARY_COERCIONS

# Caches the ids of annotated sub-Expressions, to ensure we only visit them once
self._visited: t.Set[int] = set()

def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None:
expression.type = target_type
def _set_type(
self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
) -> None:
expression.type = target_type # type: ignore
self._visited.add(id(expression))

def annotate(self, expression: E) -> E:
Expand Down Expand Up @@ -342,8 +427,8 @@ def _maybe_coerce(
def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)

left_type = expression.left.type.this
right_type = expression.right.type.this
left, right = expression.left, expression.right
left_type, right_type = left.type.this, right.type.this

if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
Expand All @@ -357,6 +442,8 @@ def _annotate_binary(self, expression: B) -> B:
self._set_type(expression, exp.DataType.Type.BOOLEAN)
elif isinstance(expression, exp.Predicate):
self._set_type(expression, exp.DataType.Type.BOOLEAN)
elif (left_type, right_type) in self.binary_coercions:
self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
else:
self._set_type(expression, self._maybe_coerce(left_type, right_type))

Expand Down Expand Up @@ -421,3 +508,19 @@ def _annotate_by_args(
)

return expression

def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
self._annotate_args(expression)

if expression.this.type.this in exp.DataType.TEXT_TYPES:
datatype = _coerce_literal_and_interval(expression.this, expression.interval())
elif (
expression.this.type.is_type(exp.DataType.Type.DATE)
and expression.text("unit").lower() not in DATE_UNITS
):
datatype = exp.DataType.Type.DATETIME
else:
datatype = expression.this.type

self._set_type(expression, datatype)
return expression
25 changes: 5 additions & 20 deletions sqlglot/optimizer/canonicalize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import itertools
import typing as t

from sqlglot import exp

Expand Down Expand Up @@ -41,30 +40,16 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
return node


# Expression type to transform -> arg key -> (allowed types, type to cast to)
ARG_TYPES: t.Dict[
t.Type[exp.Expression], t.Dict[str, t.Tuple[t.Iterable[exp.DataType.Type], exp.DataType.Type]]
] = {
exp.DateAdd: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATE)},
exp.DateSub: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATE)},
exp.DatetimeAdd: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATETIME)},
exp.DatetimeSub: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATETIME)},
exp.Extract: {"expression": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATETIME)},
}


def coerce_type(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Binary):
_coerce_date(node.left, node.right)
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
else:
arg_types = ARG_TYPES.get(node.__class__)
if arg_types:
for arg_key, (allowed, to) in arg_types.items():
arg = node.args.get(arg_key)
if arg and not arg.type.is_type(*allowed):
_replace_cast(arg, to)
elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
*exp.DataType.TEMPORAL_TYPES
):
_replace_cast(node.expression, exp.DataType.Type.DATETIME)

return node


Expand Down
5 changes: 1 addition & 4 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:

if l.__class__ in INVERSE_DATE_OPS:
a = l.this
b = exp.Interval(
this=l.expression.copy(),
unit=l.unit.copy(),
)
b = l.interval()
else:
a, b = l.left, l.right

Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def test_mysql(self):
self.validate_all(
"SELECT DATE(DATE_SUB(`dt`, INTERVAL DAYOFMONTH(`dt`) - 1 DAY)) AS __timestamp FROM tableT",
write={
"mysql": "SELECT DATE(DATE_ADD(`dt`, INTERVAL ((DAYOFMONTH(`dt`) - 1) * -1) DAY)) AS __timestamp FROM tableT",
"mysql": "SELECT DATE(DATE_SUB(`dt`, INTERVAL (DAYOFMONTH(`dt`) - 1) DAY)) AS __timestamp FROM tableT",
},
)
self.validate_identity("SELECT name FROM temp WHERE name = ? FOR UPDATE")
Expand Down
4 changes: 0 additions & 4 deletions tests/dialects/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,6 @@ def test_time(self):
"presto": "DATE_ADD('DAY', 1 * -1, x)",
},
)
self.validate_all(
"DATE_ADD('DAY', 1 * -1, CAST(CAST(x AS TIMESTAMP) AS DATE))",
read={"mysql": "DATE_SUB(x, INTERVAL 1 DAY)"},
)
self.validate_all(
"NOW()",
write={
Expand Down
3 changes: 0 additions & 3 deletions tests/fixtures/optimizer/canonicalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,3 @@ DATE_ADD(CAST("x" AS DATE), 1, 'YEAR');

DATE_ADD('2023-01-01', 1, 'YEAR');
DATE_ADD(CAST('2023-01-01' AS DATE), 1, 'YEAR');

DATETIME_SUB('2023-01-01', 1, YEAR);
DATETIME_SUB(CAST('2023-01-01' AS DATETIME), 1, YEAR);
2 changes: 1 addition & 1 deletion tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
Original file line number Diff line number Diff line change
Expand Up @@ -9775,7 +9775,7 @@ JOIN "date_dim" AS "d1"
ON "catalog_sales"."cs_sold_date_sk" = "d1"."d_date_sk"
AND "d1"."d_week_seq" = "d2"."d_week_seq"
AND "d1"."d_year" = 2002
AND "d3"."d_date" > CONCAT("d1"."d_date", INTERVAL '5' day)
AND "d3"."d_date" > "d1"."d_date" + INTERVAL '5' day
GROUP BY
"item"."i_item_desc",
"warehouse"."w_warehouse_name",
Expand Down
Loading

0 comments on commit 2bc30a5

Please sign in to comment.