From 29edd98dc690efd140c31964a494860a89073eea Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 7 Nov 2021 17:12:11 +0000 Subject: [PATCH] Simplified dictionary indexes. --- src/array/dictionary/mod.rs | 47 +++--- src/array/dictionary/mutable.rs | 10 +- src/array/display.rs | 32 ++-- src/array/equal/mod.rs | 2 +- src/array/ffi.rs | 2 +- src/array/growable/mod.rs | 36 ++--- src/array/mod.rs | 29 +--- src/array/ord.rs | 18 +-- src/compute/aggregate/memory.rs | 18 +-- src/compute/cast/dictionary_to.rs | 5 +- src/compute/cast/mod.rs | 31 +--- src/compute/sort/mod.rs | 32 +--- src/compute/take/mod.rs | 74 ++++------ src/datatypes/mod.rs | 27 ++-- src/datatypes/physical_type.rs | 6 +- src/ffi/array.rs | 2 +- src/ffi/bridge.rs | 2 +- src/ffi/schema.rs | 32 +++- src/io/avro/read/nested.rs | 2 +- src/io/avro/read/schema.rs | 2 +- src/io/csv/write/serialize.rs | 13 +- src/io/ipc/convert.rs | 35 +++-- src/io/ipc/read/deserialize.rs | 2 +- src/io/ipc/write/common.rs | 2 +- src/io/ipc/write/serialize.rs | 2 +- src/io/json/read/deserialize.rs | 2 +- src/io/json_integration/mod.rs | 41 +----- src/io/json_integration/read.rs | 2 +- src/io/json_integration/schema.rs | 234 +++++++----------------------- src/io/parquet/read/mod.rs | 14 +- src/io/parquet/write/mod.rs | 2 +- src/types/mod.rs | 1 - tests/it/compute/cast.rs | 8 +- tests/it/ffi.rs | 2 +- tests/it/io/avro/read/mod.rs | 2 +- tests/it/io/csv/write.rs | 2 +- tests/it/io/json/mod.rs | 2 +- tests/it/io/print.rs | 4 +- 38 files changed, 272 insertions(+), 507 deletions(-) diff --git a/src/array/dictionary/mod.rs b/src/array/dictionary/mod.rs index 012d937d6d9..9a8147a36a6 100644 --- a/src/array/dictionary/mod.rs +++ b/src/array/dictionary/mod.rs @@ -2,9 +2,9 @@ use std::sync::Arc; use crate::{ bitmap::Bitmap, - datatypes::DataType, + datatypes::{DataType, IntegerType}, scalar::{new_scalar, Scalar}, - types::{NativeType, NaturalDataType}, + types::NativeType, }; mod ffi; @@ -16,19 +16,35 @@ pub use mutable::*; use super::{new_empty_array, primitive::PrimitiveArray, Array}; /// Trait denoting [`NativeType`]s that can be used as keys of a dictionary. -pub trait DictionaryKey: - NativeType + NaturalDataType + num_traits::NumCast + num_traits::FromPrimitive -{ +pub trait DictionaryKey: NativeType + num_traits::NumCast + num_traits::FromPrimitive { + /// The corresponding [`IntegerType`] of this key + const KEY_TYPE: IntegerType; } -impl DictionaryKey for i8 {} -impl DictionaryKey for i16 {} -impl DictionaryKey for i32 {} -impl DictionaryKey for i64 {} -impl DictionaryKey for u8 {} -impl DictionaryKey for u16 {} -impl DictionaryKey for u32 {} -impl DictionaryKey for u64 {} +impl DictionaryKey for i8 { + const KEY_TYPE: IntegerType = IntegerType::Int8; +} +impl DictionaryKey for i16 { + const KEY_TYPE: IntegerType = IntegerType::Int16; +} +impl DictionaryKey for i32 { + const KEY_TYPE: IntegerType = IntegerType::Int32; +} +impl DictionaryKey for i64 { + const KEY_TYPE: IntegerType = IntegerType::Int64; +} +impl DictionaryKey for u8 { + const KEY_TYPE: IntegerType = IntegerType::UInt8; +} +impl DictionaryKey for u16 { + const KEY_TYPE: IntegerType = IntegerType::UInt16; +} +impl DictionaryKey for u32 { + const KEY_TYPE: IntegerType = IntegerType::UInt32; +} +impl DictionaryKey for u64 { + const KEY_TYPE: IntegerType = IntegerType::UInt64; +} /// An [`Array`] whose values are encoded by keys. This [`Array`] is useful when the cardinality of /// values is low compared to the length of the [`Array`]. @@ -59,10 +75,7 @@ impl DictionaryArray { /// The canonical method to create a new [`DictionaryArray`]. pub fn from_data(keys: PrimitiveArray, values: Arc) -> Self { - let data_type = DataType::Dictionary( - Box::new(keys.data_type().clone()), - Box::new(values.data_type().clone()), - ); + let data_type = DataType::Dictionary(K::KEY_TYPE, Box::new(values.data_type().clone())); Self { data_type, diff --git a/src/array/dictionary/mutable.rs b/src/array/dictionary/mutable.rs index 895bec0f1bb..623b6ad4ead 100644 --- a/src/array/dictionary/mutable.rs +++ b/src/array/dictionary/mutable.rs @@ -31,10 +31,7 @@ impl From> for D impl From for MutableDictionaryArray { fn from(values: M) -> Self { Self { - data_type: DataType::Dictionary( - Box::new(K::DATA_TYPE), - Box::new(values.data_type().clone()), - ), + data_type: DataType::Dictionary(K::KEY_TYPE, Box::new(values.data_type().clone())), keys: MutablePrimitiveArray::::new(), map: HashedMap::default(), values, @@ -47,10 +44,7 @@ impl MutableDictionaryArray { pub fn new() -> Self { let values = M::default(); Self { - data_type: DataType::Dictionary( - Box::new(K::DATA_TYPE), - Box::new(values.data_type().clone()), - ), + data_type: DataType::Dictionary(K::KEY_TYPE, Box::new(values.data_type().clone())), keys: MutablePrimitiveArray::::new(), map: HashedMap::default(), values, diff --git a/src/array/display.rs b/src/array/display.rs index ef65bb8aa95..a69b3342038 100644 --- a/src/array/display.rs +++ b/src/array/display.rs @@ -17,18 +17,6 @@ macro_rules! dyn_primitive { }}; } -macro_rules! dyn_dict { - ($array:expr, $ty:ty) => {{ - let a = $array - .as_any() - .downcast_ref::>() - .unwrap(); - let keys = a.keys(); - let display = get_display(a.values().as_ref()); - Box::new(move |row: usize| display(keys.value(row) as usize)) - }}; -} - /// Returns a function of index returning the string representation of the _value_ of `array`. /// This does not take nulls into account. pub fn get_value_display<'a>(array: &'a dyn Array) -> Box String + 'a> { @@ -170,17 +158,15 @@ pub fn get_value_display<'a>(array: &'a dyn Array) -> Box Strin }; dyn_display!(array, ListArray, f) } - Dictionary(key_type, _) => match key_type.as_ref() { - DataType::Int8 => dyn_dict!(array, i8), - DataType::Int16 => dyn_dict!(array, i16), - DataType::Int32 => dyn_dict!(array, i32), - DataType::Int64 => dyn_dict!(array, i64), - DataType::UInt8 => dyn_dict!(array, u8), - DataType::UInt16 => dyn_dict!(array, u16), - DataType::UInt32 => dyn_dict!(array, u32), - DataType::UInt64 => dyn_dict!(array, u64), - _ => unreachable!(), - }, + Dictionary(key_type, _) => match_integer_type!(key_type, |$T| { + let a = array + .as_any() + .downcast_ref::>() + .unwrap(); + let keys = a.keys(); + let display = get_display(a.values().as_ref()); + Box::new(move |row: usize| display(keys.value(row) as usize)) + }), Map(_, _) => todo!(), Struct(_) => { let a = array.as_any().downcast_ref::().unwrap(); diff --git a/src/array/equal/mod.rs b/src/array/equal/mod.rs index 7aa200a8b44..ed0dae28075 100644 --- a/src/array/equal/mod.rs +++ b/src/array/equal/mod.rs @@ -215,7 +215,7 @@ pub fn equal(lhs: &dyn Array, rhs: &dyn Array) -> bool { struct_::equal(lhs, rhs) } Dictionary(key_type) => { - with_match_physical_dictionary_key_type!(key_type, |$T| { + match_integer_type!(key_type, |$T| { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref().unwrap(); dictionary::equal::<$T>(lhs, rhs) diff --git a/src/array/ffi.rs b/src/array/ffi.rs index 2fbca0493d5..a9d0a089074 100644 --- a/src/array/ffi.rs +++ b/src/array/ffi.rs @@ -74,7 +74,7 @@ pub fn offset_buffers_children_dictionary(array: &dyn Array) -> BuffersChildren Union => ffi_dyn!(array, UnionArray), Map => ffi_dyn!(array, MapArray), Dictionary(key_type) => { - with_match_physical_dictionary_key_type!(key_type, |$T| { + match_integer_type!(key_type, |$T| { let array = array.as_any().downcast_ref::>().unwrap(); ( array.offset().unwrap(), diff --git a/src/array/growable/mod.rs b/src/array/growable/mod.rs index e3796167968..c6d45356237 100644 --- a/src/array/growable/mod.rs +++ b/src/array/growable/mod.rs @@ -61,25 +61,6 @@ macro_rules! dyn_growable { }}; } -macro_rules! dyn_dict_growable { - ($ty:ty, $arrays:expr, $use_validity:expr, $capacity:expr) => {{ - let arrays = $arrays - .iter() - .map(|array| { - array - .as_any() - .downcast_ref::>() - .unwrap() - }) - .collect::>(); - Box::new(dictionary::GrowableDictionary::<$ty>::new( - &arrays, - $use_validity, - $capacity, - )) - }}; -} - /// Creates a new [`Growable`] from an arbitrary number of [`Array`]s. /// # Panics /// This function panics iff @@ -132,8 +113,21 @@ pub fn make_growable<'a>( ), Union | Map => todo!(), Dictionary(key_type) => { - with_match_physical_dictionary_key_type!(key_type, |$T| { - dyn_dict_growable!($T, arrays, use_validity, capacity) + match_integer_type!(key_type, |$T| { + let arrays = arrays + .iter() + .map(|array| { + array + .as_any() + .downcast_ref::>() + .unwrap() + }) + .collect::>(); + Box::new(dictionary::GrowableDictionary::<$T>::new( + &arrays, + use_validity, + capacity, + )) }) } } diff --git a/src/array/mod.rs b/src/array/mod.rs index fa8caf7002c..c9ecb890979 100644 --- a/src/array/mod.rs +++ b/src/array/mod.rs @@ -174,28 +174,11 @@ macro_rules! fmt_dyn { }}; } -macro_rules! with_match_dictionary_key_type {( +macro_rules! match_integer_type {( $key_type:expr, | $_:tt $T:ident | $($body:tt)* ) => ({ macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} - match $key_type { - DataType::Int8 => __with_ty__! { i8 }, - DataType::Int16 => __with_ty__! { i16 }, - DataType::Int32 => __with_ty__! { i32 }, - DataType::Int64 => __with_ty__! { i64 }, - DataType::UInt8 => __with_ty__! { u8 }, - DataType::UInt16 => __with_ty__! { u16 }, - DataType::UInt32 => __with_ty__! { u32 }, - DataType::UInt64 => __with_ty__! { u64 }, - _ => ::core::unreachable!("A dictionary key type can only be of integer types"), - } -})} - -macro_rules! with_match_physical_dictionary_key_type {( - $key_type:expr, | $_:tt $T:ident | $($body:tt)* -) => ({ - macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} - use crate::datatypes::DictionaryIndexType::*; + use crate::datatypes::IntegerType::*; match $key_type { Int8 => __with_ty__! { i8 }, Int16 => __with_ty__! { i16 }, @@ -251,7 +234,7 @@ impl Display for dyn Array { Struct => fmt_dyn!(self, StructArray, f), Union => fmt_dyn!(self, UnionArray, f), Dictionary(key_type) => { - with_match_physical_dictionary_key_type!(key_type, |$T| { + match_integer_type!(key_type, |$T| { fmt_dyn!(self, DictionaryArray::<$T>, f) }) } @@ -281,7 +264,7 @@ pub fn new_empty_array(data_type: DataType) -> Box { Union => Box::new(UnionArray::new_empty(data_type)), Map => Box::new(MapArray::new_empty(data_type)), Dictionary(key_type) => { - with_match_physical_dictionary_key_type!(key_type, |$T| { + match_integer_type!(key_type, |$T| { Box::new(DictionaryArray::<$T>::new_empty(data_type)) }) } @@ -311,7 +294,7 @@ pub fn new_null_array(data_type: DataType, length: usize) -> Box { Union => Box::new(UnionArray::new_null(data_type, length)), Map => Box::new(MapArray::new_null(data_type, length)), Dictionary(key_type) => { - with_match_physical_dictionary_key_type!(key_type, |$T| { + match_integer_type!(key_type, |$T| { Box::new(DictionaryArray::<$T>::new_null(data_type, length)) }) } @@ -349,7 +332,7 @@ pub fn clone(array: &dyn Array) -> Box { Union => clone_dyn!(array, UnionArray), Map => clone_dyn!(array, MapArray), Dictionary(key_type) => { - with_match_physical_dictionary_key_type!(key_type, |$T| { + match_integer_type!(key_type, |$T| { clone_dyn!(array, DictionaryArray::<$T>) }) } diff --git a/src/array/ord.rs b/src/array/ord.rs index 319af374ab1..31c36d770e5 100644 --- a/src/array/ord.rs +++ b/src/array/ord.rs @@ -216,15 +216,15 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result compare_binary::(left, right), (LargeBinary, LargeBinary) => compare_binary::(left, right), (Dictionary(key_type_lhs, _), Dictionary(key_type_rhs, _)) => { - match (key_type_lhs.as_ref(), key_type_rhs.as_ref()) { - (UInt8, UInt8) => dyn_dict!(u8, left, right), - (UInt16, UInt16) => dyn_dict!(u16, left, right), - (UInt32, UInt32) => dyn_dict!(u32, left, right), - (UInt64, UInt64) => dyn_dict!(u64, left, right), - (Int8, Int8) => dyn_dict!(i8, left, right), - (Int16, Int16) => dyn_dict!(i16, left, right), - (Int32, Int32) => dyn_dict!(i32, left, right), - (Int64, Int64) => dyn_dict!(i64, left, right), + match (key_type_lhs, key_type_rhs) { + (IntegerType::UInt8, IntegerType::UInt8) => dyn_dict!(u8, left, right), + (IntegerType::UInt16, IntegerType::UInt16) => dyn_dict!(u16, left, right), + (IntegerType::UInt32, IntegerType::UInt32) => dyn_dict!(u32, left, right), + (IntegerType::UInt64, IntegerType::UInt64) => dyn_dict!(u64, left, right), + (IntegerType::Int8, IntegerType::Int8) => dyn_dict!(i8, left, right), + (IntegerType::Int16, IntegerType::Int16) => dyn_dict!(i16, left, right), + (IntegerType::Int32, IntegerType::Int32) => dyn_dict!(i32, left, right), + (IntegerType::Int64, IntegerType::Int64) => dyn_dict!(i64, left, right), (lhs, _) => { return Err(ArrowError::InvalidArgumentError(format!( "Dictionaries do not support keys of type {:?}", diff --git a/src/compute/aggregate/memory.rs b/src/compute/aggregate/memory.rs index 036b2b5ebfe..8c8617d936c 100644 --- a/src/compute/aggregate/memory.rs +++ b/src/compute/aggregate/memory.rs @@ -16,16 +16,6 @@ macro_rules! dyn_binary { }}; } -macro_rules! dyn_dict { - ($array:expr, $ty:ty) => {{ - let array = $array - .as_any() - .downcast_ref::>() - .unwrap(); - estimated_bytes_size(array.keys()) + estimated_bytes_size(array.values().as_ref()) - }}; -} - /// Returns the total (heap) allocated size of the array in bytes. /// # Implementation /// This estimation is the sum of the size of its buffers, validity, including nested arrays. @@ -106,8 +96,12 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { .sum::(); types + offsets + fields } - Dictionary(key_type) => with_match_physical_dictionary_key_type!(key_type, |$T| { - dyn_dict!(array, $T) + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + estimated_bytes_size(array.keys()) + estimated_bytes_size(array.values().as_ref()) }), Map => { let array = array.as_any().downcast_ref::().unwrap(); diff --git a/src/compute/cast/dictionary_to.rs b/src/compute/cast/dictionary_to.rs index 77c696ec915..375d69546d0 100644 --- a/src/compute/cast/dictionary_to.rs +++ b/src/compute/cast/dictionary_to.rs @@ -114,8 +114,9 @@ pub(super) fn dictionary_cast_dyn( let values = cast(values.as_ref(), to_values_type, options)?.into(); // create the appropriate array type - with_match_dictionary_key_type!(to_keys_type.as_ref(), |$T| { - key_cast!(keys, values, array, to_keys_type, $T) + let data_type = (*to_keys_type).into(); + match_integer_type!(to_keys_type, |$T| { + key_cast!(keys, values, array, &data_type, $T) }) } _ => unpack_dictionary::(keys, values.as_ref(), to_type, options), diff --git a/src/compute/cast/mod.rs b/src/compute/cast/mod.rs index 99c1542e792..04c4d898b42 100644 --- a/src/compute/cast/mod.rs +++ b/src/compute/cast/mod.rs @@ -375,31 +375,12 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu Ok(Box::new(list_array)) } - (Dictionary(index_type, _), _) => match **index_type { - DataType::Int8 => dictionary_cast_dyn::(array, to_type, options), - DataType::Int16 => dictionary_cast_dyn::(array, to_type, options), - DataType::Int32 => dictionary_cast_dyn::(array, to_type, options), - DataType::Int64 => dictionary_cast_dyn::(array, to_type, options), - DataType::UInt8 => dictionary_cast_dyn::(array, to_type, options), - DataType::UInt16 => dictionary_cast_dyn::(array, to_type, options), - DataType::UInt32 => dictionary_cast_dyn::(array, to_type, options), - DataType::UInt64 => dictionary_cast_dyn::(array, to_type, options), - _ => unreachable!(), - }, - (_, Dictionary(index_type, value_type)) => match **index_type { - DataType::Int8 => cast_to_dictionary::(array, value_type, options), - DataType::Int16 => cast_to_dictionary::(array, value_type, options), - DataType::Int32 => cast_to_dictionary::(array, value_type, options), - DataType::Int64 => cast_to_dictionary::(array, value_type, options), - DataType::UInt8 => cast_to_dictionary::(array, value_type, options), - DataType::UInt16 => cast_to_dictionary::(array, value_type, options), - DataType::UInt32 => cast_to_dictionary::(array, value_type, options), - DataType::UInt64 => cast_to_dictionary::(array, value_type, options), - _ => Err(ArrowError::NotYetImplemented(format!( - "Casting from type {:?} to dictionary type {:?} not supported", - from_type, to_type, - ))), - }, + (Dictionary(index_type, _), _) => match_integer_type!(index_type, |$T| { + dictionary_cast_dyn::<$T>(array, to_type, options) + }), + (_, Dictionary(index_type, value_type)) => match_integer_type!(index_type, |$T| { + cast_to_dictionary::<$T>(array, value_type, options) + }), (_, Boolean) => match from_type { UInt8 => primitive_to_boolean_dyn::(array, to_type.clone()), UInt16 => primitive_to_boolean_dyn::(array, to_type.clone()), diff --git a/src/compute/sort/mod.rs b/src/compute/sort/mod.rs index 62113946fa2..af41e27efe3 100644 --- a/src/compute/sort/mod.rs +++ b/src/compute/sort/mod.rs @@ -204,18 +204,8 @@ pub fn sort_to_indices( } } DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Utf8 => Ok(sort_dict::( - values, - key_type.as_ref(), - options, - limit, - )), - DataType::LargeUtf8 => Ok(sort_dict::( - values, - key_type.as_ref(), - options, - limit, - )), + DataType::Utf8 => Ok(sort_dict::(values, key_type, options, limit)), + DataType::LargeUtf8 => Ok(sort_dict::(values, key_type, options, limit)), t => Err(ArrowError::NotYetImplemented(format!( "Sort not supported for dictionary type with keys {:?}", t @@ -230,11 +220,11 @@ pub fn sort_to_indices( fn sort_dict( values: &dyn Array, - key_type: &DataType, + key_type: &IntegerType, options: &SortOptions, limit: Option, ) -> PrimitiveArray { - with_match_dictionary_key_type!(key_type, |$T| { + match_integer_type!(key_type, |$T| { utf8::indices_sorted_unstable_by_dictionary::( values.as_any().downcast_ref().unwrap(), options, @@ -293,18 +283,8 @@ pub fn can_sort(data_type: &DataType) -> bool { | DataType::UInt64 ) } - DataType::Dictionary(key_type, value_type) if *value_type.as_ref() == DataType::Utf8 => { - matches!( - key_type.as_ref(), - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - ) + DataType::Dictionary(_, value_type) => { + matches!(*value_type.as_ref(), DataType::Utf8 | DataType::LargeUtf8) } _ => false, } diff --git a/src/compute/take/mod.rs b/src/compute/take/mod.rs index 4f5cb273c83..d07a4873268 100644 --- a/src/compute/take/mod.rs +++ b/src/compute/take/mod.rs @@ -71,7 +71,7 @@ pub fn take(values: &dyn Array, indices: &PrimitiveArray) -> Result Ok(Box::new(binary::take::(values, indices))) } Dictionary(key_type) => { - with_match_physical_dictionary_key_type!(key_type, |$T| { + match_integer_type!(key_type, |$T| { let values = values.as_any().downcast_ref().unwrap(); Ok(Box::new(dict::take::<$T, _>(&values, indices))) }) @@ -103,46 +103,36 @@ pub fn take(values: &dyn Array, indices: &PrimitiveArray) -> Result /// assert_eq!(can_take(&data_type), true); /// ``` pub fn can_take(data_type: &DataType) -> bool { - match data_type { + matches!( + data_type, DataType::Null - | DataType::Boolean - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Date32 - | DataType::Time32(_) - | DataType::Interval(_) - | DataType::Int64 - | DataType::Date64 - | DataType::Time64(_) - | DataType::Duration(_) - | DataType::Timestamp(_, _) - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float16 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal(_, _) - | DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Binary - | DataType::LargeBinary - | DataType::Struct(_) - | DataType::List(_) - | DataType::LargeList(_) => true, - DataType::Dictionary(key_type, _) => matches!( - key_type.as_ref(), - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - ), - _ => false, - } + | DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(_) + | DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Timestamp(_, _) + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal(_, _) + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::Struct(_) + | DataType::List(_) + | DataType::LargeList(_) + | DataType::Dictionary(_, _) + ) } diff --git a/src/datatypes/mod.rs b/src/datatypes/mod.rs index 81ab5f71b96..8ab7ef579f8 100644 --- a/src/datatypes/mod.rs +++ b/src/datatypes/mod.rs @@ -128,7 +128,7 @@ pub enum DataType { /// /// This type mostly used to represent low cardinality string /// arrays or a limited set of primitive types as integers. - Dictionary(Box, Box), + Dictionary(IntegerType, Box), /// Decimal value with precision and scale /// precision is the number of digits in the number and /// scale is the number of decimal places. @@ -267,7 +267,7 @@ impl DataType { Struct(_) => PhysicalType::Struct, Union(_, _, _) => PhysicalType::Union, Map(_, _) => PhysicalType::Map, - Dictionary(key, _) => PhysicalType::Dictionary(to_dictionary_index_type(key.as_ref())), + Dictionary(key, _) => PhysicalType::Dictionary(*key), Extension(_, key, _) => key.to_physical_type(), } } @@ -284,17 +284,18 @@ impl DataType { } } -fn to_dictionary_index_type(data_type: &DataType) -> DictionaryIndexType { - match data_type { - DataType::Int8 => DictionaryIndexType::Int8, - DataType::Int16 => DictionaryIndexType::Int16, - DataType::Int32 => DictionaryIndexType::Int32, - DataType::Int64 => DictionaryIndexType::Int64, - DataType::UInt8 => DictionaryIndexType::UInt8, - DataType::UInt16 => DictionaryIndexType::UInt16, - DataType::UInt32 => DictionaryIndexType::UInt32, - DataType::UInt64 => DictionaryIndexType::UInt64, - _ => ::core::unreachable!("A dictionary key type can only be of integer types"), +impl From for DataType { + fn from(item: IntegerType) -> Self { + match item { + IntegerType::Int8 => DataType::Int8, + IntegerType::Int16 => DataType::Int16, + IntegerType::Int32 => DataType::Int32, + IntegerType::Int64 => DataType::Int64, + IntegerType::UInt8 => DataType::UInt8, + IntegerType::UInt16 => DataType::UInt16, + IntegerType::UInt32 => DataType::UInt32, + IntegerType::UInt64 => DataType::UInt64, + } } } diff --git a/src/datatypes/physical_type.rs b/src/datatypes/physical_type.rs index d950c0b426f..cb913429200 100644 --- a/src/datatypes/physical_type.rs +++ b/src/datatypes/physical_type.rs @@ -31,8 +31,8 @@ pub enum PhysicalType { Union, /// A nested type. Map, - /// A dictionary encoded array by `DictionaryIndexType`. - Dictionary(DictionaryIndexType), + /// A dictionary encoded array by `IntegerType`. + Dictionary(IntegerType), } /// The set of all (physical) primitive types. @@ -70,7 +70,7 @@ pub enum PrimitiveType { /// the set of valid indices types of a dictionary-encoded Array. /// Each type corresponds to a variant of [`crate::array::DictionaryArray`]. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum DictionaryIndexType { +pub enum IntegerType { /// A signed 8-bit integer. Int8, /// A signed 16-bit integer. diff --git a/src/ffi/array.rs b/src/ffi/array.rs index fb2c6cd4446..c87314dc436 100644 --- a/src/ffi/array.rs +++ b/src/ffi/array.rs @@ -27,7 +27,7 @@ pub unsafe fn try_from(array: A) -> Result> { FixedSizeList => Box::new(FixedSizeListArray::try_from_ffi(array)?), Struct => Box::new(StructArray::try_from_ffi(array)?), Dictionary(key_type) => { - with_match_physical_dictionary_key_type!(key_type, |$T| { + match_integer_type!(key_type, |$T| { Box::new(DictionaryArray::<$T>::try_from_ffi(array)?) }) } diff --git a/src/ffi/bridge.rs b/src/ffi/bridge.rs index 53b9167386f..1a51e56e780 100644 --- a/src/ffi/bridge.rs +++ b/src/ffi/bridge.rs @@ -33,7 +33,7 @@ pub fn align_to_c_data_interface(array: Arc) -> Arc { Union => ffi_dyn!(array, UnionArray), Map => ffi_dyn!(array, MapArray), Dictionary(key_type) => { - with_match_physical_dictionary_key_type!(key_type, |$T| { + match_integer_type!(key_type, |$T| { ffi_dyn!(array, DictionaryArray<$T>) }) } diff --git a/src/ffi/schema.rs b/src/ffi/schema.rs index ee5cb0fef1d..45fd6d9a587 100644 --- a/src/ffi/schema.rs +++ b/src/ffi/schema.rs @@ -1,7 +1,9 @@ use std::{collections::BTreeMap, convert::TryInto, ffi::CStr, ffi::CString, ptr}; use crate::{ - datatypes::{DataType, Extension, Field, IntervalUnit, Metadata, TimeUnit, UnionMode}, + datatypes::{ + DataType, Extension, Field, IntegerType, IntervalUnit, Metadata, TimeUnit, UnionMode, + }, error::{ArrowError, Result}, }; @@ -210,12 +212,9 @@ impl Drop for Ffi_ArrowSchema { pub(crate) unsafe fn to_field(schema: &Ffi_ArrowSchema) -> Result { let dictionary = schema.dictionary(); let data_type = if let Some(dictionary) = dictionary { - let indices_data_type = to_data_type(schema)?; + let indices = to_integer_type(schema.format())?; let values = to_field(dictionary)?; - DataType::Dictionary( - Box::new(indices_data_type), - Box::new(values.data_type().clone()), - ) + DataType::Dictionary(indices, Box::new(values.data_type().clone())) } else { to_data_type(schema)? }; @@ -232,6 +231,25 @@ pub(crate) unsafe fn to_field(schema: &Ffi_ArrowSchema) -> Result { Ok(field) } +fn to_integer_type(format: &str) -> Result { + use IntegerType::*; + Ok(match format { + "c" => Int8, + "C" => UInt8, + "s" => Int16, + "S" => UInt16, + "i" => Int32, + "I" => UInt32, + "l" => Int64, + "L" => UInt64, + _ => { + return Err(ArrowError::Ffi( + "Dictionary indices can only be integers".to_string(), + )) + } + }) +} + unsafe fn to_data_type(schema: &Ffi_ArrowSchema) -> Result { Ok(match schema.format() { "n" => DataType::Null, @@ -425,7 +443,7 @@ fn to_format(data_type: &DataType) -> String { r } DataType::Map(_, _) => "+m".to_string(), - DataType::Dictionary(index, _) => to_format(index.as_ref()), + DataType::Dictionary(index, _) => to_format(&(*index).into()), DataType::Extension(_, inner, _) => to_format(inner.as_ref()), } } diff --git a/src/io/avro/read/nested.rs b/src/io/avro/read/nested.rs index 3dc89a7c2d4..30d6496f4a5 100644 --- a/src/io/avro/read/nested.rs +++ b/src/io/avro/read/nested.rs @@ -137,7 +137,7 @@ impl FixedItemsUtf8Dictionary { pub fn with_capacity(values: Utf8Array, capacity: usize) -> Self { Self { data_type: DataType::Dictionary( - Box::new(DataType::Int32), + IntegerType::Int32, Box::new(values.data_type().clone()), ), keys: MutablePrimitiveArray::::with_capacity(capacity), diff --git a/src/io/avro/read/schema.rs b/src/io/avro/read/schema.rs index 7e0e8551c1a..fdb9dba57a2 100644 --- a/src/io/avro/read/schema.rs +++ b/src/io/avro/read/schema.rs @@ -169,7 +169,7 @@ fn schema_to_field( AvroSchema::Enum { .. } => { return Ok(Field::new( name.unwrap_or_default(), - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + DataType::Dictionary(IntegerType::Int32, Box::new(DataType::Utf8)), false, )) } diff --git a/src/io/csv/write/serialize.rs b/src/io/csv/write/serialize.rs index 0a3e6aeaf86..a8577a644cd 100644 --- a/src/io/csv/write/serialize.rs +++ b/src/io/csv/write/serialize.rs @@ -1,5 +1,6 @@ use lexical_core::ToLexical; +use crate::datatypes::IntegerType; use crate::temporal_conversions; use crate::types::{Index, NativeType}; use crate::util::lexical_to_bytes_mut; @@ -268,14 +269,14 @@ pub fn new_serializer<'a>( )) } DataType::Dictionary(keys_dt, values_dt) => match &**values_dt { - DataType::LargeUtf8 => match &**keys_dt { - DataType::UInt32 => serialize_utf8_dict::(array.as_any()), - DataType::UInt64 => serialize_utf8_dict::(array.as_any()), + DataType::LargeUtf8 => match *keys_dt { + IntegerType::UInt32 => serialize_utf8_dict::(array.as_any()), + IntegerType::UInt64 => serialize_utf8_dict::(array.as_any()), _ => todo!(), }, - DataType::Utf8 => match &**keys_dt { - DataType::UInt32 => serialize_utf8_dict::(array.as_any()), - DataType::UInt64 => serialize_utf8_dict::(array.as_any()), + DataType::Utf8 => match *keys_dt { + IntegerType::UInt32 => serialize_utf8_dict::(array.as_any()), + IntegerType::UInt64 => serialize_utf8_dict::(array.as_any()), _ => todo!(), }, _ => { diff --git a/src/io/ipc/convert.rs b/src/io/ipc/convert.rs index ac78bdf7a67..c3667baac04 100644 --- a/src/io/ipc/convert.rs +++ b/src/io/ipc/convert.rs @@ -28,7 +28,8 @@ mod ipc { } use crate::datatypes::{ - get_extension, DataType, Extension, Field, IntervalUnit, Metadata, Schema, TimeUnit, UnionMode, + get_extension, DataType, Extension, Field, IntegerType, IntervalUnit, Metadata, Schema, + TimeUnit, UnionMode, }; use crate::io::ipc::endianess::is_native_little_endian; @@ -142,18 +143,18 @@ fn get_data_type(field: ipc::Field, extension: Extension, may_be_dictionary: boo if may_be_dictionary { let int = dictionary.indexType().unwrap(); let index_type = match (int.bitWidth(), int.is_signed()) { - (8, true) => DataType::Int8, - (8, false) => DataType::UInt8, - (16, true) => DataType::Int16, - (16, false) => DataType::UInt16, - (32, true) => DataType::Int32, - (32, false) => DataType::UInt32, - (64, true) => DataType::Int64, - (64, false) => DataType::UInt64, + (8, true) => IntegerType::Int8, + (8, false) => IntegerType::UInt8, + (16, true) => IntegerType::Int16, + (16, false) => IntegerType::UInt16, + (32, true) => IntegerType::Int32, + (32, false) => IntegerType::UInt32, + (64, true) => IntegerType::Int64, + (64, false) => IntegerType::UInt64, _ => panic!("Unexpected bitwidth and signed"), }; return DataType::Dictionary( - Box::new(index_type), + index_type, Box::new(get_data_type(field, extension, false)), ); } @@ -740,28 +741,26 @@ pub(crate) fn get_fb_field_type<'a>( /// Create an IPC dictionary encoding pub(crate) fn get_fb_dictionary<'a>( - index_type: &DataType, + index_type: &IntegerType, dict_id: i64, dict_is_ordered: bool, fbb: &mut FlatBufferBuilder<'a>, ) -> WIPOffset> { - use DataType::*; + use IntegerType::*; // We assume that the dictionary index type (as an integer) has already been // validated elsewhere, and can safely assume we are dealing with integers let mut index_builder = ipc::IntBuilder::new(fbb); - match *index_type { + match index_type { Int8 | Int16 | Int32 | Int64 => index_builder.add_is_signed(true), UInt8 | UInt16 | UInt32 | UInt64 => index_builder.add_is_signed(false), - _ => {} } - match *index_type { + match index_type { Int8 | UInt8 => index_builder.add_bitWidth(8), Int16 | UInt16 => index_builder.add_bitWidth(16), Int32 | UInt32 => index_builder.add_bitWidth(32), Int64 | UInt64 => index_builder.add_bitWidth(64), - _ => {} } let index_builder = index_builder.finish(); @@ -908,14 +907,14 @@ mod tests { Field::new("struct<>", DataType::Struct(vec![]), true), Field::new_dict( "dictionary", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + DataType::Dictionary(IntegerType::Int32, Box::new(DataType::Utf8)), true, 123, true, ), Field::new_dict( "dictionary", - DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt32)), + DataType::Dictionary(IntegerType::UInt8, Box::new(DataType::UInt32)), true, 123, true, diff --git a/src/io/ipc/read/deserialize.rs b/src/io/ipc/read/deserialize.rs index e0fe6f7b755..1d243e5256d 100644 --- a/src/io/ipc/read/deserialize.rs +++ b/src/io/ipc/read/deserialize.rs @@ -164,7 +164,7 @@ pub fn read( ) .map(|x| Arc::new(x) as Arc), Dictionary(key_type) => { - with_match_physical_dictionary_key_type!(key_type, |$T| { + match_integer_type!(key_type, |$T| { read_dictionary::<$T, _>( field_nodes, buffers, diff --git a/src/io/ipc/write/common.rs b/src/io/ipc/write/common.rs index c557b5bf147..ae6a3986562 100644 --- a/src/io/ipc/write/common.rs +++ b/src/io/ipc/write/common.rs @@ -251,7 +251,7 @@ impl DictionaryTracker { pub fn insert(&mut self, dict_id: i64, array: &Arc) -> Result { let values = match array.data_type() { DataType::Dictionary(key_type, _) => { - with_match_dictionary_key_type!(key_type.as_ref(), |$T| { + match_integer_type!(key_type, |$T| { let array = array .as_any() .downcast_ref::>() diff --git a/src/io/ipc/write/serialize.rs b/src/io/ipc/write/serialize.rs index 6b1078e2224..291f1d9926a 100644 --- a/src/io/ipc/write/serialize.rs +++ b/src/io/ipc/write/serialize.rs @@ -471,7 +471,7 @@ pub fn write_dictionary( ) -> usize { match array.data_type() { DataType::Dictionary(key_type, _) => { - with_match_dictionary_key_type!(key_type.as_ref(), |$T| { + match_integer_type!(key_type, |$T| { _write_dictionary::<$T>( array, buffers, diff --git a/src/io/json/read/deserialize.rs b/src/io/json/read/deserialize.rs index 8ac22d6dca5..64da0d775fc 100644 --- a/src/io/json/read/deserialize.rs +++ b/src/io/json/read/deserialize.rs @@ -255,7 +255,7 @@ pub fn read(rows: &[&Value], data_type: DataType) -> Arc { DataType::LargeBinary => Arc::new(read_binary::(rows)), DataType::Struct(_) => Arc::new(read_struct(rows, data_type)), DataType::Dictionary(key_type, _) => { - with_match_dictionary_key_type!(key_type.as_ref(), |$T| { + match_integer_type!(key_type, |$T| { Arc::new(read_dictionary::<$T>(rows, data_type)) }) } diff --git a/src/io/json_integration/mod.rs b/src/io/json_integration/mod.rs index 29a015ce1b5..7f425c7b710 100644 --- a/src/io/json_integration/mod.rs +++ b/src/io/json_integration/mod.rs @@ -20,13 +20,10 @@ //! These utilities define structs that read the integration JSON format for integration testing purposes. use serde_derive::{Deserialize, Serialize}; -use serde_json::{Map, Value}; +use serde_json::Value; -use crate::datatypes::*; - -mod schema; -use schema::ToJson; mod read; +mod schema; mod write; pub use read::to_record_batch; pub use write::from_record_batch; @@ -64,47 +61,17 @@ pub struct ArrowJsonField { pub metadata: Option, } -impl From<&Field> for ArrowJsonField { - fn from(field: &Field) -> Self { - let metadata_value = match field.metadata() { - Some(kv_list) => { - let mut array = Vec::new(); - for (k, v) in kv_list { - let mut kv_map = Map::new(); - kv_map.insert(k.clone(), Value::String(v.clone())); - array.push(Value::Object(kv_map)); - } - if !array.is_empty() { - Some(Value::Array(array)) - } else { - None - } - } - _ => None, - }; - - Self { - name: field.name().to_string(), - field_type: field.data_type().to_json(), - nullable: field.is_nullable(), - children: vec![], - dictionary: None, // TODO: not enough info - metadata: metadata_value, - } - } -} - #[derive(Deserialize, Serialize, Debug)] pub struct ArrowJsonFieldDictionary { pub id: i64, #[serde(rename = "indexType")] - pub index_type: DictionaryIndexType, + pub index_type: IntegerType, #[serde(rename = "isOrdered")] pub is_ordered: bool, } #[derive(Deserialize, Serialize, Debug)] -pub struct DictionaryIndexType { +pub struct IntegerType { pub name: String, #[serde(rename = "isSigned")] pub is_signed: bool, diff --git a/src/io/json_integration/read.rs b/src/io/json_integration/read.rs index e24ae528266..856fdb9808c 100644 --- a/src/io/json_integration/read.rs +++ b/src/io/json_integration/read.rs @@ -364,7 +364,7 @@ pub fn to_array( Ok(Arc::new(array)) } Dictionary(key_type) => { - with_match_physical_dictionary_key_type!(key_type, |$T| { + match_integer_type!(key_type, |$T| { to_dictionary::<$T>(data_type, dict_id.unwrap(), json_col, dictionaries) }) } diff --git a/src/io/json_integration/schema.rs b/src/io/json_integration/schema.rs index b461408c3c1..70bcebfd414 100644 --- a/src/io/json_integration/schema.rs +++ b/src/io/json_integration/schema.rs @@ -21,139 +21,16 @@ use std::{ }; use serde_derive::Deserialize; -use serde_json::{json, Value}; +use serde_json::Value; use crate::{ datatypes::UnionMode, error::{ArrowError, Result}, }; -use crate::datatypes::{get_extension, DataType, Field, IntervalUnit, Schema, TimeUnit}; - -pub trait ToJson { - /// Generate a JSON representation - fn to_json(&self) -> Value; -} - -impl ToJson for DataType { - fn to_json(&self) -> Value { - match self { - DataType::Null => json!({"name": "null"}), - DataType::Boolean => json!({"name": "bool"}), - DataType::Int8 => json!({"name": "int", "bitWidth": 8, "isSigned": true}), - DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}), - DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}), - DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}), - DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}), - DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}), - DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}), - DataType::UInt64 => json!({"name": "int", "bitWidth": 64, "isSigned": false}), - DataType::Float16 => json!({"name": "floatingpoint", "precision": "HALF"}), - DataType::Float32 => json!({"name": "floatingpoint", "precision": "SINGLE"}), - DataType::Float64 => json!({"name": "floatingpoint", "precision": "DOUBLE"}), - DataType::Utf8 => json!({"name": "utf8"}), - DataType::LargeUtf8 => json!({"name": "largeutf8"}), - DataType::Binary => json!({"name": "binary"}), - DataType::LargeBinary => json!({"name": "largebinary"}), - DataType::FixedSizeBinary(byte_width) => { - json!({"name": "fixedsizebinary", "byteWidth": byte_width}) - } - DataType::Struct(_) => json!({"name": "struct"}), - DataType::Union(_, _, _) => json!({"name": "union"}), - DataType::Map(_, _) => json!({"name": "map"}), - DataType::List(_) => json!({ "name": "list"}), - DataType::LargeList(_) => json!({ "name": "largelist"}), - DataType::FixedSizeList(_, length) => { - json!({"name":"fixedsizelist", "listSize": length}) - } - DataType::Time32(unit) => { - json!({"name": "time", "bitWidth": 32, "unit": match unit { - TimeUnit::Second => "SECOND", - TimeUnit::Millisecond => "MILLISECOND", - TimeUnit::Microsecond => "MICROSECOND", - TimeUnit::Nanosecond => "NANOSECOND", - }}) - } - DataType::Time64(unit) => { - json!({"name": "time", "bitWidth": 64, "unit": match unit { - TimeUnit::Second => "SECOND", - TimeUnit::Millisecond => "MILLISECOND", - TimeUnit::Microsecond => "MICROSECOND", - TimeUnit::Nanosecond => "NANOSECOND", - }}) - } - DataType::Date32 => { - json!({"name": "date", "unit": "DAY"}) - } - DataType::Date64 => { - json!({"name": "date", "unit": "MILLISECOND"}) - } - DataType::Timestamp(unit, None) => { - json!({"name": "timestamp", "unit": match unit { - TimeUnit::Second => "SECOND", - TimeUnit::Millisecond => "MILLISECOND", - TimeUnit::Microsecond => "MICROSECOND", - TimeUnit::Nanosecond => "NANOSECOND", - }}) - } - DataType::Timestamp(unit, Some(tz)) => { - json!({"name": "timestamp", "unit": match unit { - TimeUnit::Second => "SECOND", - TimeUnit::Millisecond => "MILLISECOND", - TimeUnit::Microsecond => "MICROSECOND", - TimeUnit::Nanosecond => "NANOSECOND", - }, "timezone": tz}) - } - DataType::Interval(unit) => json!({"name": "interval", "unit": match unit { - IntervalUnit::YearMonth => "YEAR_MONTH", - IntervalUnit::DayTime => "DAY_TIME", - IntervalUnit::MonthDayNano => "MONTH_DAY_NANO", - }}), - DataType::Duration(unit) => json!({"name": "duration", "unit": match unit { - TimeUnit::Second => "SECOND", - TimeUnit::Millisecond => "MILLISECOND", - TimeUnit::Microsecond => "MICROSECOND", - TimeUnit::Nanosecond => "NANOSECOND", - }}), - DataType::Dictionary(_, _) => json!({ "name": "dictionary"}), - DataType::Decimal(precision, scale) => { - json!({"name": "decimal", "precision": precision, "scale": scale}) - } - DataType::Extension(_, inner_data_type, _) => inner_data_type.to_json(), - } - } -} - -impl ToJson for Field { - fn to_json(&self) -> Value { - let children: Vec = match self.data_type() { - DataType::Struct(fields) => fields.iter().map(|f| f.to_json()).collect(), - DataType::List(field) => vec![field.to_json()], - DataType::LargeList(field) => vec![field.to_json()], - DataType::FixedSizeList(field, _) => vec![field.to_json()], - _ => vec![], - }; - match self.data_type() { - DataType::Dictionary(ref index_type, ref value_type) => json!({ - "name": self.name(), - "nullable": self.is_nullable(), - "type": value_type.to_json(), - "children": children, - "dictionary": { - "id": self.dict_id(), - "indexType": index_type.to_json(), - "isOrdered": self.dict_is_ordered() - } - }), - _ => json!({ - "name": self.name(), - "nullable": self.is_nullable(), - "type": self.data_type().to_json(), - "children": children - }), - } - } -} +use crate::datatypes::{ + get_extension, DataType, Field, IntegerType, IntervalUnit, Schema, TimeUnit, +}; fn to_time_unit(item: Option<&Value>) -> Result { match item { @@ -167,6 +44,52 @@ fn to_time_unit(item: Option<&Value>) -> Result { } } +fn to_int(item: &Value) -> Result { + Ok(match item.get("isSigned") { + Some(&Value::Bool(true)) => match item.get("bitWidth") { + Some(&Value::Number(ref n)) => match n.as_u64() { + Some(8) => IntegerType::Int8, + Some(16) => IntegerType::Int16, + Some(32) => IntegerType::Int32, + Some(64) => IntegerType::Int64, + _ => { + return Err(ArrowError::Schema( + "int bitWidth missing or invalid".to_string(), + )) + } + }, + _ => { + return Err(ArrowError::Schema( + "int bitWidth missing or invalid".to_string(), + )) + } + }, + Some(&Value::Bool(false)) => match item.get("bitWidth") { + Some(&Value::Number(ref n)) => match n.as_u64() { + Some(8) => IntegerType::UInt8, + Some(16) => IntegerType::UInt16, + Some(32) => IntegerType::UInt32, + Some(64) => IntegerType::UInt64, + _ => { + return Err(ArrowError::Schema( + "int bitWidth missing or invalid".to_string(), + )) + } + }, + _ => { + return Err(ArrowError::Schema( + "int bitWidth missing or invalid".to_string(), + )) + } + }, + _ => { + return Err(ArrowError::Schema( + "int signed missing or invalid".to_string(), + )) + } + }) +} + fn children(children: Option<&Value>) -> Result> { children .map(|x| { @@ -339,49 +262,7 @@ fn to_data_type(item: &Value, mut children: Vec) -> Result { )) } }, - "int" => match item.get("isSigned") { - Some(&Value::Bool(true)) => match item.get("bitWidth") { - Some(&Value::Number(ref n)) => match n.as_u64() { - Some(8) => DataType::Int8, - Some(16) => DataType::Int16, - Some(32) => DataType::Int32, - Some(64) => DataType::Int64, - _ => { - return Err(ArrowError::Schema( - "int bitWidth missing or invalid".to_string(), - )) - } - }, - _ => { - return Err(ArrowError::Schema( - "int bitWidth missing or invalid".to_string(), - )) - } - }, - Some(&Value::Bool(false)) => match item.get("bitWidth") { - Some(&Value::Number(ref n)) => match n.as_u64() { - Some(8) => DataType::UInt8, - Some(16) => DataType::UInt16, - Some(32) => DataType::UInt32, - Some(64) => DataType::UInt64, - _ => { - return Err(ArrowError::Schema( - "int bitWidth missing or invalid".to_string(), - )) - } - }, - _ => { - return Err(ArrowError::Schema( - "int bitWidth missing or invalid".to_string(), - )) - } - }, - _ => { - return Err(ArrowError::Schema( - "int signed missing or invalid".to_string(), - )) - } - }, + "int" => to_int(item).map(|x| x.into())?, "list" => DataType::List(Box::new(children.pop().unwrap())), "largelist" => DataType::LargeList(Box::new(children.pop().unwrap())), "fixedsizelist" => { @@ -474,14 +355,14 @@ impl TryFrom<&Value> for Field { let data_type = if let Some(dictionary) = map.get("dictionary") { let index_type = match dictionary.get("indexType") { - Some(t) => to_data_type(t, vec![])?, + Some(t) => to_int(t)?, _ => { return Err(ArrowError::Schema( "Field missing 'indexType' attribute".to_string(), )); } }; - DataType::Dictionary(Box::new(index_type), Box::new(data_type)) + DataType::Dictionary(index_type, Box::new(data_type)) } else { data_type }; @@ -518,15 +399,6 @@ impl TryFrom<&Value> for Field { } } -impl ToJson for Schema { - fn to_json(&self) -> Value { - json!({ - "fields": self.fields.iter().map(|field| field.to_json()).collect::>(), - "metadata": serde_json::to_value(&self.metadata).unwrap(), - }) - } -} - #[derive(Deserialize)] struct MetadataKeyValue { key: String, diff --git a/src/io/parquet/read/mod.rs b/src/io/parquet/read/mod.rs index 3e421c1f493..aa2b051849d 100644 --- a/src/io/parquet/read/mod.rs +++ b/src/io/parquet/read/mod.rs @@ -327,17 +327,9 @@ pub fn page_iter_to_array match key.as_ref() { - Int8 => dict_read::(iter, metadata, data_type), - Int16 => dict_read::(iter, metadata, data_type), - Int32 => dict_read::(iter, metadata, data_type), - Int64 => dict_read::(iter, metadata, data_type), - UInt8 => dict_read::(iter, metadata, data_type), - UInt16 => dict_read::(iter, metadata, data_type), - UInt32 => dict_read::(iter, metadata, data_type), - UInt64 => dict_read::(iter, metadata, data_type), - _ => unreachable!(), - }, + Dictionary(key_type, _) => match_integer_type!(key_type, |$T| { + dict_read::<$T, _>(iter, metadata, data_type) + }), other => Err(ArrowError::NotYetImplemented(format!( "Reading {:?} from parquet still not implemented", diff --git a/src/io/parquet/write/mod.rs b/src/io/parquet/write/mod.rs index 5048ea2d3f2..2f1b6f52dfe 100644 --- a/src/io/parquet/write/mod.rs +++ b/src/io/parquet/write/mod.rs @@ -117,7 +117,7 @@ pub fn array_to_pages( ) -> Result>> { match array.data_type() { DataType::Dictionary(key_type, _) => { - with_match_dictionary_key_type!(key_type.as_ref(), |$T| { + match_integer_type!(key_type, |$T| { dictionary::array_to_pages::<$T>( array.as_any().downcast_ref().unwrap(), descriptor, diff --git a/src/types/mod.rs b/src/types/mod.rs index d74c9d1c078..a20523d67f8 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -64,7 +64,6 @@ pub unsafe trait NativeType: + std::fmt::Display + PartialEq + Default - + Sized + 'static { /// Type denoting its representation as bytes. diff --git a/tests/it/compute/cast.rs b/tests/it/compute/cast.rs index c1156553ee5..206183c2129 100644 --- a/tests/it/compute/cast.rs +++ b/tests/it/compute/cast.rs @@ -483,7 +483,7 @@ fn utf8_to_dict() { let array = Utf8Array::::from(&[Some("one"), None, Some("three"), Some("one")]); // Cast to a dictionary (same value type, Utf8) - let cast_type = DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)); + let cast_type = DataType::Dictionary(u8::KEY_TYPE, Box::new(DataType::Utf8)); let result = cast(&array, &cast_type, CastOptions::default()).expect("cast failed"); let mut expected = MutableDictionaryArray::>::new(); @@ -514,7 +514,7 @@ fn i32_to_dict() { let array = Int32Array::from(&[Some(1), None, Some(3), Some(1)]); // Cast to a dictionary (same value type, Utf8) - let cast_type = DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Int32)); + let cast_type = DataType::Dictionary(u8::KEY_TYPE, Box::new(DataType::Int32)); let result = cast(&array, &cast_type, CastOptions::default()).expect("cast failed"); let mut expected = MutableDictionaryArray::>::new(); @@ -617,7 +617,7 @@ fn dict_to_dict_bad_index_value_primitive() { } let array: ArrayRef = Arc::new(builder.finish()); - let cast_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let cast_type = Dictionary(i8::KEY_TYPE, Box::new(Utf8)); let res = cast(&array, &cast_type, CastOptions::default()); assert, CastOptions::default())!(res.is_err()); let actual_error = format!("{:?}", res); @@ -649,7 +649,7 @@ fn dict_to_dict_bad_index_value_utf8() { } let array: ArrayRef = Arc::new(builder.finish()); - let cast_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let cast_type = Dictionary(i8::KEY_TYPE, Box::new(Utf8)); let res = cast(&array, &cast_type, CastOptions::default()); assert, CastOptions::default())!(res.is_err()); let actual_error = format!("{:?}", res); diff --git a/tests/it/ffi.rs b/tests/it/ffi.rs index 7a1b1664f74..cccd903319c 100644 --- a/tests/it/ffi.rs +++ b/tests/it/ffi.rs @@ -196,7 +196,7 @@ fn schema() -> Result<()> { let field = Field::new( "a", - DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + DataType::Dictionary(u32::KEY_TYPE, Box::new(DataType::Utf8)), true, ); test_round_trip_schema(field)?; diff --git a/tests/it/io/avro/read/mod.rs b/tests/it/io/avro/read/mod.rs index 8435820d17a..a17fb06387d 100644 --- a/tests/it/io/avro/read/mod.rs +++ b/tests/it/io/avro/read/mod.rs @@ -69,7 +69,7 @@ fn schema() -> (AvroSchema, Schema) { ), Field::new( "enum", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + DataType::Dictionary(i32::KEY_TYPE, Box::new(DataType::Utf8)), false, ), Field::new( diff --git a/tests/it/io/csv/write.rs b/tests/it/io/csv/write.rs index 6098b60bbb2..8bcd10cd829 100644 --- a/tests/it/io/csv/write.rs +++ b/tests/it/io/csv/write.rs @@ -17,7 +17,7 @@ fn data() -> RecordBatch { Field::new("c6", DataType::Time32(TimeUnit::Second), false), Field::new( "c7", - DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + DataType::Dictionary(u32::KEY_TYPE, Box::new(DataType::Utf8)), false, ), ]); diff --git a/tests/it/io/json/mod.rs b/tests/it/io/json/mod.rs index 1bf3d0856cd..1b70e79d1c5 100644 --- a/tests/it/io/json/mod.rs +++ b/tests/it/io/json/mod.rs @@ -117,7 +117,7 @@ fn case_dict() -> (String, Schema, Vec>) { let data_type = DataType::List(Box::new(Field::new( "item", - DataType::Dictionary(Box::new(DataType::UInt64), Box::new(DataType::Utf8)), + DataType::Dictionary(u64::KEY_TYPE, Box::new(DataType::Utf8)), true, ))); diff --git a/tests/it/io/print.rs b/tests/it/io/print.rs index b8fa3f6451b..07a294942e1 100644 --- a/tests/it/io/print.rs +++ b/tests/it/io/print.rs @@ -87,7 +87,7 @@ fn write_null() -> Result<()> { #[test] fn write_dictionary() -> Result<()> { // define a schema. - let field_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let field_type = DataType::Dictionary(i32::KEY_TYPE, Box::new(DataType::Utf8)); let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); let mut array = MutableDictionaryArray::>::new(); @@ -119,7 +119,7 @@ fn write_dictionary() -> Result<()> { #[test] fn dictionary_validities() -> Result<()> { // define a schema. - let field_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32)); + let field_type = DataType::Dictionary(i32::KEY_TYPE, Box::new(DataType::Int32)); let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); let keys = PrimitiveArray::::from([Some(1), None, Some(0)]);