Skip to content

Commit

Permalink
Fix(hive): parse <number> <date_part> as an interval instead of an al…
Browse files Browse the repository at this point in the history
…ias (#2151)

* Fix(hive): parse <number> <date_part> as an interval instead of an alias

* Use super
  • Loading branch information
georgesittas authored Sep 4, 2023
1 parent 57c87e1 commit 9c4a9cd
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 1 deletion.
27 changes: 27 additions & 0 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,31 @@
"HOUR": " / 3600",
}

INTERVAL_VARS = {
"SECOND",
"SECONDS",
"MINUTE",
"MINUTES",
"DAY",
"DAYS",
"MONTH",
"MONTHS",
"YEAR",
"YEARS",
}


DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")


def _parse_number(self: Hive.Parser, token: TokenType) -> t.Optional[exp.Expression]:
number = super(type(self), self).PRIMARY_PARSERS[TokenType.NUMBER](self, token)
if self._match(TokenType.VAR, advance=False) and self._curr.text.upper() in INTERVAL_VARS:
return exp.Interval(this=number, unit=self._parse_var())

return number


def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
Expand Down Expand Up @@ -284,6 +306,11 @@ class Parser(parser.Parser):
),
}

PRIMARY_PARSERS = {
**parser.Parser.PRIMARY_PARSERS,
TokenType.NUMBER: _parse_number,
}

def _parse_transform(self) -> t.Optional[exp.Transform | exp.QueryTransform]:
if not self._match(TokenType.L_PAREN, advance=False):
self._retreat(self._index - 1)
Expand Down
5 changes: 4 additions & 1 deletion sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import typing as t

from sqlglot import exp, transforms
from sqlglot import exp, parser, transforms
from sqlglot.dialects.dialect import (
binary_from_function,
create_with_partitions_sql,
Expand Down Expand Up @@ -165,6 +165,9 @@ class Parser(Hive.Parser):
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
}

# We dont' want to inherit Hive's TokenType.NUMBER override
PRIMARY_PARSERS = parser.Parser.PRIMARY_PARSERS.copy()

def _parse_add_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()

Expand Down
8 changes: 8 additions & 0 deletions tests/dialects/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,14 @@ def test_hive(self):
self.validate_identity("SELECT * FROM my_table TIMESTAMP AS OF DATE_ADD(CURRENT_DATE, -1)")
self.validate_identity("SELECT * FROM my_table VERSION AS OF DATE_ADD(CURRENT_DATE, -1)")

self.validate_identity(
"SELECT CAST('1998-01-01' AS DATE) + 30 years",
"SELECT CAST('1998-01-01' AS DATE) + INTERVAL 30 years",
)
self.validate_identity(
"SELECT 30 + 50 bla",
"SELECT 30 + 50 AS bla",
)
self.validate_identity(
"SELECT ROW() OVER (DISTRIBUTE BY x SORT BY y)",
"SELECT ROW() OVER (PARTITION BY x ORDER BY y)",
Expand Down
4 changes: 4 additions & 0 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ def test_spark(self):
"SELECT STR_TO_MAP('a:1,b:2,c:3')",
"SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')",
)
self.validate_identity(
"SELECT CAST('1998-01-01' AS DATE) + 30 years",
"SELECT CAST('1998-01-01' AS DATE) + 30 AS years",
)

self.validate_all(
"foo.bar",
Expand Down

0 comments on commit 9c4a9cd

Please sign in to comment.