Skip to content

Commit

Permalink
Fix(oracle): set post_tablesample_alias=True to fix alias parsing (#1548
Browse files Browse the repository at this point in the history
)

* Fix(oracle): set post_tablesample_alias=True to fix alias parsing

* Fixups
  • Loading branch information
georgesittas authored May 4, 2023
1 parent b7e08cc commit f21abb7
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 24 deletions.
6 changes: 4 additions & 2 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,5 +233,7 @@ class Generator(generator.Generator):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}

def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
return super().tablesample_sql(expression, seed_prefix="REPEATABLE")
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)
4 changes: 4 additions & 0 deletions sqlglot/dialects/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def _parse_xml_table(self) -> exp.XMLTable:


class Oracle(Dialect):
alias_post_tablesample = True

# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
time_mapping = {
Expand Down Expand Up @@ -133,6 +135,7 @@ class Generator(generator.Generator):
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
exp.TableSample: lambda self, e: self.tablesample_sql(e, sep=" "),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
Expand Down Expand Up @@ -175,6 +178,7 @@ class Tokenizer(tokens.Tokenizer):
"MINUS": TokenType.EXCEPT,
"NVARCHAR2": TokenType.NVARCHAR,
"RETURNING": TokenType.RETURNING,
"SAMPLE": TokenType.TABLE_SAMPLE,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
Expand Down
6 changes: 4 additions & 2 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,10 +1115,12 @@ def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:

return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}"

def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
if self.alias_post_tablesample and expression.this.alias:
this = self.sql(expression.this, "this")
alias = f" AS {self.sql(expression.this, 'alias')}"
alias = f"{sep}{self.sql(expression.this, 'alias')}"
else:
this = self.sql(expression, "this")
alias = ""
Expand Down
4 changes: 3 additions & 1 deletion sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2333,7 +2333,9 @@ def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.Expre
size = None
seed = None

kind = "TABLESAMPLE" if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE"
kind = (
self._prev.text if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE"
)
method = self._parse_var(tokens=(TokenType.ROW,))

self._match(TokenType.L_PAREN)
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class TestOracle(Validator):

def test_oracle(self):
self.validate_identity("SELECT * FROM table_name@dblink_name.database_link_domain")
self.validate_identity("SELECT * FROM table_name SAMPLE (25) s")
self.validate_identity("SELECT * FROM V$SESSION")
self.validate_identity(
"SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name"
Expand All @@ -18,7 +19,6 @@ def test_oracle(self):
"": "IFNULL(NULL, 1)",
},
)

self.validate_all(
"DATE '2022-01-01'",
write={
Expand Down
25 changes: 7 additions & 18 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,29 +426,18 @@ def test_null_treatment(self):
def test_sample(self):
self.validate_identity("SELECT * FROM testtable TABLESAMPLE BERNOULLI (20.3)")
self.validate_identity("SELECT * FROM testtable TABLESAMPLE (100)")
self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)")
self.validate_identity("SELECT * FROM testtable TABLESAMPLE (10 ROWS)")
self.validate_identity("SELECT * FROM testtable SAMPLE (10)")
self.validate_identity("SELECT * FROM testtable SAMPLE ROW (0)")
self.validate_identity("SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)")
self.validate_identity(
"SELECT i, j FROM table1 AS t1 INNER JOIN table2 AS t2 TABLESAMPLE (50) WHERE t2.j = t1.i"
)
self.validate_identity(
"SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1)"
)
self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)")
self.validate_identity("SELECT * FROM testtable TABLESAMPLE (10 ROWS)")

self.validate_all(
"SELECT * FROM testtable SAMPLE (10)",
write={"snowflake": "SELECT * FROM testtable TABLESAMPLE (10)"},
)
self.validate_all(
"SELECT * FROM testtable SAMPLE ROW (0)",
write={"snowflake": "SELECT * FROM testtable TABLESAMPLE ROW (0)"},
)
self.validate_all(
"SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)",
write={
"snowflake": "SELECT a FROM test TABLESAMPLE BLOCK (0.5) SEED (42)",
},
)
self.validate_all(
"""
SELECT i, j
Expand All @@ -458,13 +447,13 @@ def test_sample(self):
table2 AS t2 SAMPLE (50) -- 50% of rows in table2
WHERE t2.j = t1.i""",
write={
"snowflake": "SELECT i, j FROM table1 AS t1 TABLESAMPLE (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 TABLESAMPLE (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i",
"snowflake": "SELECT i, j FROM table1 AS t1 SAMPLE (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 SAMPLE (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i",
},
)
self.validate_all(
"SELECT * FROM testtable SAMPLE BLOCK (0.012) REPEATABLE (99992)",
write={
"snowflake": "SELECT * FROM testtable TABLESAMPLE BLOCK (0.012) SEED (99992)",
"snowflake": "SELECT * FROM testtable SAMPLE BLOCK (0.012) SEED (99992)",
},
)

Expand Down

0 comments on commit f21abb7

Please sign in to comment.