Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid changing expression names during constant folding #1319

Merged
merged 17 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,15 @@ pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
exprs.into_iter().map(unnormalize_col).collect()
}

/// Recursively un-alias an expressions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "recursively" part may be misleading, this function unwraps all current aliases

So an expr like (a as "foo") + (b as "bar") will not be unaliased, but an expr like (a as "foo") as "bar" will be unaliased to "foo"

#[inline]
pub fn unalias(expr: Expr) -> Expr {
match expr {
Expr::Alias(sub_expr, _) => unalias(*sub_expr),
_ => expr,
}
}

/// Create an expression to represent the min() aggregate function
pub fn min(expr: Expr) -> Expr {
Expr::AggregateFunction {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub use expr::{
max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random,
regexp_match, regexp_replace, repeat, replace, replace_col, reverse, right, round,
rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc,
starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias,
unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter,
ExpressionVisitor, Literal, Recursion, RewriteRecursion,
};
Expand Down
46 changes: 31 additions & 15 deletions datafusion/src/optimizer/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ impl OptimizerRule for ConstantFolding {
.expressions()
.into_iter()
.map(|e| {
// We need to keep original expression name, if any.
// Constant folding should not change expression name.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

let name = &e.name(plan.schema());

// TODO iterate until no changes are made
// during rewrite (evaluating constants can
// enable new simplifications and
Expand All @@ -101,7 +105,18 @@ impl OptimizerRule for ConstantFolding {
// fold constants and then simplify
.rewrite(&mut const_evaluator)?
.rewrite(&mut simplifier)?;
Ok(new_e)

let new_name = &new_e.name(plan.schema());

if let (Ok(expr_name), Ok(new_expr_name)) = (name, new_name) {
if expr_name != new_expr_name {
Ok(new_e.alias(expr_name))
} else {
Ok(new_e)
}
} else {
Ok(new_e)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry we may be silently ignoring some real issues in the future.

However, I tried checking expr_name and new_expr_name for errors and I got a bunch of errors like

---- execution::context::tests::window_partition_by stdout ----
Error: Internal("Create name does not support sort expression")
thread 'execution::context::tests::window_partition_by' panicked at 'assertion failed: `(left == right)`
  left: `1`,
 right: `0`: the test returned a termination value with a non-zero status code (1) which indicates a failure', /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/test/src/lib.rs:194:5
stack backtrace:
   0: rust_begin_unwind
             at /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/std/src/panicking.rs:517:5
   1: core::panicking::panic_fmt
             at /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/core/src/panicking.rs:101:14
   2: core::panicking::assert_failed_inner
             at /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/core/src/panicking.rs:177:23
   3: core::panicking::assert_failed
             at /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/core/src/panicking.rs:140:5
   4: test::assert_test_result
             at /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/test/src/lib.rs:194:5
   5: datafusion::execution::context::tests::window_partition_by::{{closure}}
             at ./src/execution/context.rs:1771:11
   6: core::ops::function::FnOnce::call_once
             at /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/core/src/ops/function.rs:227:5
   7: core::ops::function::FnOnce::call_once
             at /rustc/59eed8a2aac0230a8b53e89d4e99d55912ba6b35/library/core/src/ops/function.rs:227:5
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.

So I suppose this is as good as we are going to do for now

}
})
.collect::<Result<Vec<_>>>()?;

Expand Down Expand Up @@ -626,8 +641,8 @@ mod tests {

let expected = "\
Projection: #test.a\
\n Filter: NOT #test.c\
\n Filter: #test.b\
\n Filter: NOT #test.c AS test.c = Boolean(false)\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

\n Filter: #test.b AS test.b = Boolean(true)\
\n TableScan: test projection=None";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -647,8 +662,8 @@ mod tests {
let expected = "\
Projection: #test.a\
\n Limit: 1\
\n Filter: #test.c\
\n Filter: NOT #test.b\
\n Filter: #test.c AS test.c != Boolean(false)\
\n Filter: NOT #test.b AS test.b != Boolean(true)\
\n TableScan: test projection=None";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -665,7 +680,7 @@ mod tests {

let expected = "\
Projection: #test.a\
\n Filter: NOT #test.b AND #test.c\
\n Filter: NOT #test.b AND #test.c AS test.b != Boolean(true) AND test.c = Boolean(true)\
\n TableScan: test projection=None";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -682,7 +697,7 @@ mod tests {

let expected = "\
Projection: #test.a\
\n Filter: NOT #test.b OR NOT #test.c\
\n Filter: NOT #test.b OR NOT #test.c AS test.b != Boolean(true) OR test.c = Boolean(false)\
\n TableScan: test projection=None";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -699,7 +714,7 @@ mod tests {

let expected = "\
Projection: #test.a\
\n Filter: #test.b\
\n Filter: #test.b AS NOT test.b = Boolean(false)\
\n TableScan: test projection=None";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -714,7 +729,7 @@ mod tests {
.build()?;

let expected = "\
Projection: #test.a, #test.d, NOT #test.b\
Projection: #test.a, #test.d, NOT #test.b AS test.b = Boolean(false)\
\n TableScan: test projection=None";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -733,7 +748,7 @@ mod tests {
.build()?;

let expected = "\
Aggregate: groupBy=[[#test.a, #test.c]], aggr=[[MAX(#test.b), MIN(#test.b)]]\
Aggregate: groupBy=[[#test.a, #test.c]], aggr=[[MAX(#test.b) AS MAX(test.b = Boolean(true)), MIN(#test.b)]]\
\n Projection: #test.a, #test.c, #test.b\
\n TableScan: test projection=None";

Expand Down Expand Up @@ -789,7 +804,7 @@ mod tests {
.build()
.unwrap();

let expected = "Projection: TimestampNanosecond(1599566400000000000)\
let expected = "Projection: TimestampNanosecond(1599566400000000000) AS totimestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\
\n TableScan: test projection=None"
.to_string();
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
Expand Down Expand Up @@ -824,7 +839,7 @@ mod tests {
.build()
.unwrap();

let expected = "Projection: Int32(0)\
let expected = "Projection: Int32(0) AS CAST(Utf8(\"0\") AS Int32)\
\n TableScan: test projection=None";
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
assert_eq!(expected, actual);
Expand Down Expand Up @@ -873,7 +888,7 @@ mod tests {
// expect the same timestamp appears in both exprs
let actual = get_optimized_plan_formatted(&plan, &time);
let expected = format!(
"Projection: TimestampNanosecond({}), TimestampNanosecond({}) AS t2\
"Projection: TimestampNanosecond({}) AS now(), TimestampNanosecond({}) AS t2\
\n TableScan: test projection=None",
time.timestamp_nanos(),
time.timestamp_nanos()
Expand All @@ -897,7 +912,8 @@ mod tests {
.unwrap();

let actual = get_optimized_plan_formatted(&plan, &time);
let expected = "Projection: NOT #test.a\
let expected =
"Projection: NOT #test.a AS Boolean(true) OR Boolean(false) != test.a\
\n TableScan: test projection=None";

assert_eq!(actual, expected);
Expand Down Expand Up @@ -929,7 +945,7 @@ mod tests {

// Note that constant folder runs and folds the entire
// expression down to a single constant (true)
let expected = "Filter: Boolean(true)\
let expected = "Filter: Boolean(true) AS CAST(now() AS Int64) < CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Int64) + Int32(50000)\
\n TableScan: test projection=None";
let actual = get_optimized_plan_formatted(&plan, &time);

Expand Down
7 changes: 4 additions & 3 deletions datafusion/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use super::{
use crate::execution::context::ExecutionContextState;
use crate::logical_plan::plan::EmptyRelation;
use crate::logical_plan::{
unnormalize_cols, CrossJoin, DFSchema, Expr, LogicalPlan, Operator,
unalias, unnormalize_cols, CrossJoin, DFSchema, Expr, LogicalPlan, Operator,
Partitioning as LogicalPartitioning, PlanType, Repartition, ToStringifiedPlan, Union,
UserDefinedLogicalNode,
};
Expand Down Expand Up @@ -346,7 +346,8 @@ impl DefaultPhysicalPlanner {
// doesn't know (nor should care) how the relation was
// referred to in the query
let filters = unnormalize_cols(filters.iter().cloned());
source.scan(projection, batch_size, &filters, *limit).await
let unaliased: Vec<Expr> = filters.into_iter().map(unalias).collect();
source.scan(projection, batch_size, &unaliased, *limit).await
}
LogicalPlan::Values(Values {
values,
Expand Down Expand Up @@ -1347,7 +1348,7 @@ impl DefaultPhysicalPlanner {
physical_input_schema: &Schema,
ctx_state: &ExecutionContextState,
) -> Result<Arc<dyn AggregateExpr>> {
// unpack aliased logical expressions, e.g. "sum(col) as total"
// unpack (nested) aliased logical expressions, e.g. "sum(col) as total"
let (name, e) = match e {
Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()),
_ => (physical_name(e)?, e),
Expand Down
58 changes: 37 additions & 21 deletions datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,22 @@ async fn csv_query_approx_count() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn query_count_without_from() -> Result<()> {
let mut ctx = ExecutionContext::new();
let sql = "SELECT count(1 + 1)";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+----------------------------+",
"| COUNT(Int64(1) + Int64(1)) |",
"+----------------------------+",
"| 1 |",
"+----------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn csv_query_array_agg() -> Result<()> {
let mut ctx = ExecutionContext::new();
Expand Down Expand Up @@ -1553,12 +1569,12 @@ async fn csv_query_cast_literal() -> Result<()> {
let actual = execute_to_batches(&mut ctx, sql).await;

let expected = vec![
"+--------------------+------------+",
"| c12 | Float64(1) |",
"+--------------------+------------+",
"| 0.9294097332465232 | 1 |",
"| 0.3114712539863804 | 1 |",
"+--------------------+------------+",
"+--------------------+---------------------------+",
"| c12 | CAST(Int64(1) AS Float64) |",
"+--------------------+---------------------------+",
"| 0.9294097332465232 | 1 |",
"| 0.3114712539863804 | 1 |",
"+--------------------+---------------------------+",
];

assert_batches_eq!(expected, &actual);
Expand Down Expand Up @@ -4264,11 +4280,11 @@ async fn query_without_from() -> Result<()> {
let sql = "SELECT 1+2, 3/4, cos(0)";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+----------+----------+------------+",
"| Int64(3) | Int64(0) | Float64(1) |",
"+----------+----------+------------+",
"| 3 | 0 | 1 |",
"+----------+----------+------------+",
"+---------------------+---------------------+---------------+",
"| Int64(1) + Int64(2) | Int64(3) / Int64(4) | cos(Int64(0)) |",
"+---------------------+---------------------+---------------+",
"| 3 | 0 | 1 |",
"+---------------------+---------------------+---------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down Expand Up @@ -5717,11 +5733,11 @@ async fn case_with_bool_type_result() -> Result<()> {
let sql = "select case when 'cpu' != 'cpu' then true else false end";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+----------------+",
"| Boolean(false) |",
"+----------------+",
"| false |",
"+----------------+",
"+---------------------------------------------------------------------------------+",
"| CASE WHEN Utf8(\"cpu\") != Utf8(\"cpu\") THEN Boolean(true) ELSE Boolean(false) END |",
"+---------------------------------------------------------------------------------+",
"| false |",
"+---------------------------------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
Expand All @@ -5734,11 +5750,11 @@ async fn use_between_expression_in_select_query() -> Result<()> {
let sql = "SELECT 1 NOT BETWEEN 3 AND 5";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+---------------+",
"| Boolean(true) |",
"+---------------+",
"| true |",
"+---------------+",
"+--------------------------------------------+",
"| Int64(1) NOT BETWEEN Int64(3) AND Int64(5) |",
"+--------------------------------------------+",
"| true |",
"+--------------------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down