diff --git a/arrow-parquet-integration-testing/src/main.rs b/arrow-parquet-integration-testing/src/main.rs index 3a7e38c7ab3..28f161095ce 100644 --- a/arrow-parquet-integration-testing/src/main.rs +++ b/arrow-parquet-integration-testing/src/main.rs @@ -165,7 +165,7 @@ fn main() -> Result<()> { .fields() .iter() .map(|x| match x.data_type() { - DataType::Dictionary(_, _) => Encoding::RleDictionary, + DataType::Dictionary(..) => Encoding::RleDictionary, DataType::Utf8 | DataType::LargeUtf8 => { if utf8_encoding == "delta" { Encoding::DeltaLengthByteArray diff --git a/src/array/dictionary/mod.rs b/src/array/dictionary/mod.rs index 98b9d6bbff2..bd9514bcab4 100644 --- a/src/array/dictionary/mod.rs +++ b/src/array/dictionary/mod.rs @@ -79,7 +79,8 @@ impl DictionaryArray { /// The canonical method to create a new [`DictionaryArray`]. pub fn from_data(keys: PrimitiveArray, values: Arc) -> Self { - let data_type = DataType::Dictionary(K::KEY_TYPE, Box::new(values.data_type().clone())); + let data_type = + DataType::Dictionary(K::KEY_TYPE, Box::new(values.data_type().clone()), false); Self { data_type, @@ -165,7 +166,7 @@ impl DictionaryArray { impl DictionaryArray { pub(crate) fn get_child(data_type: &DataType) -> &DataType { match data_type { - DataType::Dictionary(_, values) => values.as_ref(), + DataType::Dictionary(_, values, _) => values.as_ref(), DataType::Extension(_, inner, _) => Self::get_child(inner), _ => panic!("DictionaryArray must be initialized with DataType::Dictionary"), } diff --git a/src/array/dictionary/mutable.rs b/src/array/dictionary/mutable.rs index 0527be6ad1b..bf152bd693b 100644 --- a/src/array/dictionary/mutable.rs +++ b/src/array/dictionary/mutable.rs @@ -31,7 +31,11 @@ impl From> for D impl From for MutableDictionaryArray { fn from(values: M) -> Self { Self { - data_type: DataType::Dictionary(K::KEY_TYPE, Box::new(values.data_type().clone())), + data_type: DataType::Dictionary( + K::KEY_TYPE, + Box::new(values.data_type().clone()), + false, + ), keys: MutablePrimitiveArray::::new(), map: HashedMap::default(), values, @@ -44,7 +48,11 @@ impl MutableDictionaryArray { pub fn new() -> Self { let values = M::default(); Self { - data_type: DataType::Dictionary(K::KEY_TYPE, Box::new(values.data_type().clone())), + data_type: DataType::Dictionary( + K::KEY_TYPE, + Box::new(values.data_type().clone()), + false, + ), keys: MutablePrimitiveArray::::new(), map: HashedMap::default(), values, diff --git a/src/array/display.rs b/src/array/display.rs index 12aac44b9ea..3ab8cc67d94 100644 --- a/src/array/display.rs +++ b/src/array/display.rs @@ -158,7 +158,7 @@ pub fn get_value_display<'a>(array: &'a dyn Array) -> Box Strin }; dyn_display!(array, ListArray, f) } - Dictionary(key_type, _) => match_integer_type!(key_type, |$T| { + Dictionary(key_type, ..) => match_integer_type!(key_type, |$T| { let a = array .as_any() .downcast_ref::>() diff --git a/src/array/ord.rs b/src/array/ord.rs index 31c36d770e5..614f47f5839 100644 --- a/src/array/ord.rs +++ b/src/array/ord.rs @@ -215,7 +215,7 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result compare_string::(left, right), (Binary, Binary) => compare_binary::(left, right), (LargeBinary, LargeBinary) => compare_binary::(left, right), - (Dictionary(key_type_lhs, _), Dictionary(key_type_rhs, _)) => { + (Dictionary(key_type_lhs, ..), Dictionary(key_type_rhs, ..)) => { match (key_type_lhs, key_type_rhs) { (IntegerType::UInt8, IntegerType::UInt8) => dyn_dict!(u8, left, right), (IntegerType::UInt16, IntegerType::UInt16) => dyn_dict!(u16, left, right), diff --git a/src/compute/arithmetics/mod.rs b/src/compute/arithmetics/mod.rs index b8e61c9fc27..34ab0bf1722 100644 --- a/src/compute/arithmetics/mod.rs +++ b/src/compute/arithmetics/mod.rs @@ -447,7 +447,7 @@ pub fn neg(array: &dyn Array) -> Box { /// Whether [`neg`] is supported for a given [`DataType`] pub fn can_neg(data_type: &DataType) -> bool { - if let DataType::Dictionary(_, values) = data_type.to_logical_type() { + if let DataType::Dictionary(_, values, _) = data_type.to_logical_type() { return can_neg(values.as_ref()); } diff --git a/src/compute/cast/dictionary_to.rs b/src/compute/cast/dictionary_to.rs index 36878931fa3..0f610931746 100644 --- a/src/compute/cast/dictionary_to.rs +++ b/src/compute/cast/dictionary_to.rs @@ -110,7 +110,7 @@ pub(super) fn dictionary_cast_dyn( let values = array.values(); match to_type { - DataType::Dictionary(to_keys_type, to_values_type) => { + DataType::Dictionary(to_keys_type, to_values_type, _) => { let values = cast(values.as_ref(), to_values_type, options)?.into(); // create the appropriate array type diff --git a/src/compute/cast/mod.rs b/src/compute/cast/mod.rs index a3d5e1b4acb..faf0805acbc 100644 --- a/src/compute/cast/mod.rs +++ b/src/compute/cast/mod.rs @@ -80,40 +80,12 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { match (from_type, to_type) { ( Null, - Boolean - | Int8 - | UInt8 - | Int16 - | UInt16 - | Int32 - | UInt32 - | Float32 - | Date32 - | Time32(_) - | Int64 - | UInt64 - | Float64 - | Date64 - | List(_) - | Dictionary(_, _), + Boolean | Int8 | UInt8 | Int16 | UInt16 | Int32 | UInt32 | Float32 | Date32 | Time32(_) + | Int64 | UInt64 | Float64 | Date64 | List(_) | Dictionary(..), ) | ( - Boolean - | Int8 - | UInt8 - | Int16 - | UInt16 - | Int32 - | UInt32 - | Float32 - | Date32 - | Time32(_) - | Int64 - | UInt64 - | Float64 - | Date64 - | List(_) - | Dictionary(_, _), + Boolean | Int8 | UInt8 | Int16 | UInt16 | Int32 | UInt32 | Float32 | Date32 | Time32(_) + | Int64 | UInt64 | Float64 | Date64 | List(_) | Dictionary(..), Null, ) => true, (Struct(_), _) => false, @@ -127,11 +99,11 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (List(list_from), LargeList(list_to)) if list_from == list_to => true, (LargeList(list_from), List(list_to)) if list_from == list_to => true, (_, List(list_to)) => can_cast_types(from_type, list_to.data_type()), - (Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => { + (Dictionary(_, from_value_type, _), Dictionary(_, to_value_type, _)) => { can_cast_types(from_value_type, to_value_type) } - (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type), - (_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type), + (Dictionary(_, value_type, _), _) => can_cast_types(value_type, to_type), + (_, Dictionary(_, value_type, _)) => can_cast_types(from_type, value_type), (_, Boolean) => is_numeric(from_type), (Boolean, _) => { @@ -376,40 +348,12 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu match (from_type, to_type) { ( Null, - Boolean - | Int8 - | UInt8 - | Int16 - | UInt16 - | Int32 - | UInt32 - | Float32 - | Date32 - | Time32(_) - | Int64 - | UInt64 - | Float64 - | Date64 - | List(_) - | Dictionary(_, _), + Boolean | Int8 | UInt8 | Int16 | UInt16 | Int32 | UInt32 | Float32 | Date32 | Time32(_) + | Int64 | UInt64 | Float64 | Date64 | List(_) | Dictionary(..), ) | ( - Boolean - | Int8 - | UInt8 - | Int16 - | UInt16 - | Int32 - | UInt32 - | Float32 - | Date32 - | Time32(_) - | Int64 - | UInt64 - | Float64 - | Date64 - | List(_) - | Dictionary(_, _), + Boolean | Int8 | UInt8 | Int16 | UInt16 | Int32 | UInt32 | Float32 | Date32 | Time32(_) + | Int64 | UInt64 | Float64 | Date64 | List(_) | Dictionary(..), Null, ) => Ok(new_null_array(to_type.clone(), array.len())), (Struct(_), _) => Err(ArrowError::NotYetImplemented( @@ -449,10 +393,10 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu Ok(Box::new(list_array)) } - (Dictionary(index_type, _), _) => match_integer_type!(index_type, |$T| { + (Dictionary(index_type, ..), _) => match_integer_type!(index_type, |$T| { dictionary_cast_dyn::<$T>(array, to_type, options) }), - (_, Dictionary(index_type, value_type)) => match_integer_type!(index_type, |$T| { + (_, Dictionary(index_type, value_type, _)) => match_integer_type!(index_type, |$T| { cast_to_dictionary::<$T>(array, value_type, options) }), (_, Boolean) => match from_type { diff --git a/src/compute/sort/mod.rs b/src/compute/sort/mod.rs index 620374bd63c..6011e00ed3f 100644 --- a/src/compute/sort/mod.rs +++ b/src/compute/sort/mod.rs @@ -202,7 +202,7 @@ pub fn sort_to_indices( ))), } } - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Dictionary(key_type, value_type, _) => match value_type.as_ref() { DataType::Utf8 => Ok(sort_dict::(values, key_type, options, limit)), DataType::LargeUtf8 => Ok(sort_dict::(values, key_type, options, limit)), t => Err(ArrowError::NotYetImplemented(format!( @@ -282,7 +282,7 @@ pub fn can_sort(data_type: &DataType) -> bool { | DataType::UInt64 ) } - DataType::Dictionary(_, value_type) => { + DataType::Dictionary(_, value_type, _) => { matches!(*value_type.as_ref(), DataType::Utf8 | DataType::LargeUtf8) } _ => false, diff --git a/src/compute/take/mod.rs b/src/compute/take/mod.rs index d07a4873268..b127ddd5c4a 100644 --- a/src/compute/take/mod.rs +++ b/src/compute/take/mod.rs @@ -133,6 +133,6 @@ pub fn can_take(data_type: &DataType) -> bool { | DataType::Struct(_) | DataType::List(_) | DataType::LargeList(_) - | DataType::Dictionary(_, _) + | DataType::Dictionary(..) ) } diff --git a/src/datatypes/field.rs b/src/datatypes/field.rs index 9526bb912a8..aa0d58f1b71 100644 --- a/src/datatypes/field.rs +++ b/src/datatypes/field.rs @@ -33,8 +33,6 @@ pub struct Field { pub nullable: bool, /// The dictionary id of this field (currently un-used) pub dict_id: i64, - /// Whether the dictionary's values are ordered - pub dict_is_ordered: bool, /// A map of key-value pairs containing additional custom meta data. pub metadata: Option>, } @@ -44,7 +42,6 @@ impl std::hash::Hash for Field { self.name.hash(state); self.data_type.hash(state); self.nullable.hash(state); - self.dict_is_ordered.hash(state); self.metadata.hash(state); } } @@ -54,7 +51,6 @@ impl PartialEq for Field { self.name == other.name && self.data_type == other.data_type && self.nullable == other.nullable - && self.dict_is_ordered == other.dict_is_ordered && self.metadata == other.metadata } } @@ -67,7 +63,6 @@ impl Field { data_type, nullable, dict_id: 0, - dict_is_ordered: false, metadata: None, } } @@ -78,14 +73,12 @@ impl Field { data_type: DataType, nullable: bool, dict_id: i64, - dict_is_ordered: bool, ) -> Self { Field { name: name.into(), data_type, nullable, dict_id, - dict_is_ordered, metadata: None, } } @@ -98,7 +91,6 @@ impl Field { data_type: self.data_type, nullable: self.nullable, dict_id: self.dict_id, - dict_is_ordered: self.dict_is_ordered, metadata: Some(metadata), } } @@ -143,16 +135,7 @@ impl Field { #[inline] pub const fn dict_id(&self) -> Option { match self.data_type { - DataType::Dictionary(_, _) => Some(self.dict_id), - _ => None, - } - } - - /// Returns whether this [`Field`]'s dictionary is ordered, if this is a dictionary type. - #[inline] - pub const fn dict_is_ordered(&self) -> Option { - match self.data_type { - DataType::Dictionary(_, _) => Some(self.dict_is_ordered), + DataType::Dictionary(_, _, _) => Some(self.dict_id), _ => None, } } @@ -197,11 +180,6 @@ impl Field { "Fail to merge schema Field due to conflicting dict_id".to_string(), )); } - if from.dict_is_ordered != self.dict_is_ordered { - return Err(ArrowError::InvalidArgumentError( - "Fail to merge schema Field due to conflicting dict_is_ordered".to_string(), - )); - } match &mut self.data_type { DataType::Struct(nested_fields) => match &from.data_type { DataType::Struct(from_nested_fields) => { @@ -270,7 +248,7 @@ impl Field { | DataType::Interval(_) | DataType::LargeList(_) | DataType::List(_) - | DataType::Dictionary(_, _) + | DataType::Dictionary(_, _, _) | DataType::FixedSizeList(_, _) | DataType::FixedSizeBinary(_) | DataType::Utf8 diff --git a/src/datatypes/mod.rs b/src/datatypes/mod.rs index 04b1f0d2c66..858595c2c08 100644 --- a/src/datatypes/mod.rs +++ b/src/datatypes/mod.rs @@ -128,7 +128,7 @@ pub enum DataType { /// /// This type mostly used to represent low cardinality string /// arrays or a limited set of primitive types as integers. - Dictionary(IntegerType, Box), + Dictionary(IntegerType, Box, bool), /// Decimal value with precision and scale /// precision is the number of digits in the number and /// scale is the number of decimal places. @@ -261,7 +261,7 @@ impl DataType { Struct(_) => PhysicalType::Struct, Union(_, _, _) => PhysicalType::Union, Map(_, _) => PhysicalType::Map, - Dictionary(key, _) => PhysicalType::Dictionary(*key), + Dictionary(key, _, _) => PhysicalType::Dictionary(*key), Extension(_, key, _) => key.to_physical_type(), } } diff --git a/src/ffi/ffi.rs b/src/ffi/ffi.rs index 24b06ae8e14..1d028d8d6d0 100644 --- a/src/ffi/ffi.rs +++ b/src/ffi/ffi.rs @@ -331,7 +331,7 @@ fn create_dictionary( field: &Field, parent: Arc, ) -> Result>> { - if let DataType::Dictionary(_, values) = field.data_type() { + if let DataType::Dictionary(_, values, _) = field.data_type() { let field = Field::new("", values.as_ref().clone(), true); assert!(!array.dictionary.is_null()); let array = unsafe { &*array.dictionary }; diff --git a/src/ffi/schema.rs b/src/ffi/schema.rs index bf84963bcb7..a02b61b33f2 100644 --- a/src/ffi/schema.rs +++ b/src/ffi/schema.rs @@ -92,8 +92,8 @@ impl Ffi_ArrowSchema { .collect::>(); let n_children = children_ptr.len() as i64; - let dictionary = if let DataType::Dictionary(_, values) = field.data_type() { - flags += field.dict_is_ordered().unwrap_or_default() as i64; + let dictionary = if let DataType::Dictionary(_, values, is_ordered) = field.data_type() { + flags += *is_ordered as i64; // we do not store field info in the dict values, so can't recover it all :( let field = Field::new("", values.as_ref().clone(), true); Some(Box::new(Ffi_ArrowSchema::new(&field))) @@ -214,7 +214,8 @@ pub(crate) unsafe fn to_field(schema: &Ffi_ArrowSchema) -> Result { let data_type = if let Some(dictionary) = dictionary { let indices = to_integer_type(schema.format())?; let values = to_field(dictionary)?; - DataType::Dictionary(indices, Box::new(values.data_type().clone())) + let is_ordered = schema.flags & 1 == 1; + DataType::Dictionary(indices, Box::new(values.data_type().clone()), is_ordered) } else { to_data_type(schema)? }; @@ -449,7 +450,7 @@ fn to_format(data_type: &DataType) -> String { r } DataType::Map(_, _) => "+m".to_string(), - DataType::Dictionary(index, _) => to_format(&(*index).into()), + DataType::Dictionary(index, _, _) => to_format(&(*index).into()), DataType::Extension(_, inner, _) => to_format(inner.as_ref()), } } diff --git a/src/io/avro/read/nested.rs b/src/io/avro/read/nested.rs index a72ddb08486..d0e2ebf81bf 100644 --- a/src/io/avro/read/nested.rs +++ b/src/io/avro/read/nested.rs @@ -138,6 +138,7 @@ impl FixedItemsUtf8Dictionary { data_type: DataType::Dictionary( IntegerType::Int32, Box::new(values.data_type().clone()), + false, ), keys: MutablePrimitiveArray::::with_capacity(capacity), values, diff --git a/src/io/avro/read/schema.rs b/src/io/avro/read/schema.rs index 420187988ef..38d2b0e6215 100644 --- a/src/io/avro/read/schema.rs +++ b/src/io/avro/read/schema.rs @@ -182,7 +182,7 @@ fn schema_to_field( AvroSchema::Enum { .. } => { return Ok(Field::new( name.unwrap_or_default(), - DataType::Dictionary(IntegerType::Int32, Box::new(DataType::Utf8)), + DataType::Dictionary(IntegerType::Int32, Box::new(DataType::Utf8), false), false, )) } diff --git a/src/io/csv/write/serialize.rs b/src/io/csv/write/serialize.rs index 6c37e0eb24f..660bde3da1b 100644 --- a/src/io/csv/write/serialize.rs +++ b/src/io/csv/write/serialize.rs @@ -383,7 +383,7 @@ pub fn new_serializer<'a>( vec![], )) } - DataType::Dictionary(keys_dt, values_dt) => match &**values_dt { + DataType::Dictionary(keys_dt, values_dt, _) => match &**values_dt { DataType::LargeUtf8 => match *keys_dt { IntegerType::UInt32 => serialize_utf8_dict::(array.as_any()), IntegerType::UInt64 => serialize_utf8_dict::(array.as_any()), diff --git a/src/io/ipc/convert.rs b/src/io/ipc/convert.rs index ecaef6bc58e..2c7e0974d1b 100644 --- a/src/io/ipc/convert.rs +++ b/src/io/ipc/convert.rs @@ -97,7 +97,6 @@ impl<'a> From> for Field { data_type, field.nullable(), dictionary.id(), - dictionary.isOrdered(), ) } else { Field::new(field.name().unwrap(), data_type, field.nullable()) @@ -156,6 +155,7 @@ fn get_data_type(field: ipc::Field, extension: Extension, may_be_dictionary: boo return DataType::Dictionary( index_type, Box::new(get_data_type(field, extension, false)), + dictionary.isOrdered(), ); } } @@ -377,23 +377,22 @@ pub(crate) fn build_field<'a>( let fb_field_name = fbb.create_string(field.name().as_str()); let field_type = get_fb_field_type(field.data_type(), field.is_nullable(), fbb); - let fb_dictionary = if let DataType::Dictionary(index_type, inner) = field.data_type() { - if let DataType::Extension(name, _, metadata) = inner.as_ref() { - write_extension(fbb, name, metadata, &mut kv_vec); - } - Some(get_fb_dictionary( - index_type, - field - .dict_id() - .expect("All Dictionary types have `dict_id`"), - field - .dict_is_ordered() - .expect("All Dictionary types have `dict_is_ordered`"), - fbb, - )) - } else { - None - }; + let fb_dictionary = + if let DataType::Dictionary(index_type, inner, is_ordered) = field.data_type() { + if let DataType::Extension(name, _, metadata) = inner.as_ref() { + write_extension(fbb, name, metadata, &mut kv_vec); + } + Some(get_fb_dictionary( + index_type, + field + .dict_id() + .expect("All Dictionary types have `dict_id`"), + *is_ordered, + fbb, + )) + } else { + None + }; if let Some(metadata) = field.metadata() { if !metadata.is_empty() { @@ -450,7 +449,7 @@ fn type_to_field_type(data_type: &DataType) -> ipc::Type { Union(_, _, _) => ipc::Type::Union, Map(_, _) => ipc::Type::Map, Struct(_) => ipc::Type::Struct_, - Dictionary(_, v) => type_to_field_type(v), + Dictionary(_, v, _) => type_to_field_type(v), Extension(_, v, _) => type_to_field_type(v), } } @@ -671,7 +670,7 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&children[..])), } } - Dictionary(_, value_type) => { + Dictionary(_, value_type, _) => { // In this library, the dictionary "type" is a logical construct. Here we // pass through to the value type, as we've already captured the index // type in the DictionaryEncoding metadata in the parent field @@ -891,17 +890,15 @@ mod tests { Field::new("struct<>", DataType::Struct(vec![]), true), Field::new_dict( "dictionary", - DataType::Dictionary(IntegerType::Int32, Box::new(DataType::Utf8)), + DataType::Dictionary(IntegerType::Int32, Box::new(DataType::Utf8), true), true, 123, - true, ), Field::new_dict( "dictionary", - DataType::Dictionary(IntegerType::UInt8, Box::new(DataType::UInt32)), + DataType::Dictionary(IntegerType::UInt8, Box::new(DataType::UInt32), true), true, 123, - true, ), Field::new("decimal", DataType::Decimal(10, 6), false), ], diff --git a/src/io/ipc/read/common.rs b/src/io/ipc/read/common.rs index a7de2f599be..aa81e64b9b6 100644 --- a/src/io/ipc/read/common.rs +++ b/src/io/ipc/read/common.rs @@ -165,7 +165,7 @@ pub fn read_record_batch( fn find_first_dict_field_d(id: usize, data_type: &DataType) -> Option<&Field> { use DataType::*; match data_type { - Dictionary(_, inner) => find_first_dict_field_d(id, inner.as_ref()), + Dictionary(_, inner, _) => find_first_dict_field_d(id, inner.as_ref()), Map(field, _) => find_first_dict_field(id, field.as_ref()), List(field) => find_first_dict_field(id, field.as_ref()), LargeList(field) => find_first_dict_field(id, field.as_ref()), @@ -191,7 +191,7 @@ fn find_first_dict_field_d(id: usize, data_type: &DataType) -> Option<&Field> { } fn find_first_dict_field(id: usize, field: &Field) -> Option<&Field> { - if let DataType::Dictionary(_, _) = &field.data_type { + if let DataType::Dictionary(_, _, _) = &field.data_type { if field.dict_id as usize == id { return Some(field); } @@ -234,7 +234,7 @@ pub fn read_dictionary( // values array, we need to retrieve this from the schema. // Get an array representing this dictionary's values. let dictionary_values: ArrayRef = match first_field.data_type() { - DataType::Dictionary(_, ref value_type) => { + DataType::Dictionary(_, ref value_type, _) => { // Make a fake schema for the dictionary batch. let schema = Arc::new(Schema { fields: vec![Field::new("", value_type.as_ref().clone(), false)], diff --git a/src/io/ipc/write/common.rs b/src/io/ipc/write/common.rs index f5c11ff54af..238ba412c1f 100644 --- a/src/io/ipc/write/common.rs +++ b/src/io/ipc/write/common.rs @@ -381,7 +381,7 @@ impl DictionaryTracker { /// inserted. pub fn insert(&mut self, dict_id: i64, array: &Arc) -> Result { let values = match array.data_type() { - DataType::Dictionary(key_type, _) => { + DataType::Dictionary(key_type, _, _) => { match_integer_type!(key_type, |$T| { let array = array .as_any() diff --git a/src/io/ipc/write/serialize.rs b/src/io/ipc/write/serialize.rs index 5a261b060f8..114583d21f6 100644 --- a/src/io/ipc/write/serialize.rs +++ b/src/io/ipc/write/serialize.rs @@ -470,7 +470,7 @@ pub fn write_dictionary( write_keys: bool, ) -> usize { match array.data_type() { - DataType::Dictionary(key_type, _) => { + DataType::Dictionary(key_type, _, _) => { match_integer_type!(key_type, |$T| { _write_dictionary::<$T>( array, diff --git a/src/io/json/read/deserialize.rs b/src/io/json/read/deserialize.rs index a5b56c74c5c..a2a4f5fab86 100644 --- a/src/io/json/read/deserialize.rs +++ b/src/io/json/read/deserialize.rs @@ -238,7 +238,7 @@ fn _deserialize>(rows: &[A], data_type: DataType) -> Arc Arc::new(deserialize_binary::(rows)), DataType::LargeBinary => Arc::new(deserialize_binary::(rows)), DataType::Struct(_) => Arc::new(deserialize_struct(rows, data_type)), - DataType::Dictionary(key_type, _) => { + DataType::Dictionary(key_type, _, _) => { match_integer_type!(key_type, |$T| { Arc::new(deserialize_dictionary::<$T, _>(rows, data_type)) }) diff --git a/src/io/json_integration/mod.rs b/src/io/json_integration/mod.rs index 175c5068c78..ccb9b868d68 100644 --- a/src/io/json_integration/mod.rs +++ b/src/io/json_integration/mod.rs @@ -83,7 +83,7 @@ impl From<&Field> for ArrowJsonField { _ => None, }; - let dictionary = if let DataType::Dictionary(key_type, _) = &field.data_type { + let dictionary = if let DataType::Dictionary(key_type, _, is_ordered) = &field.data_type { use crate::datatypes::IntegerType::*; Some(ArrowJsonFieldDictionary { id: field.dict_id, @@ -100,7 +100,7 @@ impl From<&Field> for ArrowJsonField { UInt8 | UInt16 | UInt32 | UInt64 => false, }, }, - is_ordered: field.dict_is_ordered, + is_ordered: *is_ordered, }) } else { None diff --git a/src/io/json_integration/schema.rs b/src/io/json_integration/schema.rs index 8d55a76c676..607aca9474e 100644 --- a/src/io/json_integration/schema.rs +++ b/src/io/json_integration/schema.rs @@ -117,7 +117,7 @@ impl ToJson for DataType { TimeUnit::Microsecond => "MICROSECOND", TimeUnit::Nanosecond => "NANOSECOND", }}), - DataType::Dictionary(_, _) => json!({ "name": "dictionary"}), + DataType::Dictionary(_, _, _) => json!({ "name": "dictionary"}), DataType::Decimal(precision, scale) => { json!({"name": "decimal", "precision": precision, "scale": scale}) } @@ -136,7 +136,7 @@ impl ToJson for Field { _ => vec![], }; match self.data_type() { - DataType::Dictionary(ref index_type, ref value_type) => { + DataType::Dictionary(ref index_type, ref value_type, is_ordered) => { let index_type: DataType = (*index_type).into(); json!({ "name": self.name(), @@ -146,7 +146,7 @@ impl ToJson for Field { "dictionary": { "id": self.dict_id(), "indexType": index_type.to_json(), - "isOrdered": self.dict_is_ordered() + "isOrdered": is_ordered } }) } @@ -494,33 +494,32 @@ impl TryFrom<&Value> for Field { )); } }; - DataType::Dictionary(index_type, Box::new(data_type)) + let is_ordered = match dictionary.get("isOrdered") { + Some(&Value::Bool(n)) => n, + _ => { + return Err(ArrowError::OutOfSpec( + "Field missing 'isOrdered' attribute".to_string(), + )); + } + }; + DataType::Dictionary(index_type, Box::new(data_type), is_ordered) } else { data_type }; - let (dict_id, dict_is_ordered) = if let Some(dictionary) = map.get("dictionary") { - let dict_id = match dictionary.get("id") { + let dict_id = if let Some(dictionary) = map.get("dictionary") { + match dictionary.get("id") { Some(Value::Number(n)) => n.as_i64().unwrap(), _ => { return Err(ArrowError::OutOfSpec( "Field missing 'id' attribute".to_string(), )); } - }; - let dict_is_ordered = match dictionary.get("isOrdered") { - Some(&Value::Bool(n)) => n, - _ => { - return Err(ArrowError::OutOfSpec( - "Field missing 'isOrdered' attribute".to_string(), - )); - } - }; - (dict_id, dict_is_ordered) + } } else { - (0, false) + 0 }; - let mut f = Field::new_dict(&name, data_type, nullable, dict_id, dict_is_ordered); + let mut f = Field::new_dict(&name, data_type, nullable, dict_id); f.set_metadata(metadata); Ok(f) } diff --git a/src/io/parquet/read/mod.rs b/src/io/parquet/read/mod.rs index e3638304f4b..dd7eb3035e3 100644 --- a/src/io/parquet/read/mod.rs +++ b/src/io/parquet/read/mod.rs @@ -98,7 +98,7 @@ fn dict_read< data_type: DataType, ) -> Result> { use DataType::*; - let values_data_type = if let Dictionary(_, v) = &data_type { + let values_data_type = if let Dictionary(_, v, _) = &data_type { v.as_ref() } else { panic!() @@ -372,7 +372,7 @@ fn page_iter_to_array(iter, metadata, data_type, nested) } - Dictionary(key_type, _) => match_integer_type!(key_type, |$T| { + Dictionary(key_type, _, _) => match_integer_type!(key_type, |$T| { dict_read::<$T, _>(iter, metadata, data_type) }), diff --git a/src/io/parquet/read/primitive/mod.rs b/src/io/parquet/read/primitive/mod.rs index 357f6440a55..99c666b784f 100644 --- a/src/io/parquet/read/primitive/mod.rs +++ b/src/io/parquet/read/primitive/mod.rs @@ -49,7 +49,7 @@ where } let data_type = match data_type { - DataType::Dictionary(_, values) => values.as_ref().clone(), + DataType::Dictionary(_, values, _) => values.as_ref().clone(), _ => data_type, }; @@ -100,7 +100,7 @@ where } let data_type = match data_type { - DataType::Dictionary(_, values) => values.as_ref().clone(), + DataType::Dictionary(_, values, _) => values.as_ref().clone(), _ => data_type, }; diff --git a/src/io/parquet/write/mod.rs b/src/io/parquet/write/mod.rs index 8d0b5cc1215..12b12b17fb5 100644 --- a/src/io/parquet/write/mod.rs +++ b/src/io/parquet/write/mod.rs @@ -103,8 +103,8 @@ pub fn can_encode(data_type: &DataType, encoding: Encoding) -> bool { Encoding::DeltaLengthByteArray, DataType::Binary | DataType::LargeBinary | DataType::Utf8 | DataType::LargeUtf8, ) - | (Encoding::RleDictionary, DataType::Dictionary(_, _)) - | (Encoding::PlainDictionary, DataType::Dictionary(_, _)) + | (Encoding::RleDictionary, DataType::Dictionary(_, _, _)) + | (Encoding::PlainDictionary, DataType::Dictionary(_, _, _)) ) } @@ -116,7 +116,7 @@ pub fn array_to_pages( encoding: Encoding, ) -> Result>> { match array.data_type() { - DataType::Dictionary(key_type, _) => { + DataType::Dictionary(key_type, _, _) => { match_integer_type!(key_type, |$T| { dictionary::array_to_pages::<$T>( array.as_any().downcast_ref().unwrap(), diff --git a/src/io/parquet/write/schema.rs b/src/io/parquet/write/schema.rs index 88fed6505f3..521ec53aa88 100644 --- a/src/io/parquet/write/schema.rs +++ b/src/io/parquet/write/schema.rs @@ -277,7 +277,7 @@ pub fn to_parquet_type(field: &Field) -> Result { name, repetition, None, None, fields, None, )?) } - DataType::Dictionary(_, value) => { + DataType::Dictionary(_, value, _) => { let dict_field = Field::new(name.as_str(), value.as_ref().clone(), field.is_nullable()); to_parquet_type(&dict_field) } diff --git a/src/scalar/equal.rs b/src/scalar/equal.rs index 6c1ea3a1d56..583d36a5554 100644 --- a/src/scalar/equal.rs +++ b/src/scalar/equal.rs @@ -125,7 +125,7 @@ fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { let rhs = rhs.as_any().downcast_ref::>().unwrap(); lhs == rhs } - DataType::Dictionary(key_type, _) => match_integer_type!(key_type, |$T| { + DataType::Dictionary(key_type, _, _) => match_integer_type!(key_type, |$T| { let lhs = lhs.as_any().downcast_ref::>().unwrap(); let rhs = rhs.as_any().downcast_ref::>().unwrap(); lhs == rhs diff --git a/tests/it/compute/cast.rs b/tests/it/compute/cast.rs index 977aae317f5..d15d2e9e46f 100644 --- a/tests/it/compute/cast.rs +++ b/tests/it/compute/cast.rs @@ -483,7 +483,7 @@ fn utf8_to_dict() { let array = Utf8Array::::from(&[Some("one"), None, Some("three"), Some("one")]); // Cast to a dictionary (same value type, Utf8) - let cast_type = DataType::Dictionary(u8::KEY_TYPE, Box::new(DataType::Utf8)); + let cast_type = DataType::Dictionary(u8::KEY_TYPE, Box::new(DataType::Utf8), false); let result = cast(&array, &cast_type, CastOptions::default()).expect("cast failed"); let mut expected = MutableDictionaryArray::>::new(); @@ -514,7 +514,7 @@ fn i32_to_dict() { let array = Int32Array::from(&[Some(1), None, Some(3), Some(1)]); // Cast to a dictionary (same value type, Utf8) - let cast_type = DataType::Dictionary(u8::KEY_TYPE, Box::new(DataType::Int32)); + let cast_type = DataType::Dictionary(u8::KEY_TYPE, Box::new(DataType::Int32), false); let result = cast(&array, &cast_type, CastOptions::default()).expect("cast failed"); let mut expected = MutableDictionaryArray::>::new(); diff --git a/tests/it/ffi.rs b/tests/it/ffi.rs index cccd903319c..dfc5b5102cb 100644 --- a/tests/it/ffi.rs +++ b/tests/it/ffi.rs @@ -196,7 +196,7 @@ fn schema() -> Result<()> { let field = Field::new( "a", - DataType::Dictionary(u32::KEY_TYPE, Box::new(DataType::Utf8)), + DataType::Dictionary(u32::KEY_TYPE, Box::new(DataType::Utf8), false), true, ); test_round_trip_schema(field)?; diff --git a/tests/it/io/avro/read.rs b/tests/it/io/avro/read.rs index d55e10332d3..bd75048efc4 100644 --- a/tests/it/io/avro/read.rs +++ b/tests/it/io/avro/read.rs @@ -61,7 +61,7 @@ pub(super) fn schema() -> (AvroSchema, Schema) { ), Field::new( "enum", - DataType::Dictionary(i32::KEY_TYPE, Box::new(DataType::Utf8)), + DataType::Dictionary(i32::KEY_TYPE, Box::new(DataType::Utf8), false), false, ), ]); diff --git a/tests/it/io/json/mod.rs b/tests/it/io/json/mod.rs index a4988bdb6f3..caff072c340 100644 --- a/tests/it/io/json/mod.rs +++ b/tests/it/io/json/mod.rs @@ -130,7 +130,7 @@ fn case_dict() -> (String, Schema, Vec>) { let data_type = DataType::List(Box::new(Field::new( "item", - DataType::Dictionary(u64::KEY_TYPE, Box::new(DataType::Utf8)), + DataType::Dictionary(u64::KEY_TYPE, Box::new(DataType::Utf8), false), true, ))); diff --git a/tests/it/io/parquet/mod.rs b/tests/it/io/parquet/mod.rs index 9924de91da8..9a722802571 100644 --- a/tests/it/io/parquet/mod.rs +++ b/tests/it/io/parquet/mod.rs @@ -637,7 +637,7 @@ fn integration_write(schema: &Schema, batches: &[RecordBatch]) -> Result .iter() .zip(descritors.clone()) .map(|(array, descriptor)| { - let encoding = if let DataType::Dictionary(_, _) = array.data_type() { + let encoding = if let DataType::Dictionary(..) = array.data_type() { Encoding::RleDictionary } else { Encoding::Plain diff --git a/tests/it/io/print.rs b/tests/it/io/print.rs index 1af0b471d48..b1fd5f5040c 100644 --- a/tests/it/io/print.rs +++ b/tests/it/io/print.rs @@ -87,7 +87,7 @@ fn write_null() -> Result<()> { #[test] fn write_dictionary() -> Result<()> { // define a schema. - let field_type = DataType::Dictionary(i32::KEY_TYPE, Box::new(DataType::Utf8)); + let field_type = DataType::Dictionary(i32::KEY_TYPE, Box::new(DataType::Utf8), false); let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); let mut array = MutableDictionaryArray::>::new(); @@ -119,7 +119,7 @@ fn write_dictionary() -> Result<()> { #[test] fn dictionary_validities() -> Result<()> { // define a schema. - let field_type = DataType::Dictionary(i32::KEY_TYPE, Box::new(DataType::Int32)); + let field_type = DataType::Dictionary(i32::KEY_TYPE, Box::new(DataType::Int32), false); let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); let keys = PrimitiveArray::::from([Some(1), None, Some(0)]);