From 857e38075945ee0057fdfb4140d9636c0400c587 Mon Sep 17 00:00:00 2001 From: barak Date: Thu, 14 Sep 2023 17:09:55 -0700 Subject: [PATCH] fix(mysql): TIMESTAMP -> CAST (#2223) * fix(mysql): TIMESTAMP -> CAST * fixup * move to generator --- sqlglot/dialects/presto.py | 1 + sqlglot/expressions.py | 4 ++++ sqlglot/transforms.py | 9 +++++++++ tests/dialects/test_presto.py | 4 ++++ 4 files changed, 18 insertions(+) diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 85dc29edad..9ae4c32e9d 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -362,6 +362,7 @@ class Generator(generator.Generator): exp.WithinGroup: transforms.preprocess( [transforms.remove_within_group_for_percentiles] ), + exp.Timestamp: transforms.preprocess([transforms.timestamp_to_cast]), } def interval_sql(self, expression: exp.Interval) -> str: diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index a50640e737..1c3d42a227 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -4372,6 +4372,10 @@ class Extract(Func): arg_types = {"this": True, "expression": True} +class Timestamp(Func): + arg_types = {"this": False, "expression": False} + + class TimestampAdd(Func, TimeUnit): arg_types = {"this": True, "expression": True, "unit": False} diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index d4961e4033..70b9a31242 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -327,3 +327,12 @@ def _to_sql(self, expression: exp.Expression) -> str: raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") return _to_sql + + +def timestamp_to_cast(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Timestamp) and not expression.expression: + return exp.cast( + expression.this, + to=exp.DataType.Type.TIMESTAMP, + ) + return expression diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 9572992a09..a92f04f181 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -358,6 +358,10 @@ def test_time(self): write={"presto": "CAST(x AS TIMESTAMP)"}, read={"mysql": "CAST(x AS DATETIME)", "clickhouse": "CAST(x AS DATETIME64)"}, ) + self.validate_all( + "CAST(x AS TIMESTAMP)", + read={"mysql": "TIMESTAMP(x)"}, + ) def test_ddl(self): self.validate_all(