diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index f34c899e2265..bd68b9698ce9 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -1003,7 +1003,7 @@ impl PrimitiveArray { pub fn with_precision_and_scale( self, precision: u8, - scale: u8, + scale: i8, ) -> Result where Self: Sized, @@ -1024,7 +1024,7 @@ impl PrimitiveArray { fn validate_precision_scale( &self, precision: u8, - scale: u8, + scale: i8, ) -> Result<(), ArrowError> { if precision == 0 { return Err(ArrowError::InvalidArgumentError(format!( @@ -1046,7 +1046,14 @@ impl PrimitiveArray { T::MAX_SCALE ))); } - if scale > precision { + if scale < -T::MAX_SCALE { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {} is smaller than min {}", + scale, + -Decimal128Type::MAX_SCALE + ))); + } + if scale > 0 && scale as u8 > precision { return Err(ArrowError::InvalidArgumentError(format!( "scale {} is greater than precision {}", scale, precision @@ -1102,7 +1109,7 @@ impl PrimitiveArray { } /// Returns the decimal scale of this array - pub fn scale(&self) -> u8 { + pub fn scale(&self) -> i8 { match T::BYTE_LENGTH { 16 => { if let DataType::Decimal128(_, s) = self.data().data_type() { diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs index dd4d1ba4292b..40d262e8ed72 100644 --- a/arrow-array/src/types.rs +++ b/arrow-array/src/types.rs @@ -491,15 +491,15 @@ pub trait DecimalType: { const BYTE_LENGTH: usize; const MAX_PRECISION: u8; - const MAX_SCALE: u8; - const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType; + const MAX_SCALE: i8; + const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType; const DEFAULT_TYPE: DataType; /// "Decimal128" or "Decimal256", for use in error messages const PREFIX: &'static str; /// Formats the decimal value with the provided precision and scale - fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String; + fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String; /// Validates that `value` contains no more than `precision` decimal digits fn validate_decimal_precision( @@ -515,14 +515,14 @@ pub struct Decimal128Type {} impl DecimalType for Decimal128Type { const BYTE_LENGTH: usize = 16; const MAX_PRECISION: u8 = DECIMAL128_MAX_PRECISION; - const MAX_SCALE: u8 = DECIMAL128_MAX_SCALE; - const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal128; + const MAX_SCALE: i8 = DECIMAL128_MAX_SCALE; + const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal128; const DEFAULT_TYPE: DataType = DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); const PREFIX: &'static str = "Decimal128"; - fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String { - format_decimal_str(&value.to_string(), precision as usize, scale as usize) + fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String { + format_decimal_str(&value.to_string(), precision as usize, scale) } fn validate_decimal_precision(num: i128, precision: u8) -> Result<(), ArrowError> { @@ -543,14 +543,14 @@ pub struct Decimal256Type {} impl DecimalType for Decimal256Type { const BYTE_LENGTH: usize = 32; const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION; - const MAX_SCALE: u8 = DECIMAL256_MAX_SCALE; - const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal256; + const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE; + const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal256; const DEFAULT_TYPE: DataType = DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); const PREFIX: &'static str = "Decimal256"; - fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String { - format_decimal_str(&value.to_string(), precision as usize, scale as usize) + fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String { + format_decimal_str(&value.to_string(), precision as usize, scale) } fn validate_decimal_precision(num: i256, precision: u8) -> Result<(), ArrowError> { @@ -564,7 +564,7 @@ impl ArrowPrimitiveType for Decimal256Type { const DATA_TYPE: DataType = ::DEFAULT_TYPE; } -fn format_decimal_str(value_str: &str, precision: usize, scale: usize) -> String { +fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String { let (sign, rest) = match value_str.strip_prefix('-') { Some(stripped) => ("-", stripped), None => ("", value_str), @@ -574,13 +574,16 @@ fn format_decimal_str(value_str: &str, precision: usize, scale: usize) -> String if scale == 0 { value_str.to_string() - } else if rest.len() > scale { + } else if scale < 0 { + let padding = value_str.len() + scale.unsigned_abs() as usize; + format!("{:0 scale as usize { // Decimal separator is in the middle of the string - let (whole, decimal) = value_str.split_at(value_str.len() - scale); + let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize); format!("{}.{}", whole, decimal) } else { // String has to be padded - format!("{}0.{:0>width$}", sign, rest, width = scale) + format!("{}0.{:0>width$}", sign, rest, width = scale as usize) } } diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 3bf97cf7ade4..61be2171b7c1 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -319,7 +319,7 @@ fn cast_integer_to_decimal< >( array: &PrimitiveArray, precision: u8, - scale: u8, + scale: i8, base: M, cast_options: &CastOptions, ) -> Result @@ -327,7 +327,7 @@ where ::Native: AsPrimitive, M: ArrowNativeTypeOp, { - let mul: M = base.pow_checked(scale as u32).map_err(|_| { + let mul_or_div: M = base.pow_checked(scale.unsigned_abs() as u32).map_err(|_| { ArrowError::CastError(format!( "Cannot cast to {:?}({}, {}). The scale causes overflow.", D::PREFIX, @@ -336,14 +336,26 @@ where )) })?; - if cast_options.safe { + if scale < 0 { + if cast_options.safe { + array + .unary_opt::<_, D>(|v| v.as_().div_checked(mul_or_div).ok()) + .with_precision_and_scale(precision, scale) + .map(|a| Arc::new(a) as ArrayRef) + } else { + array + .try_unary::<_, D, _>(|v| v.as_().div_checked(mul_or_div)) + .and_then(|a| a.with_precision_and_scale(precision, scale)) + .map(|a| Arc::new(a) as ArrayRef) + } + } else if cast_options.safe { array - .unary_opt::<_, D>(|v| v.as_().mul_checked(mul).ok()) + .unary_opt::<_, D>(|v| v.as_().mul_checked(mul_or_div).ok()) .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) } else { array - .try_unary::<_, D, _>(|v| v.as_().mul_checked(mul)) + .try_unary::<_, D, _>(|v| v.as_().mul_checked(mul_or_div)) .and_then(|a| a.with_precision_and_scale(precision, scale)) .map(|a| Arc::new(a) as ArrayRef) } @@ -352,7 +364,7 @@ where fn cast_floating_point_to_decimal128( array: &PrimitiveArray, precision: u8, - scale: u8, + scale: i8, cast_options: &CastOptions, ) -> Result where @@ -391,7 +403,7 @@ where fn cast_floating_point_to_decimal256( array: &PrimitiveArray, precision: u8, - scale: u8, + scale: i8, cast_options: &CastOptions, ) -> Result where @@ -437,7 +449,7 @@ fn cast_reinterpret_arrays< fn cast_decimal_to_integer( array: &ArrayRef, base: D::Native, - scale: u8, + scale: i8, cast_options: &CastOptions, ) -> Result where @@ -1921,9 +1933,9 @@ fn cast_decimal_to_decimal_with_option< const BYTE_WIDTH2: usize, >( array: &ArrayRef, - input_scale: &u8, + input_scale: &i8, output_precision: &u8, - output_scale: &u8, + output_scale: &i8, cast_options: &CastOptions, ) -> Result { if cast_options.safe { @@ -1947,9 +1959,9 @@ fn cast_decimal_to_decimal_with_option< /// the array values when cast failures happen. fn cast_decimal_to_decimal_safe( array: &ArrayRef, - input_scale: &u8, + input_scale: &i8, output_precision: &u8, - output_scale: &u8, + output_scale: &i8, ) -> Result { if input_scale > output_scale { // For example, input_scale is 4 and output_scale is 3; @@ -2062,9 +2074,9 @@ fn cast_decimal_to_decimal_safe( array: &ArrayRef, - input_scale: &u8, + input_scale: &i8, output_precision: &u8, - output_scale: &u8, + output_scale: &i8, ) -> Result { if input_scale > output_scale { // For example, input_scale is 4 and output_scale is 3; @@ -3540,7 +3552,7 @@ mod tests { fn create_decimal_array( array: Vec>, precision: u8, - scale: u8, + scale: i8, ) -> Result { array .into_iter() @@ -3551,7 +3563,7 @@ mod tests { fn create_decimal256_array( array: Vec>, precision: u8, - scale: u8, + scale: i8, ) -> Result { array .into_iter() @@ -7206,4 +7218,62 @@ mod tests { err ); } + + #[test] + fn test_cast_decimal128_to_decimal128_negative_scale() { + let input_type = DataType::Decimal128(20, 0); + let output_type = DataType::Decimal128(20, -1); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let input_decimal_array = create_decimal_array(array, 20, 0).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(112345_i128), + Some(212345_i128), + Some(312345_i128), + None + ] + ); + + let casted_array = cast(&array, &output_type).unwrap(); + let decimal_arr = as_primitive_array::(&casted_array); + + assert_eq!("1123450", decimal_arr.value_as_string(0)); + assert_eq!("2123450", decimal_arr.value_as_string(1)); + assert_eq!("3123450", decimal_arr.value_as_string(2)); + } + + #[test] + fn test_cast_numeric_to_decimal128_negative() { + let decimal_type = DataType::Decimal128(38, -1); + let array = Arc::new(Int32Array::from(vec![ + Some(1123456), + Some(2123456), + Some(3123456), + ])) as ArrayRef; + + let casted_array = cast(&array, &decimal_type).unwrap(); + let decimal_arr = as_primitive_array::(&casted_array); + + assert_eq!("1123450", decimal_arr.value_as_string(0)); + assert_eq!("2123450", decimal_arr.value_as_string(1)); + assert_eq!("3123450", decimal_arr.value_as_string(2)); + + let array = Arc::new(Float32Array::from(vec![ + Some(1123.456), + Some(2123.456), + Some(3123.456), + ])) as ArrayRef; + + let casted_array = cast(&array, &decimal_type).unwrap(); + let decimal_arr = as_primitive_array::(&casted_array); + + assert_eq!("1120", decimal_arr.value_as_string(0)); + assert_eq!("2120", decimal_arr.value_as_string(1)); + assert_eq!("3120", decimal_arr.value_as_string(2)); + } } diff --git a/arrow-csv/src/reader.rs b/arrow-csv/src/reader.rs index 4200e9329c54..6432fb1b8017 100644 --- a/arrow-csv/src/reader.rs +++ b/arrow-csv/src/reader.rs @@ -721,7 +721,7 @@ fn build_decimal_array( rows: &[StringRecord], col_idx: usize, precision: u8, - scale: u8, + scale: i8, ) -> Result { let mut decimal_builder = Decimal128Builder::with_capacity(rows.len()); for row in rows { @@ -762,13 +762,13 @@ fn build_decimal_array( fn parse_decimal_with_parameter( s: &str, precision: u8, - scale: u8, + scale: i8, ) -> Result { if PARSE_DECIMAL_RE.is_match(s) { let mut offset = s.len(); let len = s.len(); let mut base = 1; - let scale_usize = usize::from(scale); + let scale_usize = usize::from(scale as u8); // handle the value after the '.' and meet the scale let delimiter_position = s.find('.'); diff --git a/arrow-data/src/decimal.rs b/arrow-data/src/decimal.rs index a6a08774941e..7011c40858c2 100644 --- a/arrow-data/src/decimal.rs +++ b/arrow-data/src/decimal.rs @@ -728,17 +728,17 @@ pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ pub const DECIMAL128_MAX_PRECISION: u8 = 38; /// The maximum scale for [arrow_schema::DataType::Decimal128] values -pub const DECIMAL128_MAX_SCALE: u8 = 38; +pub const DECIMAL128_MAX_SCALE: i8 = 38; /// The maximum precision for [arrow_schema::DataType::Decimal256] values pub const DECIMAL256_MAX_PRECISION: u8 = 76; /// The maximum scale for [arrow_schema::DataType::Decimal256] values -pub const DECIMAL256_MAX_SCALE: u8 = 76; +pub const DECIMAL256_MAX_SCALE: i8 = 76; /// The default scale for [arrow_schema::DataType::Decimal128] and /// [arrow_schema::DataType::Decimal256] values -pub const DECIMAL_DEFAULT_SCALE: u8 = 10; +pub const DECIMAL_DEFAULT_SCALE: i8 = 10; /// Validates that the specified `i128` value can be properly /// interpreted as a Decimal number with precision `precision` diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index cf85902e4ce7..f74e2a24b04f 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -190,14 +190,14 @@ pub enum DataType { /// * scale is the number of digits past the decimal /// /// For example the number 123.45 has precision 5 and scale 2. - Decimal128(u8, u8), + Decimal128(u8, i8), /// Exact 256-bit width decimal value with precision and scale /// /// * precision is the total number of digits /// * scale is the number of digits past the decimal /// /// For example the number 123.45 has precision 5 and scale 2. - Decimal256(u8, u8), + Decimal256(u8, i8), /// A Map is a logical nested type that is represented as /// /// `List>` diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index d498ae487c3e..857b6e3231ba 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -914,7 +914,7 @@ mod tests { options: Option, expected_data: Vec>, precision: &u8, - scale: &u8, + scale: &i8, ) -> Result<(), ArrowError> { let output = data .into_iter() @@ -1032,7 +1032,7 @@ mod tests { fn test_take_decimal128_non_null_indices() { let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]); let precision: u8 = 10; - let scale: u8 = 5; + let scale: i8 = 5; test_take_decimal_arrays( vec![None, Some(3), Some(5), Some(2), Some(3), None], &index, @@ -1048,7 +1048,7 @@ mod tests { fn test_take_decimal128() { let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); let precision: u8 = 10; - let scale: u8 = 5; + let scale: i8 = 5; test_take_decimal_arrays( vec![Some(0), Some(1), Some(2), Some(3), Some(4)], &index, diff --git a/arrow/benches/cast_kernels.rs b/arrow/benches/cast_kernels.rs index e93c7860885c..7ef4d1d7e74a 100644 --- a/arrow/benches/cast_kernels.rs +++ b/arrow/benches/cast_kernels.rs @@ -84,7 +84,7 @@ fn build_utf8_date_time_array(size: usize, with_nulls: bool) -> ArrayRef { Arc::new(builder.finish()) } -fn build_decimal128_array(size: usize, precision: u8, scale: u8) -> ArrayRef { +fn build_decimal128_array(size: usize, precision: u8, scale: i8) -> ArrayRef { let mut rng = seedable_rng(); let mut builder = Decimal128Builder::with_capacity(size); @@ -99,7 +99,7 @@ fn build_decimal128_array(size: usize, precision: u8, scale: u8) -> ArrayRef { ) } -fn build_decimal256_array(size: usize, precision: u8, scale: u8) -> ArrayRef { +fn build_decimal256_array(size: usize, precision: u8, scale: i8) -> ArrayRef { let mut rng = seedable_rng(); let mut builder = Decimal256Builder::with_capacity(size); let mut bytes = [0; 32]; diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs index ef303dfdd1ff..41addf24fbc2 100644 --- a/arrow/src/datatypes/ffi.rs +++ b/arrow/src/datatypes/ffi.rs @@ -103,7 +103,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType { "The decimal type requires an integer precision".to_string(), ) })?; - let parsed_scale = scale.parse::().map_err(|_| { + let parsed_scale = scale.parse::().map_err(|_| { ArrowError::CDataInterface( "The decimal type requires an integer scale".to_string(), ) @@ -119,7 +119,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType { "The decimal type requires an integer precision".to_string(), ) })?; - let parsed_scale = scale.parse::().map_err(|_| { + let parsed_scale = scale.parse::().map_err(|_| { ArrowError::CDataInterface( "The decimal type requires an integer scale".to_string(), ) diff --git a/arrow/tests/array_transform.rs b/arrow/tests/array_transform.rs index 42f9ab277d40..3c08a592dd2c 100644 --- a/arrow/tests/array_transform.rs +++ b/arrow/tests/array_transform.rs @@ -31,7 +31,7 @@ use std::sync::Arc; fn create_decimal_array( array: Vec>, precision: u8, - scale: u8, + scale: i8, ) -> Decimal128Array { array .into_iter()