Skip to content

Commit

Permalink
fix(mysql,optimizer): TO_DAYS transpilation and more date casting (#2334
Browse files Browse the repository at this point in the history
)
  • Loading branch information
barakalon authored Sep 28, 2023
1 parent e2c8366 commit fcc2d8f
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 42 deletions.
5 changes: 4 additions & 1 deletion sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def inner_func(args: t.List) -> E:


def parse_date_delta_with_interval(
expression_class: t.Type[E],
expression_class: t.Type[E], invert: bool = False
) -> t.Callable[[t.List], t.Optional[E]]:
def func(args: t.List) -> t.Optional[E]:
if len(args) < 2:
Expand All @@ -553,6 +553,9 @@ def func(args: t.List) -> t.Optional[E]:
if expression and expression.is_string:
expression = exp.Literal.number(expression.this)

if expression and invert:
expression = expression * -1

return expression_class(
this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
)
Expand Down
85 changes: 71 additions & 14 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def _str_to_date(args: t.List) -> exp.StrToDate:
return exp.StrToDate(this=seq_get(args, 0), format=date_format)


def _str_to_date_sql(self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
def _str_to_date_sql(
self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
) -> str:
date_format = self.format_time(expression)
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"

Expand All @@ -86,15 +88,41 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"


def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: MySQL.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
def _date_add_sql(
kind: str,
) -> t.Callable[[MySQL.Generator, exp.Expression], str]:
def func(self: MySQL.Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"

return func


def _ts_or_ds_to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = expression.args.get("format")
if time_format:
return _str_to_date_sql(self, expression)
return f"DATE({self.sql(expression, 'this')})"


def _remove_ts_or_ds_to_date(
to_sql: t.Optional[t.Callable[[MySQL.Generator, exp.Expression], str]] = None,
args: t.Tuple[str, ...] = ("this",),
) -> t.Callable[[MySQL.Generator, exp.Func], str]:
def func(self: MySQL.Generator, expression: exp.Func) -> str:
expression = expression.copy()

for arg_key in args:
arg = expression.args.get(arg_key)
if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"):
expression.set(arg_key, arg.this)

return to_sql(self, expression) if to_sql else self.function_fallback_sql(expression)

return func


class MySQL(Dialect):
# https://dev.mysql.com/doc/refman/8.0/en/identifiers.html
IDENTIFIERS_CAN_START_WITH_DIGIT = True
Expand Down Expand Up @@ -233,17 +261,36 @@ class Parser(parser.Parser):

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)),
"DATE_ADD": parse_date_delta_with_interval(exp.TsOrDsAdd),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"DATE_SUB": parse_date_delta_with_interval(exp.TsOrDsAdd, invert=True),
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
"ISNULL": isnull_to_is_null,
"LOCATE": locate_to_strposition,
"MONTHNAME": lambda args: exp.TimeToStr(
this=seq_get(args, 0),
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
format=exp.Literal.string("%B"),
),
"STR_TO_DATE": _str_to_date,
"TO_DAYS": lambda args: exp.paren(
exp.DateDiff(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
expression=exp.TsOrDsToDate(this=exp.Literal.string("0000-01-01")),
unit=exp.var("DAY"),
)
+ 1
),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"WEEK": lambda args: exp.Week(
this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1)
),
"WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
}

FUNCTION_PARSERS = {
Expand Down Expand Up @@ -557,20 +604,24 @@ class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.CurrentDate: no_paren_current_date_sql,
exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
exp.DateAdd: _date_add_sql("ADD"),
exp.DateDiff: _remove_ts_or_ds_to_date(
lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression")
),
exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")),
exp.DateTrunc: _date_trunc_sql,
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.Day: _remove_ts_or_ds_to_date(),
exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")),
exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Month: _remove_ts_or_ds_to_date(),
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.Pivot: no_pivot_sql,
Expand All @@ -590,10 +641,16 @@ class Generator(generator.Generator):
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
exp.TimeToStr: _remove_ts_or_ds_to_date(
lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e))
),
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.TsOrDsAdd: _date_add_sql("ADD"),
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.Week: _remove_ts_or_ds_to_date(),
exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")),
exp.Year: _remove_ts_or_ds_to_date(),
}

UNSIGNED_TYPE_MAPPING = {
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4391,6 +4391,10 @@ class DayOfYear(Func):
_sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"]


class ToDays(Func):
pass


class WeekOfYear(Func):
_sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]

