diff --git a/src/query/catalog/src/plan/partition.rs b/src/query/catalog/src/plan/partition.rs index a6944fb1235b4..ca791c8e579a5 100644 --- a/src/query/catalog/src/plan/partition.rs +++ b/src/query/catalog/src/plan/partition.rs @@ -185,6 +185,9 @@ impl Default for Partitions { pub struct StealablePartitions { pub partitions: Arc>>>, pub ctx: Arc, + // In some cases, we need to disable steal. + // Such as topk queries, this is suitable that topk will respect all the pagecache and reduce false sharing between threads. + pub disable_steal: bool, } impl StealablePartitions { @@ -192,9 +195,14 @@ impl StealablePartitions { StealablePartitions { partitions: Arc::new(RwLock::new(partitions)), ctx, + disable_steal: false, } } + pub fn disable_steal(&mut self) { + self.disable_steal = true; + } + pub fn steal_one(&self, idx: usize) -> Option { let mut partitions = self.partitions.write(); if partitions.is_empty() { @@ -212,6 +220,10 @@ impl StealablePartitions { if !partitions[index].is_empty() { return partitions[index].pop_front(); } + + if self.disable_steal { + break; + } } drop(partitions); @@ -238,6 +250,10 @@ impl StealablePartitions { let size = ps.len().min(max_size); return ps.drain(..size).collect(); } + + if self.disable_steal { + break; + } } drop(partitions); diff --git a/src/query/functions/src/scalars/decimal.rs b/src/query/functions/src/scalars/decimal.rs index 80414875bb489..4bbf516ee81f7 100644 --- a/src/query/functions/src/scalars/decimal.rs +++ b/src/query/functions/src/scalars/decimal.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::cmp::Ord; use std::ops::*; use std::sync::Arc; @@ -52,6 +53,63 @@ macro_rules! op_decimal { _ => unreachable!(), } }; + ($a: expr, $b: expr, $return_type: expr, $op: ident) => { + match $return_type { + DataType::Decimal(d) => match d { + DecimalDataType::Decimal128(_) => { + compare_decimal!($a, $b, $op, Decimal128) + } + DecimalDataType::Decimal256(_) => { + compare_decimal!($a, $b, $op, Decimal256) + } + }, + _ => unreachable!(), + } + }; +} + +macro_rules! compare_decimal { + ($a: expr, $b: expr, $op: ident, $decimal_type: tt) => {{ + match ($a, $b) { + ( + ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer_a, _))), + ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer_b, _))), + ) => { + let result = buffer_a + .iter() + .zip(buffer_b.iter()) + .map(|(a, b)| a.cmp(b).$op()) + .collect(); + + Value::Column(Column::Boolean(result)) + } + + ( + ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer, _))), + ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(b, _))), + ) => { + let result = buffer.iter().map(|a| a.cmp(b).$op()).collect(); + + Value::Column(Column::Boolean(result)) + } + + ( + ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(a, _))), + ValueRef::Column(Column::Decimal(DecimalColumn::$decimal_type(buffer, _))), + ) => { + let result = buffer.iter().map(|b| a.cmp(b).$op()).collect(); + + Value::Column(Column::Boolean(result)) + } + + ( + ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(a, _))), + ValueRef::Scalar(ScalarRef::Decimal(DecimalScalar::$decimal_type(b, _))), + ) => Value::Scalar(Scalar::Boolean(a.cmp(b).$op())), + + _ => unreachable!(), + } + }}; } macro_rules! binary_decimal { @@ -145,6 +203,65 @@ macro_rules! binary_decimal { }}; } +macro_rules! register_decimal_compare_op { + ($registry: expr, $name: expr, $op: ident) => { + $registry.register_function_factory($name, |_, args_type| { + if args_type.len() != 2 { + return None; + } + + let has_nullable = args_type.iter().any(|x| x.is_nullable_or_null()); + let args_type: Vec = args_type.iter().map(|x| x.remove_nullable()).collect(); + + // Only works for one of is decimal types + if !args_type[0].is_decimal() && !args_type[1].is_decimal() { + return None; + } + + // we use the max precision and scale for the result + let return_type = if args_type[0].is_decimal() && args_type[1].is_decimal() { + let lhs_type = args_type[0].as_decimal().unwrap(); + let rhs_type = args_type[1].as_decimal().unwrap(); + + DecimalDataType::binary_result_type(&lhs_type, &rhs_type, false, false, true) + } else if args_type[0].is_decimal() { + let lhs_type = args_type[0].as_decimal().unwrap(); + lhs_type.binary_upgrade_to_max_precision() + } else { + let rhs_type = args_type[1].as_decimal().unwrap(); + rhs_type.binary_upgrade_to_max_precision() + } + .ok()?; + + let function = Function { + signature: FunctionSignature { + name: $name.to_string(), + args_type: vec![ + DataType::Decimal(return_type.clone()), + DataType::Decimal(return_type.clone()), + ], + return_type: DataType::Decimal(return_type.clone()), + property: FunctionProperty::default(), + }, + calc_domain: Box::new(|_args_domain| FunctionDomain::Full), + eval: Box::new(move |args, _ctx| { + op_decimal!( + &args[0], + &args[1], + &DataType::Decimal(return_type.clone()), + $op + ) + }), + }; + if has_nullable { + Some(Arc::new(function.wrap_nullable())) + } else { + Some(Arc::new(function)) + } + }); + }; +} + macro_rules! register_decimal_binary_op { ($registry: expr, $name: expr, $op: ident) => { $registry.register_function_factory($name, |_, args_type| { @@ -234,6 +351,16 @@ pub fn register(registry: &mut FunctionRegistry) { register_decimal_binary_op!(registry, "divide", div); register_decimal_binary_op!(registry, "multiply", mul); + register_decimal_compare_op!(registry, "lt", is_lt); + register_decimal_compare_op!(registry, "eq", is_eq); + register_decimal_compare_op!(registry, "gt", is_gt); + + register_decimal_compare_op!(registry, "lte", is_le); + + register_decimal_compare_op!(registry, "gte", is_ge); + + register_decimal_compare_op!(registry, "ne", is_ne); + // int float to decimal registry.register_function_factory("to_decimal", |params, args_type| { if args_type.len() != 1 { diff --git a/src/query/functions/tests/it/scalars/testdata/function_list.txt b/src/query/functions/tests/it/scalars/testdata/function_list.txt index 1ef44928ef135..04333a1209c38 100644 --- a/src/query/functions/tests/it/scalars/testdata/function_list.txt +++ b/src/query/functions/tests/it/scalars/testdata/function_list.txt @@ -3190,6 +3190,7 @@ lte minus multi_if multiply +ne noteq plus point_in_ellipses diff --git a/src/query/storages/fuse/src/operations/read/fuse_source.rs b/src/query/storages/fuse/src/operations/read/fuse_source.rs index 678b7beb3aa53..e537c57c616a2 100644 --- a/src/query/storages/fuse/src/operations/read/fuse_source.rs +++ b/src/query/storages/fuse/src/operations/read/fuse_source.rs @@ -43,12 +43,21 @@ pub fn build_fuse_native_source_pipeline( ) -> Result<()> { (max_threads, max_io_requests) = adjust_threads_and_request(max_threads, max_io_requests, plan); + if topk.is_some() { + max_threads = max_threads.min(16); + max_io_requests = max_io_requests.min(16); + } + let mut source_builder = SourcePipeBuilder::create(); match block_reader.support_blocking_api() { true => { let partitions = dispatch_partitions(ctx.clone(), plan, max_threads); - let partitions = StealablePartitions::new(partitions, ctx.clone()); + let mut partitions = StealablePartitions::new(partitions, ctx.clone()); + + if topk.is_some() { + partitions.disable_steal(); + } for i in 0..max_threads { let output = OutputPort::create(); @@ -67,7 +76,11 @@ pub fn build_fuse_native_source_pipeline( } false => { let partitions = dispatch_partitions(ctx.clone(), plan, max_io_requests); - let partitions = StealablePartitions::new(partitions, ctx.clone()); + let mut partitions = StealablePartitions::new(partitions, ctx.clone()); + + if topk.is_some() { + partitions.disable_steal(); + } for i in 0..max_io_requests { let output = OutputPort::create(); diff --git a/tests/sqllogictests/suites/base/11_data_type/11_0006_data_type_decimal b/tests/sqllogictests/suites/base/11_data_type/11_0006_data_type_decimal index 2f27780a7ad59..d7be42ee2956f 100644 --- a/tests/sqllogictests/suites/base/11_data_type/11_0006_data_type_decimal +++ b/tests/sqllogictests/suites/base/11_data_type/11_0006_data_type_decimal @@ -163,6 +163,31 @@ SELECT ANY(CAST(2.34 AS DECIMAL(6, 2))) ---- 2.34 + +## compare + +query IIIII +select a > b, a < b, a = b, a <= b, a >= b from (select 3::Decimal(13,2) a , 3.1::Decimal(8,2) b); +---- +0 1 0 1 0 + + +query IIIII +select a > b, a < b, a = b, a <= b, a >= b from (select 3::Decimal(13,2) a , 3 b); +---- +0 0 1 1 1 + +query IIIII +select a > b, a < b, a = b, a <= b, a >= b from (select 3::Decimal(13,2) a , 3.1 b); +---- +0 1 0 1 0 + +query IIIII +select a > b, a < b, a = b, a <= b, a >= b from (select 3::Decimal(13,2) a , 2.9 b); +---- +1 0 0 0 1 + + ## insert statement ok