diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs index 1145e833f4bd..dbeaab75f815 100644 --- a/arrow/src/array/mod.rs +++ b/arrow/src/array/mod.rs @@ -414,23 +414,38 @@ pub type UInt64BufferBuilder = BufferBuilder; pub type Float32BufferBuilder = BufferBuilder; pub type Float64BufferBuilder = BufferBuilder; -pub type TimestampSecondBufferBuilder = BufferBuilder; -pub type TimestampMillisecondBufferBuilder = BufferBuilder; -pub type TimestampMicrosecondBufferBuilder = BufferBuilder; -pub type TimestampNanosecondBufferBuilder = BufferBuilder; -pub type Date32BufferBuilder = BufferBuilder; -pub type Date64BufferBuilder = BufferBuilder; -pub type Time32SecondBufferBuilder = BufferBuilder; -pub type Time32MillisecondBufferBuilder = BufferBuilder; -pub type Time64MicrosecondBufferBuilder = BufferBuilder; -pub type Time64NanosecondBufferBuilder = BufferBuilder; -pub type IntervalYearMonthBufferBuilder = BufferBuilder; -pub type IntervalDayTimeBufferBuilder = BufferBuilder; -pub type IntervalMonthDayNanoBufferBuilder = BufferBuilder; -pub type DurationSecondBufferBuilder = BufferBuilder; -pub type DurationMillisecondBufferBuilder = BufferBuilder; -pub type DurationMicrosecondBufferBuilder = BufferBuilder; -pub type DurationNanosecondBufferBuilder = BufferBuilder; +pub type TimestampSecondBufferBuilder = + BufferBuilder<::Native>; +pub type TimestampMillisecondBufferBuilder = + BufferBuilder<::Native>; +pub type TimestampMicrosecondBufferBuilder = + BufferBuilder<::Native>; +pub type TimestampNanosecondBufferBuilder = + BufferBuilder<::Native>; +pub type Date32BufferBuilder = BufferBuilder<::Native>; +pub type Date64BufferBuilder = BufferBuilder<::Native>; +pub type Time32SecondBufferBuilder = + BufferBuilder<::Native>; +pub type Time32MillisecondBufferBuilder = + BufferBuilder<::Native>; +pub type Time64MicrosecondBufferBuilder = + BufferBuilder<::Native>; +pub type Time64NanosecondBufferBuilder = + BufferBuilder<::Native>; +pub type IntervalYearMonthBufferBuilder = + BufferBuilder<::Native>; +pub type IntervalDayTimeBufferBuilder = + BufferBuilder<::Native>; +pub type IntervalMonthDayNanoBufferBuilder = + BufferBuilder<::Native>; +pub type DurationSecondBufferBuilder = + BufferBuilder<::Native>; +pub type DurationMillisecondBufferBuilder = + BufferBuilder<::Native>; +pub type DurationMicrosecondBufferBuilder = + BufferBuilder<::Native>; +pub type DurationNanosecondBufferBuilder = + BufferBuilder<::Native>; pub use self::builder::ArrayBuilder; pub use self::builder::BinaryBuilder; @@ -506,3 +521,37 @@ pub use self::cast::{ // ------------------------------ C Data Interface --------------------------- pub use self::array::make_array_from_raw; + +#[cfg(test)] +mod tests { + use crate::array::*; + + #[test] + fn test_buffer_builder_availability() { + let _builder = Int8BufferBuilder::new(10); + let _builder = Int16BufferBuilder::new(10); + let _builder = Int32BufferBuilder::new(10); + let _builder = Int64BufferBuilder::new(10); + let _builder = UInt16BufferBuilder::new(10); + let _builder = UInt32BufferBuilder::new(10); + let _builder = Float32BufferBuilder::new(10); + let _builder = Float64BufferBuilder::new(10); + let _builder = TimestampSecondBufferBuilder::new(10); + let _builder = TimestampMillisecondBufferBuilder::new(10); + let _builder = TimestampMicrosecondBufferBuilder::new(10); + let _builder = TimestampNanosecondBufferBuilder::new(10); + let _builder = Date32BufferBuilder::new(10); + let _builder = Date64BufferBuilder::new(10); + let _builder = Time32SecondBufferBuilder::new(10); + let _builder = Time32MillisecondBufferBuilder::new(10); + let _builder = Time64MicrosecondBufferBuilder::new(10); + let _builder = Time64NanosecondBufferBuilder::new(10); + let _builder = IntervalYearMonthBufferBuilder::new(10); + let _builder = IntervalDayTimeBufferBuilder::new(10); + let _builder = IntervalMonthDayNanoBufferBuilder::new(10); + let _builder = DurationSecondBufferBuilder::new(10); + let _builder = DurationMillisecondBufferBuilder::new(10); + let _builder = DurationMicrosecondBufferBuilder::new(10); + let _builder = DurationNanosecondBufferBuilder::new(10); + } +} diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index eacdc318c794..7c842473a9b3 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -1102,7 +1102,7 @@ where | DataType::UInt32 | DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, eq_scalar)} _ => Err(ArrowError::ComputeError( - "Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + "eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), )) } DataType::Int8 @@ -1116,7 +1116,7 @@ where dyn_compare_scalar!(&left, right, eq_scalar) } _ => Err(ArrowError::ComputeError( - "Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + "eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), )) } } @@ -1174,7 +1174,7 @@ where | DataType::UInt32 | DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, lt_eq_scalar)} _ => Err(ArrowError::ComputeError( - "Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + "lt_eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), )) } DataType::Int8 @@ -1188,7 +1188,7 @@ where dyn_compare_scalar!(&left, right, lt_eq_scalar) } _ => Err(ArrowError::ComputeError( - "Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + "lt_eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), )) } } @@ -1210,7 +1210,7 @@ where | DataType::UInt32 | DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, gt_scalar)} _ => Err(ArrowError::ComputeError( - "Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + "gt_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), )) } DataType::Int8 @@ -1224,7 +1224,7 @@ where dyn_compare_scalar!(&left, right, gt_scalar) } _ => Err(ArrowError::ComputeError( - "Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + "gt_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), )) } } @@ -1246,7 +1246,7 @@ where | DataType::UInt32 | DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, gt_eq_scalar)} _ => Err(ArrowError::ComputeError( - "Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + "gt_eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), )) } DataType::Int8 @@ -1260,7 +1260,7 @@ where dyn_compare_scalar!(&left, right, gt_eq_scalar) } _ => Err(ArrowError::ComputeError( - "Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), + "gt_eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(), )) } } @@ -1310,7 +1310,7 @@ pub fn eq_dyn_utf8_scalar(left: Arc, right: &str) -> Result Err(ArrowError::ComputeError( - "Kernel only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), + "eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), )), }, DataType::Utf8 | DataType::LargeUtf8 => { @@ -1318,7 +1318,53 @@ pub fn eq_dyn_utf8_scalar(left: Arc, right: &str) -> Result Err(ArrowError::ComputeError( - "Kernel only supports Utf8 or LargeUtf8 arrays".to_string(), + "eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), + )), + }; + result +} + +/// Perform `left < right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +pub fn lt_dyn_utf8_scalar(left: Arc, right: &str) -> Result { + let result = match left.data_type() { + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Utf8 | DataType::LargeUtf8 => { + dyn_compare_utf8_scalar!(&left, right, key_type, lt_utf8_scalar) + } + _ => Err(ArrowError::ComputeError( + "lt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), + )), + }, + DataType::Utf8 | DataType::LargeUtf8 => { + let left = as_string_array(&left); + lt_utf8_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "lt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), + )), + }; + result +} + +/// Perform `left >= right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +pub fn gt_eq_dyn_utf8_scalar(left: Arc, right: &str) -> Result { + let result = match left.data_type() { + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Utf8 | DataType::LargeUtf8 => { + dyn_compare_utf8_scalar!(&left, right, key_type, gt_eq_utf8_scalar) + } + _ => Err(ArrowError::ComputeError( + "gt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), + )), + }, + DataType::Utf8 | DataType::LargeUtf8 => { + let left = as_string_array(&left); + gt_eq_utf8_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "gt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), )), }; result @@ -3206,6 +3252,7 @@ mod tests { BooleanArray::from(vec![Some(false), None, Some(true)]) ); } + #[test] fn test_gt_dyn_scalar() { let array = Int32Array::from(vec![6, 7, 8, 8, 10]); @@ -3233,6 +3280,7 @@ mod tests { BooleanArray::from(vec![Some(true), None, Some(false)]) ); } + #[test] fn test_gt_eq_dyn_scalar() { let array = Int32Array::from(vec![6, 7, 8, 8, 10]); @@ -3245,6 +3293,7 @@ mod tests { ) ); } + #[test] fn test_gt_eq_dyn_scalar_with_dict() { let key_builder = PrimitiveBuilder::::new(3); @@ -3260,6 +3309,7 @@ mod tests { BooleanArray::from(vec![Some(false), None, Some(true)]) ); } + #[test] fn test_neq_dyn_scalar() { let array = Int32Array::from(vec![6, 7, 8, 8, 10]); @@ -3272,6 +3322,7 @@ mod tests { ) ); } + #[test] fn test_neq_dyn_scalar_with_dict() { let key_builder = PrimitiveBuilder::::new(3); @@ -3287,6 +3338,7 @@ mod tests { BooleanArray::from(vec![Some(true), None, Some(false)]) ); } + #[test] fn test_eq_dyn_utf8_scalar() { let array = StringArray::from(vec!["abc", "def", "xyz"]); @@ -3297,6 +3349,7 @@ mod tests { BooleanArray::from(vec![Some(false), Some(false), Some(true)]) ); } + #[test] fn test_eq_dyn_utf8_scalar_with_dict() { let key_builder = PrimitiveBuilder::::new(3); @@ -3316,6 +3369,65 @@ mod tests { ) ); } + #[test] + fn test_lt_dyn_utf8_scalar() { + let array = StringArray::from(vec!["abc", "def", "xyz"]); + let array = Arc::new(array); + let a_eq = lt_dyn_utf8_scalar(array, "xyz").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), Some(true), Some(false)]) + ); + } + #[test] + fn test_lt_dyn_utf8_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = StringBuilder::new(100); + let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); + builder.append("abc").unwrap(); + builder.append_null().unwrap(); + builder.append("def").unwrap(); + builder.append("def").unwrap(); + builder.append("abc").unwrap(); + let array = Arc::new(builder.finish()); + let a_eq = lt_dyn_utf8_scalar(array, "def").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(true), None, Some(false), Some(false), Some(true)] + ) + ); + } + + #[test] + fn test_gt_eq_dyn_utf8_scalar() { + let array = StringArray::from(vec!["abc", "def", "xyz"]); + let array = Arc::new(array); + let a_eq = gt_eq_dyn_utf8_scalar(array, "def").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(false), Some(true), Some(true)]) + ); + } + #[test] + fn test_gt_eq_dyn_utf8_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = StringBuilder::new(100); + let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); + builder.append("abc").unwrap(); + builder.append_null().unwrap(); + builder.append("def").unwrap(); + builder.append("def").unwrap(); + builder.append("xyz").unwrap(); + let array = Arc::new(builder.finish()); + let a_eq = gt_eq_dyn_utf8_scalar(array, "def").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(false), None, Some(true), Some(true), Some(true)] + ) + ); + } #[test] fn test_eq_dyn_bool_scalar() {