Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(query): fix lambda function bind column failed #17402

Merged
merged 2 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions src/query/sql/src/planner/expression_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,20 +356,29 @@ pub fn parse_computed_expr_to_string(

pub fn parse_lambda_expr(
ctx: Arc<dyn TableContext>,
mut bind_context: BindContext,
columns: &[(String, DataType)],
lambda_context: &mut BindContext,
lambda_columns: &[(String, DataType)],
ast: &AExpr,
) -> Result<Box<(ScalarExpr, DataType)>> {
let metadata = Metadata::default();
bind_context.set_expr_context(ExprContext::InLambdaFunction);
lambda_context.set_expr_context(ExprContext::InLambdaFunction);

let column_len = bind_context.all_column_bindings().len();
for (idx, column) in columns.iter().enumerate() {
bind_context.add_column_binding(
// The column index may not be consecutive, and the length of columns
// cannot be used to calculate the column index of the lambda argument.
// We need to start from the current largest column index.
let mut column_index = lambda_context
.all_column_bindings()
.iter()
.map(|c| c.index)
.max()
.unwrap_or_default();
for (lambda_column, lambda_column_type) in lambda_columns.iter() {
column_index += 1;
lambda_context.add_column_binding(
ColumnBindingBuilder::new(
column.0.clone(),
column_len + idx,
Box::new(column.1.clone()),
lambda_column.clone(),
column_index,
Box::new(lambda_column_type.clone()),
Visibility::Visible,
)
.build(),
Expand All @@ -379,7 +388,7 @@ pub fn parse_lambda_expr(
let settings = ctx.get_settings();
let name_resolution_ctx = NameResolutionContext::try_from(settings.as_ref())?;
let mut type_checker = TypeChecker::try_create(
&mut bind_context,
lambda_context,
ctx.clone(),
&name_resolution_ctx,
Arc::new(RwLock::new(metadata)),
Expand Down
48 changes: 31 additions & 17 deletions src/query/sql/src/planner/semantic/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

use std::collections::BTreeMap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::collections::VecDeque;
use std::mem;
use std::str::FromStr;
use std::sync::Arc;
use std::vec;
Expand Down Expand Up @@ -166,6 +168,7 @@ use crate::BindContext;
use crate::ColumnBinding;
use crate::ColumnBindingBuilder;
use crate::ColumnEntry;
use crate::IndexType;
use crate::MetadataRef;
use crate::Visibility;

Expand Down Expand Up @@ -1932,16 +1935,17 @@ impl<'a> TypeChecker<'a> {
vec![inner_ty.clone()]
};

let columns = params
let lambda_columns = params
.iter()
.zip(inner_tys.iter())
.map(|(col, ty)| (col.clone(), ty.clone()))
.collect::<Vec<_>>();

let mut lambda_context = self.bind_context.clone();
let box (lambda_expr, lambda_type) = parse_lambda_expr(
self.ctx.clone(),
self.bind_context.clone(),
&columns,
&mut lambda_context,
&lambda_columns,
&lambda.expr,
)?;

Expand Down Expand Up @@ -2035,20 +2039,24 @@ impl<'a> TypeChecker<'a> {
_ => {
struct LambdaVisitor<'a> {
bind_context: &'a BindContext,
arg_index: HashSet<IndexType>,
args: Vec<ScalarExpr>,
fields: Vec<DataField>,
}

impl<'a> ScalarVisitor<'a> for LambdaVisitor<'a> {
fn visit_bound_column_ref(&mut self, col: &'a BoundColumnRef) -> Result<()> {
let contains = self
if self.arg_index.contains(&col.column.index) {
return Ok(());
}
self.arg_index.insert(col.column.index);
let is_outer_column = self
.bind_context
.all_column_bindings()
.iter()
.map(|c| c.index)
.contains(&col.column.index);
// add outer scope columns first
if contains {
if is_outer_column {
let arg = ScalarExpr::BoundColumnRef(col.clone());
self.args.push(arg);
let field = DataField::new(
Expand All @@ -2061,24 +2069,30 @@ impl<'a> TypeChecker<'a> {
}
}

// Collect outer scope columns as arguments first.
let mut lambda_visitor = LambdaVisitor {
bind_context: self.bind_context,
arg_index: HashSet::new(),
args: Vec::new(),
fields: Vec::new(),
};
lambda_visitor.visit(&lambda_expr)?;

// add lambda columns at end
let mut fields = lambda_visitor.fields.clone();
let column_len = self.bind_context.all_column_bindings().len();
for (i, inner_ty) in inner_tys.into_iter().enumerate() {
let lambda_field = DataField::new(&format!("{}", column_len + i), inner_ty);
fields.push(lambda_field);
let mut lambda_args = mem::take(&mut lambda_visitor.args);
lambda_args.push(arg);
let mut lambda_fields = mem::take(&mut lambda_visitor.fields);
// Add lambda columns as arguments at end.
for (lambda_column_name, lambda_column_type) in lambda_columns.into_iter() {
for column in lambda_context.all_column_bindings().iter().rev() {
if column.column_name == lambda_column_name {
let lambda_field =
DataField::new(&format!("{}", column.index), lambda_column_type);
lambda_fields.push(lambda_field);
break;
}
}
}
let lambda_schema = DataSchema::new(fields);
let mut args = lambda_visitor.args.clone();
args.push(arg);

let lambda_schema = DataSchema::new(lambda_fields);
let expr = lambda_expr
.type_check(&lambda_schema)?
.project_column_ref(|index| {
Expand All @@ -2092,7 +2106,7 @@ impl<'a> TypeChecker<'a> {
LambdaFunc {
span,
func_name: func_name.to_string(),
args,
args: lambda_args,
lambda_expr: Box::new(remote_lambda_expr),
lambda_display,
return_type: Box::new(return_type.clone()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,32 @@ SELECT arrays_zip(col1, col2) FROM t3;
[(NULL,4)]
[(7,5),(8,5)]

#issue 16794

statement ok
CREATE OR REPLACE TABLE u (id VARCHAR NULL);

statement ok
INSERT INTO u VALUES(1),(2);

statement ok
CREATE OR REPLACE TABLE c (
id VARCHAR NULL,
what_fuck BOOLEAN NOT NULL,
payload VARIANT NULL
);

statement ok
INSERT INTO c VALUES(1, true, '[1,2]'),(1, false, '[3,4]'),(2, true, '123');

query IT
b41sh marked this conversation as resolved.
Show resolved Hide resolved
SELECT ids.id, array_filter(array_agg(px.payload), x -> x is not null) AS px_payload
FROM u ids LEFT JOIN c px ON px.id = ids.id
GROUP BY ids.id ORDER BY ids.id;
----
1 ['[1,2]','[3,4]']
2 ['123']

statement ok
USE default

Expand Down