diff --git a/src/compute/cast/primitive_to.rs b/src/compute/cast/primitive_to.rs index 4f8292d4407..b7a7a310c5a 100644 --- a/src/compute/cast/primitive_to.rs +++ b/src/compute/cast/primitive_to.rs @@ -1,6 +1,6 @@ use std::hash::Hash; -use num_traits::{AsPrimitive, Float}; +use num_traits::{AsPrimitive, Float, ToPrimitive}; use crate::error::Result; use crate::{ @@ -203,10 +203,8 @@ pub fn float_to_decimal( to_scale: usize, ) -> PrimitiveArray where - T: NativeType + Float, + T: NativeType + Float + ToPrimitive, f64: AsPrimitive, - i128: From, - T: AsPrimitive, { // 1.2 => 12 let multiplier: T = (10_f64).powi(to_scale as i32).as_(); @@ -218,7 +216,7 @@ where let values = from.iter().map(|x| { x.and_then(|x| { - let x = i128::from(*x * multiplier); + let x = (*x * multiplier).to_i128().unwrap(); if x > max_for_precision || x < min_for_precision { None } else { @@ -237,10 +235,11 @@ pub(super) fn float_to_decimal_dyn( scale: usize, ) -> Result> where - T: NativeType + AsPrimitive, + T: NativeType + Float + ToPrimitive, + f64: AsPrimitive, { let from = from.as_any().downcast_ref().unwrap(); - Ok(Box::new(integer_to_decimal::(from, precision, scale))) + Ok(Box::new(float_to_decimal::(from, precision, scale))) } /// Cast [`PrimitiveArray`] as a [`PrimitiveArray`] diff --git a/tests/it/compute/cast.rs b/tests/it/compute/cast.rs index 0833d113608..9c3d16d7491 100644 --- a/tests/it/compute/cast.rs +++ b/tests/it/compute/cast.rs @@ -254,19 +254,28 @@ fn int32_to_decimal() { #[test] fn float32_to_decimal() { let array = Float32Array::from(&[ - Some(2.0), + Some(2.4), Some(10.0), + Some(1.123_456_8), Some(-2.0), Some(-10.0), - Some(-100.0), // can't be represented in (1,0) + Some(-100.01), // can't be represented in (1,0) None, ]); - let b = cast(&array, &DataType::Decimal(1, 0), CastOptions::default()).unwrap(); + let b = cast(&array, &DataType::Decimal(10, 2), CastOptions::default()).unwrap(); let c = b.as_any().downcast_ref::>().unwrap(); - let expected = Int128Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None, None]) - .to(DataType::Decimal(1, 0)); + let expected = Int128Array::from(&[ + Some(240), + Some(1000), + Some(112), + Some(-200), + Some(-1000), + Some(-10001), + None, + ]) + .to(DataType::Decimal(10, 2)); assert_eq!(c, &expected) }