Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/master' into add_neq_dyn_scalar_…
Browse files Browse the repository at this point in the history
…kernel
  • Loading branch information
alamb committed Jan 2, 2022
2 parents b310c58 + 37b843b commit 00a622f
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 27 deletions.
83 changes: 66 additions & 17 deletions arrow/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,23 +414,38 @@ pub type UInt64BufferBuilder = BufferBuilder<u64>;
pub type Float32BufferBuilder = BufferBuilder<f32>;
pub type Float64BufferBuilder = BufferBuilder<f64>;

pub type TimestampSecondBufferBuilder = BufferBuilder<TimestampSecondType>;
pub type TimestampMillisecondBufferBuilder = BufferBuilder<TimestampMillisecondType>;
pub type TimestampMicrosecondBufferBuilder = BufferBuilder<TimestampMicrosecondType>;
pub type TimestampNanosecondBufferBuilder = BufferBuilder<TimestampNanosecondType>;
pub type Date32BufferBuilder = BufferBuilder<Date32Type>;
pub type Date64BufferBuilder = BufferBuilder<Date64Type>;
pub type Time32SecondBufferBuilder = BufferBuilder<Time32SecondType>;
pub type Time32MillisecondBufferBuilder = BufferBuilder<Time32MillisecondType>;
pub type Time64MicrosecondBufferBuilder = BufferBuilder<Time64MicrosecondType>;
pub type Time64NanosecondBufferBuilder = BufferBuilder<Time64NanosecondType>;
pub type IntervalYearMonthBufferBuilder = BufferBuilder<IntervalYearMonthType>;
pub type IntervalDayTimeBufferBuilder = BufferBuilder<IntervalDayTimeType>;
pub type IntervalMonthDayNanoBufferBuilder = BufferBuilder<IntervalMonthDayNanoType>;
pub type DurationSecondBufferBuilder = BufferBuilder<DurationSecondType>;
pub type DurationMillisecondBufferBuilder = BufferBuilder<DurationMillisecondType>;
pub type DurationMicrosecondBufferBuilder = BufferBuilder<DurationMicrosecondType>;
pub type DurationNanosecondBufferBuilder = BufferBuilder<DurationNanosecondType>;
pub type TimestampSecondBufferBuilder =
BufferBuilder<<TimestampSecondType as ArrowPrimitiveType>::Native>;
pub type TimestampMillisecondBufferBuilder =
BufferBuilder<<TimestampMillisecondType as ArrowPrimitiveType>::Native>;
pub type TimestampMicrosecondBufferBuilder =
BufferBuilder<<TimestampMicrosecondType as ArrowPrimitiveType>::Native>;
pub type TimestampNanosecondBufferBuilder =
BufferBuilder<<TimestampNanosecondType as ArrowPrimitiveType>::Native>;
pub type Date32BufferBuilder = BufferBuilder<<Date32Type as ArrowPrimitiveType>::Native>;
pub type Date64BufferBuilder = BufferBuilder<<Date64Type as ArrowPrimitiveType>::Native>;
pub type Time32SecondBufferBuilder =
BufferBuilder<<Time32SecondType as ArrowPrimitiveType>::Native>;
pub type Time32MillisecondBufferBuilder =
BufferBuilder<<Time32MillisecondType as ArrowPrimitiveType>::Native>;
pub type Time64MicrosecondBufferBuilder =
BufferBuilder<<Time64MicrosecondType as ArrowPrimitiveType>::Native>;
pub type Time64NanosecondBufferBuilder =
BufferBuilder<<Time64NanosecondType as ArrowPrimitiveType>::Native>;
pub type IntervalYearMonthBufferBuilder =
BufferBuilder<<IntervalYearMonthType as ArrowPrimitiveType>::Native>;
pub type IntervalDayTimeBufferBuilder =
BufferBuilder<<IntervalDayTimeType as ArrowPrimitiveType>::Native>;
pub type IntervalMonthDayNanoBufferBuilder =
BufferBuilder<<IntervalMonthDayNanoType as ArrowPrimitiveType>::Native>;
pub type DurationSecondBufferBuilder =
BufferBuilder<<DurationSecondType as ArrowPrimitiveType>::Native>;
pub type DurationMillisecondBufferBuilder =
BufferBuilder<<DurationMillisecondType as ArrowPrimitiveType>::Native>;
pub type DurationMicrosecondBufferBuilder =
BufferBuilder<<DurationMicrosecondType as ArrowPrimitiveType>::Native>;
pub type DurationNanosecondBufferBuilder =
BufferBuilder<<DurationNanosecondType as ArrowPrimitiveType>::Native>;

