diff --git a/src/compute/cast/dictionary_to.rs b/src/compute/cast/dictionary_to.rs index 01d3325cf9d..101669f6442 100644 --- a/src/compute/cast/dictionary_to.rs +++ b/src/compute/cast/dictionary_to.rs @@ -7,7 +7,7 @@ use crate::{ }; macro_rules! key_cast { - ($keys:expr, $values:expr, $array:expr, $to_keys_type:expr, $to_type:ty) => {{ + ($keys:expr, $values:expr, $array:expr, $to_keys_type:expr, $to_type:ty, $to_datatype:expr) => {{ let cast_keys = primitive_to_primitive::<_, $to_type>($keys, $to_keys_type); // Failure to cast keys (because they don't fit in the @@ -15,7 +15,11 @@ macro_rules! key_cast { if cast_keys.null_count() > $keys.null_count() { return Err(Error::Overflow); } - DictionaryArray::try_new($array.data_type().clone(), $keys.clone(), $values.clone()) + // Safety: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize` + unsafe { + DictionaryArray::try_new_unchecked($to_datatype, cast_keys, $values.clone()) + } + .map(|x| x.boxed()) }}; } @@ -88,8 +92,8 @@ where Box::new(values.data_type().clone()), is_ordered, ); - // some of the values may not fit in `usize` and thus this needs to be checked - DictionaryArray::try_new(data_type, casted_keys, values.clone()) + // Safety: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize` + unsafe { DictionaryArray::try_new_unchecked(data_type, casted_keys, values.clone()) } } } @@ -134,11 +138,13 @@ pub(super) fn dictionary_cast_dyn( let values = cast(values.as_ref(), to_values_type, options)?; // create the appropriate array type - let data_type = (*to_keys_type).into(); + let to_key_type = (*to_keys_type).into(); + + // Safety: + // we return an error on overflow so the integers remain within bounds match_integer_type!(to_keys_type, |$T| { - key_cast!(keys, values, array, &data_type, $T) + key_cast!(keys, values, array, &to_key_type, $T, to_type.clone()) }) - .map(|x| x.boxed()) } _ => unpack_dictionary::(keys, values.as_ref(), to_type, options), } diff --git a/tests/it/compute/cast.rs b/tests/it/compute/cast.rs index 79ce42c9904..9e9888926d3 100644 --- a/tests/it/compute/cast.rs +++ b/tests/it/compute/cast.rs @@ -805,3 +805,27 @@ fn utf8_to_date64() { assert_eq!(&expected, c); } + +#[test] +fn dict_keys() { + let mut array = MutableDictionaryArray::>::new(); + array + .try_extend([Some("one"), None, Some("three"), Some("one")]) + .unwrap(); + let array: DictionaryArray = array.into(); + + let result = cast( + &array, + &DataType::Dictionary(IntegerType::Int64, Box::new(DataType::Utf8), false), + CastOptions::default(), + ) + .expect("cast failed"); + + let mut expected = MutableDictionaryArray::>::new(); + expected + .try_extend([Some("one"), None, Some("three"), Some("one")]) + .unwrap(); + let expected: DictionaryArray = expected.into(); + + assert_eq!(expected, result.as_ref()); +}