diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index d0f4e3655e..3650ff5c26 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -538,7 +538,7 @@ def inner_func(args: t.List) -> E: def parse_date_delta_with_interval( - expression_class: t.Type[E], invert: bool = False + expression_class: t.Type[E], ) -> t.Callable[[t.List], t.Optional[E]]: def func(args: t.List) -> t.Optional[E]: if len(args) < 2: @@ -553,9 +553,6 @@ def func(args: t.List) -> t.Optional[E]: if expression and expression.is_string: expression = exp.Literal.number(expression.this) - if expression and invert: - expression = expression * -1 - return expression_class( this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) ) diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 374968bc61..59a0a2a9ff 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -262,9 +262,9 @@ class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)), - "DATE_ADD": parse_date_delta_with_interval(exp.TsOrDsAdd), + "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), - "DATE_SUB": parse_date_delta_with_interval(exp.TsOrDsAdd, invert=True), + "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)), "ISNULL": isnull_to_is_null, "LOCATE": locate_to_strposition, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 581453635d..204d393b6c 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -4053,6 +4053,16 @@ def unit(self) -> t.Optional[Var]: return self.args.get("unit") +class IntervalOp(TimeUnit): + arg_types = {"unit": True, "expression": True} + + def interval(self): + return Interval( + this=self.expression.copy(), + unit=self.unit.copy(), + ) + + # https://www.oracletutorial.com/oracle-basics/oracle-interval/ # https://trino.io/docs/current/language/types.html#interval-day-to-second # https://docs.databricks.com/en/sql/language-manual/data-types/interval-type.html @@ -4358,11 +4368,11 @@ class CurrentUser(Func): arg_types = {"this": False} -class DateAdd(Func, TimeUnit): +class DateAdd(Func, IntervalOp): arg_types = {"this": True, "expression": True, "unit": False} -class DateSub(Func, TimeUnit): +class DateSub(Func, IntervalOp): arg_types = {"this": True, "expression": True, "unit": False} @@ -4379,11 +4389,11 @@ def unit(self) -> Expression: return self.args["unit"] -class DatetimeAdd(Func, TimeUnit): +class DatetimeAdd(Func, IntervalOp): arg_types = {"this": True, "expression": True, "unit": False} -class DatetimeSub(Func, TimeUnit): +class DatetimeSub(Func, IntervalOp): arg_types = {"this": True, "expression": True, "unit": False} diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index afc6995d5f..17af6ac3fb 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,5 +1,7 @@ from __future__ import annotations +import datetime +import functools import typing as t from sqlglot import exp @@ -11,6 +13,16 @@ if t.TYPE_CHECKING: B = t.TypeVar("B", bound=exp.Binary) + BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type] + BinaryCoercions = t.Dict[ + t.Tuple[exp.DataType.Type, exp.DataType.Type], + BinaryCoercionFunc, + ] + + +# Interval units that operate on date components +DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} + def annotate_types( expression: E, @@ -48,6 +60,59 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type return lambda self, e: self._annotate_with_type(e, data_type) +def _is_iso_date(text: str) -> bool: + try: + datetime.date.fromisoformat(text) + return True + except ValueError: + return False + + +def _is_iso_datetime(text: str) -> bool: + try: + datetime.datetime.fromisoformat(text) + return True + except ValueError: + return False + + +def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: + date_text = l.name + unit = r.text("unit").lower() + + is_iso_date = _is_iso_date(date_text) + + if is_iso_date and unit in DATE_UNITS: + l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE)) + return exp.DataType.Type.DATE + + # An ISO date is also an ISO datetime, but not vice versa + if is_iso_date or _is_iso_datetime(date_text): + l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME)) + return exp.DataType.Type.DATETIME + + return exp.DataType.Type.UNKNOWN + + +def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: + unit = r.text("unit").lower() + if unit not in DATE_UNITS: + return exp.DataType.Type.DATETIME + return l.type.this if l.type else exp.DataType.Type.UNKNOWN + + +def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc: + @functools.wraps(func) + def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: + return func(r, l) + + return _swapped + + +def swap_all(coercions: BinaryCoercions) -> BinaryCoercions: + return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}} + + class _TypeAnnotator(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) @@ -104,10 +169,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.DataType.Type.DATE: { exp.CurrentDate, exp.Date, - exp.DateAdd, exp.DateFromParts, exp.DateStrToDate, - exp.DateSub, exp.DateTrunc, exp.DiToDate, exp.StrToDate, @@ -212,6 +275,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), + exp.DateAdd: lambda self, e: self._annotate_dateadd(e), + exp.DateSub: lambda self, e: self._annotate_dateadd(e), exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), @@ -234,21 +299,41 @@ class TypeAnnotator(metaclass=_TypeAnnotator): # Specifies what types a given type can be coerced into (autofilled) COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} + # Coercion functions for binary operations. + # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type. + BINARY_COERCIONS: BinaryCoercions = { + **swap_all( + { + (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval + for t in exp.DataType.TEXT_TYPES + } + ), + **swap_all( + { + (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval, + } + ), + } + def __init__( self, schema: Schema, annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, + binary_coercions: t.Optional[BinaryCoercions] = None, ) -> None: self.schema = schema self.annotators = annotators or self.ANNOTATORS self.coerces_to = coerces_to or self.COERCES_TO + self.binary_coercions = binary_coercions or self.BINARY_COERCIONS # Caches the ids of annotated sub-Expressions, to ensure we only visit them once self._visited: t.Set[int] = set() - def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None: - expression.type = target_type + def _set_type( + self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type + ) -> None: + expression.type = target_type # type: ignore self._visited.add(id(expression)) def annotate(self, expression: E) -> E: @@ -342,8 +427,8 @@ def _maybe_coerce( def _annotate_binary(self, expression: B) -> B: self._annotate_args(expression) - left_type = expression.left.type.this - right_type = expression.right.type.this + left, right = expression.left, expression.right + left_type, right_type = left.type.this, right.type.this if isinstance(expression, exp.Connector): if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: @@ -357,6 +442,8 @@ def _annotate_binary(self, expression: B) -> B: self._set_type(expression, exp.DataType.Type.BOOLEAN) elif isinstance(expression, exp.Predicate): self._set_type(expression, exp.DataType.Type.BOOLEAN) + elif (left_type, right_type) in self.binary_coercions: + self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right)) else: self._set_type(expression, self._maybe_coerce(left_type, right_type)) @@ -421,3 +508,19 @@ def _annotate_by_args( ) return expression + + def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp: + self._annotate_args(expression) + + if expression.this.type.this in exp.DataType.TEXT_TYPES: + datatype = _coerce_literal_and_interval(expression.this, expression.interval()) + elif ( + expression.this.type.is_type(exp.DataType.Type.DATE) + and expression.text("unit").lower() not in DATE_UNITS + ): + datatype = exp.DataType.Type.DATETIME + else: + datatype = expression.this.type + + self._set_type(expression, datatype) + return expression diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index 8acb37f6cf..ec3b3af13b 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -1,7 +1,6 @@ from __future__ import annotations import itertools -import typing as t from sqlglot import exp @@ -41,30 +40,16 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression: return node -# Expression type to transform -> arg key -> (allowed types, type to cast to) -ARG_TYPES: t.Dict[ - t.Type[exp.Expression], t.Dict[str, t.Tuple[t.Iterable[exp.DataType.Type], exp.DataType.Type]] -] = { - exp.DateAdd: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATE)}, - exp.DateSub: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATE)}, - exp.DatetimeAdd: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATETIME)}, - exp.DatetimeSub: {"this": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATETIME)}, - exp.Extract: {"expression": (exp.DataType.TEMPORAL_TYPES, exp.DataType.Type.DATETIME)}, -} - - def coerce_type(node: exp.Expression) -> exp.Expression: if isinstance(node, exp.Binary): _coerce_date(node.left, node.right) elif isinstance(node, exp.Between): _coerce_date(node.this, node.args["low"]) - else: - arg_types = ARG_TYPES.get(node.__class__) - if arg_types: - for arg_key, (allowed, to) in arg_types.items(): - arg = node.args.get(arg_key) - if arg and not arg.type.is_type(*allowed): - _replace_cast(arg, to) + elif isinstance(node, exp.Extract) and not node.expression.type.is_type( + *exp.DataType.TEMPORAL_TYPES + ): + _replace_cast(node.expression, exp.DataType.Type.DATETIME) + return node diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 8e646eac5c..de5b82c7be 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -426,10 +426,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression: if l.__class__ in INVERSE_DATE_OPS: a = l.this - b = exp.Interval( - this=l.expression.copy(), - unit=l.unit.copy(), - ) + b = l.interval() else: a, b = l.left, l.right diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 4e14a3b53a..11f921cefc 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -614,7 +614,7 @@ def test_mysql(self): self.validate_all( "SELECT DATE(DATE_SUB(`dt`, INTERVAL DAYOFMONTH(`dt`) - 1 DAY)) AS __timestamp FROM tableT", write={ - "mysql": "SELECT DATE(DATE_ADD(`dt`, INTERVAL ((DAYOFMONTH(`dt`) - 1) * -1) DAY)) AS __timestamp FROM tableT", + "mysql": "SELECT DATE(DATE_SUB(`dt`, INTERVAL (DAYOFMONTH(`dt`) - 1) DAY)) AS __timestamp FROM tableT", }, ) self.validate_identity("SELECT name FROM temp WHERE name = ? FOR UPDATE") diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 381b4fa60e..8edd31cd74 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -301,10 +301,6 @@ def test_time(self): "presto": "DATE_ADD('DAY', 1 * -1, x)", }, ) - self.validate_all( - "DATE_ADD('DAY', 1 * -1, CAST(CAST(x AS TIMESTAMP) AS DATE))", - read={"mysql": "DATE_SUB(x, INTERVAL 1 DAY)"}, - ) self.validate_all( "NOW()", write={ diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql index 9982227262..2ba762da28 100644 --- a/tests/fixtures/optimizer/canonicalize.sql +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -52,6 +52,3 @@ DATE_ADD(CAST("x" AS DATE), 1, 'YEAR'); DATE_ADD('2023-01-01', 1, 'YEAR'); DATE_ADD(CAST('2023-01-01' AS DATE), 1, 'YEAR'); - -DATETIME_SUB('2023-01-01', 1, YEAR); -DATETIME_SUB(CAST('2023-01-01' AS DATETIME), 1, YEAR); diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index cd4991ba50..22181821db 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -9775,7 +9775,7 @@ JOIN "date_dim" AS "d1" ON "catalog_sales"."cs_sold_date_sk" = "d1"."d_date_sk" AND "d1"."d_week_seq" = "d2"."d_week_seq" AND "d1"."d_year" = 2002 - AND "d3"."d_date" > CONCAT("d1"."d_date", INTERVAL '5' day) + AND "d3"."d_date" > "d1"."d_date" + INTERVAL '5' day GROUP BY "item"."i_item_desc", "warehouse"."w_warehouse_name", diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 8775852d4b..8fc3273155 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -546,6 +546,53 @@ def test_binary_annotation(self): self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT) self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT) + def test_interval_math_annotation(self): + schema = { + "x": { + "a": "DATE", + "b": "DATETIME", + } + } + for sql, expected_type, *expected_sql in [ + ( + "SELECT '2023-01-01' + INTERVAL '1' DAY", + exp.DataType.Type.DATE, + "SELECT CAST('2023-01-01' AS DATE) + INTERVAL '1' DAY", + ), + ( + "SELECT '2023-01-01' + INTERVAL '1' HOUR", + exp.DataType.Type.DATETIME, + "SELECT CAST('2023-01-01' AS DATETIME) + INTERVAL '1' HOUR", + ), + ( + "SELECT '2023-01-01 00:00:01' + INTERVAL '1' HOUR", + exp.DataType.Type.DATETIME, + "SELECT CAST('2023-01-01 00:00:01' AS DATETIME) + INTERVAL '1' HOUR", + ), + ("SELECT 'nonsense' + INTERVAL '1' DAY", exp.DataType.Type.UNKNOWN), + ("SELECT x.a + INTERVAL '1' DAY FROM x AS x", exp.DataType.Type.DATE), + ("SELECT x.a + INTERVAL '1' HOUR FROM x AS x", exp.DataType.Type.DATETIME), + ("SELECT x.b + INTERVAL '1' DAY FROM x AS x", exp.DataType.Type.DATETIME), + ("SELECT x.b + INTERVAL '1' HOUR FROM x AS x", exp.DataType.Type.DATETIME), + ( + "SELECT DATE_ADD('2023-01-01', 1, 'DAY')", + exp.DataType.Type.DATE, + "SELECT DATE_ADD(CAST('2023-01-01' AS DATE), 1, 'DAY')", + ), + ( + "SELECT DATE_ADD('2023-01-01 00:00:00', 1, 'DAY')", + exp.DataType.Type.DATETIME, + "SELECT DATE_ADD(CAST('2023-01-01 00:00:00' AS DATETIME), 1, 'DAY')", + ), + ("SELECT DATE_ADD(x.a, 1, 'DAY') FROM x AS x", exp.DataType.Type.DATE), + ("SELECT DATE_ADD(x.a, 1, 'HOUR') FROM x AS x", exp.DataType.Type.DATETIME), + ("SELECT DATE_ADD(x.b, 1, 'DAY') FROM x AS x", exp.DataType.Type.DATETIME), + ]: + with self.subTest(sql): + expression = annotate_types(parse_one(sql), schema=schema) + self.assertEqual(expected_type, expression.expressions[0].type.this) + self.assertEqual(expected_sql[0] if expected_sql else sql, expression.sql()) + def test_lateral_annotation(self): expression = optimizer.optimize( parse_one("SELECT c FROM (select 1 a) as x LATERAL VIEW EXPLODE (a) AS c")