Skip to content

Commit

Permalink
support decimal for min/max agg (#1407)
Browse files Browse the repository at this point in the history
* support decimal for min/max agg

* add table/sql test for decimal min/max agg

* change decimal test case
  • Loading branch information
liukun4515 authored Dec 10, 2021
1 parent 7aad76f commit e89da30
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 11 deletions.
40 changes: 40 additions & 0 deletions datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,46 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn aggregate_decimal_min() -> Result<()> {
let mut ctx = ExecutionContext::new();
ctx.register_table("d_table", test::table_with_decimal())
.unwrap();

let result = plan_and_collect(&mut ctx, "select min(c1) from d_table")
.await
.unwrap();
let expected = vec![
"+-----------------+",
"| MIN(d_table.c1) |",
"+-----------------+",
"| -100.009 |",
"+-----------------+",
];
assert_batches_sorted_eq!(expected, &result);
Ok(())
}

#[tokio::test]
async fn aggregate_decimal_max() -> Result<()> {
let mut ctx = ExecutionContext::new();
ctx.register_table("d_table", test::table_with_decimal())
.unwrap();

let result = plan_and_collect(&mut ctx, "select max(c1) from d_table")
.await
.unwrap();
let expected = vec![
"+-----------------+",
"| MAX(d_table.c1) |",
"+-----------------+",
"| 110.009 |",
"+-----------------+",
];
assert_batches_sorted_eq!(expected, &result);
Ok(())
}

#[tokio::test]
async fn aggregate() -> Result<()> {
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;
Expand Down
254 changes: 244 additions & 10 deletions datafusion/src/physical_plan/expressions/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ use arrow::{
};

use super::format_state_name;
use crate::arrow::array::Array;
use arrow::array::DecimalArray;

// Min/max aggregation can take Dictionary encode input but always produces unpacked
// (aka non Dictionary) output. We need to adjust the output data type to reflect this.
Expand Down Expand Up @@ -129,11 +131,49 @@ macro_rules! typed_min_max_batch {
}};
}

// TODO implement this in arrow-rs with simd
// https://github.com/apache/arrow-rs/issues/1010
// Statically-typed version of min/max(array) -> ScalarValue for decimal types.
macro_rules! typed_min_max_batch_decimal128 {
($VALUES:expr, $PRECISION:ident, $SCALE:ident, $OP:ident) => {{
let null_count = $VALUES.null_count();
if null_count == $VALUES.len() {
ScalarValue::Decimal128(None, *$PRECISION, *$SCALE)
} else {
let array = $VALUES.as_any().downcast_ref::<DecimalArray>().unwrap();
if null_count == 0 {
// there is no null value
let mut result = array.value(0);
for i in 1..array.len() {
result = result.$OP(array.value(i));
}
ScalarValue::Decimal128(Some(result), *$PRECISION, *$SCALE)
} else {
let mut result = 0_i128;
let mut has_value = false;
for i in 0..array.len() {
if !has_value && array.is_valid(i) {
has_value = true;
result = array.value(i);
}
if array.is_valid(i) {
result = result.$OP(array.value(i));
}
}
ScalarValue::Decimal128(Some(result), *$PRECISION, *$SCALE)
}
}
}};
}