pub use self::builder::ArrayBuilder;
pub use self::builder::BinaryBuilder;
Expand Down Expand Up @@ -506,3 +521,37 @@ pub use self::cast::{
// ------------------------------ C Data Interface ---------------------------

pub use self::array::make_array_from_raw;

#[cfg(test)]
mod tests {
use crate::array::*;

#[test]
fn test_buffer_builder_availability() {
let _builder = Int8BufferBuilder::new(10);
let _builder = Int16BufferBuilder::new(10);
let _builder = Int32BufferBuilder::new(10);
let _builder = Int64BufferBuilder::new(10);
let _builder = UInt16BufferBuilder::new(10);
let _builder = UInt32BufferBuilder::new(10);
let _builder = Float32BufferBuilder::new(10);
let _builder = Float64BufferBuilder::new(10);
let _builder = TimestampSecondBufferBuilder::new(10);
let _builder = TimestampMillisecondBufferBuilder::new(10);
let _builder = TimestampMicrosecondBufferBuilder::new(10);
let _builder = TimestampNanosecondBufferBuilder::new(10);
let _builder = Date32BufferBuilder::new(10);
let _builder = Date64BufferBuilder::new(10);
let _builder = Time32SecondBufferBuilder::new(10);
let _builder = Time32MillisecondBufferBuilder::new(10);
let _builder = Time64MicrosecondBufferBuilder::new(10);
let _builder = Time64NanosecondBufferBuilder::new(10);
let _builder = IntervalYearMonthBufferBuilder::new(10);
let _builder = IntervalDayTimeBufferBuilder::new(10);
let _builder = IntervalMonthDayNanoBufferBuilder::new(10);
let _builder = DurationSecondBufferBuilder::new(10);
let _builder = DurationMillisecondBufferBuilder::new(10);
let _builder = DurationMicrosecondBufferBuilder::new(10);
let _builder = DurationNanosecondBufferBuilder::new(10);
}
}
132 changes: 122 additions & 10 deletions arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,7 @@ where
| DataType::UInt32
| DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, eq_scalar)}
_ => Err(ArrowError::ComputeError(
"Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
"eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
))
}
DataType::Int8
Expand All @@ -1116,7 +1116,7 @@ where
dyn_compare_scalar!(&left, right, eq_scalar)
}
_ => Err(ArrowError::ComputeError(
"Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
"eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
))
}
}
Expand Down Expand Up @@ -1174,7 +1174,7 @@ where
| DataType::UInt32
| DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, lt_eq_scalar)}
_ => Err(ArrowError::ComputeError(
"Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
"lt_eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
))
}
DataType::Int8
Expand All @@ -1188,7 +1188,7 @@ where
dyn_compare_scalar!(&left, right, lt_eq_scalar)
}
_ => Err(ArrowError::ComputeError(
"Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
"lt_eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
))
}
}
Expand All @@ -1210,7 +1210,7 @@ where
| DataType::UInt32
| DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, gt_scalar)}
_ => Err(ArrowError::ComputeError(
"Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
"gt_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
))
}
DataType::Int8
Expand All @@ -1224,7 +1224,7 @@ where
dyn_compare_scalar!(&left, right, gt_scalar)
}
_ => Err(ArrowError::ComputeError(
"Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
"gt_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
))
}
}
Expand All @@ -1246,7 +1246,7 @@ where
| DataType::UInt32
| DataType::UInt64 => {dyn_compare_scalar!(&left, right, key_type, gt_eq_scalar)}
_ => Err(ArrowError::ComputeError(
"Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
"gt_eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
))
}
DataType::Int8
Expand All @@ -1260,7 +1260,7 @@ where
dyn_compare_scalar!(&left, right, gt_eq_scalar)
}
_ => Err(ArrowError::ComputeError(
"Kernel only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
"gt_eq_dyn_scalar only supports PrimitiveArray or DictionaryArray with Primitive values".to_string(),
))
}
}
Expand Down Expand Up @@ -1310,15 +1310,61 @@ pub fn eq_dyn_utf8_scalar(left: Arc<dyn Array>, right: &str) -> Result<BooleanAr
dyn_compare_utf8_scalar!(&left, right, key_type, eq_utf8_scalar)
}
_ => Err(ArrowError::ComputeError(
"Kernel only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(),
"eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(),
)),
},
DataType::Utf8 | DataType::LargeUtf8 => {
let left = as_string_array(&left);
eq_utf8_scalar(left, right)
}
_ => Err(ArrowError::ComputeError(
"Kernel only supports Utf8 or LargeUtf8 arrays".to_string(),
"eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(),
)),
};
result
}

/// Perform `left < right` operation on an array and a numeric scalar
/// value. Supports StringArrays, and DictionaryArrays that have string values
pub fn lt_dyn_utf8_scalar(left: Arc<dyn Array>, right: &str) -> Result<BooleanArray> {
let result = match left.data_type() {
DataType::Dictionary(key_type, value_type) => match value_type.as_ref() {
DataType::Utf8 | DataType::LargeUtf8 => {
dyn_compare_utf8_scalar!(&left, right, key_type, lt_utf8_scalar)
}
_ => Err(ArrowError::ComputeError(
"lt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(),
)),
},
DataType::Utf8 | DataType::LargeUtf8 => {
let left = as_string_array(&left);
lt_utf8_scalar(left, right)
}
_ => Err(ArrowError::ComputeError(
"lt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(),
)),
};
result
}

