Skip to content

Commit

Permalink
Feat: start with connect by closes #2112
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Aug 26, 2023
1 parent 0316f7f commit ca5c999
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 1 deletion.
10 changes: 10 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,15 @@ class Check(Expression):
pass


# https://docs.snowflake.com/en/sql-reference/constructs/connect-by
class Connect(Expression):
arg_types = {"start": False, "connect": True}


class Prior(Expression):
pass


class Directory(Expression):
# https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-dml-insert-overwrite-directory-hive.html
arg_types = {"this": True, "local": False, "row_format": False}
Expand Down Expand Up @@ -2351,6 +2360,7 @@ def with_(
"match": False,
"laterals": False,
"joins": False,
"connect": False,
"pivots": False,
"where": False,
"group": False,
Expand Down
11 changes: 11 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,16 @@ def having_sql(self, expression: exp.Having) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('HAVING')}{self.sep()}{this}"

def connect_sql(self, expression: exp.Connect) -> str:
start = self.sql(expression, "start")
start = self.seg(f"START WITH {start}") if start else ""
connect = self.sql(expression, "connect")
connect = self.seg(f"CONNECT BY {connect}")
return start + connect

def prior_sql(self, expression: exp.Prior) -> str:
return f"PRIOR {self.sql(expression, 'this')}"

def join_sql(self, expression: exp.Join) -> str:
op_sql = " ".join(
op
Expand Down Expand Up @@ -1688,6 +1698,7 @@ def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
return csv(
*sqls,
*[self.sql(join) for join in expression.args.get("joins") or []],
self.sql(expression, "connect"),
self.sql(expression, "match"),
*[self.sql(lateral) for lateral in expression.args.get("laterals") or []],
self.sql(expression, "where"),
Expand Down
20 changes: 19 additions & 1 deletion sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,8 @@ class Parser(metaclass=_Parser):
self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY),
),
TokenType.SORT_BY: lambda self: ("sort", self._parse_sort(exp.Sort, TokenType.SORT_BY)),
TokenType.CONNECT_BY: lambda self: ("connect", self._parse_connect(skip_start_token=True)),
TokenType.START_WITH: lambda self: ("connect", self._parse_connect()),
}

SET_PARSERS = {
Expand Down Expand Up @@ -2814,6 +2816,22 @@ def _parse_qualify(self) -> t.Optional[exp.Qualify]:
return None
return self.expression(exp.Qualify, this=self._parse_conjunction())

def _parse_connect(self, skip_start_token: bool = False) -> t.Optional[exp.Connect]:
if skip_start_token:
start = None
elif self._match(TokenType.START_WITH):
start = self._parse_conjunction()
else:
return None

self._match(TokenType.CONNECT_BY)
self.NO_PAREN_FUNCTION_PARSERS["PRIOR"] = lambda self: self.expression(
exp.Prior, this=self._parse_bitwise()
)
connect = self._parse_conjunction()
self.NO_PAREN_FUNCTION_PARSERS.pop("PRIOR")
return self.expression(exp.Connect, start=start, connect=connect)

def _parse_order(
self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
) -> t.Optional[exp.Expression]:
Expand Down Expand Up @@ -3623,7 +3641,7 @@ def _parse_generated_as_identity(self) -> exp.GeneratedAsIdentityColumnConstrain
identity = self._match_text_seq("IDENTITY")

if self._match(TokenType.L_PAREN):
if self._match_text_seq("START", "WITH"):
if self._match(TokenType.START_WITH):
this.set("start", self._parse_bitwise())
if self._match_text_seq("INCREMENT", "BY"):
this.set("increment", self._parse_bitwise())
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class TokenType(AutoName):
COMMAND = auto()
COMMENT = auto()
COMMIT = auto()
CONNECT_BY = auto()
CONSTRAINT = auto()
CREATE = auto()
CROSS = auto()
Expand Down Expand Up @@ -302,6 +303,7 @@ class TokenType(AutoName):
SIMILAR_TO = auto()
SOME = auto()
SORT_BY = auto()
START_WITH = auto()
STRUCT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
Expand Down Expand Up @@ -534,6 +536,7 @@ class Tokenizer(metaclass=_Tokenizer):
"COLLATE": TokenType.COLLATE,
"COLUMN": TokenType.COLUMN,
"COMMIT": TokenType.COMMIT,
"CONNECT BY": TokenType.CONNECT_BY,
"CONSTRAINT": TokenType.CONSTRAINT,
"CREATE": TokenType.CREATE,
"CROSS": TokenType.CROSS,
Expand Down Expand Up @@ -640,6 +643,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SIMILAR TO": TokenType.SIMILAR_TO,
"SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY,
"START WITH": TokenType.START_WITH,
"TABLE": TokenType.TABLE,
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
Expand Down
10 changes: 10 additions & 0 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ class TestSnowflake(Validator):
dialect = "snowflake"

def test_snowflake(self):
self.validate_all(
"SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
read={
"oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
},
write={
"oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
"snowflake": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
},
)
self.validate_all(
"SELECT INSERT(a, 0, 0, 'b')",
read={
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/identity.sql
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,7 @@ JSON_OBJECT('x': NULL, 'y': 1 WITH UNIQUE KEYS)
JSON_OBJECT('x': NULL, 'y': 1 ABSENT ON NULL WITH UNIQUE KEYS)
JSON_OBJECT('x': 1 RETURNING VARCHAR(100))
JSON_OBJECT('x': 1 RETURNING VARBINARY FORMAT JSON ENCODING UTF8)
PRIOR AS x
SELECT if.x
SELECT NEXT VALUE FOR db.schema.sequence_name
SELECT NEXT VALUE FOR db.schema.sequence_name OVER (ORDER BY foo), col
Expand Down

0 comments on commit ca5c999

Please sign in to comment.