diff --git a/src/array/utf8/mod.rs b/src/array/utf8/mod.rs index 4b0abc10039..2a49f51e7ed 100644 --- a/src/array/utf8/mod.rs +++ b/src/array/utf8/mod.rs @@ -178,7 +178,7 @@ impl Utf8Array { impl Utf8Array { /// Returns the length of this array #[inline] - fn len(&self) -> usize { + pub fn len(&self) -> usize { self.offsets.len() - 1 } diff --git a/src/compute/comparison/binary.rs b/src/compute/comparison/binary.rs index 9bf0a2c4f6a..461e9a3cb48 100644 --- a/src/compute/comparison/binary.rs +++ b/src/compute/comparison/binary.rs @@ -1,39 +1,20 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. +//! Comparison functions for [`BinaryArray`] +use crate::{ + array::{BinaryArray, BooleanArray, Offset}, + bitmap::Bitmap, + datatypes::DataType, +}; -use crate::datatypes::DataType; -use crate::error::{ArrowError, Result}; -use crate::scalar::{BinaryScalar, Scalar}; -use crate::{array::*, bitmap::Bitmap}; - -use super::{super::utils::combine_validities, Operator}; +use super::super::utils::combine_validities; /// Evaluate `op(lhs, rhs)` for [`BinaryArray`]s using a specified /// comparison function. -fn compare_op(lhs: &BinaryArray, rhs: &BinaryArray, op: F) -> Result +fn compare_op(lhs: &BinaryArray, rhs: &BinaryArray, op: F) -> BooleanArray where O: Offset, F: Fn(&[u8], &[u8]) -> bool, { - if lhs.len() != rhs.len() { - return Err(ArrowError::InvalidArgumentError( - "Cannot perform comparison operation on arrays of different length".to_string(), - )); - } + assert_eq!(lhs.len(), rhs.len()); let validity = combine_validities(lhs.validity(), rhs.validity()); @@ -43,7 +24,7 @@ where .map(|(lhs, rhs)| op(lhs, rhs)); let values = Bitmap::from_trusted_len_iter(values); - Ok(BooleanArray::from_data(DataType::Boolean, values, validity)) + BooleanArray::from_data(DataType::Boolean, values, validity) } /// Evaluate `op(lhs, rhs)` for [`BinaryArray`] and scalar using @@ -62,123 +43,74 @@ where } /// Perform `lhs == rhs` operation on [`BinaryArray`]. -fn eq(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { +/// # Panic +/// iff the arrays do not have the same length. +pub fn eq(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { compare_op(lhs, rhs, |a, b| a == b) } /// Perform `lhs == rhs` operation on [`BinaryArray`] and a scalar. -fn eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { +pub fn eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { compare_op_scalar(lhs, rhs, |a, b| a == b) } /// Perform `lhs != rhs` operation on [`BinaryArray`]. -fn neq(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { +/// # Panic +/// iff the arrays do not have the same length. +pub fn neq(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { compare_op(lhs, rhs, |a, b| a != b) } /// Perform `lhs != rhs` operation on [`BinaryArray`] and a scalar. -fn neq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { +pub fn neq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { compare_op_scalar(lhs, rhs, |a, b| a != b) } /// Perform `lhs < rhs` operation on [`BinaryArray`]. -fn lt(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { +pub fn lt(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { compare_op(lhs, rhs, |a, b| a < b) } /// Perform `lhs < rhs` operation on [`BinaryArray`] and a scalar. -fn lt_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { +pub fn lt_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { compare_op_scalar(lhs, rhs, |a, b| a < b) } /// Perform `lhs <= rhs` operation on [`BinaryArray`]. -fn lt_eq(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { +pub fn lt_eq(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { compare_op(lhs, rhs, |a, b| a <= b) } /// Perform `lhs <= rhs` operation on [`BinaryArray`] and a scalar. -fn lt_eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { +pub fn lt_eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { compare_op_scalar(lhs, rhs, |a, b| a <= b) } /// Perform `lhs > rhs` operation on [`BinaryArray`]. -fn gt(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { +pub fn gt(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { compare_op(lhs, rhs, |a, b| a > b) } /// Perform `lhs > rhs` operation on [`BinaryArray`] and a scalar. -fn gt_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { +pub fn gt_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { compare_op_scalar(lhs, rhs, |a, b| a > b) } /// Perform `lhs >= rhs` operation on [`BinaryArray`]. -fn gt_eq(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { +pub fn gt_eq(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { compare_op(lhs, rhs, |a, b| a >= b) } /// Perform `lhs >= rhs` operation on [`BinaryArray`] and a scalar. -fn gt_eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { +pub fn gt_eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { compare_op_scalar(lhs, rhs, |a, b| a >= b) } -/// Compare two [`BinaryArray`]s using the given [`Operator`]. -/// -/// # Errors -/// When the two arrays have different lengths. -/// -/// Check the [crate::compute::comparison](module documentation) for usage -/// examples. -pub fn compare( - lhs: &BinaryArray, - rhs: &BinaryArray, - op: Operator, -) -> Result { - match op { - Operator::Eq => eq(lhs, rhs), - Operator::Neq => neq(lhs, rhs), - Operator::Gt => gt(lhs, rhs), - Operator::GtEq => gt_eq(lhs, rhs), - Operator::Lt => lt(lhs, rhs), - Operator::LtEq => lt_eq(lhs, rhs), - } -} - -/// Compare a [`BinaryArray`] and a scalar value using the given -/// [`Operator`]. -/// -/// Check the [crate::compute::comparison](module documentation) for usage -/// examples. -pub fn compare_scalar( - lhs: &BinaryArray, - rhs: &BinaryScalar, - op: Operator, -) -> BooleanArray { - if !rhs.is_valid() { - return BooleanArray::new_null(DataType::Boolean, lhs.len()); - } - compare_scalar_non_null(lhs, rhs.value(), op) -} - -pub fn compare_scalar_non_null( - lhs: &BinaryArray, - rhs: &[u8], - op: Operator, -) -> BooleanArray { - match op { - Operator::Eq => eq_scalar(lhs, rhs), - Operator::Neq => neq_scalar(lhs, rhs), - Operator::Gt => gt_scalar(lhs, rhs), - Operator::GtEq => gt_eq_scalar(lhs, rhs), - Operator::Lt => lt_scalar(lhs, rhs), - Operator::LtEq => lt_eq_scalar(lhs, rhs), - } -} - #[cfg(test)] mod tests { use super::*; - fn test_generic, &BinaryArray) -> Result>( + fn test_generic, &BinaryArray) -> BooleanArray>( lhs: Vec<&[u8]>, rhs: Vec<&[u8]>, op: F, @@ -187,7 +119,7 @@ mod tests { let lhs = BinaryArray::::from_slice(lhs); let rhs = BinaryArray::::from_slice(rhs); let expected = BooleanArray::from_slice(expected); - assert_eq!(op(&lhs, &rhs).unwrap(), expected); + assert_eq!(op(&lhs, &rhs), expected); } fn test_generic_scalar, &[u8]) -> BooleanArray>( diff --git a/src/compute/comparison/boolean.rs b/src/compute/comparison/boolean.rs index f9eb74928a4..fb279e16a08 100644 --- a/src/compute/comparison/boolean.rs +++ b/src/compute/comparison/boolean.rs @@ -1,16 +1,14 @@ -use crate::array::*; -use crate::bitmap::Bitmap; -use crate::buffer::MutableBuffer; -use crate::datatypes::DataType; -use crate::scalar::{BooleanScalar, Scalar}; +//! Comparison functions for [`BooleanArray`] use crate::{ - bitmap::MutableBitmap, - error::{ArrowError, Result}, + array::BooleanArray, + bitmap::{Bitmap, MutableBitmap}, + buffer::MutableBuffer, + datatypes::DataType, }; -use super::{super::utils::combine_validities, Operator}; +use super::super::utils::combine_validities; -pub fn compare_values_op(lhs: &Bitmap, rhs: &Bitmap, op: F) -> MutableBitmap +fn compare_values_op(lhs: &Bitmap, rhs: &Bitmap, op: F) -> MutableBitmap where F: Fn(u8, u8) -> u8, { @@ -33,25 +31,16 @@ where /// Evaluate `op(lhs, rhs)` for [`BooleanArray`]s using a specified /// comparison function. -fn compare_op(lhs: &BooleanArray, rhs: &BooleanArray, op: F) -> Result +fn compare_op(lhs: &BooleanArray, rhs: &BooleanArray, op: F) -> BooleanArray where F: Fn(u8, u8) -> u8, { - if lhs.len() != rhs.len() { - return Err(ArrowError::InvalidArgumentError( - "Cannot perform comparison operation on arrays of different length".to_string(), - )); - } - + assert_eq!(lhs.len(), rhs.len()); let validity = combine_validities(lhs.validity(), rhs.validity()); let values = compare_values_op(lhs.values(), rhs.values(), op); - Ok(BooleanArray::from_data( - DataType::Boolean, - values.into(), - validity, - )) + BooleanArray::from_data(DataType::Boolean, values.into(), validity) } /// Evaluate `op(left, right)` for [`BooleanArray`] and scalar using @@ -75,12 +64,12 @@ where BooleanArray::from_data(DataType::Boolean, values, lhs.validity().cloned()) } -/// Perform `lhs == rhs` operation on two arrays. -pub fn eq(lhs: &BooleanArray, rhs: &BooleanArray) -> Result { +/// Perform `lhs == rhs` operation on two [`BooleanArray`]s. +pub fn eq(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { compare_op(lhs, rhs, |a, b| !(a ^ b)) } -/// Perform `left == right` operation on an array and a scalar value. +/// Perform `lhs == rhs` operation on a [`BooleanArray`] and a scalar value. pub fn eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { if rhs { lhs.clone() @@ -89,8 +78,8 @@ pub fn eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { } } -/// Perform `left != right` operation on two arrays. -pub fn neq(lhs: &BooleanArray, rhs: &BooleanArray) -> Result { +/// `lhs != rhs` for [`BooleanArray`] +pub fn neq(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { compare_op(lhs, rhs, |a, b| a ^ b) } @@ -100,7 +89,7 @@ pub fn neq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { } /// Perform `left < right` operation on two arrays. -pub fn lt(lhs: &BooleanArray, rhs: &BooleanArray) -> Result { +pub fn lt(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { compare_op(lhs, rhs, |a, b| !a & b) } @@ -118,7 +107,7 @@ pub fn lt_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { } /// Perform `left <= right` operation on two arrays. -pub fn lt_eq(lhs: &BooleanArray, rhs: &BooleanArray) -> Result { +pub fn lt_eq(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { compare_op(lhs, rhs, |a, b| !a | b) } @@ -134,7 +123,7 @@ pub fn lt_eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { /// Perform `left > right` operation on two arrays. Non-null values are greater than null /// values. -pub fn gt(lhs: &BooleanArray, rhs: &BooleanArray) -> Result { +pub fn gt(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { compare_op(lhs, rhs, |a, b| a & !b) } @@ -154,7 +143,7 @@ pub fn gt_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { /// Perform `left >= right` operation on two arrays. Non-null values are greater than null /// values. -pub fn gt_eq(lhs: &BooleanArray, rhs: &BooleanArray) -> Result { +pub fn gt_eq(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { compare_op(lhs, rhs, |a, b| a | !b) } @@ -168,47 +157,6 @@ pub fn gt_eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { } } -/// Compare two [`BooleanArray`]s using the given [`Operator`]. -/// -/// # Errors -/// When the two arrays have different lengths. -/// -/// Check the [crate::compute::comparison](module documentation) for usage -/// examples. -pub fn compare(lhs: &BooleanArray, rhs: &BooleanArray, op: Operator) -> Result { - match op { - Operator::Eq => eq(lhs, rhs), - Operator::Neq => neq(lhs, rhs), - Operator::Gt => gt(lhs, rhs), - Operator::GtEq => gt_eq(lhs, rhs), - Operator::Lt => lt(lhs, rhs), - Operator::LtEq => lt_eq(lhs, rhs), - } -} - -/// Compare a [`BooleanArray`] and a scalar value using the given -/// [`Operator`]. -/// -/// Check the [crate::compute::comparison](module documentation) for usage -/// examples. -pub fn compare_scalar(lhs: &BooleanArray, rhs: &BooleanScalar, op: Operator) -> BooleanArray { - if !rhs.is_valid() { - return BooleanArray::new_null(DataType::Boolean, lhs.len()); - } - compare_scalar_non_null(lhs, rhs.value(), op) -} - -pub fn compare_scalar_non_null(lhs: &BooleanArray, rhs: bool, op: Operator) -> BooleanArray { - match op { - Operator::Eq => eq_scalar(lhs, rhs), - Operator::Neq => neq_scalar(lhs, rhs), - Operator::Gt => gt_scalar(lhs, rhs), - Operator::GtEq => gt_eq_scalar(lhs, rhs), - Operator::Lt => lt_scalar(lhs, rhs), - Operator::LtEq => lt_eq_scalar(lhs, rhs), - } -} - // disable wrapping inside literal vectors used for test data and assertions #[rustfmt::skip::macros(vec)] #[cfg(test)] @@ -219,7 +167,7 @@ mod tests { ($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => { let a = BooleanArray::from_slice($A_VEC); let b = BooleanArray::from_slice($B_VEC); - let c = $KERNEL(&a, &b).unwrap(); + let c = $KERNEL(&a, &b); assert_eq!(BooleanArray::from_slice($EXPECTED), c); }; } @@ -228,7 +176,7 @@ mod tests { ($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => { let a = BooleanArray::from($A_VEC); let b = BooleanArray::from($B_VEC); - let c = $KERNEL(&a, &b).unwrap(); + let c = $KERNEL(&a, &b); assert_eq!(BooleanArray::from($EXPECTED), c); }; } @@ -261,7 +209,7 @@ mod tests { let a = BooleanArray::from_slice(&[true, true, false]); let b = BooleanArray::from_slice(&[true, true, true, true, false]); let c = b.slice(2, 3); - let d = eq(&c, &a).unwrap(); + let d = eq(&c, &a); assert_eq!(d, BooleanArray::from_slice(&[true, true, true])); } diff --git a/src/compute/comparison/mod.rs b/src/compute/comparison/mod.rs index 8a3071c271c..a3471f3b285 100644 --- a/src/compute/comparison/mod.rs +++ b/src/compute/comparison/mod.rs @@ -1,359 +1,423 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - //! Basic comparison kernels. //! //! The module contains functions that compare either an array and a scalar //! or two arrays of the same [`DataType`]. The scalar-oriented functions are -//! suffixed with `_scalar`. In general, these comparison functions receive as -//! inputs the two items for comparison and an [`Operator`] which specifies the -//! type of comparison that will be conducted, such as `<=` ([`Operator::LtEq`]). -//! -//! Much like the parent module [`compute`](crate::compute), the comparison functions -//! have two variants - a statically typed one ([`primitive_compare`]) -//! which expects concrete types such as [`Int8Array`] and a dynamically typed -//! variant ([`compare`]) that compares values of type `&dyn Array` and errors -//! if the underlying concrete types mismsatch. The statically-typed functions -//! are appropriately prefixed with the concrete types they expect. +//! suffixed with `_scalar`. //! -//! Also note that the scalar-based comparison functions for the concrete types, -//! like [`primitive_compare_scalar`], are infallible and always return a -//! [`BooleanArray`] while the rest of the functions always wrap the returned -//! array in a [`Result`] due to their internal checks of the data types and -//! lengths. +//! The functions are organized in two variants: +//! * statically typed +//! * dynamically typed +//! The statically typed are available under each module of this module (e.g. [`primitive::eq`], [`primitive::lt_scalar`]) +//! The dynamically typed are available in this module (e.g. [`eq`] or [`lt_scalar`]). //! //! # Examples //! //! Compare two [`PrimitiveArray`]s: //! ``` -//! use arrow2::compute::comparison::{primitive_compare, Operator}; -//! # use arrow2::array::{BooleanArray, PrimitiveArray}; -//! # use arrow2::error::{ArrowError, Result}; +//! use arrow2::array::{BooleanArray, PrimitiveArray}; +//! use arrow2::compute::comparison::primitive::gt; //! //! let array1 = PrimitiveArray::::from([Some(1), None, Some(2)]); -//! let array2 = PrimitiveArray::::from([Some(1), None, Some(1)]); -//! let result = primitive_compare(&array1, &array2, Operator::Gt)?; +//! let array2 = PrimitiveArray::::from([Some(1), Some(3), Some(1)]); +//! let result = gt(&array1, &array2); //! assert_eq!(result, BooleanArray::from([Some(false), None, Some(true)])); -//! # Ok::<(), ArrowError>(()) //! ``` -//! Compare two dynamically-typed arrays (trait objects): +//! +//! Compare two dynamically-typed [`Array`]s (trait objects): //! ``` -//! use arrow2::compute::comparison::{compare, Operator}; -//! # use arrow2::array::{Array, BooleanArray, PrimitiveArray}; -//! # use arrow2::error::{ArrowError, Result}; +//! use arrow2::array::{Array, BooleanArray, PrimitiveArray}; +//! use arrow2::compute::comparison::eq; //! //! let array1: &dyn Array = &PrimitiveArray::::from(&[Some(10.0), None, Some(20.0)]); //! let array2: &dyn Array = &PrimitiveArray::::from(&[Some(10.0), None, Some(10.0)]); -//! let result = compare(array1, array2, Operator::LtEq)?; +//! let result = eq(array1, array2); //! assert_eq!(result, BooleanArray::from([Some(true), None, Some(false)])); -//! # Ok::<(), ArrowError>(()) //! ``` -//! Compare an array of strings to a "scalar", i.e a word (note that we use -//! [`Operator::Neq`]): +//! +//! Compare (not equal) a [`Utf8Array`] to a word: //! ``` -//! use arrow2::compute::comparison::{utf8_compare_scalar, Operator}; -//! # use arrow2::array::{Array, BooleanArray, Utf8Array}; -//! # use arrow2::scalar::Utf8Scalar; -//! # use arrow2::error::{ArrowError, Result}; +//! use arrow2::array::{BooleanArray, Utf8Array}; +//! use arrow2::compute::comparison::utf8::neq_scalar; //! //! let array = Utf8Array::::from([Some("compute"), None, Some("compare")]); -//! let word = Utf8Scalar::new(Some("compare")); -//! let result = utf8_compare_scalar(&array, &word, Operator::Neq); +//! let result = neq_scalar(&array, "compare"); //! assert_eq!(result, BooleanArray::from([Some(true), None, Some(false)])); -//! # Ok::<(), ArrowError>(()) //! ``` use crate::array::*; use crate::datatypes::{DataType, IntervalUnit}; -use crate::error::{ArrowError, Result}; -use crate::scalar::Scalar; +use crate::scalar::*; -mod binary; -mod boolean; -mod primitive; -mod utf8; +pub mod binary; +pub mod boolean; +pub mod primitive; +pub mod utf8; mod simd; pub use simd::{Simd8, Simd8Lanes}; -pub use binary::compare as binary_compare; -pub use binary::compare_scalar as binary_compare_scalar; -pub use boolean::compare as boolean_compare; -pub use boolean::compare_scalar as boolean_compare_scalar; -pub use primitive::compare as primitive_compare; -pub use primitive::compare_scalar as primitive_compare_scalar; pub(crate) use primitive::compare_values_op as primitive_compare_values_op; -pub use utf8::compare as utf8_compare; -pub use utf8::compare_scalar as utf8_compare_scalar; -/// Comparison operators, such as `>` ([`Operator::Gt`]) -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum Operator { - /// Less than - Lt, - /// Less than or equal to - LtEq, - /// Greater than - Gt, - /// Greater than or equal to - GtEq, - /// Equal - Eq, - /// Not equal - Neq, +macro_rules! compare { + ($lhs:expr, $rhs:expr, $op:tt) => {{ + let lhs = $lhs; + let rhs = $rhs; + assert_eq!( + lhs.data_type().to_logical_type(), + rhs.data_type().to_logical_type() + ); + + use DataType::*; + let data_type = lhs.data_type().to_logical_type(); + match data_type { + Boolean => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + boolean::$op(lhs, rhs) + } + Int8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::$op::(lhs, rhs) + } + Int16 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::$op::(lhs, rhs) + } + Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::$op::(lhs, rhs) + } + Int64 | Timestamp(_, _) | Date64 | Time64(_) | Duration(_) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::$op::(lhs, rhs) + } + UInt8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::$op::(lhs, rhs) + } + UInt16 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::$op::(lhs, rhs) + } + UInt32 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::$op::(lhs, rhs) + } + UInt64 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::$op::(lhs, rhs) + } + Float16 => unreachable!(), + Float32 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::$op::(lhs, rhs) + } + Float64 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::$op::(lhs, rhs) + } + Utf8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + utf8::$op::(lhs, rhs) + } + LargeUtf8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + utf8::$op::(lhs, rhs) + } + Decimal(_, _) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::$op::(lhs, rhs) + } + Binary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::$op::(lhs, rhs) + } + LargeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::$op::(lhs, rhs) + } + _ => todo!("Comparisons of {:?} are not yet supported", data_type), + } + }}; } -/// Compares each slot of `lhs` against each slot of `rhs`. -/// # Error -/// Errors iff: -/// * `lhs.data_type() != rhs.data_type()` or -/// * `lhs.len() != rhs.len()` or -/// * the datatype is not supported (use [`can_compare`] to tell whether it is supported) -pub fn compare(lhs: &dyn Array, rhs: &dyn Array, operator: Operator) -> Result { - let data_type = lhs.data_type(); - if data_type != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Comparison is only supported for arrays of the same logical type".to_string(), - )); - } - match data_type { - DataType::Boolean => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - boolean::compare(lhs, rhs, operator) - } - DataType::Int8 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare::(lhs, rhs, operator) - } - DataType::Int16 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare::(lhs, rhs, operator) - } - DataType::Int32 - | DataType::Date32 - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare::(lhs, rhs, operator) - } - DataType::Int64 - | DataType::Timestamp(_, None) - | DataType::Date64 - | DataType::Time64(_) - | DataType::Duration(_) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare::(lhs, rhs, operator) - } - DataType::UInt8 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare::(lhs, rhs, operator) - } - DataType::UInt16 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare::(lhs, rhs, operator) - } - DataType::UInt32 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare::(lhs, rhs, operator) - } - DataType::UInt64 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare::(lhs, rhs, operator) - } - DataType::Float16 => unreachable!(), - DataType::Float32 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare::(lhs, rhs, operator) - } - DataType::Float64 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare::(lhs, rhs, operator) - } - DataType::Utf8 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - utf8::compare::(lhs, rhs, operator) - } - DataType::LargeUtf8 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - utf8::compare::(lhs, rhs, operator) - } - DataType::Decimal(_, _) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare::(lhs, rhs, operator) - } - DataType::Binary => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - binary::compare::(lhs, rhs, operator) - } - DataType::LargeBinary => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - binary::compare::(lhs, rhs, operator) - } - _ => Err(ArrowError::NotYetImplemented(format!( - "Comparison between {:?} is not supported", - data_type - ))), - } +/// `==` between two [`Array`]s. +/// Use [`can_eq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, eq) } -/// Compares all slots of `lhs` against `rhs`. -/// # Error -/// Errors iff: -/// * `lhs.data_type() != rhs.data_type()` or -/// * the datatype is not supported (use [`can_compare`] to tell whether it is supported) -pub fn compare_scalar( - lhs: &dyn Array, - rhs: &dyn Scalar, - operator: Operator, -) -> Result { - let data_type = lhs.data_type(); - if data_type != rhs.data_type() { - return Err(ArrowError::InvalidArgumentError( - "Comparison is only supported for the same logical type".to_string(), - )); - } - Ok(match data_type { - DataType::Boolean => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - boolean::compare_scalar(lhs, rhs, operator) - } - DataType::Int8 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare_scalar::(lhs, rhs, operator) - } - DataType::Int16 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare_scalar::(lhs, rhs, operator) - } - DataType::Int32 - | DataType::Date32 - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare_scalar::(lhs, rhs, operator) - } - DataType::Int64 - | DataType::Timestamp(_, None) - | DataType::Date64 - | DataType::Time64(_) - | DataType::Duration(_) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare_scalar::(lhs, rhs, operator) - } - DataType::UInt8 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare_scalar::(lhs, rhs, operator) - } - DataType::UInt16 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare_scalar::(lhs, rhs, operator) - } - DataType::UInt32 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare_scalar::(lhs, rhs, operator) - } - DataType::UInt64 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare_scalar::(lhs, rhs, operator) - } - DataType::Float16 => unreachable!(), - DataType::Float32 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare_scalar::(lhs, rhs, operator) - } - DataType::Float64 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare_scalar::(lhs, rhs, operator) - } - DataType::Decimal(_, _) => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - primitive::compare_scalar::(lhs, rhs, operator) - } - DataType::Utf8 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - utf8::compare_scalar::(lhs, rhs, operator) - } - DataType::LargeUtf8 => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - utf8::compare_scalar::(lhs, rhs, operator) - } - DataType::Binary => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - binary::compare_scalar::(lhs, rhs, operator) - } - DataType::LargeBinary => { - let lhs = lhs.as_any().downcast_ref().unwrap(); - let rhs = rhs.as_any().downcast_ref().unwrap(); - binary::compare_scalar::(lhs, rhs, operator) - } - _ => { - return Err(ArrowError::NotYetImplemented(format!( - "Comparison between {:?} is not supported", - data_type - ))) +/// `!=` between two [`Array`]s. +/// Use [`can_neq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn neq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, neq) +} + +/// `<` between two [`Array`]s. +/// Use [`can_lt`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn lt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, lt) +} + +/// `<=` between two [`Array`]s. +/// Use [`can_lt_eq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn lt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, lt_eq) +} + +/// `>` between two [`Array`]s. +/// Use [`can_gt`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn gt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, gt) +} + +/// `>=` between two [`Array`]s. +/// Use [`can_gt_eq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn gt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, gt_eq) +} + +macro_rules! compare_scalar { + ($lhs:expr, $rhs:expr, $op:tt) => {{ + let lhs = $lhs; + let rhs = $rhs; + assert_eq!( + lhs.data_type().to_logical_type(), + rhs.data_type().to_logical_type() + ); + if !rhs.is_valid() { + return BooleanArray::new_null(DataType::Boolean, lhs.len()); } - }) + + use DataType::*; + let data_type = lhs.data_type().to_logical_type(); + match data_type { + Boolean => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::().unwrap(); + boolean::$op(lhs, rhs.value()) + } + Int8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::$op::(lhs, rhs.value()) + } + Int16 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::$op::(lhs, rhs.value()) + } + Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::$op::(lhs, rhs.value()) + } + Int64 | Timestamp(_, _) | Date64 | Time64(_) | Duration(_) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::$op::(lhs, rhs.value()) + } + UInt8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::$op::(lhs, rhs.value()) + } + UInt16 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::$op::(lhs, rhs.value()) + } + UInt32 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::$op::(lhs, rhs.value()) + } + UInt64 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::$op::(lhs, rhs.value()) + } + Float16 => unreachable!(), + Float32 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::$op::(lhs, rhs.value()) + } + Float64 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::$op::(lhs, rhs.value()) + } + Utf8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + utf8::$op::(lhs, rhs.value()) + } + LargeUtf8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + utf8::$op::(lhs, rhs.value()) + } + Decimal(_, _) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs + .as_any() + .downcast_ref::>() + .unwrap(); + primitive::$op::(lhs, rhs.value()) + } + Binary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + binary::$op::(lhs, rhs.value()) + } + LargeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + binary::$op::(lhs, rhs.value()) + } + _ => todo!("Comparisons of {:?} are not yet supported", data_type), + } + }}; +} + +/// `==` between an [`Array`] and a [`Scalar`]. +/// Use [`can_eq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, eq_scalar) +} + +/// `!=` between an [`Array`] and a [`Scalar`]. +/// Use [`can_neq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn neq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, neq_scalar) +} + +/// `<` between an [`Array`] and a [`Scalar`]. +/// Use [`can_lt`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn lt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, lt_scalar) +} + +/// `<=` between an [`Array`] and a [`Scalar`]. +/// Use [`can_lt_eq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn lt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, lt_eq_scalar) +} + +/// `>` between an [`Array`] and a [`Scalar`]. +/// Use [`can_gt`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn gt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, gt_scalar) +} + +/// `>=` between an [`Array`] and a [`Scalar`]. +/// Use [`can_gt_eq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn gt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, gt_eq_scalar) +} + +/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison. +pub fn can_eq(data_type: &DataType) -> bool { + can_compare(data_type) +} + +/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison. +pub fn can_neq(data_type: &DataType) -> bool { + can_compare(data_type) } -/// Checks if an array of type `datatype` can be compared with another array of -/// the same type. -/// -/// # Examples -/// ``` -/// use arrow2::compute::comparison::can_compare; -/// use arrow2::datatypes::{DataType}; -/// -/// let data_type = DataType::Int8; -/// assert_eq!(can_compare(&data_type), true); -/// -/// let data_type = DataType::LargeBinary; -/// assert_eq!(can_compare(&data_type), true) -/// ``` -pub fn can_compare(data_type: &DataType) -> bool { + +/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison. +pub fn can_lt(data_type: &DataType) -> bool { + can_compare(data_type) +} + +/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison. +pub fn can_lt_eq(data_type: &DataType) -> bool { + can_compare(data_type) +} + +/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison. +pub fn can_gt(data_type: &DataType) -> bool { + can_compare(data_type) +} + +/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison. +pub fn can_gt_eq(data_type: &DataType) -> bool { + can_compare(data_type) +} + +// The list of operations currently supported. +fn can_compare(data_type: &DataType) -> bool { matches!( data_type, DataType::Boolean @@ -364,7 +428,7 @@ pub fn can_compare(data_type: &DataType) -> bool { | DataType::Time32(_) | DataType::Interval(_) | DataType::Int64 - | DataType::Timestamp(_, None) + | DataType::Timestamp(_, _) | DataType::Date64 | DataType::Time64(_) | DataType::Duration(_) diff --git a/src/compute/comparison/primitive.rs b/src/compute/comparison/primitive.rs index 34a00a2330e..5a427ecfd6a 100644 --- a/src/compute/comparison/primitive.rs +++ b/src/compute/comparison/primitive.rs @@ -1,32 +1,14 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::bitmap::Bitmap; -use crate::datatypes::DataType; -use crate::scalar::{PrimitiveScalar, Scalar}; -use crate::{array::*, types::NativeType}; +//! Comparison functions for [`PrimitiveArray`] use crate::{ - bitmap::MutableBitmap, + array::{BooleanArray, PrimitiveArray}, + bitmap::{Bitmap, MutableBitmap}, buffer::MutableBuffer, - error::{ArrowError, Result}, + datatypes::DataType, + types::NativeType, }; +use super::super::utils::combine_validities; use super::simd::{Simd8, Simd8Lanes}; -use super::{super::utils::combine_validities, Operator}; pub(crate) fn compare_values_op(lhs: &[T], rhs: &[T], op: F) -> MutableBitmap where @@ -58,26 +40,16 @@ where /// Evaluate `op(lhs, rhs)` for [`PrimitiveArray`]s using a specified /// comparison function. -fn compare_op(lhs: &PrimitiveArray, rhs: &PrimitiveArray, op: F) -> Result +fn compare_op(lhs: &PrimitiveArray, rhs: &PrimitiveArray, op: F) -> BooleanArray where T: NativeType + Simd8, F: Fn(T::Simd, T::Simd) -> u8, { - if lhs.len() != rhs.len() { - return Err(ArrowError::InvalidArgumentError( - "Cannot perform comparison operation on arrays of different length".to_string(), - )); - } - let validity = combine_validities(lhs.validity(), rhs.validity()); let values = compare_values_op(lhs.values(), rhs.values(), op); - Ok(BooleanArray::from_data( - DataType::Boolean, - values.into(), - validity, - )) + BooleanArray::from_data(DataType::Boolean, values.into(), validity) } /// Evaluate `op(left, right)` for [`PrimitiveArray`] and scalar using @@ -113,7 +85,7 @@ where } /// Perform `lhs == rhs` operation on two arrays. -pub fn eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result +pub fn eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray where T: NativeType + Simd8, { @@ -129,7 +101,7 @@ where } /// Perform `left != right` operation on two arrays. -pub fn neq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result +pub fn neq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray where T: NativeType + Simd8, { @@ -145,7 +117,7 @@ where } /// Perform `left < right` operation on two arrays. -pub fn lt(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result +pub fn lt(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray where T: NativeType + Simd8, { @@ -161,7 +133,7 @@ where } /// Perform `left <= right` operation on two arrays. -pub fn lt_eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result +pub fn lt_eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray where T: NativeType + Simd8, { @@ -179,7 +151,7 @@ where /// Perform `left > right` operation on two arrays. Non-null values are greater than null /// values. -pub fn gt(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result +pub fn gt(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray where T: NativeType + Simd8, { @@ -197,7 +169,7 @@ where /// Perform `left >= right` operation on two arrays. Non-null values are greater than null /// values. -pub fn gt_eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Result +pub fn gt_eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray where T: NativeType + Simd8, { @@ -213,64 +185,12 @@ where compare_op_scalar(lhs, rhs, |a, b| a.gt_eq(b)) } -/// Compare two [`PrimitiveArray`]s using the given [`Operator`]. -/// -/// # Errors -/// When the two arrays have different lengths. -/// -/// Check the [crate::compute::comparison](module documentation) for usage -/// examples. -pub fn compare( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, - op: Operator, -) -> Result { - match op { - Operator::Eq => eq(lhs, rhs), - Operator::Neq => neq(lhs, rhs), - Operator::Gt => gt(lhs, rhs), - Operator::GtEq => gt_eq(lhs, rhs), - Operator::Lt => lt(lhs, rhs), - Operator::LtEq => lt_eq(lhs, rhs), - } -} - -/// Compare a [`PrimitiveArray`] and a scalar value using the given -/// [`Operator`]. -/// -/// Check the [crate::compute::comparison](module documentation) for usage -/// examples. -pub fn compare_scalar( - lhs: &PrimitiveArray, - rhs: &PrimitiveScalar, - op: Operator, -) -> BooleanArray { - if !rhs.is_valid() { - return BooleanArray::new_null(DataType::Boolean, lhs.len()); - } - compare_scalar_non_null(lhs, rhs.value(), op) -} - -pub fn compare_scalar_non_null( - lhs: &PrimitiveArray, - rhs: T, - op: Operator, -) -> BooleanArray { - match op { - Operator::Eq => eq_scalar(lhs, rhs), - Operator::Neq => neq_scalar(lhs, rhs), - Operator::Gt => gt_scalar(lhs, rhs), - Operator::GtEq => gt_eq_scalar(lhs, rhs), - Operator::Lt => lt_scalar(lhs, rhs), - Operator::LtEq => lt_eq_scalar(lhs, rhs), - } -} - // disable wrapping inside literal vectors used for test data and assertions #[rustfmt::skip::macros(vec)] #[cfg(test)] mod tests { use super::*; + use crate::array::{Int64Array, Int8Array}; /// Evaluate `KERNEL` with two vectors as inputs and assert against the expected output. /// `A_VEC` and `B_VEC` can be of type `Vec` or `Vec>`. @@ -280,7 +200,7 @@ mod tests { ($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => { let a = Int64Array::from_slice($A_VEC); let b = Int64Array::from_slice($B_VEC); - let c = $KERNEL(&a, &b).unwrap(); + let c = $KERNEL(&a, &b); assert_eq!(BooleanArray::from_slice($EXPECTED), c); }; } @@ -289,7 +209,7 @@ mod tests { ($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => { let a = Int64Array::from($A_VEC); let b = Int64Array::from($B_VEC); - let c = $KERNEL(&a, &b).unwrap(); + let c = $KERNEL(&a, &b); assert_eq!(BooleanArray::from($EXPECTED), c); }; } @@ -339,7 +259,7 @@ mod tests { let a = Int64Array::from_slice(&[6, 7, 8, 8, 10]); let b = Int64Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); let c = b.slice(5, 5); - let d = eq(&c, &a).unwrap(); + let d = eq(&c, &a); assert_eq!( d, BooleanArray::from_slice(&vec![true, true, true, false, true]) @@ -572,7 +492,7 @@ mod tests { let a = a.slice(50, 50); let b = (100..200).map(Some).collect::>(); let b = b.slice(50, 50); - let actual = lt(&a, &b).unwrap(); + let actual = lt(&a, &b); let expected: BooleanArray = (0..50).map(|_| Some(true)).collect(); assert_eq!(expected, actual); } @@ -594,7 +514,7 @@ mod tests { let array_a = Int8Array::from_slice(&vec![1; item_count]); let array_b = Int8Array::from_slice(&vec![2; item_count]); let expected = BooleanArray::from_slice(&vec![false; item_count]); - let result = gt_eq(&array_a, &array_b).unwrap(); + let result = gt_eq(&array_a, &array_b); assert_eq!(result, expected) } diff --git a/src/compute/comparison/utf8.rs b/src/compute/comparison/utf8.rs index e6a66816890..6266e578f09 100644 --- a/src/compute/comparison/utf8.rs +++ b/src/compute/comparison/utf8.rs @@ -1,40 +1,20 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::datatypes::DataType; -use crate::error::{ArrowError, Result}; -use crate::scalar::{Scalar, Utf8Scalar}; -use crate::{array::*, bitmap::Bitmap}; - -use super::{super::utils::combine_validities, Operator}; +//! Comparison functions for [`Utf8Array`] +use crate::{ + array::{BooleanArray, Offset, Utf8Array}, + bitmap::Bitmap, + datatypes::DataType, +}; + +use super::super::utils::combine_validities; /// Evaluate `op(lhs, rhs)` for [`Utf8Array`]s using a specified /// comparison function. -fn compare_op(lhs: &Utf8Array, rhs: &Utf8Array, op: F) -> Result +fn compare_op(lhs: &Utf8Array, rhs: &Utf8Array, op: F) -> BooleanArray where O: Offset, F: Fn(&str, &str) -> bool, { - if lhs.len() != rhs.len() { - return Err(ArrowError::InvalidArgumentError( - "Cannot perform comparison operation on arrays of different length".to_string(), - )); - } - + assert_eq!(lhs.len(), rhs.len()); let validity = combine_validities(lhs.validity(), rhs.validity()); let values = lhs @@ -43,7 +23,7 @@ where .map(|(lhs, rhs)| op(lhs, rhs)); let values = Bitmap::from_trusted_len_iter(values); - Ok(BooleanArray::from_data(DataType::Boolean, values, validity)) + BooleanArray::from_data(DataType::Boolean, values, validity) } /// Evaluate `op(lhs, rhs)` for [`Utf8Array`] and scalar using @@ -62,123 +42,70 @@ where } /// Perform `lhs == rhs` operation on [`Utf8Array`]. -fn eq(lhs: &Utf8Array, rhs: &Utf8Array) -> Result { +pub fn eq(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { compare_op(lhs, rhs, |a, b| a == b) } /// Perform `lhs == rhs` operation on [`Utf8Array`] and a scalar. -fn eq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { +pub fn eq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { compare_op_scalar(lhs, rhs, |a, b| a == b) } /// Perform `lhs != rhs` operation on [`Utf8Array`]. -fn neq(lhs: &Utf8Array, rhs: &Utf8Array) -> Result { +pub fn neq(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { compare_op(lhs, rhs, |a, b| a != b) } /// Perform `lhs != rhs` operation on [`Utf8Array`] and a scalar. -fn neq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { +pub fn neq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { compare_op_scalar(lhs, rhs, |a, b| a != b) } /// Perform `lhs < rhs` operation on [`Utf8Array`]. -fn lt(lhs: &Utf8Array, rhs: &Utf8Array) -> Result { +pub fn lt(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { compare_op(lhs, rhs, |a, b| a < b) } /// Perform `lhs < rhs` operation on [`Utf8Array`] and a scalar. -fn lt_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { +pub fn lt_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { compare_op_scalar(lhs, rhs, |a, b| a < b) } /// Perform `lhs <= rhs` operation on [`Utf8Array`]. -fn lt_eq(lhs: &Utf8Array, rhs: &Utf8Array) -> Result { +pub fn lt_eq(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { compare_op(lhs, rhs, |a, b| a <= b) } /// Perform `lhs <= rhs` operation on [`Utf8Array`] and a scalar. -fn lt_eq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { +pub fn lt_eq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { compare_op_scalar(lhs, rhs, |a, b| a <= b) } /// Perform `lhs > rhs` operation on [`Utf8Array`]. -fn gt(lhs: &Utf8Array, rhs: &Utf8Array) -> Result { +pub fn gt(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { compare_op(lhs, rhs, |a, b| a > b) } /// Perform `lhs > rhs` operation on [`Utf8Array`] and a scalar. -fn gt_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { +pub fn gt_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { compare_op_scalar(lhs, rhs, |a, b| a > b) } /// Perform `lhs >= rhs` operation on [`Utf8Array`]. -fn gt_eq(lhs: &Utf8Array, rhs: &Utf8Array) -> Result { +pub fn gt_eq(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { compare_op(lhs, rhs, |a, b| a >= b) } /// Perform `lhs >= rhs` operation on [`Utf8Array`] and a scalar. -fn gt_eq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { +pub fn gt_eq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { compare_op_scalar(lhs, rhs, |a, b| a >= b) } -/// Compare two [`Utf8Array`]s using the given [`Operator`]. -/// -/// # Errors -/// When the two arrays have different lengths. -/// -/// Check the [crate::compute::comparison](module documentation) for usage -/// examples. -pub fn compare( - lhs: &Utf8Array, - rhs: &Utf8Array, - op: Operator, -) -> Result { - match op { - Operator::Eq => eq(lhs, rhs), - Operator::Neq => neq(lhs, rhs), - Operator::Gt => gt(lhs, rhs), - Operator::GtEq => gt_eq(lhs, rhs), - Operator::Lt => lt(lhs, rhs), - Operator::LtEq => lt_eq(lhs, rhs), - } -} - -/// Compare a [`Utf8Array`] and a scalar value using the given -/// [`Operator`]. -/// -/// Check the [crate::compute::comparison](module documentation) for usage -/// examples. -pub fn compare_scalar( - lhs: &Utf8Array, - rhs: &Utf8Scalar, - op: Operator, -) -> BooleanArray { - if !rhs.is_valid() { - return BooleanArray::new_null(DataType::Boolean, lhs.len()); - } - compare_scalar_non_null(lhs, rhs.value(), op) -} - -pub fn compare_scalar_non_null( - lhs: &Utf8Array, - rhs: &str, - op: Operator, -) -> BooleanArray { - match op { - Operator::Eq => eq_scalar(lhs, rhs), - Operator::Neq => neq_scalar(lhs, rhs), - Operator::Gt => gt_scalar(lhs, rhs), - Operator::GtEq => gt_eq_scalar(lhs, rhs), - Operator::Lt => lt_scalar(lhs, rhs), - Operator::LtEq => lt_eq_scalar(lhs, rhs), - } -} - #[cfg(test)] mod tests { use super::*; - fn test_generic, &Utf8Array) -> Result>( + fn test_generic, &Utf8Array) -> BooleanArray>( lhs: Vec<&str>, rhs: Vec<&str>, op: F, @@ -187,7 +114,7 @@ mod tests { let lhs = Utf8Array::::from_slice(lhs); let rhs = Utf8Array::::from_slice(rhs); let expected = BooleanArray::from_slice(expected); - assert_eq!(op(&lhs, &rhs).unwrap(), expected); + assert_eq!(op(&lhs, &rhs), expected); } fn test_generic_scalar, &str) -> BooleanArray>( diff --git a/src/compute/regex_match.rs b/src/compute/regex_match.rs index 0a36c031de9..3cdc94cfad1 100644 --- a/src/compute/regex_match.rs +++ b/src/compute/regex_match.rs @@ -6,9 +6,9 @@ use regex::Regex; use super::utils::{combine_validities, unary_utf8_boolean}; use crate::array::{BooleanArray, Offset, Utf8Array}; +use crate::bitmap::Bitmap; use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; -use crate::{array::*, bitmap::Bitmap}; /// Regex matches pub fn regex_match(values: &Utf8Array, regex: &Utf8Array) -> Result { diff --git a/src/compute/sort/utf8.rs b/src/compute/sort/utf8.rs index 1776b6b022e..4d00dc42ecb 100644 --- a/src/compute/sort/utf8.rs +++ b/src/compute/sort/utf8.rs @@ -1,5 +1,5 @@ -use crate::array::{Array, Offset, PrimitiveArray, Utf8Array}; use crate::array::{DictionaryArray, DictionaryKey}; +use crate::array::{Offset, PrimitiveArray, Utf8Array}; use crate::types::Index; use super::common; diff --git a/tests/it/compute/comparison.rs b/tests/it/compute/comparison.rs index 02943ce6e25..4dbdff5fd24 100644 --- a/tests/it/compute/comparison.rs +++ b/tests/it/compute/comparison.rs @@ -1,5 +1,5 @@ use arrow2::array::new_null_array; -use arrow2::compute::comparison::{can_compare, compare, compare_scalar, Operator}; +use arrow2::compute::comparison::{can_eq, eq, eq_scalar}; use arrow2::datatypes::DataType::*; use arrow2::datatypes::TimeUnit; use arrow2::scalar::new_scalar; @@ -42,11 +42,8 @@ fn consistency() { // array <> array datatypes.clone().into_iter().for_each(|d1| { let array = new_null_array(d1.clone(), 10); - let op = Operator::Eq; - if can_compare(&d1) { - assert!(compare(array.as_ref(), array.as_ref(), op).is_ok()); - } else { - assert!(compare(array.as_ref(), array.as_ref(), op).is_err()); + if can_eq(&d1) { + eq(array.as_ref(), array.as_ref()); } }); @@ -54,11 +51,8 @@ fn consistency() { datatypes.into_iter().for_each(|d1| { let array = new_null_array(d1.clone(), 10); let scalar = new_scalar(array.as_ref(), 0); - let op = Operator::Eq; - if can_compare(&d1) { - assert!(compare_scalar(array.as_ref(), scalar.as_ref(), op).is_ok()); - } else { - assert!(compare_scalar(array.as_ref(), scalar.as_ref(), op).is_err()); + if can_eq(&d1) { + eq_scalar(array.as_ref(), scalar.as_ref()); } }); }