Skip to content

Commit

Permalink
Fix(clickhouse): don't generate parentheses, match R_PAREN conditiona…
Browse files Browse the repository at this point in the history
…lly (#2332)

* Fix(clickhouse): don't generate parentheses, match R_PAREN conditionally

* Fixups
  • Loading branch information
georgesittas authored Sep 27, 2023
1 parent f80501c commit 58c7849
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 5 deletions.
1 change: 1 addition & 0 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ class Generator(generator.Generator):
QUERY_HINTS = False
STRUCT_DELIMITER = ("(", ")")
NVL2_SUPPORTED = False
TABLESAMPLE_REQUIRES_PARENS = False

STRING_TYPE_MAPPING = {
exp.DataType.Type.CHAR: "String",
Expand Down
13 changes: 12 additions & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ class Generator:
# Whether or not CREATE TABLE .. COPY .. is supported. False means we'll generate CLONE instead of COPY
SUPPORTS_TABLE_COPY = True

# Whether or not parentheses are required around the table sample's expression
TABLESAMPLE_REQUIRES_PARENS = True

TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
Expand Down Expand Up @@ -1353,6 +1356,7 @@ def tablesample_sql(
else:
this = self.sql(expression, "this")
alias = ""

method = self.sql(expression, "method")
method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else ""
numerator = self.sql(expression, "bucket_numerator")
Expand All @@ -1364,13 +1368,20 @@ def tablesample_sql(
percent = f"{percent} PERCENT" if percent else ""
rows = self.sql(expression, "rows")
rows = f"{rows} ROWS" if rows else ""

size = self.sql(expression, "size")
if size and self.TABLESAMPLE_SIZE_IS_PERCENT:
size = f"{size} PERCENT"

seed = self.sql(expression, "seed")
seed = f" {seed_prefix} ({seed})" if seed else ""
kind = expression.args.get("kind", "TABLESAMPLE")
return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}"

expr = f"{bucket}{percent}{rows}{size}"
if self.TABLESAMPLE_REQUIRES_PARENS:
expr = f"({expr})"

return f"{this} {kind} {method}{expr}{seed}{alias}"

def pivot_sql(self, expression: exp.Pivot) -> str:
expressions = self.expressions(expression, flat=True)
Expand Down
11 changes: 8 additions & 3 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2734,14 +2734,18 @@ def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.Table
)
method = self._parse_var(tokens=(TokenType.ROW,))

self._match(TokenType.L_PAREN)
matched_l_paren = self._match(TokenType.L_PAREN)

if self.TABLESAMPLE_CSV:
num = None
expressions = self._parse_csv(self._parse_primary)
else:
expressions = None
num = self._parse_primary()
num = (
self._parse_factor()
if self._match(TokenType.NUMBER, advance=False)
else self._parse_primary()
)

if self._match_text_seq("BUCKET"):
bucket_numerator = self._parse_number()
Expand All @@ -2756,7 +2760,8 @@ def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.Table
elif num:
size = num

self._match(TokenType.R_PAREN)
if matched_l_paren:
self._match_r_paren()

if self._match(TokenType.L_PAREN):
method = self._parse_var()
Expand Down
4 changes: 3 additions & 1 deletion tests/dialects/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def test_clickhouse(self):
self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)")
self.assertIsNone(expr._meta)

self.validate_identity("SELECT sum(foo * bar) FROM bla SAMPLE (10000000)")
self.validate_identity("SELECT * FROM (SELECT a FROM b SAMPLE 0.01)")
self.validate_identity("SELECT * FROM (SELECT a FROM b SAMPLE 1 / 10 OFFSET 1 / 2)")
self.validate_identity("SELECT sum(foo * bar) FROM bla SAMPLE 10000000")
self.validate_identity("CAST(x AS Nested(ID UInt32, Serial UInt32, EventTime DATETIME))")
self.validate_identity("CAST(x AS Enum('hello' = 1, 'world' = 2))")
self.validate_identity("CAST(x AS Enum('hello', 'world'))")
Expand Down

0 comments on commit 58c7849

Please sign in to comment.