Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Added support to cast decimal #761

Merged
merged 2 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions src/compute/cast/decimal_to.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use num_traits::{AsPrimitive, Float, NumCast};

use crate::error::Result;
use crate::types::NativeType;
use crate::{array::*, datatypes::DataType};

#[inline]
fn decimal_to_decimal_impl<F: Fn(i128) -> Option<i128>>(
from: &PrimitiveArray<i128>,
op: F,
to_precision: usize,
to_scale: usize,
) -> PrimitiveArray<i128> {
let min_for_precision = 9_i128
.saturating_pow(1 + to_precision as u32)
.saturating_neg();
let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32);

let values = from.iter().map(|x| {
x.and_then(|x| {
op(*x).and_then(|x| {
if x > max_for_precision || x < min_for_precision {
None
} else {
Some(x)
}
})
})
});
PrimitiveArray::<i128>::from_trusted_len_iter(values)
.to(DataType::Decimal(to_precision, to_scale))
}

/// Returns a [`PrimitiveArray<i128>`] with the casted values. Values are `None` on overflow
pub fn decimal_to_decimal(
from: &PrimitiveArray<i128>,
to_precision: usize,
to_scale: usize,
) -> PrimitiveArray<i128> {
let (from_precision, from_scale) =
if let DataType::Decimal(p, s) = from.data_type().to_logical_type() {
(*p, *s)
} else {
panic!("internal error: i128 is always a decimal")
};

if to_scale == from_scale && to_precision >= from_precision {
// fast path
return from.clone().to(DataType::Decimal(to_precision, to_scale));
}
// todo: other fast paths include increasing scale and precision by so that
// a number will never overflow (validity is preserved)

if from_scale > to_scale {
let factor = 10_i128.pow((from_scale - to_scale) as u32);
decimal_to_decimal_impl(
from,
|x: i128| x.checked_div(factor),
to_precision,
to_scale,
)
} else {
let factor = 10_i128.pow((to_scale - from_scale) as u32);
decimal_to_decimal_impl(
from,
|x: i128| x.checked_mul(factor),
to_precision,
to_scale,
)
}
}

pub(super) fn decimal_to_decimal_dyn(
from: &dyn Array,
to_precision: usize,
to_scale: usize,
) -> Result<Box<dyn Array>> {
let from = from.as_any().downcast_ref().unwrap();
Ok(Box::new(decimal_to_decimal(from, to_precision, to_scale)))
}

/// Returns a [`PrimitiveArray<i128>`] with the casted values. Values are `None` on overflow
pub fn decimal_to_float<T>(from: &PrimitiveArray<i128>) -> PrimitiveArray<T>
where
T: NativeType + Float,
f64: AsPrimitive<T>,
{
let (_, from_scale) = if let DataType::Decimal(p, s) = from.data_type().to_logical_type() {
(*p, *s)
} else {
panic!("internal error: i128 is always a decimal")
};

let div = 10_f64.powi(from_scale as i32);
let values = from
.values()
.iter()
.map(|x| (*x as f64 / div).as_())
.collect();

PrimitiveArray::<T>::from_data(T::PRIMITIVE.into(), values, from.validity().cloned())
}

pub(super) fn decimal_to_float_dyn<T>(from: &dyn Array) -> Result<Box<dyn Array>>
where
T: NativeType + Float,
f64: AsPrimitive<T>,
{
let from = from.as_any().downcast_ref().unwrap();
Ok(Box::new(decimal_to_float::<T>(from)))
}

/// Returns a [`PrimitiveArray<i128>`] with the casted values. Values are `None` on overflow
pub fn decimal_to_integer<T>(from: &PrimitiveArray<i128>) -> PrimitiveArray<T>
where
T: NativeType + NumCast,
{
let (_, from_scale) = if let DataType::Decimal(p, s) = from.data_type().to_logical_type() {
(*p, *s)
} else {
panic!("internal error: i128 is always a decimal")
};

let factor = 10_i128.pow(from_scale as u32);
let values = from.iter().map(|x| x.and_then(|x| T::from(*x / factor)));

PrimitiveArray::from_trusted_len_iter(values)
}

