diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 60ba634e733a..c4c9fa1c3d41 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -1193,6 +1193,42 @@ where } } +/// Perform `left > right` operation on an array and a numeric scalar +/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values +pub fn gt_dyn_scalar(left: Arc, right: T) -> Result +where + T: TryInto + Copy + std::fmt::Debug, +{ + match left.data_type() { + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, gt_scalar)} + _ => Err(ArrowError::ComputeError( + "gt_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + )) + } + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { + dyn_compare_scalar!(&left, right, gt_scalar) + } + _ => Err(ArrowError::ComputeError( + "gt_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + )) + } +} + /// Perform `left == right` operation on an array and a numeric scalar /// value. Supports StringArrays, and DictionaryArrays that have string values pub fn eq_dyn_utf8_scalar(left: Arc, right: &str) -> Result { @@ -3145,6 +3181,34 @@ mod tests { ); } + #[test] + fn test_gt_dyn_scalar() { + let array = Int32Array::from(vec![6, 7, 8, 8, 10]); + let array = Arc::new(array); + let a_eq = gt_dyn_scalar(array, 8).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(true)] + ) + ); + } + #[test] + fn test_gt_dyn_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(123).unwrap(); + builder.append_null().unwrap(); + builder.append(23).unwrap(); + let array = Arc::new(builder.finish()); + let a_eq = gt_dyn_scalar(array, 23).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), None, Some(false)]) + ); + } + #[test] fn test_eq_dyn_utf8_scalar() { let array = StringArray::from(vec!["abc", "def", "xyz"]);