From 985a4041a44805a03fff032783a9c70eb176a021 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Sun, 2 Jan 2022 08:02:22 -0500 Subject: [PATCH] Add neq dyn scalar kernel (#1118) * Add lt_dyn_scalar and tests * Add lt_eq_dyn_scalar kernel * Add gt_dyn_scalar kernel * Add gt_eq_dyn_scalar kernel * Add neq_dyn_scalar kernel * Add kernel to err message Co-authored-by: Andrew Lamb --- arrow/src/compute/kernels/comparison.rs | 65 +++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 2f7ddec07948..7c842473a9b3 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -1265,6 +1265,42 @@ where } } +/// Perform `left != right` operation on an array and a numeric scalar +/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values +pub fn neq_dyn_scalar(left: Arc, right: T) -> Result +where + T: TryInto + Copy + std::fmt::Debug, +{ + match left.data_type() { + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, neq_scalar)} + _ => Err(ArrowError::ComputeError( + "neq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + )) + } + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { + dyn_compare_scalar!(&left, right, neq_scalar) + } + _ => Err(ArrowError::ComputeError( + "neq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + )) + } +} + /// Perform `left == right` operation on an array and a numeric scalar /// value. Supports StringArrays, and DictionaryArrays that have string values pub fn eq_dyn_utf8_scalar(left: Arc, right: &str) -> Result { @@ -3274,6 +3310,35 @@ mod tests { ); } + #[test] + fn test_neq_dyn_scalar() { + let array = Int32Array::from(vec![6, 7, 8, 8, 10]); + let array = Arc::new(array); + let a_eq = neq_dyn_scalar(array, 8).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(true), Some(true), Some(false), Some(false), Some(true)] + ) + ); + } + + #[test] + fn test_neq_dyn_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(22).unwrap(); + builder.append_null().unwrap(); + builder.append(23).unwrap(); + let array = Arc::new(builder.finish()); + let a_eq = neq_dyn_scalar(array, 23).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), None, Some(false)]) + ); + } + #[test] fn test_eq_dyn_utf8_scalar() { let array = StringArray::from(vec!["abc", "def", "xyz"]);