diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index ad162a9222e6..a76e26aa7d98 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -532,7 +532,7 @@ impl Unparser<'_> { let input_exprs: Vec = union .inputs .iter() - .map(|input| self.select_to_sql_expr(input, &mut None)) + .map(|input| self.select_to_sql_expr(input, query)) .collect::>>()?; let union_expr = SetExpr::SetOperation { @@ -542,10 +542,12 @@ impl Unparser<'_> { right: Box::new(input_exprs[1].clone()), }; - query - .as_mut() - .expect("to have a query builder") - .body(Box::new(union_expr)); + let Some(query) = query.as_mut() else { + return internal_err!( + "UNION ALL operator only valid in a statement context" + ); + }; + query.body(Box::new(union_expr)); Ok(()) } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 76f19a23a0ae..49f4720ed137 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -146,7 +146,10 @@ fn roundtrip_statement() -> Result<()> { sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#, "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person", "WITH t1 AS (SELECT j1_id AS id, j1_string name FROM j1), t2 AS (SELECT j2_id AS id, j2_string name FROM j2) SELECT * FROM t1 JOIN t2 USING (id, name)", - ]; + "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3", + "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col), w4 as (SELECT 'd' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3 UNION ALL SELECT * FROM w4", + "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col) SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col", + ]; // For each test sql string, we transform as follows: // sql -> ast::Statement (s1) -> LogicalPlan (p1) -> ast::Statement (s2) -> LogicalPlan (p2)