Skip to content

Commit

Permalink
fix(query): fix register function working with nullable scalar (#17217)
Browse files Browse the repository at this point in the history
* fix(query): fix register function working with nullable scalar

* fix(query): fix register function working with nullable scalar

* fix(query): increase pool

* Update 19_0005_fuzz_cte.sh

* Update mysql_source.rs

* fix(query): fix register function working with nullable scalar
  • Loading branch information
sundy-li committed Jan 9, 2025
1 parent ddb8d0b commit 03a24a9
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 17 deletions.
25 changes: 9 additions & 16 deletions src/query/expression/src/register_vectorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ pub fn passthrough_nullable_1_arg<I1: ArgType, O: ArgType>(

match out {
Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(Some(out)),
}
}
_ => Value::Scalar(None),
Expand All @@ -308,15 +307,15 @@ pub fn passthrough_nullable_2_arg<I1: ArgType, I2: ArgType, O: ArgType>(
if let Some(validity) = ctx.validity.as_ref() {
args_validity = &args_validity & validity;
}

ctx.validity = Some(args_validity.clone());
match (arg1.value(), arg2.value()) {
(Some(arg1), Some(arg2)) => {
let out = func(arg1, arg2, ctx);

match out {
Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(Some(out)),
}
}
_ => Value::Scalar(None),
Expand Down Expand Up @@ -352,8 +351,7 @@ pub fn passthrough_nullable_3_arg<I1: ArgType, I2: ArgType, I3: ArgType, O: ArgT

match out {
Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(Some(out)),
}
}
_ => Value::Scalar(None),
Expand Down Expand Up @@ -397,8 +395,7 @@ pub fn passthrough_nullable_4_arg<

match out {
Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(Some(out)),
}
}
_ => Value::Scalar(None),
Expand Down Expand Up @@ -427,8 +424,7 @@ pub fn combine_nullable_1_arg<I1: ArgType, O: ArgType>(
out.column,
&args_validity & &out.validity,
)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(out),
}
}
_ => Value::Scalar(None),
Expand Down Expand Up @@ -465,8 +461,7 @@ pub fn combine_nullable_2_arg<I1: ArgType, I2: ArgType, O: ArgType>(
out.column,
&args_validity & &out.validity,
)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(out),
}
}
_ => Value::Scalar(None),
Expand Down Expand Up @@ -505,8 +500,7 @@ pub fn combine_nullable_3_arg<I1: ArgType, I2: ArgType, I3: ArgType, O: ArgType>
out.column,
&args_validity & &out.validity,
)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(out),
}
}
_ => Value::Scalar(None),
Expand Down Expand Up @@ -552,8 +546,7 @@ pub fn combine_nullable_4_arg<I1: ArgType, I2: ArgType, I3: ArgType, I4: ArgType
out.column,
&args_validity & &out.validity,
)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(out),
}
}
_ => Value::Scalar(None),
Expand Down
40 changes: 40 additions & 0 deletions src/query/functions/tests/it/scalars/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,43 @@ fn list_all_builtin_functions() {
fn check_ambiguity() {
BUILTIN_FUNCTIONS.check_ambiguity()
}

#[test]
fn test_if_function() -> Result<()> {
use databend_common_expression::types::*;
use databend_common_expression::FromData;
use databend_common_expression::Scalar;
let raw_expr = parser::parse_raw_expr("if(eq(n,1), sum_sid + 1,100)", &[
("n", UInt8Type::data_type()),
("sum_sid", Int32Type::data_type().wrap_nullable()),
]);
let expr = type_check::check(&raw_expr, &BUILTIN_FUNCTIONS)?;
let block = DataBlock::new(
vec![
BlockEntry {
data_type: UInt8Type::data_type(),
value: Value::Column(UInt8Type::from_data(vec![2_u8, 1])),
},
BlockEntry {
data_type: Int32Type::data_type().wrap_nullable(),
value: Value::Scalar(Scalar::Number(NumberScalar::Int32(2400_i32))),
},
],
2,
);
let func_ctx = FunctionContext::default();
let evaluator = Evaluator::new(&block, &func_ctx, &BUILTIN_FUNCTIONS);
let result = evaluator.run(&expr).unwrap();
let result = result
.as_column()
.unwrap()
.clone()
.as_nullable()
.unwrap()
.clone();

let bm = Bitmap::from_iter([true, true]);
assert_eq!(result.validity, bm);
assert_eq!(result.column, Int64Type::from_data(vec![100, 2401]));
Ok(())
}
61 changes: 60 additions & 1 deletion tests/sqllogictests/suites/query/cte/basic_r_cte.test
Original file line number Diff line number Diff line change
Expand Up @@ -227,5 +227,64 @@ select cte1.a from cte1;
8
9


statement ok
create table train(
train_id varchar(8) not null ,
departure_station varchar(32) not null,
arrival_station varchar(32) not null,
seat_count int not null
);

statement ok
create table passenger(
passenger_id varchar(16) not null,
departure_station varchar(32) not null,
arrival_station varchar(32) not null
);

statement ok
create table city(city varchar(32));

statement ok
insert into city
with t as (select 1 n union select 2 union select 3 union select 4 union select 5)
,t1 as(select row_number()over() rn from t ,t t2,t t3)
select concat('城市',rn::varchar) city from t1 where rn<=5;

statement ok
insert into train
select concat('G',row_number()over()::varchar),c1.city,c2.city, n from city c1, city c2, (select 600 n union select 800 union select 1200 union select 1600) a ;

statement ok
insert into passenger
select concat('P',substr((100000000+row_number()over())::varchar,2)),c1.city,c2.city from city c1, city c2 ,city c3, city c4, city c5,
city c6, (select 1 n union select 2 union select 3 union select 4) c7,(select 1 n union select 2) c8;


query III
with
t0 as (
select
train_id,
seat_count,
sum(seat_count) over (
partition by departure_station, arrival_station order by train_id
) ::int sum_sid
from
train
)
select
sum(case when n=1 then sum_sid+1 else 0 end::int),
sum(sum_sid),
sum(seat_count)
from
t0,(select 1 n union all select 2);
----
261700 523200 210000

statement ok
use default;

statement ok
drop table t1;
drop database db;
1 change: 1 addition & 0 deletions tests/suites/0_stateless/19_fuzz/19_0005_fuzz_cte.result
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
OK
25 changes: 25 additions & 0 deletions tests/suites/0_stateless/19_fuzz/19_0005_fuzz_cte.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env bash

CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
. "$CURDIR"/../../../shell_env.sh


times=256

echo "" > /tmp/fuzz_a.txt
echo "" > /tmp/fuzz_b.txt

for i in `seq 1 ${times}`;do
echo """with t0(sum_sid) as (select sum(number) over(partition by number order by number)
from numbers(3)) select n, if(n =1, sum_sid +1, 0) from t0, (select 1 n union all select 2) order by 1,2;
""" | $BENDSQL_CLIENT_CONNECT >> /tmp/fuzz_a.txt
done


for i in `seq 1 ${times}`;do
echo """with t0(sum_sid) as (select sum(number) over(partition by number order by number)
from numbers(3)) select n, if(n =1, sum_sid +1, 0) from t0, (select 1 n union all select 2) order by 1,2;
""" | $BENDSQL_CLIENT_CONNECT >> /tmp/fuzz_b.txt
done

diff /tmp/fuzz_a.txt /tmp/fuzz_b.txt && echo "OK"

0 comments on commit 03a24a9

Please sign in to comment.