diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index c20775120c..edd72ea6bb 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -3309,6 +3309,7 @@ class Pivot(Expression): "using": False, "group": False, "columns": False, + "include_nulls": False, } diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 8dba11fc74..0ab960ac1b 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1290,7 +1290,12 @@ def pivot_sql(self, expression: exp.Pivot) -> str: unpivot = expression.args.get("unpivot") direction = "UNPIVOT" if unpivot else "PIVOT" field = self.sql(expression, "field") - return f"{direction}({expressions} FOR {field}){alias}" + include_nulls = expression.args.get("include_nulls") + if include_nulls is not None: + nulls = " INCLUDE NULLS " if include_nulls else " EXCLUDE NULLS " + else: + nulls = "" + return f"{direction}{nulls}({expressions} FOR {field}){alias}" def tuple_sql(self, expression: exp.Tuple) -> str: return f"({self.expressions(expression, flat=True)})" diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 1f3e240df4..1847148052 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -2623,11 +2623,18 @@ def _parse_on() -> t.Optional[exp.Expression]: def _parse_pivot(self) -> t.Optional[exp.Pivot]: index = self._index + include_nulls = None if self._match(TokenType.PIVOT): unpivot = False elif self._match(TokenType.UNPIVOT): unpivot = True + + # https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-unpivot.html#syntax + if self._match_text_seq("INCLUDE", "NULLS"): + include_nulls = True + elif self._match_text_seq("EXCLUDE", "NULLS"): + include_nulls = False else: return None @@ -2658,7 +2665,13 @@ def _parse_pivot(self) -> t.Optional[exp.Pivot]: self._match_r_paren() - pivot = self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot) + pivot = self.expression( + exp.Pivot, + expressions=expressions, + field=field, + unpivot=unpivot, + include_nulls=include_nulls, + ) if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False): pivot.set("alias", self._parse_table_alias()) diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 14f7cd048b..95e6635c65 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -11,6 +11,9 @@ def test_databricks(self): self.validate_identity("CREATE FUNCTION a AS b") self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1") self.validate_identity("CREATE TABLE foo (x DATE GENERATED ALWAYS AS (CAST(y AS DATE)))") + self.validate_identity( + "SELECT * FROM sales UNPIVOT INCLUDE NULLS (sales FOR quarter IN (q1 AS `Jan-Mar`))" + ) self.validate_all( "CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))",