diff --git a/src/array/binary/mutable.rs b/src/array/binary/mutable.rs index fc9adba37f..b52bfa8c83 100644 --- a/src/array/binary/mutable.rs +++ b/src/array/binary/mutable.rs @@ -144,9 +144,9 @@ impl MutableArray for MutableBinaryArray { fn data_type(&self) -> &DataType { if O::is_large() { - &DataType::LargeUtf8 + &DataType::LargeBinary } else { - &DataType::Utf8 + &DataType::Binary } } diff --git a/src/compute/contains.rs b/src/compute/contains.rs index d9452da03f..bdf82bf5bf 100644 --- a/src/compute/contains.rs +++ b/src/compute/contains.rs @@ -17,7 +17,7 @@ use crate::types::NativeType; use crate::{ - array::{Array, BooleanArray, ListArray, Offset, PrimitiveArray, Utf8Array}, + array::{Array, BooleanArray, ListArray, Offset, PrimitiveArray, Utf8Array, BinaryArray}, bitmap::Bitmap, }; use crate::{ @@ -96,6 +96,40 @@ where Ok(BooleanArray::from_data(values, validity)) } +/// Checks if a [`GenericListArray`] contains a value in the [`BinaryArray`] +fn contains_binary(list: &ListArray, values: &BinaryArray) -> Result +where + O: Offset, + OO: Offset, +{ + if list.len() != values.len() { + return Err(ArrowError::InvalidArgumentError( + "Contains requires arrays of the same length".to_string(), + )); + } + if list.values().data_type() != values.data_type() { + return Err(ArrowError::InvalidArgumentError( + "Contains requires the inner array to be of the same logical type".to_string(), + )); + } + + let validity = combine_validities(list.validity(), values.validity()); + + let values = list.iter().zip(values.iter()).map(|(list, values)| { + if list.is_none() | values.is_none() { + // validity takes care of this + return false; + }; + let list = list.unwrap(); + let list = list.as_any().downcast_ref::>().unwrap(); + let values = values.unwrap(); + list.iter().any(|x| x.map(|x| x == values).unwrap_or(false)) + }); + let values = Bitmap::from_trusted_len_iter(values); + + Ok(BooleanArray::from_data(values, validity)) +} + macro_rules! primitive { ($list:expr, $values:expr, $l_ty:ty, $r_ty:ty) => {{ let list = $list.as_any().downcast_ref::>().unwrap(); @@ -132,6 +166,26 @@ pub fn contains(list: &dyn Array, values: &dyn Array) -> Result { let values = values.as_any().downcast_ref::>().unwrap(); contains_utf8(list, values) } + (DataType::List(_), DataType::Binary) => { + let list = list.as_any().downcast_ref::>().unwrap(); + let values = values.as_any().downcast_ref::>().unwrap(); + contains_binary(list, values) + } + (DataType::List(_), DataType::LargeBinary) => { + let list = list.as_any().downcast_ref::>().unwrap(); + let values = values.as_any().downcast_ref::>().unwrap(); + contains_binary(list, values) + } + (DataType::LargeList(_), DataType::LargeBinary) => { + let list = list.as_any().downcast_ref::>().unwrap(); + let values = values.as_any().downcast_ref::>().unwrap(); + contains_binary(list, values) + } + (DataType::LargeList(_), DataType::Binary) => { + let list = list.as_any().downcast_ref::>().unwrap(); + let values = values.as_any().downcast_ref::>().unwrap(); + contains_binary(list, values) + } (DataType::List(_), DataType::Int8) => primitive!(list, values, i32, i8), (DataType::List(_), DataType::Int16) => primitive!(list, values, i32, i16), (DataType::List(_), DataType::Int32) => primitive!(list, values, i32, i32), @@ -196,4 +250,29 @@ mod tests { assert_eq!(result, expected); } + + #[test] + fn test_contains_binary() { + let data = vec![ + Some(vec![Some(b"a"), Some(b"b"), None]), + Some(vec![Some(b"a"), Some(b"b"), None]), + Some(vec![Some(b"a"), Some(b"b"), None]), + None, + ]; + let values = BinaryArray::::from(&[Some(b"a"), Some(b"c"), None, Some(b"a")]); + let expected = BooleanArray::from(vec![ + Some(true), + Some(false), + None, + None + ]); + + let mut a = MutableListArray::>::new(); + a.try_extend(data).unwrap(); + let a: ListArray = a.into(); + + let result = contains(&a, &values).unwrap(); + + assert_eq!(result, expected); + } }