Skip to content

Commit

Permalink
Feat!: add support for casting to user defined types (#2096)
Browse files Browse the repository at this point in the history
* Feat: add support for casting to user defined types

* Add test

* Refactor

* Simplify

* Simplify

* Refactor to make UDTs first-class

* Remove redundant test

* PR feedback

* Autoflake fixup

* Comment fixup

* Exception cleanup

* Cleanup

* Fix test

* make more forgiving

---------

Co-authored-by: tobymao <[email protected]>
  • Loading branch information
georgesittas and tobymao authored Aug 21, 2023
1 parent 0b9a575 commit 28a0e20
Show file tree
Hide file tree
Showing 16 changed files with 171 additions and 28 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ for table in parse_one("SELECT * FROM x JOIN y JOIN z").find_all(exp.Table):

### Parser Errors

When the parser detects an error in the syntax, it raises a ParserError:
When the parser detects an error in the syntax, it raises a ParseError:

```python
import sqlglot
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ class Parser(parser.Parser):
LOG_BASE_FIRST = False
LOG_DEFAULTS_TO_LN = True

SUPPORTS_USER_DEFINED_TYPES = False

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE": _parse_date,
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class Tokenizer(tokens.Tokenizer):
}

class Parser(parser.Parser):
SUPPORTS_USER_DEFINED_TYPES = False

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ANY": exp.AnyValue.from_arg_list,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class Tokenizer(tokens.Tokenizer):
class Parser(parser.Parser):
STRICT_CAST = False
CONCAT_NULL_OUTPUTS_STRING = True
SUPPORTS_USER_DEFINED_TYPES = False

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class Tokenizer(tokens.Tokenizer):

class Parser(parser.Parser):
CONCAT_NULL_OUTPUTS_STRING = True
SUPPORTS_USER_DEFINED_TYPES = False

BITWISE = {
**parser.Parser.BITWISE,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class Tokenizer(tokens.Tokenizer):
class Parser(parser.Parser):
LOG_DEFAULTS_TO_LN = True
STRICT_CAST = False
SUPPORTS_USER_DEFINED_TYPES = False

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ class Tokenizer(tokens.Tokenizer):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}

class Parser(parser.Parser):
SUPPORTS_USER_DEFINED_TYPES = False

FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
TokenType.DATABASE,
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Redshift(Postgres):
}

class Parser(Postgres.Parser):
SUPPORTS_USER_DEFINED_TYPES = False

FUNCTIONS = {
**Postgres.Parser.FUNCTIONS,
"ADD_MONTHS": lambda args: exp.DateAdd(
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class Snowflake(Dialect):

class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
SUPPORTS_USER_DEFINED_TYPES = False

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
Expand Down
3 changes: 3 additions & 0 deletions sqlglot/dialects/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ class Generator(Presto.Generator):

class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]

class Parser(Presto.Parser):
SUPPORTS_USER_DEFINED_TYPES = False
89 changes: 75 additions & 14 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3399,6 +3399,7 @@ class DataTypeParam(Expression):
class DataType(Expression):
arg_types = {
"this": True,
"expression": False,
"expressions": False,
"nested": False,
"values": False,
Expand Down Expand Up @@ -3515,7 +3516,10 @@ class Type(AutoName):
Type.DOUBLE,
}

NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}
NUMERIC_TYPES = {
*INTEGER_TYPES,
*FLOAT_TYPES,
}

TEMPORAL_TYPES = {
Type.TIME,
Expand All @@ -3528,23 +3532,45 @@ class Type(AutoName):
Type.DATETIME64,
}

META_TYPES = {"UNKNOWN", "NULL"}
META_TYPES = {
"UNKNOWN",
"NULL",
}

@classmethod
def build(
cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
cls,
dtype: str | DataType | DataType.Type,
dialect: DialectType = None,
udt: bool = False,
**kwargs,
) -> DataType:
"""
Constructs a DataType object.
Args:
dtype: the data type of interest.
dialect: the dialect to use for parsing `dtype`, in case it's a string.
udt: when set to True, `dtype` will be used as-is if it can't be parsed into a
DataType, thus creating a user-defined type.
kawrgs: additional arguments to pass in the constructor of DataType.
Returns:
The constructed DataType object.
"""
from sqlglot import parse_one

if isinstance(dtype, str):
upper = dtype.upper()
if upper in DataType.META_TYPES:
data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[upper])
data_type_exp = DataType(this=DataType.Type[upper])
else:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)

if data_type_exp is None:
raise ValueError(f"Unparsable data type value: {dtype}")
try:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
except ParseError:
if udt:
return DataType(this=DataType.Type.USERDEFINED, expression=dtype, **kwargs)
raise
elif isinstance(dtype, DataType.Type):
data_type_exp = DataType(this=dtype)
elif isinstance(dtype, DataType):
Expand All @@ -3555,7 +3581,31 @@ def build(
return DataType(**{**data_type_exp.args, **kwargs})

def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
return any(self.this == DataType.build(dtype).this for dtype in dtypes)
"""
Checks whether this DataType matches one of the provided data types. Nested types or precision
will be compared using "structural equivalence" semantics, so e.g. array<int> != array<float>.
Args:
dtypes: the data types to compare this DataType to.
Returns:
True, if and only if there is a type in `dtypes` which is equal to this DataType.
"""
for dtype in dtypes:
other = DataType.build(dtype, udt=True)

if (
other.expressions
or self.this == DataType.Type.USERDEFINED
or other.this == DataType.Type.USERDEFINED
):
matches = self == other
else:
matches = self.this == other.this

if matches:
return True
return False


# https://www.postgresql.org/docs/15/datatype-pseudo.html
Expand Down Expand Up @@ -4112,18 +4162,29 @@ def output_name(self) -> str:
return self.name

def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
return self.to.is_type(*dtypes)
"""
Checks whether this Cast's DataType matches one of the provided data types. Nested types
like arrays or structs will be compared using "structural equivalence" semantics, so e.g.
array<int> != array<float>.
Args:
dtypes: the data types to compare this Cast's DataType to.
class CastToStrType(Func):
arg_types = {"this": True, "expression": True}
Returns:
True, if and only if there is a type in `dtypes` which is equal to this Cast's DataType.
"""
return self.to.is_type(*dtypes)


class Collate(Binary):
class TryCast(Cast):
pass


class TryCast(Cast):
class CastToStrType(Func):
arg_types = {"this": True, "to": True}


class Collate(Binary):
pass


Expand Down
13 changes: 8 additions & 5 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,11 +842,14 @@ def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this

type_sql = (
self.TYPE_MAPPING.get(type_value, type_value.value)
if isinstance(type_value, exp.DataType.Type)
else type_value
)
if type_value == exp.DataType.Type.USERDEFINED and expression.expression:
type_sql = self.sql(expression.expression)
else:
type_sql = (
self.TYPE_MAPPING.get(type_value, type_value.value)
if isinstance(type_value, exp.DataType.Type)
else type_value
)

nested = ""
interior = self.expressions(expression, flat=True)
Expand Down
14 changes: 8 additions & 6 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,8 @@ class Parser(metaclass=_Parser):
LOG_BASE_FIRST = True
LOG_DEFAULTS_TO_LN = False

SUPPORTS_USER_DEFINED_TYPES = True

__slots__ = (
"error_level",
"error_message_context",
Expand Down Expand Up @@ -3859,17 +3861,17 @@ def _parse_cast(self, strict: bool) -> exp.Expression:

if not self._match(TokenType.ALIAS):
if self._match(TokenType.COMMA):
return self.expression(
exp.CastToStrType, this=this, expression=self._parse_string()
)
else:
self.raise_error("Expected AS after CAST")
return self.expression(exp.CastToStrType, this=this, to=self._parse_string())

self.raise_error("Expected AS after CAST")

fmt = None
to = self._parse_types()
to = self._parse_types() or (self.SUPPORTS_USER_DEFINED_TYPES and self._parse_id_var())

if not to:
self.raise_error("Expected TYPE after CAST")
elif isinstance(to, exp.Identifier):
to = exp.DataType.build(to.name, udt=True)
elif to.this == exp.DataType.Type.CHAR:
if self._match(TokenType.CHARACTER_SET):
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
Expand Down
25 changes: 24 additions & 1 deletion tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import unittest

from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one
from sqlglot import (
Dialect,
Dialects,
ErrorLevel,
ParseError,
UnsupportedError,
parse_one,
)
from sqlglot.dialects import Hive


Expand Down Expand Up @@ -1764,3 +1771,19 @@ def test_count_if(self):
"tsql": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
},
)

def test_cast_to_user_defined_type(self):
self.validate_all(
"CAST(x AS some_udt)",
write={
"": "CAST(x AS some_udt)",
"oracle": "CAST(x AS some_udt)",
"postgres": "CAST(x AS some_udt)",
"presto": "CAST(x AS some_udt)",
"teradata": "CAST(x AS some_udt)",
"tsql": "CAST(x AS some_udt)",
},
)

with self.assertRaises(ParseError):
parse_one("CAST(x AS some_udt)", read="bigquery")
3 changes: 3 additions & 0 deletions tests/dialects/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def test_oracle(self):
self.validate_identity("SELECT * FROM table_name@dblink_name.database_link_domain")
self.validate_identity("SELECT * FROM table_name SAMPLE (25) s")
self.validate_identity("SELECT * FROM V$SESSION")
self.validate_identity(
"SELECT COUNT(1) INTO V_Temp FROM TABLE(CAST(somelist AS data_list)) WHERE col LIKE '%contact'"
)
self.validate_identity(
"SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name"
)
Expand Down
38 changes: 37 additions & 1 deletion tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
import unittest

from sqlglot import alias, exp, parse_one
from sqlglot import ParseError, alias, exp, parse_one


class TestExpressions(unittest.TestCase):
Expand Down Expand Up @@ -896,3 +896,39 @@ def test_unnest(self):
second_subquery = ast.args["from"].this.this
innermost_subquery = list(ast.find_all(exp.Select))[1].parent
self.assertIs(second_subquery, innermost_subquery.unwrap())

def test_is_type(self):
ast = parse_one("CAST(x AS VARCHAR)")
assert ast.is_type("VARCHAR")
assert not ast.is_type("VARCHAR(5)")
assert not ast.is_type("FLOAT")

ast = parse_one("CAST(x AS VARCHAR(5))")
assert ast.is_type("VARCHAR")
assert ast.is_type("VARCHAR(5)")
assert not ast.is_type("VARCHAR(4)")
assert not ast.is_type("FLOAT")

ast = parse_one("CAST(x AS ARRAY<INT>)")
assert ast.is_type("ARRAY")
assert ast.is_type("ARRAY<INT>")
assert not ast.is_type("ARRAY<FLOAT>")
assert not ast.is_type("INT")

ast = parse_one("CAST(x AS ARRAY)")
assert ast.is_type("ARRAY")
assert not ast.is_type("ARRAY<INT>")
assert not ast.is_type("ARRAY<FLOAT>")
assert not ast.is_type("INT")

ast = parse_one("CAST(x AS STRUCT<a INT, b FLOAT>)")
assert ast.is_type("STRUCT")
assert ast.is_type("STRUCT<a INT, b FLOAT>")
assert not ast.is_type("STRUCT<a VARCHAR, b INT>")

dtype = exp.DataType.build("foo", udt=True)
assert dtype.is_type("foo")
assert not dtype.is_type("bar")

with self.assertRaises(ParseError):
exp.DataType.build("foo")

0 comments on commit 28a0e20

Please sign in to comment.