From c36330ad8ee4de0622b94a59bfbacba454117530 Mon Sep 17 00:00:00 2001 From: zhang2014 Date: Sun, 18 Jul 2021 12:48:56 +0800 Subject: [PATCH 1/5] try refactor constant folding --- common/datavalues/src/data_schema.rs | 4 + common/functions/src/scalars/function.rs | 4 + common/planners/src/plan_rewriter.rs | 2 +- fusequery/query/src/optimizers/optimizer.rs | 3 +- .../optimizers/optimizer_constant_folding.rs | 175 +++++++++--------- 5 files changed, 96 insertions(+), 92 deletions(-) diff --git a/common/datavalues/src/data_schema.rs b/common/datavalues/src/data_schema.rs index 0bbb4fd6c2f93..c122968bb2f63 100644 --- a/common/datavalues/src/data_schema.rs +++ b/common/datavalues/src/data_schema.rs @@ -96,6 +96,10 @@ pub type DataSchemaRef = Arc; pub struct DataSchemaRefExt; impl DataSchemaRefExt { + pub fn empty() -> DataSchemaRef { + Arc::new(DataSchema::empty()) + } + pub fn create(fields: Vec) -> DataSchemaRef { Arc::new(DataSchema::new(fields)) } diff --git a/common/functions/src/scalars/function.rs b/common/functions/src/scalars/function.rs index 29823f85b104f..5cbc773434344 100644 --- a/common/functions/src/scalars/function.rs +++ b/common/functions/src/scalars/function.rs @@ -26,4 +26,8 @@ pub trait Function: fmt::Display + Sync + Send + DynClone { fn return_type(&self, args: &[DataType]) -> Result; fn nullable(&self, _input_schema: &DataSchema) -> Result; fn eval(&self, columns: &[DataColumn], _input_rows: usize) -> Result; + + fn is_deterministic(&self) -> bool { + true + } } diff --git a/common/planners/src/plan_rewriter.rs b/common/planners/src/plan_rewriter.rs index b4736d90c0ce8..fa6c3beffa92b 100644 --- a/common/planners/src/plan_rewriter.rs +++ b/common/planners/src/plan_rewriter.rs @@ -13,7 +13,7 @@ use common_exception::Result; use crate::plan_broadcast::BroadcastPlan; use crate::plan_subqueries_set::SubQueriesSetPlan; -use crate::AggregatorFinalPlan; +use crate::{AggregatorFinalPlan, rebase_expr_from_input}; use crate::AggregatorPartialPlan; use crate::CreateDatabasePlan; use crate::CreateTablePlan; diff --git a/fusequery/query/src/optimizers/optimizer.rs b/fusequery/query/src/optimizers/optimizer.rs index 7cef222190a7d..1a50b9f39ebe2 100644 --- a/fusequery/query/src/optimizers/optimizer.rs +++ b/fusequery/query/src/optimizers/optimizer.rs @@ -7,7 +7,7 @@ use common_planners::PlanNode; use common_tracing::tracing; use crate::optimizers::optimizer_scatters::ScattersOptimizer; -use crate::optimizers::ProjectionPushDownOptimizer; +use crate::optimizers::{ProjectionPushDownOptimizer, ConstantFoldingOptimizer}; use crate::optimizers::StatisticsExactOptimizer; use crate::sessions::FuseQueryContextRef; @@ -32,6 +32,7 @@ impl Optimizers { pub fn without_scatters(ctx: FuseQueryContextRef) -> Self { Optimizers { inner: vec![ + Box::new(ConstantFoldingOptimizer::create(ctx.clone())), Box::new(ProjectionPushDownOptimizer::create(ctx.clone())), Box::new(StatisticsExactOptimizer::create(ctx)), ], diff --git a/fusequery/query/src/optimizers/optimizer_constant_folding.rs b/fusequery/query/src/optimizers/optimizer_constant_folding.rs index 28129abda0776..ca61fcd4abc7f 100644 --- a/fusequery/query/src/optimizers/optimizer_constant_folding.rs +++ b/fusequery/query/src/optimizers/optimizer_constant_folding.rs @@ -5,7 +5,7 @@ use common_datavalues::prelude::*; use common_exception::ErrorCode; use common_exception::Result; -use common_planners::AggregatorFinalPlan; +use common_planners::{AggregatorFinalPlan, Expressions}; use common_planners::AggregatorPartialPlan; use common_planners::Expression; use common_planners::PlanBuilder; @@ -14,107 +14,102 @@ use common_planners::PlanRewriter; use crate::optimizers::Optimizer; use crate::sessions::FuseQueryContextRef; +use common_functions::scalars::FunctionFactory; +use crate::pipelines::transforms::ExpressionExecutor; +use common_datablocks::DataBlock; pub struct ConstantFoldingOptimizer {} -fn is_boolean_type(schema: &DataSchemaRef, expr: &Expression) -> Result { - if let DataType::Boolean = expr.to_data_field(schema)?.data_type() { - return Ok(true); - } - Ok(false) -} - struct ConstantFoldingImpl { before_group_by_schema: Option, } -fn constant_folding(schema: &DataSchemaRef, expr: Expression) -> Result { - let new_expr = match expr { - Expression::BinaryExpression { left, op, right } => match op.as_str() { - "=" => match (left.as_ref(), right.as_ref()) { - ( - Expression::Literal(DataValue::Boolean(l)), - Expression::Literal(DataValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => Expression::Literal(DataValue::Boolean(Some(l == r))), - _ => Expression::Literal(DataValue::Boolean(None)), - }, - (Expression::Literal(DataValue::Boolean(b)), _) - if is_boolean_type(schema, &right)? => - { - match b { - Some(true) => *right, - // Fix this after we implement NOT - Some(false) => Expression::BinaryExpression { left, op, right }, - None => Expression::Literal(DataValue::Boolean(None)), - } - } - (_, Expression::Literal(DataValue::Boolean(b))) - if is_boolean_type(schema, &left)? => - { - match b { - Some(true) => *left, - // Fix this after we implement NOT - Some(false) => Expression::BinaryExpression { left, op, right }, - None => Expression::Literal(DataValue::Boolean(None)), - } - } - _ => Expression::BinaryExpression { - left, - op: "=".to_string(), - right, - }, - }, - "!=" => match (left.as_ref(), right.as_ref()) { - ( - Expression::Literal(DataValue::Boolean(l)), - Expression::Literal(DataValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => Expression::Literal(DataValue::Boolean(Some(l != r))), - _ => Expression::Literal(DataValue::Boolean(None)), - }, - (Expression::Literal(DataValue::Boolean(b)), _) - if is_boolean_type(schema, &right)? => - { - match b { - Some(true) => Expression::BinaryExpression { left, op, right }, - Some(false) => *right, - None => Expression::Literal(DataValue::Boolean(None)), - } - } - (_, Expression::Literal(DataValue::Boolean(b))) - if is_boolean_type(schema, &left)? => - { - match b { - Some(true) => Expression::BinaryExpression { left, op, right }, - Some(false) => *left, - None => Expression::Literal(DataValue::Boolean(None)), - } - } - _ => Expression::BinaryExpression { - left, - op: "!=".to_string(), - right, - }, - }, - _ => Expression::BinaryExpression { left, op, right }, - }, - expr => { - // do nothing - expr +impl ConstantFoldingImpl { + fn rewrite_alias(alias: &str, expr: Expression) -> Result { + Ok(Expression::Alias(alias.to_string(), Box::new(expr))) + } + + fn constants_arguments(args: &[Expression]) -> bool { + !args.iter().any(|expr| !matches!(expr, Expression::Literal(_))) + } + + fn rewrite_function(name: &str, args: Expressions) -> Result { + println!("rewrite function"); + let function = FunctionFactory::get(name)?; + match function.is_deterministic() { + true => Self::rewrite_deterministic_function(name, args), + false => Ok(Expression::ScalarFunction { + op: name.to_string(), + args: args.to_vec(), + }) + } + } + + fn rewrite_deterministic_function(name: &str, args: Expressions) -> Result { + println!("rewrite deterministic function"); + match ConstantFoldingImpl::constants_arguments(&args) { + true => Self::rewrite_const_deterministic_function(name, args), + false => Ok(Expression::ScalarFunction { op: name.to_string(), args }) } - }; - Ok(new_expr) + } + + fn expr_executor(schema: &DataSchemaRef, expr: Expression) -> Result { + let output_fields = vec![expr.to_data_field(&schema)?]; + let output_schema = DataSchemaRefExt::create(output_fields); + ExpressionExecutor::try_create( + "Constant folding optimizer.", + schema.clone(), + output_schema, + vec![expr], + false, + ) + } + + fn rewrite_const_deterministic_function(name: &str, args: Expressions) -> Result { + println!("rewrite const deterministic function"); + let input_fields = vec![DataField::new("_dummy", DataType::UInt8, false)]; + let input_schema = Arc::new(DataSchema::new(input_fields)); + + let expression = Expression::ScalarFunction { op: name.to_string(), args }; + let expression_executor = ConstantFoldingImpl::expr_executor(&input_schema, expression)?; + let dummy_columns = vec![DataColumn::Constant(DataValue::UInt8(Some(1)), 1)]; + let data_block = DataBlock::create(input_schema, dummy_columns); + let executed_data_block = expression_executor.execute(&data_block)?; + + assert_eq!(executed_data_block.num_rows(), 1); + assert_eq!(executed_data_block.num_columns(), 1); + Ok(Expression::Literal(executed_data_block.column(0).to_values()?[0].clone())) + } } impl PlanRewriter for ConstantFoldingImpl { fn rewrite_expr(&mut self, schema: &DataSchemaRef, expr: &Expression) -> Result { - /* TODO: Recursively optimize constant expressions. - * such as: - * subquery, - * a + (1 + (2 + 3)) => a + 6 - */ - constant_folding(schema, expr.clone()) + println!("rewrite_expr"); + match expr { + Expression::Alias(alias, expr) => { + Self::rewrite_alias(alias, self.rewrite_expr(schema, expr)?) + }, + Expression::ScalarFunction { op, args } => { + Ok(Expression::Alias(expr.column_name(), Box::new(Self::rewrite_function(op, args + .iter() + .map(|expr| Self::rewrite_expr(self, schema, expr)) + .collect::>>()?)? + ))) + } + Expression::UnaryExpression { op, expr } => { + Ok(Expression::Alias(expr.column_name(), Box::new( + Self::rewrite_function(op, vec![self.rewrite_expr(schema, expr)?])? + ))) + } + Expression::BinaryExpression { op, left, right } => { + Ok(Expression::Alias(expr.column_name(), Box::new(Self::rewrite_function(op, vec![ + self.rewrite_expr(schema, left)?, + self.rewrite_expr(schema, right)? + ])?))) + }, + + _ => Ok(expr.clone()), + } } fn rewrite_aggregate_partial(&mut self, plan: &AggregatorPartialPlan) -> Result { From 0c188dc4ea3ef906fd84613e25e0360d31161f1a Mon Sep 17 00:00:00 2001 From: zhang2014 Date: Sun, 18 Jul 2021 22:50:15 +0800 Subject: [PATCH 2/5] Implement constant folding for deterministic functions --- common/functions/src/scalars/udfs/crash_me.rs | 4 + common/functions/src/scalars/udfs/database.rs | 4 + common/functions/src/scalars/udfs/exists.rs | 4 + common/functions/src/scalars/udfs/sleep.rs | 4 + common/functions/src/scalars/udfs/version.rs | 4 + common/planners/src/lib.rs | 1 + common/planners/src/plan_rewriter.rs | 53 +++++-- .../optimizers/optimizer_constant_folding.rs | 99 +++++++----- .../optimizer_constant_folding_test.rs | 146 ++++++++++++------ .../optimizer_projection_push_down.rs | 11 +- 10 files changed, 224 insertions(+), 106 deletions(-) diff --git a/common/functions/src/scalars/udfs/crash_me.rs b/common/functions/src/scalars/udfs/crash_me.rs index a90329f23f357..49aeca1ae3d5a 100644 --- a/common/functions/src/scalars/udfs/crash_me.rs +++ b/common/functions/src/scalars/udfs/crash_me.rs @@ -44,6 +44,10 @@ impl Function for CrashMeFunction { fn eval(&self, _columns: &[DataColumn], _input_rows: usize) -> Result { panic!("crash me function"); } + + fn is_deterministic(&self) -> bool { + false + } } impl fmt::Display for CrashMeFunction { diff --git a/common/functions/src/scalars/udfs/database.rs b/common/functions/src/scalars/udfs/database.rs index 5f8d615eef28d..179194e24c8af 100644 --- a/common/functions/src/scalars/udfs/database.rs +++ b/common/functions/src/scalars/udfs/database.rs @@ -41,6 +41,10 @@ impl Function for DatabaseFunction { fn num_arguments(&self) -> usize { 1 } + + fn is_deterministic(&self) -> bool { + false + } } impl fmt::Display for DatabaseFunction { diff --git a/common/functions/src/scalars/udfs/exists.rs b/common/functions/src/scalars/udfs/exists.rs index 7c64067e47888..c99fb61bac335 100644 --- a/common/functions/src/scalars/udfs/exists.rs +++ b/common/functions/src/scalars/udfs/exists.rs @@ -64,6 +64,10 @@ impl Function for ExistsFunction { fn num_arguments(&self) -> usize { 1 } + + fn is_deterministic(&self) -> bool { + false + } } impl fmt::Display for ExistsFunction { diff --git a/common/functions/src/scalars/udfs/sleep.rs b/common/functions/src/scalars/udfs/sleep.rs index 9b718f33f2659..99fab0e541c62 100644 --- a/common/functions/src/scalars/udfs/sleep.rs +++ b/common/functions/src/scalars/udfs/sleep.rs @@ -91,6 +91,10 @@ impl Function for SleepFunction { } } } + + fn is_deterministic(&self) -> bool { + false + } } impl fmt::Display for SleepFunction { diff --git a/common/functions/src/scalars/udfs/version.rs b/common/functions/src/scalars/udfs/version.rs index 12b15f15ca9e6..69965d1fc984c 100644 --- a/common/functions/src/scalars/udfs/version.rs +++ b/common/functions/src/scalars/udfs/version.rs @@ -44,6 +44,10 @@ impl Function for VersionFunction { fn num_arguments(&self) -> usize { 1 } + + fn is_deterministic(&self) -> bool { + false + } } impl fmt::Display for VersionFunction { diff --git a/common/planners/src/lib.rs b/common/planners/src/lib.rs index f7bce11856f5f..59c005f9e8540 100644 --- a/common/planners/src/lib.rs +++ b/common/planners/src/lib.rs @@ -136,6 +136,7 @@ pub use plan_projection::ProjectionPlan; pub use plan_read_datasource::ReadDataSourcePlan; pub use plan_remote::RemotePlan; pub use plan_rewriter::PlanRewriter; +pub use plan_rewriter::SchemaChanges; pub use plan_rewriter::RewriteHelper; pub use plan_scan::ScanPlan; pub use plan_select::SelectPlan; diff --git a/common/planners/src/plan_rewriter.rs b/common/planners/src/plan_rewriter.rs index fa6c3beffa92b..4cb348f738bc5 100644 --- a/common/planners/src/plan_rewriter.rs +++ b/common/planners/src/plan_rewriter.rs @@ -13,7 +13,7 @@ use common_exception::Result; use crate::plan_broadcast::BroadcastPlan; use crate::plan_subqueries_set::SubQueriesSetPlan; -use crate::{AggregatorFinalPlan, rebase_expr_from_input}; +use crate::AggregatorFinalPlan; use crate::AggregatorPartialPlan; use crate::CreateDatabasePlan; use crate::CreateTablePlan; @@ -100,30 +100,30 @@ pub trait PlanRewriter { } // TODO: Move it to ExpressionsRewrite trait - fn rewrite_expr(&mut self, schema: &DataSchemaRef, expr: &Expression) -> Result { + fn rewrite_expr(&mut self, changes: &SchemaChanges, expr: &Expression) -> Result { match expr { Expression::Alias(alias, input) => Ok(Expression::Alias( alias.clone(), - Box::new(self.rewrite_expr(schema, input.as_ref())?), + Box::new(self.rewrite_expr(changes, input.as_ref())?), )), Expression::UnaryExpression { op, expr } => Ok(Expression::UnaryExpression { op: op.clone(), - expr: Box::new(self.rewrite_expr(schema, expr.as_ref())?), + expr: Box::new(self.rewrite_expr(changes, expr.as_ref())?), }), Expression::BinaryExpression { op, left, right } => Ok(Expression::BinaryExpression { op: op.clone(), - left: Box::new(self.rewrite_expr(schema, left.as_ref())?), - right: Box::new(self.rewrite_expr(schema, right.as_ref())?), + left: Box::new(self.rewrite_expr(changes, left.as_ref())?), + right: Box::new(self.rewrite_expr(changes, right.as_ref())?), }), Expression::ScalarFunction { op, args } => Ok(Expression::ScalarFunction { op: op.clone(), - args: self.rewrite_exprs(schema, args)?, + args: self.rewrite_exprs(changes, args)?, }), Expression::AggregateFunction { op, distinct, args } => { Ok(Expression::AggregateFunction { op: op.clone(), distinct: *distinct, - args: self.rewrite_exprs(schema, args)?, + args: self.rewrite_exprs(changes, args)?, }) } Expression::Sort { @@ -131,12 +131,12 @@ pub trait PlanRewriter { asc, nulls_first, } => Ok(Expression::Sort { - expr: Box::new(self.rewrite_expr(schema, expr.as_ref())?), + expr: Box::new(self.rewrite_expr(changes, expr.as_ref())?), asc: *asc, nulls_first: *nulls_first, }), Expression::Cast { expr, data_type } => Ok(Expression::Cast { - expr: Box::new(self.rewrite_expr(schema, expr.as_ref())?), + expr: Box::new(self.rewrite_expr(changes, expr.as_ref())?), data_type: data_type.clone(), }), Expression::Wildcard => Ok(Expression::Wildcard), @@ -162,12 +162,12 @@ pub trait PlanRewriter { // TODO: Move it to ExpressionsRewrite trait fn rewrite_exprs( &mut self, - schema: &DataSchemaRef, + changes: &SchemaChanges, exprs: &[Expression], ) -> Result { exprs .iter() - .map(|expr| Self::rewrite_expr(self, schema, expr)) + .map(|expr| Self::rewrite_expr(self, changes, expr)) .collect::>>() } @@ -201,13 +201,15 @@ pub trait PlanRewriter { fn rewrite_projection(&mut self, plan: &ProjectionPlan) -> Result { let new_input = self.rewrite_plan_node(plan.input.as_ref())?; - let new_exprs = self.rewrite_exprs(&new_input.schema(), &plan.expr)?; + let schema_change = SchemaChanges::new(&plan.input.schema(), &new_input.schema()); + let new_exprs = self.rewrite_exprs(&schema_change, &plan.expr)?; PlanBuilder::from(&new_input).project(&new_exprs)?.build() } fn rewrite_expression(&mut self, plan: &ExpressionPlan) -> Result { let new_input = self.rewrite_plan_node(plan.input.as_ref())?; - let new_exprs = self.rewrite_exprs(&new_input.schema(), &plan.exprs)?; + let schema_change = SchemaChanges::new(&plan.input.schema(), &new_input.schema()); + let new_exprs = self.rewrite_exprs(&schema_change, &plan.exprs)?; PlanBuilder::from(&new_input) .expression(&new_exprs, &plan.desc)? .build() @@ -220,19 +222,22 @@ pub trait PlanRewriter { fn rewrite_filter(&mut self, plan: &FilterPlan) -> Result { let new_input = self.rewrite_plan_node(plan.input.as_ref())?; - let new_predicate = self.rewrite_expr(&new_input.schema(), &plan.predicate)?; + let schema_change = SchemaChanges::new(&plan.input.schema(), &new_input.schema()); + let new_predicate = self.rewrite_expr(&schema_change, &plan.predicate)?; PlanBuilder::from(&new_input).filter(new_predicate)?.build() } fn rewrite_having(&mut self, plan: &HavingPlan) -> Result { let new_input = self.rewrite_plan_node(plan.input.as_ref())?; - let new_predicate = self.rewrite_expr(&new_input.schema(), &plan.predicate)?; + let schema_change = SchemaChanges::new(&plan.input.schema(), &new_input.schema()); + let new_predicate = self.rewrite_expr(&schema_change, &plan.predicate)?; PlanBuilder::from(&new_input).having(new_predicate)?.build() } fn rewrite_sort(&mut self, plan: &SortPlan) -> Result { let new_input = self.rewrite_plan_node(plan.input.as_ref())?; - let new_order_by = self.rewrite_exprs(&new_input.schema(), &plan.order_by)?; + let schema_change = SchemaChanges::new(&plan.input.schema(), &new_input.schema()); + let new_order_by = self.rewrite_exprs(&schema_change, &plan.order_by)?; PlanBuilder::from(&new_input).sort(&new_order_by)?.build() } @@ -319,6 +324,20 @@ pub trait PlanRewriter { } } +pub struct SchemaChanges { + pub before_input_schema: DataSchemaRef, + pub after_input_schema: DataSchemaRef, +} + +impl SchemaChanges { + pub fn new(before: &DataSchemaRef, after: &DataSchemaRef) -> SchemaChanges { + SchemaChanges { + before_input_schema: before.clone(), + after_input_schema: after.clone(), + } + } +} + pub struct RewriteHelper {} struct QueryAliasData { diff --git a/fusequery/query/src/optimizers/optimizer_constant_folding.rs b/fusequery/query/src/optimizers/optimizer_constant_folding.rs index ca61fcd4abc7f..4b720805f049d 100644 --- a/fusequery/query/src/optimizers/optimizer_constant_folding.rs +++ b/fusequery/query/src/optimizers/optimizer_constant_folding.rs @@ -5,7 +5,7 @@ use common_datavalues::prelude::*; use common_exception::ErrorCode; use common_exception::Result; -use common_planners::{AggregatorFinalPlan, Expressions}; +use common_planners::{AggregatorFinalPlan, Expressions, SchemaChanges}; use common_planners::AggregatorPartialPlan; use common_planners::Expression; use common_planners::PlanBuilder; @@ -33,23 +33,21 @@ impl ConstantFoldingImpl { !args.iter().any(|expr| !matches!(expr, Expression::Literal(_))) } - fn rewrite_function(name: &str, args: Expressions) -> Result { - println!("rewrite function"); - let function = FunctionFactory::get(name)?; + fn rewrite_function(op: &str, args: Expressions) -> Result> { + let function = FunctionFactory::get(op)?; match function.is_deterministic() { - true => Self::rewrite_deterministic_function(name, args), - false => Ok(Expression::ScalarFunction { - op: name.to_string(), - args: args.to_vec(), - }) + true => Self::rewrite_deterministic_function(op, args), + false => Ok(None) } } - fn rewrite_deterministic_function(name: &str, args: Expressions) -> Result { - println!("rewrite deterministic function"); + fn rewrite_deterministic_function(op: &str, args: Expressions) -> Result> { match ConstantFoldingImpl::constants_arguments(&args) { - true => Self::rewrite_const_deterministic_function(name, args), - false => Ok(Expression::ScalarFunction { op: name.to_string(), args }) + false => Ok(None), + true => ConstantFoldingImpl::execute_expression(Expression::ScalarFunction { + op: op.to_string(), + args: args, + }), } } @@ -65,49 +63,80 @@ impl ConstantFoldingImpl { ) } - fn rewrite_const_deterministic_function(name: &str, args: Expressions) -> Result { - println!("rewrite const deterministic function"); + fn execute_expression(expression: Expression) -> Result> { let input_fields = vec![DataField::new("_dummy", DataType::UInt8, false)]; let input_schema = Arc::new(DataSchema::new(input_fields)); - let expression = Expression::ScalarFunction { op: name.to_string(), args }; - let expression_executor = ConstantFoldingImpl::expr_executor(&input_schema, expression)?; + let expression_executor = Self::expr_executor(&input_schema, expression)?; let dummy_columns = vec![DataColumn::Constant(DataValue::UInt8(Some(1)), 1)]; - let data_block = DataBlock::create(input_schema, dummy_columns); + let data_block = DataBlock::create(input_schema.clone(), dummy_columns); let executed_data_block = expression_executor.execute(&data_block)?; assert_eq!(executed_data_block.num_rows(), 1); assert_eq!(executed_data_block.num_columns(), 1); - Ok(Expression::Literal(executed_data_block.column(0).to_values()?[0].clone())) + Ok(Some(Expression::Literal(executed_data_block.column(0).to_values()?[0].clone()))) } } impl PlanRewriter for ConstantFoldingImpl { - fn rewrite_expr(&mut self, schema: &DataSchemaRef, expr: &Expression) -> Result { - println!("rewrite_expr"); + fn rewrite_expr(&mut self, changes: &SchemaChanges, expr: &Expression) -> Result { match expr { Expression::Alias(alias, expr) => { - Self::rewrite_alias(alias, self.rewrite_expr(schema, expr)?) + Self::rewrite_alias(alias, self.rewrite_expr(changes, expr)?) }, Expression::ScalarFunction { op, args } => { - Ok(Expression::Alias(expr.column_name(), Box::new(Self::rewrite_function(op, args + let new_args = args .iter() - .map(|expr| Self::rewrite_expr(self, schema, expr)) - .collect::>>()?)? - ))) + .map(|expr| Self::rewrite_expr(self, changes, expr)) + .collect::>>()?; + + match Self::rewrite_function(op, new_args.clone())? { + Some(new_expr) => Ok(new_expr), + None => Ok(Expression::ScalarFunction { + op: op.clone(), + args: new_args, + }), + } } - Expression::UnaryExpression { op, expr } => { - Ok(Expression::Alias(expr.column_name(), Box::new( - Self::rewrite_function(op, vec![self.rewrite_expr(schema, expr)?])? - ))) + Expression::UnaryExpression { op, expr: inner_expr } => { + let new_expr = self.rewrite_expr(changes, inner_expr)?; + match Self::rewrite_function(op, vec![new_expr.clone()])? { + Some(new_expr) => Ok(new_expr), + None => Ok(Expression::UnaryExpression { + op: op.clone(), + expr: Box::new(new_expr), + }), + } } Expression::BinaryExpression { op, left, right } => { - Ok(Expression::Alias(expr.column_name(), Box::new(Self::rewrite_function(op, vec![ - self.rewrite_expr(schema, left)?, - self.rewrite_expr(schema, right)? - ])?))) + let new_left = self.rewrite_expr(changes, left)?; + let new_right = self.rewrite_expr(changes, right)?; + match Self::rewrite_function(op, vec![new_left.clone(), new_right.clone()])? { + Some(new_expr) => Ok(new_expr), + None => Ok(Expression::BinaryExpression { + op: op.clone(), + left: Box::new(new_left), + right: Box::new(new_right), + }), + } }, - + Expression::Cast { expr, data_type } => { + let new_expr = self.rewrite_expr(changes, expr)?; + match &new_expr { + Expression::Literal(_) => Ok(Self::execute_expression( + Expression::Cast { + expr: Box::new(new_expr), + data_type: data_type.clone(), + } + )?.unwrap()), + _ => Ok(new_expr) + } + } + Expression::Column(column_name) => { + let field_pos = changes.before_input_schema.index_of(column_name)?; + let new_field = changes.after_input_schema.field(field_pos); + Ok(Expression::Column(new_field.name().to_string())) + } _ => Ok(expr.clone()), } } diff --git a/fusequery/query/src/optimizers/optimizer_constant_folding_test.rs b/fusequery/query/src/optimizers/optimizer_constant_folding_test.rs index 39fa18930cf08..f2031813d99a2 100644 --- a/fusequery/query/src/optimizers/optimizer_constant_folding_test.rs +++ b/fusequery/query/src/optimizers/optimizer_constant_folding_test.rs @@ -4,64 +4,110 @@ #[cfg(test)] mod tests { - use std::mem::size_of; - use std::sync::Arc; - - use common_datavalues::prelude::*; use common_exception::Result; - use common_planners::*; - use pretty_assertions::assert_eq; - - use crate::optimizers::optimizer_test::*; use crate::optimizers::*; #[test] fn test_constant_folding_optimizer() -> Result<()> { - let ctx = crate::tests::try_create_context()?; - - let total = ctx.get_settings().get_max_block_size()? as u64; - let statistics = - Statistics::new_exact(total as usize, ((total) * size_of::() as u64) as usize); - ctx.try_set_statistics(&statistics)?; - let source_plan = PlanNode::ReadSource(ReadDataSourcePlan { - db: "system".to_string(), - table: "test".to_string(), - schema: DataSchemaRefExt::create(vec![ - DataField::new("a", DataType::Utf8, false), - DataField::new("b", DataType::Utf8, false), - DataField::new("c", DataType::Utf8, false), - ]), - parts: generate_partitions(8, total as u64), - statistics: statistics.clone(), - description: format!( - "(Read from system.{} table, Read Rows:{}, Read Bytes:{})", - "test".to_string(), - statistics.read_rows, - statistics.read_bytes - ), - scan_plan: Arc::new(ScanPlan::empty()), - remote: false, - }); - - let filter_plan = PlanBuilder::from(&source_plan) - .filter(col("a").gt(lit(6)).eq(lit(true)))? - .build()?; + #[allow(dead_code)] + struct Test { + name: &'static str, + query: &'static str, + expect: &'static str, + } - let plan = PlanNode::Projection(ProjectionPlan { - expr: vec![col("a")], - schema: DataSchemaRefExt::create(vec![DataField::new("a", DataType::Utf8, false)]), - input: Arc::from(filter_plan), - }); + let tests: Vec = vec![ + Test { + name: "Projection const recursion", + query: "SELECT 1 + 2 + 3", + expect: "\ + Projection: 6:UInt32\ + \n Expression: 6:UInt32 (Before Projection)\ + \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", + }, + Test { + name: "Projection left non const recursion", + query: "SELECT dummy + 1 + 2 + 3", + expect: "\ + Projection: (((dummy + 1) + 2) + 3):UInt64\ + \n Expression: (((dummy + 1) + 2) + 3):UInt64 (Before Projection)\ + \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", + }, + Test { + name: "Projection right non const recursion", + query: "SELECT 1 + 2 + 3 + dummy", + expect: "\ + Projection: (6 + dummy):UInt64\ + \n Expression: (6 + dummy):UInt64 (Before Projection)\ + \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", + }, + Test { + name: "Projection arithmetic const recursion", + query: "SELECT 1 + 2 + 3 / 3", + expect: "\ + Projection: 4:Float64\ + \n Expression: 4:Float64 (Before Projection)\ + \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", + }, + Test { + name: "Projection comparisons const recursion", + query: "SELECT 1 + 2 + 3 > 3", + expect: "\ + Projection: true:Boolean\ + \n Expression: true:Boolean (Before Projection)\ + \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", + }, + Test { + name: "Projection cast const recursion", + query: "SELECT CAST(1 AS bigint)", + expect: "\ + Projection: 1:Int64\ + \n Expression: 1:Int64 (Before Projection)\ + \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", + }, + Test { + name: "Projection hash const recursion", + query: "SELECT sipHash('test_string')", + expect: "\ + Projection: 17123704338732264132:UInt64\ + \n Expression: 17123704338732264132:UInt64 (Before Projection)\ + \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", + }, + Test { + name: "Projection logics const recursion", + query: "SELECT 1 = 1 AND 2 > 1", + expect: "\ + Projection: true:Boolean\ + \n Expression: true:Boolean (Before Projection)\ + \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", + }, + Test { + name: "Projection strings const recursion", + query: "SELECT SUBSTRING('1234567890' FROM 3 FOR 3)", + expect: "\ + Projection: 345:Utf8\ + \n Expression: 345:Utf8 (Before Projection)\ + \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", + }, + Test { + name: "Projection to type name const recursion", + query: "SELECT toTypeName('1234567890')", + expect: "\ + Projection: Utf8:Utf8\ + \n Expression: Utf8:Utf8 (Before Projection)\ + \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", + }, + ]; - let mut constant_folding = ConstantFoldingOptimizer::create(ctx); - let optimized = constant_folding.optimize(&plan)?; + for test in tests { + let ctx = crate::tests::try_create_context()?; - let expect = "\ - Projection: a:Utf8\ - \n Filter: (a > 6)\ - \n ReadDataSource: scan partitions: [8], scan schema: [a:Utf8, b:Utf8, c:Utf8], statistics: [read_rows: 10000, read_bytes: 80000]"; - let actual = format!("{:?}", optimized); - assert_eq!(expect, actual); + let plan = crate::tests::parse_query(test.query)?; + let mut optimizer = ConstantFoldingOptimizer::create(ctx); + let optimized = optimizer.optimize(&plan)?; + let actual = format!("{:?}", optimized); + assert_eq!(test.expect, actual, "{:#?}", test.name); + } Ok(()) } } diff --git a/fusequery/query/src/optimizers/optimizer_projection_push_down.rs b/fusequery/query/src/optimizers/optimizer_projection_push_down.rs index 6bcef5bec3487..c36292f2b05a7 100644 --- a/fusequery/query/src/optimizers/optimizer_projection_push_down.rs +++ b/fusequery/query/src/optimizers/optimizer_projection_push_down.rs @@ -10,7 +10,7 @@ use common_datavalues::DataSchemaRef; use common_datavalues::DataSchemaRefExt; use common_exception::ErrorCode; use common_exception::Result; -use common_planners::AggregatorFinalPlan; +use common_planners::{AggregatorFinalPlan, SchemaChanges}; use common_planners::AggregatorPartialPlan; use common_planners::EmptyPlan; use common_planners::Expression; @@ -76,24 +76,27 @@ impl PlanRewriter for ProjectionPushDownImpl { self.collect_column_names_from_expr_vec(plan.expr.as_slice())?; self.has_projection = true; let new_input = self.rewrite_plan_node(&plan.input)?; + let schema_changes = SchemaChanges::new(&plan.schema, &new_input.schema()); PlanBuilder::from(&new_input) - .project(&self.rewrite_exprs(&new_input.schema(), &plan.expr)?)? + .project(&self.rewrite_exprs(&schema_changes, &plan.expr)?)? .build() } fn rewrite_filter(&mut self, plan: &FilterPlan) -> Result { self.collect_column_names_from_expr(&plan.predicate)?; let new_input = self.rewrite_plan_node(&plan.input)?; + let schema_changes = SchemaChanges::new(&plan.schema, &new_input.schema()); PlanBuilder::from(&new_input) - .filter(self.rewrite_expr(&new_input.schema(), &plan.predicate)?)? + .filter(self.rewrite_expr(&schema_changes, &plan.predicate)?)? .build() } fn rewrite_sort(&mut self, plan: &SortPlan) -> Result { self.collect_column_names_from_expr_vec(plan.order_by.as_slice())?; let new_input = self.rewrite_plan_node(&plan.input)?; + let schema_changes = SchemaChanges::new(&plan.schema, &new_input.schema()); PlanBuilder::from(&new_input) - .sort(&self.rewrite_exprs(&new_input.schema(), &plan.order_by)?)? + .sort(&self.rewrite_exprs(&schema_changes, &plan.order_by)?)? .build() } From 288ca4da5c7fef16b61e25f334a40ddfdb913920 Mon Sep 17 00:00:00 2001 From: zhang2014 Date: Mon, 19 Jul 2021 22:22:49 +0800 Subject: [PATCH 3/5] Implement constant folding for deterministic functions --- common/datavalues/src/data_schema.rs | 4 - common/planners/src/lib.rs | 1 - common/planners/src/plan_expression.rs | 48 ++++- common/planners/src/plan_expression_chain.rs | 4 +- common/planners/src/plan_expression_common.rs | 2 +- .../planners/src/plan_expression_literal.rs | 6 +- common/planners/src/plan_rewriter.rs | 64 +++--- .../query/src/api/rpc/flight_actions_test.rs | 4 +- .../src/api/rpc/flight_dispatcher_test.rs | 2 +- .../query/src/api/rpc/flight_scatter_hash.rs | 2 +- .../query/src/api/rpc/flight_service_test.rs | 2 +- .../src/datasources/system/numbers_table.rs | 4 +- .../datasources/system/numbers_table_test.rs | 2 +- .../query/src/functions/context_function.rs | 4 +- .../src/interpreters/plan_scheduler_test.rs | 24 +-- fusequery/query/src/optimizers/optimizer.rs | 3 +- .../optimizers/optimizer_constant_folding.rs | 197 +++++++++++------- .../optimizer_constant_folding_test.rs | 19 +- .../optimizer_projection_push_down.rs | 11 +- .../src/optimizers/optimizer_scatters.rs | 6 +- .../optimizers/optimizer_statistics_exact.rs | 6 +- .../optimizer_statistics_exact_test.rs | 4 +- fusequery/query/src/sql/plan_parser.rs | 20 +- fusequery/query/src/sql/sql_common.rs | 12 +- fusequery/query/src/tests/number.rs | 2 +- .../0_stateless/08_0000_optimizer.result | 2 - .../suites/0_stateless/08_0000_optimizer.sql | 6 +- 27 files changed, 265 insertions(+), 196 deletions(-) diff --git a/common/datavalues/src/data_schema.rs b/common/datavalues/src/data_schema.rs index c122968bb2f63..0bbb4fd6c2f93 100644 --- a/common/datavalues/src/data_schema.rs +++ b/common/datavalues/src/data_schema.rs @@ -96,10 +96,6 @@ pub type DataSchemaRef = Arc; pub struct DataSchemaRefExt; impl DataSchemaRefExt { - pub fn empty() -> DataSchemaRef { - Arc::new(DataSchema::empty()) - } - pub fn create(fields: Vec) -> DataSchemaRef { Arc::new(DataSchema::new(fields)) } diff --git a/common/planners/src/lib.rs b/common/planners/src/lib.rs index 59c005f9e8540..f7bce11856f5f 100644 --- a/common/planners/src/lib.rs +++ b/common/planners/src/lib.rs @@ -136,7 +136,6 @@ pub use plan_projection::ProjectionPlan; pub use plan_read_datasource::ReadDataSourcePlan; pub use plan_remote::RemotePlan; pub use plan_rewriter::PlanRewriter; -pub use plan_rewriter::SchemaChanges; pub use plan_rewriter::RewriteHelper; pub use plan_scan::ScanPlan; pub use plan_select::SelectPlan; diff --git a/common/planners/src/plan_expression.rs b/common/planners/src/plan_expression.rs index af8f6a92c9fda..2f33408960a6f 100644 --- a/common/planners/src/plan_expression.rs +++ b/common/planners/src/plan_expression.rs @@ -48,7 +48,10 @@ pub enum Expression { /// Column name. Column(String), /// Constant value. - Literal(DataValue), + Literal { + value: DataValue, + column_name: Option, + }, /// A unary expression such as "NOT foo" UnaryExpression { op: String, expr: Box }, @@ -101,15 +104,50 @@ pub enum Expression { } impl Expression { + pub fn create_literal(value: DataValue) -> Expression { + Expression::Literal { + value, + column_name: None, + } + } + pub fn column_name(&self) -> String { match self { Expression::Alias(name, _expr) => name.clone(), - Expression::ScalarFunction { op, .. } => { + Expression::Column(name) => name.clone(), + Expression::Literal { + column_name: Some(name), + .. + } => name.clone(), + Expression::UnaryExpression { op, expr } => { + format!("({} {})", op, expr.column_name()) + } + Expression::BinaryExpression { op, left, right } => { + format!("({} {} {})", left.column_name(), op, right.column_name()) + } + Expression::ScalarFunction { op, args } => { match OP_SET.get(&op.to_lowercase().as_ref()) { Some(_) => format!("{}()", op), - None => format!("{:?}", self), + None => { + let args_column_name = + args.iter().map(Expression::column_name).collect::>(); + + format!("{}({})", op, args_column_name.join(", ")) + } + } + } + Expression::AggregateFunction { op, distinct, args } => { + let args_column_name = args.iter().map(Expression::column_name).collect::>(); + + match distinct { + true => format!("{}(distinct {})", op, args_column_name.join(", ")), + false => format!("{}({})", op, args_column_name.join(", ")), } } + Expression::Sort { expr, .. } => expr.column_name(), + Expression::Cast { expr, data_type } => { + format!("cast({} as {:?})", expr.column_name(), data_type) + } Expression::Subquery { name, .. } => name.clone(), Expression::ScalarSubquery { name, .. } => name.clone(), _ => format!("{:?}", self), @@ -164,7 +202,7 @@ impl Expression { match self { Expression::Alias(_, expr) => expr.to_data_type(input_schema), Expression::Column(s) => Ok(input_schema.field_with_name(s)?.data_type().clone()), - Expression::Literal(v) => Ok(v.data_type()), + Expression::Literal { value, .. } => Ok(value.data_type()), Expression::Subquery { query_plan, .. } => Ok(Self::to_subquery_type(query_plan)), Expression::ScalarSubquery { query_plan, .. } => { Ok(Self::to_scalar_subquery_type(query_plan)) @@ -246,7 +284,7 @@ impl fmt::Debug for Expression { match self { Expression::Alias(alias, v) => write!(f, "{:?} as {:#}", v, alias), Expression::Column(ref v) => write!(f, "{:#}", v), - Expression::Literal(ref v) => write!(f, "{:#}", v), + Expression::Literal { ref value, .. } => write!(f, "{:#}", value), Expression::Subquery { name, .. } => write!(f, "subquery({})", name), Expression::ScalarSubquery { name, .. } => write!(f, "scalar subquery({})", name), Expression::BinaryExpression { op, left, right } => { diff --git a/common/planners/src/plan_expression_chain.rs b/common/planners/src/plan_expression_chain.rs index 515b04ad1c710..9d5d8eed68edc 100644 --- a/common/planners/src/plan_expression_chain.rs +++ b/common/planners/src/plan_expression_chain.rs @@ -56,10 +56,10 @@ impl ExpressionChain { }; self.actions.push(ExpressionAction::Input(input)); } - Expression::Literal(l) => { + Expression::Literal { value, .. } => { let value = ActionConstant { name: expr.column_name(), - value: l.clone(), + value: value.clone(), }; self.actions.push(ExpressionAction::Constant(value)); diff --git a/common/planners/src/plan_expression_common.rs b/common/planners/src/plan_expression_common.rs index 6b47bcf9c4ed9..39febe2f7d6f7 100644 --- a/common/planners/src/plan_expression_common.rs +++ b/common/planners/src/plan_expression_common.rs @@ -295,7 +295,7 @@ where F: Fn(&Expression) -> Result> { }), Expression::Column(_) - | Expression::Literal(_) + | Expression::Literal { .. } | Expression::Subquery { .. } | Expression::ScalarSubquery { .. } => Ok(expr.clone()), }, diff --git a/common/planners/src/plan_expression_literal.rs b/common/planners/src/plan_expression_literal.rs index 255236bb349b0..209b5d90e1ed5 100644 --- a/common/planners/src/plan_expression_literal.rs +++ b/common/planners/src/plan_expression_literal.rs @@ -12,13 +12,13 @@ pub trait Literal { impl Literal for &str { fn to_literal(&self) -> Expression { - Expression::Literal(DataValue::Utf8(Some(self.to_string()))) + Expression::create_literal(DataValue::Utf8(Some(self.to_string()))) } } impl Literal for String { fn to_literal(&self) -> Expression { - Expression::Literal(DataValue::Utf8(Some(self.clone()))) + Expression::create_literal(DataValue::Utf8(Some(self.clone()))) } } @@ -27,7 +27,7 @@ macro_rules! make_literal { #[allow(missing_docs)] impl Literal for $TYPE { fn to_literal(&self) -> Expression { - Expression::Literal(DataValue::$SCALAR(Some(self.clone()))) + Expression::create_literal(DataValue::$SCALAR(Some(self.clone()))) } } }; diff --git a/common/planners/src/plan_rewriter.rs b/common/planners/src/plan_rewriter.rs index 4cb348f738bc5..258f006c3ebb8 100644 --- a/common/planners/src/plan_rewriter.rs +++ b/common/planners/src/plan_rewriter.rs @@ -100,30 +100,30 @@ pub trait PlanRewriter { } // TODO: Move it to ExpressionsRewrite trait - fn rewrite_expr(&mut self, changes: &SchemaChanges, expr: &Expression) -> Result { + fn rewrite_expr(&mut self, schema: &DataSchemaRef, expr: &Expression) -> Result { match expr { Expression::Alias(alias, input) => Ok(Expression::Alias( alias.clone(), - Box::new(self.rewrite_expr(changes, input.as_ref())?), + Box::new(self.rewrite_expr(schema, input.as_ref())?), )), Expression::UnaryExpression { op, expr } => Ok(Expression::UnaryExpression { op: op.clone(), - expr: Box::new(self.rewrite_expr(changes, expr.as_ref())?), + expr: Box::new(self.rewrite_expr(schema, expr.as_ref())?), }), Expression::BinaryExpression { op, left, right } => Ok(Expression::BinaryExpression { op: op.clone(), - left: Box::new(self.rewrite_expr(changes, left.as_ref())?), - right: Box::new(self.rewrite_expr(changes, right.as_ref())?), + left: Box::new(self.rewrite_expr(schema, left.as_ref())?), + right: Box::new(self.rewrite_expr(schema, right.as_ref())?), }), Expression::ScalarFunction { op, args } => Ok(Expression::ScalarFunction { op: op.clone(), - args: self.rewrite_exprs(changes, args)?, + args: self.rewrite_exprs(schema, args)?, }), Expression::AggregateFunction { op, distinct, args } => { Ok(Expression::AggregateFunction { op: op.clone(), distinct: *distinct, - args: self.rewrite_exprs(changes, args)?, + args: self.rewrite_exprs(schema, args)?, }) } Expression::Sort { @@ -131,17 +131,20 @@ pub trait PlanRewriter { asc, nulls_first, } => Ok(Expression::Sort { - expr: Box::new(self.rewrite_expr(changes, expr.as_ref())?), + expr: Box::new(self.rewrite_expr(schema, expr.as_ref())?), asc: *asc, nulls_first: *nulls_first, }), Expression::Cast { expr, data_type } => Ok(Expression::Cast { - expr: Box::new(self.rewrite_expr(changes, expr.as_ref())?), + expr: Box::new(self.rewrite_expr(schema, expr.as_ref())?), data_type: data_type.clone(), }), Expression::Wildcard => Ok(Expression::Wildcard), Expression::Column(column_name) => Ok(Expression::Column(column_name.clone())), - Expression::Literal(value) => Ok(Expression::Literal(value.clone())), + Expression::Literal { value, column_name } => Ok(Expression::Literal { + value: value.clone(), + column_name: column_name.clone(), + }), Expression::Subquery { name, query_plan } => { let new_subquery = self.rewrite_subquery_plan(query_plan)?; Ok(Expression::Subquery { @@ -162,12 +165,12 @@ pub trait PlanRewriter { // TODO: Move it to ExpressionsRewrite trait fn rewrite_exprs( &mut self, - changes: &SchemaChanges, + schema: &DataSchemaRef, exprs: &[Expression], ) -> Result { exprs .iter() - .map(|expr| Self::rewrite_expr(self, changes, expr)) + .map(|expr| Self::rewrite_expr(self, schema, expr)) .collect::>>() } @@ -201,15 +204,13 @@ pub trait PlanRewriter { fn rewrite_projection(&mut self, plan: &ProjectionPlan) -> Result { let new_input = self.rewrite_plan_node(plan.input.as_ref())?; - let schema_change = SchemaChanges::new(&plan.input.schema(), &new_input.schema()); - let new_exprs = self.rewrite_exprs(&schema_change, &plan.expr)?; + let new_exprs = self.rewrite_exprs(&new_input.schema(), &plan.expr)?; PlanBuilder::from(&new_input).project(&new_exprs)?.build() } fn rewrite_expression(&mut self, plan: &ExpressionPlan) -> Result { let new_input = self.rewrite_plan_node(plan.input.as_ref())?; - let schema_change = SchemaChanges::new(&plan.input.schema(), &new_input.schema()); - let new_exprs = self.rewrite_exprs(&schema_change, &plan.exprs)?; + let new_exprs = self.rewrite_exprs(&new_input.schema(), &plan.exprs)?; PlanBuilder::from(&new_input) .expression(&new_exprs, &plan.desc)? .build() @@ -222,22 +223,19 @@ pub trait PlanRewriter { fn rewrite_filter(&mut self, plan: &FilterPlan) -> Result { let new_input = self.rewrite_plan_node(plan.input.as_ref())?; - let schema_change = SchemaChanges::new(&plan.input.schema(), &new_input.schema()); - let new_predicate = self.rewrite_expr(&schema_change, &plan.predicate)?; + let new_predicate = self.rewrite_expr(&new_input.schema(), &plan.predicate)?; PlanBuilder::from(&new_input).filter(new_predicate)?.build() } fn rewrite_having(&mut self, plan: &HavingPlan) -> Result { let new_input = self.rewrite_plan_node(plan.input.as_ref())?; - let schema_change = SchemaChanges::new(&plan.input.schema(), &new_input.schema()); - let new_predicate = self.rewrite_expr(&schema_change, &plan.predicate)?; + let new_predicate = self.rewrite_expr(&new_input.schema(), &plan.predicate)?; PlanBuilder::from(&new_input).having(new_predicate)?.build() } fn rewrite_sort(&mut self, plan: &SortPlan) -> Result { let new_input = self.rewrite_plan_node(plan.input.as_ref())?; - let schema_change = SchemaChanges::new(&plan.input.schema(), &new_input.schema()); - let new_order_by = self.rewrite_exprs(&schema_change, &plan.order_by)?; + let new_order_by = self.rewrite_exprs(&new_input.schema(), &plan.order_by)?; PlanBuilder::from(&new_input).sort(&new_order_by)?.build() } @@ -324,20 +322,6 @@ pub trait PlanRewriter { } } -pub struct SchemaChanges { - pub before_input_schema: DataSchemaRef, - pub after_input_schema: DataSchemaRef, -} - -impl SchemaChanges { - pub fn new(before: &DataSchemaRef, after: &DataSchemaRef) -> SchemaChanges { - SchemaChanges { - before_input_schema: before.clone(), - after_input_schema: after.clone(), - } - } -} - pub struct RewriteHelper {} struct QueryAliasData { @@ -499,7 +483,7 @@ impl RewriteHelper { }) } Expression::Wildcard - | Expression::Literal(_) + | Expression::Literal { .. } | Expression::Subquery { .. } | Expression::ScalarSubquery { .. } | Expression::Sort { .. } => Ok(expr.clone()), @@ -552,7 +536,7 @@ impl RewriteHelper { Ok(match expr { Expression::Alias(_, expr) => vec![expr.as_ref().clone()], Expression::Column(_) => vec![], - Expression::Literal(_) => vec![], + Expression::Literal { .. } => vec![], Expression::Subquery { .. } => vec![], Expression::ScalarSubquery { .. } => vec![], Expression::UnaryExpression { expr, .. } => { @@ -574,7 +558,7 @@ impl RewriteHelper { Ok(match expr { Expression::Alias(_, expr) => Self::expression_plan_columns(expr)?, Expression::Column(_) => vec![expr.clone()], - Expression::Literal(_) => vec![], + Expression::Literal { .. } => vec![], Expression::Subquery { .. } => vec![], Expression::ScalarSubquery { .. } => vec![], Expression::UnaryExpression { expr, .. } => Self::expression_plan_columns(expr)?, @@ -647,7 +631,7 @@ impl RewriteHelper { Expression::Alias(alias.clone(), Box::from(expressions[0].clone())) } Expression::Column(_) => expr.clone(), - Expression::Literal(_) => expr.clone(), + Expression::Literal { .. } => expr.clone(), Expression::BinaryExpression { op, .. } => Expression::BinaryExpression { left: Box::new(expressions[0].clone()), op: op.clone(), diff --git a/fusequery/query/src/api/rpc/flight_actions_test.rs b/fusequery/query/src/api/rpc/flight_actions_test.rs index 4b3686b597d31..88656eaf2fe1d 100644 --- a/fusequery/query/src/api/rpc/flight_actions_test.rs +++ b/fusequery/query/src/api/rpc/flight_actions_test.rs @@ -21,7 +21,7 @@ async fn test_shuffle_action_try_into() -> Result<()> { stage_id: String::from("stage_id"), plan: parse_query("SELECT number FROM numbers(5)")?, sinks: vec![String::from("stream_id")], - scatters_expression: Expression::Literal(DataValue::UInt64(Some(1))), + scatters_expression: Expression::create_literal(DataValue::UInt64(Some(1))), }; let from_action = FlightAction::PrepareShuffleAction(shuffle_action); @@ -36,7 +36,7 @@ async fn test_shuffle_action_try_into() -> Result<()> { assert_eq!(action.sinks, vec![String::from("stream_id")]); assert_eq!( action.scatters_expression, - Expression::Literal(DataValue::UInt64(Some(1))) + Expression::create_literal(DataValue::UInt64(Some(1))) ); } } diff --git a/fusequery/query/src/api/rpc/flight_dispatcher_test.rs b/fusequery/query/src/api/rpc/flight_dispatcher_test.rs index 6d163839c450a..d891ce92d5892 100644 --- a/fusequery/query/src/api/rpc/flight_dispatcher_test.rs +++ b/fusequery/query/src/api/rpc/flight_dispatcher_test.rs @@ -51,7 +51,7 @@ async fn test_run_shuffle_action_with_no_scatters() -> Result<()> { stage_id: stage_id.clone(), plan: parse_query("SELECT number FROM numbers(5)")?, sinks: vec![stream_id.clone()], - scatters_expression: Expression::Literal(DataValue::UInt64(Some(1))), + scatters_expression: Expression::create_literal(DataValue::UInt64(Some(1))), }), )?; diff --git a/fusequery/query/src/api/rpc/flight_scatter_hash.rs b/fusequery/query/src/api/rpc/flight_scatter_hash.rs index a940dfb1d0e92..f33df20488f37 100644 --- a/fusequery/query/src/api/rpc/flight_scatter_hash.rs +++ b/fusequery/query/src/api/rpc/flight_scatter_hash.rs @@ -84,7 +84,7 @@ impl HashFlightScatter { expr: Box::new(expr), data_type: DataType::UInt64, }, - Expression::Literal(DataValue::UInt64(Some(num as u64))), + Expression::create_literal(DataValue::UInt64(Some(num as u64))), ], } } diff --git a/fusequery/query/src/api/rpc/flight_service_test.rs b/fusequery/query/src/api/rpc/flight_service_test.rs index 6151548211b18..8107299cc3df1 100644 --- a/fusequery/query/src/api/rpc/flight_service_test.rs +++ b/fusequery/query/src/api/rpc/flight_service_test.rs @@ -159,7 +159,7 @@ fn do_action_request(query_id: &str, stage_id: &str) -> Result> stage_id: String::from(stage_id), plan: parse_query("SELECT number FROM numbers(5)")?, sinks: vec![String::from("stream_id")], - scatters_expression: Expression::Literal(DataValue::UInt64(Some(1))), + scatters_expression: Expression::create_literal(DataValue::UInt64(Some(1))), }); Ok(Request::new(flight_action.try_into()?)) diff --git a/fusequery/query/src/datasources/system/numbers_table.rs b/fusequery/query/src/datasources/system/numbers_table.rs index aca4f6a9a0bc0..1be55ce96f4b6 100644 --- a/fusequery/query/src/datasources/system/numbers_table.rs +++ b/fusequery/query/src/datasources/system/numbers_table.rs @@ -78,8 +78,8 @@ impl Table for NumbersTable { ) -> Result { let mut total = None; let ScanPlan { table_args, .. } = scan.clone(); - if let Some(Expression::Literal(v)) = table_args { - total = Some(v.as_u64()?); + if let Some(Expression::Literal { value, .. }) = table_args { + total = Some(value.as_u64()?); } let total = total.ok_or_else(|| { diff --git a/fusequery/query/src/datasources/system/numbers_table_test.rs b/fusequery/query/src/datasources/system/numbers_table_test.rs index 5ff3f98308c0c..9ca050f88fa6d 100644 --- a/fusequery/query/src/datasources/system/numbers_table_test.rs +++ b/fusequery/query/src/datasources/system/numbers_table_test.rs @@ -19,7 +19,7 @@ async fn test_number_table() -> Result<()> { let scan = &ScanPlan { schema_name: "scan_test".to_string(), table_schema: DataSchemaRefExt::create(vec![]), - table_args: Some(Expression::Literal(DataValue::UInt64(Some(8)))), + table_args: Some(Expression::create_literal(DataValue::UInt64(Some(8)))), projected_schema: DataSchemaRefExt::create(vec![DataField::new( "number", DataType::UInt64, diff --git a/fusequery/query/src/functions/context_function.rs b/fusequery/query/src/functions/context_function.rs index eb8dd7e783453..34eb68c496824 100644 --- a/fusequery/query/src/functions/context_function.rs +++ b/fusequery/query/src/functions/context_function.rs @@ -26,10 +26,10 @@ impl ContextFunction { } Ok(match name.to_lowercase().as_str() { - "database" => vec![Expression::Literal(DataValue::Utf8(Some( + "database" => vec![Expression::create_literal(DataValue::Utf8(Some( ctx.get_current_database(), )))], - "version" => vec![Expression::Literal(DataValue::Utf8(Some( + "version" => vec![Expression::create_literal(DataValue::Utf8(Some( ctx.get_fuse_version(), )))], _ => vec![], diff --git a/fusequery/query/src/interpreters/plan_scheduler_test.rs b/fusequery/query/src/interpreters/plan_scheduler_test.rs index 909c23415441e..1fa3e8bcdfc99 100644 --- a/fusequery/query/src/interpreters/plan_scheduler_test.rs +++ b/fusequery/query/src/interpreters/plan_scheduler_test.rs @@ -52,7 +52,7 @@ async fn test_scheduler_plan_with_one_convergent_stage() -> Result<()> { let scheduler = PlanScheduler::try_create(context.clone())?; let scheduled_tasks = scheduler.reschedule(&PlanNode::Stage(StagePlan { kind: StageKind::Convergent, - scatters_expr: Expression::Literal(DataValue::UInt64(Some(0))), + scatters_expr: Expression::create_literal(DataValue::UInt64(Some(0))), input: Arc::new(PlanNode::Empty(EmptyPlan::cluster())), }))?; @@ -69,7 +69,7 @@ async fn test_scheduler_plan_with_one_convergent_stage() -> Result<()> { assert_eq!(remote_actions[0].1.sinks, vec![String::from("dummy_local")]); assert_eq!( remote_actions[0].1.scatters_expression, - Expression::Literal(DataValue::UInt64(Some(0))) + Expression::create_literal(DataValue::UInt64(Some(0))) ); assert_eq!( remote_actions[0].1.plan, @@ -80,7 +80,7 @@ async fn test_scheduler_plan_with_one_convergent_stage() -> Result<()> { assert_eq!(remote_actions[1].1.sinks, vec![String::from("dummy_local")]); assert_eq!( remote_actions[1].1.scatters_expression, - Expression::Literal(DataValue::UInt64(Some(0))) + Expression::create_literal(DataValue::UInt64(Some(0))) ); assert_eq!( remote_actions[1].1.plan, @@ -122,7 +122,7 @@ async fn test_scheduler_plan_with_convergent_and_expansive_stage() -> Result<()> let scheduled_tasks = scheduler.reschedule(&PlanNode::Select(SelectPlan { input: Arc::new(PlanNode::Stage(StagePlan { kind: StageKind::Convergent, - scatters_expr: Expression::Literal(DataValue::UInt64(Some(0))), + scatters_expr: Expression::create_literal(DataValue::UInt64(Some(0))), input: Arc::new(PlanNode::Select(SelectPlan { input: Arc::new(PlanNode::Stage(StagePlan { kind: StageKind::Expansive, @@ -165,14 +165,14 @@ async fn test_scheduler_plan_with_convergent_and_expansive_stage() -> Result<()> assert_eq!(remote_actions[1].1.sinks, vec![String::from("dummy_local")]); assert_eq!( remote_actions[1].1.scatters_expression, - Expression::Literal(DataValue::UInt64(Some(0))) + Expression::create_literal(DataValue::UInt64(Some(0))) ); assert_eq!(remote_actions[2].0.name, String::from("dummy")); assert_eq!(remote_actions[2].1.sinks, vec![String::from("dummy_local")]); assert_eq!( remote_actions[2].1.scatters_expression, - Expression::Literal(DataValue::UInt64(Some(0))) + Expression::create_literal(DataValue::UInt64(Some(0))) ); // Perform the same plan in different nodes @@ -225,11 +225,11 @@ async fn test_scheduler_plan_with_convergent_and_normal_stage() -> Result<()> { let scheduled_tasks = plan_scheduler.reschedule(&PlanNode::Select(SelectPlan { input: Arc::new(PlanNode::Stage(StagePlan { kind: StageKind::Convergent, - scatters_expr: Expression::Literal(DataValue::UInt64(Some(1))), + scatters_expr: Expression::create_literal(DataValue::UInt64(Some(1))), input: Arc::new(PlanNode::Select(SelectPlan { input: Arc::new(PlanNode::Stage(StagePlan { kind: StageKind::Normal, - scatters_expr: Expression::Literal(DataValue::UInt64(Some(0))), + scatters_expr: Expression::create_literal(DataValue::UInt64(Some(0))), input: Arc::new(PlanNode::Empty(EmptyPlan::cluster())), })), })), @@ -252,7 +252,7 @@ async fn test_scheduler_plan_with_convergent_and_normal_stage() -> Result<()> { ]); assert_eq!( remote_actions[0].1.scatters_expression, - Expression::Literal(DataValue::UInt64(Some(0))) + Expression::create_literal(DataValue::UInt64(Some(0))) ); assert_eq!( remote_actions[0].1.plan, @@ -266,7 +266,7 @@ async fn test_scheduler_plan_with_convergent_and_normal_stage() -> Result<()> { ]); assert_eq!( remote_actions[2].1.scatters_expression, - Expression::Literal(DataValue::UInt64(Some(0))) + Expression::create_literal(DataValue::UInt64(Some(0))) ); assert_eq!( remote_actions[2].1.plan, @@ -277,14 +277,14 @@ async fn test_scheduler_plan_with_convergent_and_normal_stage() -> Result<()> { assert_eq!(remote_actions[1].1.sinks, vec![String::from("dummy_local")]); assert_eq!( remote_actions[1].1.scatters_expression, - Expression::Literal(DataValue::UInt64(Some(1))) + Expression::create_literal(DataValue::UInt64(Some(1))) ); assert_eq!(remote_actions[3].0.name, String::from("dummy")); assert_eq!(remote_actions[3].1.sinks, vec![String::from("dummy_local")]); assert_eq!( remote_actions[3].1.scatters_expression, - Expression::Literal(DataValue::UInt64(Some(1))) + Expression::create_literal(DataValue::UInt64(Some(1))) ); // Perform the same plan in different nodes diff --git a/fusequery/query/src/optimizers/optimizer.rs b/fusequery/query/src/optimizers/optimizer.rs index 1a50b9f39ebe2..800a31d71edca 100644 --- a/fusequery/query/src/optimizers/optimizer.rs +++ b/fusequery/query/src/optimizers/optimizer.rs @@ -7,7 +7,8 @@ use common_planners::PlanNode; use common_tracing::tracing; use crate::optimizers::optimizer_scatters::ScattersOptimizer; -use crate::optimizers::{ProjectionPushDownOptimizer, ConstantFoldingOptimizer}; +use crate::optimizers::ConstantFoldingOptimizer; +use crate::optimizers::ProjectionPushDownOptimizer; use crate::optimizers::StatisticsExactOptimizer; use crate::sessions::FuseQueryContextRef; diff --git a/fusequery/query/src/optimizers/optimizer_constant_folding.rs b/fusequery/query/src/optimizers/optimizer_constant_folding.rs index 4b720805f049d..79c9d6a234a57 100644 --- a/fusequery/query/src/optimizers/optimizer_constant_folding.rs +++ b/fusequery/query/src/optimizers/optimizer_constant_folding.rs @@ -2,21 +2,22 @@ // // SPDX-License-Identifier: Apache-2.0. +use common_datablocks::DataBlock; use common_datavalues::prelude::*; use common_exception::ErrorCode; use common_exception::Result; -use common_planners::{AggregatorFinalPlan, Expressions, SchemaChanges}; +use common_functions::scalars::FunctionFactory; +use common_planners::AggregatorFinalPlan; use common_planners::AggregatorPartialPlan; use common_planners::Expression; +use common_planners::Expressions; use common_planners::PlanBuilder; use common_planners::PlanNode; use common_planners::PlanRewriter; use crate::optimizers::Optimizer; -use crate::sessions::FuseQueryContextRef; -use common_functions::scalars::FunctionFactory; use crate::pipelines::transforms::ExpressionExecutor; -use common_datablocks::DataBlock; +use crate::sessions::FuseQueryContextRef; pub struct ConstantFoldingOptimizer {} @@ -30,29 +31,46 @@ impl ConstantFoldingImpl { } fn constants_arguments(args: &[Expression]) -> bool { - !args.iter().any(|expr| !matches!(expr, Expression::Literal(_))) + !args + .iter() + .any(|expr| !matches!(expr, Expression::Literal { .. })) } - fn rewrite_function(op: &str, args: Expressions) -> Result> { + fn rewrite_function(op: &str, args: Expressions, name: String, f: F) -> Result + where F: Fn(&str, Expressions) -> Expression { let function = FunctionFactory::get(op)?; - match function.is_deterministic() { - true => Self::rewrite_deterministic_function(op, args), - false => Ok(None) + + if function.is_deterministic() && ConstantFoldingImpl::constants_arguments(&args) { + let op = op.to_string(); + return ConstantFoldingImpl::execute_expression( + Expression::ScalarFunction { op, args }, + name, + ); } + + Ok(f(op, args)) } - fn rewrite_deterministic_function(op: &str, args: Expressions) -> Result> { - match ConstantFoldingImpl::constants_arguments(&args) { - false => Ok(None), - true => ConstantFoldingImpl::execute_expression(Expression::ScalarFunction { - op: op.to_string(), - args: args, - }), - } + fn create_scalar_function(op: &str, args: Expressions) -> Expression { + let op = op.to_string(); + Expression::ScalarFunction { op, args } + } + + fn create_unary_expression(op: &str, mut args: Expressions) -> Expression { + let op = op.to_string(); + let expr = Box::new(args.remove(0)); + Expression::UnaryExpression { op, expr } + } + + fn create_binary_expression(op: &str, mut args: Expressions) -> Expression { + let op = op.to_string(); + let left = Box::new(args.remove(0)); + let right = Box::new(args.remove(0)); + Expression::BinaryExpression { op, left, right } } fn expr_executor(schema: &DataSchemaRef, expr: Expression) -> Result { - let output_fields = vec![expr.to_data_field(&schema)?]; + let output_fields = vec![expr.to_data_field(schema)?]; let output_schema = DataSchemaRefExt::create(output_fields); ExpressionExecutor::try_create( "Constant folding optimizer.", @@ -63,95 +81,112 @@ impl ConstantFoldingImpl { ) } - fn execute_expression(expression: Expression) -> Result> { + fn execute_expression(expression: Expression, origin_name: String) -> Result { let input_fields = vec![DataField::new("_dummy", DataType::UInt8, false)]; let input_schema = Arc::new(DataSchema::new(input_fields)); let expression_executor = Self::expr_executor(&input_schema, expression)?; let dummy_columns = vec![DataColumn::Constant(DataValue::UInt8(Some(1)), 1)]; - let data_block = DataBlock::create(input_schema.clone(), dummy_columns); + let data_block = DataBlock::create(input_schema, dummy_columns); let executed_data_block = expression_executor.execute(&data_block)?; - assert_eq!(executed_data_block.num_rows(), 1); - assert_eq!(executed_data_block.num_columns(), 1); - Ok(Some(Expression::Literal(executed_data_block.column(0).to_values()?[0].clone()))) + ConstantFoldingImpl::convert_to_expression(origin_name, executed_data_block) + } + + fn convert_to_expression(column_name: String, data_block: DataBlock) -> Result { + assert_eq!(data_block.num_rows(), 1); + assert_eq!(data_block.num_columns(), 1); + + let column_name = Some(column_name); + let value = data_block.column(0).to_values()?.remove(0); + Ok(Expression::Literal { value, column_name }) } } impl PlanRewriter for ConstantFoldingImpl { - fn rewrite_expr(&mut self, changes: &SchemaChanges, expr: &Expression) -> Result { - match expr { + fn rewrite_expr(&mut self, schema: &DataSchemaRef, origin: &Expression) -> Result { + /* TODO: constant folding for subquery and scalar subquery + * For example: + * before optimize: SELECT (SELECT 1 + 2) + * after optimize: SELECT 3 + */ + match origin { Expression::Alias(alias, expr) => { - Self::rewrite_alias(alias, self.rewrite_expr(changes, expr)?) - }, + Self::rewrite_alias(alias, self.rewrite_expr(schema, expr)?) + } Expression::ScalarFunction { op, args } => { let new_args = args .iter() - .map(|expr| Self::rewrite_expr(self, changes, expr)) + .map(|expr| Self::rewrite_expr(self, schema, expr)) .collect::>>()?; - match Self::rewrite_function(op, new_args.clone())? { - Some(new_expr) => Ok(new_expr), - None => Ok(Expression::ScalarFunction { - op: op.clone(), - args: new_args, - }), - } + let origin_name = origin.column_name(); + Self::rewrite_function(op, new_args, origin_name, Self::create_scalar_function) } - Expression::UnaryExpression { op, expr: inner_expr } => { - let new_expr = self.rewrite_expr(changes, inner_expr)?; - match Self::rewrite_function(op, vec![new_expr.clone()])? { - Some(new_expr) => Ok(new_expr), - None => Ok(Expression::UnaryExpression { - op: op.clone(), - expr: Box::new(new_expr), - }), - } + Expression::UnaryExpression { op, expr } => { + let origin_name = origin.column_name(); + let new_expr = vec![self.rewrite_expr(schema, expr)?]; + Self::rewrite_function(op, new_expr, origin_name, Self::create_unary_expression) } Expression::BinaryExpression { op, left, right } => { - let new_left = self.rewrite_expr(changes, left)?; - let new_right = self.rewrite_expr(changes, right)?; - match Self::rewrite_function(op, vec![new_left.clone(), new_right.clone()])? { - Some(new_expr) => Ok(new_expr), - None => Ok(Expression::BinaryExpression { - op: op.clone(), - left: Box::new(new_left), - right: Box::new(new_right), - }), - } - }, + let new_left = self.rewrite_expr(schema, left)?; + let new_right = self.rewrite_expr(schema, right)?; + + let origin_name = origin.column_name(); + let new_exprs = vec![new_left, new_right]; + Self::rewrite_function(op, new_exprs, origin_name, Self::create_binary_expression) + } Expression::Cast { expr, data_type } => { - let new_expr = self.rewrite_expr(changes, expr)?; - match &new_expr { - Expression::Literal(_) => Ok(Self::execute_expression( - Expression::Cast { - expr: Box::new(new_expr), - data_type: data_type.clone(), - } - )?.unwrap()), - _ => Ok(new_expr) + let new_expr = self.rewrite_expr(schema, expr)?; + + if matches!(&new_expr, Expression::Literal { .. }) { + let optimize_expr = Expression::Cast { + expr: Box::new(new_expr), + data_type: data_type.clone(), + }; + + return Self::execute_expression(optimize_expr, origin.column_name()); } + + Ok(Expression::Cast { + expr: Box::new(new_expr), + data_type: data_type.clone(), + }) + } + Expression::Sort { + expr, + asc, + nulls_first, + } => { + let new_expr = self.rewrite_expr(schema, expr)?; + Ok(ConstantFoldingImpl::create_sort(asc, nulls_first, new_expr)) } - Expression::Column(column_name) => { - let field_pos = changes.before_input_schema.index_of(column_name)?; - let new_field = changes.after_input_schema.field(field_pos); - Ok(Expression::Column(new_field.name().to_string())) + Expression::AggregateFunction { op, distinct, args } => { + let args = args + .iter() + .map(|expr| Self::rewrite_expr(self, schema, expr)) + .collect::>>()?; + + let op = op.clone(); + let distinct = *distinct; + Ok(Expression::AggregateFunction { op, distinct, args }) } - _ => Ok(expr.clone()), + _ => Ok(origin.clone()), } } fn rewrite_aggregate_partial(&mut self, plan: &AggregatorPartialPlan) -> Result { let new_input = self.rewrite_plan_node(&plan.input)?; - match self.before_group_by_schema { Some(_) => Err(ErrorCode::LogicalError( "Logical error: before group by schema must be None", )), None => { self.before_group_by_schema = Some(new_input.schema()); + let new_aggr_expr = self.rewrite_exprs(&new_input.schema(), &plan.aggr_expr)?; + let new_group_expr = self.rewrite_exprs(&new_input.schema(), &plan.group_expr)?; PlanBuilder::from(&new_input) - .aggregate_partial(&plan.aggr_expr, &plan.group_expr)? + .aggregate_partial(&new_aggr_expr, &new_group_expr)? .build() } } @@ -164,9 +199,13 @@ impl PlanRewriter for ConstantFoldingImpl { None => Err(ErrorCode::LogicalError( "Logical error: before group by schema must be Some", )), - Some(schema_before_group_by) => PlanBuilder::from(&new_input) - .aggregate_final(schema_before_group_by, &plan.aggr_expr, &plan.group_expr)? - .build(), + Some(schema_before_group_by) => { + let new_aggr_expr = self.rewrite_exprs(&new_input.schema(), &plan.aggr_expr)?; + let new_group_expr = self.rewrite_exprs(&new_input.schema(), &plan.group_expr)?; + PlanBuilder::from(&new_input) + .aggregate_final(schema_before_group_by, &new_aggr_expr, &new_group_expr)? + .build() + } } } } @@ -195,3 +234,13 @@ impl ConstantFoldingOptimizer { ConstantFoldingOptimizer {} } } + +impl ConstantFoldingImpl { + fn create_sort(asc: &bool, nulls_first: &bool, new_expr: Expression) -> Expression { + Expression::Sort { + expr: Box::new(new_expr), + asc: *asc, + nulls_first: *nulls_first, + } + } +} diff --git a/fusequery/query/src/optimizers/optimizer_constant_folding_test.rs b/fusequery/query/src/optimizers/optimizer_constant_folding_test.rs index f2031813d99a2..f21effb725853 100644 --- a/fusequery/query/src/optimizers/optimizer_constant_folding_test.rs +++ b/fusequery/query/src/optimizers/optimizer_constant_folding_test.rs @@ -5,6 +5,7 @@ #[cfg(test)] mod tests { use common_exception::Result; + use crate::optimizers::*; #[test] @@ -21,7 +22,7 @@ mod tests { name: "Projection const recursion", query: "SELECT 1 + 2 + 3", expect: "\ - Projection: 6:UInt32\ + Projection: ((1 + 2) + 3):UInt32\ \n Expression: 6:UInt32 (Before Projection)\ \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", }, @@ -37,7 +38,7 @@ mod tests { name: "Projection right non const recursion", query: "SELECT 1 + 2 + 3 + dummy", expect: "\ - Projection: (6 + dummy):UInt64\ + Projection: (((1 + 2) + 3) + dummy):UInt64\ \n Expression: (6 + dummy):UInt64 (Before Projection)\ \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", }, @@ -45,7 +46,7 @@ mod tests { name: "Projection arithmetic const recursion", query: "SELECT 1 + 2 + 3 / 3", expect: "\ - Projection: 4:Float64\ + Projection: ((1 + 2) + (3 / 3)):Float64\ \n Expression: 4:Float64 (Before Projection)\ \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", }, @@ -53,7 +54,7 @@ mod tests { name: "Projection comparisons const recursion", query: "SELECT 1 + 2 + 3 > 3", expect: "\ - Projection: true:Boolean\ + Projection: (((1 + 2) + 3) > 3):Boolean\ \n Expression: true:Boolean (Before Projection)\ \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", }, @@ -61,7 +62,7 @@ mod tests { name: "Projection cast const recursion", query: "SELECT CAST(1 AS bigint)", expect: "\ - Projection: 1:Int64\ + Projection: cast(1 as Int64):Int64\ \n Expression: 1:Int64 (Before Projection)\ \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", }, @@ -69,7 +70,7 @@ mod tests { name: "Projection hash const recursion", query: "SELECT sipHash('test_string')", expect: "\ - Projection: 17123704338732264132:UInt64\ + Projection: sipHash(test_string):UInt64\ \n Expression: 17123704338732264132:UInt64 (Before Projection)\ \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", }, @@ -77,7 +78,7 @@ mod tests { name: "Projection logics const recursion", query: "SELECT 1 = 1 AND 2 > 1", expect: "\ - Projection: true:Boolean\ + Projection: ((1 = 1) AND (2 > 1)):Boolean\ \n Expression: true:Boolean (Before Projection)\ \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", }, @@ -85,7 +86,7 @@ mod tests { name: "Projection strings const recursion", query: "SELECT SUBSTRING('1234567890' FROM 3 FOR 3)", expect: "\ - Projection: 345:Utf8\ + Projection: substring(1234567890, 3, 3):Utf8\ \n Expression: 345:Utf8 (Before Projection)\ \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", }, @@ -93,7 +94,7 @@ mod tests { name: "Projection to type name const recursion", query: "SELECT toTypeName('1234567890')", expect: "\ - Projection: Utf8:Utf8\ + Projection: toTypeName(1234567890):Utf8\ \n Expression: Utf8:Utf8 (Before Projection)\ \n ReadDataSource: scan partitions: [1], scan schema: [dummy:UInt8], statistics: [read_rows: 1, read_bytes: 1]", }, diff --git a/fusequery/query/src/optimizers/optimizer_projection_push_down.rs b/fusequery/query/src/optimizers/optimizer_projection_push_down.rs index c36292f2b05a7..6bcef5bec3487 100644 --- a/fusequery/query/src/optimizers/optimizer_projection_push_down.rs +++ b/fusequery/query/src/optimizers/optimizer_projection_push_down.rs @@ -10,7 +10,7 @@ use common_datavalues::DataSchemaRef; use common_datavalues::DataSchemaRefExt; use common_exception::ErrorCode; use common_exception::Result; -use common_planners::{AggregatorFinalPlan, SchemaChanges}; +use common_planners::AggregatorFinalPlan; use common_planners::AggregatorPartialPlan; use common_planners::EmptyPlan; use common_planners::Expression; @@ -76,27 +76,24 @@ impl PlanRewriter for ProjectionPushDownImpl { self.collect_column_names_from_expr_vec(plan.expr.as_slice())?; self.has_projection = true; let new_input = self.rewrite_plan_node(&plan.input)?; - let schema_changes = SchemaChanges::new(&plan.schema, &new_input.schema()); PlanBuilder::from(&new_input) - .project(&self.rewrite_exprs(&schema_changes, &plan.expr)?)? + .project(&self.rewrite_exprs(&new_input.schema(), &plan.expr)?)? .build() } fn rewrite_filter(&mut self, plan: &FilterPlan) -> Result { self.collect_column_names_from_expr(&plan.predicate)?; let new_input = self.rewrite_plan_node(&plan.input)?; - let schema_changes = SchemaChanges::new(&plan.schema, &new_input.schema()); PlanBuilder::from(&new_input) - .filter(self.rewrite_expr(&schema_changes, &plan.predicate)?)? + .filter(self.rewrite_expr(&new_input.schema(), &plan.predicate)?)? .build() } fn rewrite_sort(&mut self, plan: &SortPlan) -> Result { self.collect_column_names_from_expr_vec(plan.order_by.as_slice())?; let new_input = self.rewrite_plan_node(&plan.input)?; - let schema_changes = SchemaChanges::new(&plan.schema, &new_input.schema()); PlanBuilder::from(&new_input) - .sort(&self.rewrite_exprs(&schema_changes, &plan.order_by)?)? + .sort(&self.rewrite_exprs(&new_input.schema(), &plan.order_by)?)? .build() } diff --git a/fusequery/query/src/optimizers/optimizer_scatters.rs b/fusequery/query/src/optimizers/optimizer_scatters.rs index d09793522d044..6c131a3e3acab 100644 --- a/fusequery/query/src/optimizers/optimizer_scatters.rs +++ b/fusequery/query/src/optimizers/optimizer_scatters.rs @@ -168,7 +168,7 @@ impl ScattersOptimizerImpl { fn convergent_shuffle_stage_builder(input: Arc) -> PlanBuilder { PlanBuilder::from(&PlanNode::Stage(StagePlan { kind: StageKind::Convergent, - scatters_expr: Expression::Literal(DataValue::UInt64(Some(0))), + scatters_expr: Expression::create_literal(DataValue::UInt64(Some(0))), input, })) } @@ -176,7 +176,7 @@ impl ScattersOptimizerImpl { fn convergent_shuffle_stage(input: PlanNode) -> Result { Ok(PlanNode::Stage(StagePlan { kind: StageKind::Convergent, - scatters_expr: Expression::Literal(DataValue::UInt64(Some(0))), + scatters_expr: Expression::create_literal(DataValue::UInt64(Some(0))), input: Arc::new(input), })) } @@ -307,7 +307,7 @@ impl Optimizer for ScattersOptimizer { RunningMode::Standalone => Ok(rewrite_plan), RunningMode::Cluster => Ok(PlanNode::Stage(StagePlan { kind: StageKind::Convergent, - scatters_expr: Expression::Literal(DataValue::UInt64(Some(0))), + scatters_expr: Expression::create_literal(DataValue::UInt64(Some(0))), input: Arc::new(rewrite_plan), })), } diff --git a/fusequery/query/src/optimizers/optimizer_statistics_exact.rs b/fusequery/query/src/optimizers/optimizer_statistics_exact.rs index c43fc4d449e3a..9098ac01e6677 100644 --- a/fusequery/query/src/optimizers/optimizer_statistics_exact.rs +++ b/fusequery/query/src/optimizers/optimizer_statistics_exact.rs @@ -41,7 +41,7 @@ impl PlanRewriter for StatisticsExactImpl<'_> { }], PlanNode::Expression(ExpressionPlan { input, .. }), ) if op == "count" && args.len() == 1 => match (&args[0], input.as_ref()) { - (Expression::Literal(_), PlanNode::ReadSource(read_source_plan)) + (Expression::Literal { .. }, PlanNode::ReadSource(read_source_plan)) if read_source_plan.statistics.is_exact => { let db_name = "system"; @@ -73,7 +73,9 @@ impl PlanRewriter for StatisticsExactImpl<'_> { let ser = serde_json::to_string(&states)?; PlanBuilder::from(&dummy_read_plan) .expression( - &[Expression::Literal(DataValue::Utf8(Some(ser.clone())))], + &[Expression::create_literal(DataValue::Utf8(Some( + ser.clone(), + )))], "Exact Statistics", )? .project(&[Expression::Column(ser).alias("count(0)")])? diff --git a/fusequery/query/src/optimizers/optimizer_statistics_exact_test.rs b/fusequery/query/src/optimizers/optimizer_statistics_exact_test.rs index 8660ab74a30bc..646475a0c6549 100644 --- a/fusequery/query/src/optimizers/optimizer_statistics_exact_test.rs +++ b/fusequery/query/src/optimizers/optimizer_statistics_exact_test.rs @@ -46,12 +46,12 @@ mod tests { let aggr_expr = Expression::AggregateFunction { op: "count".to_string(), distinct: false, - args: vec![Expression::Literal(DataValue::UInt64(Some(0)))], + args: vec![Expression::create_literal(DataValue::UInt64(Some(0)))], }; let plan = PlanBuilder::from(&source_plan) .expression( - &[Expression::Literal(DataValue::UInt64(Some(0)))], + &[Expression::create_literal(DataValue::UInt64(Some(0)))], "Before GroupBy", )? .aggregate_partial(&[aggr_expr.clone()], &[])? diff --git a/fusequery/query/src/sql/plan_parser.rs b/fusequery/query/src/sql/plan_parser.rs index 95096a6efb08e..a8bc63ee51c77 100644 --- a/fusequery/query/src/sql/plan_parser.rs +++ b/fusequery/query/src/sql/plan_parser.rs @@ -803,11 +803,11 @@ impl PlanParser { fn value_to_rex(value: &sqlparser::ast::Value) -> Result { match value { sqlparser::ast::Value::Number(ref n, _) => { - DataValue::try_from_literal(n).map(Expression::Literal) - } - sqlparser::ast::Value::SingleQuotedString(ref value) => { - Ok(Expression::Literal(DataValue::Utf8(Some(value.clone())))) + DataValue::try_from_literal(n).map(Expression::create_literal) } + sqlparser::ast::Value::SingleQuotedString(ref value) => Ok( + Expression::create_literal(DataValue::Utf8(Some(value.clone()))), + ), sqlparser::ast::Value::Interval { value, leading_field, @@ -822,7 +822,7 @@ impl PlanParser { fractional_seconds_precision, ), sqlparser::ast::Value::Boolean(b) => { - Ok(Expression::Literal(DataValue::Boolean(Some(*b)))) + Ok(Expression::create_literal(DataValue::Boolean(Some(*b)))) } other => Result::Err(ErrorCode::SyntaxException(format!( "Unsupported value expression: {}, type: {:?}", @@ -903,7 +903,9 @@ impl PlanParser { sqlparser::ast::Expr::Wildcard => Ok(Expression::Wildcard), sqlparser::ast::Expr::TypedString { data_type, value } => { SQLCommon::make_data_type(data_type).map(|data_type| Expression::Cast { - expr: Box::new(Expression::Literal(DataValue::Utf8(Some(value.clone())))), + expr: Box::new(Expression::create_literal(DataValue::Utf8(Some( + value.clone(), + )))), data_type, }) } @@ -924,7 +926,7 @@ impl PlanParser { if let Some(from) = substring_from { args.push(self.sql_to_rex(from, schema, select)?); } else { - args.push(Expression::Literal(DataValue::Int64(Some(1)))); + args.push(Expression::create_literal(DataValue::Int64(Some(1)))); } if let Some(len) = substring_for { @@ -1097,7 +1099,7 @@ impl PlanParser { .map(|limit_expr| { self.sql_to_rex(limit_expr, &input.schema(), select) .and_then(|limit_expr| match limit_expr { - Expression::Literal(v) => Ok(v.as_u64()? as usize), + Expression::Literal { value, .. } => Ok(value.as_u64()? as usize), _ => Err(ErrorCode::SyntaxException(format!( "Unexpected expression for LIMIT clause: {:?}", limit_expr @@ -1112,7 +1114,7 @@ impl PlanParser { let offset_expr = &offset.value; self.sql_to_rex(offset_expr, &input.schema(), select) .and_then(|offset_expr| match offset_expr { - Expression::Literal(v) => Ok(v.as_u64()? as usize), + Expression::Literal { value, .. } => Ok(value.as_u64()? as usize), _ => Err(ErrorCode::SyntaxException(format!( "Unexpected expression for OFFSET clause: {:?}", offset_expr, diff --git a/fusequery/query/src/sql/sql_common.rs b/fusequery/query/src/sql/sql_common.rs index e5319ff77e7fd..cb9b9d9156f7f 100644 --- a/fusequery/query/src/sql/sql_common.rs +++ b/fusequery/query/src/sql/sql_common.rs @@ -223,14 +223,14 @@ impl SQLCommon { } if result_month != 0 { - return Ok(Expression::Literal(DataValue::IntervalYearMonth(Some( - result_month as i32, - )))); + return Ok(Expression::create_literal(DataValue::IntervalYearMonth( + Some(result_month as i32), + ))); } let result: i64 = (result_days << 32) | result_millis; - Ok(Expression::Literal(DataValue::IntervalDayTime(Some( - result, - )))) + Ok(Expression::create_literal(DataValue::IntervalDayTime( + Some(result), + ))) } } diff --git a/fusequery/query/src/tests/number.rs b/fusequery/query/src/tests/number.rs index 1eb6b4da8bee7..9b3e8ccecf7e3 100644 --- a/fusequery/query/src/tests/number.rs +++ b/fusequery/query/src/tests/number.rs @@ -45,7 +45,7 @@ impl NumberTestData { &ScanPlan { schema_name: self.db.to_string(), table_schema: Arc::new(DataSchema::empty()), - table_args: Some(Expression::Literal(DataValue::Int64(Some(numbers)))), + table_args: Some(Expression::create_literal(DataValue::Int64(Some(numbers)))), projected_schema: Arc::new(DataSchema::empty()), push_downs: Extras::default(), }, diff --git a/tests/suites/0_stateless/08_0000_optimizer.result b/tests/suites/0_stateless/08_0000_optimizer.result index 33000a1ec996c..06ac61bcd46c9 100644 --- a/tests/suites/0_stateless/08_0000_optimizer.result +++ b/tests/suites/0_stateless/08_0000_optimizer.result @@ -1,5 +1,3 @@ -filter push down: push (number+1) to filter -Projection: (number + 1) as a:UInt64\n Expression: (number + 1):UInt64 (Before Projection)\n Filter: (a > 2)\n ReadDataSource: scan partitions: [16], scan schema: [number:UInt64], statistics: [read_rows: 10000, read_bytes: 80000] limit push down: push (limit 10) to projection group by push down: push alias to group by Projection: max((number + 1)) as c1:UInt64, ((number % 3) + 1) as c2:UInt16\n AggregatorFinal: groupBy=[[((number % 3) + 1)]], aggr=[[max((number + 1))]]\n AggregatorPartial: groupBy=[[((number % 3) + 1)]], aggr=[[max((number + 1))]]\n Expression: ((number % 3) + 1):UInt16, (number + 1):UInt64 (Before GroupBy)\n ReadDataSource: scan partitions: [16], scan schema: [number:UInt64], statistics: [read_rows: 10000, read_bytes: 80000] diff --git a/tests/suites/0_stateless/08_0000_optimizer.sql b/tests/suites/0_stateless/08_0000_optimizer.sql index 76e40695636a4..b93f1f38c040c 100644 --- a/tests/suites/0_stateless/08_0000_optimizer.sql +++ b/tests/suites/0_stateless/08_0000_optimizer.sql @@ -1,6 +1,8 @@ SET max_threads=16; -SELECT 'filter push down: push (number+1) to filter'; -EXPLAIN SELECT (number+1) as a from numbers_mt(10000) where a > 2; + +-- https://github.com/datafuselabs/datafuse/issues/574 +-- SELECT 'filter push down: push (number+1) to filter'; +-- EXPLAIN SELECT (number+1) as a from numbers_mt(10000) where a > 2; SELECT 'limit push down: push (limit 10) to projection'; From 110d455f5b9ed52863bd92843b7df2a4c62e8702 Mon Sep 17 00:00:00 2001 From: zhang2014 Date: Mon, 19 Jul 2021 22:28:48 +0800 Subject: [PATCH 4/5] Try fix cluster stateless test --- tests/suites/0_stateless/08_0000_optimizer_cluster.result | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/suites/0_stateless/08_0000_optimizer_cluster.result b/tests/suites/0_stateless/08_0000_optimizer_cluster.result index e51ff78e862e5..90759ca03010c 100644 --- a/tests/suites/0_stateless/08_0000_optimizer_cluster.result +++ b/tests/suites/0_stateless/08_0000_optimizer_cluster.result @@ -1,5 +1,3 @@ -filter push down: push (number+1) to filter -RedistributeStage[expr: 0]\n Projection: (number + 1) as a:UInt64\n Expression: (number + 1):UInt64 (Before Projection)\n Filter: (a > 2)\n ReadDataSource: scan partitions: [16], scan schema: [number:UInt64], statistics: [read_rows: 10000, read_bytes: 80000] limit push down: push (limit 10) to projection group by push down: push alias to group by RedistributeStage[expr: 0]\n Projection: max((number + 1)) as c1:UInt64, ((number % 3) + 1) as c2:UInt16\n AggregatorFinal: groupBy=[[((number % 3) + 1)]], aggr=[[max((number + 1))]]\n RedistributeStage[expr: sipHash(_group_by_key)]\n AggregatorPartial: groupBy=[[((number % 3) + 1)]], aggr=[[max((number + 1))]]\n Expression: ((number % 3) + 1):UInt16, (number + 1):UInt64 (Before GroupBy)\n ReadDataSource: scan partitions: [16], scan schema: [number:UInt64], statistics: [read_rows: 10000, read_bytes: 80000] From 096e0a75f5e62e78c028c955fe3ab0e05c8048c2 Mon Sep 17 00:00:00 2001 From: zhang2014 Date: Tue, 20 Jul 2021 12:00:32 +0800 Subject: [PATCH 5/5] Add comment for constant folding --- common/functions/src/scalars/function.rs | 1 + common/planners/src/plan_expression.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/common/functions/src/scalars/function.rs b/common/functions/src/scalars/function.rs index 5cbc773434344..8b207fa8a5437 100644 --- a/common/functions/src/scalars/function.rs +++ b/common/functions/src/scalars/function.rs @@ -27,6 +27,7 @@ pub trait Function: fmt::Display + Sync + Send + DynClone { fn nullable(&self, _input_schema: &DataSchema) -> Result; fn eval(&self, columns: &[DataColumn], _input_rows: usize) -> Result; + // If function returns the same result when same arguments, it is deterministic function. fn is_deterministic(&self) -> bool { true } diff --git a/common/planners/src/plan_expression.rs b/common/planners/src/plan_expression.rs index 2f33408960a6f..ed03daff02bc2 100644 --- a/common/planners/src/plan_expression.rs +++ b/common/planners/src/plan_expression.rs @@ -48,6 +48,7 @@ pub enum Expression { /// Column name. Column(String), /// Constant value. + /// Note: When literal represents a column, its column_name will not be None Literal { value: DataValue, column_name: Option,