diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 59d6f44f59b1..d7c536ed2771 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1842,6 +1842,46 @@ mod tests { Ok(()) } + #[tokio::test] + async fn aggregate_decimal_min() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("d_table", test::table_with_decimal()) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "select min(c1) from d_table") + .await + .unwrap(); + let expected = vec![ + "+-----------------+", + "| MIN(d_table.c1) |", + "+-----------------+", + "| -100.009 |", + "+-----------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + Ok(()) + } + + #[tokio::test] + async fn aggregate_decimal_max() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("d_table", test::table_with_decimal()) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "select max(c1) from d_table") + .await + .unwrap(); + let expected = vec![ + "+-----------------+", + "| MAX(d_table.c1) |", + "+-----------------+", + "| 110.009 |", + "+-----------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + Ok(()) + } + #[tokio::test] async fn aggregate() -> Result<()> { let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?; diff --git a/datafusion/src/physical_plan/expressions/min_max.rs b/datafusion/src/physical_plan/expressions/min_max.rs index 9e5b1e095cd6..2f6188169654 100644 --- a/datafusion/src/physical_plan/expressions/min_max.rs +++ b/datafusion/src/physical_plan/expressions/min_max.rs @@ -37,6 +37,8 @@ use arrow::{ }; use super::format_state_name; +use crate::arrow::array::Array; +use arrow::array::DecimalArray; // Min/max aggregation can take Dictionary encode input but always produces unpacked // (aka non Dictionary) output. We need to adjust the output data type to reflect this. @@ -129,11 +131,49 @@ macro_rules! typed_min_max_batch { }}; } +// TODO implement this in arrow-rs with simd +// https://github.com/apache/arrow-rs/issues/1010 +// Statically-typed version of min/max(array) -> ScalarValue for decimal types. +macro_rules! typed_min_max_batch_decimal128 { + ($VALUES:expr, $PRECISION:ident, $SCALE:ident, $OP:ident) => {{ + let null_count = $VALUES.null_count(); + if null_count == $VALUES.len() { + ScalarValue::Decimal128(None, *$PRECISION, *$SCALE) + } else { + let array = $VALUES.as_any().downcast_ref::().unwrap(); + if null_count == 0 { + // there is no null value + let mut result = array.value(0); + for i in 1..array.len() { + result = result.$OP(array.value(i)); + } + ScalarValue::Decimal128(Some(result), *$PRECISION, *$SCALE) + } else { + let mut result = 0_i128; + let mut has_value = false; + for i in 0..array.len() { + if !has_value && array.is_valid(i) { + has_value = true; + result = array.value(i); + } + if array.is_valid(i) { + result = result.$OP(array.value(i)); + } + } + ScalarValue::Decimal128(Some(result), *$PRECISION, *$SCALE) + } + } + }}; +} + // Statically-typed version of min/max(array) -> ScalarValue for non-string types. // this is a macro to support both operations (min and max). macro_rules! min_max_batch { ($VALUES:expr, $OP:ident) => {{ match $VALUES.data_type() { + DataType::Decimal(precision, scale) => { + typed_min_max_batch_decimal128!($VALUES, precision, scale, $OP) + } // all types that have a natural order DataType::Float64 => { typed_min_max_batch!($VALUES, Float64Array, Float64, $OP) @@ -208,6 +248,20 @@ fn max_batch(values: &ArrayRef) -> Result { _ => min_max_batch!(values, max), }) } +macro_rules! typed_min_max_decimal { + ($VALUE:expr, $DELTA:expr, $PRECISION:expr, $SCALE:expr, $SCALAR:ident, $OP:ident) => {{ + ScalarValue::$SCALAR( + match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(a.clone()), + (None, Some(b)) => Some(b.clone()), + (Some(a), Some(b)) => Some((*a).$OP(*b)), + }, + $PRECISION.clone(), + $SCALE.clone(), + ) + }}; +} // min/max of two non-string scalar values. macro_rules! typed_min_max { @@ -237,6 +291,16 @@ macro_rules! typed_min_max_string { macro_rules! min_max { ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ Ok(match ($VALUE, $DELTA) { + (ScalarValue::Decimal128(lhsv,lhsp,lhss), ScalarValue::Decimal128(rhsv,rhsp,rhss)) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max_decimal!(lhsv, rhsv, lhsp, lhss, Decimal128, $OP) + } else { + return Err(DataFusionError::Internal(format!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (ScalarValue::Decimal128(*lhsv,*lhsp,*lhss),ScalarValue::Decimal128(*rhsv,*rhsp,*rhss)) + ))); + } + } (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { typed_min_max!(lhs, rhs, Float64, $OP) } @@ -411,6 +475,10 @@ impl AggregateExpr for Min { )) } + fn create_accumulator(&self) -> Result> { + Ok(Box::new(MinAccumulator::try_new(&self.data_type)?)) + } + fn state_fields(&self) -> Result> { Ok(vec![Field::new( &format_state_name(&self.name, "min"), @@ -423,10 +491,6 @@ impl AggregateExpr for Min { vec![self.expr.clone()] } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(MinAccumulator::try_new(&self.data_type)?)) - } - fn name(&self) -> &str { &self.name } @@ -452,6 +516,12 @@ impl Accumulator for MinAccumulator { Ok(vec![self.min.clone()]) } + fn update(&mut self, values: &[ScalarValue]) -> Result<()> { + let value = &values[0]; + self.min = min(&self.min, value)?; + Ok(()) + } + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; let delta = &min_batch(values)?; @@ -459,12 +529,6 @@ impl Accumulator for MinAccumulator { Ok(()) } - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - let value = &values[0]; - self.min = min(&self.min, value)?; - Ok(()) - } - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { self.update(states) } @@ -483,10 +547,180 @@ mod tests { use super::*; use crate::physical_plan::expressions::col; use crate::physical_plan::expressions::tests::aggregate; + use crate::scalar::ScalarValue::Decimal128; use crate::{error::Result, generic_test_op}; + use arrow::array::DecimalBuilder; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + #[test] + fn min_decimal() -> Result<()> { + // min + let left = ScalarValue::Decimal128(Some(123), 10, 2); + let right = ScalarValue::Decimal128(Some(124), 10, 2); + let result = min(&left, &right)?; + assert_eq!(result, left); + + // min batch + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for i in 1..6 { + decimal_builder.append_value(i as i128)?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + + let result = min_batch(&array)?; + assert_eq!(result, ScalarValue::Decimal128(Some(1), 10, 0)); + // min batch without values + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + let array: ArrayRef = Arc::new(decimal_builder.finish()); + let result = min_batch(&array)?; + assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); + + let mut decimal_builder = DecimalBuilder::new(0, 10, 0); + let array: ArrayRef = Arc::new(decimal_builder.finish()); + let result = min_batch(&array)?; + assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); + + // min batch with agg + let mut decimal_builder = DecimalBuilder::new(6, 10, 0); + decimal_builder.append_null().unwrap(); + for i in 1..6 { + decimal_builder.append_value(i as i128)?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Min, + ScalarValue::Decimal128(Some(1), 10, 0), + DataType::Decimal(10, 0) + ) + } + + #[test] + fn min_decimal_all_nulls() -> Result<()> { + // min batch all nulls + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for _i in 1..6 { + decimal_builder.append_null()?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Min, + ScalarValue::Decimal128(None, 10, 0), + DataType::Decimal(10, 0) + ) + } + + #[test] + fn min_decimal_with_nulls() -> Result<()> { + // min batch with nulls + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for i in 1..6 { + if i == 2 { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(i as i128)?; + } + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Min, + ScalarValue::Decimal128(Some(1), 10, 0), + DataType::Decimal(10, 0) + ) + } + + #[test] + fn max_decimal() -> Result<()> { + // max + let left = ScalarValue::Decimal128(Some(123), 10, 2); + let right = ScalarValue::Decimal128(Some(124), 10, 2); + let result = max(&left, &right)?; + assert_eq!(result, right); + + let right = ScalarValue::Decimal128(Some(124), 10, 3); + let result = max(&left, &right); + let expect = DataFusionError::Internal(format!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (Decimal128(Some(123), 10, 2), Decimal128(Some(124), 10, 3)) + )); + assert_eq!(expect.to_string(), result.unwrap_err().to_string()); + + // max batch + let mut decimal_builder = DecimalBuilder::new(5, 10, 5); + for i in 1..6 { + decimal_builder.append_value(i as i128)?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + let result = max_batch(&array)?; + assert_eq!(result, ScalarValue::Decimal128(Some(5), 10, 5)); + // max batch without values + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + let array: ArrayRef = Arc::new(decimal_builder.finish()); + let result = max_batch(&array)?; + assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); + + let mut decimal_builder = DecimalBuilder::new(0, 10, 0); + let array: ArrayRef = Arc::new(decimal_builder.finish()); + let result = max_batch(&array)?; + assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); + // max batch with agg + let mut decimal_builder = DecimalBuilder::new(6, 10, 0); + decimal_builder.append_null().unwrap(); + for i in 1..6 { + decimal_builder.append_value(i as i128)?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Max, + ScalarValue::Decimal128(Some(5), 10, 0), + DataType::Decimal(10, 0) + ) + } + + #[test] + fn max_decimal_with_nulls() -> Result<()> { + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for i in 1..6 { + if i == 2 { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(i as i128)?; + } + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Max, + ScalarValue::Decimal128(Some(5), 10, 0), + DataType::Decimal(10, 0) + ) + } + + #[test] + fn max_decimal_all_nulls() -> Result<()> { + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for _i in 1..6 { + decimal_builder.append_null()?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Min, + ScalarValue::Decimal128(None, 10, 0), + DataType::Decimal(10, 0) + ) + } + #[test] fn max_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs index 16c1383c119f..39c9de1f6a5f 100644 --- a/datafusion/src/test/mod.rs +++ b/datafusion/src/test/mod.rs @@ -25,7 +25,7 @@ use array::{ Array, ArrayRef, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }; -use arrow::array::{self, Int32Array}; +use arrow::array::{self, DecimalBuilder, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use futures::{Future, FutureExt}; @@ -192,6 +192,27 @@ pub fn table_with_timestamps() -> Arc { Arc::new(MemTable::try_new(schema, partitions).unwrap()) } +/// Return a new table which provide this decimal column +pub fn table_with_decimal() -> Arc { + let batch_decimal = make_decimal(); + let schema = batch_decimal.schema(); + let partitions = vec![vec![batch_decimal]]; + Arc::new(MemTable::try_new(schema, partitions).unwrap()) +} + +fn make_decimal() -> RecordBatch { + let mut decimal_builder = DecimalBuilder::new(20, 10, 3); + for i in 110000..110010 { + decimal_builder.append_value(i as i128).unwrap(); + } + for i in 100000..100010 { + decimal_builder.append_value(-i as i128).unwrap(); + } + let array = decimal_builder.finish(); + let schema = Schema::new(vec![Field::new("c1", array.data_type().clone(), true)]); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap() +} + /// Return record batch with all of the supported timestamp types /// values ///