Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Add support for binary contains
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyass committed Aug 31, 2021
1 parent 9e5aef9 commit df546be
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/array/binary/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ impl<O: Offset> MutableArray for MutableBinaryArray<O> {

fn data_type(&self) -> &DataType {
if O::is_large() {
&DataType::LargeUtf8
&DataType::LargeBinary
} else {
&DataType::Utf8
&DataType::Binary
}
}

Expand Down
81 changes: 80 additions & 1 deletion src/compute/contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use crate::types::NativeType;
use crate::{
array::{Array, BooleanArray, ListArray, Offset, PrimitiveArray, Utf8Array},
array::{Array, BooleanArray, ListArray, Offset, PrimitiveArray, Utf8Array, BinaryArray},
bitmap::Bitmap,
};
use crate::{
Expand Down Expand Up @@ -96,6 +96,40 @@ where
Ok(BooleanArray::from_data(values, validity))
}

/// Checks if a [`GenericListArray`] contains a value in the [`BinaryArray`]
fn contains_binary<O, OO>(list: &ListArray<O>, values: &BinaryArray<OO>) -> Result<BooleanArray>
where
O: Offset,
OO: Offset,
{
if list.len() != values.len() {
return Err(ArrowError::InvalidArgumentError(
"Contains requires arrays of the same length".to_string(),
));
}
if list.values().data_type() != values.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Contains requires the inner array to be of the same logical type".to_string(),
));
}

let validity = combine_validities(list.validity(), values.validity());

let values = list.iter().zip(values.iter()).map(|(list, values)| {
if list.is_none() | values.is_none() {
// validity takes care of this
return false;
};
let list = list.unwrap();
let list = list.as_any().downcast_ref::<BinaryArray<OO>>().unwrap();
let values = values.unwrap();
list.iter().any(|x| x.map(|x| x == values).unwrap_or(false))
});
let values = Bitmap::from_trusted_len_iter(values);

Ok(BooleanArray::from_data(values, validity))
}

macro_rules! primitive {
($list:expr, $values:expr, $l_ty:ty, $r_ty:ty) => {{
let list = $list.as_any().downcast_ref::<ListArray<$l_ty>>().unwrap();
Expand Down Expand Up @@ -132,6 +166,26 @@ pub fn contains(list: &dyn Array, values: &dyn Array) -> Result<BooleanArray> {
let values = values.as_any().downcast_ref::<Utf8Array<i32>>().unwrap();
contains_utf8(list, values)
}
(DataType::List(_), DataType::Binary) => {
let list = list.as_any().downcast_ref::<ListArray<i32>>().unwrap();
let values = values.as_any().downcast_ref::<BinaryArray<i32>>().unwrap();
contains_binary(list, values)
}
(DataType::List(_), DataType::LargeBinary) => {
let list = list.as_any().downcast_ref::<ListArray<i32>>().unwrap();
let values = values.as_any().downcast_ref::<BinaryArray<i64>>().unwrap();
contains_binary(list, values)
}
(DataType::LargeList(_), DataType::LargeBinary) => {
let list = list.as_any().downcast_ref::<ListArray<i64>>().unwrap();
let values = values.as_any().downcast_ref::<BinaryArray<i64>>().unwrap();
contains_binary(list, values)
}
(DataType::LargeList(_), DataType::Binary) => {
let list = list.as_any().downcast_ref::<ListArray<i64>>().unwrap();
let values = values.as_any().downcast_ref::<BinaryArray<i32>>().unwrap();
contains_binary(list, values)
}
(DataType::List(_), DataType::Int8) => primitive!(list, values, i32, i8),
(DataType::List(_), DataType::Int16) => primitive!(list, values, i32, i16),
(DataType::List(_), DataType::Int32) => primitive!(list, values, i32, i32),
Expand Down Expand Up @@ -196,4 +250,29 @@ mod tests {

assert_eq!(result, expected);
}

#[test]
fn test_contains_binary() {
let data = vec![
Some(vec![Some(b"a"), Some(b"b"), None]),
Some(vec![Some(b"a"), Some(b"b"), None]),
Some(vec![Some(b"a"), Some(b"b"), None]),
None,
];
let values = BinaryArray::<i32>::from(&[Some(b"a"), Some(b"c"), None, Some(b"a")]);
let expected = BooleanArray::from(vec![
Some(true),
Some(false),
None,
None
]);

let mut a = MutableListArray::<i32, MutableBinaryArray<i32>>::new();
a.try_extend(data).unwrap();
let a: ListArray<i32> = a.into();

let result = contains(&a, &values).unwrap();

assert_eq!(result, expected);
}
}

0 comments on commit df546be

Please sign in to comment.