Skip to content

Commit

Permalink
Fix: improve performance of VALUES -> UNION ALL transpilation (#2283)
Browse files Browse the repository at this point in the history
* Fix: improve performance of VALUES -> UNION ALL transpilation

* Create Union expressions when pretty=True
  • Loading branch information
georgesittas authored Sep 21, 2023
1 parent ef062d1 commit 310d691
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
25 changes: 15 additions & 10 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,9 +1414,6 @@ def values_sql(self, expression: exp.Values) -> str:
return f"{values} AS {alias}" if alias else values

# Converts `VALUES...` expression into a series of select unions.
# Note: If you have a lot of unions then this will result in a large number of recursive statements to
# evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
# very slow.
expression = expression.copy()
column_names = expression.alias and expression.args["alias"].columns

Expand All @@ -1432,14 +1429,22 @@ def values_sql(self, expression: exp.Values) -> str:

selects.append(exp.Select(expressions=row))

subquery_expression: exp.Select | exp.Union = selects[0]
if len(selects) > 1:
for select in selects[1:]:
subquery_expression = exp.union(
subquery_expression, select, distinct=False, copy=False
)
if self.pretty:
# This may result in poor performance for large-cardinality `VALUES` tables, due to
# the deep nesting of the resulting exp.Unions. If this is a problem, either increase
# `sys.setrecursionlimit` to avoid RecursionErrors, or don't set `pretty`.
subquery_expression: exp.Select | exp.Union = selects[0]
if len(selects) > 1:
for select in selects[1:]:
subquery_expression = exp.union(
subquery_expression, select, distinct=False, copy=False
)

return self.subquery_sql(subquery_expression.subquery(expression.alias, copy=False))

return self.subquery_sql(subquery_expression.subquery(expression.alias, copy=False))
alias = f" AS {expression.alias}" if expression.alias else ""
unions = " UNION ALL ".join(self.sql(select) for select in selects)
return f"({unions}){alias}"

def var_sql(self, expression: exp.Var) -> str:
return self.sql(expression, "this")
Expand Down
21 changes: 21 additions & 0 deletions tests/dialects/test_redshift.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sqlglot import transpile
from tests.dialects.test_dialect import Validator


Expand Down Expand Up @@ -270,6 +271,26 @@ def test_identity(self):
)

def test_values(self):
# Test crazy-sized VALUES clause to UNION ALL conversion to ensure we don't get RecursionError
values = [str(v) for v in range(0, 10000)]
values_query = f"SELECT * FROM (VALUES {', '.join('(' + v + ')' for v in values)})"
union_query = f"SELECT * FROM ({' UNION ALL '.join('SELECT ' + v for v in values)})"
self.assertEqual(transpile(values_query, write="redshift")[0], union_query)

self.validate_identity(
"SELECT * FROM (VALUES (1), (2))",
"""SELECT
*
FROM (
SELECT
1
UNION ALL
SELECT
2
)""",
pretty=True,
)

self.validate_all(
"SELECT * FROM (VALUES (1, 2)) AS t",
write={
Expand Down

0 comments on commit 310d691

Please sign in to comment.