Skip to content

Commit

Permalink
Fix: parsing unknown into data type build
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Jun 15, 2023
1 parent 3233c73 commit b29a421
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 33 deletions.
6 changes: 3 additions & 3 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3364,6 +3364,7 @@ class Type(AutoName):
NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}

TEMPORAL_TYPES = {
Type.TIME,
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
Type.TIMESTAMPLTZ,
Expand All @@ -3379,9 +3380,8 @@ def build(
from sqlglot import parse_one

if isinstance(dtype, str):
data_type = cls.Type.__members__.get(dtype.upper())
if not dialect and data_type:
data_type_exp: t.Optional[Expression] = DataType(this=data_type)
if dtype.upper() == "UNKNOWN":
data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type.UNKNOWN)
else:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)

Expand Down
19 changes: 5 additions & 14 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2986,23 +2986,14 @@ def _parse_types(

value: t.Optional[exp.Expression] = None
if type_token in self.TIMESTAMPS:
if self._match_text_seq("WITH", "TIME", "ZONE") or type_token == TokenType.TIMESTAMPTZ:
if self._match_text_seq("WITH", "TIME", "ZONE"):
maybe_func = False
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
elif (
self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE")
or type_token == TokenType.TIMESTAMPLTZ
):
elif self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE"):
maybe_func = False
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
elif self._match_text_seq("WITHOUT", "TIME", "ZONE"):
if type_token == TokenType.TIME:
value = exp.DataType(this=exp.DataType.Type.TIME, expressions=expressions)
else:
value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)

maybe_func = maybe_func and value is None

if value is None:
value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
maybe_func = False
elif type_token == TokenType.INTERVAL:
unit = self._parse_var()

Expand Down
21 changes: 5 additions & 16 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def test_data_type_builder(self):
self.assertEqual(exp.DataType.build("DECIMAL").sql(), "DECIMAL")
self.assertEqual(exp.DataType.build("BOOLEAN").sql(), "BOOLEAN")
self.assertEqual(exp.DataType.build("JSON").sql(), "JSON")
self.assertEqual(exp.DataType.build("JSONB").sql(), "JSONB")
self.assertEqual(exp.DataType.build("JSONB", dialect="postgres").sql(), "JSONB")
self.assertEqual(exp.DataType.build("INTERVAL").sql(), "INTERVAL")
self.assertEqual(exp.DataType.build("TIME").sql(), "TIME")
self.assertEqual(exp.DataType.build("TIMESTAMP").sql(), "TIMESTAMP")
Expand All @@ -802,22 +802,11 @@ def test_data_type_builder(self):
self.assertEqual(exp.DataType.build("GEOMETRY").sql(), "GEOMETRY")
self.assertEqual(exp.DataType.build("STRUCT").sql(), "STRUCT")
self.assertEqual(exp.DataType.build("NULLABLE").sql(), "NULLABLE")
self.assertEqual(exp.DataType.build("HLLSKETCH").sql(), "HLLSKETCH")
self.assertEqual(exp.DataType.build("HSTORE").sql(), "HSTORE")
self.assertEqual(exp.DataType.build("SUPER").sql(), "SUPER")
self.assertEqual(exp.DataType.build("SERIAL").sql(), "SERIAL")
self.assertEqual(exp.DataType.build("SMALLSERIAL").sql(), "SMALLSERIAL")
self.assertEqual(exp.DataType.build("BIGSERIAL").sql(), "BIGSERIAL")
self.assertEqual(exp.DataType.build("XML").sql(), "XML")
self.assertEqual(exp.DataType.build("UNIQUEIDENTIFIER").sql(), "UNIQUEIDENTIFIER")
self.assertEqual(exp.DataType.build("MONEY").sql(), "MONEY")
self.assertEqual(exp.DataType.build("SMALLMONEY").sql(), "SMALLMONEY")
self.assertEqual(exp.DataType.build("ROWVERSION").sql(), "ROWVERSION")
self.assertEqual(exp.DataType.build("IMAGE").sql(), "IMAGE")
self.assertEqual(exp.DataType.build("VARIANT").sql(), "VARIANT")
self.assertEqual(exp.DataType.build("OBJECT").sql(), "OBJECT")
self.assertEqual(exp.DataType.build("NULL").sql(), "NULL")
self.assertEqual(exp.DataType.build("HLLSKETCH", dialect="redshift").sql(), "HLLSKETCH")
self.assertEqual(exp.DataType.build("HSTORE", dialect="postgres").sql(), "HSTORE")
self.assertEqual(exp.DataType.build("UNKNOWN").sql(), "UNKNOWN")
self.assertEqual(exp.DataType.build("UNKNOWN", dialect="bigquery").sql(), "UNKNOWN")
self.assertEqual(exp.DataType.build("UNKNOWN", dialect="snowflake").sql(), "UNKNOWN")
self.assertEqual(exp.DataType.build("TIMESTAMP", dialect="bigquery").sql(), "TIMESTAMPTZ")
self.assertEqual(
exp.DataType.build("struct<x int>", dialect="spark").sql(), "STRUCT<x INT>"
Expand Down

0 comments on commit b29a421

Please sign in to comment.