Skip to content

Commit

Permalink
Avoid changing expression names during constant folding (#1319)
Browse files Browse the repository at this point in the history
* Avoid changing output column name before pushdown projection.

* Avoid early optimization which is duplicate.

* Modify test.

* Revert "Modify test."

This reverts commit ebf67d3.

* Revert "Avoid early optimization which is duplicate."

This reverts commit fe8445d.

* Add aliases during constant folding.

* Some expressions don't support name.

* Don't create redundant alias.

* Only add alias for certain plans.

* Fix clippy.

* Fix.

* Revert "Fix."

This reverts commit d767aeb.

* Apply to all nodes and update tests.

* Unalias when push donw to TableScan.

* Update more tests.

* Remove previous change.
  • Loading branch information
viirya authored Nov 22, 2021
1 parent 7905926 commit 0df9b99
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 40 deletions.
9 changes: 9 additions & 0 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,15 @@ pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
exprs.into_iter().map(unnormalize_col).collect()
}

/// Recursively un-alias an expressions
#[inline]
pub fn unalias(expr: Expr) -> Expr {
match expr {
Expr::Alias(sub_expr, _) => unalias(*sub_expr),
_ => expr,
}
}

/// Create an expression to represent the min() aggregate function
pub fn min(expr: Expr) -> Expr {
Expr::AggregateFunction {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub use expr::{
max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random,
regexp_match, regexp_replace, repeat, replace, replace_col, reverse, right, round,
rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc,
starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias,
unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter,
ExpressionVisitor, Literal, Recursion, RewriteRecursion,
};
Expand Down
46 changes: 31 additions & 15 deletions datafusion/src/optimizer/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ impl OptimizerRule for ConstantFolding {
.expressions()
.into_iter()
.map(|e| {
// We need to keep original expression name, if any.
// Constant folding should not change expression name.
let name = &e.name(plan.schema());

// TODO iterate until no changes are made
// during rewrite (evaluating constants can
// enable new simplifications and
Expand All @@ -101,7 +105,18 @@ impl OptimizerRule for ConstantFolding {
// fold constants and then simplify
.rewrite(&mut const_evaluator)?
.rewrite(&mut simplifier)?;
Ok(new_e)

let new_name = &new_e.name(plan.schema());

if let (Ok(expr_name), Ok(new_expr_name)) = (name, new_name) {
if expr_name != new_expr_name {
Ok(new_e.alias(expr_name))
} else {
Ok(new_e)
}
} else {
Ok(new_e)
}
})
.collect::<Result<Vec<_>>>()?;

Expand Down Expand Up @@ -626,8 +641,8 @@ mod tests {

let expected = "\
Projection: #test.a\
\n Filter: NOT #test.c\
\n Filter: #test.b\
\n Filter: NOT #test.c AS test.c = Boolean(false)\
\n Filter: #test.b AS test.b = Boolean(true)\
\n TableScan: test projection=None";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -647,8 +662,8 @@ mod tests {
let expected = "\
Projection: #test.a\
\n Limit: 1\
\n Filter: #test.c\
\n Filter: NOT #test.b\
\n Filter: #test.c AS test.c != Boolean(false)\
\n Filter: NOT #test.b AS test.b != Boolean(true)\
\n TableScan: test projection=None";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -665,7 +680,7 @@ mod tests {

let expected = "\
Projection: #test.a\
\n Filter: NOT #test.b AND #test.c\
\n Filter: NOT #test.b AND #test.c AS test.b != Boolean(true) AND test.c = Boolean(true)\
\n TableScan: test projection=None";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -682,7 +697,7 @@ mod tests {

let expected = "\
Projection: #test.a\
\n Filter: NOT #test.b OR NOT #test.c\
\n Filter: NOT #test.b OR NOT #test.c AS test.b != Boolean(true) OR test.c = Boolean(false)\
\n TableScan: test projection=None";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -699,7 +714,7 @@ mod tests {

let expected = "\
Projection: #test.a\
\n Filter: #test.b\
\n Filter: #test.b AS NOT test.b = Boolean(false)\
\n TableScan: test projection=None";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -714,7 +729,7 @@ mod tests {
.build()?;

let expected = "\
Projection: #test.a, #test.d, NOT #test.b\
Projection: #test.a, #test.d, NOT #test.b AS test.b = Boolean(false)\
\n TableScan: test projection=None";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -733,7 +748,7 @@ mod tests {
.build()?;

let expected = "\
Aggregate: groupBy=[[#test.a, #test.c]], aggr=[[MAX(#test.b), MIN(#test.b)]]\
Aggregate: groupBy=[[#test.a, #test.c]], aggr=[[MAX(#test.b) AS MAX(test.b = Boolean(true)), MIN(#test.b)]]\
\n Projection: #test.a, #test.c, #test.b\
\n TableScan: test projection=None";

Expand Down Expand Up @@ -789,7 +804,7 @@ mod tests {
.build()
.unwrap();

let expected = "Projection: TimestampNanosecond(1599566400000000000)\
let expected = "Projection: TimestampNanosecond(1599566400000000000) AS totimestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\
\n TableScan: test projection=None"
.to_string();
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
Expand Down Expand Up @@ -824,7 +839,7 @@ mod tests {
.build()
.unwrap();

let expected = "Projection: Int32(0)\
let expected = "Projection: Int32(0) AS CAST(Utf8(\"0\") AS Int32)\
\n TableScan: test projection=None";
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
assert_eq!(expected, actual);
Expand Down Expand Up @@ -873,7 +888,7 @@ mod tests {
// expect the same timestamp appears in both exprs
let actual = get_optimized_plan_formatted(&plan, &time);
let expected = format!(
"Projection: TimestampNanosecond({}), TimestampNanosecond({}) AS t2\
"Projection: TimestampNanosecond({}) AS now(), TimestampNanosecond({}) AS t2\
\n TableScan: test projection=None",
time.timestamp_nanos(),
time.timestamp_nanos()
Expand All @@ -897,7 +912,8 @@ mod tests {
.unwrap();

let actual = get_optimized_plan_formatted(&plan, &time);
let expected = "Projection: NOT #test.a\
let expected =
"Projection: NOT #test.a AS Boolean(true) OR Boolean(false) != test.a\
\n TableScan: test projection=None";

assert_eq!(actual, expected);
Expand Down Expand Up @@ -929,7 +945,7 @@ mod tests {

// Note that constant folder runs and folds the entire
// expression down to a single constant (true)
let expected = "Filter: Boolean(true)\
let expected = "Filter: Boolean(true) AS CAST(now() AS Int64) < CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Int64) + Int32(50000)\
\n TableScan: test projection=None";
let actual = get_optimized_plan_formatted(&plan, &time);

Expand Down
7 changes: 4 additions & 3 deletions datafusion/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use super::{
use crate::execution::context::ExecutionContextState;
use crate::logical_plan::plan::{EmptyRelation, Filter, Projection, Window};
use crate::logical_plan::{
unnormalize_cols, CrossJoin, DFSchema, Expr, LogicalPlan, Operator,
unalias, unnormalize_cols, CrossJoin, DFSchema, Expr, LogicalPlan, Operator,
Partitioning as LogicalPartitioning, PlanType, Repartition, ToStringifiedPlan, Union,
UserDefinedLogicalNode,
};
Expand Down Expand Up @@ -339,7 +339,8 @@ impl DefaultPhysicalPlanner {
// doesn't know (nor should care) how the relation was
// referred to in the query
let filters = unnormalize_cols(filters.iter().cloned());
source.scan(projection, batch_size, &filters, *limit).await
let unaliased: Vec<Expr> = filters.into_iter().map(unalias).collect();
source.scan(projection, batch_size, &unaliased, *limit).await
}
LogicalPlan::Values(Values {
values,
Expand Down Expand Up @@ -1340,7 +1341,7 @@ impl DefaultPhysicalPlanner {
physical_input_schema: &Schema,
ctx_state: &ExecutionContextState,
) -> Result<Arc<dyn AggregateExpr>> {
// unpack aliased logical expressions, e.g. "sum(col) as total"
// unpack (nested) aliased logical expressions, e.g. "sum(col) as total"
let (name, e) = match e {
Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()),
_ => (physical_name(e)?, e),
Expand Down
58 changes: 37 additions & 21 deletions datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,22 @@ async fn csv_query_approx_count() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn query_count_without_from() -> Result<()> {
let mut ctx = ExecutionContext::new();
let sql = "SELECT count(1 + 1)";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+----------------------------+",
"| COUNT(Int64(1) + Int64(1)) |",
"+----------------------------+",
"| 1 |",
"+----------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn csv_query_array_agg() -> Result<()> {
let mut ctx = ExecutionContext::new();
Expand Down Expand Up @@ -1663,12 +1679,12 @@ async fn csv_query_cast_literal() -> Result<()> {
let actual = execute_to_batches(&mut ctx, sql).await;

let expected = vec![
"+--------------------+------------+",
"| c12 | Float64(1) |",
"+--------------------+------------+",
"| 0.9294097332465232 | 1 |",
"| 0.3114712539863804 | 1 |",
"+--------------------+------------+",
"+--------------------+---------------------------+",
"| c12 | CAST(Int64(1) AS Float64) |",
"+--------------------+---------------------------+",
"| 0.9294097332465232 | 1 |",
"| 0.3114712539863804 | 1 |",
"+--------------------+---------------------------+",
];

assert_batches_eq!(expected, &actual);
Expand Down Expand Up @@ -4410,11 +4426,11 @@ async fn query_without_from() -> Result<()> {
let sql = "SELECT 1+2, 3/4, cos(0)";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+----------+----------+------------+",
"| Int64(3) | Int64(0) | Float64(1) |",
"+----------+----------+------------+",
"| 3 | 0 | 1 |",
"+----------+----------+------------+",
"+---------------------+---------------------+---------------+",
"| Int64(1) + Int64(2) | Int64(3) / Int64(4) | cos(Int64(0)) |",
"+---------------------+---------------------+---------------+",
"| 3 | 0 | 1 |",
"+---------------------+---------------------+---------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down Expand Up @@ -5865,11 +5881,11 @@ async fn case_with_bool_type_result() -> Result<()> {
let sql = "select case when 'cpu' != 'cpu' then true else false end";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+----------------+",
"| Boolean(false) |",
"+----------------+",
"| false |",
"+----------------+",
"+---------------------------------------------------------------------------------+",
"| CASE WHEN Utf8(\"cpu\") != Utf8(\"cpu\") THEN Boolean(true) ELSE Boolean(false) END |",
"+---------------------------------------------------------------------------------+",
"| false |",
"+---------------------------------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
Expand All @@ -5882,11 +5898,11 @@ async fn use_between_expression_in_select_query() -> Result<()> {
let sql = "SELECT 1 NOT BETWEEN 3 AND 5";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+---------------+",
"| Boolean(true) |",
"+---------------+",
"| true |",
"+---------------+",
"+--------------------------------------------+",
"| Int64(1) NOT BETWEEN Int64(3) AND Int64(5) |",
"+--------------------------------------------+",
"| true |",
"+--------------------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down

0 comments on commit 0df9b99

Please sign in to comment.