Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Compute: add partial option into CastOptions #561

Merged
merged 5 commits into from
Nov 4, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/compute/cast/binary_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::convert::TryFrom;
use crate::error::{ArrowError, Result};
use crate::{array::*, buffer::Buffer, datatypes::DataType, types::NativeType};

use super::CastOptions;

/// Conversion of binary
pub fn binary_to_large_binary(from: &BinaryArray<i32>, to_data_type: DataType) -> BinaryArray<i64> {
let values = from.values().clone();
Expand Down Expand Up @@ -31,26 +33,35 @@ pub fn binary_large_to_binary(
}

/// Casts a [`BinaryArray`] to a [`PrimitiveArray`], making any uncastable value a Null.
pub fn binary_to_primitive<O: Offset, T>(from: &BinaryArray<O>, to: &DataType) -> PrimitiveArray<T>
pub fn binary_to_primitive<O: Offset, T>(
from: &BinaryArray<O>,
to: &DataType,
options: CastOptions,
) -> PrimitiveArray<T>
where
T: NativeType + lexical_core::FromLexical,
{
let iter = from
.iter()
.map(|x| x.and_then::<T, _>(|x| lexical_core::parse(x).ok()));
let parse_fn = if !options.partial {
jorgecarleitao marked this conversation as resolved.
Show resolved Hide resolved
|x| lexical_core::parse(x).ok()
} else {
|x| lexical_core::parse_partial(x).ok().map(|x| x.0)
};

let iter = from.iter().map(|x| x.and_then::<T, _>(parse_fn));

PrimitiveArray::<T>::from_trusted_len_iter(iter).to(to.clone())
}

pub(super) fn binary_to_primitive_dyn<O: Offset, T>(
from: &dyn Array,
to: &DataType,
options: CastOptions,
) -> Result<Box<dyn Array>>
where
T: NativeType + lexical_core::FromLexical,
{
let from = from.as_any().downcast_ref().unwrap();
Ok(Box::new(binary_to_primitive::<O, T>(from, to)))
Ok(Box::new(binary_to_primitive::<O, T>(from, to, options)))
}

/// Cast [`BinaryArray`] to [`DictionaryArray`], also known as packing.
Expand Down
11 changes: 9 additions & 2 deletions src/compute/cast/dictionary_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,15 @@ pub fn wrapping_dictionary_to_dictionary_values<K: DictionaryKey>(
let keys = from.keys();
let values = from.values();

let values =
cast_with_options(values.as_ref(), values_type, CastOptions { wrapped: true })?.into();
let values = cast_with_options(
values.as_ref(),
values_type,
CastOptions {
wrapped: true,
partial: false,
},
)?
.into();
Ok(DictionaryArray::from_data(keys.clone(), values))
}

Expand Down
110 changes: 67 additions & 43 deletions src/compute/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ pub use utf8_to::*;

/// options defining how Cast kernels behave
#[derive(Clone, Copy, Debug, Default)]
struct CastOptions {
pub struct CastOptions {
/// default to false
/// whether an overflowing cast should be converted to `None` (default), or be wrapped (i.e. `256i16 as u8 = 0` vectorized).
/// Settings this to `true` is 5-6x faster for numeric types.
wrapped: bool,
/// default to false
/// whether to cast to an integer at the best-effort
partial: bool,
}

impl CastOptions {
Expand Down Expand Up @@ -324,14 +327,35 @@ fn cast_large_to_list(array: &ListArray<i64>, to_type: &DataType) -> ListArray<i
/// * Utf8 to boolean
/// * Interval and duration
pub fn cast(array: &dyn Array, to_type: &DataType) -> Result<Box<dyn Array>> {
cast_with_options(array, to_type, CastOptions { wrapped: false })
cast_with_options(array, to_type, CastOptions::default())
}

/// Similar to [`cast`], but overflowing cast is wrapped
/// Behavior:
/// * PrimitiveArray to PrimitiveArray: overflowing cast will be wrapped (i.e. `256i16 as u8 = 0` vectorized).
pub fn wrapping_cast(array: &dyn Array, to_type: &DataType) -> Result<Box<dyn Array>> {
cast_with_options(array, to_type, CastOptions { wrapped: true })
cast_with_options(
array,
to_type,
CastOptions {
wrapped: true,
partial: false,
},
)
}

/// Similar to [`cast`], but parse the utf8/binary into integer at the best-effort.
/// Behavior:
/// * PrimitiveArray to PrimitiveArray: overflowing cast will be wrapped (i.e. `256i16 as u8 = 0` vectorized).
jorgecarleitao marked this conversation as resolved.
Show resolved Hide resolved
pub fn partial_cast(array: &dyn Array, to_type: &DataType) -> Result<Box<dyn Array>> {
cast_with_options(
array,
to_type,
CastOptions {
wrapped: false,
partial: true,
},
)
}

#[inline]
Expand Down Expand Up @@ -451,16 +475,16 @@ fn cast_with_options(
},

(Utf8, _) => match to_type {
UInt8 => utf8_to_primitive_dyn::<i32, u8>(array, to_type),
UInt16 => utf8_to_primitive_dyn::<i32, u16>(array, to_type),
UInt32 => utf8_to_primitive_dyn::<i32, u32>(array, to_type),
UInt64 => utf8_to_primitive_dyn::<i32, u64>(array, to_type),
Int8 => utf8_to_primitive_dyn::<i32, i8>(array, to_type),
Int16 => utf8_to_primitive_dyn::<i32, i16>(array, to_type),
Int32 => utf8_to_primitive_dyn::<i32, i32>(array, to_type),
Int64 => utf8_to_primitive_dyn::<i32, i64>(array, to_type),
Float32 => utf8_to_primitive_dyn::<i32, f32>(array, to_type),
Float64 => utf8_to_primitive_dyn::<i32, f64>(array, to_type),
UInt8 => utf8_to_primitive_dyn::<i32, u8>(array, to_type, options),
UInt16 => utf8_to_primitive_dyn::<i32, u16>(array, to_type, options),
UInt32 => utf8_to_primitive_dyn::<i32, u32>(array, to_type, options),
UInt64 => utf8_to_primitive_dyn::<i32, u64>(array, to_type, options),
Int8 => utf8_to_primitive_dyn::<i32, i8>(array, to_type, options),
Int16 => utf8_to_primitive_dyn::<i32, i16>(array, to_type, options),
Int32 => utf8_to_primitive_dyn::<i32, i32>(array, to_type, options),
Int64 => utf8_to_primitive_dyn::<i32, i64>(array, to_type, options),
Float32 => utf8_to_primitive_dyn::<i32, f32>(array, to_type, options),
Float64 => utf8_to_primitive_dyn::<i32, f64>(array, to_type, options),
Date32 => utf8_to_date32_dyn::<i32>(array),
Date64 => utf8_to_date64_dyn::<i32>(array),
LargeUtf8 => Ok(Box::new(utf8_to_large_utf8(
Expand All @@ -476,16 +500,16 @@ fn cast_with_options(
))),
},
(LargeUtf8, _) => match to_type {
UInt8 => utf8_to_primitive_dyn::<i64, u8>(array, to_type),
UInt16 => utf8_to_primitive_dyn::<i64, u16>(array, to_type),
UInt32 => utf8_to_primitive_dyn::<i64, u32>(array, to_type),
UInt64 => utf8_to_primitive_dyn::<i64, u64>(array, to_type),
Int8 => utf8_to_primitive_dyn::<i64, i8>(array, to_type),
Int16 => utf8_to_primitive_dyn::<i64, i16>(array, to_type),
Int32 => utf8_to_primitive_dyn::<i64, i32>(array, to_type),
Int64 => utf8_to_primitive_dyn::<i64, i64>(array, to_type),
Float32 => utf8_to_primitive_dyn::<i64, f32>(array, to_type),
Float64 => utf8_to_primitive_dyn::<i64, f64>(array, to_type),
UInt8 => utf8_to_primitive_dyn::<i64, u8>(array, to_type, options),
UInt16 => utf8_to_primitive_dyn::<i64, u16>(array, to_type, options),
UInt32 => utf8_to_primitive_dyn::<i64, u32>(array, to_type, options),
UInt64 => utf8_to_primitive_dyn::<i64, u64>(array, to_type, options),
Int8 => utf8_to_primitive_dyn::<i64, i8>(array, to_type, options),
Int16 => utf8_to_primitive_dyn::<i64, i16>(array, to_type, options),
Int32 => utf8_to_primitive_dyn::<i64, i32>(array, to_type, options),
Int64 => utf8_to_primitive_dyn::<i64, i64>(array, to_type, options),
Float32 => utf8_to_primitive_dyn::<i64, f32>(array, to_type, options),
Float64 => utf8_to_primitive_dyn::<i64, f64>(array, to_type, options),
Date32 => utf8_to_date32_dyn::<i64>(array),
Date64 => utf8_to_date64_dyn::<i64>(array),
Utf8 => utf8_large_to_utf8(array.as_any().downcast_ref().unwrap())
Expand Down Expand Up @@ -573,16 +597,16 @@ fn cast_with_options(
},

(Binary, _) => match to_type {
UInt8 => binary_to_primitive_dyn::<i32, u8>(array, to_type),
UInt16 => binary_to_primitive_dyn::<i32, u16>(array, to_type),
UInt32 => binary_to_primitive_dyn::<i32, u32>(array, to_type),
UInt64 => binary_to_primitive_dyn::<i32, u64>(array, to_type),
Int8 => binary_to_primitive_dyn::<i32, i8>(array, to_type),
Int16 => binary_to_primitive_dyn::<i32, i16>(array, to_type),
Int32 => binary_to_primitive_dyn::<i32, i32>(array, to_type),
Int64 => binary_to_primitive_dyn::<i32, i64>(array, to_type),
Float32 => binary_to_primitive_dyn::<i32, f32>(array, to_type),
Float64 => binary_to_primitive_dyn::<i32, f64>(array, to_type),
UInt8 => binary_to_primitive_dyn::<i32, u8>(array, to_type, options),
UInt16 => binary_to_primitive_dyn::<i32, u16>(array, to_type, options),
UInt32 => binary_to_primitive_dyn::<i32, u32>(array, to_type, options),
UInt64 => binary_to_primitive_dyn::<i32, u64>(array, to_type, options),
Int8 => binary_to_primitive_dyn::<i32, i8>(array, to_type, options),
Int16 => binary_to_primitive_dyn::<i32, i16>(array, to_type, options),
Int32 => binary_to_primitive_dyn::<i32, i32>(array, to_type, options),
Int64 => binary_to_primitive_dyn::<i32, i64>(array, to_type, options),
Float32 => binary_to_primitive_dyn::<i32, f32>(array, to_type, options),
Float64 => binary_to_primitive_dyn::<i32, f64>(array, to_type, options),
LargeBinary => Ok(Box::new(binary_to_large_binary(
array.as_any().downcast_ref().unwrap(),
to_type.clone(),
Expand All @@ -594,16 +618,16 @@ fn cast_with_options(
},

(LargeBinary, _) => match to_type {
UInt8 => binary_to_primitive_dyn::<i64, u8>(array, to_type),
UInt16 => binary_to_primitive_dyn::<i64, u16>(array, to_type),
UInt32 => binary_to_primitive_dyn::<i64, u32>(array, to_type),
UInt64 => binary_to_primitive_dyn::<i64, u64>(array, to_type),
Int8 => binary_to_primitive_dyn::<i64, i8>(array, to_type),
Int16 => binary_to_primitive_dyn::<i64, i16>(array, to_type),
Int32 => binary_to_primitive_dyn::<i64, i32>(array, to_type),
Int64 => binary_to_primitive_dyn::<i64, i64>(array, to_type),
Float32 => binary_to_primitive_dyn::<i64, f32>(array, to_type),
Float64 => binary_to_primitive_dyn::<i64, f64>(array, to_type),
UInt8 => binary_to_primitive_dyn::<i64, u8>(array, to_type, options),
UInt16 => binary_to_primitive_dyn::<i64, u16>(array, to_type, options),
UInt32 => binary_to_primitive_dyn::<i64, u32>(array, to_type, options),
UInt64 => binary_to_primitive_dyn::<i64, u64>(array, to_type, options),
Int8 => binary_to_primitive_dyn::<i64, i8>(array, to_type, options),
Int16 => binary_to_primitive_dyn::<i64, i16>(array, to_type, options),
Int32 => binary_to_primitive_dyn::<i64, i32>(array, to_type, options),
Int64 => binary_to_primitive_dyn::<i64, i64>(array, to_type, options),
Float32 => binary_to_primitive_dyn::<i64, f32>(array, to_type, options),
Float64 => binary_to_primitive_dyn::<i64, f64>(array, to_type, options),
Binary => {
binary_large_to_binary(array.as_any().downcast_ref().unwrap(), to_type.clone())
.map(|x| Box::new(x) as Box<dyn Array>)
Expand Down
19 changes: 16 additions & 3 deletions src/compute/cast/utf8_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,42 @@ use crate::{
},
};

use super::CastOptions;

const RFC3339: &str = "%Y-%m-%dT%H:%M:%S%.f%:z";

/// Casts a [`Utf8Array`] to a [`PrimitiveArray`], making any uncastable value a Null.
pub fn utf8_to_primitive<O: Offset, T>(from: &Utf8Array<O>, to: &DataType) -> PrimitiveArray<T>
jorgecarleitao marked this conversation as resolved.
Show resolved Hide resolved
pub fn utf8_to_primitive<O: Offset, T>(
from: &Utf8Array<O>,
to: &DataType,
options: CastOptions,
) -> PrimitiveArray<T>
where
T: NativeType + lexical_core::FromLexical,
{
let parse_fn = if !options.partial {
jorgecarleitao marked this conversation as resolved.
Show resolved Hide resolved
|x| lexical_core::parse(x).ok()
} else {
|x| lexical_core::parse_partial(x).ok().map(|x| x.0)
};

let iter = from
.iter()
.map(|x| x.and_then::<T, _>(|x| lexical_core::parse(x.as_bytes()).ok()));
.map(|x| x.and_then::<T, _>(|x| parse_fn(x.as_bytes())));

PrimitiveArray::<T>::from_trusted_len_iter(iter).to(to.clone())
}

pub(super) fn utf8_to_primitive_dyn<O: Offset, T>(
from: &dyn Array,
to: &DataType,
options: CastOptions,
) -> Result<Box<dyn Array>>
where
T: NativeType + lexical_core::FromLexical,
{
let from = from.as_any().downcast_ref().unwrap();
Ok(Box::new(utf8_to_primitive::<O, T>(from, to)))
Ok(Box::new(utf8_to_primitive::<O, T>(from, to, options)))
}

/// Casts a [`Utf8Array`] to a Date32 primitive, making any uncastable value a Null.
Expand Down
13 changes: 12 additions & 1 deletion tests/it/compute/cast.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use arrow2::array::*;
use arrow2::compute::cast::{can_cast_types, cast, wrapping_cast};
use arrow2::compute::cast::{can_cast_types, cast, partial_cast, wrapping_cast};
use arrow2::datatypes::*;
use arrow2::types::NativeType;

Expand Down Expand Up @@ -173,6 +173,17 @@ fn binary_to_i32() {
assert_eq!(c, &expected);
}

#[test]
fn binary_to_i32_partial() {
let array = BinaryArray::<i32>::from_slice(&["5", "6", "123 abseven", "aaa", "9.1"]);
let b = partial_cast(&array, &DataType::Int32).unwrap();
let c = b.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap();

let expected = &[Some(5), Some(6), Some(123), Some(0), Some(9)];
let expected = Int32Array::from(expected);
assert_eq!(c, &expected);
}

#[test]
fn utf8_to_i32() {
let array = Utf8Array::<i32>::from_slice(&["5", "6", "seven", "8", "9.1"]);
Expand Down