From 0b6b8567e700f154bdbe869ccf79072ea7fb5c24 Mon Sep 17 00:00:00 2001 From: Li Yazhou Date: Tue, 9 Nov 2021 11:13:40 +0800 Subject: [PATCH] allow null array to be casted to all other types --- src/compute/cast/mod.rs | 80 ++++++++++++++++++++++++++++++++++++++-- tests/it/compute/cast.rs | 36 ++++++++++++++++++ 2 files changed, 112 insertions(+), 4 deletions(-) diff --git a/src/compute/cast/mod.rs b/src/compute/cast/mod.rs index 99c1542e792..2262c2534b4 100644 --- a/src/compute/cast/mod.rs +++ b/src/compute/cast/mod.rs @@ -78,6 +78,44 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } match (from_type, to_type) { + ( + Null, + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | List(_) + | Dictionary(_, _), + ) + | ( + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | List(_) + | Dictionary(_, _), + Null, + ) => true, (Struct(_), _) => false, (_, Struct(_)) => false, (List(list_from), List(list_to)) => { @@ -254,7 +292,6 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Timestamp(_, _), Date64) => true, (Int64, Duration(_)) => true, (Duration(_), Int64) => true, - (Null, Int32) => true, (_, _) => false, } } @@ -337,7 +374,44 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu let as_options = options.with_wrapped(true); match (from_type, to_type) { - (Null, Int32) => Ok(new_null_array(to_type.clone(), array.len())), + ( + Null, + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | List(_) + | Dictionary(_, _), + ) + | ( + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | List(_) + | Dictionary(_, _), + Null, + ) => Ok(new_null_array(to_type, array.len())), (Struct(_), _) => Err(ArrowError::NotYetImplemented( "Cannot cast from struct to other types".to_string(), )), @@ -790,8 +864,6 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Int64, Duration(_)) => primitive_to_same_primitive_dyn::(array, to_type), (Duration(_), Int64) => primitive_to_same_primitive_dyn::(array, to_type), - // null to primitive/flat types - //(Null, Int32) => Ok(Box::new(Int32Array::from(vec![None; array.len()]))), (_, _) => Err(ArrowError::NotYetImplemented(format!( "Casting from {:?} to {:?} not supported", from_type, to_type, diff --git a/tests/it/compute/cast.rs b/tests/it/compute/cast.rs index c1156553ee5..e5f9659c243 100644 --- a/tests/it/compute/cast.rs +++ b/tests/it/compute/cast.rs @@ -597,6 +597,42 @@ fn naive_timestamp_to_utf8() { assert_eq!(expected, result.as_ref()); } +#[test] +fn null_array_from_and_to_others() { + macro_rules! typed_test { + ($ARR_TYPE:ident, $DATATYPE:ident, $TYPE:tt) => {{ + { + let array = Arc::new(NullArray::new(6)) as ArrayRef; + let expected = $ARR_TYPE::from(vec![None; 6]); + let cast_type = DataType::$DATATYPE; + let cast_array = cast(&array, &cast_type).expect("cast failed"); + let cast_array = as_primitive_array::<$TYPE>(&cast_array); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(cast_array, &expected); + } + { + let array = Arc::new($ARR_TYPE::from(vec![None; 4])) as ArrayRef; + let expected = NullArray::new(4); + let cast_array = cast(&array, &DataType::Null).expect("cast failed"); + let cast_array = as_null_array(&cast_array); + assert_eq!(cast_array.data_type(), &DataType::Null); + assert_eq!(cast_array, &expected); + } + }}; + } + + typed_test!(Int16Array, Int16, Int16Type); + typed_test!(Int32Array, Int32, Int32Type); + typed_test!(Int64Array, Int64, Int64Type); + + typed_test!(UInt16Array, UInt16, UInt16Type); + typed_test!(UInt32Array, UInt32, UInt32Type); + typed_test!(UInt64Array, UInt64, UInt64Type); + + typed_test!(Float32Array, Float32, Float32Type); + typed_test!(Float64Array, Float64, Float64Type); +} + /* #[test] fn dict_to_dict_bad_index_value_primitive() {