Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Add partial option to CastOptions (#561)
Browse files Browse the repository at this point in the history
  • Loading branch information
sundy-li authored Nov 4, 2021
1 parent ed8836f commit 3c10b16
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 103 deletions.
24 changes: 23 additions & 1 deletion src/compute/cast/binary_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::convert::TryFrom;
use crate::error::{ArrowError, Result};
use crate::{array::*, buffer::Buffer, datatypes::DataType, types::NativeType};

use super::CastOptions;

/// Conversion of binary
pub fn binary_to_large_binary(from: &BinaryArray<i32>, to_data_type: DataType) -> BinaryArray<i64> {
let values = from.values().clone();
Expand Down Expand Up @@ -30,6 +32,21 @@ pub fn binary_large_to_binary(
))
}

/// Casts a [`BinaryArray`] to a [`PrimitiveArray`] at best-effort using `lexical_core::parse_partial`, making any uncastable value as zero.
pub fn partial_binary_to_primitive<O: Offset, T>(
from: &BinaryArray<O>,
to: &DataType,
) -> PrimitiveArray<T>
where
T: NativeType + lexical_core::FromLexical,
{
let iter = from
.iter()
.map(|x| x.and_then::<T, _>(|x| lexical_core::parse_partial(x).ok().map(|x| x.0)));

PrimitiveArray::<T>::from_trusted_len_iter(iter).to(to.clone())
}

/// Casts a [`BinaryArray`] to a [`PrimitiveArray`], making any uncastable value a Null.
pub fn binary_to_primitive<O: Offset, T>(from: &BinaryArray<O>, to: &DataType) -> PrimitiveArray<T>
where
Expand All @@ -45,12 +62,17 @@ where
pub(super) fn binary_to_primitive_dyn<O: Offset, T>(
from: &dyn Array,
to: &DataType,
options: CastOptions,
) -> Result<Box<dyn Array>>
where
T: NativeType + lexical_core::FromLexical,
{
let from = from.as_any().downcast_ref().unwrap();
Ok(Box::new(binary_to_primitive::<O, T>(from, to)))
if options.partial {
Ok(Box::new(partial_binary_to_primitive::<O, T>(from, to)))
} else {
Ok(Box::new(binary_to_primitive::<O, T>(from, to)))
}
}

/// Cast [`BinaryArray`] to [`DictionaryArray`], also known as packing.
Expand Down
19 changes: 13 additions & 6 deletions src/compute/cast/dictionary_to.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{primitive_as_primitive, primitive_to_primitive, CastOptions};
use crate::{
array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray},
compute::{cast::cast_with_options, take::take},
compute::{cast::cast, take::take},
datatypes::DataType,
error::{ArrowError, Result},
};
Expand Down Expand Up @@ -32,7 +32,7 @@ pub fn dictionary_to_dictionary_values<K: DictionaryKey>(
let keys = from.keys();
let values = from.values();

let values = cast_with_options(values.as_ref(), values_type, CastOptions::default())?.into();
let values = cast(values.as_ref(), values_type, CastOptions::default())?.into();
Ok(DictionaryArray::from_data(keys.clone(), values))
}

Expand All @@ -44,8 +44,15 @@ pub fn wrapping_dictionary_to_dictionary_values<K: DictionaryKey>(
let keys = from.keys();
let values = from.values();

let values =
cast_with_options(values.as_ref(), values_type, CastOptions { wrapped: true })?.into();
let values = cast(
values.as_ref(),
values_type,
CastOptions {
wrapped: true,
partial: false,
},
)?
.into();
Ok(DictionaryArray::from_data(keys.clone(), values))
}

