Skip to content

Commit

Permalink
Fix: use SUPPORTS_USER_DEFINED_TYPES to set udt in schema _to_data_ty…
Browse files Browse the repository at this point in the history
…pe (#2203)

* Fix: use SUPPORTS_USER_DEFINED_TYPES to set udt in schema _to_data_type

* Formatting
  • Loading branch information
georgesittas authored Sep 12, 2023
1 parent 4ac4f62 commit 416b341
Show file tree
Hide file tree
Showing 13 changed files with 24 additions and 18 deletions.
3 changes: 1 addition & 2 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5:

class BigQuery(Dialect):
UNNEST_COLUMN_ONLY = True
SUPPORTS_USER_DEFINED_TYPES = False

# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
Expand Down Expand Up @@ -278,8 +279,6 @@ class Parser(parser.Parser):
LOG_BASE_FIRST = False
LOG_DEFAULTS_TO_LN = True

SUPPORTS_USER_DEFINED_TYPES = False

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE": _parse_date,
Expand Down
3 changes: 1 addition & 2 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ClickHouse(Dialect):
NORMALIZE_FUNCTIONS: bool | str = False
NULL_ORDERING = "nulls_are_last"
STRICT_STRING_CONCAT = True
SUPPORTS_USER_DEFINED_TYPES = False

class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
Expand Down Expand Up @@ -64,8 +65,6 @@ class Tokenizer(tokens.Tokenizer):
}

class Parser(parser.Parser):
SUPPORTS_USER_DEFINED_TYPES = False

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ANY": exp.AnyValue.from_arg_list,
Expand Down
3 changes: 3 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ class Dialect(metaclass=_Dialect):
# Determines whether or not CONCAT's arguments must be strings
STRICT_STRING_CONCAT = False

# Determines whether or not user-defined data types are supported
SUPPORTS_USER_DEFINED_TYPES = True

# Determines how function names are going to be normalized
NORMALIZE_FUNCTIONS: bool | str = "upper"

Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Drill(Dialect):
DATE_FORMAT = "'yyyy-MM-dd'"
DATEINT_FORMAT = "'yyyyMMdd'"
TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
SUPPORTS_USER_DEFINED_TYPES = False

