Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support decimal for min and max aggregate #1407

Merged
merged 4 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,46 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn aggregate_decimal_min() -> Result<()> {
let mut ctx = ExecutionContext::new();
ctx.register_table("d_table", test::table_with_decimal())
.unwrap();

let result = plan_and_collect(&mut ctx, "select min(c1) from d_table")
.await
.unwrap();
let expected = vec![
"+-----------------+",
"| MIN(d_table.c1) |",
"+-----------------+",
"| -100.009 |",
"+-----------------+",
];
assert_batches_sorted_eq!(expected, &result);
Ok(())
}

#[tokio::test]
async fn aggregate_decimal_max() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

let mut ctx = ExecutionContext::new();
ctx.register_table("d_table", test::table_with_decimal())
.unwrap();

let result = plan_and_collect(&mut ctx, "select max(c1) from d_table")
.await
.unwrap();
let expected = vec![
"+-----------------+",
"| MAX(d_table.c1) |",
"+-----------------+",
"| 110.009 |",
"+-----------------+",
];
assert_batches_sorted_eq!(expected, &result);
Ok(())
}

#[tokio::test]
async fn aggregate() -> Result<()> {
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;
Expand Down
254 changes: 244 additions & 10 deletions datafusion/src/physical_plan/expressions/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ use arrow::{
};

use super::format_state_name;
use crate::arrow::array::Array;
use arrow::array::DecimalArray;

// Min/max aggregation can take Dictionary encode input but always produces unpacked
// (aka non Dictionary) output. We need to adjust the output data type to reflect this.
Expand Down Expand Up @@ -129,11 +131,49 @@ macro_rules! typed_min_max_batch {
}};
}

// TODO implement this in arrow-rs with simd
// https://github.com/apache/arrow-rs/issues/1010
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not reviewed the code in the datafusion aggregate functions for a while, so I am not familiar with how much they do / don't use the arrow compute kernels, but I think the more we can leverage / reuse those kernels (and their SIMD specializations, if they exist), the better