Expand Down Expand Up @@ -104,7 +111,7 @@ pub(super) fn dictionary_cast_dyn<K: DictionaryKey>(

match to_type {
DataType::Dictionary(to_keys_type, to_values_type) => {
let values = cast_with_options(values.as_ref(), to_values_type, options)?.into();
let values = cast(values.as_ref(), to_values_type, options)?.into();

// create the appropriate array type
with_match_dictionary_key_type!(to_keys_type.as_ref(), |$T| {
Expand All @@ -127,7 +134,7 @@ where
{
// attempt to cast the dict values to the target type
// use the take kernel to expand out the dictionary
let values = cast_with_options(values, to_type, options)?;
let values = cast(values, to_type, options)?;

// take requires first casting i32
let indices = primitive_to_primitive::<_, i32>(keys, &DataType::Int32);
Expand Down
111 changes: 49 additions & 62 deletions src/compute/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ pub use utf8_to::*;

/// options defining how Cast kernels behave
#[derive(Clone, Copy, Debug, Default)]
struct CastOptions {
pub struct CastOptions {
/// default to false
/// whether an overflowing cast should be converted to `None` (default), or be wrapped (i.e. `256i16 as u8 = 0` vectorized).
/// Settings this to `true` is 5-6x faster for numeric types.
wrapped: bool,
pub wrapped: bool,
/// default to false
/// whether to cast to an integer at the best-effort
pub partial: bool,
}

impl CastOptions {
Expand Down Expand Up @@ -262,7 +265,7 @@ fn cast_list<O: Offset>(
options: CastOptions,
) -> Result<ListArray<O>> {
let values = array.values();
let new_values = cast_with_options(
let new_values = cast(
values.as_ref(),
ListArray::<O>::get_child_type(to_type),
options,
Expand Down Expand Up @@ -323,23 +326,7 @@ fn cast_large_to_list(array: &ListArray<i64>, to_type: &DataType) -> ListArray<i
/// * List to primitive
/// * Utf8 to boolean
/// * Interval and duration
pub fn cast(array: &dyn Array, to_type: &DataType) -> Result<Box<dyn Array>> {
cast_with_options(array, to_type, CastOptions { wrapped: false })
}

/// Similar to [`cast`], but overflowing cast is wrapped
/// Behavior:
/// * PrimitiveArray to PrimitiveArray: overflowing cast will be wrapped (i.e. `256i16 as u8 = 0` vectorized).
pub fn wrapping_cast(array: &dyn Array, to_type: &DataType) -> Result<Box<dyn Array>> {
cast_with_options(array, to_type, CastOptions { wrapped: true })
}

#[inline]
fn cast_with_options(
array: &dyn Array,
to_type: &DataType,
options: CastOptions,
) -> Result<Box<dyn Array>> {
pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Result<Box<dyn Array>> {
use DataType::*;
let from_type = array.data_type();

Expand Down Expand Up @@ -378,7 +365,7 @@ fn cast_with_options(

(_, List(to)) => {
// cast primitive to list's primitive
let values = cast_with_options(array, to.data_type(), options)?.into();
let values = cast(array, to.data_type(), options)?.into();
// create offsets, where if array.len() = 2, we have [0,1,2]
let offsets =
unsafe { Buffer::from_trusted_len_iter_unchecked(0..=array.len() as i32) };
Expand Down Expand Up @@ -451,16 +438,16 @@ fn cast_with_options(
},

(Utf8, _) => match to_type {
UInt8 => utf8_to_primitive_dyn::<i32, u8>(array, to_type),
UInt16 => utf8_to_primitive_dyn::<i32, u16>(array, to_type),
UInt32 => utf8_to_primitive_dyn::<i32, u32>(array, to_type),
UInt64 => utf8_to_primitive_dyn::<i32, u64>(array, to_type),
Int8 => utf8_to_primitive_dyn::<i32, i8>(array, to_type),
Int16 => utf8_to_primitive_dyn::<i32, i16>(array, to_type),
Int32 => utf8_to_primitive_dyn::<i32, i32>(array, to_type),
Int64 => utf8_to_primitive_dyn::<i32, i64>(array, to_type),
Float32 => utf8_to_primitive_dyn::<i32, f32>(array, to_type),
Float64 => utf8_to_primitive_dyn::<i32, f64>(array, to_type),
UInt8 => utf8_to_primitive_dyn::<i32, u8>(array, to_type, options),
UInt16 => utf8_to_primitive_dyn::<i32, u16>(array, to_type, options),
UInt32 => utf8_to_primitive_dyn::<i32, u32>(array, to_type, options),
UInt64 => utf8_to_primitive_dyn::<i32, u64>(array, to_type, options),
Int8 => utf8_to_primitive_dyn::<i32, i8>(array, to_type, options),
Int16 => utf8_to_primitive_dyn::<i32, i16>(array, to_type, options),
Int32 => utf8_to_primitive_dyn::<i32, i32>(array, to_type, options),
Int64 => utf8_to_primitive_dyn::<i32, i64>(array, to_type, options),
Float32 => utf8_to_primitive_dyn::<i32, f32>(array, to_type, options),
Float64 => utf8_to_primitive_dyn::<i32, f64>(array, to_type, options),
Date32 => utf8_to_date32_dyn::<i32>(array),
Date64 => utf8_to_date64_dyn::<i32>(array),
LargeUtf8 => Ok(Box::new(utf8_to_large_utf8(
Expand All @@ -476,16 +463,16 @@ fn cast_with_options(
))),
},
(LargeUtf8, _) => match to_type {
UInt8 => utf8_to_primitive_dyn::<i64, u8>(array, to_type),
UInt16 => utf8_to_primitive_dyn::<i64, u16>(array, to_type),
UInt32 => utf8_to_primitive_dyn::<i64, u32>(array, to_type),
UInt64 => utf8_to_primitive_dyn::<i64, u64>(array, to_type),
Int8 => utf8_to_primitive_dyn::<i64, i8>(array, to_type),
Int16 => utf8_to_primitive_dyn::<i64, i16>(array, to_type),
Int32 => utf8_to_primitive_dyn::<i64, i32>(array, to_type),
Int64 => utf8_to_primitive_dyn::<i64, i64>(array, to_type),
Float32 => utf8_to_primitive_dyn::<i64, f32>(array, to_type),
Float64 => utf8_to_primitive_dyn::<i64, f64>(array, to_type),
UInt8 => utf8_to_primitive_dyn::<i64, u8>(array, to_type, options),
UInt16 => utf8_to_primitive_dyn::<i64, u16>(array, to_type, options),
UInt32 => utf8_to_primitive_dyn::<i64, u32>(array, to_type, options),
UInt64 => utf8_to_primitive_dyn::<i64, u64>(array, to_type, options),
Int8 => utf8_to_primitive_dyn::<i64, i8>(array, to_type, options),
Int16 => utf8_to_primitive_dyn::<i64, i16>(array, to_type, options),
Int32 => utf8_to_primitive_dyn::<i64, i32>(array, to_type, options),
Int64 => utf8_to_primitive_dyn::<i64, i64>(array, to_type, options),
Float32 => utf8_to_primitive_dyn::<i64, f32>(array, to_type, options),
Float64 => utf8_to_primitive_dyn::<i64, f64>(array, to_type, options),
Date32 => utf8_to_date32_dyn::<i64>(array),
Date64 => utf8_to_date64_dyn::<i64>(array),
Utf8 => utf8_large_to_utf8(array.as_any().downcast_ref().unwrap())
Expand Down Expand Up @@ -573,16 +560,16 @@ fn cast_with_options(
},

(Binary, _) => match to_type {
UInt8 => binary_to_primitive_dyn::<i32, u8>(array, to_type),
UInt16 => binary_to_primitive_dyn::<i32, u16>(array, to_type),
UInt32 => binary_to_primitive_dyn::<i32, u32>(array, to_type),
UInt64 => binary_to_primitive_dyn::<i32, u64>(array, to_type),
Int8 => binary_to_primitive_dyn::<i32, i8>(array, to_type),
Int16 => binary_to_primitive_dyn::<i32, i16>(array, to_type),
Int32 => binary_to_primitive_dyn::<i32, i32>(array, to_type),
Int64 => binary_to_primitive_dyn::<i32, i64>(array, to_type),
Float32 => binary_to_primitive_dyn::<i32, f32>(array, to_type),
Float64 => binary_to_primitive_dyn::<i32, f64>(array, to_type),
UInt8 => binary_to_primitive_dyn::<i32, u8>(array, to_type, options),
UInt16 => binary_to_primitive_dyn::<i32, u16>(array, to_type, options),
UInt32 => binary_to_primitive_dyn::<i32, u32>(array, to_type, options),
UInt64 => binary_to_primitive_dyn::<i32, u64>(array, to_type, options),
Int8 => binary_to_primitive_dyn::<i32, i8>(array, to_type, options),
Int16 => binary_to_primitive_dyn::<i32, i16>(array, to_type, options),
Int32 => binary_to_primitive_dyn::<i32, i32>(array, to_type, options),
Int64 => binary_to_primitive_dyn::<i32, i64>(array, to_type, options),
Float32 => binary_to_primitive_dyn::<i32, f32>(array, to_type, options),
Float64 => binary_to_primitive_dyn::<i32, f64>(array, to_type, options),
LargeBinary => Ok(Box::new(binary_to_large_binary(
array.as_any().downcast_ref().unwrap(),
to_type.clone(),
Expand All @@ -594,16 +581,16 @@ fn cast_with_options(
},

(LargeBinary, _) => match to_type {
UInt8 => binary_to_primitive_dyn::<i64, u8>(array, to_type),
UInt16 => binary_to_primitive_dyn::<i64, u16>(array, to_type),
UInt32 => binary_to_primitive_dyn::<i64, u32>(array, to_type),
UInt64 => binary_to_primitive_dyn::<i64, u64>(array, to_type),
Int8 => binary_to_primitive_dyn::<i64, i8>(array, to_type),
Int16 => binary_to_primitive_dyn::<i64, i16>(array, to_type),
Int32 => binary_to_primitive_dyn::<i64, i32>(array, to_type),
Int64 => binary_to_primitive_dyn::<i64, i64>(array, to_type),
Float32 => binary_to_primitive_dyn::<i64, f32>(array, to_type),
Float64 => binary_to_primitive_dyn::<i64, f64>(array, to_type),
UInt8 => binary_to_primitive_dyn::<i64, u8>(array, to_type, options),
UInt16 => binary_to_primitive_dyn::<i64, u16>(array, to_type, options),
UInt32 => binary_to_primitive_dyn::<i64, u32>(array, to_type, options),
UInt64 => binary_to_primitive_dyn::<i64, u64>(array, to_type, options),
Int8 => binary_to_primitive_dyn::<i64, i8>(array, to_type, options),
Int16 => binary_to_primitive_dyn::<i64, i16>(array, to_type, options),
Int32 => binary_to_primitive_dyn::<i64, i32>(array, to_type, options),
Int64 => binary_to_primitive_dyn::<i64, i64>(array, to_type, options),
Float32 => binary_to_primitive_dyn::<i64, f32>(array, to_type, options),
Float64 => binary_to_primitive_dyn::<i64, f64>(array, to_type, options),
Binary => {
binary_large_to_binary(array.as_any().downcast_ref().unwrap(), to_type.clone())
.map(|x| Box::new(x) as Box<dyn Array>)
Expand Down Expand Up @@ -821,7 +808,7 @@ fn cast_to_dictionary<K: DictionaryKey>(
dict_value_type: &DataType,
options: CastOptions,
) -> Result<Box<dyn Array>> {
let array = cast_with_options(array, dict_value_type, options)?;
let array = cast(array, dict_value_type, options)?;
let array = array.as_ref();
match *dict_value_type {
DataType::Int8 => primitive_to_dictionary_dyn::<i8, K>(array),
Expand Down
24 changes: 23 additions & 1 deletion src/compute/cast/utf8_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use crate::{
},
};

use super::CastOptions;

const RFC3339: &str = "%Y-%m-%dT%H:%M:%S%.f%:z";

/// Casts a [`Utf8Array`] to a [`PrimitiveArray`], making any uncastable value a Null.
Expand All @@ -25,15 +27,35 @@ where
PrimitiveArray::<T>::from_trusted_len_iter(iter).to(to.clone())
}

/// Casts a [`Utf8Array`] to a [`PrimitiveArray`] at best-effort using `lexical_core::parse_partial`, making any uncastable value as zero.
pub fn partial_utf8_to_primitive<O: Offset, T>(
from: &Utf8Array<O>,
to: &DataType,
) -> PrimitiveArray<T>
where
T: NativeType + lexical_core::FromLexical,
{
let iter = from.iter().map(|x| {
x.and_then::<T, _>(|x| lexical_core::parse_partial(x.as_bytes()).ok().map(|x| x.0))
});

PrimitiveArray::<T>::from_trusted_len_iter(iter).to(to.clone())
}

pub(super) fn utf8_to_primitive_dyn<O: Offset, T>(
from: &dyn Array,
to: &DataType,
options: CastOptions,
) -> Result<Box<dyn Array>>
where
T: NativeType + lexical_core::FromLexical,
{
let from = from.as_any().downcast_ref().unwrap();
Ok(Box::new(utf8_to_primitive::<O, T>(from, to)))
if options.partial {
Ok(Box::new(partial_utf8_to_primitive::<O, T>(from, to)))
} else {
Ok(Box::new(utf8_to_primitive::<O, T>(from, to)))
}
}

/// Casts a [`Utf8Array`] to a Date32 primitive, making any uncastable value a Null.
Expand Down
Loading

0 comments on commit 3c10b16

Please sign in to comment.