Skip to content

Commit

Permalink
chore(query): rewrite binder lambda and udf (#13818)
Browse files Browse the repository at this point in the history
* chore(query): rewrite binder lambda and udf

* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
b41sh authored Nov 28, 2023
1 parent d7cc57c commit 282b69d
Show file tree
Hide file tree
Showing 10 changed files with 274 additions and 194 deletions.
25 changes: 16 additions & 9 deletions src/query/sql/src/executor/physical_plans/physical_lambda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,26 @@ impl PhysicalPlanBuilder {
.map(|arg| {
match arg {
ScalarExpr::BoundColumnRef(col) => {
let index = input_schema
.index_of(&col.column.index.to_string())
.unwrap();
let index = match input_schema.index_of(&col.column.index.to_string()) {
Ok(index) => index,
Err(_) => {
// the argument of lambda function may be another lambda function
match lambda_index_map.get(&col.column.column_name) {
Some(index) => *index,
None => {
return Err(ErrorCode::Internal(format!(
"Unable to get lambda function's argument \"{}\".",
col.column.column_name
)))
}
}
}
};
Ok(index)
}
ScalarExpr::LambdaFunction(inner_func) => {
// nested lambda function as an argument of parent lambda function
let index = lambda_index_map.get(&inner_func.display_name).unwrap();
Ok(*index)
}
_ => {
Err(ErrorCode::Internal(
"lambda function's argument must be a BoundColumnRef or LambdaFunction"
"Lambda function's argument must be a BoundColumnRef"
.to_string(),
))
}
Expand Down
40 changes: 17 additions & 23 deletions src/query/sql/src/executor/physical_plans/physical_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,33 +99,27 @@ impl PhysicalPlanBuilder {
let arg_indices = func
.arguments
.iter()
.map(|arg| {
match arg {
ScalarExpr::BoundColumnRef(col) => {
let index = input_schema
.index_of(&col.column.index.to_string())
.unwrap();
Ok(index)
}
ScalarExpr::UDFServerCall(inner_udf) => {
// nested udf function as an argument of parent udf function
let index = udf_index_map.get(&inner_udf.display_name).unwrap();
Ok(*index)
}
_ => {
Err(ErrorCode::Internal(
"udf function's argument must be a BoundColumnRef or UDFServerCall"
.to_string(),
))
}
.map(|arg| match arg {
ScalarExpr::BoundColumnRef(col) => {
let index =
match input_schema.index_of(&col.column.index.to_string()) {
Ok(index) => index,
Err(_) => {
return Err(ErrorCode::Internal(format!(
"Unable to get udf function's argument \"{}\".",
col.column.column_name
)));
}
};
Ok(index)
}
_ => Err(ErrorCode::Internal(
"Udf function's argument must be a BoundColumnRef".to_string(),
)),
})
.collect::<Result<Vec<_>>>()?;

udf_index_map.insert(
func.display_name.clone(),
index,
);
udf_index_map.insert(func.display_name.clone(), index);
index += 1;

let arg_exprs = func
Expand Down
5 changes: 0 additions & 5 deletions src/query/sql/src/planner/binder/bind_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ use itertools::Itertools;
use super::AggregateInfo;
use super::INTERNAL_COLUMN_FACTORY;
use crate::binder::column_binding::ColumnBinding;
use crate::binder::lambda::LambdaInfo;
use crate::binder::window::WindowInfo;
use crate::binder::ColumnBindingBuilder;
use crate::normalize_identifier;
Expand Down Expand Up @@ -117,8 +116,6 @@ pub struct BindContext {

pub windows: WindowInfo,

pub lambda_info: LambdaInfo,

/// If the `BindContext` is created from a CTE, record the cte name
pub cte_name: Option<String>,

Expand Down Expand Up @@ -170,7 +167,6 @@ impl BindContext {
bound_internal_columns: BTreeMap::new(),
aggregate_info: AggregateInfo::default(),
windows: WindowInfo::default(),
lambda_info: LambdaInfo::default(),
cte_name: None,
cte_map_ref: Box::default(),
allow_internal_columns: true,
Expand All @@ -190,7 +186,6 @@ impl BindContext {
bound_internal_columns: BTreeMap::new(),
aggregate_info: Default::default(),
windows: Default::default(),
lambda_info: Default::default(),
cte_name: parent.cte_name,
cte_map_ref: parent.cte_map_ref.clone(),
allow_internal_columns: parent.allow_internal_columns,
Expand Down
202 changes: 123 additions & 79 deletions src/query/sql/src/planner/binder/lambda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,24 @@
// limitations under the License.

use std::collections::HashMap;
use std::mem;
use std::sync::Arc;

use common_exception::ErrorCode;
use common_exception::Result;

use super::select::SelectList;
use crate::binder::ColumnBindingBuilder;
use crate::optimizer::SExpr;
use crate::plans::walk_expr_mut;
use crate::plans::BoundColumnRef;
use crate::plans::EvalScalar;
use crate::plans::Lambda;
use crate::plans::LambdaFunc;
use crate::plans::RelOperator;
use crate::plans::ScalarExpr;
use crate::plans::ScalarItem;
use crate::plans::VisitorMut;
use crate::BindContext;
use crate::Binder;
use crate::ColumnBindingBuilder;
use crate::IndexType;
use crate::MetadataRef;
use crate::Visibility;

Expand All @@ -39,31 +41,117 @@ pub struct LambdaInfo {
/// Lambda functions
pub lambda_functions: Vec<ScalarItem>,
/// Mapping: (lambda function display name) -> (derived column ref)
/// This is used to generate column in projection.
/// This is used to replace lambda with a derived column.
pub lambda_functions_map: HashMap<String, BoundColumnRef>,
/// Mapping: (lambda function display name) -> (derived index)
/// This is used to reuse already generated derived columns
pub lambda_functions_index_map: HashMap<String, IndexType>,
}

pub(super) struct LambdaRewriter<'a> {
pub bind_context: &'a mut BindContext,
pub metadata: MetadataRef,
pub(crate) struct LambdaRewriter {
lambda_info: LambdaInfo,
metadata: MetadataRef,
}

impl<'a> LambdaRewriter<'a> {
pub fn new(bind_context: &'a mut BindContext, metadata: MetadataRef) -> Self {
impl LambdaRewriter {
pub(crate) fn new(metadata: MetadataRef) -> Self {
Self {
bind_context,
lambda_info: Default::default(),
metadata,
}
}

pub(crate) fn rewrite(&mut self, s_expr: &SExpr) -> Result<SExpr> {
let mut s_expr = s_expr.clone();
if !s_expr.children.is_empty() {
let mut children = Vec::with_capacity(s_expr.children.len());
for child in s_expr.children.iter() {
children.push(Arc::new(self.rewrite(child)?));
}
s_expr.children = children;
}

// Rewrite Lambda and its arguments as derived column.
match (*s_expr.plan).clone() {
RelOperator::EvalScalar(mut plan) => {
for item in &plan.items {
// The index of Lambda item can be reused.
if let ScalarExpr::LambdaFunction(lambda) = &item.scalar {
self.lambda_info
.lambda_functions_index_map
.insert(lambda.display_name.clone(), item.index);
}
}
for item in &mut plan.items {
self.visit(&mut item.scalar)?;
}
let child_expr = self.create_lambda_expr(s_expr.children[0].clone());
let new_expr = SExpr::create_unary(Arc::new(plan.into()), child_expr);
Ok(new_expr)
}
RelOperator::Filter(mut plan) => {
for scalar in &mut plan.predicates {
self.visit(scalar)?;
}
let child_expr = self.create_lambda_expr(s_expr.children[0].clone());
let new_expr = SExpr::create_unary(Arc::new(plan.into()), child_expr);
Ok(new_expr)
}
_ => Ok(s_expr),
}
}

fn create_lambda_expr(&mut self, mut child_expr: Arc<SExpr>) -> Arc<SExpr> {
let lambda_info = &mut self.lambda_info;
if !lambda_info.lambda_functions.is_empty() {
if !lambda_info.lambda_arguments.is_empty() {
// Add an EvalScalar for the arguments of Lambda.
let mut scalar_items = mem::take(&mut lambda_info.lambda_arguments);
scalar_items.sort_by_key(|item| item.index);
let eval_scalar = EvalScalar {
items: scalar_items,
};
child_expr = Arc::new(SExpr::create_unary(
Arc::new(eval_scalar.into()),
child_expr,
));
}

let lambda_plan = Lambda {
items: mem::take(&mut lambda_info.lambda_functions),
};
Arc::new(SExpr::create_unary(
Arc::new(lambda_plan.into()),
child_expr,
))
} else {
child_expr
}
}
}

impl<'a, 'b> VisitorMut<'a> for LambdaRewriter<'b> {
impl<'a> VisitorMut<'a> for LambdaRewriter {
fn visit(&mut self, expr: &'a mut ScalarExpr) -> Result<()> {
walk_expr_mut(self, expr)?;
// replace lambda with derived column
if let ScalarExpr::LambdaFunction(lambda) = expr {
if let Some(column_ref) = self
.lambda_info
.lambda_functions_map
.get(&lambda.display_name)
{
*expr = ScalarExpr::BoundColumnRef(column_ref.clone());
} else {
return Err(ErrorCode::Internal("Rewrite lambda function failed"));
}
}
Ok(())
}

fn visit_lambda_function(&mut self, lambda_func: &'a mut LambdaFunc) -> Result<()> {
for (i, arg) in lambda_func.args.iter_mut().enumerate() {
let arg_is_lambda = matches!(arg, ScalarExpr::LambdaFunction(_));
self.visit(arg)?;
if let ScalarExpr::LambdaFunction(_) = arg {
continue;
}

let new_column_ref = if let ScalarExpr::BoundColumnRef(ref column_ref) = &arg {
column_ref.clone()
Expand All @@ -89,22 +177,28 @@ impl<'a, 'b> VisitorMut<'a> for LambdaRewriter<'b> {
}
};

self.bind_context
.lambda_info
.lambda_arguments
.push(ScalarItem {
if !arg_is_lambda {
self.lambda_info.lambda_arguments.push(ScalarItem {
index: new_column_ref.column.index,
scalar: arg.clone(),
});

}
*arg = new_column_ref.into();
}

let index = self.metadata.write().add_derived_column(
lambda_func.display_name.clone(),
(*lambda_func.return_type).clone(),
);
let index = match self
.lambda_info
.lambda_functions_index_map
.get(&lambda_func.display_name)
{
Some(index) => *index,
None => self.metadata.write().add_derived_column(
lambda_func.display_name.clone(),
(*lambda_func.return_type).clone(),
),
};

// Generate a ColumnBinding for the lambda function
let column = ColumnBindingBuilder::new(
lambda_func.display_name.clone(),
index,
Expand All @@ -118,64 +212,14 @@ impl<'a, 'b> VisitorMut<'a> for LambdaRewriter<'b> {
column,
};

self.bind_context
.lambda_info
self.lambda_info
.lambda_functions_map
.insert(lambda_func.display_name.clone(), replaced_column);
self.bind_context
.lambda_info
.lambda_functions
.push(ScalarItem {
index,
scalar: lambda_func.clone().into(),
});

Ok(())
}
}

impl Binder {
/// Analyze lambda functions in select clause, this will rewrite lambda functions.
/// See [`LambdaRewriter`] for more details.
pub(crate) fn analyze_lambda(
&mut self,
bind_context: &mut BindContext,
select_list: &mut SelectList,
) -> Result<()> {
for item in select_list.items.iter_mut() {
let mut rewriter = LambdaRewriter::new(bind_context, self.metadata.clone());
rewriter.visit(&mut item.scalar)?;
}
self.lambda_info.lambda_functions.push(ScalarItem {
index,
scalar: lambda_func.clone().into(),
});

Ok(())
}

#[async_backtrace::framed]
pub async fn bind_lambda(
&mut self,
bind_context: &mut BindContext,
child: SExpr,
) -> Result<SExpr> {
let lambda_info = &bind_context.lambda_info;
if lambda_info.lambda_functions.is_empty() {
return Ok(child);
}

let mut new_expr = child;
if !lambda_info.lambda_arguments.is_empty() {
let mut scalar_items = lambda_info.lambda_arguments.clone();
scalar_items.sort_by_key(|item| item.index);
let eval_scalar = EvalScalar {
items: scalar_items,
};
new_expr = SExpr::create_unary(Arc::new(eval_scalar.into()), Arc::new(new_expr));
}

let lambda_plan = Lambda {
items: lambda_info.lambda_functions.clone(),
};
new_expr = SExpr::create_unary(Arc::new(lambda_plan.into()), Arc::new(new_expr));

Ok(new_expr)
}
}
Loading

0 comments on commit 282b69d

Please sign in to comment.