Skip to content

Commit

Permalink
Fix: parse time[(p)] with time zone correctly (#2041)
Browse files Browse the repository at this point in the history
* Fix: parse time[(p)] with time zone correctly

* Generate TIMETZ for duckdb

* Include a redshift test as it also supports TIMETZ

* Refactor - redshift doesn't support precision parameter too

* Revert last commit - Redshift only allows precision with WITH TIME ZONE syntax..

* Small perf improvement
  • Loading branch information
georgesittas authored Aug 11, 2023
1 parent 32eb129 commit 4591092
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 17 deletions.
5 changes: 5 additions & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return f"{self.expressions(expression, flat=True)}[]"

# Type TIMESTAMP / TIME WITH TIME ZONE does not support any modifiers
if expression.is_type("timestamptz", "timetz"):
return expression.this.value

return self.datatype_sql(expression)


Expand Down
10 changes: 2 additions & 8 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistin
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"


def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
sql = self.datatype_sql(expression)
if expression.is_type("timestamptz"):
sql = f"{sql} WITH TIME ZONE"
return sql


def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
expression = expression.copy()
Expand Down Expand Up @@ -231,6 +224,7 @@ class Generator(generator.Generator):
TABLE_HINTS = False
QUERY_HINTS = False
IS_BOOL_ALLOWED = False
TZ_TO_WITH_TIME_ZONE = True
STRUCT_DELIMITER = ("(", ")")

PROPERTIES_LOCATION = {
Expand All @@ -245,6 +239,7 @@ class Generator(generator.Generator):
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
exp.DataType.Type.TEXT: "VARCHAR",
exp.DataType.Type.TIMETZ: "TIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.STRUCT: "ROW",
}
Expand All @@ -265,7 +260,6 @@ class Generator(generator.Generator):
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
Expand Down
7 changes: 4 additions & 3 deletions sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ class Tokenizer(Postgres.Tokenizer):
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
"TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ,
"TOP": TokenType.TOP,
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
Expand All @@ -101,12 +99,15 @@ class Generator(Postgres.Generator):
RENAME_TABLE_WITH_DB = False
QUERY_HINTS = False
VALUES_AS_TABLE = False
TZ_TO_WITH_TIME_ZONE = True

TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "VARBYTE",
exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.TIMETZ: "TIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "VARBYTE",
}