pub(super) fn decimal_to_integer_dyn<T>(from: &dyn Array) -> Result<Box<dyn Array>>
where
T: NativeType + NumCast,
{
let from = from.as_any().downcast_ref().unwrap();
Ok(Box::new(decimal_to_integer::<T>(from)))
}
61 changes: 55 additions & 6 deletions src/compute/cast/mod.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
//! Defines different casting operators such as [`cast`] or [`primitive_to_binary`].

use crate::{
array::*,
datatypes::*,
error::{ArrowError, Result},
};

mod binary_to;
mod boolean_to;
mod decimal_to;
mod dictionary_to;
mod primitive_to;
mod utf8_to;

pub use binary_to::*;
pub use boolean_to::*;
pub use decimal_to::*;
pub use dictionary_to::*;
pub use primitive_to::*;
pub use utf8_to::*;

use crate::{
array::*,
datatypes::*,
error::{ArrowError, Result},
};

/// options defining how Cast kernels behave
#[derive(Clone, Copy, Debug, Default)]
pub struct CastOptions {
Expand Down Expand Up @@ -143,6 +145,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(UInt8, Int64) => true,
(UInt8, Float32) => true,
(UInt8, Float64) => true,
(UInt8, Decimal(_, _)) => true,

(UInt16, UInt8) => true,
(UInt16, UInt32) => true,
Expand All @@ -153,6 +156,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(UInt16, Int64) => true,
(UInt16, Float32) => true,
(UInt16, Float64) => true,
(UInt16, Decimal(_, _)) => true,

(UInt32, UInt8) => true,
(UInt32, UInt16) => true,
Expand All @@ -163,6 +167,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(UInt32, Int64) => true,
(UInt32, Float32) => true,
(UInt32, Float64) => true,
(UInt32, Decimal(_, _)) => true,

(UInt64, UInt8) => true,
(UInt64, UInt16) => true,
Expand All @@ -173,6 +178,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(UInt64, Int64) => true,
(UInt64, Float32) => true,
(UInt64, Float64) => true,
(UInt64, Decimal(_, _)) => true,

(Int8, UInt8) => true,
(Int8, UInt16) => true,
Expand All @@ -183,6 +189,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Int8, Int64) => true,
(Int8, Float32) => true,
(Int8, Float64) => true,
(Int8, Decimal(_, _)) => true,

(Int16, UInt8) => true,
(Int16, UInt16) => true,
Expand All @@ -193,6 +200,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Int16, Int64) => true,
(Int16, Float32) => true,
(Int16, Float64) => true,
(Int16, Decimal(_, _)) => true,

(Int32, UInt8) => true,
(Int32, UInt16) => true,
Expand All @@ -203,6 +211,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Int32, Int64) => true,
(Int32, Float32) => true,
(Int32, Float64) => true,
(Int32, Decimal(_, _)) => true,

(Int64, UInt8) => true,
(Int64, UInt16) => true,
Expand All @@ -213,6 +222,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Int64, Int32) => true,
(Int64, Float32) => true,
(Int64, Float64) => true,
(Int64, Decimal(_, _)) => true,

(Float32, UInt8) => true,
(Float32, UInt16) => true,
Expand All @@ -223,6 +233,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Float32, Int32) => true,
(Float32, Int64) => true,
(Float32, Float64) => true,
(Float32, Decimal(_, _)) => true,

(Float64, UInt8) => true,
(Float64, UInt16) => true,
Expand All @@ -233,6 +244,22 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Float64, Int32) => true,
(Float64, Int64) => true,
(Float64, Float32) => true,
(Float64, Decimal(_, _)) => true,

(
Decimal(_, _),
UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float32
| Float64
| Decimal(_, _),
) => true,
// end numeric casts

