Skip to content

Commit

Permalink
Implement GroupsAccumulator for stddev and variance
Browse files Browse the repository at this point in the history
  • Loading branch information
eejbyfeldt committed Aug 21, 2024
1 parent 0324bb4 commit 847e122
Show file tree
Hide file tree
Showing 3 changed files with 433 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,105 +134,14 @@ impl NullState {
T: ArrowPrimitiveType + Send,
F: FnMut(usize, T::Native) + Send,
{
let data: &[T::Native] = values.values();
assert_eq!(data.len(), group_indices.len());

// ensure the seen_values is big enough (start everything at
// "not seen" valid)
let seen_values =
initialize_builder(&mut self.seen_values, total_num_groups, false);

match (values.null_count() > 0, opt_filter) {
// no nulls, no filter,
(false, None) => {
let iter = group_indices.iter().zip(data.iter());
for (&group_index, &new_value) in iter {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
}
// nulls, no filter
(true, None) => {
let nulls = values.nulls().unwrap();
// This is based on (ahem, COPY/PASTE) arrow::compute::aggregate::sum
// iterate over in chunks of 64 bits for more efficient null checking
let group_indices_chunks = group_indices.chunks_exact(64);
let data_chunks = data.chunks_exact(64);
let bit_chunks = nulls.inner().bit_chunks();

let group_indices_remainder = group_indices_chunks.remainder();
let data_remainder = data_chunks.remainder();

group_indices_chunks
.zip(data_chunks)
.zip(bit_chunks.iter())
.for_each(|((group_index_chunk, data_chunk), mask)| {
// index_mask has value 1 << i in the loop
let mut index_mask = 1;
group_index_chunk.iter().zip(data_chunk.iter()).for_each(
|(&group_index, &new_value)| {
// valid bit was set, real value
let is_valid = (mask & index_mask) != 0;
if is_valid {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
index_mask <<= 1;
},
)
});

// handle any remaining bits (after the initial 64)
let remainder_bits = bit_chunks.remainder_bits();
group_indices_remainder
.iter()
.zip(data_remainder.iter())
.enumerate()
.for_each(|(i, (&group_index, &new_value))| {
let is_valid = remainder_bits & (1 << i) != 0;
if is_valid {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
});
}
// no nulls, but a filter
(false, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
// The performance with a filter could be improved by
// iterating over the filter in chunks, rather than a single
// iterator. TODO file a ticket
group_indices
.iter()
.zip(data.iter())
.zip(filter.iter())
.for_each(|((&group_index, &new_value), filter_value)| {
if let Some(true) = filter_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value);
}
})
}
// both null values and filters
(true, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
// The performance with a filter could be improved by
// iterating over the filter in chunks, rather than using
// iterators. TODO file a ticket
filter
.iter()
.zip(group_indices.iter())
.zip(values.iter())
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
seen_values.set_bit(group_index, true);
value_fn(group_index, new_value)
}
}
})
}
}
accumulate(group_indices, values, opt_filter, |group_index, value| {
seen_values.set_bit(group_index, true);
value_fn(group_index, value);
});
}

/// Invokes `value_fn(group_index, value)` for each non null, non
Expand Down Expand Up @@ -351,6 +260,106 @@ impl NullState {
}
}