/// Perform `left >= right` operation on an array and a numeric scalar
/// value. Supports StringArrays, and DictionaryArrays that have string values
pub fn gt_eq_dyn_utf8_scalar(left: Arc<dyn Array>, right: &str) -> Result<BooleanArray> {
let result = match left.data_type() {
DataType::Dictionary(key_type, value_type) => match value_type.as_ref() {
DataType::Utf8 | DataType::LargeUtf8 => {
dyn_compare_utf8_scalar!(&left, right, key_type, gt_eq_utf8_scalar)
}
_ => Err(ArrowError::ComputeError(
"gt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(),
)),
},
DataType::Utf8 | DataType::LargeUtf8 => {
let left = as_string_array(&left);
gt_eq_utf8_scalar(left, right)
}
_ => Err(ArrowError::ComputeError(
"gt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(),
)),
};
result
Expand Down Expand Up @@ -3206,6 +3252,7 @@ mod tests {
BooleanArray::from(vec![Some(false), None, Some(true)])
);
}

#[test]
fn test_gt_dyn_scalar() {
let array = Int32Array::from(vec![6, 7, 8, 8, 10]);
Expand Down Expand Up @@ -3233,6 +3280,7 @@ mod tests {
BooleanArray::from(vec![Some(true), None, Some(false)])
);
}

#[test]
fn test_gt_eq_dyn_scalar() {
let array = Int32Array::from(vec![6, 7, 8, 8, 10]);
Expand All @@ -3245,6 +3293,7 @@ mod tests {
)
);
}

#[test]
fn test_gt_eq_dyn_scalar_with_dict() {
let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
Expand All @@ -3260,6 +3309,7 @@ mod tests {
BooleanArray::from(vec![Some(false), None, Some(true)])
);
}

#[test]
fn test_neq_dyn_scalar() {
let array = Int32Array::from(vec![6, 7, 8, 8, 10]);
Expand All @@ -3272,6 +3322,7 @@ mod tests {
)
);
}

#[test]
fn test_neq_dyn_scalar_with_dict() {
let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
Expand All @@ -3287,6 +3338,7 @@ mod tests {
BooleanArray::from(vec![Some(true), None, Some(false)])
);
}

#[test]
fn test_eq_dyn_utf8_scalar() {
let array = StringArray::from(vec!["abc", "def", "xyz"]);
Expand All @@ -3297,6 +3349,7 @@ mod tests {
BooleanArray::from(vec![Some(false), Some(false), Some(true)])
);
}

#[test]
fn test_eq_dyn_utf8_scalar_with_dict() {
let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
Expand All @@ -3316,6 +3369,65 @@ mod tests {
)
);
}
#[test]
fn test_lt_dyn_utf8_scalar() {
let array = StringArray::from(vec!["abc", "def", "xyz"]);
let array = Arc::new(array);
let a_eq = lt_dyn_utf8_scalar(array, "xyz").unwrap();
assert_eq!(
a_eq,
BooleanArray::from(vec![Some(true), Some(true), Some(false)])
);
}
#[test]
fn test_lt_dyn_utf8_scalar_with_dict() {
let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
let value_builder = StringBuilder::new(100);
let mut builder = StringDictionaryBuilder::new(key_builder, value_builder);
builder.append("abc").unwrap();
builder.append_null().unwrap();
builder.append("def").unwrap();
builder.append("def").unwrap();
builder.append("abc").unwrap();
let array = Arc::new(builder.finish());
let a_eq = lt_dyn_utf8_scalar(array, "def").unwrap();
assert_eq!(
a_eq,
BooleanArray::from(
vec![Some(true), None, Some(false), Some(false), Some(true)]
)
);
}

#[test]
fn test_gt_eq_dyn_utf8_scalar() {
let array = StringArray::from(vec!["abc", "def", "xyz"]);
let array = Arc::new(array);
let a_eq = gt_eq_dyn_utf8_scalar(array, "def").unwrap();
assert_eq!(
a_eq,
BooleanArray::from(vec![Some(false), Some(true), Some(true)])
);
}
#[test]
fn test_gt_eq_dyn_utf8_scalar_with_dict() {
let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
let value_builder = StringBuilder::new(100);
let mut builder = StringDictionaryBuilder::new(key_builder, value_builder);
builder.append("abc").unwrap();
builder.append_null().unwrap();
builder.append("def").unwrap();
builder.append("def").unwrap();
builder.append("xyz").unwrap();
let array = Arc::new(builder.finish());
let a_eq = gt_eq_dyn_utf8_scalar(array, "def").unwrap();
assert_eq!(
a_eq,
BooleanArray::from(
vec![Some(false), None, Some(true), Some(true), Some(true)]
)
);
}

#[test]
fn test_eq_dyn_bool_scalar() {
Expand Down

0 comments on commit 00a622f

Please sign in to comment.