From d92a5b73895ed3d843f44e4b24b68bf283376ee6 Mon Sep 17 00:00:00 2001 From: Jo <46752250+GeorgeSittas@users.noreply.github.com> Date: Tue, 15 Aug 2023 18:41:49 +0300 Subject: [PATCH] Feat(optimizer): improve type annotation for nested types (#2061) * Feat(optimizer): improve type annotation for nested types * Improve support for ARRAY, ARRAY_CAT --- sqlglot/dialects/postgres.py | 1 + sqlglot/expressions.py | 1 + sqlglot/optimizer/annotate_types.py | 39 +++++++++++++++++++++-------- tests/dialects/test_duckdb.py | 12 ++++++--- tests/test_optimizer.py | 18 +++++++++++++ 5 files changed, 58 insertions(+), 13 deletions(-) diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 88cdf092a4..32904954fe 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -360,6 +360,7 @@ class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.AnyValue: any_value_to_max_sql, + exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.BitwiseXor: lambda self, e: self.binary(e, "#"), exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), exp.Explode: rename_func("UNNEST"), diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 768f5ee3f6..28174dd903 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -4005,6 +4005,7 @@ class ArrayAny(Func): class ArrayConcat(Func): + _sql_names = ["ARRAY_CONCAT", "ARRAY_CAT"] arg_types = {"this": True, "expressions": False} is_var_len_args = True diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index e7cb80b5e6..a4296557f1 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -203,10 +203,15 @@ class TypeAnnotator(metaclass=_TypeAnnotator): for expr_type in expressions }, exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), + exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), + exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), + exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), + exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), + exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), @@ -220,6 +225,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), } + NESTED_TYPES = { + exp.DataType.Type.ARRAY, + } + # Specifies what types a given type can be coerced into (autofilled) COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} @@ -299,19 +308,22 @@ def _annotate_args(self, expression: E) -> E: def _maybe_coerce( self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type - ) -> exp.DataType.Type: - # We propagate the NULL / UNKNOWN types upwards if found - if isinstance(type1, exp.DataType): - type1 = type1.this - if isinstance(type2, exp.DataType): - type2 = type2.this + ) -> exp.DataType | exp.DataType.Type: + type1_value = type1.this if isinstance(type1, exp.DataType) else type1 + type2_value = type2.this if isinstance(type2, exp.DataType) else type2 - if exp.DataType.Type.NULL in (type1, type2): + # We propagate the NULL / UNKNOWN types upwards if found + if exp.DataType.Type.NULL in (type1_value, type2_value): return exp.DataType.Type.NULL - if exp.DataType.Type.UNKNOWN in (type1, type2): + if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): return exp.DataType.Type.UNKNOWN - return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore + if type1_value in self.NESTED_TYPES: + return type1 + if type2_value in self.NESTED_TYPES: + return type2 + + return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore # Note: the following "no_type_check" decorators were added because mypy was yelling due # to assigning Type values to expression.type (since its getter returns Optional[DataType]). @@ -368,7 +380,9 @@ def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> return self._annotate_args(expression) @t.no_type_check - def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E: + def _annotate_by_args( + self, expression: E, *args: str, promote: bool = False, array: bool = False + ) -> E: self._annotate_args(expression) expressions: t.List[exp.Expression] = [] @@ -388,4 +402,9 @@ def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> elif expression.type.this in exp.DataType.FLOAT_TYPES: expression.type = exp.DataType.Type.DOUBLE + if array: + expression.type = exp.DataType( + this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True + ) + return expression diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index c33c899bc7..c800e5891f 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -305,13 +305,19 @@ def test_duckdb(self): ) self.validate_all( "ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))", + read={ + "bigquery": "ARRAY_CONCAT([1, 2], [3, 4])", + "postgres": "ARRAY_CAT(ARRAY[1, 2], ARRAY[3, 4])", + "snowflake": "ARRAY_CAT([1, 2], [3, 4])", + }, write={ + "bigquery": "ARRAY_CONCAT([1, 2], [3, 4])", "duckdb": "ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))", - "presto": "CONCAT(ARRAY[1, 2], ARRAY[3, 4])", "hive": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))", - "spark": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))", + "postgres": "ARRAY_CAT(ARRAY[1, 2], ARRAY[3, 4])", + "presto": "CONCAT(ARRAY[1, 2], ARRAY[3, 4])", "snowflake": "ARRAY_CAT([1, 2], [3, 4])", - "bigquery": "ARRAY_CONCAT([1, 2], [3, 4])", + "spark": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))", }, ) self.validate_all( diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index a1bd309afe..e001c1fdb5 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -758,6 +758,24 @@ def test_root_subquery_annotation(self): self.assertEqual(exp.DataType.Type.INT, expression.selects[0].type.this) self.assertEqual(exp.DataType.Type.INT, expression.selects[1].type.this) + def test_nested_type_annotation(self): + schema = {"order": {"customer_id": "bigint", "item_id": "bigint", "item_price": "numeric"}} + sql = """ + SELECT ARRAY_AGG(DISTINCT order.item_id) FILTER (WHERE order.item_price > 10) AS items, + FROM order AS order + GROUP BY order.customer_id + """ + expression = annotate_types(parse_one(sql), schema=schema) + + self.assertEqual(exp.DataType.Type.ARRAY, expression.selects[0].type.this) + self.assertEqual(expression.selects[0].type.sql(), "ARRAY") + + expression = annotate_types( + parse_one("SELECT ARRAY_CAT(ARRAY[1,2,3], ARRAY[4,5])", read="postgres") + ) + self.assertEqual(exp.DataType.Type.ARRAY, expression.selects[0].type.this) + self.assertEqual(expression.selects[0].type.sql(), "ARRAY") + def test_recursive_cte(self): query = parse_one( """