From 23682f033d6e473a8292e441414795e843b47188 Mon Sep 17 00:00:00 2001 From: zhyass <34016424+zhyass@users.noreply.github.com> Date: Thu, 26 Aug 2021 22:16:03 +0800 Subject: [PATCH] Add support for binary compute (#345) [summary] 1. comparison 2. sort 3. contains 4. like 5. cast 6. substring 7. min_max 8. cargo fmt --- src/array/binary/from.rs | 110 +++++----- src/array/binary/mutable.rs | 280 ++++++++++++++++++++++++- src/array/growable/binary.rs | 7 +- src/array/ord.rs | 8 + src/compute/aggregate/min_max.rs | 77 ++++++- src/compute/arithmetics/decimal/sub.rs | 96 +++------ src/compute/cast/binary_to.rs | 46 +++- src/compute/cast/boolean_to.rs | 13 +- src/compute/cast/mod.rs | 102 ++++++++- src/compute/cast/primitive_to.rs | 23 +- src/compute/comparison/binary.rs | 251 ++++++++++++++++++++++ src/compute/comparison/mod.rs | 26 ++- src/compute/contains.rs | 81 ++++++- src/compute/like.rs | 189 +++++++++++++++++ src/compute/merge_sort/mod.rs | 33 ++- src/compute/sort/binary.rs | 15 ++ src/compute/sort/mod.rs | 17 +- src/compute/substring.rs | 226 +++++++++++++++++++- src/io/parquet/read/binary/mod.rs | 2 +- tests/it/array/binary/mod.rs | 16 ++ tests/it/array/binary/mutable.rs | 11 + tests/it/compute/cast.rs | 40 ++++ tests/it/io/json/write.rs | 2 +- 23 files changed, 1512 insertions(+), 159 deletions(-) create mode 100644 src/compute/comparison/binary.rs create mode 100644 src/compute/sort/binary.rs create mode 100644 tests/it/array/binary/mutable.rs diff --git a/src/array/binary/from.rs b/src/array/binary/from.rs index 60fda084131..2154a22c16d 100644 --- a/src/array/binary/from.rs +++ b/src/array/binary/from.rs @@ -1,18 +1,13 @@ use std::iter::FromIterator; -use crate::{ - array::Offset, - bitmap::{Bitmap, MutableBitmap}, - buffer::{Buffer, MutableBuffer}, - trusted_len::TrustedLen, -}; +use crate::{array::Offset, trusted_len::TrustedLen}; use super::{BinaryArray, MutableBinaryArray}; impl BinaryArray { /// Creates a new [`BinaryArray`] from slices of `&[u8]`. pub fn from_slice, P: AsRef<[T]>>(slice: P) -> Self { - Self::from_iter(slice.as_ref().iter().map(Some)) + Self::from_trusted_len_values_iter(slice.as_ref().iter()) } /// Creates a new [`BinaryArray`] from a slice of optional `&[u8]`. @@ -23,15 +18,15 @@ impl BinaryArray { /// Creates a [`BinaryArray`] from an iterator of trusted length. #[inline] - pub fn from_trusted_len_iter(iterator: I) -> Self - where - P: AsRef<[u8]>, - I: TrustedLen>, - { - // soundness: I is `TrustedLen` - let (validity, offsets, values) = unsafe { trusted_len_unzip(iterator) }; + pub fn from_trusted_len_values_iter, I: TrustedLen>( + iterator: I, + ) -> Self { + MutableBinaryArray::::from_trusted_len_values_iter(iterator).into() + } - Self::from_data(Self::default_data_type(), offsets, values, validity) + /// Creates a new [`BinaryArray`] from a [`Iterator`] of `&str`. + pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { + MutableBinaryArray::::from_iter_values(iterator).into() } } @@ -42,49 +37,52 @@ impl> FromIterator> for BinaryArray { } } -/// Creates [`Bitmap`] and two [`Buffer`]s from an iterator of `Option`. -/// The first buffer corresponds to a offset buffer, the second one -/// corresponds to a values buffer. -/// # Safety -/// The caller must ensure that `iterator` is `TrustedLen`. -#[inline] -pub unsafe fn trusted_len_unzip(iterator: I) -> (Option, Buffer, Buffer) -where - O: Offset, - P: AsRef<[u8]>, - I: Iterator>, -{ - let (_, upper) = iterator.size_hint(); - let len = upper.expect("trusted_len_unzip requires an upper limit"); - - let mut null = MutableBitmap::with_capacity(len); - let mut offsets = MutableBuffer::::with_capacity(len + 1); - let mut values = MutableBuffer::::new(); +impl BinaryArray { + /// Creates a [`BinaryArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: Iterator>, + { + MutableBinaryArray::::from_trusted_len_iter_unchecked(iterator).into() + } - let mut length = O::default(); - let mut dst = offsets.as_mut_ptr(); - std::ptr::write(dst, length); - dst = dst.add(1); - for item in iterator { - if let Some(item) = item { - null.push(true); - let s = item.as_ref(); - length += O::from_usize(s.len()).unwrap(); - values.extend_from_slice(s); - } else { - null.push(false); - values.extend_from_slice(b""); - }; + /// Creates a [`BinaryArray`] from an iterator of trusted length. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: TrustedLen>, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } - std::ptr::write(dst, length); - dst = dst.add(1); + /// Creates a [`BinaryArray`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked(iterator: I) -> Result + where + P: AsRef<[u8]>, + I: IntoIterator, E>>, + { + MutableBinaryArray::::try_from_trusted_len_iter_unchecked(iterator).map(|x| x.into()) } - assert_eq!( - dst.offset_from(offsets.as_ptr()) as usize, - len + 1, - "Trusted iterator length was not accurately reported" - ); - offsets.set_len(len + 1); - (null.into(), offsets.into(), values.into()) + /// Creates a [`BinaryArray`] from an fallible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iter: I) -> Result + where + P: AsRef<[u8]>, + I: TrustedLen, E>>, + { + // soundness: I: TrustedLen + unsafe { Self::try_from_trusted_len_iter_unchecked(iter) } + } } diff --git a/src/array/binary/mutable.rs b/src/array/binary/mutable.rs index 52ae198fa7e..f353af63db3 100644 --- a/src/array/binary/mutable.rs +++ b/src/array/binary/mutable.rs @@ -1,11 +1,12 @@ use std::{iter::FromIterator, sync::Arc}; use crate::{ - array::{Array, MutableArray, Offset, TryExtend, TryPush}, + array::{specification::check_offsets, Array, MutableArray, Offset, TryExtend, TryPush}, bitmap::MutableBitmap, buffer::MutableBuffer, datatypes::DataType, error::{ArrowError, Result}, + trusted_len::TrustedLen, }; use super::BinaryArray; @@ -47,6 +48,36 @@ impl MutableBinaryArray { Self::with_capacity(0) } + /// The canonical method to create a [`MutableBinaryArray`] out of low-end APIs. + /// # Panics + /// This function panics iff: + /// * The `offsets` and `values` are inconsistent + /// * The validity is not `None` and its length is different from `offsets`'s length minus one. + pub fn from_data( + data_type: DataType, + offsets: MutableBuffer, + values: MutableBuffer, + validity: Option, + ) -> Self { + check_offsets(&offsets, values.len()); + if let Some(ref validity) = validity { + assert_eq!(offsets.len() - 1, validity.len()); + } + if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + panic!("MutableBinaryArray can only be initialized with DataType::Binary or DataType::LargeBinary") + } + Self { + data_type, + offsets, + values, + validity, + } + } + + fn default_data_type() -> DataType { + BinaryArray::::default_data_type() + } + /// Creates a new [`MutableBinaryArray`] with capacity for `capacity` values. /// # Implementation /// This does not allocate the validity. @@ -148,6 +179,90 @@ impl> FromIterator> for MutableBinaryArray MutableBinaryArray { + /// Creates a [`MutableBinaryArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: Iterator>, + { + let (validity, offsets, values) = trusted_len_unzip(iterator); + + // soundness: P is `str` + Self::from_data(Self::default_data_type(), offsets, values, validity) + } + + /// Creates a [`MutableBinaryArray`] from an iterator of trusted length. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: TrustedLen>, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`BinaryArray`] from a [`TrustedLen`] of `&str`. + #[inline] + pub fn from_trusted_len_values_iter, I: TrustedLen>( + iterator: I, + ) -> Self { + // soundness: I is `TrustedLen` + let (offsets, values) = unsafe { trusted_len_values_iter(iterator) }; + // soundness: T is AsRef<[u8]> + Self::from_data(Self::default_data_type(), offsets, values, None) + } + + /// Creates a [`MutableBinaryArray`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iterator: I, + ) -> std::result::Result + where + P: AsRef<[u8]>, + I: IntoIterator, E>>, + { + let iterator = iterator.into_iter(); + + // soundness: assumed trusted len + let (validity, offsets, values) = try_trusted_len_unzip(iterator)?; + + // soundness: P is `str` + Ok(Self::from_data( + Self::default_data_type(), + offsets, + values, + validity, + )) + } + + /// Creates a [`MutableBinaryArray`] from an falible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + P: AsRef<[u8]>, + I: TrustedLen, E>>, + { + // soundness: I: TrustedLen + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`MutableBinaryArray`] from a [`Iterator`] of `&[u8]`. + pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { + let (offsets, values) = values_iter(iterator); + // soundness: T: AsRef<[u8]> + Self::from_data(Self::default_data_type(), offsets, values, None) + } +} + impl> Extend> for MutableBinaryArray { fn extend>>(&mut self, iter: I) { self.try_extend(iter).unwrap(); @@ -191,3 +306,166 @@ impl> TryPush> for MutableBinaryArray { Ok(()) } } + +/// Creates [`MutableBitmap`] and two [`MutableBuffer`]s from an iterator of `Option`. +/// The first buffer corresponds to a offset buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +unsafe fn trusted_len_unzip( + iterator: I, +) -> (Option, MutableBuffer, MutableBuffer) +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut offsets = MutableBuffer::::with_capacity(len + 1); + let mut values = MutableBuffer::::new(); + + let mut length = O::default(); + let mut dst = offsets.as_mut_ptr(); + std::ptr::write(dst, length); + dst = dst.add(1); + for item in iterator { + if let Some(item) = item { + null.push(true); + let s = item.as_ref(); + length += O::from_usize(s.len()).unwrap(); + values.extend_from_slice(s); + } else { + null.push(false); + values.extend_from_slice(b""); + }; + + std::ptr::write(dst, length); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(offsets.as_ptr()) as usize, + len + 1, + "Trusted iterator length was not accurately reported" + ); + offsets.set_len(len + 1); + + (null.into(), offsets, values) +} + +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +#[allow(clippy::type_complexity)] +pub(crate) unsafe fn try_trusted_len_unzip( + iterator: I, +) -> std::result::Result<(Option, MutableBuffer, MutableBuffer), E> +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, E>>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut offsets = MutableBuffer::::with_capacity(len + 1); + let mut values = MutableBuffer::::new(); + + let mut length = O::default(); + let mut dst = offsets.as_mut_ptr(); + std::ptr::write(dst, length); + dst = dst.add(1); + for item in iterator { + if let Some(item) = item? { + null.push(true); + let s = item.as_ref(); + length += O::from_usize(s.len()).unwrap(); + values.extend_from_slice(s); + } else { + null.push(false); + }; + + std::ptr::write(dst, length); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(offsets.as_ptr()) as usize, + len + 1, + "Trusted iterator length was not accurately reported" + ); + offsets.set_len(len + 1); + + Ok((null.into(), offsets, values)) +} + +/// Creates two [`Buffer`]s from an iterator of `&[u8]`. +/// The first buffer corresponds to a offset buffer, the second to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is [`TrustedLen`]. +#[inline] +pub(crate) unsafe fn trusted_len_values_iter( + iterator: I, +) -> (MutableBuffer, MutableBuffer) +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut offsets = MutableBuffer::::with_capacity(len + 1); + let mut values = MutableBuffer::::new(); + + let mut length = O::default(); + let mut dst = offsets.as_mut_ptr(); + std::ptr::write(dst, length); + dst = dst.add(1); + for item in iterator { + let s = item.as_ref(); + length += O::from_usize(s.len()).unwrap(); + values.extend_from_slice(s); + + std::ptr::write(dst, length); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(offsets.as_ptr()) as usize, + len + 1, + "Trusted iterator length was not accurately reported" + ); + offsets.set_len(len + 1); + + (offsets, values) +} + +/// Creates two [`MutableBuffer`]s from an iterator of `&[u8]`. +/// The first buffer corresponds to a offset buffer, the second to a values buffer. +#[inline] +fn values_iter(iterator: I) -> (MutableBuffer, MutableBuffer) +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, +{ + let (lower, _) = iterator.size_hint(); + + let mut offsets = MutableBuffer::::with_capacity(lower + 1); + let mut values = MutableBuffer::::new(); + + let mut length = O::default(); + offsets.push(length); + + for item in iterator { + let s = item.as_ref(); + length += O::from_usize(s.len()).unwrap(); + values.extend_from_slice(s); + + offsets.push(length) + } + (offsets, values) +} diff --git a/src/array/growable/binary.rs b/src/array/growable/binary.rs index f51dd90a5bd..caca72a272b 100644 --- a/src/array/growable/binary.rs +++ b/src/array/growable/binary.rs @@ -100,6 +100,11 @@ impl<'a, O: Offset> Growable<'a> for GrowableBinary<'a, O> { impl<'a, O: Offset> From> for BinaryArray { fn from(val: GrowableBinary<'a, O>) -> Self { - BinaryArray::::from_data(val.data_type, val.offsets.into(), val.values.into(), val.validity.into()) + BinaryArray::::from_data( + val.data_type, + val.offsets.into(), + val.values.into(), + val.validity.into(), + ) } } diff --git a/src/array/ord.rs b/src/array/ord.rs index ea9d0d69a36..3eb1a98ae53 100644 --- a/src/array/ord.rs +++ b/src/array/ord.rs @@ -92,6 +92,12 @@ fn compare_string<'a, O: Offset>(left: &'a dyn Array, right: &'a dyn Array) -> D Box::new(move |i, j| left.value(i).cmp(right.value(j))) } +fn compare_binary<'a, O: Offset>(left: &'a dyn Array, right: &'a dyn Array) -> DynComparator<'a> { + let left = left.as_any().downcast_ref::>().unwrap(); + let right = right.as_any().downcast_ref::>().unwrap(); + Box::new(move |i, j| left.value(i).cmp(right.value(j))) +} + fn compare_dict<'a, K>( left: &'a DictionaryArray, right: &'a DictionaryArray, @@ -178,6 +184,8 @@ pub fn build_compare<'a>(left: &'a dyn Array, right: &'a dyn Array) -> Result compare_f64(left, right), (Utf8, Utf8) => compare_string::(left, right), (LargeUtf8, LargeUtf8) => compare_string::(left, right), + (Binary, Binary) => compare_binary::(left, right), + (LargeBinary, LargeBinary) => compare_binary::(left, right), (Dictionary(key_type_lhs, _), Dictionary(key_type_rhs, _)) => { match (key_type_lhs.as_ref(), key_type_rhs.as_ref()) { (UInt8, UInt8) => dyn_dict!(u8, left, right), diff --git a/src/compute/aggregate/min_max.rs b/src/compute/aggregate/min_max.rs index 858a10e6eed..77f7be6c32b 100644 --- a/src/compute/aggregate/min_max.rs +++ b/src/compute/aggregate/min_max.rs @@ -5,7 +5,7 @@ use crate::scalar::*; use crate::types::simd::*; use crate::types::NativeType; use crate::{ - array::{Array, BooleanArray, Offset, PrimitiveArray, Utf8Array}, + array::{Array, BinaryArray, BooleanArray, Offset, PrimitiveArray, Utf8Array}, bitmap::Bitmap, }; @@ -20,6 +20,42 @@ pub trait SimdOrd { fn new_max() -> Self; } +/// Helper macro to perform min/max of binarys. +fn min_max_binary bool>( + array: &BinaryArray, + cmp: F, +) -> Option<&[u8]> { + let null_count = array.null_count(); + + if null_count == array.len() || array.len() == 0 { + return None; + } + let mut n; + if let Some(validity) = array.validity() { + n = "".as_bytes(); + let mut has_value = false; + + for i in 0..array.len() { + let item = array.value(i); + if validity.get_bit(i) && (!has_value || cmp(n, item)) { + has_value = true; + n = item; + } + } + } else { + // array.len() == 0 checked above + n = unsafe { array.value_unchecked(0) }; + for i in 1..array.len() { + // loop is up to `len`. + let item = unsafe { array.value_unchecked(i) }; + if cmp(n, item) { + n = item; + } + } + } + Some(n) +} + /// Helper macro to perform min/max of strings fn min_max_string bool>( array: &Utf8Array, @@ -224,6 +260,16 @@ where }) } +/// Returns the maximum value in the binary array, according to the natural order. +pub fn max_binary(array: &BinaryArray) -> Option<&[u8]> { + min_max_binary(array, |a, b| a < b) +} + +/// Returns the minimum value in the binary array, according to the natural order. +pub fn min_binary(array: &BinaryArray) -> Option<&[u8]> { + min_max_binary(array, |a, b| a > b) +} + /// Returns the maximum value in the string array, according to the natural order. pub fn max_string(array: &Utf8Array) -> Option<&str> { min_max_string(array, |a, b| a < b) @@ -329,6 +375,10 @@ pub fn max(array: &dyn Array) -> Result> { DataType::Float64 => dyn_primitive!(f64, array, max_primitive), DataType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), DataType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), + DataType::Binary => dyn_generic!(BinaryArray, BinaryScalar, array, max_binary), + DataType::LargeBinary => { + dyn_generic!(BinaryArray, BinaryScalar, array, max_binary) + } _ => { return Err(ArrowError::InvalidArgumentError(format!( "The `max` operator does not support type `{}`", @@ -363,6 +413,10 @@ pub fn min(array: &dyn Array) -> Result> { DataType::Float64 => dyn_primitive!(f64, array, min_primitive), DataType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), DataType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), + DataType::Binary => dyn_generic!(BinaryArray, BinaryScalar, array, min_binary), + DataType::LargeBinary => { + dyn_generic!(BinaryArray, BinaryScalar, array, min_binary) + } _ => { return Err(ArrowError::InvalidArgumentError(format!( "The `max` operator does not support type `{}`", @@ -538,4 +592,25 @@ mod tests { assert_eq!(Some(true), min_boolean(&a)); assert_eq!(Some(true), max_boolean(&a)); } + + #[test] + fn test_binary_min_max_with_nulls() { + let a = BinaryArray::::from(&[Some(b"b"), None, None, Some(b"a"), Some(b"c")]); + assert_eq!("a".as_bytes(), min_binary(&a).unwrap()); + assert_eq!("c".as_bytes(), max_binary(&a).unwrap()); + } + + #[test] + fn test_binary_min_max_all_nulls() { + let a = BinaryArray::::from(&[None::<&[u8]>, None]); + assert_eq!(None, min_binary(&a)); + assert_eq!(None, max_binary(&a)); + } + + #[test] + fn test_binary_min_max_1() { + let a = BinaryArray::::from(&[None, None, Some(b"b"), Some(b"a")]); + assert_eq!(Some("a".as_bytes()), min_binary(&a)); + assert_eq!(Some("b".as_bytes()), max_binary(&a)); + } } diff --git a/src/compute/arithmetics/decimal/sub.rs b/src/compute/arithmetics/decimal/sub.rs index eba1f92bf94..17d7668dae4 100644 --- a/src/compute/arithmetics/decimal/sub.rs +++ b/src/compute/arithmetics/decimal/sub.rs @@ -314,30 +314,16 @@ mod tests { #[test] fn test_subtract_normal() { - let a = PrimitiveArray::from([ - Some(11111i128), - Some(22200i128), - None, - Some(40000i128), - ]) - .to(DataType::Decimal(5, 2)); + let a = PrimitiveArray::from([Some(11111i128), Some(22200i128), None, Some(40000i128)]) + .to(DataType::Decimal(5, 2)); - let b = PrimitiveArray::from([ - Some(22222i128), - Some(11100i128), - None, - Some(11100i128), - ]) - .to(DataType::Decimal(5, 2)); + let b = PrimitiveArray::from([Some(22222i128), Some(11100i128), None, Some(11100i128)]) + .to(DataType::Decimal(5, 2)); let result = sub(&a, &b).unwrap(); - let expected = PrimitiveArray::from([ - Some(-11111i128), - Some(11100i128), - None, - Some(28900i128), - ]) - .to(DataType::Decimal(5, 2)); + let expected = + PrimitiveArray::from([Some(-11111i128), Some(11100i128), None, Some(28900i128)]) + .to(DataType::Decimal(5, 2)); assert_eq!(result, expected); @@ -367,30 +353,16 @@ mod tests { #[test] fn test_subtract_saturating() { - let a = PrimitiveArray::from([ - Some(11111i128), - Some(22200i128), - None, - Some(40000i128), - ]) - .to(DataType::Decimal(5, 2)); + let a = PrimitiveArray::from([Some(11111i128), Some(22200i128), None, Some(40000i128)]) + .to(DataType::Decimal(5, 2)); - let b = PrimitiveArray::from([ - Some(22222i128), - Some(11100i128), - None, - Some(11100i128), - ]) - .to(DataType::Decimal(5, 2)); + let b = PrimitiveArray::from([Some(22222i128), Some(11100i128), None, Some(11100i128)]) + .to(DataType::Decimal(5, 2)); let result = saturating_sub(&a, &b).unwrap(); - let expected = PrimitiveArray::from([ - Some(-11111i128), - Some(11100i128), - None, - Some(28900i128), - ]) - .to(DataType::Decimal(5, 2)); + let expected = + PrimitiveArray::from([Some(-11111i128), Some(11100i128), None, Some(28900i128)]) + .to(DataType::Decimal(5, 2)); assert_eq!(result, expected); @@ -435,30 +407,16 @@ mod tests { #[test] fn test_subtract_checked() { - let a = PrimitiveArray::from([ - Some(11111i128), - Some(22200i128), - None, - Some(40000i128), - ]) - .to(DataType::Decimal(5, 2)); + let a = PrimitiveArray::from([Some(11111i128), Some(22200i128), None, Some(40000i128)]) + .to(DataType::Decimal(5, 2)); - let b = PrimitiveArray::from([ - Some(22222i128), - Some(11100i128), - None, - Some(11100i128), - ]) - .to(DataType::Decimal(5, 2)); + let b = PrimitiveArray::from([Some(22222i128), Some(11100i128), None, Some(11100i128)]) + .to(DataType::Decimal(5, 2)); let result = checked_sub(&a, &b).unwrap(); - let expected = PrimitiveArray::from([ - Some(-11111i128), - Some(11100i128), - None, - Some(28900i128), - ]) - .to(DataType::Decimal(5, 2)); + let expected = + PrimitiveArray::from([Some(-11111i128), Some(11100i128), None, Some(28900i128)]) + .to(DataType::Decimal(5, 2)); assert_eq!(result, expected); @@ -469,8 +427,7 @@ mod tests { #[test] fn test_subtract_checked_overflow() { - let a = - PrimitiveArray::from([Some(4i128), Some(-99999i128)]).to(DataType::Decimal(5, 2)); + let a = PrimitiveArray::from([Some(4i128), Some(-99999i128)]).to(DataType::Decimal(5, 2)); let b = PrimitiveArray::from([Some(2i128), Some(1i128)]).to(DataType::Decimal(5, 2)); let result = checked_sub(&a, &b).unwrap(); let expected = PrimitiveArray::from([Some(2i128), None]).to(DataType::Decimal(5, 2)); @@ -487,8 +444,7 @@ mod tests { let b = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(7, 2)); let result = adaptive_sub(&a, &b).unwrap(); - let expected = - PrimitiveArray::from([Some(-11099_9989i128)]).to(DataType::Decimal(9, 4)); + let expected = PrimitiveArray::from([Some(-11099_9989i128)]).to(DataType::Decimal(9, 4)); assert_eq!(result, expected); assert_eq!(result.data_type(), &DataType::Decimal(9, 4)); @@ -501,8 +457,7 @@ mod tests { let b = PrimitiveArray::from([Some(1111i128)]).to(DataType::Decimal(5, 4)); let result = adaptive_sub(&a, &b).unwrap(); - let expected = - PrimitiveArray::from([Some(11110_8889i128)]).to(DataType::Decimal(9, 4)); + let expected = PrimitiveArray::from([Some(11110_8889i128)]).to(DataType::Decimal(9, 4)); assert_eq!(result, expected); assert_eq!(result.data_type(), &DataType::Decimal(9, 4)); @@ -515,8 +470,7 @@ mod tests { let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal(8, 3)); let result = adaptive_sub(&a, &b).unwrap(); - let expected = - PrimitiveArray::from([Some(-00000_001i128)]).to(DataType::Decimal(8, 3)); + let expected = PrimitiveArray::from([Some(-00000_001i128)]).to(DataType::Decimal(8, 3)); assert_eq!(result, expected); assert_eq!(result.data_type(), &DataType::Decimal(8, 3)); diff --git a/src/compute/cast/binary_to.rs b/src/compute/cast/binary_to.rs index 79edc2af5da..3296ca3d715 100644 --- a/src/compute/cast/binary_to.rs +++ b/src/compute/cast/binary_to.rs @@ -1,8 +1,7 @@ use std::convert::TryFrom; -use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; -use crate::{array::*, buffer::Buffer}; +use crate::{array::*, buffer::Buffer, datatypes::DataType, types::NativeType}; pub fn binary_to_large_binary(from: &BinaryArray, to_data_type: DataType) -> BinaryArray { let values = from.values().clone(); @@ -28,3 +27,46 @@ pub fn binary_large_to_binary( from.validity().clone(), )) } + +/// Casts a [`BinaryArray`] to a [`PrimitiveArray`], making any uncastable value a Null. +pub fn binary_to_primitive(from: &BinaryArray, to: &DataType) -> PrimitiveArray +where + T: NativeType + lexical_core::FromLexical, +{ + let iter = from + .iter() + .map(|x| x.and_then::(|x| lexical_core::parse(x).ok())); + + PrimitiveArray::::from_trusted_len_iter(iter).to(to.clone()) +} + +pub(super) fn binary_to_primitive_dyn( + from: &dyn Array, + to: &DataType, +) -> Result> +where + T: NativeType + lexical_core::FromLexical, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(binary_to_primitive::(from, to))) +} + +/// Cast [`BinaryArray`] to [`DictionaryArray`], also known as packing. +/// # Errors +/// This function errors if the maximum key is smaller than the number of distinct elements +/// in the array. +pub fn binary_to_dictionary( + from: &BinaryArray, +) -> Result> { + let mut array = MutableDictionaryArray::>::new(); + array.try_extend(from.iter())?; + + Ok(array.into()) +} + +pub(super) fn binary_to_dictionary_dyn( + from: &dyn Array, +) -> Result> { + let values = from.as_any().downcast_ref().unwrap(); + binary_to_dictionary::(values).map(|x| Box::new(x) as Box) +} diff --git a/src/compute/cast/boolean_to.rs b/src/compute/cast/boolean_to.rs index 2d6c2437053..b6206527cc1 100644 --- a/src/compute/cast/boolean_to.rs +++ b/src/compute/cast/boolean_to.rs @@ -4,7 +4,7 @@ use crate::{ types::{NativeType, NaturalDataType}, }; use crate::{ - array::{Offset, Utf8Array}, + array::{BinaryArray, Offset, Utf8Array}, error::Result, }; @@ -40,3 +40,14 @@ pub(super) fn boolean_to_utf8_dyn(array: &dyn Array) -> Result(array))) } + +/// Casts the [`BooleanArray`] to a [`BinaryArray`], casting trues to `"1"` and falses to `"0"` +pub fn boolean_to_binary(from: &BooleanArray) -> BinaryArray { + let iter = from.values().iter().map(|x| if x { b"1" } else { b"0" }); + BinaryArray::from_trusted_len_values_iter(iter) +} + +pub(super) fn boolean_to_binary_dyn(array: &dyn Array) -> Result> { + let array = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(boolean_to_binary::(array))) +} diff --git a/src/compute/cast/mod.rs b/src/compute/cast/mod.rs index 451002b7b07..ccf8d9d379b 100644 --- a/src/compute/cast/mod.rs +++ b/src/compute/cast/mod.rs @@ -116,7 +116,13 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type), (_, Boolean) => is_numeric(from_type), - (Boolean, _) => is_numeric(to_type) || to_type == &Utf8 || to_type == &LargeUtf8, + (Boolean, _) => { + is_numeric(to_type) + || to_type == &Utf8 + || to_type == &LargeUtf8 + || to_type == &Binary + || to_type == &LargeBinary + } (Utf8, Date32) => true, (Utf8, Date64) => true, @@ -130,8 +136,11 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (LargeUtf8, _) => is_numeric(to_type), (_, Utf8) => is_numeric(from_type) || from_type == &Binary, (_, LargeUtf8) => is_numeric(from_type) || from_type == &Binary, - (Binary, LargeBinary) => true, - (LargeBinary, Binary) => true, + + (Binary, _) => is_numeric(to_type) || to_type == &LargeBinary, + (LargeBinary, _) => is_numeric(to_type) || to_type == &Binary, + (_, Binary) => is_numeric(from_type), + (_, LargeBinary) => is_numeric(from_type), // start numeric casts (UInt8, UInt16) => true, @@ -454,6 +463,8 @@ fn cast_with_options( Float64 => boolean_to_primitive_dyn::(array), Utf8 => boolean_to_utf8_dyn::(array), LargeUtf8 => boolean_to_utf8_dyn::(array), + Binary => boolean_to_binary_dyn::(array), + LargeBinary => boolean_to_binary_dyn::(array), _ => Err(ArrowError::NotYetImplemented(format!( "Casting from {:?} to {:?} not supported", from_type, to_type, @@ -560,14 +571,81 @@ fn cast_with_options( ))), }, - (Binary, LargeBinary) => Ok(Box::new(binary_to_large_binary( - array.as_any().downcast_ref().unwrap(), - to_type.clone(), - ))), - (LargeBinary, Binary) => { - binary_large_to_binary(array.as_any().downcast_ref().unwrap(), to_type.clone()) - .map(|x| Box::new(x) as Box) - } + (Binary, _) => match to_type { + UInt8 => binary_to_primitive_dyn::(array, to_type), + UInt16 => binary_to_primitive_dyn::(array, to_type), + UInt32 => binary_to_primitive_dyn::(array, to_type), + UInt64 => binary_to_primitive_dyn::(array, to_type), + Int8 => binary_to_primitive_dyn::(array, to_type), + Int16 => binary_to_primitive_dyn::(array, to_type), + Int32 => binary_to_primitive_dyn::(array, to_type), + Int64 => binary_to_primitive_dyn::(array, to_type), + Float32 => binary_to_primitive_dyn::(array, to_type), + Float64 => binary_to_primitive_dyn::(array, to_type), + LargeBinary => Ok(Box::new(binary_to_large_binary( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ))), + _ => Err(ArrowError::NotYetImplemented(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, + + (LargeBinary, _) => match to_type { + UInt8 => binary_to_primitive_dyn::(array, to_type), + UInt16 => binary_to_primitive_dyn::(array, to_type), + UInt32 => binary_to_primitive_dyn::(array, to_type), + UInt64 => binary_to_primitive_dyn::(array, to_type), + Int8 => binary_to_primitive_dyn::(array, to_type), + Int16 => binary_to_primitive_dyn::(array, to_type), + Int32 => binary_to_primitive_dyn::(array, to_type), + Int64 => binary_to_primitive_dyn::(array, to_type), + Float32 => binary_to_primitive_dyn::(array, to_type), + Float64 => binary_to_primitive_dyn::(array, to_type), + Binary => { + binary_large_to_binary(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .map(|x| Box::new(x) as Box) + } + _ => Err(ArrowError::NotYetImplemented(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, + + (_, Binary) => match from_type { + UInt8 => primitive_to_binary_dyn::(array), + UInt16 => primitive_to_binary_dyn::(array), + UInt32 => primitive_to_binary_dyn::(array), + UInt64 => primitive_to_binary_dyn::(array), + Int8 => primitive_to_binary_dyn::(array), + Int16 => primitive_to_binary_dyn::(array), + Int32 => primitive_to_binary_dyn::(array), + Int64 => primitive_to_binary_dyn::(array), + Float32 => primitive_to_binary_dyn::(array), + Float64 => primitive_to_binary_dyn::(array), + _ => Err(ArrowError::NotYetImplemented(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, + + (_, LargeBinary) => match from_type { + UInt8 => primitive_to_binary_dyn::(array), + UInt16 => primitive_to_binary_dyn::(array), + UInt32 => primitive_to_binary_dyn::(array), + UInt64 => primitive_to_binary_dyn::(array), + Int8 => primitive_to_binary_dyn::(array), + Int16 => primitive_to_binary_dyn::(array), + Int32 => primitive_to_binary_dyn::(array), + Int64 => primitive_to_binary_dyn::(array), + Float32 => primitive_to_binary_dyn::(array), + Float64 => primitive_to_binary_dyn::(array), + _ => Err(ArrowError::NotYetImplemented(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, // start numeric casts (UInt8, UInt16) => primitive_to_primitive_dyn::(array, to_type, as_options), @@ -755,6 +833,8 @@ fn cast_to_dictionary( DataType::UInt64 => primitive_to_dictionary_dyn::(array), DataType::Utf8 => utf8_to_dictionary_dyn::(array), DataType::LargeUtf8 => utf8_to_dictionary_dyn::(array), + DataType::Binary => binary_to_dictionary_dyn::(array), + DataType::LargeBinary => binary_to_dictionary_dyn::(array), _ => Err(ArrowError::NotYetImplemented(format!( "Unsupported output type for dictionary packing: {:?}", dict_value_type diff --git a/src/compute/cast/primitive_to.rs b/src/compute/cast/primitive_to.rs index ec8d02e850c..78171193e91 100644 --- a/src/compute/cast/primitive_to.rs +++ b/src/compute/cast/primitive_to.rs @@ -8,10 +8,31 @@ use crate::{ temporal_conversions::*, types::NativeType, }; -use crate::{error::Result, util::lexical_to_string}; +use crate::{ + error::Result, + util::{lexical_to_bytes, lexical_to_string}, +}; use super::CastOptions; +/// Returns a [`BinaryArray`] where every element is the binary representation of the number. +pub fn primitive_to_binary( + from: &PrimitiveArray, +) -> BinaryArray { + let iter = from.iter().map(|x| x.map(|x| lexical_to_bytes(*x))); + + BinaryArray::from_trusted_len_iter(iter) +} + +pub(super) fn primitive_to_binary_dyn(from: &dyn Array) -> Result> +where + O: Offset, + T: NativeType + lexical_core::ToLexical, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(primitive_to_binary::(from))) +} + /// Returns a [`BooleanArray`] where every element is different from zero. /// Validity is preserved. pub fn primitive_to_boolean( diff --git a/src/compute/comparison/binary.rs b/src/compute/comparison/binary.rs new file mode 100644 index 00000000000..eccc4189ccd --- /dev/null +++ b/src/compute/comparison/binary.rs @@ -0,0 +1,251 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::datatypes::DataType; +use crate::error::{ArrowError, Result}; +use crate::scalar::{BinaryScalar, Scalar}; +use crate::{array::*, bitmap::Bitmap}; + +use super::{super::utils::combine_validities, Operator}; + +/// Evaluate `op(lhs, rhs)` for [`BinaryArray`]s using a specified +/// comparison function. +fn compare_op(lhs: &BinaryArray, rhs: &BinaryArray, op: F) -> Result +where + O: Offset, + F: Fn(&[u8], &[u8]) -> bool, +{ + if lhs.len() != rhs.len() { + return Err(ArrowError::InvalidArgumentError( + "Cannot perform comparison operation on arrays of different length".to_string(), + )); + } + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let values = lhs + .values_iter() + .zip(rhs.values_iter()) + .map(|(lhs, rhs)| op(lhs, rhs)); + let values = Bitmap::from_trusted_len_iter(values); + + Ok(BooleanArray::from_data(DataType::Boolean, values, validity)) +} + +/// Evaluate `op(lhs, rhs)` for [`BinaryArray`] and scalar using +/// a specified comparison function. +fn compare_op_scalar(lhs: &BinaryArray, rhs: &[u8], op: F) -> BooleanArray +where + O: Offset, + F: Fn(&[u8], &[u8]) -> bool, +{ + let validity = lhs.validity().clone(); + + let values = lhs.values_iter().map(|lhs| op(lhs, rhs)); + let values = Bitmap::from_trusted_len_iter(values); + + BooleanArray::from_data(DataType::Boolean, values, validity) +} + +/// Perform `lhs == rhs` operation on [`BinaryArray`]. +fn eq(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { + compare_op(lhs, rhs, |a, b| a == b) +} + +/// Perform `lhs == rhs` operation on [`BinaryArray`] and a scalar. +fn eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a == b) +} + +/// Perform `lhs != rhs` operation on [`BinaryArray`]. +fn neq(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { + compare_op(lhs, rhs, |a, b| a != b) +} + +/// Perform `lhs != rhs` operation on [`BinaryArray`] and a scalar. +fn neq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a != b) +} + +/// Perform `lhs < rhs` operation on [`BinaryArray`]. +fn lt(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { + compare_op(lhs, rhs, |a, b| a < b) +} + +/// Perform `lhs < rhs` operation on [`BinaryArray`] and a scalar. +fn lt_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a < b) +} + +/// Perform `lhs <= rhs` operation on [`BinaryArray`]. +fn lt_eq(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { + compare_op(lhs, rhs, |a, b| a <= b) +} + +/// Perform `lhs <= rhs` operation on [`BinaryArray`] and a scalar. +fn lt_eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a <= b) +} + +/// Perform `lhs > rhs` operation on [`BinaryArray`]. +fn gt(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { + compare_op(lhs, rhs, |a, b| a > b) +} + +/// Perform `lhs > rhs` operation on [`BinaryArray`] and a scalar. +fn gt_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a > b) +} + +/// Perform `lhs >= rhs` operation on [`BinaryArray`]. +fn gt_eq(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { + compare_op(lhs, rhs, |a, b| a >= b) +} + +/// Perform `lhs >= rhs` operation on [`BinaryArray`] and a scalar. +fn gt_eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a >= b) +} + +pub fn compare( + lhs: &BinaryArray, + rhs: &BinaryArray, + op: Operator, +) -> Result { + match op { + Operator::Eq => eq(lhs, rhs), + Operator::Neq => neq(lhs, rhs), + Operator::Gt => gt(lhs, rhs), + Operator::GtEq => gt_eq(lhs, rhs), + Operator::Lt => lt(lhs, rhs), + Operator::LtEq => lt_eq(lhs, rhs), + } +} + +pub fn compare_scalar( + lhs: &BinaryArray, + rhs: &BinaryScalar, + op: Operator, +) -> BooleanArray { + if !rhs.is_valid() { + return BooleanArray::new_null(DataType::Boolean, lhs.len()); + } + compare_scalar_non_null(lhs, rhs.value(), op) +} + +pub fn compare_scalar_non_null( + lhs: &BinaryArray, + rhs: &[u8], + op: Operator, +) -> BooleanArray { + match op { + Operator::Eq => eq_scalar(lhs, rhs), + Operator::Neq => neq_scalar(lhs, rhs), + Operator::Gt => gt_scalar(lhs, rhs), + Operator::GtEq => gt_eq_scalar(lhs, rhs), + Operator::Lt => lt_scalar(lhs, rhs), + Operator::LtEq => lt_eq_scalar(lhs, rhs), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_generic, &BinaryArray) -> Result>( + lhs: Vec<&[u8]>, + rhs: Vec<&[u8]>, + op: F, + expected: Vec, + ) { + let lhs = BinaryArray::::from_slice(lhs); + let rhs = BinaryArray::::from_slice(rhs); + let expected = BooleanArray::from_slice(expected); + assert_eq!(op(&lhs, &rhs).unwrap(), expected); + } + + fn test_generic_scalar, &[u8]) -> BooleanArray>( + lhs: Vec<&[u8]>, + rhs: &[u8], + op: F, + expected: Vec, + ) { + let lhs = BinaryArray::::from_slice(lhs); + let expected = BooleanArray::from_slice(expected); + assert_eq!(op(&lhs, rhs), expected); + } + + #[test] + fn test_gt_eq() { + test_generic::( + vec![b"arrow", b"datafusion", b"flight", b"parquet"], + vec![b"flight", b"flight", b"flight", b"flight"], + gt_eq, + vec![false, false, true, true], + ) + } + + #[test] + fn test_gt_eq_scalar() { + test_generic_scalar::( + vec![b"arrow", b"datafusion", b"flight", b"parquet"], + b"flight", + gt_eq_scalar, + vec![false, false, true, true], + ) + } + + #[test] + fn test_eq() { + test_generic::( + vec![b"arrow", b"arrow", b"arrow", b"arrow"], + vec![b"arrow", b"parquet", b"datafusion", b"flight"], + eq, + vec![true, false, false, false], + ) + } + + #[test] + fn test_eq_scalar() { + test_generic_scalar::( + vec![b"arrow", b"parquet", b"datafusion", b"flight"], + b"arrow", + eq_scalar, + vec![true, false, false, false], + ) + } + + #[test] + fn test_neq() { + test_generic::( + vec![b"arrow", b"arrow", b"arrow", b"arrow"], + vec![b"arrow", b"parquet", b"datafusion", b"flight"], + neq, + vec![false, true, true, true], + ) + } + + #[test] + fn test_neq_scalar() { + test_generic_scalar::( + vec![b"arrow", b"parquet", b"datafusion", b"flight"], + b"arrow", + neq_scalar, + vec![false, true, true, true], + ) + } +} diff --git a/src/compute/comparison/mod.rs b/src/compute/comparison/mod.rs index ceaf548ca68..5c3514145c4 100644 --- a/src/compute/comparison/mod.rs +++ b/src/compute/comparison/mod.rs @@ -22,6 +22,7 @@ use crate::datatypes::{DataType, IntervalUnit}; use crate::error::{ArrowError, Result}; use crate::scalar::Scalar; +mod binary; mod boolean; mod primitive; mod utf8; @@ -131,6 +132,16 @@ pub fn compare(lhs: &dyn Array, rhs: &dyn Array, operator: Operator) -> Result(lhs, rhs, operator) } + DataType::Binary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::compare::(lhs, rhs, operator) + } + DataType::LargeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::compare::(lhs, rhs, operator) + } _ => Err(ArrowError::NotYetImplemented(format!( "Comparison between {:?} is not supported", data_type @@ -233,6 +244,16 @@ pub fn compare_scalar( let rhs = rhs.as_any().downcast_ref().unwrap(); utf8::compare_scalar::(lhs, rhs, operator) } + DataType::Binary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::compare_scalar::(lhs, rhs, operator) + } + DataType::LargeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::compare_scalar::(lhs, rhs, operator) + } _ => { return Err(ArrowError::NotYetImplemented(format!( "Comparison between {:?} is not supported", @@ -242,6 +263,7 @@ pub fn compare_scalar( }) } +pub use binary::compare_scalar_non_null as binary_compare_scalar; pub use boolean::compare_scalar_non_null as boolean_compare_scalar; pub use primitive::compare_scalar_non_null as primitive_compare_scalar; pub(crate) use primitive::compare_values_op as primitive_compare_values_op; @@ -259,7 +281,7 @@ pub use utf8::compare_scalar_non_null as utf8_compare_scalar; /// assert_eq!(can_compare(&data_type), true); /// /// let data_type = DataType::LargeBinary; -/// assert_eq!(can_compare(&data_type), false) +/// assert_eq!(can_compare(&data_type), true) /// ``` pub fn can_compare(data_type: &DataType) -> bool { matches!( @@ -285,6 +307,8 @@ pub fn can_compare(data_type: &DataType) -> bool { | DataType::Utf8 | DataType::LargeUtf8 | DataType::Decimal(_, _) + | DataType::Binary + | DataType::LargeBinary ) } diff --git a/src/compute/contains.rs b/src/compute/contains.rs index 44cc7a83639..8e534ad8ed1 100644 --- a/src/compute/contains.rs +++ b/src/compute/contains.rs @@ -17,7 +17,7 @@ use crate::types::NativeType; use crate::{ - array::{Array, BooleanArray, ListArray, Offset, PrimitiveArray, Utf8Array}, + array::{Array, BinaryArray, BooleanArray, ListArray, Offset, PrimitiveArray, Utf8Array}, bitmap::Bitmap, }; use crate::{ @@ -96,6 +96,40 @@ where Ok(BooleanArray::from_data(DataType::Boolean, values, validity)) } +/// Checks if a [`GenericListArray`] contains a value in the [`BinaryArray`] +fn contains_binary(list: &ListArray, values: &BinaryArray) -> Result +where + O: Offset, + OO: Offset, +{ + if list.len() != values.len() { + return Err(ArrowError::InvalidArgumentError( + "Contains requires arrays of the same length".to_string(), + )); + } + if list.values().data_type() != values.data_type() { + return Err(ArrowError::InvalidArgumentError( + "Contains requires the inner array to be of the same logical type".to_string(), + )); + } + + let validity = combine_validities(list.validity(), values.validity()); + + let values = list.iter().zip(values.iter()).map(|(list, values)| { + if list.is_none() | values.is_none() { + // validity takes care of this + return false; + }; + let list = list.unwrap(); + let list = list.as_any().downcast_ref::>().unwrap(); + let values = values.unwrap(); + list.iter().any(|x| x.map(|x| x == values).unwrap_or(false)) + }); + let values = Bitmap::from_trusted_len_iter(values); + + Ok(BooleanArray::from_data(DataType::Boolean, values, validity)) +} + macro_rules! primitive { ($list:expr, $values:expr, $l_ty:ty, $r_ty:ty) => {{ let list = $list.as_any().downcast_ref::>().unwrap(); @@ -132,6 +166,26 @@ pub fn contains(list: &dyn Array, values: &dyn Array) -> Result { let values = values.as_any().downcast_ref::>().unwrap(); contains_utf8(list, values) } + (DataType::List(_), DataType::Binary) => { + let list = list.as_any().downcast_ref::>().unwrap(); + let values = values.as_any().downcast_ref::>().unwrap(); + contains_binary(list, values) + } + (DataType::List(_), DataType::LargeBinary) => { + let list = list.as_any().downcast_ref::>().unwrap(); + let values = values.as_any().downcast_ref::>().unwrap(); + contains_binary(list, values) + } + (DataType::LargeList(_), DataType::LargeBinary) => { + let list = list.as_any().downcast_ref::>().unwrap(); + let values = values.as_any().downcast_ref::>().unwrap(); + contains_binary(list, values) + } + (DataType::LargeList(_), DataType::Binary) => { + let list = list.as_any().downcast_ref::>().unwrap(); + let values = values.as_any().downcast_ref::>().unwrap(); + contains_binary(list, values) + } (DataType::List(_), DataType::Int8) => primitive!(list, values, i32, i8), (DataType::List(_), DataType::Int16) => primitive!(list, values, i32, i16), (DataType::List(_), DataType::Int32) => primitive!(list, values, i32, i32), @@ -196,4 +250,29 @@ mod tests { assert_eq!(result, expected); } + + #[test] + fn test_contains_binary() { + let data = vec![ + Some(vec![Some(b"a"), Some(b"b"), None]), + Some(vec![Some(b"a"), Some(b"b"), None]), + Some(vec![Some(b"a"), Some(b"b"), None]), + None, + ]; + let values = BinaryArray::::from(&[Some(b"a"), Some(b"c"), None, Some(b"a")]); + let expected = BooleanArray::from(vec![ + Some(true), + Some(false), + None, + None + ]); + + let mut a = MutableListArray::>::new(); + a.try_extend(data).unwrap(); + let a: ListArray = a.into(); + + let result = contains(&a, &values).unwrap(); + + assert_eq!(result, expected); + } } diff --git a/src/compute/like.rs b/src/compute/like.rs index abda4f912e8..865e44c74cc 100644 --- a/src/compute/like.rs +++ b/src/compute/like.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use regex::bytes::Regex as BytesRegex; use regex::Regex; use crate::datatypes::DataType; @@ -148,3 +149,191 @@ pub fn like_utf8_scalar(lhs: &Utf8Array, rhs: &str) -> Result(lhs: &Utf8Array, rhs: &str) -> Result { a_like_utf8_scalar(lhs, rhs, |x| !x) } + +#[inline] +fn a_like_binary bool>( + lhs: &BinaryArray, + rhs: &BinaryArray, + op: F, +) -> Result { + if lhs.len() != rhs.len() { + return Err(ArrowError::InvalidArgumentError( + "Cannot perform comparison operation on arrays of different length".to_string(), + )); + } + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let mut map = HashMap::new(); + + let values = + Bitmap::try_from_trusted_len_iter(lhs.iter().zip(rhs.iter()).map(|(lhs, rhs)| { + match (lhs, rhs) { + (Some(lhs), Some(pattern)) => { + let pattern = if let Some(pattern) = map.get(pattern) { + pattern + } else { + let re_pattern = std::str::from_utf8(pattern) + .unwrap() + .replace("%", ".*") + .replace("_", "."); + let re = BytesRegex::new(&format!("^{}$", re_pattern)).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Unable to build regex from LIKE pattern: {}", + e + )) + })?; + map.insert(pattern, re); + map.get(pattern).unwrap() + }; + Result::Ok(op(pattern.is_match(lhs))) + } + _ => Ok(false), + } + }))?; + + Ok(BooleanArray::from_data(DataType::Boolean, values, validity)) +} + +/// Returns `lhs LIKE rhs` operation on two [`BinaryArray`]. +/// +/// There are two wildcards supported: +/// +/// * `%` - The percent sign represents zero, one, or multiple characters +/// * `_` - The underscore represents a single character +/// +/// # Error +/// Errors iff: +/// * the arrays have a different length +/// * any of the patterns is not valid +/// # Example +/// ``` +/// use arrow2::array::{BinaryArray, BooleanArray}; +/// use arrow2::compute::like::like_binary; +/// +/// let strings = BinaryArray::::from_slice(&["Arrow", "Arrow", "Arrow", "Arrow", "Ar"]); +/// let patterns = BinaryArray::::from_slice(&["A%", "B%", "%r_ow", "A_", "A_"]); +/// +/// let result = like_binary(&strings, &patterns).unwrap(); +/// assert_eq!(result, BooleanArray::from_slice(&[true, false, true, false, true])); +/// ``` +pub fn like_binary(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { + a_like_binary(lhs, rhs, |x| x) +} + +pub fn nlike_binary(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { + a_like_binary(lhs, rhs, |x| !x) +} + +fn a_like_binary_scalar bool>( + lhs: &BinaryArray, + rhs: &[u8], + op: F, +) -> Result { + let validity = lhs.validity(); + + let pattern = std::str::from_utf8(rhs).unwrap(); + + let values = if !pattern.contains(is_like_pattern) { + Bitmap::from_trusted_len_iter(lhs.values_iter().map(|x| x == rhs)) + } else if pattern.ends_with('%') && !pattern[..pattern.len() - 1].contains(is_like_pattern) { + // fast path, can use starts_with + let starts_with = &rhs[..rhs.len() - 1]; + Bitmap::from_trusted_len_iter(lhs.values_iter().map(|x| op(x.starts_with(starts_with)))) + } else if pattern.starts_with('%') && !pattern[1..].contains(is_like_pattern) { + // fast path, can use ends_with + let ends_with = &rhs[1..]; + Bitmap::from_trusted_len_iter(lhs.values_iter().map(|x| op(x.ends_with(ends_with)))) + } else { + let re_pattern = pattern.replace("%", ".*").replace("_", "."); + let re = BytesRegex::new(&format!("^{}$", re_pattern)).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Unable to build regex from LIKE pattern: {}", + e + )) + })?; + Bitmap::from_trusted_len_iter(lhs.values_iter().map(|x| op(re.is_match(x)))) + }; + Ok(BooleanArray::from_data( + DataType::Boolean, + values, + validity.clone(), + )) +} + +/// Returns `lhs LIKE rhs` operation. +/// +/// There are two wildcards supported: +/// +/// * `%` - The percent sign represents zero, one, or multiple characters +/// * `_` - The underscore represents a single character +/// +/// # Error +/// Errors iff: +/// * the arrays have a different length +/// * any of the patterns is not valid +/// # Example +/// ``` +/// use arrow2::array::{BinaryArray, BooleanArray}; +/// use arrow2::compute::like::like_binary_scalar; +/// +/// let array = BinaryArray::::from_slice(&["Arrow", "Arrow", "Arrow", "BA"]); +/// +/// let result = like_binary_scalar(&array, &"A%").unwrap(); +/// assert_eq!(result, BooleanArray::from_slice(&[true, true, true, false])); +/// ``` +pub fn like_binary_scalar(lhs: &BinaryArray, rhs: &[u8]) -> Result { + a_like_binary_scalar(lhs, rhs, |x| x) +} + +pub fn nlike_binary_scalar(lhs: &BinaryArray, rhs: &[u8]) -> Result { + a_like_binary_scalar(lhs, rhs, |x| !x) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_like_binary() -> Result<()> { + let strings = BinaryArray::::from_slice(&["Arrow", "Arrow", "Arrow", "Arrow", "Ar"]); + let patterns = BinaryArray::::from_slice(&["A%", "B%", "%r_ow", "A_", "A_"]); + let result = like_binary(&strings, &patterns).unwrap(); + assert_eq!( + result, + BooleanArray::from_slice(&[true, false, true, false, true]) + ); + Ok(()) + } + + #[test] + fn test_nlike_binary() -> Result<()> { + let strings = BinaryArray::::from_slice(&["Arrow", "Arrow", "Arrow", "Arrow", "Ar"]); + let patterns = BinaryArray::::from_slice(&["A%", "B%", "%r_ow", "A_", "A_"]); + let result = nlike_binary(&strings, &patterns).unwrap(); + assert_eq!( + result, + BooleanArray::from_slice(&[false, true, false, true, false]) + ); + Ok(()) + } + + #[test] + fn test_like_binary_scalar() -> Result<()> { + let array = BinaryArray::::from_slice(&["Arrow", "Arrow", "Arrow", "BA"]); + let result = like_binary_scalar(&array, b"A%").unwrap(); + assert_eq!(result, BooleanArray::from_slice(&[true, true, true, false])); + Ok(()) + } + + #[test] + fn test_nlike_binary_scalar() -> Result<()> { + let array = BinaryArray::::from_slice(&["Arrow", "Arrow", "Arrow", "BA"]); + let result = nlike_binary_scalar(&array, "A%".as_bytes()).unwrap(); + assert_eq!( + result, + BooleanArray::from_slice(&[false, false, false, true]) + ); + Ok(()) + } +} diff --git a/src/compute/merge_sort/mod.rs b/src/compute/merge_sort/mod.rs index bf0b97734e8..66c4afabd95 100644 --- a/src/compute/merge_sort/mod.rs +++ b/src/compute/merge_sort/mod.rs @@ -532,7 +532,7 @@ pub fn build_comparator<'a>( #[cfg(test)] mod tests { - use crate::array::{Int32Array, Utf8Array}; + use crate::array::{BinaryArray, Int32Array, Utf8Array}; use crate::compute::sort::sort; use super::*; @@ -638,6 +638,37 @@ mod tests { Ok(()) } + #[test] + fn test_merge_binary() -> Result<()> { + let a0: &dyn Array = &BinaryArray::::from_slice(&[b"a", b"c", b"d", b"e"]); + let a1: &dyn Array = &BinaryArray::::from_slice(&[b"b", b"y", b"z", b"z"]); + + let options = SortOptions::default(); + let arrays = vec![a0, a1]; + let pairs = vec![(arrays.as_ref(), &options)]; + let comparator = build_comparator(&pairs)?; + + // (0, 0, 4) corresponds to slice ["a", "c", "d", "e"] of a0 + // (1, 0, 4) corresponds to slice ["b", "y", "z", "z"] of a1 + + let result = + merge_sort_slices(once(&(0, 0, 4)), once(&(1, 0, 4)), &comparator).collect::>(); + + // "a" (a0) , "b" (a1) , ["c", "d", "e"] (a0), ["y", "z", "z"] (a1) + // (0, 0, 1), (1, 0, 1), (0, 1, 3) , (1, 1, 3) + assert_eq!(result, vec![(0, 0, 1), (1, 0, 1), (0, 1, 3), (1, 1, 3)]); + + // (0, 1, 2) corresponds to slice ["c", "d"] of a0 + // (1, 0, 3) corresponds to slice ["b", "y", "z"] of a1 + let result = + merge_sort_slices(once(&(0, 1, 2)), once(&(1, 0, 3)), &comparator).collect::>(); + + // "b" (a1) , ["c", "d"] (a0) , ["y", "z"] + // (1, 0, 1), (0, 1, 2) , (1, 1, 2) + assert_eq!(result, vec![(1, 0, 1), (0, 1, 2), (1, 1, 2)]); + Ok(()) + } + #[test] fn test_merge_string() -> Result<()> { let a0: &dyn Array = &Utf8Array::::from_slice(&["a", "c", "d", "e"]); diff --git a/src/compute/sort/binary.rs b/src/compute/sort/binary.rs new file mode 100644 index 00000000000..109f88b8564 --- /dev/null +++ b/src/compute/sort/binary.rs @@ -0,0 +1,15 @@ +use crate::array::{Array, BinaryArray, Offset, PrimitiveArray}; +use crate::types::Index; + +use super::common; +use super::SortOptions; + +pub(super) fn indices_sorted_unstable_by( + array: &BinaryArray, + options: &SortOptions, + limit: Option, +) -> PrimitiveArray { + let get = |idx| unsafe { array.value_unchecked(idx as usize) }; + let cmp = |lhs: &&[u8], rhs: &&[u8]| lhs.cmp(rhs); + common::indices_sorted_unstable_by(array.validity(), get, cmp, array.len(), options, limit) +} diff --git a/src/compute/sort/mod.rs b/src/compute/sort/mod.rs index f0ad5002255..036bc8b4e84 100644 --- a/src/compute/sort/mod.rs +++ b/src/compute/sort/mod.rs @@ -11,6 +11,7 @@ use crate::{ use crate::buffer::MutableBuffer; +mod binary; mod boolean; mod common; mod lex_sort; @@ -141,6 +142,16 @@ pub fn sort_to_indices( options, limit, )), + DataType::Binary => Ok(binary::indices_sorted_unstable_by::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), + DataType::LargeBinary => Ok(binary::indices_sorted_unstable_by::( + values.as_any().downcast_ref().unwrap(), + options, + limit, + )), DataType::List(field) => { let (v, n) = partition_validity(values); match field.data_type() { @@ -243,7 +254,7 @@ fn sort_dict( /// assert_eq!(can_sort(&data_type), true); /// /// let data_type = DataType::LargeBinary; -/// assert_eq!(can_sort(&data_type), false) +/// assert_eq!(can_sort(&data_type), true) /// ``` pub fn can_sort(data_type: &DataType) -> bool { match data_type { @@ -266,7 +277,9 @@ pub fn can_sort(data_type: &DataType) -> bool { | DataType::Float32 | DataType::Float64 | DataType::Utf8 - | DataType::LargeUtf8 => true, + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary => true, DataType::List(field) | DataType::LargeList(field) | DataType::FixedSizeList(field, _) => { matches!( field.data_type(), diff --git a/src/compute/substring.rs b/src/compute/substring.rs index c82029b747f..aac841bbe82 100644 --- a/src/compute/substring.rs +++ b/src/compute/substring.rs @@ -68,11 +68,76 @@ fn utf8_substring(array: &Utf8Array, start: O, length: &Option) ) } +fn binary_substring( + array: &BinaryArray, + start: O, + length: &Option, +) -> BinaryArray { + let validity = array.validity(); + let offsets = array.offsets(); + let values = array.values(); + + let mut new_offsets = MutableBuffer::::with_capacity(array.len() + 1); + let mut new_values = MutableBuffer::::new(); // we have no way to estimate how much this will be. + + let mut length_so_far = O::zero(); + new_offsets.push(length_so_far); + + offsets.windows(2).for_each(|windows| { + let length_i: O = windows[1] - windows[0]; + + // compute where we should start slicing this entry + let start = windows[0] + + if start >= O::zero() { + start + } else { + length_i + start + }; + let start = start.max(windows[0]).min(windows[1]); + + let length: O = length + .unwrap_or(length_i) + // .max(0) is not needed as it is guaranteed + .min(windows[1] - start); // so we do not go beyond this entry + length_so_far += length; + new_offsets.push(length_so_far); + + // we need usize for ranges + let start = start.to_usize(); + let length = length.to_usize(); + + new_values.extend_from_slice(&values[start..start + length]); + }); + + BinaryArray::::from_data( + array.data_type().clone(), + new_offsets.into(), + new_values.into(), + validity.clone(), + ) +} + /// Returns an ArrayRef with a substring starting from `start` and with optional length `length` of each of the elements in `array`. /// `start` can be negative, in which case the start counts from the end of the string. /// this function errors when the passed array is not a \[Large\]String array. pub fn substring(array: &dyn Array, start: i64, length: &Option) -> Result> { match array.data_type() { + DataType::Binary => Ok(Box::new(binary_substring( + array + .as_any() + .downcast_ref::>() + .expect("A binary is expected"), + start as i32, + &length.map(|e| e as i32), + ))), + DataType::LargeBinary => Ok(Box::new(binary_substring( + array + .as_any() + .downcast_ref::>() + .expect("A large binary is expected"), + start, + &length.map(|e| e as i64), + ))), DataType::LargeUtf8 => Ok(Box::new(utf8_substring( array .as_any() @@ -110,14 +175,17 @@ pub fn substring(array: &dyn Array, start: i64, length: &Option) -> Result< /// assert_eq!(can_substring(&data_type), false); /// ``` pub fn can_substring(data_type: &DataType) -> bool { - matches!(data_type, DataType::LargeUtf8 | DataType::Utf8) + matches!( + data_type, + DataType::LargeUtf8 | DataType::Utf8 | DataType::LargeBinary | DataType::Binary + ) } #[cfg(test)] mod tests { use super::*; - fn with_nulls() -> Result<()> { + fn with_nulls_utf8() -> Result<()> { let cases = vec![ // identity ( @@ -174,15 +242,15 @@ mod tests { #[test] fn with_nulls_string() -> Result<()> { - with_nulls::() + with_nulls_utf8::() } #[test] fn with_nulls_large_string() -> Result<()> { - with_nulls::() + with_nulls_utf8::() } - fn without_nulls() -> Result<()> { + fn without_nulls_utf8() -> Result<()> { let cases = vec![ // increase start ( @@ -253,12 +321,156 @@ mod tests { #[test] fn without_nulls_string() -> Result<()> { - without_nulls::() + without_nulls_utf8::() } #[test] fn without_nulls_large_string() -> Result<()> { - without_nulls::() + without_nulls_utf8::() + } + + fn with_null_binarys() -> Result<()> { + let cases = vec![ + // identity + ( + vec![Some(b"hello"), None, Some(b"world")], + 0, + None, + vec![Some("hello"), None, Some("world")], + ), + // 0 length -> Nothing + ( + vec![Some(b"hello"), None, Some(b"world")], + 0, + Some(0), + vec![Some(""), None, Some("")], + ), + // high start -> Nothing + ( + vec![Some(b"hello"), None, Some(b"world")], + 1000, + Some(0), + vec![Some(""), None, Some("")], + ), + // high negative start -> identity + ( + vec![Some(b"hello"), None, Some(b"world")], + -1000, + None, + vec![Some("hello"), None, Some("world")], + ), + // high length -> identity + ( + vec![Some(b"hello"), None, Some(b"world")], + 0, + Some(1000), + vec![Some("hello"), None, Some("world")], + ), + ]; + + cases + .into_iter() + .try_for_each::<_, Result<()>>(|(array, start, length, expected)| { + let array = BinaryArray::::from(&array); + let result = substring(&array, start, &length)?; + assert_eq!(array.len(), result.len()); + + let result = result.as_any().downcast_ref::>().unwrap(); + let expected = BinaryArray::::from(&expected); + assert_eq!(&expected, result); + Ok(()) + })?; + + Ok(()) + } + + #[test] + fn with_nulls_binary() -> Result<()> { + with_null_binarys::() + } + + #[test] + fn with_nulls_large_binary() -> Result<()> { + with_null_binarys::() + } + + fn without_null_binarys() -> Result<()> { + let cases = vec![ + // increase start + ( + vec!["hello", "", "word"], + 0, + None, + vec!["hello", "", "word"], + ), + (vec!["hello", "", "word"], 1, None, vec!["ello", "", "ord"]), + (vec!["hello", "", "word"], 2, None, vec!["llo", "", "rd"]), + (vec!["hello", "", "word"], 3, None, vec!["lo", "", "d"]), + (vec!["hello", "", "word"], 10, None, vec!["", "", ""]), + // increase start negatively + (vec!["hello", "", "word"], -1, None, vec!["o", "", "d"]), + (vec!["hello", "", "word"], -2, None, vec!["lo", "", "rd"]), + (vec!["hello", "", "word"], -3, None, vec!["llo", "", "ord"]), + ( + vec!["hello", "", "word"], + -10, + None, + vec!["hello", "", "word"], + ), + // increase length + (vec!["hello", "", "word"], 1, Some(1), vec!["e", "", "o"]), + (vec!["hello", "", "word"], 1, Some(2), vec!["el", "", "or"]), + ( + vec!["hello", "", "word"], + 1, + Some(3), + vec!["ell", "", "ord"], + ), + ( + vec!["hello", "", "word"], + 1, + Some(4), + vec!["ello", "", "ord"], + ), + (vec!["hello", "", "word"], -3, Some(1), vec!["l", "", "o"]), + (vec!["hello", "", "word"], -3, Some(2), vec!["ll", "", "or"]), + ( + vec!["hello", "", "word"], + -3, + Some(3), + vec!["llo", "", "ord"], + ), + ( + vec!["hello", "", "word"], + -3, + Some(4), + vec!["llo", "", "ord"], + ), + ]; + + cases + .into_iter() + .try_for_each::<_, Result<()>>(|(array, start, length, expected)| { + let array = BinaryArray::::from_slice(&array); + let result = substring(&array, start, &length)?; + assert_eq!(array.len(), result.len()); + let result = result.as_any().downcast_ref::>().unwrap(); + let expected = BinaryArray::::from_slice(&expected); + assert_eq!(&expected, result); + Ok(()) + })?; + + Ok(()) + } + + #[test] + fn without_nulls_binary() -> Result<()> { + without_null_binarys::() + } + + #[test] + fn without_nulls_large_binary() -> Result<()> { + without_null_binarys::() } #[test] diff --git a/src/io/parquet/read/binary/mod.rs b/src/io/parquet/read/binary/mod.rs index 54da7b3f1db..1bfd2b04235 100644 --- a/src/io/parquet/read/binary/mod.rs +++ b/src/io/parquet/read/binary/mod.rs @@ -3,6 +3,6 @@ mod dictionary; mod nested; pub use basic::iter_to_array; -pub use dictionary::iter_to_array as iter_to_dict_array; pub use basic::stream_to_array; +pub use dictionary::iter_to_array as iter_to_dict_array; pub use nested::iter_to_array as iter_to_array_nested; diff --git a/tests/it/array/binary/mod.rs b/tests/it/array/binary/mod.rs index a7c64010348..5641a71ce95 100644 --- a/tests/it/array/binary/mod.rs +++ b/tests/it/array/binary/mod.rs @@ -4,6 +4,8 @@ use arrow2::{ datatypes::DataType, }; +mod mutable; + #[test] fn basics() { let data = vec![Some(b"hello".to_vec()), None, Some(b"hello2".to_vec())]; @@ -55,3 +57,17 @@ fn from() { let a = array.validity().as_ref().unwrap(); assert_eq!(a, &Bitmap::from([true, true, false])); } + +#[test] +fn from_trusted_len_iter() { + let iter = std::iter::repeat(b"hello").take(2).map(Some); + let a = BinaryArray::::from_trusted_len_iter(iter); + assert_eq!(a.len(), 2); +} + +#[test] +fn from_iter() { + let iter = std::iter::repeat(b"hello").take(2).map(Some); + let a: BinaryArray:: = iter.collect(); + assert_eq!(a.len(), 2); +} diff --git a/tests/it/array/binary/mutable.rs b/tests/it/array/binary/mutable.rs new file mode 100644 index 00000000000..e7788339140 --- /dev/null +++ b/tests/it/array/binary/mutable.rs @@ -0,0 +1,11 @@ +use arrow2::array::{Array, BinaryArray, MutableBinaryArray}; +use arrow2::bitmap::Bitmap; + +#[test] +fn push_null() { + let mut array = MutableBinaryArray::::new(); + array.push::<&str>(None); + + let array: BinaryArray = array.into(); + assert_eq!(array.validity(), &Some(Bitmap::from([false]))); +} diff --git a/tests/it/compute/cast.rs b/tests/it/compute/cast.rs index f979924f197..8836fc098a4 100644 --- a/tests/it/compute/cast.rs +++ b/tests/it/compute/cast.rs @@ -153,6 +153,26 @@ fn i32_to_list_f64_nullable_sliced() { assert_eq!(c, &expected); } +#[test] +fn i32_to_binary() { + let array = Int32Array::from_slice(&[5, 6, 7]); + let b = cast(&array, &DataType::Binary).unwrap(); + let expected = BinaryArray::::from(&[Some(b"5"), Some(b"6"), Some(b"7")]); + let c = b.as_any().downcast_ref::>().unwrap(); + assert_eq!(c, &expected); +} + +#[test] +fn binary_to_i32() { + let array = BinaryArray::::from_slice(&["5", "6", "seven", "8", "9.1"]); + let b = cast(&array, &DataType::Int32).unwrap(); + let c = b.as_any().downcast_ref::>().unwrap(); + + let expected = &[Some(5), Some(6), None, Some(8), None]; + let expected = Int32Array::from(expected); + assert_eq!(c, &expected); +} + #[test] fn utf8_to_i32() { let array = Utf8Array::::from_slice(&["5", "6", "seven", "8", "9.1"]); @@ -186,6 +206,26 @@ fn bool_to_f64() { assert_eq!(c, &expected); } +#[test] +fn bool_to_utf8() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::Utf8).unwrap(); + let c = b.as_any().downcast_ref::>().unwrap(); + + let expected = Utf8Array::::from(&[Some("1"), Some("0"), Some("0")]); + assert_eq!(c, &expected); +} + +#[test] +fn bool_to_binary() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::Binary).unwrap(); + let c = b.as_any().downcast_ref::>().unwrap(); + + let expected = BinaryArray::::from(&[Some("1"), Some("0"), Some("0")]); + assert_eq!(c, &expected); +} + #[test] #[should_panic(expected = "Casting from Int32 to Timestamp(Microsecond, None) not supported")] fn int32_to_timestamp() { diff --git a/tests/it/io/json/write.rs b/tests/it/io/json/write.rs index 93c6fb27f8d..f2131778503 100644 --- a/tests/it/io/json/write.rs +++ b/tests/it/io/json/write.rs @@ -196,7 +196,7 @@ fn write_list_of_struct() { vec![ Arc::new(Int32Array::from(&[Some(1), None, Some(5)])), Arc::new(StructArray::from_data( - DataType::Struct(inner), + DataType::Struct(inner), vec![Arc::new(Utf8Array::::from(&vec![ Some("e"), Some("f"),