Skip to content

Commit

Permalink
Chore(snowflake): remove select_sql and values_sql as the bug was res…
Browse files Browse the repository at this point in the history
…olved (#1653)
  • Loading branch information
georgesittas authored May 18, 2023
1 parent 19a56d9 commit 862deab
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 45 deletions.
45 changes: 1 addition & 44 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
var_map_sql,
)
from sqlglot.expressions import Literal
from sqlglot.helper import flatten, seq_get
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType

Expand Down Expand Up @@ -361,53 +361,10 @@ def intersect_op(self, expression: exp.Intersect) -> str:
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)

def values_sql(self, expression: exp.Values) -> str:
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted.
We also want to make sure that after we find matches where we need to unquote a column that we prevent users
from adding quotes to the column by using the `identify` argument when generating the SQL.
"""
alias = expression.args.get("alias")
if alias and alias.args.get("columns"):
expression = expression.transform(
lambda node: exp.Identifier(**{**node.args, "quoted": False})
if isinstance(node, exp.Identifier)
and isinstance(node.parent, exp.TableAlias)
and node.arg_key == "columns"
else node,
)
return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
return super().values_sql(expression)

def settag_sql(self, expression: exp.SetTag) -> str:
action = "UNSET" if expression.args.get("unset") else "SET"
return f"{action} TAG {self.expressions(expression)}"

def select_sql(self, expression: exp.Select) -> str:
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need
to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when
generating the SQL.
Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the
expression. This might not be true in a case where the same column name can be sourced from another table that can
properly quote but should be true in most cases.
"""
values_identifiers = set(
flatten(
(v.args.get("alias") or exp.Alias()).args.get("columns", [])
for v in expression.find_all(exp.Values)
)
)
if values_identifiers:
expression = expression.transform(
lambda node: exp.Identifier(**{**node.args, "quoted": False})
if isinstance(node, exp.Identifier) and node in values_identifiers
else node,
)
return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
return super().select_sql(expression)

def describe_sql(self, expression: exp.Describe) -> str:
# Default to table if kind is unknown
kind_value = expression.args.get("kind") or "TABLE"
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def test_minus(self):

def test_values(self):
self.validate_all(
'SELECT c0, c1 FROM (VALUES (1, 2), (3, 4)) AS "t0"(c0, c1)',
'SELECT "c0", "c1" FROM (VALUES (1, 2), (3, 4)) AS "t0"("c0", "c1")',
read={
"spark": "SELECT `c0`, `c1` FROM (VALUES (1, 2), (3, 4)) AS `t0`(`c0`, `c1`)",
},
Expand Down

0 comments on commit 862deab

Please sign in to comment.