diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index b8ba78ba4d..83d1684cc2 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -32,9 +32,13 @@ def _create_sql(self: Hive.Generator, e: exp.Create) -> str: def _map_sql(self: Hive.Generator, expression: exp.Map) -> str: - keys = self.sql(expression.args["keys"]) - values = self.sql(expression.args["values"]) - return f"MAP_FROM_ARRAYS({keys}, {values})" + keys = expression.args.get("keys") + values = expression.args.get("values") + + if not keys or not values: + return "MAP()" + + return f"MAP_FROM_ARRAYS({self.sql(keys)}, {self.sql(values)})" def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]: diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index a984025c6f..1808f53178 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -244,6 +244,23 @@ def test_spark(self): "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", ) + self.validate_all( + "MAP(1, 2, 3, 4)", + write={ + "spark": "MAP(1, 2, 3, 4)", + "trino": "MAP(ARRAY[1, 3], ARRAY[2, 4])", + }, + ) + self.validate_all( + "MAP()", + read={ + "spark": "MAP()", + "trino": "MAP()", + }, + write={ + "trino": "MAP(ARRAY[], ARRAY[])", + }, + ) self.validate_all( "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", read={