// Statically-typed version of min/max(array) -> ScalarValue for non-string types.
// this is a macro to support both operations (min and max).
macro_rules! min_max_batch {
($VALUES:expr, $OP:ident) => {{
match $VALUES.data_type() {
DataType::Decimal(precision, scale) => {
typed_min_max_batch_decimal128!($VALUES, precision, scale, $OP)
}
// all types that have a natural order
DataType::Float64 => {
typed_min_max_batch!($VALUES, Float64Array, Float64, $OP)
Expand Down Expand Up @@ -208,6 +248,20 @@ fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
_ => min_max_batch!(values, max),
})
}
macro_rules! typed_min_max_decimal {
($VALUE:expr, $DELTA:expr, $PRECISION:expr, $SCALE:expr, $SCALAR:ident, $OP:ident) => {{
ScalarValue::$SCALAR(
match ($VALUE, $DELTA) {
(None, None) => None,
(Some(a), None) => Some(a.clone()),
(None, Some(b)) => Some(b.clone()),
(Some(a), Some(b)) => Some((*a).$OP(*b)),
},
$PRECISION.clone(),
$SCALE.clone(),
)
}};
}

// min/max of two non-string scalar values.
macro_rules! typed_min_max {
Expand Down Expand Up @@ -237,6 +291,16 @@ macro_rules! typed_min_max_string {
macro_rules! min_max {
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
Ok(match ($VALUE, $DELTA) {
(ScalarValue::Decimal128(lhsv,lhsp,lhss), ScalarValue::Decimal128(rhsv,rhsp,rhss)) => {
if lhsp.eq(rhsp) && lhss.eq(rhss) {
typed_min_max_decimal!(lhsv, rhsv, lhsp, lhss, Decimal128, $OP)
} else {
return Err(DataFusionError::Internal(format!(
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
(ScalarValue::Decimal128(*lhsv,*lhsp,*lhss),ScalarValue::Decimal128(*rhsv,*rhsp,*rhss))
)));
}
}
(ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
typed_min_max!(lhs, rhs, Float64, $OP)
}
Expand Down Expand Up @@ -411,6 +475,10 @@ impl AggregateExpr for Min {
))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(MinAccumulator::try_new(&self.data_type)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new(
&format_state_name(&self.name, "min"),
Expand All @@ -423,10 +491,6 @@ impl AggregateExpr for Min {
vec![self.expr.clone()]
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(MinAccumulator::try_new(&self.data_type)?))
}

fn name(&self) -> &str {
&self.name
}
Expand All @@ -452,19 +516,19 @@ impl Accumulator for MinAccumulator {
Ok(vec![self.min.clone()])
}

fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
let value = &values[0];
self.min = min(&self.min, value)?;
Ok(())
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &values[0];
let delta = &min_batch(values)?;
self.min = min(&self.min, delta)?;
Ok(())
}

fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
let value = &values[0];
self.min = min(&self.min, value)?;
Ok(())
}

fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
self.update(states)
}
Expand All @@ -483,10 +547,180 @@ mod tests {
use super::*;
use crate::physical_plan::expressions::col;
use crate::physical_plan::expressions::tests::aggregate;
use crate::scalar::ScalarValue::Decimal128;
use crate::{error::Result, generic_test_op};
use arrow::array::DecimalBuilder;
use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;

#[test]
fn min_decimal() -> Result<()> {
// min
let left = ScalarValue::Decimal128(Some(123), 10, 2);
let right = ScalarValue::Decimal128(Some(124), 10, 2);
let result = min(&left, &right)?;
assert_eq!(result, left);

// min batch
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
for i in 1..6 {
decimal_builder.append_value(i as i128)?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());

let result = min_batch(&array)?;
assert_eq!(result, ScalarValue::Decimal128(Some(1), 10, 0));
// min batch without values
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
let array: ArrayRef = Arc::new(decimal_builder.finish());
let result = min_batch(&array)?;
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);

let mut decimal_builder = DecimalBuilder::new(0, 10, 0);
let array: ArrayRef = Arc::new(decimal_builder.finish());
let result = min_batch(&array)?;
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);

// min batch with agg
let mut decimal_builder = DecimalBuilder::new(6, 10, 0);
decimal_builder.append_null().unwrap();
for i in 1..6 {
decimal_builder.append_value(i as i128)?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
generic_test_op!(
array,
DataType::Decimal(10, 0),
Min,
ScalarValue::Decimal128(Some(1), 10, 0),
DataType::Decimal(10, 0)
)
}

#[test]
fn min_decimal_all_nulls() -> Result<()> {
// min batch all nulls
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
for _i in 1..6 {
decimal_builder.append_null()?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
generic_test_op!(
array,
DataType::Decimal(10, 0),
Min,
ScalarValue::Decimal128(None, 10, 0),
DataType::Decimal(10, 0)
)
}

#[test]
fn min_decimal_with_nulls() -> Result<()> {
// min batch with nulls
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
for i in 1..6 {
if i == 2 {
decimal_builder.append_null()?;
} else {
decimal_builder.append_value(i as i128)?;
}
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
generic_test_op!(
array,
DataType::Decimal(10, 0),
Min,
ScalarValue::Decimal128(Some(1), 10, 0),
DataType::Decimal(10, 0)
)
}

#[test]
fn max_decimal() -> Result<()> {
// max
let left = ScalarValue::Decimal128(Some(123), 10, 2);
let right = ScalarValue::Decimal128(Some(124), 10, 2);
let result = max(&left, &right)?;
assert_eq!(result, right);

let right = ScalarValue::Decimal128(Some(124), 10, 3);
let result = max(&left, &right);
let expect = DataFusionError::Internal(format!(
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
(Decimal128(Some(123), 10, 2), Decimal128(Some(124), 10, 3))
));
assert_eq!(expect.to_string(), result.unwrap_err().to_string());

// max batch
let mut decimal_builder = DecimalBuilder::new(5, 10, 5);
for i in 1..6 {
decimal_builder.append_value(i as i128)?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
let result = max_batch(&array)?;
assert_eq!(result, ScalarValue::Decimal128(Some(5), 10, 5));
// max batch without values
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
let array: ArrayRef = Arc::new(decimal_builder.finish());
let result = max_batch(&array)?;
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);

let mut decimal_builder = DecimalBuilder::new(0, 10, 0);
let array: ArrayRef = Arc::new(decimal_builder.finish());
let result = max_batch(&array)?;
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
// max batch with agg
let mut decimal_builder = DecimalBuilder::new(6, 10, 0);
decimal_builder.append_null().unwrap();
for i in 1..6 {
decimal_builder.append_value(i as i128)?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
generic_test_op!(
array,
DataType::Decimal(10, 0),
Max,
ScalarValue::Decimal128(Some(5), 10, 0),
DataType::Decimal(10, 0)
)
}

#[test]
fn max_decimal_with_nulls() -> Result<()> {
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
for i in 1..6 {
if i == 2 {
decimal_builder.append_null()?;
} else {
decimal_builder.append_value(i as i128)?;
}
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
generic_test_op!(
array,
DataType::Decimal(10, 0),
Max,
ScalarValue::Decimal128(Some(5), 10, 0),
DataType::Decimal(10, 0)
)
}

#[test]
fn max_decimal_all_nulls() -> Result<()> {
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
for _i in 1..6 {
decimal_builder.append_null()?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
generic_test_op!(
array,
DataType::Decimal(10, 0),
Min,
ScalarValue::Decimal128(None, 10, 0),
DataType::Decimal(10, 0)
)
}

#[test]
fn max_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
Expand Down
Loading

0 comments on commit e89da30

Please sign in to comment.