PROPERTIES_LOCATION = {
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3452,6 +3452,7 @@ class Type(AutoName):
SUPER = auto()
TEXT = auto()
TIME = auto()
TIMETZ = auto()
TIMESTAMP = auto()
TIMESTAMPLTZ = auto()
TIMESTAMPTZ = auto()
Expand Down Expand Up @@ -3501,6 +3502,7 @@ class Type(AutoName):

TEMPORAL_TYPES = {
Type.TIME,
Type.TIMETZ,
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
Type.TIMESTAMPLTZ,
Expand Down
15 changes: 14 additions & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ class Generator:
# Whether or not to generate an unquoted value for EXTRACT's date part argument
EXTRACT_ALLOWS_QUOTES = True

# Whether or not TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax
TZ_TO_WITH_TIME_ZONE = False

# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")

Expand Down Expand Up @@ -835,14 +838,17 @@ def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:

def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this

type_sql = (
self.TYPE_MAPPING.get(type_value, type_value.value)
if isinstance(type_value, exp.DataType.Type)
else type_value
)

nested = ""
interior = self.expressions(expression, flat=True)
values = ""

if interior:
if expression.args.get("nested"):
nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
Expand All @@ -855,7 +861,14 @@ def datatype_sql(self, expression: exp.DataType) -> str:
else:
nested = f"({interior})"

return f"{type_sql}{nested}{values}"
type_sql = f"{type_sql}{nested}{values}"
if self.TZ_TO_WITH_TIME_ZONE and type_value in (
exp.DataType.Type.TIMETZ,
exp.DataType.Type.TIMESTAMPTZ,
):
type_sql = f"{type_sql} WITH TIME ZONE"

return type_sql

def directory_sql(self, expression: exp.Directory) -> str:
local = "LOCAL " if expression.args.get("local") else ""
Expand Down
15 changes: 13 additions & 2 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class Parser(metaclass=_Parser):
TokenType.JSONB,
TokenType.INTERVAL,
TokenType.TIME,
TokenType.TIMETZ,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
Expand Down Expand Up @@ -393,11 +394,16 @@ class Parser(metaclass=_Parser):
TokenType.STAR: exp.Mul,
}

TIMESTAMPS = {
TIMES = {
TokenType.TIME,
TokenType.TIMETZ,
}

TIMESTAMPS = {
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
*TIMES,
}

SET_OPERATIONS = {
Expand Down Expand Up @@ -3165,7 +3171,12 @@ def _parse_types(
if type_token in self.TIMESTAMPS:
if self._match_text_seq("WITH", "TIME", "ZONE"):
maybe_func = False
this = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
tz_type = (
exp.DataType.Type.TIMETZ
if type_token in self.TIMES
else exp.DataType.Type.TIMESTAMPTZ
)
this = exp.DataType(this=tz_type, expressions=expressions)
elif self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE"):
maybe_func = False
this = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class TokenType(AutoName):
JSON = auto()
JSONB = auto()
TIME = auto()
TIMETZ = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
Expand Down Expand Up @@ -705,6 +706,7 @@ class Tokenizer(metaclass=_Tokenizer):
"BYTEA": TokenType.VARBINARY,
"VARBINARY": TokenType.VARBINARY,
"TIME": TokenType.TIME,
"TIMETZ": TokenType.TIMETZ,
"TIMESTAMP": TokenType.TIMESTAMP,
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def test_cast(self):
"CAST(a AS TIMESTAMPTZ)",
write={
"starrocks": "CAST(a AS DATETIME)",
"redshift": "CAST(a AS TIMESTAMPTZ)",
"redshift": "CAST(a AS TIMESTAMP WITH TIME ZONE)",
"doris": "CAST(a AS DATETIME)",
},
)
Expand Down
16 changes: 14 additions & 2 deletions tests/dialects/test_presto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest import mock

from sqlglot import UnsupportedError
from sqlglot import UnsupportedError, exp, parse_one
from tests.dialects.test_dialect import Validator


Expand Down Expand Up @@ -116,11 +116,20 @@ def test_cast(self):
"snowflake": "CAST(OBJECT_CONSTRUCT('a', [1], 'b', [2], 'c', [3]) AS OBJECT)",
},
)
self.validate_all(
"CAST(x AS TIME(5) WITH TIME ZONE)",
write={
"duckdb": "CAST(x AS TIMETZ)",
"postgres": "CAST(x AS TIMETZ(5))",
"presto": "CAST(x AS TIME(5) WITH TIME ZONE)",
"redshift": "CAST(x AS TIME(5) WITH TIME ZONE)",
},
)
self.validate_all(
"CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
write={
"bigquery": "CAST(x AS TIMESTAMP)",
"duckdb": "CAST(x AS TIMESTAMPTZ(9))",
"duckdb": "CAST(x AS TIMESTAMPTZ)",
"presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)",
"hive": "CAST(x AS TIMESTAMP)",
"spark": "CAST(x AS TIMESTAMP)",
Expand Down Expand Up @@ -194,6 +203,9 @@ def test_interval_plural_to_singular(self):
)

def test_time(self):
expr = parse_one("TIME(7) WITH TIME ZONE", into=exp.DataType, read="presto")
self.assertEqual(expr.this, exp.DataType.Type.TIMETZ)

self.validate_identity("FROM_UNIXTIME(a, b)")
self.validate_identity("FROM_UNIXTIME(a, b, c)")
self.validate_identity("TRIM(a, b)")
Expand Down
20 changes: 20 additions & 0 deletions tests/dialects/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@ class TestRedshift(Validator):
dialect = "redshift"

def test_redshift(self):
self.validate_all(
"SELECT CAST('01:03:05.124' AS TIME(2) WITH TIME ZONE)",
read={
"postgres": "SELECT CAST('01:03:05.124' AS TIMETZ(2))",
},
write={
"postgres": "SELECT CAST('01:03:05.124' AS TIMETZ(2))",
"redshift": "SELECT CAST('01:03:05.124' AS TIME(2) WITH TIME ZONE)",
},
)
self.validate_all(
"SELECT CAST('2020-02-02 01:03:05.124' AS TIMESTAMP(2) WITH TIME ZONE)",
read={
"postgres": "SELECT CAST('2020-02-02 01:03:05.124' AS TIMESTAMPTZ(2))",
},
write={
"postgres": "SELECT CAST('2020-02-02 01:03:05.124' AS TIMESTAMPTZ(2))",
"redshift": "SELECT CAST('2020-02-02 01:03:05.124' AS TIMESTAMP(2) WITH TIME ZONE)",
},
)
self.validate_all(
"SELECT INTERVAL '5 days'",
read={
Expand Down

0 comments on commit 4591092

Please sign in to comment.