Skip to content

Commit

Permalink
Fix(mysql): timestamp add/sub closes #2214
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Sep 14, 2023
1 parent 9357769 commit b3c97de
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 22 deletions.
32 changes: 10 additions & 22 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlglot.dialects.dialect import (
Dialect,
binary_from_function,
date_add_interval_sql,
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
Expand All @@ -28,19 +29,6 @@
logger = logging.getLogger("sqlglot")


def _date_add_sql(
data_type: str, kind: str
) -> t.Callable[[BigQuery.Generator, exp.Expression], str]:
def func(self: BigQuery.Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.args.get("unit")
unit = exp.var(unit.name.upper() if unit else "DAY")
interval = exp.Interval(this=expression.expression.copy(), unit=unit)
return f"{data_type}_{kind}({this}, {self.sql(interval)})"

return func


def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Values) -> str:
if not expression.find_ancestor(exp.From, exp.Join):
return self.values_sql(expression)
Expand Down Expand Up @@ -435,13 +423,13 @@ class Generator(generator.Generator):
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateFromParts: rename_func("DATE"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateSub: date_add_interval_sql("DATE", "SUB"),
exp.DatetimeAdd: date_add_interval_sql("DATETIME", "ADD"),
exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"),
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
exp.GroupConcat: rename_func("STRING_AGG"),
Expand Down Expand Up @@ -483,13 +471,13 @@ class Generator(generator.Generator):
exp.StrToTime: lambda self, e: self.func(
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
),
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeAdd: date_add_interval_sql("TIME", "ADD"),
exp.TimeSub: date_add_interval_sql("TIME", "SUB"),
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
exp.TsOrDsAdd: date_add_interval_sql("DATE", "ADD"),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
exp.Unhex: rename_func("FROM_HEX"),
exp.Values: _derived_table_values_to_unnest,
Expand Down
13 changes: 13 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,19 @@ def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
return exp.TimestampTrunc(this=this, unit=unit)


def date_add_interval_sql(
data_type: str, kind: str
) -> t.Callable[[Generator, exp.Expression], str]:
def func(self: Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.args.get("unit")
unit = exp.var(unit.name.upper() if unit else "DAY")
interval = exp.Interval(this=expression.expression.copy(), unit=unit)
return f"{data_type}_{kind}({this}, {self.sql(interval)})"

return func


def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
return self.func(
"DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
Expand Down
3 changes: 3 additions & 0 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
date_add_interval_sql,
datestrtodate_sql,
format_time_lambda,
json_keyvalue_comma_sql,
Expand Down Expand Up @@ -559,6 +560,8 @@ class Generator(generator.Generator):
exp.StrToTime: _str_to_date_sql,
exp.Stuff: rename_func("INSERT"),
exp.TableSample: no_tablesample_sql,
exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"),
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
Expand Down
8 changes: 8 additions & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,17 @@ def test_bigquery(self):
write={
"bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)",
"databricks": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
"mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)",
"spark": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
},
)
self.validate_all(
'SELECT TIMESTAMP_SUB(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)',
write={
"bigquery": "SELECT TIMESTAMP_SUB(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)",
"mysql": "SELECT DATE_SUB(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)",
},
)
self.validate_all(
"MD5(x)",
write={
Expand Down

0 comments on commit b3c97de

Please sign in to comment.