Skip to content

Commit

Permalink
Fix: full support for spark clustered by
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Jun 27, 2023
1 parent b60e19b commit 40928b7
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 13 deletions.
6 changes: 0 additions & 6 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ class Tokenizer(tokens.Tokenizer):
"ADD FILES": TokenType.COMMAND,
"ADD JAR": TokenType.COMMAND,
"ADD JARS": TokenType.COMMAND,
"CLUSTERED BY": TokenType.CLUSTER_BY,
"MSCK REPAIR": TokenType.COMMAND,
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
}
Expand Down Expand Up @@ -437,8 +436,3 @@ def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
]

def cluster_sql(self, expression: exp.Cluster) -> str:
if isinstance(expression.parent, exp.Properties):
return f"CLUSTERED BY ({self.expressions(expression, flat=True)})"
return self.op_expressions("CLUSTER BY", expression)
6 changes: 6 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1930,6 +1930,11 @@ class LanguageProperty(Property):
arg_types = {"this": True}


# spark ddl
class ClusteredByProperty(Property):
arg_types = {"expressions": True, "sorted_by": False, "buckets": True}


class DictProperty(Property):
arg_types = {"this": True, "kind": True, "settings": False}

Expand Down Expand Up @@ -2070,6 +2075,7 @@ class Properties(Expression):
"ALGORITHM": AlgorithmProperty,
"AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER SET": CharacterSetProperty,
"CLUSTERED_BY": ClusteredByProperty,
"COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty,
"DEFINER": DefinerProperty,
Expand Down
8 changes: 8 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class Generator:
exp.CollateProperty: exp.Properties.Location.POST_SCHEMA,
exp.CopyGrantsProperty: exp.Properties.Location.POST_SCHEMA,
exp.Cluster: exp.Properties.Location.POST_SCHEMA,
exp.ClusteredByProperty: exp.Properties.Location.POST_SCHEMA,
exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME,
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
exp.DictRange: exp.Properties.Location.POST_SCHEMA,
Expand Down Expand Up @@ -2413,6 +2414,13 @@ def dictsubproperty_sql(self, expression: exp.DictSubProperty) -> str:
def oncluster_sql(self, expression: exp.OnCluster) -> str:
return ""

def clusteredbyproperty_sql(self, expression: exp.ClusteredByProperty) -> str:
expressions = self.expressions(expression, key="expressions", flat=True)
sorted_by = self.expressions(expression, key="sorted_by", flat=True)
sorted_by = f" SORTED BY ({sorted_by})" if sorted_by else ""
buckets = self.sql(expression, "buckets")
return f"CLUSTERED BY ({expressions}){sorted_by} INTO {buckets} BUCKETS"


def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
Expand Down
30 changes: 24 additions & 6 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ class Parser(metaclass=_Parser):
"CHARACTER SET": lambda self: self._parse_character_set(),
"CHECKSUM": lambda self: self._parse_checksum(),
"CLUSTER BY": lambda self: self._parse_cluster(),
"CLUSTERED BY": lambda self: self._parse_cluster(dml=True),
"CLUSTERED": lambda self: self._parse_clustered_by(),
"COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
"COPY": lambda self: self._parse_copy_property(),
Expand Down Expand Up @@ -1427,15 +1427,33 @@ def _parse_checksum(self) -> exp.ChecksumProperty:

return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT))

def _parse_cluster(self, dml: bool = False) -> t.Optional[exp.Cluster]:
if dml:
def _parse_cluster(self) -> exp.Cluster:
return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))

def _parse_clustered_by(self) -> exp.ClusteredByProperty:
self._match_text_seq("BY")

self._match_l_paren()
expressions = self._parse_csv(self._parse_column)
self._match_r_paren()

if self._match_text_seq("SORTED", "BY"):
self._match_l_paren()
expressions = self._parse_csv(self._parse_ordered)
sorted_by = self._parse_csv(self._parse_ordered)
self._match_r_paren()
else:
expressions = self._parse_csv(self._parse_ordered)
sorted_by = None

return self.expression(exp.Cluster, expressions=expressions)
self._match(TokenType.INTO)
buckets = self._parse_number()
self._match_text_seq("BUCKETS")

return self.expression(
exp.ClusteredByProperty,
expressions=expressions,
sorted_by=sorted_by,
buckets=buckets,
)

def _parse_copy_property(self) -> t.Optional[exp.CopyGrantsProperty]:
if not self._match_text_seq("GRANTS"):
Expand Down
5 changes: 4 additions & 1 deletion tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ class TestSpark(Validator):
def test_ddl(self):
self.validate_identity("CREATE TABLE foo (col VARCHAR(50))")
self.validate_identity("CREATE TABLE foo (col STRUCT<struct_col_a: VARCHAR((50))>)")
self.validate_identity("CREATE TABLE foo (col STRING) CLUSTERED BY (col)")
self.validate_identity("CREATE TABLE foo (col STRING) CLUSTERED BY (col) INTO 10 BUCKETS")
self.validate_identity(
"CREATE TABLE foo (col STRING) CLUSTERED BY (col) SORTED BY (col) INTO 10 BUCKETS"
)

self.validate_all(
"CREATE TABLE db.example_table (col_a struct<struct_col_a:int, struct_col_b:string>)",
Expand Down

0 comments on commit 40928b7

Please sign in to comment.