diff --git a/src/compute/cast/binary_to.rs b/src/compute/cast/binary_to.rs index aee7c8fcabb..44224321b55 100644 --- a/src/compute/cast/binary_to.rs +++ b/src/compute/cast/binary_to.rs @@ -36,6 +36,37 @@ pub fn binary_large_to_binary( )) } +/// Conversion to utf8 +pub fn binary_to_utf8( + from: &BinaryArray, + to_data_type: DataType, +) -> Result> { + Utf8Array::::try_new( + to_data_type, + from.offsets().clone(), + from.values().clone(), + from.validity().cloned(), + ) +} + +/// Conversion to utf8 +/// # Errors +/// This function errors if the values are not valid utf8 +pub fn binary_to_large_utf8( + from: &BinaryArray, + to_data_type: DataType, +) -> Result> { + let values = from.values().clone(); + let offsets = from + .offsets() + .iter() + .map(|x| *x as i64) + .collect::>() + .into(); + + Utf8Array::::try_new(to_data_type, offsets, values, from.validity().cloned()) +} + /// Casts a [`BinaryArray`] to a [`PrimitiveArray`] at best-effort using `lexical_core::parse_partial`, making any uncastable value as zero. pub fn partial_binary_to_primitive( from: &BinaryArray, diff --git a/src/compute/cast/mod.rs b/src/compute/cast/mod.rs index d3c37cd73fb..d58210fb9c9 100644 --- a/src/compute/cast/mod.rs +++ b/src/compute/cast/mod.rs @@ -106,23 +106,31 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { || to_type == &LargeBinary } - (Utf8, Date32) => true, - (Utf8, Date64) => true, - (Utf8, Timestamp(TimeUnit::Nanosecond, _)) => true, - (Utf8, LargeUtf8) => true, - (Utf8, _) => is_numeric(to_type), - (LargeUtf8, Date32) => true, - (LargeUtf8, Date64) => true, - (LargeUtf8, Timestamp(TimeUnit::Nanosecond, _)) => true, - (LargeUtf8, Utf8) => true, - (LargeUtf8, _) => is_numeric(to_type), + (Utf8, to_type) => { + is_numeric(to_type) + || matches!( + to_type, + LargeUtf8 | Binary | Date32 | Date64 | Timestamp(TimeUnit::Nanosecond, _) + ) + } + (LargeUtf8, to_type) => { + is_numeric(to_type) + || matches!( + to_type, + Utf8 | LargeBinary | Date32 | Date64 | Timestamp(TimeUnit::Nanosecond, _) + ) + } + + (Binary, to_type) => { + is_numeric(to_type) || matches!(to_type, LargeBinary | Utf8 | LargeUtf8) + } + (LargeBinary, to_type) => is_numeric(to_type) || matches!(to_type, Binary | LargeUtf8), + (Timestamp(_, _), Utf8) => true, (Timestamp(_, _), LargeUtf8) => true, (_, Utf8) => is_numeric(from_type) || from_type == &Binary, - (_, LargeUtf8) => is_numeric(from_type) || from_type == &Binary, + (_, LargeUtf8) => is_numeric(from_type) || from_type == &LargeBinary, - (Binary, _) => is_numeric(to_type) || to_type == &LargeBinary, - (LargeBinary, _) => is_numeric(to_type) || to_type == &Binary, (_, Binary) => is_numeric(from_type), (_, LargeBinary) => is_numeric(from_type), @@ -380,22 +388,18 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu )), (List(_), List(_)) => { cast_list::(array.as_any().downcast_ref().unwrap(), to_type, options) - .map(|x| Box::new(x) as Box) + .map(|x| x.boxed()) } (LargeList(_), LargeList(_)) => { cast_list::(array.as_any().downcast_ref().unwrap(), to_type, options) - .map(|x| Box::new(x) as Box) + .map(|x| x.boxed()) + } + (List(lhs), LargeList(rhs)) if lhs == rhs => { + Ok(cast_list_to_large_list(array.as_any().downcast_ref().unwrap(), to_type).boxed()) + } + (LargeList(lhs), List(rhs)) if lhs == rhs => { + Ok(cast_large_to_list(array.as_any().downcast_ref().unwrap(), to_type).boxed()) } - (List(lhs), LargeList(rhs)) if lhs == rhs => Ok(cast_list_to_large_list( - array.as_any().downcast_ref().unwrap(), - to_type, - )) - .map(|x| Box::new(x) as Box), - (LargeList(lhs), List(rhs)) if lhs == rhs => Ok(cast_large_to_list( - array.as_any().downcast_ref().unwrap(), - to_type, - )) - .map(|x| Box::new(x) as Box), (_, List(to)) => { // cast primitive to list's primitive @@ -467,6 +471,11 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu LargeUtf8 => Ok(Box::new(utf8_to_large_utf8( array.as_any().downcast_ref().unwrap(), ))), + Binary => Ok(utf8_to_binary::( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ) + .boxed()), Timestamp(TimeUnit::Nanosecond, None) => utf8_to_naive_timestamp_ns_dyn::(array), Timestamp(TimeUnit::Nanosecond, Some(tz)) => { utf8_to_timestamp_ns_dyn::(array, tz.clone()) @@ -489,8 +498,12 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu Float64 => utf8_to_primitive_dyn::(array, to_type, options), Date32 => utf8_to_date32_dyn::(array), Date64 => utf8_to_date64_dyn::(array), - Utf8 => utf8_large_to_utf8(array.as_any().downcast_ref().unwrap()) - .map(|x| Box::new(x) as Box), + Utf8 => utf8_large_to_utf8(array.as_any().downcast_ref().unwrap()).map(|x| x.boxed()), + LargeBinary => Ok(utf8_to_binary::( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ) + .boxed()), Timestamp(TimeUnit::Nanosecond, None) => utf8_to_naive_timestamp_ns_dyn::(array), Timestamp(TimeUnit::Nanosecond, Some(tz)) => { utf8_to_timestamp_ns_dyn::(array, tz.clone()) @@ -548,16 +561,11 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu Int64 => primitive_to_utf8_dyn::(array), Float32 => primitive_to_utf8_dyn::(array), Float64 => primitive_to_utf8_dyn::(array), - Binary => { - let array = array.as_any().downcast_ref::>().unwrap(); - - // perf todo: the offsets are equal; we can speed-up this - let iter = array - .iter() - .map(|x| x.and_then(|x| simdutf8::basic::from_utf8(x).ok())); - - let array = Utf8Array::::from_trusted_len_iter(iter); - Ok(Box::new(array)) + Binary => binary_to_large_utf8(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .map(|x| x.boxed()), + LargeBinary => { + binary_to_utf8::(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .map(|x| x.boxed()) } Timestamp(from_unit, Some(tz)) => { let from = array.as_any().downcast_ref().unwrap(); @@ -588,6 +596,9 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu array.as_any().downcast_ref().unwrap(), to_type.clone(), ))), + Utf8 => binary_to_utf8::(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .map(|x| x.boxed()), + _ => Err(Error::NotYetImplemented(format!( "Casting from {:?} to {:?} not supported", from_type, to_type, @@ -607,7 +618,11 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu Float64 => binary_to_primitive_dyn::(array, to_type, options), Binary => { binary_large_to_binary(array.as_any().downcast_ref().unwrap(), to_type.clone()) - .map(|x| Box::new(x) as Box) + .map(|x| x.boxed()) + } + LargeUtf8 => { + binary_to_utf8::(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .map(|x| x.boxed()) } _ => Err(Error::NotYetImplemented(format!( "Casting from {:?} to {:?} not supported", diff --git a/src/compute/cast/utf8_to.rs b/src/compute/cast/utf8_to.rs index d6a235500e0..996889174a2 100644 --- a/src/compute/cast/utf8_to.rs +++ b/src/compute/cast/utf8_to.rs @@ -175,3 +175,15 @@ pub fn utf8_large_to_utf8(from: &Utf8Array) -> Result> { // Safety: sound because `offsets` fulfills the same invariants as `from.offsets()` Ok(unsafe { Utf8Array::::from_data_unchecked(data_type, offsets, values, validity) }) } + +/// Conversion to binary +pub fn utf8_to_binary(from: &Utf8Array, to_data_type: DataType) -> BinaryArray { + unsafe { + BinaryArray::::new_unchecked( + to_data_type, + from.offsets().clone(), + from.values().clone(), + from.validity().cloned(), + ) + } +} diff --git a/tests/it/compute/cast.rs b/tests/it/compute/cast.rs index 9e9888926d3..a9a7c45e341 100644 --- a/tests/it/compute/cast.rs +++ b/tests/it/compute/cast.rs @@ -481,7 +481,10 @@ fn consistency() { if let Ok(result) = result { assert_eq!(result.data_type(), d2, "type not equal: {:?} {:?}", d1, d2); } else { - panic!("Cast should have not failed {:?} {:?}", d1, d2); + panic!( + "Cast should have not have failed {:?} {:?}: {:?}", + d1, d2, result + ); } } else if cast(array.as_ref(), d2, CastOptions::default()).is_ok() { panic!("Cast should have failed {:?} {:?}", d1, d2);