pub fn accumulate<T, F>(
group_indices: &[usize],
values: &PrimitiveArray<T>,
opt_filter: Option<&BooleanArray>,
mut value_fn: F,
) where
T: ArrowPrimitiveType + Send,
F: FnMut(usize, T::Native) + Send,
{
let data: &[T::Native] = values.values();
assert_eq!(data.len(), group_indices.len());

match (values.null_count() > 0, opt_filter) {
// no nulls, no filter,
(false, None) => {
let iter = group_indices.iter().zip(data.iter());
for (&group_index, &new_value) in iter {
value_fn(group_index, new_value);
}
}
// nulls, no filter
(true, None) => {
let nulls = values.nulls().unwrap();
// This is based on (ahem, COPY/PASTE) arrow::compute::aggregate::sum
// iterate over in chunks of 64 bits for more efficient null checking
let group_indices_chunks = group_indices.chunks_exact(64);
let data_chunks = data.chunks_exact(64);
let bit_chunks = nulls.inner().bit_chunks();

let group_indices_remainder = group_indices_chunks.remainder();
let data_remainder = data_chunks.remainder();

group_indices_chunks
.zip(data_chunks)
.zip(bit_chunks.iter())
.for_each(|((group_index_chunk, data_chunk), mask)| {
// index_mask has value 1 << i in the loop
let mut index_mask = 1;
group_index_chunk.iter().zip(data_chunk.iter()).for_each(
|(&group_index, &new_value)| {
// valid bit was set, real value
let is_valid = (mask & index_mask) != 0;
if is_valid {
value_fn(group_index, new_value);
}
index_mask <<= 1;
},
)
});

// handle any remaining bits (after the initial 64)
let remainder_bits = bit_chunks.remainder_bits();
group_indices_remainder
.iter()
.zip(data_remainder.iter())
.enumerate()
.for_each(|(i, (&group_index, &new_value))| {
let is_valid = remainder_bits & (1 << i) != 0;
if is_valid {
value_fn(group_index, new_value);
}
});
}
// no nulls, but a filter
(false, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
// The performance with a filter could be improved by
// iterating over the filter in chunks, rather than a single
// iterator. TODO file a ticket
group_indices
.iter()
.zip(data.iter())
.zip(filter.iter())
.for_each(|((&group_index, &new_value), filter_value)| {
if let Some(true) = filter_value {
value_fn(group_index, new_value);
}
})
}
// both null values and filters
(true, Some(filter)) => {
assert_eq!(filter.len(), group_indices.len());
// The performance with a filter could be improved by
// iterating over the filter in chunks, rather than using
// iterators. TODO file a ticket
filter
.iter()
.zip(group_indices.iter())
.zip(values.iter())
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
value_fn(group_index, new_value)
}
}
})
}
}
}

/// This function is called to update the accumulator state per row
/// when the value is not needed (e.g. COUNT)
///
Expand Down
83 changes: 81 additions & 2 deletions datafusion/functions-aggregate/src/stddev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,21 @@
use std::any::Any;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use arrow::array::Float64Array;
use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};

use datafusion_common::{internal_err, not_impl_err, Result};
use datafusion_common::{plan_err, ScalarValue};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, GroupsAccumulator, Signature, Volatility,
};
use datafusion_functions_aggregate_common::stats::StatsType;

use crate::variance::VarianceAccumulator;
use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator};

make_udaf_expr_and_func!(
Stddev,
Expand Down Expand Up @@ -118,6 +122,17 @@ impl AggregateUDFImpl for Stddev {
fn aliases(&self) -> &[String] {
&self.alias
}

fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
!acc_args.is_distinct
}

fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample)))
}
}

make_udaf_expr_and_func!(
Expand Down Expand Up @@ -201,6 +216,19 @@ impl AggregateUDFImpl for StddevPop {

Ok(DataType::Float64)
}

fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool {
!acc_args.is_distinct
}

fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(StddevGroupsAccumulator::new(
StatsType::Population,
)))
}
}

/// An accumulator to compute the average
Expand Down Expand Up @@ -267,6 +295,57 @@ impl Accumulator for StddevAccumulator {
}
}

#[derive(Debug)]
pub struct StddevGroupsAccumulator {
variance: VarianceGroupsAccumulator,
}

impl StddevGroupsAccumulator {
pub fn new(s_type: StatsType) -> Self {
Self {
variance: VarianceGroupsAccumulator::new(s_type),
}
}
}

impl GroupsAccumulator for StddevGroupsAccumulator {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&arrow::array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
self.variance
.update_batch(values, group_indices, opt_filter, total_num_groups)
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&arrow::array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
self.variance
.merge_batch(values, group_indices, opt_filter, total_num_groups)
}

fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> {
let (mut variances, nulls) = self.variance.variance(emit_to);
variances.iter_mut().for_each(|v| *v = v.sqrt());
Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))
}

fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> {
self.variance.state(emit_to)
}

fn size(&self) -> usize {
self.variance.size()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit 847e122

Please sign in to comment.