From 966ce678e90b30db7a849bf87449d6ea57922177 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 Nov 2022 10:12:52 -0800 Subject: [PATCH] Check overflow while casting between decimal types --- arrow-cast/src/cast.rs | 439 +++++++++++++++++++++++++++++++++-------- 1 file changed, 359 insertions(+), 80 deletions(-) diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index e394426bd682..d0132d430c60 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -562,16 +562,16 @@ pub fn cast_with_options( } match (from_type, to_type) { (Decimal128(_, s1), Decimal128(p2, s2)) => { - cast_decimal_to_decimal::<16, 16>(array, s1, p2, s2) + cast_decimal_to_decimal_with_option::<16, 16>(array, s1, p2, s2, cast_options) } (Decimal256(_, s1), Decimal256(p2, s2)) => { - cast_decimal_to_decimal::<32, 32>(array, s1, p2, s2) + cast_decimal_to_decimal_with_option::<32, 32>(array, s1, p2, s2, cast_options) } (Decimal128(_, s1), Decimal256(p2, s2)) => { - cast_decimal_to_decimal::<16, 32>(array, s1, p2, s2) + cast_decimal_to_decimal_with_option::<16, 32>(array, s1, p2, s2, cast_options) } (Decimal256(_, s1), Decimal128(p2, s2)) => { - cast_decimal_to_decimal::<32, 16>(array, s1, p2, s2) + cast_decimal_to_decimal_with_option::<32, 16>(array, s1, p2, s2, cast_options) } (Decimal128(_, scale), _) => { // cast decimal to other type @@ -1593,7 +1593,36 @@ const fn time_unit_multiple(unit: &TimeUnit) -> i64 { } /// Cast one type of decimal array to another type of decimal array -fn cast_decimal_to_decimal( +fn cast_decimal_to_decimal_with_option< + const BYTE_WIDTH1: usize, + const BYTE_WIDTH2: usize, +>( + array: &ArrayRef, + input_scale: &u8, + output_precision: &u8, + output_scale: &u8, + cast_options: &CastOptions, +) -> Result { + if cast_options.safe { + cast_decimal_to_decimal_safe::( + array, + input_scale, + output_precision, + output_scale, + ) + } else { + cast_decimal_to_decimal::( + array, + input_scale, + output_precision, + output_scale, + ) + } +} + +/// Cast one type of decimal array to another type of decimal array. Returning NULLs for +/// the array values when cast failures happen. +fn cast_decimal_to_decimal_safe( array: &ArrayRef, input_scale: &u8, output_precision: &u8, @@ -1605,54 +1634,50 @@ fn cast_decimal_to_decimal( let div = 10_i128.pow((input_scale - output_scale) as u32); if BYTE_WIDTH1 == 16 { let array = array.as_any().downcast_ref::().unwrap(); - let iter = array.iter().map(|v| v.map(|v| v.wrapping_div(div))); if BYTE_WIDTH2 == 16 { - let output_array = iter - .collect::() - .with_precision_and_scale(*output_precision, *output_scale)?; - - Ok(Arc::new(output_array)) + let iter = array + .iter() + .map(|v| v.and_then(|v| v.div_checked(div).ok())); + let casted_array = unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + }; + casted_array + .with_precision_and_scale(*output_precision, *output_scale) + .map(|a| Arc::new(a) as ArrayRef) } else { - let output_array = iter - .map(|v| v.map(i256::from_i128)) - .collect::() - .with_precision_and_scale(*output_precision, *output_scale)?; - - Ok(Arc::new(output_array)) + let iter = array.iter().map(|v| { + v.and_then(|v| v.div_checked(div).ok().map(i256::from_i128)) + }); + let casted_array = unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + }; + casted_array + .with_precision_and_scale(*output_precision, *output_scale) + .map(|a| Arc::new(a) as ArrayRef) } } else { let array = array.as_any().downcast_ref::().unwrap(); let div = i256::from_i128(div); - let iter = array.iter().map(|v| v.map(|v| v.wrapping_div(div))); if BYTE_WIDTH2 == 16 { - let values = iter - .map(|v| { - if v.is_none() { - Ok(None) - } else { - v.as_ref().and_then(|v| v.to_i128()) - .ok_or_else(|| { - ArrowError::InvalidArgumentError( - format!("{:?} cannot be casted to 128-bit integer for Decimal128", v), - ) - }) - .map(Some) - } - }) - .collect::, _>>()?; - - let output_array = values - .into_iter() - .collect::() - .with_precision_and_scale(*output_precision, *output_scale)?; - - Ok(Arc::new(output_array)) + let iter = array.iter().map(|v| { + v.and_then(|v| v.div_checked(div).ok().and_then(|v| v.to_i128())) + }); + let casted_array = unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + }; + casted_array + .with_precision_and_scale(*output_precision, *output_scale) + .map(|a| Arc::new(a) as ArrayRef) } else { - let output_array = iter - .collect::() - .with_precision_and_scale(*output_precision, *output_scale)?; - - Ok(Arc::new(output_array)) + let iter = array + .iter() + .map(|v| v.and_then(|v| v.div_checked(div).ok())); + let casted_array = unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + }; + casted_array + .with_precision_and_scale(*output_precision, *output_scale) + .map(|a| Arc::new(a) as ArrayRef) } } } else { @@ -1661,54 +1686,278 @@ fn cast_decimal_to_decimal( let mul = 10_i128.pow((output_scale - input_scale) as u32); if BYTE_WIDTH1 == 16 { let array = array.as_any().downcast_ref::().unwrap(); - let iter = array.iter().map(|v| v.map(|v| v.wrapping_mul(mul))); if BYTE_WIDTH2 == 16 { - let output_array = iter - .collect::() - .with_precision_and_scale(*output_precision, *output_scale)?; + let iter = array + .iter() + .map(|v| v.and_then(|v| v.mul_checked(mul).ok())); + let casted_array = unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + }; + casted_array + .with_precision_and_scale(*output_precision, *output_scale) + .map(|a| Arc::new(a) as ArrayRef) + } else { + let iter = array.iter().map(|v| { + v.and_then(|v| v.mul_checked(mul).ok().map(i256::from_i128)) + }); + let casted_array = unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + }; + casted_array + .with_precision_and_scale(*output_precision, *output_scale) + .map(|a| Arc::new(a) as ArrayRef) + } + } else { + let array = array.as_any().downcast_ref::().unwrap(); + let mul = i256::from_i128(mul); + if BYTE_WIDTH2 == 16 { + let iter = array.iter().map(|v| { + v.and_then(|v| v.mul_checked(mul).ok().and_then(|v| v.to_i128())) + }); + let casted_array = unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + }; + casted_array + .with_precision_and_scale(*output_precision, *output_scale) + .map(|a| Arc::new(a) as ArrayRef) + } else { + let iter = array + .iter() + .map(|v| v.and_then(|v| v.mul_checked(mul).ok())); + let casted_array = unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + }; + casted_array + .with_precision_and_scale(*output_precision, *output_scale) + .map(|a| Arc::new(a) as ArrayRef) + } + } + } +} + +/// Cast one type of decimal array to another type of decimal array. Returning `Err` if +/// cast failure happens. +fn cast_decimal_to_decimal( + array: &ArrayRef, + input_scale: &u8, + output_precision: &u8, + output_scale: &u8, +) -> Result { + if input_scale > output_scale { + // For example, input_scale is 4 and output_scale is 3; + // Original value is 11234_i128, and will be cast to 1123_i128. + let array = array.as_any().downcast_ref::().unwrap(); + if BYTE_WIDTH1 == 16 { + if BYTE_WIDTH2 == 16 { + let div = 10_i128 + .pow_checked((input_scale - output_scale) as u32) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast. The scale {} causes overflow.", + *output_scale, + )) + })?; - Ok(Arc::new(output_array)) + array + .try_unary::<_, Decimal128Type, _>(|v| { + v.checked_div(div).ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {:?}({}, {}). Overflowing on {:?}", + Decimal128Type::PREFIX, + *output_precision, + *output_scale, + v + )) + }) + }) + .and_then(|a| { + a.with_precision_and_scale(*output_precision, *output_scale) + }) + .map(|a| Arc::new(a) as ArrayRef) } else { - let output_array = iter - .map(|v| v.map(i256::from_i128)) - .collect::() - .with_precision_and_scale(*output_precision, *output_scale)?; + let div = i256::from_i128(10_i128) + .pow_checked((input_scale - output_scale) as u32) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast. The scale {} causes overflow.", + *output_scale, + )) + })?; - Ok(Arc::new(output_array)) + array + .try_unary::<_, Decimal256Type, _>(|v| { + i256::from_i128(v).checked_div(div).ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {:?}({}, {}). Overflowing on {:?}", + Decimal256Type::PREFIX, + *output_precision, + *output_scale, + v + )) + }) + }) + .and_then(|a| { + a.with_precision_and_scale(*output_precision, *output_scale) + }) + .map(|a| Arc::new(a) as ArrayRef) } } else { let array = array.as_any().downcast_ref::().unwrap(); - let mul = i256::from_i128(mul); - let iter = array.iter().map(|v| v.map(|v| v.wrapping_mul(mul))); + let div = i256::from_i128(10_i128) + .pow_checked((input_scale - output_scale) as u32) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast. The scale {} causes overflow.", + *output_scale, + )) + })?; if BYTE_WIDTH2 == 16 { - let values = iter - .map(|v| { - if v.is_none() { - Ok(None) - } else { - v.as_ref().and_then(|v| v.to_i128()) - .ok_or_else(|| { - ArrowError::InvalidArgumentError( - format!("{:?} cannot be casted to 128-bit integer for Decimal128", v), - ) - }) - .map(Some) - } + array + .try_unary::<_, Decimal128Type, _>(|v| { + v.checked_div(div).ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {:?}({}, {}). Overflowing on {:?}", + Decimal128Type::PREFIX, + *output_precision, + *output_scale, + v + )) + }).and_then(|v| v.to_i128().ok_or_else(|| { + ArrowError::InvalidArgumentError( + format!("{:?} cannot be casted to 128-bit integer for Decimal128", v), + ) + })) + }) + .and_then(|a| { + a.with_precision_and_scale(*output_precision, *output_scale) }) - .collect::, _>>()?; + .map(|a| Arc::new(a) as ArrayRef) + } else { + array + .try_unary::<_, Decimal256Type, _>(|v| { + v.checked_div(div).ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {:?}({}, {}). Overflowing on {:?}", + Decimal256Type::PREFIX, + *output_precision, + *output_scale, + v + )) + }) + }) + .and_then(|a| { + a.with_precision_and_scale(*output_precision, *output_scale) + }) + .map(|a| Arc::new(a) as ArrayRef) + } + } + } else { + // For example, input_scale is 3 and output_scale is 4; + // Original value is 1123_i128, and will be cast to 11230_i128. + if BYTE_WIDTH1 == 16 { + let array = array.as_any().downcast_ref::().unwrap(); - let output_array = values - .into_iter() - .collect::() - .with_precision_and_scale(*output_precision, *output_scale)?; + if BYTE_WIDTH2 == 16 { + let mul = 10_i128 + .pow_checked((output_scale - input_scale) as u32) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast. The scale {} causes overflow.", + *output_scale, + )) + })?; - Ok(Arc::new(output_array)) + array + .try_unary::<_, Decimal128Type, _>(|v| { + v.checked_mul(mul).ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {:?}({}, {}). Overflowing on {:?}", + Decimal128Type::PREFIX, + *output_precision, + *output_scale, + v + )) + }) + }) + .and_then(|a| { + a.with_precision_and_scale(*output_precision, *output_scale) + }) + .map(|a| Arc::new(a) as ArrayRef) } else { - let output_array = iter - .collect::() - .with_precision_and_scale(*output_precision, *output_scale)?; + let mul = i256::from_i128(10_i128) + .pow_checked((output_scale - input_scale) as u32) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast. The scale {} causes overflow.", + *output_scale, + )) + })?; - Ok(Arc::new(output_array)) + array + .try_unary::<_, Decimal256Type, _>(|v| { + i256::from_i128(v).checked_mul(mul).ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {:?}({}, {}). Overflowing on {:?}", + Decimal256Type::PREFIX, + *output_precision, + *output_scale, + v + )) + }) + }) + .and_then(|a| { + a.with_precision_and_scale(*output_precision, *output_scale) + }) + .map(|a| Arc::new(a) as ArrayRef) + } + } else { + let array = array.as_any().downcast_ref::().unwrap(); + let mul = i256::from_i128(10_i128) + .pow_checked((output_scale - input_scale) as u32) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast. The scale {} causes overflow.", + *output_scale, + )) + })?; + if BYTE_WIDTH2 == 16 { + array + .try_unary::<_, Decimal128Type, _>(|v| { + v.checked_mul(mul).ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {:?}({}, {}). Overflowing on {:?}", + Decimal128Type::PREFIX, + *output_precision, + *output_scale, + v + )) + }).and_then(|v| v.to_i128().ok_or_else(|| { + ArrowError::InvalidArgumentError( + format!("{:?} cannot be casted to 128-bit integer for Decimal128", v), + ) + })) + }) + .and_then(|a| { + a.with_precision_and_scale(*output_precision, *output_scale) + }) + .map(|a| Arc::new(a) as ArrayRef) + } else { + array + .try_unary::<_, Decimal256Type, _>(|v| { + v.checked_mul(mul).ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast to {:?}({}, {}). Overflowing on {:?}", + Decimal256Type::PREFIX, + *output_precision, + *output_scale, + v + )) + }) + }) + .and_then(|a| { + a.with_precision_and_scale(*output_precision, *output_scale) + }) + .map(|a| Arc::new(a) as ArrayRef) } } } @@ -3025,6 +3274,36 @@ mod tests { err.unwrap_err().to_string()); } + #[test] + fn test_cast_decimal128_to_decimal128_overflow() { + let input_type = DataType::Decimal128(38, 3); + let output_type = DataType::Decimal128(38, 38); + assert!(can_cast_types(&input_type, &output_type)); + + let array = vec![Some(i128::MAX)]; + let input_decimal_array = create_decimal_array(array, 38, 3).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + let result = + cast_with_options(&array, &output_type, &CastOptions { safe: false }); + assert_eq!("Cast error: Cannot cast to \"Decimal128\"(38, 38). Overflowing on 170141183460469231731687303715884105727", + result.unwrap_err().to_string()); + } + + #[test] + fn test_cast_decimal128_to_decimal256_overflow() { + let input_type = DataType::Decimal128(38, 3); + let output_type = DataType::Decimal256(76, 76); + assert!(can_cast_types(&input_type, &output_type)); + + let array = vec![Some(i128::MAX)]; + let input_decimal_array = create_decimal_array(array, 38, 3).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + let result = + cast_with_options(&array, &output_type, &CastOptions { safe: false }); + assert_eq!("Cast error: Cannot cast to \"Decimal256\"(76, 76). Overflowing on 170141183460469231731687303715884105727", + result.unwrap_err().to_string()); + } + #[test] fn test_cast_decimal128_to_decimal256() { let input_type = DataType::Decimal128(20, 3);