From aa86aa93dff63076652692e6d081292610ef45e3 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sun, 11 Sep 2022 12:51:48 +0200 Subject: [PATCH] fixed `neq_and_validity` in the default case w/ nulls (#1244) fixed neq_and_validity in the default case w/ nulls --- src/compute/comparison/mod.rs | 9 ++++++--- tests/it/compute/comparison.rs | 15 ++++++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/compute/comparison/mod.rs b/src/compute/comparison/mod.rs index ba6a84e49b8..9457e0adf05 100644 --- a/src/compute/comparison/mod.rs +++ b/src/compute/comparison/mod.rs @@ -612,9 +612,12 @@ fn finish_neq_validities( if both_sides_invalid.values().unset_bits() != both_sides_invalid.len() { // we use the `binary` kernel directly to save allocations // and apply `lhs & !rhs)` in one shot. - compute::boolean::binary_boolean_kernel(&lhs, &rhs, |lhs, rhs| { - binary(lhs, rhs, |lhs, rhs| (lhs & !rhs)) - }) + + compute::boolean::binary_boolean_kernel( + &or, + &both_sides_invalid, + |lhs, rhs| binary(lhs, rhs, |lhs, rhs| (lhs & !rhs)), + ) } else { or } diff --git a/tests/it/compute/comparison.rs b/tests/it/compute/comparison.rs index b077e996100..eb5058d337a 100644 --- a/tests/it/compute/comparison.rs +++ b/tests/it/compute/comparison.rs @@ -383,7 +383,7 @@ fn primitive_gt_eq() { } #[test] -#[cfg(any(feature = "compute_cast", feature = "compute_boolean_kleene"))] +#[cfg(all(feature = "compute_cast", feature = "compute_boolean_kleene"))] fn utf8_and_validity() { use arrow2::compute::cast::CastOptions; let a1 = Utf8Array::::from([Some("0"), Some("1"), None, Some("2")]); @@ -401,3 +401,16 @@ fn utf8_and_validity() { assert_eq!(utf8::neq_and_validity(&a1, &a1), expected); assert_eq!(utf8::neq_and_validity(&a1, a2), expected); } + +#[test] +#[cfg(feature = "compute_boolean_kleene")] +fn primitive_and_validity() { + let a1 = Int32Array::from([Some(0), None]); + let a2 = Int32Array::from([Some(10), None]); + + let expected = BooleanArray::from_slice([true, false]); + assert_eq!(primitive::neq_and_validity(&a1, &a2), expected); + + let expected = BooleanArray::from_slice([false, true]); + assert_eq!(primitive::eq_and_validity(&a1, &a2), expected); +}