Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Added compare_scalar #317

Merged
merged 2 commits into from
Aug 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions benches/comparison_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ where
criterion::black_box(arr_a),
criterion::black_box(value_b),
op,
)
.unwrap();
);
}

fn add_benchmark(c: &mut Criterion) {
Expand Down
28 changes: 18 additions & 10 deletions src/compute/comparison/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use crate::array::*;
use crate::bitmap::Bitmap;
use crate::scalar::{BooleanScalar, Scalar};
use crate::{
bitmap::MutableBitmap,
error::{ArrowError, Result},
Expand Down Expand Up @@ -56,14 +57,14 @@ where

/// Evaluate `op(left, right)` for [`BooleanArray`] and scalar using
/// a specified comparison function.
pub fn compare_op_scalar<F>(lhs: &BooleanArray, rhs: bool, op: F) -> Result<BooleanArray>
pub fn compare_op_scalar<F>(lhs: &BooleanArray, rhs: bool, op: F) -> BooleanArray
where
F: Fn(bool, bool) -> bool,
{
let lhs_iter = lhs.values().iter();

let values = Bitmap::from_trusted_len_iter(lhs_iter.map(|x| op(x, rhs)));
Ok(BooleanArray::from_data(values, lhs.validity().clone()))
BooleanArray::from_data(values, lhs.validity().clone())
}

/// Perform `lhs == rhs` operation on two arrays.
Expand All @@ -72,7 +73,7 @@ pub fn eq(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {
}

/// Perform `left == right` operation on an array and a scalar value.
pub fn eq_scalar(lhs: &BooleanArray, rhs: bool) -> Result<BooleanArray> {
pub fn eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a == b)
}

Expand All @@ -82,7 +83,7 @@ pub fn neq(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {
}

/// Perform `left != right` operation on an array and a scalar value.
pub fn neq_scalar(lhs: &BooleanArray, rhs: bool) -> Result<BooleanArray> {
pub fn neq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a != b)
}

Expand All @@ -92,7 +93,7 @@ pub fn lt(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {
}

/// Perform `left < right` operation on an array and a scalar value.
pub fn lt_scalar(lhs: &BooleanArray, rhs: bool) -> Result<BooleanArray> {
pub fn lt_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| !a & b)
}

Expand All @@ -103,7 +104,7 @@ pub fn lt_eq(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {

/// Perform `left <= right` operation on an array and a scalar value.
/// Null values are less than non-null values.
pub fn lt_eq_scalar(lhs: &BooleanArray, rhs: bool) -> Result<BooleanArray> {
pub fn lt_eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a <= b)
}

Expand All @@ -115,7 +116,7 @@ pub fn gt(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {

/// Perform `left > right` operation on an array and a scalar value.
/// Non-null values are greater than null values.
pub fn gt_scalar(lhs: &BooleanArray, rhs: bool) -> Result<BooleanArray> {
pub fn gt_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a & !b)
}

Expand All @@ -127,7 +128,7 @@ pub fn gt_eq(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {

/// Perform `left >= right` operation on an array and a scalar value.
/// Non-null values are greater than null values.
pub fn gt_eq_scalar(lhs: &BooleanArray, rhs: bool) -> Result<BooleanArray> {
pub fn gt_eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a >= b)
}

Expand All @@ -142,7 +143,14 @@ pub fn compare(lhs: &BooleanArray, rhs: &BooleanArray, op: Operator) -> Result<B
}
}

pub fn compare_scalar(lhs: &BooleanArray, rhs: bool, op: Operator) -> Result<BooleanArray> {
pub fn compare_scalar(lhs: &BooleanArray, rhs: &BooleanScalar, op: Operator) -> BooleanArray {
if !rhs.is_valid() {
return BooleanArray::new_null(lhs.len());
}
compare_scalar_non_null(lhs, rhs.value(), op)
}

pub fn compare_scalar_non_null(lhs: &BooleanArray, rhs: bool, op: Operator) -> BooleanArray {
match op {
Operator::Eq => eq_scalar(lhs, rhs),
Operator::Neq => neq_scalar(lhs, rhs),
Expand Down Expand Up @@ -180,7 +188,7 @@ mod tests {
macro_rules! cmp_bool_scalar {
($KERNEL:ident, $A_VEC:expr, $B:literal, $EXPECTED:expr) => {
let a = BooleanArray::from_slice($A_VEC);
let c = $KERNEL(&a, $B).unwrap();
let c = $KERNEL(&a, $B);
assert_eq!(BooleanArray::from_slice($EXPECTED), c);
};
}
Expand Down
Loading