Skip to content

Commit

Permalink
Feat(optimizer): improve type annotation for nested types (#2061)
Browse files Browse the repository at this point in the history
* Feat(optimizer): improve type annotation for nested types

* Improve support for ARRAY, ARRAY_CAT
  • Loading branch information
georgesittas authored Aug 15, 2023
1 parent 21b061f commit d92a5b7
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 13 deletions.
1 change: 1 addition & 0 deletions sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
exp.Explode: rename_func("UNNEST"),
Expand Down
1 change: 1 addition & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4005,6 +4005,7 @@ class ArrayAny(Func):


class ArrayConcat(Func):
_sql_names = ["ARRAY_CONCAT", "ARRAY_CAT"]
arg_types = {"this": True, "expressions": False}
is_var_len_args = True

Expand Down
39 changes: 29 additions & 10 deletions sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,15 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
for expr_type in expressions
},
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
Expand All @@ -220,6 +225,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
}

NESTED_TYPES = {
exp.DataType.Type.ARRAY,
}

# Specifies what types a given type can be coerced into (autofilled)
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}

Expand Down Expand Up @@ -299,19 +308,22 @@ def _annotate_args(self, expression: E) -> E:

def _maybe_coerce(
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
) -> exp.DataType.Type:
# We propagate the NULL / UNKNOWN types upwards if found
if isinstance(type1, exp.DataType):
type1 = type1.this
if isinstance(type2, exp.DataType):
type2 = type2.this
) -> exp.DataType | exp.DataType.Type:
type1_value = type1.this if isinstance(type1, exp.DataType) else type1
type2_value = type2.this if isinstance(type2, exp.DataType) else type2

if exp.DataType.Type.NULL in (type1, type2):
# We propagate the NULL / UNKNOWN types upwards if found
if exp.DataType.Type.NULL in (type1_value, type2_value):
return exp.DataType.Type.NULL
if exp.DataType.Type.UNKNOWN in (type1, type2):
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
return exp.DataType.Type.UNKNOWN

return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore
if type1_value in self.NESTED_TYPES:
return type1
if type2_value in self.NESTED_TYPES:
return type2

return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore

# Note: the following "no_type_check" decorators were added because mypy was yelling due
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
Expand Down Expand Up @@ -368,7 +380,9 @@ def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) ->
return self._annotate_args(expression)

@t.no_type_check
def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
def _annotate_by_args(
self, expression: E, *args: str, promote: bool = False, array: bool = False
) -> E:
self._annotate_args(expression)

expressions: t.List[exp.Expression] = []
Expand All @@ -388,4 +402,9 @@ def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) ->
elif expression.type.this in exp.DataType.FLOAT_TYPES:
expression.type = exp.DataType.Type.DOUBLE

if array:
expression.type = exp.DataType(
this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
)

return expression
12 changes: 9 additions & 3 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,19 @@ def test_duckdb(self):
)
self.validate_all(
"ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))",
read={
"bigquery": "ARRAY_CONCAT([1, 2], [3, 4])",
"postgres": "ARRAY_CAT(ARRAY[1, 2], ARRAY[3, 4])",
"snowflake": "ARRAY_CAT([1, 2], [3, 4])",
},
write={
"bigquery": "ARRAY_CONCAT([1, 2], [3, 4])",
"duckdb": "ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))",
"presto": "CONCAT(ARRAY[1, 2], ARRAY[3, 4])",
"hive": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))",
"spark": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))",
"postgres": "ARRAY_CAT(ARRAY[1, 2], ARRAY[3, 4])",
"presto": "CONCAT(ARRAY[1, 2], ARRAY[3, 4])",
"snowflake": "ARRAY_CAT([1, 2], [3, 4])",
"bigquery": "ARRAY_CONCAT([1, 2], [3, 4])",
"spark": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))",
},
)
self.validate_all(
Expand Down
18 changes: 18 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,24 @@ def test_root_subquery_annotation(self):
self.assertEqual(exp.DataType.Type.INT, expression.selects[0].type.this)
self.assertEqual(exp.DataType.Type.INT, expression.selects[1].type.this)

def test_nested_type_annotation(self):
schema = {"order": {"customer_id": "bigint", "item_id": "bigint", "item_price": "numeric"}}
sql = """
SELECT ARRAY_AGG(DISTINCT order.item_id) FILTER (WHERE order.item_price > 10) AS items,
FROM order AS order
GROUP BY order.customer_id
"""
expression = annotate_types(parse_one(sql), schema=schema)

self.assertEqual(exp.DataType.Type.ARRAY, expression.selects[0].type.this)
self.assertEqual(expression.selects[0].type.sql(), "ARRAY<BIGINT>")

expression = annotate_types(
parse_one("SELECT ARRAY_CAT(ARRAY[1,2,3], ARRAY[4,5])", read="postgres")
)
self.assertEqual(exp.DataType.Type.ARRAY, expression.selects[0].type.this)
self.assertEqual(expression.selects[0].type.sql(), "ARRAY<INT>")

def test_recursive_cte(self):
query = parse_one(
"""
Expand Down

0 comments on commit d92a5b7

Please sign in to comment.