TIME_MAPPING = {
"y": "%Y",
Expand Down Expand Up @@ -80,7 +81,6 @@ class Tokenizer(tokens.Tokenizer):
class Parser(parser.Parser):
STRICT_CAST = False
CONCAT_NULL_OUTPUTS_STRING = True
SUPPORTS_USER_DEFINED_TYPES = False

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str:

class DuckDB(Dialect):
NULL_ORDERING = "nulls_are_last"
SUPPORTS_USER_DEFINED_TYPES = False

# https://duckdb.org/docs/sql/introduction.html#creating-a-new-table
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
Expand Down Expand Up @@ -135,7 +136,6 @@ class Tokenizer(tokens.Tokenizer):

class Parser(parser.Parser):
CONCAT_NULL_OUTPUTS_STRING = True
SUPPORTS_USER_DEFINED_TYPES = False

BITWISE = {
**parser.Parser.BITWISE,
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str:
class Hive(Dialect):
ALIAS_POST_TABLESAMPLE = True
IDENTIFIERS_CAN_START_WITH_DIGIT = True
SUPPORTS_USER_DEFINED_TYPES = False

# https://spark.apache.org/docs/latest/sql-ref-identifier.html#description
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
Expand Down Expand Up @@ -222,7 +223,6 @@ class Tokenizer(tokens.Tokenizer):
class Parser(parser.Parser):
LOG_DEFAULTS_TO_LN = True
STRICT_CAST = False
SUPPORTS_USER_DEFINED_TYPES = False

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
Expand Down
3 changes: 1 addition & 2 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class MySQL(Dialect):

TIME_FORMAT = "'%Y-%m-%d %T'"
DPIPE_IS_STRING_CONCAT = False
SUPPORTS_USER_DEFINED_TYPES = False

# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
TIME_MAPPING = {
Expand Down Expand Up @@ -193,8 +194,6 @@ class Tokenizer(tokens.Tokenizer):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}

class Parser(parser.Parser):
SUPPORTS_USER_DEFINED_TYPES = False

FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
TokenType.DATABASE,
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class Redshift(Postgres):
# https://docs.aws.amazon.com/redshift/latest/dg/r_names.html
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None

SUPPORTS_USER_DEFINED_TYPES = False

TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
TIME_MAPPING = {
**Postgres.TIME_MAPPING,
Expand All @@ -38,8 +40,6 @@ class Redshift(Postgres):
}

class Parser(Postgres.Parser):
SUPPORTS_USER_DEFINED_TYPES = False

FUNCTIONS = {
**Postgres.Parser.FUNCTIONS,
"ADD_MONTHS": lambda args: exp.DateAdd(
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ class Snowflake(Dialect):
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
NULL_ORDERING = "nulls_are_large"
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
SUPPORTS_USER_DEFINED_TYPES = False

TIME_MAPPING = {
"YYYY": "%Y",
Expand Down Expand Up @@ -234,7 +235,6 @@ class Snowflake(Dialect):

class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
SUPPORTS_USER_DEFINED_TYPES = False

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
Expand Down
5 changes: 2 additions & 3 deletions sqlglot/dialects/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@


class Trino(Presto):
SUPPORTS_USER_DEFINED_TYPES = False

class Generator(Presto.Generator):
TRANSFORMS = {
**Presto.Generator.TRANSFORMS,
Expand All @@ -13,6 +15,3 @@ class Generator(Presto.Generator):

class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]

class Parser(Presto.Parser):
SUPPORTS_USER_DEFINED_TYPES = False
3 changes: 1 addition & 2 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,8 +863,6 @@ class Parser(metaclass=_Parser):
LOG_BASE_FIRST = True
LOG_DEFAULTS_TO_LN = False

SUPPORTS_USER_DEFINED_TYPES = True

# Whether or not ADD is present for each column added by ALTER TABLE
ALTER_TABLE_ADD_COLUMN_KEYWORD = True

Expand Down Expand Up @@ -892,6 +890,7 @@ class Parser(metaclass=_Parser):
UNNEST_COLUMN_ONLY: bool = False
ALIAS_POST_TABLESAMPLE: bool = False
STRICT_STRING_CONCAT = False
SUPPORTS_USER_DEFINED_TYPES = True
NORMALIZE_FUNCTIONS = "upper"
NULL_ORDERING: str = "nulls_are_small"
SHOW_TRIE: t.Dict = {}
Expand Down
3 changes: 2 additions & 1 deletion sqlglot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,10 @@ def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.Da
"""
if schema_type not in self._type_mapping_cache:
dialect = dialect or self.dialect
udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES

try:
expression = exp.DataType.build(schema_type, dialect=dialect)
expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
self._type_mapping_cache[schema_type] = expression
except AttributeError:
in_dialect = f" in dialect {dialect}" if dialect else ""
Expand Down
7 changes: 7 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,13 @@ def test_nested_type_annotation(self):
self.assertEqual(exp.DataType.Type.ARRAY, expression.selects[0].type.this)
self.assertEqual(expression.selects[0].type.sql(), "ARRAY<INT>")

def test_user_defined_type_annotation(self):
schema = MappingSchema({"t": {"x": "int"}}, dialect="postgres")
expression = annotate_types(parse_one("SELECT CAST(x AS IPADDRESS) FROM t"), schema=schema)

self.assertEqual(exp.DataType.Type.USERDEFINED, expression.selects[0].type.this)
self.assertEqual(expression.selects[0].type.sql(dialect="postgres"), "IPADDRESS")

def test_recursive_cte(self):
query = parse_one(
"""
Expand Down

0 comments on commit 416b341

Please sign in to comment.