Skip to content

Commit

Permalink
Feat(presto): transpile explode/posexplode into (cross join) unnest (#…
Browse files Browse the repository at this point in the history
…1501)

* Feat(presto): transpile explode/posexplode into (cross join) unnest

* Use more descriptive auto-generated names

* Make exp.alias_ pure, change base names to pos, col
  • Loading branch information
georgesittas authored Apr 30, 2023
1 parent 2dcbc7f commit 80287dd
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 2 deletions.
4 changes: 3 additions & 1 deletion sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
Expand Down Expand Up @@ -303,6 +302,9 @@ class Generator(generator.Generator):
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_qualify, transforms.explode_to_unnest]
),
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4741,6 +4741,8 @@ def alias_(

if table:
table_alias = TableAlias(this=alias)

exp = exp.copy()
exp.set("alias", table_alias)

if not isinstance(table, bool):
Expand Down
63 changes: 62 additions & 1 deletion sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr


def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
"""Convert cross join unnest into lateral view explode. Used in presto -> hive"""
"""Convert cross join unnest into lateral view explode (used in presto -> hive)."""
if isinstance(expression, exp.Select):
for join in expression.args.get("joins") or []:
unnest = join.this
Expand All @@ -161,6 +161,67 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
return expression


def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
"""Convert explode/posexplode into unnest (used in hive -> presto)."""
if isinstance(expression, exp.Select):
from sqlglot.optimizer.scope import build_scope

taken_select_names = set(expression.named_selects)
taken_source_names = set(build_scope(expression).selected_sources)

for select in expression.selects:
to_replace = select

pos_alias = ""
explode_alias = ""

if isinstance(select, exp.Alias):
explode_alias = select.alias
select = select.this
elif isinstance(select, exp.Aliases):
pos_alias = select.aliases[0].name
explode_alias = select.aliases[1].name
select = select.this

if isinstance(select, (exp.Explode, exp.Posexplode)):
is_posexplode = isinstance(select, exp.Posexplode)

explode_arg = select.this
unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)

# This ensures that we won't use [POS]EXPLODE's argument as a new selection
if isinstance(explode_arg, exp.Column):
taken_select_names.add(explode_arg.output_name)

unnest_source_alias = find_new_name(taken_source_names, "_u")
taken_source_names.add(unnest_source_alias)

if not explode_alias:
explode_alias = find_new_name(taken_select_names, "col")
taken_select_names.add(explode_alias)

if is_posexplode:
pos_alias = find_new_name(taken_select_names, "pos")
taken_select_names.add(pos_alias)

if is_posexplode:
column_names = [explode_alias, pos_alias]
to_replace.pop()
expression.select(pos_alias, explode_alias, copy=False)
else:
column_names = [explode_alias]
to_replace.replace(exp.column(explode_alias))

unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)

if not expression.args.get("from"):
expression.from_(unnest, copy=False)
else:
expression.join(unnest, join_type="CROSS", copy=False)

return expression


def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
"""Remove table refs from columns in when statements."""
if isinstance(expression, exp.Merge):
Expand Down
38 changes: 38 additions & 0 deletions tests/dialects/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,3 +676,41 @@ def test_json(self):
"presto": "SELECT CAST(ARRAY[1, 23, 456] AS JSON)",
},
)

def test_explode_to_unnest(self):
self.validate_all(
"SELECT col FROM tbl CROSS JOIN UNNEST(x) AS _u(col)",
read={"spark": "SELECT EXPLODE(x) FROM tbl"},
)
self.validate_all(
"SELECT col_2 FROM _u CROSS JOIN UNNEST(col) AS _u_2(col_2)",
read={"spark": "SELECT EXPLODE(col) FROM _u"},
)
self.validate_all(
"SELECT exploded FROM schema.tbl CROSS JOIN UNNEST(col) AS _u(exploded)",
read={"spark": "SELECT EXPLODE(col) AS exploded FROM schema.tbl"},
)
self.validate_all(
"SELECT col FROM UNNEST(SEQUENCE(1, 2)) AS _u(col)",
read={"spark": "SELECT EXPLODE(SEQUENCE(1, 2))"},
)
self.validate_all(
"SELECT col FROM tbl AS t CROSS JOIN UNNEST(t.c) AS _u(col)",
read={"spark": "SELECT EXPLODE(t.c) FROM tbl t"},
)
self.validate_all(
"SELECT pos, col FROM UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u(col, pos)",
read={"spark": "SELECT POSEXPLODE(SEQUENCE(2, 3))"},
)
self.validate_all(
"SELECT pos, col FROM tbl CROSS JOIN UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u(col, pos)",
read={"spark": "SELECT POSEXPLODE(SEQUENCE(2, 3)) FROM tbl"},
)
self.validate_all(
"SELECT pos, col FROM tbl AS t CROSS JOIN UNNEST(t.c) WITH ORDINALITY AS _u(col, pos)",
read={"spark": "SELECT POSEXPLODE(t.c) FROM tbl t"},
)
self.validate_all(
"SELECT col, pos, pos_2, col_2 FROM _u CROSS JOIN UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u_2(col_2, pos_2)",
read={"spark": "SELECT col, pos, POSEXPLODE(SEQUENCE(2, 3)) FROM _u"},
)

0 comments on commit 80287dd

Please sign in to comment.