Expand Down
42 changes: 23 additions & 19 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,6 @@ def _is_number(expression: exp.Expression) -> bool:
return expression.is_number


def _is_date(expression: exp.Expression) -> bool:
return isinstance(expression, exp.Cast) and extract_date(expression) is not None


def _is_interval(expression: exp.Expression) -> bool:
return isinstance(expression, exp.Interval) and extract_interval(expression) is not None

Expand Down Expand Up @@ -422,8 +418,8 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
if r.is_number:
a_predicate = _is_number
b_predicate = _is_number
elif _is_date(r):
a_predicate = _is_date
elif _is_date_literal(r):
a_predicate = _is_date_literal
b_predicate = _is_interval
else:
return expression
Expand Down Expand Up @@ -509,14 +505,14 @@ def _simplify_binary(expression, a, b):

if boolean:
return boolean
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
elif _is_date_literal(a) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b)
if a and b:
if isinstance(expression, exp.Add):
return date_literal(a + b)
if isinstance(expression, exp.Sub):
return date_literal(a - b)
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
elif isinstance(a, exp.Interval) and _is_date_literal(b):
a, b = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
if a and b and isinstance(expression, exp.Add):
Expand Down Expand Up @@ -702,11 +698,7 @@ def _datetrunc_neq(


def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
return (
isinstance(left, (exp.DateTrunc, exp.TimestampTrunc))
and isinstance(right, exp.Cast)
and right.is_type(*exp.DataType.TEMPORAL_TYPES)
)
return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)


@catch(ModuleNotFoundError, UnsupportedUnit)
Expand Down Expand Up @@ -854,15 +846,25 @@ def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.da
return None


def extract_date(cast: exp.Cast) -> t.Optional[t.Union[datetime.date, datetime.date]]:
value: t.Any
def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
if isinstance(cast, exp.Cast):
to = cast.to
elif isinstance(cast, exp.TsOrDsToDate):
to = exp.DataType.build(exp.DataType.Type.DATE)
else:
return None

if isinstance(cast.this, exp.Literal):
value = cast.this.name
elif isinstance(cast.this, exp.Cast):
value: t.Any = cast.this.name
elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
value = extract_date(cast.this)
else:
return None
return cast_value(value, cast.to)
return cast_value(value, to)


def _is_date_literal(expression: exp.Expression) -> bool:
return extract_date(expression) is not None


def extract_interval(expression):
Expand All @@ -878,7 +880,9 @@ def extract_interval(expression):
def date_literal(date):
return exp.cast(
exp.Literal.string(date),
"DATETIME" if isinstance(date, datetime.datetime) else "DATE",
exp.DataType.Type.DATETIME
if isinstance(date, datetime.datetime)
else exp.DataType.Type.DATE,
)


Expand Down
24 changes: 19 additions & 5 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def test_time(self):
"presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)",
"snowflake": "CAST(x AS DATE)",
"doris": "TO_DATE(x)",
"mysql": "DATE(x)",
},
)
self.validate_all(
Expand Down Expand Up @@ -687,9 +688,7 @@ def test_time(self):
self.validate_all(
"DATE_ADD(x, 1, 'DAY')",
read={
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
"snowflake": "DATEADD('DAY', 1, x)",
"starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
},
write={
"bigquery": "DATE_ADD(x, INTERVAL 1 DAY)",
Expand Down Expand Up @@ -881,6 +880,7 @@ def test_time(self):
"hive": "DATE_ADD('2021-02-01', 1)",
"presto": "DATE_ADD('DAY', 1, CAST(CAST('2021-02-01' AS TIMESTAMP) AS DATE))",
"spark": "DATE_ADD('2021-02-01', 1)",
"mysql": "DATE_ADD('2021-02-01', INTERVAL 1 DAY)",
},
)
self.validate_all(
Expand Down Expand Up @@ -936,10 +936,7 @@ def test_time(self):
"bigquery",
"drill",
"duckdb",
"mysql",
"presto",
"starrocks",
"doris",
)
},
write={
Expand All @@ -952,8 +949,25 @@ def test_time(self):
"presto",
"hive",
"spark",
)
},
)
self.validate_all(
f"{unit}(TS_OR_DS_TO_DATE(x))",
read={
dialect: f"{unit}(x)"
for dialect in (
"mysql",
"doris",
"starrocks",
)
},
write={
dialect: f"{unit}(x)"
for dialect in (
"mysql",
"doris",
"starrocks",
)
},
)
Expand Down
Loading

0 comments on commit fcc2d8f

Please sign in to comment.