Skip to content

Commit

Permalink
Fix: unnest complex closes #2284
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Sep 22, 2023
1 parent fc793c4 commit 06e0869
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 26 deletions.
50 changes: 27 additions & 23 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,22 +188,27 @@ def new_name(names: t.Set[str], name: str) -> str:
)

# we use list here because expression.selects is mutated inside the loop
for select in list(expression.selects):
to_replace = select
pos_alias = ""
explode_alias = ""

if isinstance(select, exp.Alias):
explode_alias = select.alias
select = select.this
elif isinstance(select, exp.Aliases):
pos_alias = select.aliases[0].name
explode_alias = select.aliases[1].name
select = select.this

if isinstance(select, (exp.Explode, exp.Posexplode)):
is_posexplode = isinstance(select, exp.Posexplode)
explode_arg = select.this
for select in expression.selects.copy():
explode = select.find(exp.Explode, exp.Posexplode)

if isinstance(explode, (exp.Explode, exp.Posexplode)):
pos_alias = ""
explode_alias = ""

if isinstance(select, exp.Alias):
explode_alias = select.alias
alias = select
elif isinstance(select, exp.Aliases):
pos_alias = select.aliases[0].name
explode_alias = select.aliases[1].name
alias = select.replace(exp.alias_(select.this, "", copy=False))
else:
alias = select.replace(exp.alias_(select, ""))
explode = alias.find(exp.Explode, exp.Posexplode)
assert explode

is_posexplode = isinstance(explode, exp.Posexplode)
explode_arg = explode.this

# This ensures that we won't use [POS]EXPLODE's argument as a new selection
if isinstance(explode_arg, exp.Column):
Expand All @@ -220,26 +225,25 @@ def new_name(names: t.Set[str], name: str) -> str:
if not pos_alias:
pos_alias = new_name(taken_select_names, "pos")

alias.set("alias", exp.to_identifier(explode_alias))

column = exp.If(
this=exp.column(series_alias).eq(exp.column(pos_alias)),
true=exp.column(explode_alias),
).as_(explode_alias)
)

explode.replace(column)

if is_posexplode:
expressions = expression.expressions
index = expressions.index(to_replace)
expressions.pop(index)
expressions.insert(index, column)
expressions.insert(
index + 1,
expressions.index(alias) + 1,
exp.If(
this=exp.column(series_alias).eq(exp.column(pos_alias)),
true=exp.column(pos_alias),
).as_(pos_alias),
)
expression.set("expressions", expressions)
else:
to_replace.replace(column)

if not arrays:
if expression.args.get("from"):
Expand Down
12 changes: 12 additions & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,18 @@ def test_duckdb(self):
"presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_3, col_2) AS col_2, IF(pos = pos_4, col_3) AS col_3 FROM x, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2, 3]) AND pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5]) AND pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (pos = pos_4 OR (pos > CARDINALITY(ARRAY[6]) AND pos_4 = CARDINALITY(ARRAY[6])))",
},
)
self.validate_all(
"SELECT UNNEST(x) + 1",
write={
"bigquery": "SELECT IF(pos = pos_2, col, NULL) + 1 AS col FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(x)) - 1)) AS pos CROSS JOIN UNNEST(x) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(x) - 1) AND pos_2 = (ARRAY_LENGTH(x) - 1))",
},
)
self.validate_all(
"SELECT UNNEST(x) + 1 AS y",
write={
"bigquery": "SELECT IF(pos = pos_2, y, NULL) + 1 AS y FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(x)) - 1)) AS pos CROSS JOIN UNNEST(x) AS y WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(x) - 1) AND pos_2 = (ARRAY_LENGTH(x) - 1))",
},
)

self.validate_identity("SELECT i FROM RANGE(5) AS _(i) ORDER BY i ASC")
self.validate_identity("[x.STRING_SPLIT(' ')[1] FOR x IN ['1', '2', 3] IF x.CONTAINS('1')]")
Expand Down
12 changes: 9 additions & 3 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,10 +616,16 @@ def test_explode_to_unnest(self):
},
)
self.validate_all(
"SELECT POSEXPLODE(ARRAY(2, 3))",
"SELECT POSEXPLODE(ARRAY(2, 3)) AS x",
write={
"bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_2, pos_2, NULL) AS pos_2 FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([2, 3])) - 1)) AS pos CROSS JOIN UNNEST([2, 3]) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH([2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([2, 3]) - 1))",
"presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_2, pos_2) AS pos_2 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[2, 3]) AND pos_2 = CARDINALITY(ARRAY[2, 3]))",
"bigquery": "SELECT IF(pos = pos_2, x, NULL) AS x, IF(pos = pos_2, pos_2, NULL) AS pos_2 FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([2, 3])) - 1)) AS pos CROSS JOIN UNNEST([2, 3]) AS x WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH([2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([2, 3]) - 1))",
"presto": "SELECT IF(pos = pos_2, x) AS x, IF(pos = pos_2, pos_2) AS pos_2 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_2(x, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[2, 3]) AND pos_2 = CARDINALITY(ARRAY[2, 3]))",
},
)
self.validate_all(
"SELECT POSEXPLODE(x) AS (a, b)",
write={
"presto": "SELECT IF(pos = a, b) AS b, IF(pos = a, a) AS a FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(x)))) AS _u(pos) CROSS JOIN UNNEST(x) WITH ORDINALITY AS _u_2(b, a) WHERE pos = a OR (pos > CARDINALITY(x) AND a = CARDINALITY(x))",
},
)
self.validate_all(
Expand Down

0 comments on commit 06e0869

Please sign in to comment.