From 03a24a912b2e5c90ffa03c006ea67189861085a5 Mon Sep 17 00:00:00 2001 From: sundyli <543950155@qq.com> Date: Thu, 9 Jan 2025 09:26:13 +0800 Subject: [PATCH] fix(query): fix register function working with nullable scalar (#17217) * fix(query): fix register function working with nullable scalar * fix(query): fix register function working with nullable scalar * fix(query): increase pool * Update 19_0005_fuzz_cte.sh * Update mysql_source.rs * fix(query): fix register function working with nullable scalar --- .../expression/src/register_vectorize.rs | 25 +++----- src/query/functions/tests/it/scalars/mod.rs | 40 ++++++++++++ .../suites/query/cte/basic_r_cte.test | 61 ++++++++++++++++++- .../19_fuzz/19_0005_fuzz_cte.result | 1 + .../0_stateless/19_fuzz/19_0005_fuzz_cte.sh | 25 ++++++++ 5 files changed, 135 insertions(+), 17 deletions(-) create mode 100755 tests/suites/0_stateless/19_fuzz/19_0005_fuzz_cte.result create mode 100755 tests/suites/0_stateless/19_fuzz/19_0005_fuzz_cte.sh diff --git a/src/query/expression/src/register_vectorize.rs b/src/query/expression/src/register_vectorize.rs index 794e2c1dc1ee4..a5332be3a084d 100755 --- a/src/query/expression/src/register_vectorize.rs +++ b/src/query/expression/src/register_vectorize.rs @@ -283,8 +283,7 @@ pub fn passthrough_nullable_1_arg( match out { Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)), - Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)), - _ => Value::Scalar(None), + Value::Scalar(out) => Value::Scalar(Some(out)), } } _ => Value::Scalar(None), @@ -308,6 +307,7 @@ pub fn passthrough_nullable_2_arg( if let Some(validity) = ctx.validity.as_ref() { args_validity = &args_validity & validity; } + ctx.validity = Some(args_validity.clone()); match (arg1.value(), arg2.value()) { (Some(arg1), Some(arg2)) => { @@ -315,8 +315,7 @@ pub fn passthrough_nullable_2_arg( match out { Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)), - Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)), - _ => Value::Scalar(None), + Value::Scalar(out) => Value::Scalar(Some(out)), } } _ => Value::Scalar(None), @@ -352,8 +351,7 @@ pub fn passthrough_nullable_3_arg Value::Column(NullableColumn::new(out, args_validity)), - Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)), - _ => Value::Scalar(None), + Value::Scalar(out) => Value::Scalar(Some(out)), } } _ => Value::Scalar(None), @@ -397,8 +395,7 @@ pub fn passthrough_nullable_4_arg< match out { Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)), - Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)), - _ => Value::Scalar(None), + Value::Scalar(out) => Value::Scalar(Some(out)), } } _ => Value::Scalar(None), @@ -427,8 +424,7 @@ pub fn combine_nullable_1_arg( out.column, &args_validity & &out.validity, )), - Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out), - _ => Value::Scalar(None), + Value::Scalar(out) => Value::Scalar(out), } } _ => Value::Scalar(None), @@ -465,8 +461,7 @@ pub fn combine_nullable_2_arg( out.column, &args_validity & &out.validity, )), - Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out), - _ => Value::Scalar(None), + Value::Scalar(out) => Value::Scalar(out), } } _ => Value::Scalar(None), @@ -505,8 +500,7 @@ pub fn combine_nullable_3_arg out.column, &args_validity & &out.validity, )), - Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out), - _ => Value::Scalar(None), + Value::Scalar(out) => Value::Scalar(out), } } _ => Value::Scalar(None), @@ -552,8 +546,7 @@ pub fn combine_nullable_4_arg Value::Scalar(out), - _ => Value::Scalar(None), + Value::Scalar(out) => Value::Scalar(out), } } _ => Value::Scalar(None), diff --git a/src/query/functions/tests/it/scalars/mod.rs b/src/query/functions/tests/it/scalars/mod.rs index 7f8c5344df4c0..770961db74f04 100644 --- a/src/query/functions/tests/it/scalars/mod.rs +++ b/src/query/functions/tests/it/scalars/mod.rs @@ -271,3 +271,43 @@ fn list_all_builtin_functions() { fn check_ambiguity() { BUILTIN_FUNCTIONS.check_ambiguity() } + +#[test] +fn test_if_function() -> Result<()> { + use databend_common_expression::types::*; + use databend_common_expression::FromData; + use databend_common_expression::Scalar; + let raw_expr = parser::parse_raw_expr("if(eq(n,1), sum_sid + 1,100)", &[ + ("n", UInt8Type::data_type()), + ("sum_sid", Int32Type::data_type().wrap_nullable()), + ]); + let expr = type_check::check(&raw_expr, &BUILTIN_FUNCTIONS)?; + let block = DataBlock::new( + vec![ + BlockEntry { + data_type: UInt8Type::data_type(), + value: Value::Column(UInt8Type::from_data(vec![2_u8, 1])), + }, + BlockEntry { + data_type: Int32Type::data_type().wrap_nullable(), + value: Value::Scalar(Scalar::Number(NumberScalar::Int32(2400_i32))), + }, + ], + 2, + ); + let func_ctx = FunctionContext::default(); + let evaluator = Evaluator::new(&block, &func_ctx, &BUILTIN_FUNCTIONS); + let result = evaluator.run(&expr).unwrap(); + let result = result + .as_column() + .unwrap() + .clone() + .as_nullable() + .unwrap() + .clone(); + + let bm = Bitmap::from_iter([true, true]); + assert_eq!(result.validity, bm); + assert_eq!(result.column, Int64Type::from_data(vec![100, 2401])); + Ok(()) +} diff --git a/tests/sqllogictests/suites/query/cte/basic_r_cte.test b/tests/sqllogictests/suites/query/cte/basic_r_cte.test index 12cd5a84b74b1..1d4ce93efcd9d 100644 --- a/tests/sqllogictests/suites/query/cte/basic_r_cte.test +++ b/tests/sqllogictests/suites/query/cte/basic_r_cte.test @@ -227,5 +227,64 @@ select cte1.a from cte1; 8 9 + +statement ok +create table train( +train_id varchar(8) not null , +departure_station varchar(32) not null, +arrival_station varchar(32) not null, +seat_count int not null +); + +statement ok +create table passenger( +passenger_id varchar(16) not null, +departure_station varchar(32) not null, +arrival_station varchar(32) not null +); + +statement ok +create table city(city varchar(32)); + +statement ok +insert into city +with t as (select 1 n union select 2 union select 3 union select 4 union select 5) +,t1 as(select row_number()over() rn from t ,t t2,t t3) +select concat('城市',rn::varchar) city from t1 where rn<=5; + +statement ok +insert into train +select concat('G',row_number()over()::varchar),c1.city,c2.city, n from city c1, city c2, (select 600 n union select 800 union select 1200 union select 1600) a ; + +statement ok +insert into passenger +select concat('P',substr((100000000+row_number()over())::varchar,2)),c1.city,c2.city from city c1, city c2 ,city c3, city c4, city c5, +city c6, (select 1 n union select 2 union select 3 union select 4) c7,(select 1 n union select 2) c8; + + +query III +with +t0 as ( +select + train_id, + seat_count, + sum(seat_count) over ( + partition by departure_station, arrival_station order by train_id + ) ::int sum_sid +from + train +) +select + sum(case when n=1 then sum_sid+1 else 0 end::int), + sum(sum_sid), + sum(seat_count) +from + t0,(select 1 n union all select 2); +---- +261700 523200 210000 + +statement ok +use default; + statement ok -drop table t1; +drop database db; diff --git a/tests/suites/0_stateless/19_fuzz/19_0005_fuzz_cte.result b/tests/suites/0_stateless/19_fuzz/19_0005_fuzz_cte.result new file mode 100755 index 0000000000000..d86bac9de59ab --- /dev/null +++ b/tests/suites/0_stateless/19_fuzz/19_0005_fuzz_cte.result @@ -0,0 +1 @@ +OK diff --git a/tests/suites/0_stateless/19_fuzz/19_0005_fuzz_cte.sh b/tests/suites/0_stateless/19_fuzz/19_0005_fuzz_cte.sh new file mode 100755 index 0000000000000..8b3d81efb4a20 --- /dev/null +++ b/tests/suites/0_stateless/19_fuzz/19_0005_fuzz_cte.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +. "$CURDIR"/../../../shell_env.sh + + +times=256 + +echo "" > /tmp/fuzz_a.txt +echo "" > /tmp/fuzz_b.txt + +for i in `seq 1 ${times}`;do + echo """with t0(sum_sid) as (select sum(number) over(partition by number order by number) + from numbers(3)) select n, if(n =1, sum_sid +1, 0) from t0, (select 1 n union all select 2) order by 1,2; + """ | $BENDSQL_CLIENT_CONNECT >> /tmp/fuzz_a.txt +done + + +for i in `seq 1 ${times}`;do + echo """with t0(sum_sid) as (select sum(number) over(partition by number order by number) + from numbers(3)) select n, if(n =1, sum_sid +1, 0) from t0, (select 1 n union all select 2) order by 1,2; + """ | $BENDSQL_CLIENT_CONNECT >> /tmp/fuzz_b.txt +done + +diff /tmp/fuzz_a.txt /tmp/fuzz_b.txt && echo "OK"