Skip to content

Commit

Permalink
Fix: cluster/distribute/sort by for hive
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Jun 16, 2023
1 parent cf0e28a commit 18db68c
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 23 deletions.
6 changes: 3 additions & 3 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ class Parser(parser.Parser):

QUERY_MODIFIER_PARSERS = {
**parser.Parser.QUERY_MODIFIER_PARSERS,
"distribute": lambda self: self._parse_sort(exp.Distribute, "DISTRIBUTE", "BY"),
"sort": lambda self: self._parse_sort(exp.Sort, "SORT", "BY"),
"cluster": lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
"cluster": lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
"distribute": lambda self: self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY),
"sort": lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY),
}

def _parse_types(
Expand Down
14 changes: 5 additions & 9 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ class Parser(metaclass=_Parser):
}

EXPRESSION_PARSERS = {
exp.Cluster: lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
exp.Cluster: lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
exp.Column: lambda self: self._parse_column(),
exp.Condition: lambda self: self._parse_conjunction(),
exp.DataType: lambda self: self._parse_types(),
Expand All @@ -484,7 +484,7 @@ class Parser(metaclass=_Parser):
exp.Properties: lambda self: self._parse_properties(),
exp.Qualify: lambda self: self._parse_qualify(),
exp.Returning: lambda self: self._parse_returning(),
exp.Sort: lambda self: self._parse_sort(exp.Sort, "SORT", "BY"),
exp.Sort: lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY),
exp.Table: lambda self: self._parse_table_parts(),
exp.TableAlias: lambda self: self._parse_table_alias(),
exp.Where: lambda self: self._parse_where(),
Expand Down Expand Up @@ -584,7 +584,7 @@ class Parser(metaclass=_Parser):
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
"CHARACTER SET": lambda self: self._parse_character_set(),
"CHECKSUM": lambda self: self._parse_checksum(),
"CLUSTER": lambda self: self._parse_cluster(),
"CLUSTER BY": lambda self: self._parse_cluster(),
"COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
"COPY": lambda self: self._parse_copy_property(),
Expand Down Expand Up @@ -1424,10 +1424,6 @@ def _parse_checksum(self) -> exp.ChecksumProperty:
return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT))

def _parse_cluster(self) -> t.Optional[exp.Cluster]:
if not self._match_text_seq("BY"):
self._retreat(self._index - 1)
return None

return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))

def _parse_copy_property(self) -> t.Optional[exp.CopyGrantsProperty]:
Expand Down Expand Up @@ -2618,8 +2614,8 @@ def _parse_order(
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
)

def _parse_sort(self, exp_class: t.Type[E], *texts: str) -> t.Optional[E]:
if not self._match_text_seq(*texts):
def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]:
if not self._match(token):
return None
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))

Expand Down
6 changes: 6 additions & 0 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class TokenType(AutoName):
CACHE = auto()
CASE = auto()
CHARACTER_SET = auto()
CLUSTER_BY = auto()
COLLATE = auto()
COMMAND = auto()
COMMENT = auto()
Expand All @@ -182,6 +183,7 @@ class TokenType(AutoName):
DESCRIBE = auto()
DICTIONARY = auto()
DISTINCT = auto()
DISTRIBUTE_BY = auto()
DIV = auto()
DROP = auto()
ELSE = auto()
Expand Down Expand Up @@ -282,6 +284,7 @@ class TokenType(AutoName):
SHOW = auto()
SIMILAR_TO = auto()
SOME = auto()
SORT_BY = auto()
STRUCT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
Expand Down Expand Up @@ -509,6 +512,7 @@ class Tokenizer(metaclass=_Tokenizer):
"UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE,
"CHARACTER SET": TokenType.CHARACTER_SET,
"CLUSTER BY": TokenType.CLUSTER_BY,
"COLLATE": TokenType.COLLATE,
"COLUMN": TokenType.COLUMN,
"COMMIT": TokenType.COMMIT,
Expand All @@ -526,6 +530,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DESC": TokenType.DESC,
"DESCRIBE": TokenType.DESCRIBE,
"DISTINCT": TokenType.DISTINCT,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DIV": TokenType.DIV,
"DROP": TokenType.DROP,
"ELSE": TokenType.ELSE,
Expand Down Expand Up @@ -617,6 +622,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SHOW": TokenType.SHOW,
"SIMILAR TO": TokenType.SIMILAR_TO,
"SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY,
"TABLE": TokenType.TABLE,
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
Expand Down
21 changes: 10 additions & 11 deletions tests/dialects/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,6 @@
class TestHive(Validator):
dialect = "hive"

def test_hive(self):
self.validate_identity("SELECT * FROM test DISTRIBUTE BY y SORT BY x DESC ORDER BY l")
self.validate_identity(
"SELECT * FROM test WHERE RAND() <= 0.1 DISTRIBUTE BY RAND() SORT BY RAND()"
)
self.validate_identity("(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z")
self.validate_identity("(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z SORT BY x")
self.validate_identity("(SELECT 1 UNION SELECT 2) CLUSTER BY y DESC")
self.validate_identity("SELECT * FROM test CLUSTER BY y")
self.validate_identity("(SELECT 1 UNION SELECT 2) SORT BY z")

def test_bits(self):
self.validate_all(
"x & 1",
Expand Down Expand Up @@ -381,6 +370,16 @@ def test_order_by(self):
)

def test_hive(self):
self.validate_identity("SELECT * FROM test DISTRIBUTE BY y SORT BY x DESC ORDER BY l")
self.validate_identity(
"SELECT * FROM test WHERE RAND() <= 0.1 DISTRIBUTE BY RAND() SORT BY RAND()"
)
self.validate_identity("(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z")
self.validate_identity("(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z SORT BY x")
self.validate_identity("(SELECT 1 UNION SELECT 2) CLUSTER BY y DESC")
self.validate_identity("SELECT * FROM test CLUSTER BY y")

self.validate_identity("(SELECT 1 UNION SELECT 2) SORT BY z")
self.validate_identity(
"INSERT OVERWRITE TABLE zipcodes PARTITION(state = '0') VALUES (896, 'US', 'TAMPA', 33607)"
)
Expand Down

0 comments on commit 18db68c

Please sign in to comment.