Skip to content

Commit

Permalink
Feat(databricks): add support for UNPIVOT nulls option
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas committed Aug 10, 2023
1 parent c3fd695 commit 12a7ba4
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 2 deletions.
1 change: 1 addition & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3309,6 +3309,7 @@ class Pivot(Expression):
"using": False,
"group": False,
"columns": False,
"include_nulls": False,
}


Expand Down
7 changes: 6 additions & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,7 +1290,12 @@ def pivot_sql(self, expression: exp.Pivot) -> str:
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
field = self.sql(expression, "field")
return f"{direction}({expressions} FOR {field}){alias}"
include_nulls = expression.args.get("include_nulls")
if include_nulls is not None:
nulls = " INCLUDE NULLS " if include_nulls else " EXCLUDE NULLS "
else:
nulls = ""
return f"{direction}{nulls}({expressions} FOR {field}){alias}"

def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})"
Expand Down
15 changes: 14 additions & 1 deletion sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2623,11 +2623,18 @@ def _parse_on() -> t.Optional[exp.Expression]:

def _parse_pivot(self) -> t.Optional[exp.Pivot]:
index = self._index
include_nulls = None

if self._match(TokenType.PIVOT):
unpivot = False
elif self._match(TokenType.UNPIVOT):
unpivot = True

# https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-unpivot.html#syntax
if self._match_text_seq("INCLUDE", "NULLS"):
include_nulls = True
elif self._match_text_seq("EXCLUDE", "NULLS"):
include_nulls = False
else:
return None

Expand Down Expand Up @@ -2658,7 +2665,13 @@ def _parse_pivot(self) -> t.Optional[exp.Pivot]:

self._match_r_paren()

pivot = self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot)
pivot = self.expression(
exp.Pivot,
expressions=expressions,
field=field,
unpivot=unpivot,
include_nulls=include_nulls,
)

if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False):
pivot.set("alias", self._parse_table_alias())
Expand Down
3 changes: 3 additions & 0 deletions tests/dialects/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ def test_databricks(self):
self.validate_identity("CREATE FUNCTION a AS b")
self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1")
self.validate_identity("CREATE TABLE foo (x DATE GENERATED ALWAYS AS (CAST(y AS DATE)))")
self.validate_identity(
"SELECT * FROM sales UNPIVOT INCLUDE NULLS (sales FOR quarter IN (q1 AS `Jan-Mar`))"
)

self.validate_all(
"CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))",
Expand Down

0 comments on commit 12a7ba4

Please sign in to comment.