-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support decimal for min
and max
aggregate
#1407
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,8 @@ use arrow::{ | |
}; | ||
|
||
use super::format_state_name; | ||
use crate::arrow::array::Array; | ||
use arrow::array::DecimalArray; | ||
|
||
// Min/max aggregation can take Dictionary encode input but always produces unpacked | ||
// (aka non Dictionary) output. We need to adjust the output data type to reflect this. | ||
|
@@ -129,11 +131,49 @@ macro_rules! typed_min_max_batch { | |
}}; | ||
} | ||
|
||
// TODO implement this in arrow-rs with simd | ||
// https://github.com/apache/arrow-rs/issues/1010 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have not reviewed the code in the datafusion aggregate functions for a while, so I am not familiar with how much they do / don't use the arrow compute kernels, but I think the more we can leverage / reuse those kernels (and their SIMD specializations, if they exist), the better |
||
// Statically-typed version of min/max(array) -> ScalarValue for decimal types. | ||
macro_rules! typed_min_max_batch_decimal128 { | ||
($VALUES:expr, $PRECISION:ident, $SCALE:ident, $OP:ident) => {{ | ||
let null_count = $VALUES.null_count(); | ||
if null_count == $VALUES.len() { | ||
ScalarValue::Decimal128(None, *$PRECISION, *$SCALE) | ||
} else { | ||
let array = $VALUES.as_any().downcast_ref::<DecimalArray>().unwrap(); | ||
if null_count == 0 { | ||
// there is no null value | ||
let mut result = array.value(0); | ||
for i in 1..array.len() { | ||
result = result.$OP(array.value(i)); | ||
} | ||
ScalarValue::Decimal128(Some(result), *$PRECISION, *$SCALE) | ||
} else { | ||
let mut result = 0_i128; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be more idomatic to use Then instead of code like if !has_value && array.is_valid(i) {
has_value = true;
result = array.value(i);
}
if array.is_valid(i) {
result = result.$OP(array.value(i));
} You could write code like the following (which I think saves at least one check of if array.is_valid(i) {
let value = array.value(i);
result = result.$OP(result.unwrap_or(value)
} This is just a style suggestion, it is not needed I don't think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I plan to resolve it in the follow-up pull request. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I recheck your suggestion and find some bugs. I think we should just change to that
In addition the
|
||
let mut has_value = false; | ||
for i in 0..array.len() { | ||
if !has_value && array.is_valid(i) { | ||
has_value = true; | ||
result = array.value(i); | ||
} | ||
if array.is_valid(i) { | ||
result = result.$OP(array.value(i)); | ||
} | ||
} | ||
ScalarValue::Decimal128(Some(result), *$PRECISION, *$SCALE) | ||
} | ||
} | ||
}}; | ||
} | ||
|
||
// Statically-typed version of min/max(array) -> ScalarValue for non-string types. | ||
// this is a macro to support both operations (min and max). | ||
macro_rules! min_max_batch { | ||
($VALUES:expr, $OP:ident) => {{ | ||
match $VALUES.data_type() { | ||
DataType::Decimal(precision, scale) => { | ||
typed_min_max_batch_decimal128!($VALUES, precision, scale, $OP) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if you added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, when I finish the task of aggregating with decimal data type, I will begin to do basic operations in arrow-rs for decimal data type. |
||
} | ||
// all types that have a natural order | ||
DataType::Float64 => { | ||
typed_min_max_batch!($VALUES, Float64Array, Float64, $OP) | ||
|
@@ -208,6 +248,20 @@ fn max_batch(values: &ArrayRef) -> Result<ScalarValue> { | |
_ => min_max_batch!(values, max), | ||
}) | ||
} | ||
macro_rules! typed_min_max_decimal { | ||
($VALUE:expr, $DELTA:expr, $PRECISION:expr, $SCALE:expr, $SCALAR:ident, $OP:ident) => {{ | ||
ScalarValue::$SCALAR( | ||
match ($VALUE, $DELTA) { | ||
(None, None) => None, | ||
(Some(a), None) => Some(a.clone()), | ||
(None, Some(b)) => Some(b.clone()), | ||
(Some(a), Some(b)) => Some((*a).$OP(*b)), | ||
}, | ||
$PRECISION.clone(), | ||
$SCALE.clone(), | ||
) | ||
}}; | ||
} | ||
|
||
// min/max of two non-string scalar values. | ||
macro_rules! typed_min_max { | ||
|
@@ -237,6 +291,16 @@ macro_rules! typed_min_max_string { | |
macro_rules! min_max { | ||
($VALUE:expr, $DELTA:expr, $OP:ident) => {{ | ||
Ok(match ($VALUE, $DELTA) { | ||
(ScalarValue::Decimal128(lhsv,lhsp,lhss), ScalarValue::Decimal128(rhsv,rhsp,rhss)) => { | ||
if lhsp.eq(rhsp) && lhss.eq(rhss) { | ||
typed_min_max_decimal!(lhsv, rhsv, lhsp, lhss, Decimal128, $OP) | ||
} else { | ||
return Err(DataFusionError::Internal(format!( | ||
"MIN/MAX is not expected to receive scalars of incompatible types {:?}", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
(ScalarValue::Decimal128(*lhsv,*lhsp,*lhss),ScalarValue::Decimal128(*rhsv,*rhsp,*rhss)) | ||
))); | ||
} | ||
} | ||
(ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { | ||
typed_min_max!(lhs, rhs, Float64, $OP) | ||
} | ||
|
@@ -411,6 +475,10 @@ impl AggregateExpr for Min { | |
)) | ||
} | ||
|
||
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> { | ||
Ok(Box::new(MinAccumulator::try_new(&self.data_type)?)) | ||
} | ||
|
||
fn state_fields(&self) -> Result<Vec<Field>> { | ||
Ok(vec![Field::new( | ||
&format_state_name(&self.name, "min"), | ||
|
@@ -423,10 +491,6 @@ impl AggregateExpr for Min { | |
vec![self.expr.clone()] | ||
} | ||
|
||
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> { | ||
Ok(Box::new(MinAccumulator::try_new(&self.data_type)?)) | ||
} | ||
|
||
fn name(&self) -> &str { | ||
&self.name | ||
} | ||
|
@@ -452,19 +516,19 @@ impl Accumulator for MinAccumulator { | |
Ok(vec![self.min.clone()]) | ||
} | ||
|
||
fn update(&mut self, values: &[ScalarValue]) -> Result<()> { | ||
let value = &values[0]; | ||
self.min = min(&self.min, value)?; | ||
Ok(()) | ||
} | ||
|
||
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { | ||
let values = &values[0]; | ||
let delta = &min_batch(values)?; | ||
self.min = min(&self.min, delta)?; | ||
Ok(()) | ||
} | ||
|
||
fn update(&mut self, values: &[ScalarValue]) -> Result<()> { | ||
let value = &values[0]; | ||
self.min = min(&self.min, value)?; | ||
Ok(()) | ||
} | ||
|
||
fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { | ||
self.update(states) | ||
} | ||
|
@@ -483,10 +547,180 @@ mod tests { | |
use super::*; | ||
use crate::physical_plan::expressions::col; | ||
use crate::physical_plan::expressions::tests::aggregate; | ||
use crate::scalar::ScalarValue::Decimal128; | ||
use crate::{error::Result, generic_test_op}; | ||
use arrow::array::DecimalBuilder; | ||
use arrow::datatypes::*; | ||
use arrow::record_batch::RecordBatch; | ||
|
||
#[test] | ||
fn min_decimal() -> Result<()> { | ||
// min | ||
let left = ScalarValue::Decimal128(Some(123), 10, 2); | ||
let right = ScalarValue::Decimal128(Some(124), 10, 2); | ||
let result = min(&left, &right)?; | ||
assert_eq!(result, left); | ||
|
||
// min batch | ||
let mut decimal_builder = DecimalBuilder::new(5, 10, 0); | ||
for i in 1..6 { | ||
decimal_builder.append_value(i as i128)?; | ||
} | ||
let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
|
||
let result = min_batch(&array)?; | ||
assert_eq!(result, ScalarValue::Decimal128(Some(1), 10, 0)); | ||
// min batch without values | ||
let mut decimal_builder = DecimalBuilder::new(5, 10, 0); | ||
let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
let result = min_batch(&array)?; | ||
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); | ||
|
||
let mut decimal_builder = DecimalBuilder::new(0, 10, 0); | ||
let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
let result = min_batch(&array)?; | ||
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); | ||
|
||
// min batch with agg | ||
let mut decimal_builder = DecimalBuilder::new(6, 10, 0); | ||
decimal_builder.append_null().unwrap(); | ||
for i in 1..6 { | ||
decimal_builder.append_value(i as i128)?; | ||
} | ||
let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
generic_test_op!( | ||
array, | ||
DataType::Decimal(10, 0), | ||
Min, | ||
ScalarValue::Decimal128(Some(1), 10, 0), | ||
DataType::Decimal(10, 0) | ||
) | ||
} | ||
|
||
#[test] | ||
fn min_decimal_all_nulls() -> Result<()> { | ||
// min batch all nulls | ||
let mut decimal_builder = DecimalBuilder::new(5, 10, 0); | ||
for _i in 1..6 { | ||
decimal_builder.append_null()?; | ||
} | ||
let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
generic_test_op!( | ||
array, | ||
DataType::Decimal(10, 0), | ||
Min, | ||
ScalarValue::Decimal128(None, 10, 0), | ||
DataType::Decimal(10, 0) | ||
) | ||
} | ||
|
||
#[test] | ||
fn min_decimal_with_nulls() -> Result<()> { | ||
// min batch with nulls | ||
let mut decimal_builder = DecimalBuilder::new(5, 10, 0); | ||
for i in 1..6 { | ||
if i == 2 { | ||
decimal_builder.append_null()?; | ||
} else { | ||
decimal_builder.append_value(i as i128)?; | ||
} | ||
} | ||
let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
generic_test_op!( | ||
array, | ||
DataType::Decimal(10, 0), | ||
Min, | ||
ScalarValue::Decimal128(Some(1), 10, 0), | ||
DataType::Decimal(10, 0) | ||
) | ||
} | ||
|
||
#[test] | ||
fn max_decimal() -> Result<()> { | ||
// max | ||
let left = ScalarValue::Decimal128(Some(123), 10, 2); | ||
let right = ScalarValue::Decimal128(Some(124), 10, 2); | ||
let result = max(&left, &right)?; | ||
assert_eq!(result, right); | ||
|
||
let right = ScalarValue::Decimal128(Some(124), 10, 3); | ||
let result = max(&left, &right); | ||
let expect = DataFusionError::Internal(format!( | ||
"MIN/MAX is not expected to receive scalars of incompatible types {:?}", | ||
(Decimal128(Some(123), 10, 2), Decimal128(Some(124), 10, 3)) | ||
)); | ||
assert_eq!(expect.to_string(), result.unwrap_err().to_string()); | ||
|
||
// max batch | ||
let mut decimal_builder = DecimalBuilder::new(5, 10, 5); | ||
for i in 1..6 { | ||
decimal_builder.append_value(i as i128)?; | ||
} | ||
let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
let result = max_batch(&array)?; | ||
assert_eq!(result, ScalarValue::Decimal128(Some(5), 10, 5)); | ||
// max batch without values | ||
let mut decimal_builder = DecimalBuilder::new(5, 10, 0); | ||
let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
let result = max_batch(&array)?; | ||
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); | ||
|
||
let mut decimal_builder = DecimalBuilder::new(0, 10, 0); | ||
let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
let result = max_batch(&array)?; | ||
assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); | ||
// max batch with agg | ||
let mut decimal_builder = DecimalBuilder::new(6, 10, 0); | ||
decimal_builder.append_null().unwrap(); | ||
for i in 1..6 { | ||
decimal_builder.append_value(i as i128)?; | ||
} | ||
let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
generic_test_op!( | ||
array, | ||
DataType::Decimal(10, 0), | ||
Max, | ||
ScalarValue::Decimal128(Some(5), 10, 0), | ||
DataType::Decimal(10, 0) | ||
) | ||
} | ||
|
||
#[test] | ||
fn max_decimal_with_nulls() -> Result<()> { | ||
let mut decimal_builder = DecimalBuilder::new(5, 10, 0); | ||
for i in 1..6 { | ||
if i == 2 { | ||
decimal_builder.append_null()?; | ||
} else { | ||
decimal_builder.append_value(i as i128)?; | ||
} | ||
} | ||
let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
generic_test_op!( | ||
array, | ||
DataType::Decimal(10, 0), | ||
Max, | ||
ScalarValue::Decimal128(Some(5), 10, 0), | ||
DataType::Decimal(10, 0) | ||
) | ||
} | ||
|
||
#[test] | ||
fn max_decimal_all_nulls() -> Result<()> { | ||
let mut decimal_builder = DecimalBuilder::new(5, 10, 0); | ||
for _i in 1..6 { | ||
decimal_builder.append_null()?; | ||
} | ||
let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
generic_test_op!( | ||
array, | ||
DataType::Decimal(10, 0), | ||
Min, | ||
ScalarValue::Decimal128(None, 10, 0), | ||
DataType::Decimal(10, 0) | ||
) | ||
} | ||
|
||
#[test] | ||
fn max_i32() -> Result<()> { | ||
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