diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index 50bd24c487bfe..a6d5054ec170b 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -488,6 +488,20 @@ macro_rules! typed_min_max { }}; } +macro_rules! typed_min_max_float { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ + ScalarValue::$SCALAR(match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) => match a.total_cmp(b) { + choose_min_max!($OP) => Some(*b), + _ => Some(*a), + }, + }) + }}; +} + // min/max of two scalar string values. macro_rules! typed_min_max_string { ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ @@ -500,7 +514,7 @@ macro_rules! typed_min_max_string { }}; } -macro_rules! interval_choose_min_max { +macro_rules! choose_min_max { (min) => { std::cmp::Ordering::Greater }; @@ -512,7 +526,7 @@ macro_rules! interval_choose_min_max { macro_rules! interval_min_max { ($OP:tt, $LHS:expr, $RHS:expr) => {{ match $LHS.partial_cmp(&$RHS) { - Some(interval_choose_min_max!($OP)) => $RHS.clone(), + Some(choose_min_max!($OP)) => $RHS.clone(), Some(_) => $LHS.clone(), None => { return internal_err!("Comparison error while computing interval min/max") @@ -555,10 +569,10 @@ macro_rules! min_max { typed_min_max!(lhs, rhs, Boolean, $OP) } (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { - typed_min_max!(lhs, rhs, Float64, $OP) + typed_min_max_float!(lhs, rhs, Float64, $OP) } (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { - typed_min_max!(lhs, rhs, Float32, $OP) + typed_min_max_float!(lhs, rhs, Float32, $OP) } (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { typed_min_max!(lhs, rhs, UInt64, $OP) @@ -1103,3 +1117,41 @@ impl Accumulator for SlidingMinAccumulator { std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn float_min_max_with_nans() { + let pos_nan = f32::NAN; + let zero = 0_f32; + let neg_inf = f32::NEG_INFINITY; + + let check = |acc: &mut dyn Accumulator, values: &[&[f32]], expected: f32| { + for batch in values.iter() { + let batch = + Arc::new(Float32Array::from_iter_values(batch.iter().copied())); + acc.update_batch(&[batch]).unwrap(); + } + let result = acc.evaluate().unwrap(); + assert_eq!(result, ScalarValue::Float32(Some(expected))); + }; + + // This test checks both comparison between batches (which uses the min_max macro + // defined above) and within a batch (which uses the arrow min/max compute function + // and verifies both respect the total order comparison for floats) + + let min = || MinAccumulator::try_new(&DataType::Float32).unwrap(); + let max = || MaxAccumulator::try_new(&DataType::Float32).unwrap(); + + check(&mut min(), &[&[zero], &[pos_nan]], zero); + check(&mut min(), &[&[zero, pos_nan]], zero); + check(&mut min(), &[&[zero], &[neg_inf]], neg_inf); + check(&mut min(), &[&[zero, neg_inf]], neg_inf); + check(&mut max(), &[&[zero], &[pos_nan]], pos_nan); + check(&mut max(), &[&[zero, pos_nan]], pos_nan); + check(&mut max(), &[&[zero], &[neg_inf]], zero); + check(&mut max(), &[&[zero, neg_inf]], zero); + } +}