From 4919a9d86e46c371bbd4b364c8807a9a58feb392 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Thu, 13 Jan 2022 06:11:51 +0000 Subject: [PATCH] Added support to cast decimal --- src/compute/cast/decimal_to.rs | 136 +++++++++++++++++++++++++++++++ src/compute/cast/mod.rs | 61 ++++++++++++-- src/compute/cast/primitive_to.rs | 90 ++++++++++++++++++++ tests/it/compute/cast.rs | 100 +++++++++++++++++++++++ 4 files changed, 381 insertions(+), 6 deletions(-) create mode 100644 src/compute/cast/decimal_to.rs diff --git a/src/compute/cast/decimal_to.rs b/src/compute/cast/decimal_to.rs new file mode 100644 index 00000000000..3434e2666e1 --- /dev/null +++ b/src/compute/cast/decimal_to.rs @@ -0,0 +1,136 @@ +use num_traits::{AsPrimitive, Float, NumCast}; + +use crate::error::Result; +use crate::types::NativeType; +use crate::{array::*, datatypes::DataType}; + +#[inline] +fn decimal_to_decimal_impl Option>( + from: &PrimitiveArray, + op: F, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray { + let min_for_precision = 9_i128 + .saturating_pow(1 + to_precision as u32) + .saturating_neg(); + let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32); + + let values = from.iter().map(|x| { + x.and_then(|x| { + op(*x).and_then(|x| { + if x > max_for_precision || x < min_for_precision { + None + } else { + Some(x) + } + }) + }) + }); + PrimitiveArray::::from_trusted_len_iter(values) + .to(DataType::Decimal(to_precision, to_scale)) +} + +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn decimal_to_decimal( + from: &PrimitiveArray, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray { + let (from_precision, from_scale) = + if let DataType::Decimal(p, s) = from.data_type().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; + + if to_scale == from_scale && to_precision >= from_precision { + // fast path + return from.clone().to(DataType::Decimal(to_precision, to_scale)); + } + // todo: other fast paths include increasing scale and precision by so that + // a number will never overflow (validity is preserved) + + if from_scale > to_scale { + let factor = 10_i128.pow((from_scale - to_scale) as u32); + decimal_to_decimal_impl( + from, + |x: i128| x.checked_div(factor), + to_precision, + to_scale, + ) + } else { + let factor = 10_i128.pow((to_scale - from_scale) as u32); + decimal_to_decimal_impl( + from, + |x: i128| x.checked_mul(factor), + to_precision, + to_scale, + ) + } +} + +pub(super) fn decimal_to_decimal_dyn( + from: &dyn Array, + to_precision: usize, + to_scale: usize, +) -> Result> { + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(decimal_to_decimal(from, to_precision, to_scale))) +} + +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn decimal_to_float(from: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + Float, + f64: AsPrimitive, +{ + let (_, from_scale) = if let DataType::Decimal(p, s) = from.data_type().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; + + let div = 10_f64.powi(from_scale as i32); + let values = from + .values() + .iter() + .map(|x| (*x as f64 / div).as_()) + .collect(); + + PrimitiveArray::::from_data(T::PRIMITIVE.into(), values, from.validity().cloned()) +} + +pub(super) fn decimal_to_float_dyn(from: &dyn Array) -> Result> +where + T: NativeType + Float, + f64: AsPrimitive, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(decimal_to_float::(from))) +} + +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn decimal_to_integer(from: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + NumCast, +{ + let (_, from_scale) = if let DataType::Decimal(p, s) = from.data_type().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; + + let factor = 10_i128.pow(from_scale as u32); + let values = from.iter().map(|x| x.and_then(|x| T::from(*x / factor))); + + PrimitiveArray::from_trusted_len_iter(values) +} + +pub(super) fn decimal_to_integer_dyn(from: &dyn Array) -> Result> +where + T: NativeType + NumCast, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(decimal_to_integer::(from))) +} diff --git a/src/compute/cast/mod.rs b/src/compute/cast/mod.rs index 7b4231f9605..3e6adf5a30b 100644 --- a/src/compute/cast/mod.rs +++ b/src/compute/cast/mod.rs @@ -1,23 +1,25 @@ //! Defines different casting operators such as [`cast`] or [`primitive_to_binary`]. -use crate::{ - array::*, - datatypes::*, - error::{ArrowError, Result}, -}; - mod binary_to; mod boolean_to; +mod decimal_to; mod dictionary_to; mod primitive_to; mod utf8_to; pub use binary_to::*; pub use boolean_to::*; +pub use decimal_to::*; pub use dictionary_to::*; pub use primitive_to::*; pub use utf8_to::*; +use crate::{ + array::*, + datatypes::*, + error::{ArrowError, Result}, +}; + /// options defining how Cast kernels behave #[derive(Clone, Copy, Debug, Default)] pub struct CastOptions { @@ -143,6 +145,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (UInt8, Int64) => true, (UInt8, Float32) => true, (UInt8, Float64) => true, + (UInt8, Decimal(_, _)) => true, (UInt16, UInt8) => true, (UInt16, UInt32) => true, @@ -153,6 +156,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (UInt16, Int64) => true, (UInt16, Float32) => true, (UInt16, Float64) => true, + (UInt16, Decimal(_, _)) => true, (UInt32, UInt8) => true, (UInt32, UInt16) => true, @@ -163,6 +167,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (UInt32, Int64) => true, (UInt32, Float32) => true, (UInt32, Float64) => true, + (UInt32, Decimal(_, _)) => true, (UInt64, UInt8) => true, (UInt64, UInt16) => true, @@ -173,6 +178,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (UInt64, Int64) => true, (UInt64, Float32) => true, (UInt64, Float64) => true, + (UInt64, Decimal(_, _)) => true, (Int8, UInt8) => true, (Int8, UInt16) => true, @@ -183,6 +189,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Int8, Int64) => true, (Int8, Float32) => true, (Int8, Float64) => true, + (Int8, Decimal(_, _)) => true, (Int16, UInt8) => true, (Int16, UInt16) => true, @@ -193,6 +200,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Int16, Int64) => true, (Int16, Float32) => true, (Int16, Float64) => true, + (Int16, Decimal(_, _)) => true, (Int32, UInt8) => true, (Int32, UInt16) => true, @@ -203,6 +211,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Int32, Int64) => true, (Int32, Float32) => true, (Int32, Float64) => true, + (Int32, Decimal(_, _)) => true, (Int64, UInt8) => true, (Int64, UInt16) => true, @@ -213,6 +222,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Int64, Int32) => true, (Int64, Float32) => true, (Int64, Float64) => true, + (Int64, Decimal(_, _)) => true, (Float32, UInt8) => true, (Float32, UInt16) => true, @@ -223,6 +233,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Float32, Int32) => true, (Float32, Int64) => true, (Float32, Float64) => true, + (Float32, Decimal(_, _)) => true, (Float64, UInt8) => true, (Float64, UInt16) => true, @@ -233,6 +244,22 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Float64, Int32) => true, (Float64, Int64) => true, (Float64, Float32) => true, + (Float64, Decimal(_, _)) => true, + + ( + Decimal(_, _), + UInt8 + | UInt16 + | UInt32 + | UInt64 + | Int8 + | Int16 + | Int32 + | Int64 + | Float32 + | Float64 + | Decimal(_, _), + ) => true, // end numeric casts // temporal casts @@ -649,6 +676,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (UInt8, Int64) => primitive_to_primitive_dyn::(array, to_type, options), (UInt8, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), (UInt8, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), (UInt16, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (UInt16, UInt32) => primitive_to_primitive_dyn::(array, to_type, as_options), @@ -659,6 +687,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (UInt16, Int64) => primitive_to_primitive_dyn::(array, to_type, options), (UInt16, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), (UInt16, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), (UInt32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (UInt32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -669,6 +698,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (UInt32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), (UInt32, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), (UInt32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt32, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), (UInt64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (UInt64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -679,6 +709,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (UInt64, Int64) => primitive_to_primitive_dyn::(array, to_type, options), (UInt64, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), (UInt64, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt64, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), (Int8, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (Int8, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -689,6 +720,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Int8, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), (Int8, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), (Int8, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), (Int16, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (Int16, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -699,6 +731,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Int16, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), (Int16, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), (Int16, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), (Int32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (Int32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -709,6 +742,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Int32, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), (Int32, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), (Int32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int32, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), (Int64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (Int64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -719,6 +753,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Int64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), (Int64, Float32) => primitive_to_primitive_dyn::(array, to_type, options), (Int64, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int64, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), (Float32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (Float32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -729,6 +764,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Float32, Int32) => primitive_to_primitive_dyn::(array, to_type, options), (Float32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), (Float32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Float32, Decimal(p, s)) => float_to_decimal_dyn::(array, *p, *s), (Float64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (Float64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -739,6 +775,19 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Float64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), (Float64, Int64) => primitive_to_primitive_dyn::(array, to_type, options), (Float64, Float32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Decimal(p, s)) => float_to_decimal_dyn::(array, *p, *s), + + (Decimal(_, _), UInt8) => decimal_to_integer_dyn::(array), + (Decimal(_, _), UInt16) => decimal_to_integer_dyn::(array), + (Decimal(_, _), UInt32) => decimal_to_integer_dyn::(array), + (Decimal(_, _), UInt64) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int8) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int16) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int32) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int64) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Float32) => decimal_to_float_dyn::(array), + (Decimal(_, _), Float64) => decimal_to_float_dyn::(array), + (Decimal(_, _), Decimal(to_p, to_s)) => decimal_to_decimal_dyn(array, *to_p, *to_s), // end numeric casts // temporal casts diff --git a/src/compute/cast/primitive_to.rs b/src/compute/cast/primitive_to.rs index d2c77e4221e..4f8292d4407 100644 --- a/src/compute/cast/primitive_to.rs +++ b/src/compute/cast/primitive_to.rs @@ -1,5 +1,7 @@ use std::hash::Hash; +use num_traits::{AsPrimitive, Float}; + use crate::error::Result; use crate::{ array::*, @@ -153,6 +155,94 @@ where PrimitiveArray::::from_trusted_len_iter(iter).to(to_type.clone()) } +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn integer_to_decimal>( + from: &PrimitiveArray, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray { + let multiplier = 10_i128.pow(to_scale as u32); + + let min_for_precision = 9_i128 + .saturating_pow(1 + to_precision as u32) + .saturating_neg(); + let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32); + + let values = from.iter().map(|x| { + x.and_then(|x| { + x.as_().checked_mul(multiplier).and_then(|x| { + if x > max_for_precision || x < min_for_precision { + None + } else { + Some(x) + } + }) + }) + }); + + PrimitiveArray::::from_trusted_len_iter(values) + .to(DataType::Decimal(to_precision, to_scale)) +} + +pub(super) fn integer_to_decimal_dyn( + from: &dyn Array, + precision: usize, + scale: usize, +) -> Result> +where + T: NativeType + AsPrimitive, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(integer_to_decimal::(from, precision, scale))) +} + +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn float_to_decimal( + from: &PrimitiveArray, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray +where + T: NativeType + Float, + f64: AsPrimitive, + i128: From, + T: AsPrimitive, +{ + // 1.2 => 12 + let multiplier: T = (10_f64).powi(to_scale as i32).as_(); + + let min_for_precision = 9_i128 + .saturating_pow(1 + to_precision as u32) + .saturating_neg(); + let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32); + + let values = from.iter().map(|x| { + x.and_then(|x| { + let x = i128::from(*x * multiplier); + if x > max_for_precision || x < min_for_precision { + None + } else { + Some(x) + } + }) + }); + + PrimitiveArray::::from_trusted_len_iter(values) + .to(DataType::Decimal(to_precision, to_scale)) +} + +pub(super) fn float_to_decimal_dyn( + from: &dyn Array, + precision: usize, + scale: usize, +) -> Result> +where + T: NativeType + AsPrimitive, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(integer_to_decimal::(from, precision, scale))) +} + /// Cast [`PrimitiveArray`] as a [`PrimitiveArray`] /// Same as `number as to_number_type` in rust pub fn primitive_as_primitive( diff --git a/tests/it/compute/cast.rs b/tests/it/compute/cast.rs index d15d2e9e46f..5c8298c04e9 100644 --- a/tests/it/compute/cast.rs +++ b/tests/it/compute/cast.rs @@ -238,6 +238,104 @@ fn utf8_to_i32() { assert_eq!(c, &expected); } +#[test] +fn int32_to_decimal() { + // 10 and -10 can be represented with precision 1 and scale 0 + let array = Int32Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None]); + + let b = cast(&array, &DataType::Decimal(1, 0), CastOptions::default()).unwrap(); + let c = b.as_any().downcast_ref::>().unwrap(); + + let expected = Int128Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None]) + .to(DataType::Decimal(1, 0)); + assert_eq!(c, &expected) +} + +#[test] +fn float32_to_decimal() { + let array = Float32Array::from(&[ + Some(2.0), + Some(10.0), + Some(-2.0), + Some(-10.0), + Some(-100.0), // can't be represented in (1,0) + None, + ]); + + let b = cast(&array, &DataType::Decimal(1, 0), CastOptions::default()).unwrap(); + let c = b.as_any().downcast_ref::>().unwrap(); + + let expected = Int128Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None, None]) + .to(DataType::Decimal(1, 0)); + assert_eq!(c, &expected) +} + +#[test] +fn int32_to_decimal_scaled() { + // 10 and -10 can't be represented with precision 1 and scale 1 + let array = Int32Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None]); + + let b = cast(&array, &DataType::Decimal(1, 1), CastOptions::default()).unwrap(); + let c = b.as_any().downcast_ref::>().unwrap(); + + let expected = + Int128Array::from(&[Some(20), None, Some(-20), None, None]).to(DataType::Decimal(1, 1)); + assert_eq!(c, &expected) +} + +#[test] +fn decimal_to_decimal() { + // increase scale and precision + let array = Int128Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None]) + .to(DataType::Decimal(1, 0)); + + let b = cast(&array, &DataType::Decimal(2, 1), CastOptions::default()).unwrap(); + let c = b.as_any().downcast_ref::>().unwrap(); + + let expected = Int128Array::from(&[Some(20), Some(100), Some(-20), Some(-100), None]) + .to(DataType::Decimal(2, 1)); + assert_eq!(c, &expected) +} + +#[test] +fn decimal_to_decimal_scaled() { + // decrease precision + // 10 and -10 can't be represented with precision 1 and scale 1 + let array = Int128Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None]) + .to(DataType::Decimal(1, 0)); + + let b = cast(&array, &DataType::Decimal(1, 1), CastOptions::default()).unwrap(); + let c = b.as_any().downcast_ref::>().unwrap(); + + let expected = + Int128Array::from(&[Some(20), None, Some(-20), None, None]).to(DataType::Decimal(1, 1)); + assert_eq!(c, &expected) +} + +#[test] +fn decimal_to_float() { + let array = Int128Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None]) + .to(DataType::Decimal(2, 1)); + + let b = cast(&array, &DataType::Float32, CastOptions::default()).unwrap(); + let c = b.as_any().downcast_ref::>().unwrap(); + + let expected = Float32Array::from(&[Some(0.2), Some(1.0), Some(-0.2), Some(-1.0), None]); + assert_eq!(c, &expected) +} + +#[test] +fn decimal_to_integer() { + let array = Int128Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None, Some(2560)]) + .to(DataType::Decimal(2, 1)); + + let b = cast(&array, &DataType::Int8, CastOptions::default()).unwrap(); + let c = b.as_any().downcast_ref::>().unwrap(); + + let expected = Int8Array::from(&[Some(0), Some(1), Some(0), Some(-1), None, None]); + assert_eq!(c, &expected) +} + #[test] fn utf8_to_i32_partial() { let array = Utf8Array::::from_slice(&["5", "6", "seven", "8aa", "9.1aa"]); @@ -336,6 +434,8 @@ fn consistency() { Date32, Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond), + Decimal(1, 2), + Decimal(2, 2), Date64, Utf8, LargeUtf8,