diff --git a/src/query/sql/src/planner/expression_parser.rs b/src/query/sql/src/planner/expression_parser.rs index e5ba4a9ef159..f4552f6c5417 100644 --- a/src/query/sql/src/planner/expression_parser.rs +++ b/src/query/sql/src/planner/expression_parser.rs @@ -356,20 +356,29 @@ pub fn parse_computed_expr_to_string( pub fn parse_lambda_expr( ctx: Arc, - mut bind_context: BindContext, - columns: &[(String, DataType)], + lambda_context: &mut BindContext, + lambda_columns: &[(String, DataType)], ast: &AExpr, ) -> Result> { let metadata = Metadata::default(); - bind_context.set_expr_context(ExprContext::InLambdaFunction); + lambda_context.set_expr_context(ExprContext::InLambdaFunction); - let column_len = bind_context.all_column_bindings().len(); - for (idx, column) in columns.iter().enumerate() { - bind_context.add_column_binding( + // The column index may not be consecutive, and the length of columns + // cannot be used to calculate the column index of the lambda argument. + // We need to start from the current largest column index. + let mut column_index = lambda_context + .all_column_bindings() + .iter() + .map(|c| c.index) + .max() + .unwrap_or_default(); + for (lambda_column, lambda_column_type) in lambda_columns.iter() { + column_index += 1; + lambda_context.add_column_binding( ColumnBindingBuilder::new( - column.0.clone(), - column_len + idx, - Box::new(column.1.clone()), + lambda_column.clone(), + column_index, + Box::new(lambda_column_type.clone()), Visibility::Visible, ) .build(), @@ -379,7 +388,7 @@ pub fn parse_lambda_expr( let settings = ctx.get_settings(); let name_resolution_ctx = NameResolutionContext::try_from(settings.as_ref())?; let mut type_checker = TypeChecker::try_create( - &mut bind_context, + lambda_context, ctx.clone(), &name_resolution_ctx, Arc::new(RwLock::new(metadata)), diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index 94d8c5c1c089..8bc71e0ad94e 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -14,7 +14,9 @@ use std::collections::BTreeMap; use std::collections::HashMap; +use std::collections::HashSet; use std::collections::VecDeque; +use std::mem; use std::str::FromStr; use std::sync::Arc; use std::vec; @@ -166,6 +168,7 @@ use crate::BindContext; use crate::ColumnBinding; use crate::ColumnBindingBuilder; use crate::ColumnEntry; +use crate::IndexType; use crate::MetadataRef; use crate::Visibility; @@ -1932,16 +1935,17 @@ impl<'a> TypeChecker<'a> { vec![inner_ty.clone()] }; - let columns = params + let lambda_columns = params .iter() .zip(inner_tys.iter()) .map(|(col, ty)| (col.clone(), ty.clone())) .collect::>(); + let mut lambda_context = self.bind_context.clone(); let box (lambda_expr, lambda_type) = parse_lambda_expr( self.ctx.clone(), - self.bind_context.clone(), - &columns, + &mut lambda_context, + &lambda_columns, &lambda.expr, )?; @@ -2035,20 +2039,24 @@ impl<'a> TypeChecker<'a> { _ => { struct LambdaVisitor<'a> { bind_context: &'a BindContext, + arg_index: HashSet, args: Vec, fields: Vec, } impl<'a> ScalarVisitor<'a> for LambdaVisitor<'a> { fn visit_bound_column_ref(&mut self, col: &'a BoundColumnRef) -> Result<()> { - let contains = self + if self.arg_index.contains(&col.column.index) { + return Ok(()); + } + self.arg_index.insert(col.column.index); + let is_outer_column = self .bind_context .all_column_bindings() .iter() .map(|c| c.index) .contains(&col.column.index); - // add outer scope columns first - if contains { + if is_outer_column { let arg = ScalarExpr::BoundColumnRef(col.clone()); self.args.push(arg); let field = DataField::new( @@ -2061,24 +2069,30 @@ impl<'a> TypeChecker<'a> { } } + // Collect outer scope columns as arguments first. let mut lambda_visitor = LambdaVisitor { bind_context: self.bind_context, + arg_index: HashSet::new(), args: Vec::new(), fields: Vec::new(), }; lambda_visitor.visit(&lambda_expr)?; - // add lambda columns at end - let mut fields = lambda_visitor.fields.clone(); - let column_len = self.bind_context.all_column_bindings().len(); - for (i, inner_ty) in inner_tys.into_iter().enumerate() { - let lambda_field = DataField::new(&format!("{}", column_len + i), inner_ty); - fields.push(lambda_field); + let mut lambda_args = mem::take(&mut lambda_visitor.args); + lambda_args.push(arg); + let mut lambda_fields = mem::take(&mut lambda_visitor.fields); + // Add lambda columns as arguments at end. + for (lambda_column_name, lambda_column_type) in lambda_columns.into_iter() { + for column in lambda_context.all_column_bindings().iter().rev() { + if column.column_name == lambda_column_name { + let lambda_field = + DataField::new(&format!("{}", column.index), lambda_column_type); + lambda_fields.push(lambda_field); + break; + } + } } - let lambda_schema = DataSchema::new(fields); - let mut args = lambda_visitor.args.clone(); - args.push(arg); - + let lambda_schema = DataSchema::new(lambda_fields); let expr = lambda_expr .type_check(&lambda_schema)? .project_column_ref(|index| { @@ -2092,7 +2106,7 @@ impl<'a> TypeChecker<'a> { LambdaFunc { span, func_name: func_name.to_string(), - args, + args: lambda_args, lambda_expr: Box::new(remote_lambda_expr), lambda_display, return_type: Box::new(return_type.clone()), diff --git a/tests/sqllogictests/suites/query/functions/02_0061_function_array.test b/tests/sqllogictests/suites/query/functions/02_0061_function_array.test index 40cb07f5a50a..4ffeef10a854 100644 --- a/tests/sqllogictests/suites/query/functions/02_0061_function_array.test +++ b/tests/sqllogictests/suites/query/functions/02_0061_function_array.test @@ -419,6 +419,32 @@ SELECT arrays_zip(col1, col2) FROM t3; [(NULL,4)] [(7,5),(8,5)] +#issue 16794 + +statement ok +CREATE OR REPLACE TABLE u (id VARCHAR NULL); + +statement ok +INSERT INTO u VALUES(1),(2); + +statement ok +CREATE OR REPLACE TABLE c ( + id VARCHAR NULL, + what_fuck BOOLEAN NOT NULL, + payload VARIANT NULL +); + +statement ok +INSERT INTO c VALUES(1, true, '[1,2]'),(1, false, '[3,4]'),(2, true, '123'); + +query IT +SELECT ids.id, array_filter(array_agg(px.payload), x -> x is not null) AS px_payload + FROM u ids LEFT JOIN c px ON px.id = ids.id + GROUP BY ids.id ORDER BY ids.id; +---- +1 ['[1,2]','[3,4]'] +2 ['123'] + statement ok USE default