// temporal casts
Expand Down Expand Up @@ -649,6 +676,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(UInt8, Int64) => primitive_to_primitive_dyn::<u8, i64>(array, to_type, options),
(UInt8, Float32) => primitive_to_primitive_dyn::<u8, f32>(array, to_type, as_options),
(UInt8, Float64) => primitive_to_primitive_dyn::<u8, f64>(array, to_type, as_options),
(UInt8, Decimal(p, s)) => integer_to_decimal_dyn::<u8>(array, *p, *s),

(UInt16, UInt8) => primitive_to_primitive_dyn::<u16, u8>(array, to_type, options),
(UInt16, UInt32) => primitive_to_primitive_dyn::<u16, u32>(array, to_type, as_options),
Expand All @@ -659,6 +687,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(UInt16, Int64) => primitive_to_primitive_dyn::<u16, i64>(array, to_type, options),
(UInt16, Float32) => primitive_to_primitive_dyn::<u16, f32>(array, to_type, as_options),
(UInt16, Float64) => primitive_to_primitive_dyn::<u16, f64>(array, to_type, as_options),
(UInt16, Decimal(p, s)) => integer_to_decimal_dyn::<u16>(array, *p, *s),

(UInt32, UInt8) => primitive_to_primitive_dyn::<u32, u8>(array, to_type, options),
(UInt32, UInt16) => primitive_to_primitive_dyn::<u32, u16>(array, to_type, options),
Expand All @@ -669,6 +698,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(UInt32, Int64) => primitive_to_primitive_dyn::<u32, i64>(array, to_type, options),
(UInt32, Float32) => primitive_to_primitive_dyn::<u32, f32>(array, to_type, as_options),
(UInt32, Float64) => primitive_to_primitive_dyn::<u32, f64>(array, to_type, as_options),
(UInt32, Decimal(p, s)) => integer_to_decimal_dyn::<u32>(array, *p, *s),

(UInt64, UInt8) => primitive_to_primitive_dyn::<u64, u8>(array, to_type, options),
(UInt64, UInt16) => primitive_to_primitive_dyn::<u64, u16>(array, to_type, options),
Expand All @@ -679,6 +709,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(UInt64, Int64) => primitive_to_primitive_dyn::<u64, i64>(array, to_type, options),
(UInt64, Float32) => primitive_to_primitive_dyn::<u64, f32>(array, to_type, as_options),
(UInt64, Float64) => primitive_to_primitive_dyn::<u64, f64>(array, to_type, as_options),
(UInt64, Decimal(p, s)) => integer_to_decimal_dyn::<u64>(array, *p, *s),

(Int8, UInt8) => primitive_to_primitive_dyn::<i8, u8>(array, to_type, options),
(Int8, UInt16) => primitive_to_primitive_dyn::<i8, u16>(array, to_type, options),
Expand All @@ -689,6 +720,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(Int8, Int64) => primitive_to_primitive_dyn::<i8, i64>(array, to_type, as_options),
(Int8, Float32) => primitive_to_primitive_dyn::<i8, f32>(array, to_type, as_options),
(Int8, Float64) => primitive_to_primitive_dyn::<i8, f64>(array, to_type, as_options),
(Int8, Decimal(p, s)) => integer_to_decimal_dyn::<i8>(array, *p, *s),

(Int16, UInt8) => primitive_to_primitive_dyn::<i16, u8>(array, to_type, options),
(Int16, UInt16) => primitive_to_primitive_dyn::<i16, u16>(array, to_type, options),
Expand All @@ -699,6 +731,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(Int16, Int64) => primitive_to_primitive_dyn::<i16, i64>(array, to_type, as_options),
(Int16, Float32) => primitive_to_primitive_dyn::<i16, f32>(array, to_type, as_options),
(Int16, Float64) => primitive_to_primitive_dyn::<i16, f64>(array, to_type, as_options),
(Int16, Decimal(p, s)) => integer_to_decimal_dyn::<i16>(array, *p, *s),

