Skip to content

Commit

Permalink
Feat(snowflake): add support for staged file table syntax (#2333)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Sep 27, 2023
1 parent 088d212 commit 8af4054
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 28 deletions.
22 changes: 22 additions & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,28 @@ class Parser(parser.Parser):
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
}

STAGED_FILE_SINGLE_TOKENS = {
TokenType.DOT,
TokenType.MOD,
TokenType.SLASH,
}

def _parse_table_parts(self, schema: bool = False) -> exp.Table:
# https://docs.snowflake.com/en/user-guide/querying-stage
if self._match_text_seq("@"):
table_name = "@"
while True:
self._advance()
table_name += self._prev.text
if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False):
break
while self._match_set(self.STAGED_FILE_SINGLE_TOKENS):
table_name += self._prev.text

return self.expression(exp.Table, this=exp.Identifier(this=table_name))

return super()._parse_table_parts(schema=schema)

def _parse_id_var(
self,
any_token: bool = True,
Expand Down
63 changes: 35 additions & 28 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,11 @@ class TestSnowflake(Validator):
dialect = "snowflake"

def test_snowflake(self):
self.validate_identity(
'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage'
)

self.validate_all(
"SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
read={
"oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
},
write={
"oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
"snowflake": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
},
)
self.validate_all(
"SELECT INSERT(a, 0, 0, 'b')",
read={
"mysql": "SELECT INSERT(a, 0, 0, 'b')",
"snowflake": "SELECT INSERT(a, 0, 0, 'b')",
"tsql": "SELECT STUFF(a, 0, 0, 'b')",
},
write={
"mysql": "SELECT INSERT(a, 0, 0, 'b')",
"snowflake": "SELECT INSERT(a, 0, 0, 'b')",
"tsql": "SELECT STUFF(a, 0, 0, 'b')",
},
)

self.validate_identity("SELECT * FROM @~")
self.validate_identity("SELECT * FROM @~/some/path/to/file.csv")
self.validate_identity("SELECT * FROM @mystage")
self.validate_identity("SELECT * FROM @namespace.mystage/path/to/file.json.gz")
self.validate_identity("SELECT * FROM @namespace.%table_name/path/to/file.json.gz")
self.validate_identity("LISTAGG(data['some_field'], ',')")
self.validate_identity("WEEKOFYEAR(tstamp)")
self.validate_identity("SELECT SUM(amount) FROM mytable GROUP BY ALL")
Expand Down Expand Up @@ -64,6 +41,9 @@ def test_snowflake(self):
self.validate_identity("COMMENT IF EXISTS ON TABLE foo IS 'bar'")
self.validate_identity("SELECT CONVERT_TIMEZONE('UTC', 'America/Los_Angeles', col)")
self.validate_identity("REGEXP_REPLACE('target', 'pattern', '\n')")
self.validate_identity(
'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage'
)
self.validate_identity(
'COPY INTO NEW_TABLE ("foo", "bar") FROM (SELECT $1, $2, $3, $4 FROM @%old_table)'
)
Expand All @@ -82,11 +62,38 @@ def test_snowflake(self):
"SELECT {'test': 'best'}::VARIANT",
"SELECT CAST(OBJECT_CONSTRUCT('test', 'best') AS VARIANT)",
)
self.validate_identity(
"SELECT parse_json($1):a.b FROM @mystage2/data1.json.gz",
"SELECT PARSE_JSON($1)['a'].b FROM @mystage2/data1.json.gz",
)

self.validate_all("CAST(x AS BYTEINT)", write={"snowflake": "CAST(x AS INT)"})
self.validate_all("CAST(x AS CHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"})
self.validate_all("CAST(x AS CHARACTER VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"})
self.validate_all("CAST(x AS NCHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"})
self.validate_all(
"SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
read={
"oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
},
write={
"oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
"snowflake": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
},
)
self.validate_all(
"SELECT INSERT(a, 0, 0, 'b')",
read={
"mysql": "SELECT INSERT(a, 0, 0, 'b')",
"snowflake": "SELECT INSERT(a, 0, 0, 'b')",
"tsql": "SELECT STUFF(a, 0, 0, 'b')",
},
write={
"mysql": "SELECT INSERT(a, 0, 0, 'b')",
"snowflake": "SELECT INSERT(a, 0, 0, 'b')",
"tsql": "SELECT STUFF(a, 0, 0, 'b')",
},
)
self.validate_all(
"ARRAY_GENERATE_RANGE(0, 3)",
write={
Expand Down

0 comments on commit 8af4054

Please sign in to comment.