Skip to content

Commit

Permalink
Fix!: MySQL Timestamp Data Types (#2173)
Browse files Browse the repository at this point in the history
* fix mysql timestamps

* update tests

* remove datetime + fix doris

* Update sqlglot/dialects/mysql.py

---------

Co-authored-by: Toby Mao <[email protected]>
  • Loading branch information
eakmanrq and tobymao authored Sep 6, 2023
1 parent dd454b3 commit 93b7ba2
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 2 deletions.
2 changes: 2 additions & 0 deletions sqlglot/dialects/doris.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class Generator(MySQL.Generator):
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}

TIMESTAMP_FUNC_TYPES = set()

TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
Expand Down
15 changes: 15 additions & 0 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,16 @@ class Generator(generator.Generator):
exp.DataType.Type.UTINYINT: "TINYINT",
}

TIMESTAMP_TYPE_MAPPING = {
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
}

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**UNSIGNED_TYPE_MAPPING,
**TIMESTAMP_TYPE_MAPPING,
}

TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
Expand All @@ -592,6 +599,11 @@ class Generator(generator.Generator):
exp.DataType.Type.VARCHAR: "CHAR",
}

TIMESTAMP_FUNC_TYPES = {
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMPLTZ,
}

def datatype_sql(self, expression: exp.DataType) -> str:
# https://dev.mysql.com/doc/refman/8.0/en/numeric-type-syntax.html
result = super().datatype_sql(expression)
Expand All @@ -618,6 +630,9 @@ def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str:
return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})"

def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if expression.to.this in self.TIMESTAMP_FUNC_TYPES:
return self.func("TIMESTAMP", expression.this)

to = self.CAST_MAPPING.get(expression.to.this)

if to:
Expand Down
6 changes: 4 additions & 2 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,14 +282,16 @@ def test_cast(self):
"starrocks": "CAST(a AS DATETIME)",
"redshift": "CAST(a AS TIMESTAMP)",
"doris": "CAST(a AS DATETIME)",
"mysql": "CAST(a AS DATETIME)",
},
)
self.validate_all(
"CAST(a AS TIMESTAMPTZ)",
write={
"starrocks": "CAST(a AS DATETIME)",
"starrocks": "TIMESTAMP(a)",
"redshift": "CAST(a AS TIMESTAMP WITH TIME ZONE)",
"doris": "CAST(a AS DATETIME)",
"mysql": "TIMESTAMP(a)",
},
)
self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"})
Expand Down Expand Up @@ -870,7 +872,7 @@ def test_time(self):
"TIMESTAMP '2022-01-01'",
write={
"drill": "CAST('2022-01-01' AS TIMESTAMP)",
"mysql": "CAST('2022-01-01' AS TIMESTAMP)",
"mysql": "CAST('2022-01-01' AS DATETIME)",
"starrocks": "CAST('2022-01-01' AS DATETIME)",
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
"doris": "CAST('2022-01-01' AS DATETIME)",
Expand Down
9 changes: 9 additions & 0 deletions tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def test_ddl(self):
"mysql": "CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE d (b), INDEX e (b))",
},
)
self.validate_all(
"CREATE TABLE test (ts TIMESTAMP, ts_tz TIMESTAMPTZ, ts_ltz TIMESTAMPLTZ)",
write={
"mysql": "CREATE TABLE test (ts DATETIME, ts_tz TIMESTAMP, ts_ltz TIMESTAMP)",
},
)

def test_identity(self):
self.validate_identity(
Expand Down Expand Up @@ -215,6 +221,9 @@ def test_types(self):
"spark": "CAST(x AS BLOB) + CAST(y AS BLOB)",
},
)
self.validate_all("CAST(x AS TIMESTAMP)", write={"mysql": "CAST(x AS DATETIME)"})
self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"mysql": "TIMESTAMP(x)"})
self.validate_all("CAST(x AS TIMESTAMPLTZ)", write={"mysql": "TIMESTAMP(x)"})

def test_canonical_functions(self):
self.validate_identity("SELECT LEFT('str', 2)", "SELECT LEFT('str', 2)")
Expand Down

0 comments on commit 93b7ba2

Please sign in to comment.