(Int32, UInt8) => primitive_to_primitive_dyn::<i32, u8>(array, to_type, options),
(Int32, UInt16) => primitive_to_primitive_dyn::<i32, u16>(array, to_type, options),
Expand All @@ -709,6 +742,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(Int32, Int64) => primitive_to_primitive_dyn::<i32, i64>(array, to_type, as_options),
(Int32, Float32) => primitive_to_primitive_dyn::<i32, f32>(array, to_type, as_options),
(Int32, Float64) => primitive_to_primitive_dyn::<i32, f64>(array, to_type, as_options),
(Int32, Decimal(p, s)) => integer_to_decimal_dyn::<i32>(array, *p, *s),

(Int64, UInt8) => primitive_to_primitive_dyn::<i64, u8>(array, to_type, options),
(Int64, UInt16) => primitive_to_primitive_dyn::<i64, u16>(array, to_type, options),
Expand All @@ -719,6 +753,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(Int64, Int32) => primitive_to_primitive_dyn::<i64, i32>(array, to_type, options),
(Int64, Float32) => primitive_to_primitive_dyn::<i64, f32>(array, to_type, options),
(Int64, Float64) => primitive_to_primitive_dyn::<i64, f64>(array, to_type, as_options),
(Int64, Decimal(p, s)) => integer_to_decimal_dyn::<i64>(array, *p, *s),

(Float32, UInt8) => primitive_to_primitive_dyn::<f32, u8>(array, to_type, options),
(Float32, UInt16) => primitive_to_primitive_dyn::<f32, u16>(array, to_type, options),
Expand All @@ -729,6 +764,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(Float32, Int32) => primitive_to_primitive_dyn::<f32, i32>(array, to_type, options),
(Float32, Int64) => primitive_to_primitive_dyn::<f32, i64>(array, to_type, options),
(Float32, Float64) => primitive_to_primitive_dyn::<f32, f64>(array, to_type, as_options),
(Float32, Decimal(p, s)) => float_to_decimal_dyn::<f32>(array, *p, *s),

(Float64, UInt8) => primitive_to_primitive_dyn::<f64, u8>(array, to_type, options),
(Float64, UInt16) => primitive_to_primitive_dyn::<f64, u16>(array, to_type, options),
Expand All @@ -739,6 +775,19 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu
(Float64, Int32) => primitive_to_primitive_dyn::<f64, i32>(array, to_type, options),
(Float64, Int64) => primitive_to_primitive_dyn::<f64, i64>(array, to_type, options),
(Float64, Float32) => primitive_to_primitive_dyn::<f64, f32>(array, to_type, options),
(Float64, Decimal(p, s)) => float_to_decimal_dyn::<f64>(array, *p, *s),

(Decimal(_, _), UInt8) => decimal_to_integer_dyn::<u8>(array),
(Decimal(_, _), UInt16) => decimal_to_integer_dyn::<u16>(array),
(Decimal(_, _), UInt32) => decimal_to_integer_dyn::<u32>(array),
(Decimal(_, _), UInt64) => decimal_to_integer_dyn::<u64>(array),
(Decimal(_, _), Int8) => decimal_to_integer_dyn::<i8>(array),
(Decimal(_, _), Int16) => decimal_to_integer_dyn::<i16>(array),
(Decimal(_, _), Int32) => decimal_to_integer_dyn::<i32>(array),
(Decimal(_, _), Int64) => decimal_to_integer_dyn::<i64>(array),
(Decimal(_, _), Float32) => decimal_to_float_dyn::<f32>(array),
(Decimal(_, _), Float64) => decimal_to_float_dyn::<f64>(array),
(Decimal(_, _), Decimal(to_p, to_s)) => decimal_to_decimal_dyn(array, *to_p, *to_s),
// end numeric casts

// temporal casts
Expand Down
Loading