Skip to content

Commit

Permalink
Feat: implement transform to add column names to recursive CTEs (#1687)
Browse files Browse the repository at this point in the history
* Feat: implement transform to add column names to recursive CTEs

* Add names to unnamed projections

* Use alias_or_name, and fix AST with to_identifier
  • Loading branch information
georgesittas authored May 24, 2023
1 parent a392114 commit 1cb9614
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
1 change: 1 addition & 0 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ class Generator(generator.Generator):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]),
exp.WithinGroup: transforms.preprocess(
[transforms.remove_within_group_for_percentiles]
),
Expand Down
20 changes: 20 additions & 0 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
import typing as t

from sqlglot import expressions as exp
Expand Down Expand Up @@ -247,6 +248,25 @@ def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expre
return expression


def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.With) and expression.recursive:
sequence = itertools.count()
next_name = lambda: f"_c_{next(sequence)}"

for cte in expression.expressions:
if not cte.args["alias"].columns:
query = cte.this
if isinstance(query, exp.Union):
query = query.this

cte.args["alias"].set(
"columns",
[exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
)

return expression


def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
Expand Down
30 changes: 30 additions & 0 deletions tests/dialects/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,36 @@ def test_presto(self):
self.validate_all("INTERVAL '1 day'", write={"trino": "INTERVAL '1' day"})
self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' week"})
self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' WEEKS"})
self.validate_all(
"WITH RECURSIVE t(n) AS (SELECT 1 AS n UNION ALL SELECT n + 1 AS n FROM t WHERE n < 4) SELECT SUM(n) FROM t",
read={
"postgres": "WITH RECURSIVE t AS (SELECT 1 AS n UNION ALL SELECT n + 1 AS n FROM t WHERE n < 4) SELECT SUM(n) FROM t",
},
)
self.validate_all(
"WITH RECURSIVE t(n, k) AS (SELECT 1 AS n, 2 AS k) SELECT SUM(n) FROM t",
read={
"postgres": "WITH RECURSIVE t AS (SELECT 1 AS n, 2 as k) SELECT SUM(n) FROM t",
},
)
self.validate_all(
"WITH RECURSIVE t1(n) AS (SELECT 1 AS n), t2(n) AS (SELECT 2 AS n) SELECT SUM(t1.n), SUM(t2.n) FROM t1, t2",
read={
"postgres": "WITH RECURSIVE t1 AS (SELECT 1 AS n), t2 AS (SELECT 2 AS n) SELECT SUM(t1.n), SUM(t2.n) FROM t1, t2",
},
)
self.validate_all(
"WITH RECURSIVE t(n, _c_0) AS (SELECT 1 AS n, (1 + 2)) SELECT * FROM t",
read={
"postgres": "WITH RECURSIVE t AS (SELECT 1 AS n, (1 + 2)) SELECT * FROM t",
},
)
self.validate_all(
'WITH RECURSIVE t(n, "1") AS (SELECT n, 1 FROM tbl) SELECT * FROM t',
read={
"postgres": "WITH RECURSIVE t AS (SELECT n, 1 FROM tbl) SELECT * FROM t",
},
)
self.validate_all(
"SELECT JSON_OBJECT(KEY 'key1' VALUE 1, KEY 'key2' VALUE TRUE)",
write={
Expand Down

0 comments on commit 1cb9614

Please sign in to comment.