// Statically-typed version of min/max(array) -> ScalarValue for decimal types.
macro_rules! typed_min_max_batch_decimal128 {
($VALUES:expr, $PRECISION:ident, $SCALE:ident, $OP:ident) => {{
let null_count = $VALUES.null_count();
if null_count == $VALUES.len() {
ScalarValue::Decimal128(None, *$PRECISION, *$SCALE)
} else {
let array = $VALUES.as_any().downcast_ref::<DecimalArray>().unwrap();
if null_count == 0 {
// there is no null value
let mut result = array.value(0);
for i in 1..array.len() {
result = result.$OP(array.value(i));
}
ScalarValue::Decimal128(Some(result), *$PRECISION, *$SCALE)
} else {
let mut result = 0_i128;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be more idomatic to use let mut rust: Option<i128> = 0; (aka use an Option rather than explicit flag)

Then instead of code like

                    if !has_value && array.is_valid(i) {
                        has_value = true;
                        result = array.value(i);
                    }
                    if array.is_valid(i) {
                        result = result.$OP(array.value(i));
                    }

You could write code like the following (which I think saves at least one check of is_valid()):

                    if array.is_valid(i) {
                        let value = array.value(i);
                        result = result.$OP(result.unwrap_or(value)
                    }

This is just a style suggestion, it is not needed I don't think

Copy link
Contributor Author

@liukun4515 liukun4515 Dec 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to resolve it in the follow-up pull request.
You can merge it if it looks good to you.
@alamb

Copy link
Contributor Author

@liukun4515 liukun4515 Dec 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recheck your suggestion and find some bugs.
For example, if all value in the array is less than zero, it may be [-1,-3,-3] and we want to get the max value of them.
initially, we set the result as Some(0) and follow your suggestion code, we will get the 0 as the max value for the result.

I think we should just change to that

let mut result = 0_i128; ->>> let mut result : i128 = 0;

In addition the unwrap_or is

    pub fn unwrap_or(self, default: T) -> T {
        match self {
            Some(x) => x,
            None => default,
        }
    }

let mut has_value = false;
for i in 0..array.len() {
if !has_value && array.is_valid(i) {
has_value = true;
result = array.value(i);
}
if array.is_valid(i) {
result = result.$OP(array.value(i));
}
}
ScalarValue::Decimal128(Some(result), *$PRECISION, *$SCALE)
}
}
}};
}

// Statically-typed version of min/max(array) -> ScalarValue for non-string types.
// this is a macro to support both operations (min and max).
macro_rules! min_max_batch {
($VALUES:expr, $OP:ident) => {{
match $VALUES.data_type() {
DataType::Decimal(precision, scale) => {
typed_min_max_batch_decimal128!($VALUES, precision, scale, $OP)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you added DecimalArray support to arrow::compute::kernels::min and arrow::compute::kernels::max you might be able to use typed_min_max_batch! here. Work for the future perhaps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when I finish the task of aggregating with decimal data type, I will begin to do basic operations in arrow-rs for decimal data type.

}
// all types that have a natural order
DataType::Float64 => {
typed_min_max_batch!($VALUES, Float64Array, Float64, $OP)
Expand Down Expand Up @@ -208,6 +248,20 @@ fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
_ => min_max_batch!(values, max),
})
}
macro_rules! typed_min_max_decimal {
($VALUE:expr, $DELTA:expr, $PRECISION:expr, $SCALE:expr, $SCALAR:ident, $OP:ident) => {{
ScalarValue::$SCALAR(
match ($VALUE, $DELTA) {
(None, None) => None,
(Some(a), None) => Some(a.clone()),
(None, Some(b)) => Some(b.clone()),
(Some(a), Some(b)) => Some((*a).$OP(*b)),
},
$PRECISION.clone(),
$SCALE.clone(),
)
}};
}

// min/max of two non-string scalar values.
macro_rules! typed_min_max {
Expand Down Expand Up @@ -237,6 +291,16 @@ macro_rules! typed_min_max_string {
macro_rules! min_max {
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
Ok(match ($VALUE, $DELTA) {
(ScalarValue::Decimal128(lhsv,lhsp,lhss), ScalarValue::Decimal128(rhsv,rhsp,rhss)) => {
if lhsp.eq(rhsp) && lhss.eq(rhss) {
typed_min_max_decimal!(lhsv, rhsv, lhsp, lhss, Decimal128, $OP)
} else {
return Err(DataFusionError::Internal(format!(
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

(ScalarValue::Decimal128(*lhsv,*lhsp,*lhss),ScalarValue::Decimal128(*rhsv,*rhsp,*rhss))
)));
}
}
(ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
typed_min_max!(lhs, rhs, Float64, $OP)
}
Expand Down Expand Up @@ -411,6 +475,10 @@ impl AggregateExpr for Min {
))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(MinAccumulator::try_new(&self.data_type)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new(
&format_state_name(&self.name, "min"),
Expand All @@ -423,10 +491,6 @@ impl AggregateExpr for Min {
vec![self.expr.clone()]
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(MinAccumulator::try_new(&self.data_type)?))
}

fn name(&self) -> &str {
&self.name
}
Expand All @@ -452,19 +516,19 @@ impl Accumulator for MinAccumulator {
Ok(vec![self.min.clone()])
}

fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
let value = &values[0];
self.min = min(&self.min, value)?;
Ok(())
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &values[0];
let delta = &min_batch(values)?;
self.min = min(&self.min, delta)?;
Ok(())
}

fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
let value = &values[0];
self.min = min(&self.min, value)?;
Ok(())
}

fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
self.update(states)
}
Expand All @@ -483,10 +547,180 @@ mod tests {
use super::*;
use crate::physical_plan::expressions::col;
use crate::physical_plan::expressions::tests::aggregate;
use crate::scalar::ScalarValue::Decimal128;
use crate::{error::Result, generic_test_op};
use arrow::array::DecimalBuilder;
use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;

#[test]
fn min_decimal() -> Result<()> {
// min
let left = ScalarValue::Decimal128(Some(123), 10, 2);
let right = ScalarValue::Decimal128(Some(124), 10, 2);
let result = min(&left, &right)?;
assert_eq!(result, left);

// min batch
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
for i in 1..6 {
decimal_builder.append_value(i as i128)?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());

let result = min_batch(&array)?;
assert_eq!(result, ScalarValue::Decimal128(Some(1), 10, 0));
// min batch without values
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
let array: ArrayRef = Arc::new(decimal_builder.finish());
let result = min_batch(&array)?;
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);

let mut decimal_builder = DecimalBuilder::new(0, 10, 0);
let array: ArrayRef = Arc::new(decimal_builder.finish());
let result = min_batch(&array)?;
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);

// min batch with agg
let mut decimal_builder = DecimalBuilder::new(6, 10, 0);
decimal_builder.append_null().unwrap();
for i in 1..6 {
decimal_builder.append_value(i as i128)?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
generic_test_op!(
array,
DataType::Decimal(10, 0),
Min,
ScalarValue::Decimal128(Some(1), 10, 0),
DataType::Decimal(10, 0)
)
}

#[test]
fn min_decimal_all_nulls() -> Result<()> {
// min batch all nulls
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
for _i in 1..6 {
decimal_builder.append_null()?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
generic_test_op!(
array,
DataType::Decimal(10, 0),
Min,
ScalarValue::Decimal128(None, 10, 0),
DataType::Decimal(10, 0)
)
}

#[test]
fn min_decimal_with_nulls() -> Result<()> {
// min batch with nulls
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
for i in 1..6 {
if i == 2 {
decimal_builder.append_null()?;
} else {
decimal_builder.append_value(i as i128)?;
}
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
generic_test_op!(
array,
DataType::Decimal(10, 0),
Min,
ScalarValue::Decimal128(Some(1), 10, 0),
DataType::Decimal(10, 0)
)
}

#[test]
fn max_decimal() -> Result<()> {
// max
let left = ScalarValue::Decimal128(Some(123), 10, 2);
let right = ScalarValue::Decimal128(Some(124), 10, 2);
let result = max(&left, &right)?;
assert_eq!(result, right);

let right = ScalarValue::Decimal128(Some(124), 10, 3);
let result = max(&left, &right);
let expect = DataFusionError::Internal(format!(
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
(Decimal128(Some(123), 10, 2), Decimal128(Some(124), 10, 3))
));
assert_eq!(expect.to_string(), result.unwrap_err().to_string());

// max batch
let mut decimal_builder = DecimalBuilder::new(5, 10, 5);
for i in 1..6 {
decimal_builder.append_value(i as i128)?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
let result = max_batch(&array)?;
assert_eq!(result, ScalarValue::Decimal128(Some(5), 10, 5));
// max batch without values
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
let array: ArrayRef = Arc::new(decimal_builder.finish());
let result = max_batch(&array)?;
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);

let mut decimal_builder = DecimalBuilder::new(0, 10, 0);
let array: ArrayRef = Arc::new(decimal_builder.finish());
let result = max_batch(&array)?;
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
// max batch with agg
let mut decimal_builder = DecimalBuilder::new(6, 10, 0);
decimal_builder.append_null().unwrap();
for i in 1..6 {
decimal_builder.append_value(i as i128)?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
generic_test_op!(
array,
DataType::Decimal(10, 0),
Max,
ScalarValue::Decimal128(Some(5), 10, 0),
DataType::Decimal(10, 0)
)
}

#[test]
fn max_decimal_with_nulls() -> Result<()> {
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
for i in 1..6 {
if i == 2 {
decimal_builder.append_null()?;
} else {
decimal_builder.append_value(i as i128)?;
}
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
generic_test_op!(
array,
DataType::Decimal(10, 0),
Max,
ScalarValue::Decimal128(Some(5), 10, 0),
DataType::Decimal(10, 0)
)
}

#[test]
fn max_decimal_all_nulls() -> Result<()> {
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
for _i in 1..6 {
decimal_builder.append_null()?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
generic_test_op!(
array,
DataType::Decimal(10, 0),
Min,
ScalarValue::Decimal128(None, 10, 0),
DataType::Decimal(10, 0)
)
}

#[test]
fn max_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
Expand Down
Loading