From 0d825c196e343805c7500bdc06af0c6e941e2577 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Sat, 1 Jan 2022 07:06:08 -0500 Subject: [PATCH] Define eq_dyn_scalar API (#1074) * Squash * Cleanup error messages --- arrow/src/compute/kernels/comparison.rs | 374 +++++++++++++++++++++++- 1 file changed, 370 insertions(+), 4 deletions(-) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index f78588efbca3..3e7a084cf334 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -21,22 +21,24 @@ //! detection is provided, you should enable the specific SIMD intrinsics using //! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. +//! use crate::array::*; use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, MutableBuffer}; use crate::compute::binary_boolean_kernel; use crate::compute::util::combine_option_bitmap; use crate::datatypes::{ - ArrowNumericType, DataType, Date32Type, Date64Type, Float32Type, Float64Type, - Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, + ArrowNativeType, ArrowNumericType, DataType, Date32Type, Date64Type, Float32Type, + Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use crate::error::{ArrowError, Result}; use crate::util::bit_util; use regex::{escape, Regex}; use std::any::type_name; use std::collections::HashMap; +use std::sync::Arc; /// Helper function to perform boolean lambda function on values from two arrays, this /// version does not attempt to use SIMD. @@ -888,6 +890,303 @@ pub fn gt_eq_utf8_scalar( compare_op_scalar!(left, right, |a, b| a >= b) } +macro_rules! dyn_compare_scalar { + // Applies `LEFT OP RIGHT` when `LEFT` is a `DictionaryArray` + ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{ + let right: i128 = $RIGHT.try_into().map_err(|_| { + ArrowError::ComputeError(String::from("Can not convert scalar to i128")) + })?; + match $LEFT.data_type() { + DataType::Int8 => { + let right: i8 = right.try_into().map_err(|_| { + ArrowError::ComputeError(String::from("Can not convert scalar to i8")) + })?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::Int16 => { + let right: i16 = right.try_into().map_err(|_| { + ArrowError::ComputeError(String::from( + "Can not convert scalar to i16", + )) + })?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::Int32 => { + let right: i32 = right.try_into().map_err(|_| { + ArrowError::ComputeError(String::from( + "Can not convert scalar to i32", + )) + })?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::Int64 => { + let right: i64 = right.try_into().map_err(|_| { + ArrowError::ComputeError(String::from( + "Can not convert scalar to i64", + )) + })?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::UInt8 => { + let right: u8 = right.try_into().map_err(|_| { + ArrowError::ComputeError(String::from("Can not convert scalar to u8")) + })?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::UInt16 => { + let right: u16 = right.try_into().map_err(|_| { + ArrowError::ComputeError(String::from( + "Can not convert scalar to u16", + )) + })?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::UInt32 => { + let right: u32 = right.try_into().map_err(|_| { + ArrowError::ComputeError(String::from( + "Can not convert scalar to u32", + )) + })?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::UInt64 => { + let right: u64 = right.try_into().map_err(|_| { + ArrowError::ComputeError(String::from( + "Can not convert scalar to u64", + )) + })?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + _ => Err(ArrowError::ComputeError(String::from( + "Unsupported data type", + ))), + } + }}; + // Applies `LEFT OP RIGHT` when `LEFT` is a `DictionaryArray` with keys of type `KT` + ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{ + let right: i128 = $RIGHT.try_into().map_err(|_| { + ArrowError::ComputeError(String::from("Can not convert scalar to i128")) + })?; + match $KT.as_ref() { + DataType::UInt8 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), right, $OP)?, + ) + } + DataType::UInt16 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), right, $OP)?, + ) + } + DataType::UInt32 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), right, $OP)?, + ) + } + DataType::UInt64 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), right, $OP)?, + ) + } + DataType::Int8 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), right, $OP)?, + ) + } + DataType::Int16 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), right, $OP)?, + ) + } + DataType::Int32 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), right, $OP)?, + ) + } + DataType::Int64 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), right, $OP)?, + ) + } + _ => Err(ArrowError::ComputeError(String::from("Unknown key type"))), + } + }}; +} + +macro_rules! dyn_compare_utf8_scalar { + ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{ + match $KT.as_ref() { + DataType::UInt8 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + DataType::UInt16 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + DataType::UInt32 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + DataType::UInt64 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + DataType::Int8 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + DataType::Int16 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + DataType::Int32 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + DataType::Int64 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + _ => Err(ArrowError::ComputeError(String::from("Unknown key type"))), + } + }}; +} + +/// Perform `left == right` operation on an array and a numeric scalar +/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values +pub fn eq_dyn_scalar(left: Arc, right: T) -> Result +where + T: TryInto + Copy + std::fmt::Debug, +{ + match left.data_type() { + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, eq_scalar)} + _ => Err(ArrowError::ComputeError( + "Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + )) + } + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { + dyn_compare_scalar!(&left, right, eq_scalar) + } + _ => Err(ArrowError::ComputeError( + "Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + )) + } +} + +/// Perform `left == right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +pub fn eq_dyn_utf8_scalar(left: Arc, right: &str) -> Result { + let result = match left.data_type() { + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Utf8 | DataType::LargeUtf8 => { + dyn_compare_utf8_scalar!(&left, right, key_type, eq_utf8_scalar) + } + _ => Err(ArrowError::ComputeError( + "Kernel only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), + )), + }, + DataType::Utf8 | DataType::LargeUtf8 => { + let left = as_string_array(&left); + eq_utf8_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "Kernel only supports Utf8 or LargeUtf8 arrays".to_string(), + )), + }; + result +} + +/// Perform `left == right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays, and DictionaryArrays that have string values +pub fn eq_dyn_bool_scalar(left: Arc, right: bool) -> Result { + let result = match left.data_type() { + DataType::Boolean => { + let left = as_boolean_array(&left); + eq_bool_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "Kernel only supports BooleanArray".to_string(), + )), + }; + result +} + +/// unpacks the results of comparing left.values (as a boolean) +/// +/// TODO add example +/// +fn unpack_dict_comparison( + dict: &DictionaryArray, + dict_comparison: BooleanArray, +) -> Result +where + K: ArrowNumericType, +{ + assert_eq!(dict_comparison.len(), dict.values().len()); + + let result: BooleanArray = dict + .keys() + .iter() + .map(|key| { + key.map(|key| unsafe { + // safety lengths were verified above + let key = key.to_usize().expect("Dictionary index not usize"); + dict_comparison.value_unchecked(key) + }) + }) + .collect(); + + Ok(result) +} + /// Helper function to perform boolean lambda function on values from two arrays using /// SIMD. #[cfg(feature = "simd")] @@ -2646,4 +2945,71 @@ mod tests { regexp_is_match_utf8_scalar, vec![true, true, false, false] ); + #[test] + fn test_eq_dyn_scalar() { + let array = Int32Array::from(vec![6, 7, 8, 8, 10]); + let array = Arc::new(array); + let a_eq = eq_dyn_scalar(array, 8).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(true), Some(false)] + ) + ); + } + #[test] + fn test_eq_dyn_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(123).unwrap(); + builder.append_null().unwrap(); + builder.append(23).unwrap(); + let array = Arc::new(builder.finish()); + let a_eq = eq_dyn_scalar(array, 123).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), None, Some(false)]) + ); + } + #[test] + fn test_eq_dyn_utf8_scalar() { + let array = StringArray::from(vec!["abc", "def", "xyz"]); + let array = Arc::new(array); + let a_eq = eq_dyn_utf8_scalar(array, "xyz").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(false), Some(false), Some(true)]) + ); + } + #[test] + fn test_eq_dyn_utf8_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = StringBuilder::new(100); + let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); + builder.append("abc").unwrap(); + builder.append_null().unwrap(); + builder.append("def").unwrap(); + builder.append("def").unwrap(); + builder.append("abc").unwrap(); + let array = Arc::new(builder.finish()); + let a_eq = eq_dyn_utf8_scalar(array, "def").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(false), None, Some(true), Some(true), Some(false)] + ) + ); + } + + #[test] + fn test_eq_dyn_bool_scalar() { + let array = BooleanArray::from(vec![true, false, true]); + let array = Arc::new(array); + let a_eq = eq_dyn_bool_scalar(array, false).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(false), Some(true), Some(false)]) + ); + } }