Skip to content

Commit

Permalink
Fix(redshift): generate correct SQL VALUES clause alias (#2298)
Browse files Browse the repository at this point in the history
* Fix(redshift): generate correct SQL VALUES clause alias

* Fixup

* Refactor with functools.reduce

* Mypy fix
  • Loading branch information
georgesittas authored Sep 22, 2023
1 parent ef10fdf commit 8fe91e2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
20 changes: 9 additions & 11 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import typing as t
from collections import defaultdict
from functools import reduce

from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
Expand Down Expand Up @@ -1423,9 +1424,10 @@ def values_sql(self, expression: exp.Values) -> str:

# Converts `VALUES...` expression into a series of select unions.
expression = expression.copy()
column_names = expression.alias and expression.args["alias"].columns
alias_node = expression.args.get("alias")
column_names = alias_node and alias_node.columns

selects = []
selects: t.List[exp.Subqueryable] = []

for i, tup in enumerate(expression.expressions):
row = tup.expressions
Expand All @@ -1441,16 +1443,12 @@ def values_sql(self, expression: exp.Values) -> str:
# This may result in poor performance for large-cardinality `VALUES` tables, due to
# the deep nesting of the resulting exp.Unions. If this is a problem, either increase
# `sys.setrecursionlimit` to avoid RecursionErrors, or don't set `pretty`.
subquery_expression: exp.Select | exp.Union = selects[0]
if len(selects) > 1:
for select in selects[1:]:
subquery_expression = exp.union(
subquery_expression, select, distinct=False, copy=False
)

return self.subquery_sql(subquery_expression.subquery(expression.alias, copy=False))
subqueryable = reduce(lambda x, y: exp.union(x, y, distinct=False, copy=False), selects)
return self.subquery_sql(
subqueryable.subquery(alias_node and alias_node.this, copy=False)
)

alias = f" AS {expression.alias}" if expression.alias else ""
alias = f" AS {self.sql(alias_node, 'this')}" if alias_node else ""
unions = " UNION ALL ".join(self.sql(select) for select in selects)
return f"({unions}){alias}"

Expand Down
14 changes: 12 additions & 2 deletions tests/dialects/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,9 @@ def test_values(self):
},
)
self.validate_all(
"SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)",
'SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS "t" (a, b)',
write={
"redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t",
"redshift": 'SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS "t"',
},
)
self.validate_all(
Expand All @@ -341,6 +341,16 @@ def test_values(self):
"redshift": "INSERT INTO t (a, b) VALUES (1, 2), (3, 4)",
},
)
self.validate_identity(
'SELECT * FROM (VALUES (1)) AS "t"(a)',
'''SELECT
*
FROM (
SELECT
1 AS a
) AS "t"''',
pretty=True,
)

def test_create_table_like(self):
self.validate_all(
Expand Down

0 comments on commit 8fe91e2

Please sign in to comment.