Skip to content

Commit

Permalink
Fix!: don't parse SEMI, ANTI as table aliases, fix join side issue (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Sep 17, 2023
1 parent 94d56be commit ff19f4c
Show file tree
Hide file tree
Showing 16 changed files with 40 additions and 5 deletions.
1 change: 1 addition & 0 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5:
class BigQuery(Dialect):
UNNEST_COLUMN_ONLY = True
SUPPORTS_USER_DEFINED_TYPES = False
SUPPORTS_SEMI_ANTI_JOIN = False

# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
Expand Down
4 changes: 0 additions & 4 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,11 @@ class Parser(parser.Parser):
*parser.Parser.JOIN_KINDS,
TokenType.ANY,
TokenType.ASOF,
TokenType.ANTI,
TokenType.SEMI,
TokenType.ARRAY,
}

TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
TokenType.ANY,
TokenType.SEMI,
TokenType.ANTI,
TokenType.SETTINGS,
TokenType.FORMAT,
TokenType.ARRAY,
Expand Down
9 changes: 9 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[
if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe

if not klass.SUPPORTS_SEMI_ANTI_JOIN:
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.ANTI,
TokenType.SEMI,
}

klass.generator_class.can_identify = klass.can_identify

return klass
Expand Down Expand Up @@ -156,6 +162,9 @@ class Dialect(metaclass=_Dialect):
# Determines whether or not user-defined data types are supported
SUPPORTS_USER_DEFINED_TYPES = True

# Determines whether or not SEMI/ANTI JOINs are supported
SUPPORTS_SEMI_ANTI_JOIN = True

# Determines how function names are going to be normalized
NORMALIZE_FUNCTIONS: bool | str = "upper"

Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Drill(Dialect):
DATEINT_FORMAT = "'yyyyMMdd'"
TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
SUPPORTS_USER_DEFINED_TYPES = False
SUPPORTS_SEMI_ANTI_JOIN = False

TIME_MAPPING = {
"y": "%Y",
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
RENAME_TABLE_WITH_DB = False
NVL2_SUPPORTED = False
SEMI_ANTI_JOIN_WITH_SIDE = False

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class MySQL(Dialect):
TIME_FORMAT = "'%Y-%m-%d %T'"
DPIPE_IS_STRING_CONCAT = False
SUPPORTS_USER_DEFINED_TYPES = False
SUPPORTS_SEMI_ANTI_JOIN = False

# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
TIME_MAPPING = {
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class Presto(Dialect):
TIME_FORMAT = MySQL.TIME_FORMAT
TIME_MAPPING = MySQL.TIME_MAPPING
STRICT_STRING_CONCAT = True
SUPPORTS_SEMI_ANTI_JOIN = False

# https://github.com/trinodb/trino/issues/17
# https://github.com/trinodb/trino/issues/12289
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ class Snowflake(Dialect):
NULL_ORDERING = "nulls_are_large"
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
SUPPORTS_USER_DEFINED_TYPES = False
SUPPORTS_SEMI_ANTI_JOIN = False

TIME_MAPPING = {
"YYYY": "%Y",
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:
class SQLite(Dialect):
# https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
SUPPORTS_SEMI_ANTI_JOIN = False

class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@


class Teradata(Dialect):
SUPPORTS_SEMI_ANTI_JOIN = False

TIME_MAPPING = {
"Y": "%Y",
"YYYY": "%Y",
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class TSQL(Dialect):
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
NULL_ORDERING = "nulls_are_small"
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
SUPPORTS_SEMI_ANTI_JOIN = False

TIME_MAPPING = {
"year": "%Y",
Expand Down
10 changes: 9 additions & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ class Generator:
# Whether or not FILTER (WHERE cond) can be used for conditional aggregation
AGGREGATE_FILTER_SUPPORTED = True

# Whether or not JOIN sides (LEFT, RIGHT) are supported in conjunction with SEMI/ANTI join kinds
SEMI_ANTI_JOIN_WITH_SIDE = True

TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
Expand Down Expand Up @@ -1491,12 +1494,17 @@ def prior_sql(self, expression: exp.Prior) -> str:
return f"PRIOR {self.sql(expression, 'this')}"

def join_sql(self, expression: exp.Join) -> str:
if not self.SEMI_ANTI_JOIN_WITH_SIDE and expression.kind in ("SEMI", "ANTI"):
side = None
else:
side = expression.side

op_sql = " ".join(
op
for op in (
expression.method,
"GLOBAL" if expression.args.get("global") else None,
expression.side,
side,
expression.kind,
expression.hint if self.JOIN_HINTS else None,
)
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ class Parser(metaclass=_Parser):
INTERVAL_VARS = ID_VAR_TOKENS - {TokenType.END}

TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.ANTI,
TokenType.APPLY,
TokenType.ASOF,
TokenType.FULL,
Expand All @@ -324,6 +325,7 @@ class Parser(metaclass=_Parser):
TokenType.NATURAL,
TokenType.OFFSET,
TokenType.RIGHT,
TokenType.SEMI,
TokenType.WINDOW,
}

Expand Down
8 changes: 8 additions & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class TestDuckDB(Validator):
def test_duckdb(self):
for join_type in ("SEMI", "ANTI"):
exists = "EXISTS" if join_type == "SEMI" else "NOT EXISTS"

self.validate_all(
f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x",
write={
Expand All @@ -32,6 +33,13 @@ def test_duckdb(self):
"tsql": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)",
},
)
self.validate_all(
f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x",
read={
"duckdb": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x",
"spark": f"SELECT * FROM t1 LEFT {join_type} JOIN t2 ON t1.x = t2.x",
},
)

self.validate_all(
"WITH cte(x) AS (SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) SELECT AVG(x) FILTER (WHERE x > 1) FROM cte",
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def test_spark(self):
self.assertIsInstance(expr.args.get("ignore_nulls"), exp.Boolean)
self.assertEqual(expr.sql(dialect="spark"), "ANY_VALUE(col, TRUE)")

self.validate_identity("SELECT * FROM t1 SEMI JOIN t2 ON t1.x = t2.x")
self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)")
self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)")
self.validate_identity("REFRESH table a.b.c")
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/identity.sql
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ SELECT 1 FROM a INNER JOIN b ON a.x = b.x
SELECT 1 FROM a LEFT JOIN b ON a.x = b.x
SELECT 1 FROM a RIGHT JOIN b ON a.x = b.x
SELECT 1 FROM a CROSS JOIN b ON a.x = b.x
SELECT 1 FROM a SEMI JOIN b ON a.x = b.x
SELECT 1 FROM a LEFT SEMI JOIN b ON a.x = b.x
SELECT 1 FROM a LEFT ANTI JOIN b ON a.x = b.x
SELECT 1 FROM a RIGHT SEMI JOIN b ON a.x = b.x
Expand Down

0 comments on commit ff19f4c

Please sign in to comment.