Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Speed up boolean comparison kernels (~3x) (#610)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dandandan authored Nov 16, 2021
1 parent ed0dee7 commit 70562fa
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 59 deletions.
29 changes: 7 additions & 22 deletions benches/comparison_kernels.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use criterion::{criterion_group, criterion_main, Criterion};

use arrow2::array::*;
use arrow2::scalar::*;
use arrow2::util::bench_util::*;
use arrow2::{compute::comparison::*, datatypes::DataType};

fn bench_op(arr_a: &dyn Array, arr_b: &dyn Array, op: Operator) {
compare(black_box(arr_a), black_box(arr_b), op).unwrap();
}

fn bench_op_scalar(arr_a: &dyn Array, value_b: &dyn Scalar, op: Operator) {
compare_scalar(black_box(arr_a), black_box(value_b), op).unwrap();
}

fn add_benchmark(c: &mut Criterion) {
(10..=20).step_by(2).for_each(|log2_size| {
let size = 2usize.pow(log2_size);
Expand All @@ -21,36 +12,30 @@ fn add_benchmark(c: &mut Criterion) {
let arr_b = create_primitive_array_with_seed::<f32>(size, DataType::Float32, 0.0, 43);

c.bench_function(&format!("f32 2^{}", log2_size), |b| {
b.iter(|| bench_op(&arr_a, &arr_b, Operator::Eq))
b.iter(|| eq(&arr_a, &arr_b))
});
c.bench_function(&format!("f32 scalar 2^{}", log2_size), |b| {
b.iter(|| {
bench_op_scalar(
&arr_a,
&PrimitiveScalar::<f32>::from(Some(0.5)),
Operator::Eq,
)
})
b.iter(|| eq_scalar(&arr_a, &PrimitiveScalar::<f32>::from(Some(0.5))))
});

let arr_a = create_boolean_array(size, 0.0, 0.1);
let arr_b = create_boolean_array(size, 0.0, 0.2);

c.bench_function(&format!("bool 2^{}", log2_size), |b| {
b.iter(|| bench_op(&arr_a, &arr_b, Operator::Eq))
b.iter(|| eq(&arr_a, &arr_b))
});
c.bench_function(&format!("bool scalar 2^{}", log2_size), |b| {
b.iter(|| bench_op_scalar(&arr_a, &BooleanScalar::from(Some(false)), Operator::Eq))
b.iter(|| eq_scalar(&arr_a, &BooleanScalar::from(Some(false))))
});

let arr_a = create_string_array::<i32>(size, 4, 0.1, 42);
let arr_b = create_string_array::<i32>(size, 4, 0.1, 43);
c.bench_function(&format!("utf8 2^{}", log2_size), |b| {
b.iter(|| bench_op(&arr_a, &arr_b, Operator::Eq))
b.iter(|| eq(&arr_a, &arr_b))
});

c.bench_function(&format!("utf8 2^{}", log2_size), |b| {
b.iter(|| bench_op_scalar(&arr_a, &Utf8Scalar::<i32>::from(Some("abc")), Operator::Eq))
b.iter(|| eq_scalar(&arr_a, &Utf8Scalar::<i32>::from(Some("abc"))))
});
})
}
Expand Down
43 changes: 6 additions & 37 deletions src/compute/comparison/boolean.rs
Original file line number Diff line number Diff line change
@@ -1,66 +1,35 @@
//! Comparison functions for [`BooleanArray`]
use crate::{
array::BooleanArray,
bitmap::{Bitmap, MutableBitmap},
buffer::MutableBuffer,
bitmap::{binary, unary, Bitmap},
datatypes::DataType,
};

use super::super::utils::combine_validities;

fn compare_values_op<F>(lhs: &Bitmap, rhs: &Bitmap, op: F) -> MutableBitmap
where
F: Fn(u8, u8) -> u8,
{
assert_eq!(lhs.len(), rhs.len());
let lhs_iter = lhs.chunks();
let rhs_iter = rhs.chunks();
let lhs_remainder = lhs_iter.remainder();
let rhs_remainder = rhs_iter.remainder();

let mut values = MutableBuffer::with_capacity((lhs.len() + 7) / 8);
let iter = lhs_iter.zip(rhs_iter).map(|(x, y)| op(x, y));
values.extend_from_trusted_len_iter(iter);

if lhs.len() % 8 != 0 {
values.push(op(lhs_remainder, rhs_remainder))
};

MutableBitmap::from_buffer(values, lhs.len())
}

/// Evaluate `op(lhs, rhs)` for [`BooleanArray`]s using a specified
/// comparison function.
fn compare_op<F>(lhs: &BooleanArray, rhs: &BooleanArray, op: F) -> BooleanArray
where
F: Fn(u8, u8) -> u8,
F: Fn(u64, u64) -> u64,
{
assert_eq!(lhs.len(), rhs.len());
let validity = combine_validities(lhs.validity(), rhs.validity());

let values = compare_values_op(lhs.values(), rhs.values(), op);
let values = binary(lhs.values(), rhs.values(), op);

BooleanArray::from_data(DataType::Boolean, values.into(), validity)
BooleanArray::from_data(DataType::Boolean, values, validity)
}

/// Evaluate `op(left, right)` for [`BooleanArray`] and scalar using
/// a specified comparison function.
pub fn compare_op_scalar<F>(lhs: &BooleanArray, rhs: bool, op: F) -> BooleanArray
where
F: Fn(u8, u8) -> u8,
F: Fn(u64, u64) -> u64,
{
let lhs_iter = lhs.values().chunks();
let lhs_remainder = lhs_iter.remainder();
let rhs = if rhs { 0b11111111 } else { 0 };

let mut values = MutableBuffer::with_capacity((lhs.len() + 7) / 8);
let iter = lhs_iter.map(|x| op(x, rhs));
values.extend_from_trusted_len_iter(iter);

if lhs.len() % 8 != 0 {
values.push(op(lhs_remainder, rhs))
};
let values = MutableBitmap::from_buffer(values, lhs.len()).into();
let values = unary(lhs.values(), |x| op(x, rhs));
BooleanArray::from_data(DataType::Boolean, values, lhs.validity().cloned())
}

Expand Down

0 comments on commit 70562fa

Please sign in to comment.