Skip to content

Commit

Permalink
feat(optimizer): canonicalize date arithmetic funcs (#2320)
Browse files Browse the repository at this point in the history
  • Loading branch information
barakalon authored Sep 25, 2023
1 parent aa2c4c3 commit 64a7b93
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
27 changes: 22 additions & 5 deletions sqlglot/optimizer/canonicalize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import itertools
import typing as t

from sqlglot import exp

Expand Down Expand Up @@ -40,14 +41,30 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
return node


# Expression type to transform -> arg key -> (allowed types, type to cast to)
ARG_TYPES: t.Dict[
t.Type[exp.Expression], t.Dict[str, t.Tuple[t.Iterable[exp.DataType.Type], exp.DataType.Type]]
] = {
exp.DateAdd: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATE)},
exp.DateSub: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATE)},
exp.DatetimeAdd: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATETIME)},
exp.DatetimeSub: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATETIME)},
exp.Extract: {"expression": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATETIME)},
}


def coerce_type(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Binary):
_coerce_date(node.left, node.right)
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
elif isinstance(node, exp.Extract):
if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES:
_replace_cast(node.expression, "datetime")
else:
arg_types = ARG_TYPES.get(node.__class__)
if arg_types:
for arg_key, (allowed, to) in arg_types.items():
arg = node.args.get(arg_key)
if arg and not arg.type.is_type(*allowed):
_replace_cast(arg, to)
return node


Expand Down Expand Up @@ -89,10 +106,10 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
and b.type
and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
):
_replace_cast(b, "date")
_replace_cast(b, exp.DataType.Type.DATE)


def _replace_cast(node: exp.Expression, to: str) -> None:
def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
node.replace(exp.cast(node.copy(), to=to))


Expand Down
9 changes: 9 additions & 0 deletions tests/fixtures/optimizer/canonicalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,12 @@ CAST('2023-01-01' AS TIMESTAMP);

TIMESTAMP('2023-01-01', '12:00:00');
TIMESTAMP('2023-01-01', '12:00:00');

DATE_ADD(CAST("x" AS DATE), 1, 'YEAR');
DATE_ADD(CAST("x" AS DATE), 1, 'YEAR');

DATE_ADD('2023-01-01', 1, 'YEAR');
DATE_ADD(CAST('2023-01-01' AS DATE), 1, 'YEAR');

DATETIME_SUB('2023-01-01', 1, YEAR);
DATETIME_SUB(CAST('2023-01-01' AS DATETIME), 1, YEAR);

0 comments on commit 64a7b93

Please sign in to comment.