diff --git a/integration-testing/src/flight_client_scenarios/integration_test.rs b/integration-testing/src/flight_client_scenarios/integration_test.rs index 27a06b307c7..dd65352df5b 100644 --- a/integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/integration-testing/src/flight_client_scenarios/integration_test.rs @@ -33,7 +33,7 @@ use arrow_format::ipc::Message::MessageHeader; use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; use tonic::{Request, Streaming}; -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; type ArrayRef = Arc; type SchemaRef = Arc; @@ -199,10 +199,10 @@ async fn consume_flight_location( // first FlightData. Ignore this one. let _schema_again = resp.next().await.unwrap(); - let mut dictionaries_by_field = vec![None; schema.fields().len()]; + let mut dictionaries = Default::default(); for (counter, expected_batch) in expected_data.iter().enumerate() { - let data = receive_batch_flight_data(&mut resp, schema.clone(), &mut dictionaries_by_field) + let data = receive_batch_flight_data(&mut resp, schema.clone(), &mut dictionaries) .await .unwrap_or_else(|| { panic!( @@ -215,7 +215,7 @@ async fn consume_flight_location( let metadata = counter.to_string().into_bytes(); assert_eq!(metadata, data.app_metadata); - let actual_batch = deserialize_batch(&data, schema.clone(), true, &dictionaries_by_field) + let actual_batch = deserialize_batch(&data, schema.clone(), true, &dictionaries) .expect("Unable to convert flight data to Arrow batch"); assert_eq!(expected_batch.schema(), actual_batch.schema()); @@ -245,7 +245,7 @@ async fn consume_flight_location( async fn receive_batch_flight_data( resp: &mut Streaming, schema: SchemaRef, - dictionaries_by_field: &mut [Option], + dictionaries: &mut HashMap>, ) -> Option { let mut data = resp.next().await?.ok()?; let mut message = @@ -259,7 +259,7 @@ async fn receive_batch_flight_data( .expect("Error parsing dictionary"), &schema, true, - dictionaries_by_field, + dictionaries, &mut reader, 0, ) diff --git a/integration-testing/src/flight_server_scenarios/integration_test.rs b/integration-testing/src/flight_server_scenarios/integration_test.rs index c2f9d967be0..a91e2b7348e 100644 --- a/integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/integration-testing/src/flight_server_scenarios/integration_test.rs @@ -275,7 +275,7 @@ async fn record_batch_from_message( message: Message<'_>, data_body: &[u8], schema_ref: Arc, - dictionaries_by_field: &[Option>], + dictionaries: &mut HashMap>, ) -> Result { let ipc_batch = message .header_as_record_batch() @@ -288,7 +288,7 @@ async fn record_batch_from_message( schema_ref, None, true, - dictionaries_by_field, + dictionaries, ArrowSchema::MetadataVersion::V5, &mut reader, 0, @@ -302,7 +302,7 @@ async fn dictionary_from_message( message: Message<'_>, data_body: &[u8], schema_ref: Arc, - dictionaries_by_field: &mut [Option>], + dictionaries: &mut HashMap>, ) -> Result<(), Status> { let ipc_batch = message .header_as_dictionary_batch() @@ -310,14 +310,8 @@ async fn dictionary_from_message( let mut reader = std::io::Cursor::new(data_body); - let dictionary_batch_result = ipc::read::read_dictionary( - ipc_batch, - &schema_ref, - true, - dictionaries_by_field, - &mut reader, - 0, - ); + let dictionary_batch_result = + ipc::read::read_dictionary(ipc_batch, &schema_ref, true, dictionaries, &mut reader, 0); dictionary_batch_result .map_err(|e| Status::internal(format!("Could not convert to Dictionary: {:?}", e))) } @@ -333,7 +327,7 @@ async fn save_uploaded_chunks( let mut chunks = vec![]; let mut uploaded_chunks = uploaded_chunks.lock().await; - let mut dictionaries_by_field = vec![None; schema_ref.fields().len()]; + let mut dictionaries = Default::default(); while let Some(Ok(data)) = input_stream.next().await { let message = root_as_message(&data.data_header[..]) @@ -352,7 +346,7 @@ async fn save_uploaded_chunks( message, &data.data_body, schema_ref.clone(), - &dictionaries_by_field, + &mut dictionaries, ) .await?; @@ -363,7 +357,7 @@ async fn save_uploaded_chunks( message, &data.data_body, schema_ref.clone(), - &mut dictionaries_by_field, + &mut dictionaries, ) .await?; } diff --git a/integration-testing/unskip.patch b/integration-testing/unskip.patch index 8809446ba00..436e2b5cfa2 100644 --- a/integration-testing/unskip.patch +++ b/integration-testing/unskip.patch @@ -1,8 +1,8 @@ diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py -index 2d90d6c86..d5a0bc833 100644 +index 6a077a893..cab6ecd37 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py -@@ -1569,8 +1569,7 @@ def get_generated_json_files(tempdir=None): +@@ -1561,8 +1561,7 @@ def get_generated_json_files(tempdir=None): .skip_category('C#') .skip_category('JS'), # TODO(ARROW-7900) @@ -12,7 +12,7 @@ index 2d90d6c86..d5a0bc833 100644 generate_decimal256_case() .skip_category('Go') # TODO(ARROW-7948): Decimal + Go -@@ -1582,18 +1581,15 @@ def get_generated_json_files(tempdir=None): +@@ -1574,18 +1573,15 @@ def get_generated_json_files(tempdir=None): generate_interval_case() .skip_category('C#') @@ -34,7 +34,7 @@ index 2d90d6c86..d5a0bc833 100644 generate_non_canonical_map_case() .skip_category('C#') -@@ -1611,14 +1607,12 @@ def get_generated_json_files(tempdir=None): +@@ -1602,14 +1598,12 @@ def get_generated_json_files(tempdir=None): generate_nested_large_offsets_case() .skip_category('C#') .skip_category('Go') @@ -51,7 +51,14 @@ index 2d90d6c86..d5a0bc833 100644 generate_custom_metadata_case() .skip_category('C#') -@@ -1649,8 +1643,7 @@ def get_generated_json_files(tempdir=None): +@@ -1634,14 +1628,12 @@ def get_generated_json_files(tempdir=None): + .skip_category('C#') + .skip_category('Go') + .skip_category('Java') # TODO(ARROW-7779) +- .skip_category('JS') +- .skip_category('Rust'), ++ .skip_category('JS'), + generate_extension_case() .skip_category('C#') .skip_category('Go') # TODO(ARROW-3039): requires dictionaries diff --git a/src/array/dictionary/mod.rs b/src/array/dictionary/mod.rs index 8f4de0f4bd4..de45eff8bee 100644 --- a/src/array/dictionary/mod.rs +++ b/src/array/dictionary/mod.rs @@ -13,7 +13,8 @@ mod mutable; pub use iterator::*; pub use mutable::*; -use super::{new_empty_array, primitive::PrimitiveArray, Array}; +use super::display::get_value_display; +use super::{display_fmt, new_empty_array, primitive::PrimitiveArray, Array}; use crate::scalar::NullScalar; /// Trait denoting [`NativeType`]s that can be used as keys of a dictionary. @@ -196,9 +197,10 @@ where PrimitiveArray: std::fmt::Display, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!(f, "{:?}{{", self.data_type())?; - writeln!(f, "keys: {},", self.keys())?; - writeln!(f, "values: {},", self.values())?; - write!(f, "}}") + let display = get_value_display(self); + let new_lines = false; + let head = &format!("{}", self.data_type()); + let iter = self.iter().enumerate().map(|(i, x)| x.map(|_| display(i))); + display_fmt(iter, head, f, new_lines) } } diff --git a/src/array/display.rs b/src/array/display.rs index a69b3342038..12aac44b9ea 100644 --- a/src/array/display.rs +++ b/src/array/display.rs @@ -165,7 +165,13 @@ pub fn get_value_display<'a>(array: &'a dyn Array) -> Box Strin .unwrap(); let keys = a.keys(); let display = get_display(a.values().as_ref()); - Box::new(move |row: usize| display(keys.value(row) as usize)) + Box::new(move |row: usize| { + if keys.is_null(row) { + "".to_string() + }else { + display(keys.value(row) as usize) + } + }) }), Map(_, _) => todo!(), Struct(_) => { diff --git a/src/datatypes/field.rs b/src/datatypes/field.rs index 8fc544f1ffd..f6ba2f84e76 100644 --- a/src/datatypes/field.rs +++ b/src/datatypes/field.rs @@ -23,7 +23,7 @@ use super::DataType; /// A logical [`DataType`] and its associated metadata per /// [Arrow specification](https://arrow.apache.org/docs/cpp/api/datatype.html) -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Eq)] pub struct Field { /// Its name pub name: String, @@ -39,6 +39,26 @@ pub struct Field { pub metadata: Option>, } +impl std::hash::Hash for Field { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.data_type.hash(state); + self.nullable.hash(state); + self.dict_is_ordered.hash(state); + self.metadata.hash(state); + } +} + +impl PartialEq for Field { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.data_type == other.data_type + && self.nullable == other.nullable + && self.dict_is_ordered == other.dict_is_ordered + && self.metadata == other.metadata + } +} + impl Field { /// Creates a new field pub fn new(name: &str, data_type: DataType, nullable: bool) -> Self { diff --git a/src/datatypes/schema.rs b/src/datatypes/schema.rs index afb656899b7..4ab3488a63d 100644 --- a/src/datatypes/schema.rs +++ b/src/datatypes/schema.rs @@ -164,14 +164,6 @@ impl Schema { Ok(&self.fields[self.index_of(name)?]) } - /// Returns all [`Field`]s with dictionary id `dict_id`. - pub fn fields_with_dict_id(&self, dict_id: i64) -> Vec<&Field> { - self.fields - .iter() - .filter(|f| f.dict_id() == Some(dict_id)) - .collect() - } - /// Find the index of the column with the given name. pub fn index_of(&self, name: &str) -> Result { for i in 0..self.fields.len() { diff --git a/src/io/flight/mod.rs b/src/io/flight/mod.rs index ac4f0773f34..75b270c7c68 100644 --- a/src/io/flight/mod.rs +++ b/src/io/flight/mod.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::convert::TryFrom; use std::sync::Arc; @@ -122,7 +123,7 @@ pub fn deserialize_batch( data: &FlightData, schema: Arc, is_little_endian: bool, - dictionaries_by_field: &[Option>], + dictionaries: &HashMap>, ) -> Result { // check that the data_header is a record batch message let message = ipc::Message::root_as_message(&data.data_header[..]) @@ -141,7 +142,7 @@ pub fn deserialize_batch( schema.clone(), None, is_little_endian, - dictionaries_by_field, + dictionaries, ipc::Schema::MetadataVersion::V5, &mut reader, 0, diff --git a/src/io/ipc/convert.rs b/src/io/ipc/convert.rs index c3667baac04..ecaef6bc58e 100644 --- a/src/io/ipc/convert.rs +++ b/src/io/ipc/convert.rs @@ -663,24 +663,8 @@ pub(crate) fn get_fb_field_type<'a>( } } Struct(fields) => { - // struct's fields are children - let mut children = vec![]; - for field in fields { - let inner_types = get_fb_field_type(field.data_type(), field.is_nullable(), fbb); - let field_name = fbb.create_string(field.name()); - children.push(ipc::Field::create( - fbb, - &ipc::FieldArgs { - name: Some(field_name), - nullable: field.is_nullable(), - type_type: inner_types.type_type, - type_: Some(inner_types.type_), - dictionary: None, - children: inner_types.children, - custom_metadata: None, - }, - )); - } + let children: Vec<_> = fields.iter().map(|field| build_field(fbb, field)).collect(); + FbFieldType { type_type, type_: ipc::Struct_Builder::new(fbb).finish().as_union_value(), diff --git a/src/io/ipc/read/array/binary.rs b/src/io/ipc/read/array/binary.rs index cae97b1a1eb..485e2451637 100644 --- a/src/io/ipc/read/array/binary.rs +++ b/src/io/ipc/read/array/binary.rs @@ -25,7 +25,7 @@ pub fn read_binary( where Vec: TryInto + TryInto<::Bytes>, { - let field_node = field_nodes.pop_front().unwrap().0; + let field_node = field_nodes.pop_front().unwrap(); let validity = read_validity( buffers, diff --git a/src/io/ipc/read/array/boolean.rs b/src/io/ipc/read/array/boolean.rs index d9ea780d170..bab7b4250ee 100644 --- a/src/io/ipc/read/array/boolean.rs +++ b/src/io/ipc/read/array/boolean.rs @@ -19,7 +19,7 @@ pub fn read_boolean( is_little_endian: bool, compression: Option, ) -> Result { - let field_node = field_nodes.pop_front().unwrap().0; + let field_node = field_nodes.pop_front().unwrap(); let length = field_node.length() as usize; let validity = read_validity( diff --git a/src/io/ipc/read/array/dictionary.rs b/src/io/ipc/read/array/dictionary.rs index 0357457d014..1d61b59b6f4 100644 --- a/src/io/ipc/read/array/dictionary.rs +++ b/src/io/ipc/read/array/dictionary.rs @@ -1,19 +1,24 @@ -use std::collections::VecDeque; +use std::collections::{HashMap, HashSet, VecDeque}; use std::convert::TryInto; use std::io::{Read, Seek}; +use std::sync::Arc; use arrow_format::ipc; -use crate::array::{DictionaryArray, DictionaryKey}; -use crate::error::Result; +use crate::array::{Array, DictionaryArray, DictionaryKey}; +use crate::datatypes::Field; +use crate::error::{ArrowError, Result}; use super::super::deserialize::Node; use super::{read_primitive, skip_primitive}; +#[allow(clippy::too_many_arguments)] pub fn read_dictionary( field_nodes: &mut VecDeque, + field: &Field, buffers: &mut VecDeque<&ipc::Schema::Buffer>, reader: &mut R, + dictionaries: &HashMap>, block_offset: u64, compression: Option, is_little_endian: bool, @@ -21,7 +26,17 @@ pub fn read_dictionary( where Vec: TryInto, { - let values = field_nodes.front().unwrap().1.as_ref().unwrap(); + let id = field.dict_id().unwrap() as usize; + let values = dictionaries + .get(&id) + .ok_or_else(|| { + let valid_ids = dictionaries.keys().collect::>(); + ArrowError::Ipc(format!( + "Dictionary id {} not found. Valid ids: {:?}", + id, valid_ids + )) + })? + .clone(); let keys = read_primitive( field_nodes, @@ -33,7 +48,7 @@ where compression, )?; - Ok(DictionaryArray::::from_data(keys, values.clone())) + Ok(DictionaryArray::::from_data(keys, values)) } pub fn skip_dictionary( diff --git a/src/io/ipc/read/array/fixed_size_binary.rs b/src/io/ipc/read/array/fixed_size_binary.rs index 094b34d2f78..b2547016e4b 100644 --- a/src/io/ipc/read/array/fixed_size_binary.rs +++ b/src/io/ipc/read/array/fixed_size_binary.rs @@ -19,7 +19,7 @@ pub fn read_fixed_size_binary( is_little_endian: bool, compression: Option, ) -> Result { - let field_node = field_nodes.pop_front().unwrap().0; + let field_node = field_nodes.pop_front().unwrap(); let validity = read_validity( buffers, diff --git a/src/io/ipc/read/array/fixed_size_list.rs b/src/io/ipc/read/array/fixed_size_list.rs index 3213eecc296..274b50fe490 100644 --- a/src/io/ipc/read/array/fixed_size_list.rs +++ b/src/io/ipc/read/array/fixed_size_list.rs @@ -1,9 +1,10 @@ -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::io::{Read, Seek}; +use std::sync::Arc; use arrow_format::ipc; -use crate::array::FixedSizeListArray; +use crate::array::{Array, FixedSizeListArray}; use crate::datatypes::DataType; use crate::error::Result; @@ -16,12 +17,13 @@ pub fn read_fixed_size_list( data_type: DataType, buffers: &mut VecDeque<&ipc::Schema::Buffer>, reader: &mut R, + dictionaries: &HashMap>, block_offset: u64, is_little_endian: bool, compression: Option, version: ipc::Schema::MetadataVersion, ) -> Result { - let field_node = field_nodes.pop_front().unwrap().0; + let field_node = field_nodes.pop_front().unwrap(); let validity = read_validity( buffers, @@ -32,13 +34,14 @@ pub fn read_fixed_size_list( compression, )?; - let (value_data_type, _) = FixedSizeListArray::get_child_and_size(&data_type); + let (field, _) = FixedSizeListArray::get_child_and_size(&data_type); let values = read( field_nodes, - value_data_type.data_type().clone(), + field, buffers, reader, + dictionaries, block_offset, is_little_endian, compression, diff --git a/src/io/ipc/read/array/list.rs b/src/io/ipc/read/array/list.rs index ab55bb7bdba..65495cd3aa3 100644 --- a/src/io/ipc/read/array/list.rs +++ b/src/io/ipc/read/array/list.rs @@ -1,10 +1,11 @@ -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::convert::TryInto; use std::io::{Read, Seek}; +use std::sync::Arc; use arrow_format::ipc; -use crate::array::{ListArray, Offset}; +use crate::array::{Array, ListArray, Offset}; use crate::buffer::Buffer; use crate::datatypes::DataType; use crate::error::Result; @@ -18,6 +19,7 @@ pub fn read_list( data_type: DataType, buffers: &mut VecDeque<&ipc::Schema::Buffer>, reader: &mut R, + dictionaries: &HashMap>, block_offset: u64, is_little_endian: bool, compression: Option, @@ -26,7 +28,7 @@ pub fn read_list( where Vec: TryInto, { - let field_node = field_nodes.pop_front().unwrap().0; + let field_node = field_nodes.pop_front().unwrap(); let validity = read_validity( buffers, @@ -48,13 +50,14 @@ where // Older versions of the IPC format sometimes do not report an offset .or_else(|_| Result::Ok(Buffer::::from(&[O::default()])))?; - let value_data_type = ListArray::::get_child_type(&data_type).clone(); + let field = ListArray::::get_child_field(&data_type); let values = read( field_nodes, - value_data_type, + field, buffers, reader, + dictionaries, block_offset, is_little_endian, compression, diff --git a/src/io/ipc/read/array/map.rs b/src/io/ipc/read/array/map.rs index a11151d6ec9..bbd518e8d4b 100644 --- a/src/io/ipc/read/array/map.rs +++ b/src/io/ipc/read/array/map.rs @@ -1,9 +1,10 @@ -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::io::{Read, Seek}; +use std::sync::Arc; use arrow_format::ipc; -use crate::array::MapArray; +use crate::array::{Array, MapArray}; use crate::buffer::Buffer; use crate::datatypes::DataType; use crate::error::Result; @@ -17,12 +18,13 @@ pub fn read_map( data_type: DataType, buffers: &mut VecDeque<&ipc::Schema::Buffer>, reader: &mut R, + dictionaries: &HashMap>, block_offset: u64, is_little_endian: bool, compression: Option, version: ipc::Schema::MetadataVersion, ) -> Result { - let field_node = field_nodes.pop_front().unwrap().0; + let field_node = field_nodes.pop_front().unwrap(); let validity = read_validity( buffers, @@ -44,13 +46,14 @@ pub fn read_map( // Older versions of the IPC format sometimes do not report an offset .or_else(|_| Result::Ok(Buffer::::from(&[0i32])))?; - let value_data_type = MapArray::get_field(&data_type).data_type().clone(); + let field = MapArray::get_field(&data_type); let field = read( field_nodes, - value_data_type, + field, buffers, reader, + dictionaries, block_offset, is_little_endian, compression, diff --git a/src/io/ipc/read/array/null.rs b/src/io/ipc/read/array/null.rs index d1cdcb0e499..2885a620ffd 100644 --- a/src/io/ipc/read/array/null.rs +++ b/src/io/ipc/read/array/null.rs @@ -7,7 +7,7 @@ use super::super::deserialize::Node; pub fn read_null(field_nodes: &mut VecDeque, data_type: DataType) -> NullArray { NullArray::from_data( data_type, - field_nodes.pop_front().unwrap().0.length() as usize, + field_nodes.pop_front().unwrap().length() as usize, ) } diff --git a/src/io/ipc/read/array/primitive.rs b/src/io/ipc/read/array/primitive.rs index df851e2c4ff..f43eb8f2703 100644 --- a/src/io/ipc/read/array/primitive.rs +++ b/src/io/ipc/read/array/primitive.rs @@ -22,7 +22,7 @@ pub fn read_primitive( where Vec: TryInto, { - let field_node = field_nodes.pop_front().unwrap().0; + let field_node = field_nodes.pop_front().unwrap(); let validity = read_validity( buffers, diff --git a/src/io/ipc/read/array/struct_.rs b/src/io/ipc/read/array/struct_.rs index 312a68ea262..775774291d3 100644 --- a/src/io/ipc/read/array/struct_.rs +++ b/src/io/ipc/read/array/struct_.rs @@ -1,9 +1,10 @@ -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::io::{Read, Seek}; +use std::sync::Arc; use arrow_format::ipc; -use crate::array::StructArray; +use crate::array::{Array, StructArray}; use crate::datatypes::DataType; use crate::error::Result; @@ -16,12 +17,13 @@ pub fn read_struct( data_type: DataType, buffers: &mut VecDeque<&ipc::Schema::Buffer>, reader: &mut R, + dictionaries: &HashMap>, block_offset: u64, is_little_endian: bool, compression: Option, version: ipc::Schema::MetadataVersion, ) -> Result { - let field_node = field_nodes.pop_front().unwrap().0; + let field_node = field_nodes.pop_front().unwrap(); let validity = read_validity( buffers, @@ -39,9 +41,10 @@ pub fn read_struct( .map(|field| { read( field_nodes, - field.data_type().clone(), + field, buffers, reader, + dictionaries, block_offset, is_little_endian, compression, diff --git a/src/io/ipc/read/array/union.rs b/src/io/ipc/read/array/union.rs index 66cab751c42..87afdfb582e 100644 --- a/src/io/ipc/read/array/union.rs +++ b/src/io/ipc/read/array/union.rs @@ -1,9 +1,10 @@ -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::io::{Read, Seek}; +use std::sync::Arc; use arrow_format::ipc; -use crate::array::UnionArray; +use crate::array::{Array, UnionArray}; use crate::datatypes::DataType; use crate::datatypes::UnionMode::Dense; use crate::error::Result; @@ -17,12 +18,13 @@ pub fn read_union( data_type: DataType, buffers: &mut VecDeque<&ipc::Schema::Buffer>, reader: &mut R, + dictionaries: &HashMap>, block_offset: u64, is_little_endian: bool, compression: Option, version: ipc::Schema::MetadataVersion, ) -> Result { - let field_node = field_nodes.pop_front().unwrap().0; + let field_node = field_nodes.pop_front().unwrap(); if version != ipc::Schema::MetadataVersion::V5 { let _ = buffers.pop_front().unwrap(); @@ -61,9 +63,10 @@ pub fn read_union( .map(|field| { read( field_nodes, - field.data_type().clone(), + field, buffers, reader, + dictionaries, block_offset, is_little_endian, compression, diff --git a/src/io/ipc/read/array/utf8.rs b/src/io/ipc/read/array/utf8.rs index d11ae7a12fa..93d024b9f13 100644 --- a/src/io/ipc/read/array/utf8.rs +++ b/src/io/ipc/read/array/utf8.rs @@ -25,7 +25,7 @@ pub fn read_utf8( where Vec: TryInto + TryInto<::Bytes>, { - let field_node = field_nodes.pop_front().unwrap().0; + let field_node = field_nodes.pop_front().unwrap(); let validity = read_validity( buffers, diff --git a/src/io/ipc/read/common.rs b/src/io/ipc/read/common.rs index aa0fa7b6490..055b31b3705 100644 --- a/src/io/ipc/read/common.rs +++ b/src/io/ipc/read/common.rs @@ -98,7 +98,7 @@ pub fn read_record_batch( schema: Arc, projection: Option<(&[usize], Arc)>, is_little_endian: bool, - dictionaries: &[Option], + dictionaries: &HashMap>, version: MetadataVersion, reader: &mut R, block_offset: u64, @@ -111,13 +111,7 @@ pub fn read_record_batch( ArrowError::Ipc("Unable to get field nodes from IPC RecordBatch".to_string()) })?; - // This is a bug fix: we should have one dictionary per node, not schema field - let dictionaries = dictionaries.iter().chain(std::iter::repeat(&None)); - - let mut field_nodes = field_nodes - .iter() - .zip(dictionaries) - .collect::>(); + let mut field_nodes = field_nodes.iter().collect::>(); let (schema, columns) = if let Some(projection) = projection { let projected_schema = projection.1.clone(); @@ -128,9 +122,10 @@ pub fn read_record_batch( .map(|maybe_field| match maybe_field { ProjectionResult::Selected(field) => Some(read( &mut field_nodes, - field.data_type().clone(), + field, &mut buffers, reader, + dictionaries, block_offset, is_little_endian, batch.compression(), @@ -151,9 +146,10 @@ pub fn read_record_batch( .map(|field| { read( &mut field_nodes, - field.data_type().clone(), + field, &mut buffers, reader, + dictionaries, block_offset, is_little_endian, batch.compression(), @@ -166,13 +162,62 @@ pub fn read_record_batch( RecordBatch::try_new(schema, columns) } +fn find_first_dict_field_d(id: usize, data_type: &DataType) -> Option<&Field> { + use DataType::*; + match data_type { + Dictionary(_, inner) => find_first_dict_field_d(id, inner.as_ref()), + Map(field, _) => find_first_dict_field(id, field.as_ref()), + List(field) => find_first_dict_field(id, field.as_ref()), + LargeList(field) => find_first_dict_field(id, field.as_ref()), + FixedSizeList(field, _) => find_first_dict_field(id, field.as_ref()), + Union(fields, _, _) => { + for field in fields { + if let Some(f) = find_first_dict_field(id, field) { + return Some(f); + } + } + None + } + Struct(fields) => { + for field in fields { + if let Some(f) = find_first_dict_field(id, field) { + return Some(f); + } + } + None + } + _ => None, + } +} + +fn find_first_dict_field(id: usize, field: &Field) -> Option<&Field> { + if let DataType::Dictionary(_, _) = &field.data_type { + if field.dict_id as usize == id { + return Some(field); + } + } + find_first_dict_field_d(id, &field.data_type) +} + +fn first_dict_field(id: usize, fields: &[Field]) -> Result<&Field> { + for field in fields { + if let Some(field) = find_first_dict_field(id, field) { + return Ok(field); + } + } + Err(ArrowError::Schema(format!( + "dictionary id {} not found in schema", + id + ))) +} + /// Read the dictionary from the buffer and provided metadata, -/// updating the `dictionaries_by_field` with the resulting dictionary +/// updating the `dictionaries` with the resulting dictionary pub fn read_dictionary( batch: ipc::Message::DictionaryBatch, schema: &Schema, is_little_endian: bool, - dictionaries_by_field: &mut [Option], + dictionaries: &mut HashMap>, reader: &mut R, block_offset: u64, ) -> Result<()> { @@ -183,10 +228,7 @@ pub fn read_dictionary( } let id = batch.id(); - let fields_using_this_dictionary = schema.fields_with_dict_id(id); - let first_field = fields_using_this_dictionary.first().ok_or_else(|| { - ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string()) - })?; + let first_field = first_dict_field(id as usize, &schema.fields)?; // As the dictionary batch does not contain the type of the // values array, we need to retrieve this from the schema. @@ -204,7 +246,7 @@ pub fn read_dictionary( schema, None, is_little_endian, - dictionaries_by_field, + dictionaries, MetadataVersion::V5, reader, block_offset, @@ -217,16 +259,7 @@ pub fn read_dictionary( ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string()) })?; - // for all fields with this dictionary id, update the dictionaries vector - // in the reader. Note that a dictionary batch may be shared between many fields. - // We don't currently record the isOrdered field. This could be general - // attributes of arrays. - for (i, field) in schema.fields().iter().enumerate() { - if field.dict_id() == Some(id) { - // Add (possibly multiple) array refs to the dictionaries array. - dictionaries_by_field[i] = Some(dictionary_values.clone()); - } - } + dictionaries.insert(id as usize, dictionary_values); Ok(()) } diff --git a/src/io/ipc/read/deserialize.rs b/src/io/ipc/read/deserialize.rs index 1d243e5256d..f533a63af40 100644 --- a/src/io/ipc/read/deserialize.rs +++ b/src/io/ipc/read/deserialize.rs @@ -1,9 +1,4 @@ -//! Arrow IPC File and Stream Readers -//! -//! The `FileReader` and `StreamReader` have similar interfaces, -//! however the `FileReader` expects a reader that supports `Seek`ing - -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::{ io::{Read, Seek}, sync::Arc, @@ -13,25 +8,28 @@ use arrow_format::ipc; use arrow_format::ipc::{Message::BodyCompression, Schema::MetadataVersion}; use crate::array::*; -use crate::datatypes::{DataType, PhysicalType}; +use crate::datatypes::{DataType, Field, PhysicalType}; use crate::error::Result; use super::array::*; -pub type Node<'a> = (&'a ipc::Message::FieldNode, &'a Option>); +pub type Node<'a> = &'a ipc::Message::FieldNode; #[allow(clippy::too_many_arguments)] pub fn read( field_nodes: &mut VecDeque, - data_type: DataType, + field: &Field, buffers: &mut VecDeque<&ipc::Schema::Buffer>, reader: &mut R, + dictionaries: &HashMap>, block_offset: u64, is_little_endian: bool, compression: Option, version: MetadataVersion, ) -> Result> { use PhysicalType::*; + let data_type = field.data_type().clone(); + match data_type.to_physical_type() { Null => { let array = read_null(field_nodes, data_type); @@ -124,6 +122,7 @@ pub fn read( data_type, buffers, reader, + dictionaries, block_offset, is_little_endian, compression, @@ -135,6 +134,7 @@ pub fn read( data_type, buffers, reader, + dictionaries, block_offset, is_little_endian, compression, @@ -146,6 +146,7 @@ pub fn read( data_type, buffers, reader, + dictionaries, block_offset, is_little_endian, compression, @@ -157,6 +158,7 @@ pub fn read( data_type, buffers, reader, + dictionaries, block_offset, is_little_endian, compression, @@ -167,8 +169,10 @@ pub fn read( match_integer_type!(key_type, |$T| { read_dictionary::<$T, _>( field_nodes, + field, buffers, reader, + dictionaries, block_offset, compression, is_little_endian, @@ -181,6 +185,7 @@ pub fn read( data_type, buffers, reader, + dictionaries, block_offset, is_little_endian, compression, @@ -192,6 +197,7 @@ pub fn read( data_type, buffers, reader, + dictionaries, block_offset, is_little_endian, compression, diff --git a/src/io/ipc/read/reader.rs b/src/io/ipc/read/reader.rs index 6f2707a6412..0f025e4451c 100644 --- a/src/io/ipc/read/reader.rs +++ b/src/io/ipc/read/reader.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::io::{Read, Seek, SeekFrom}; use std::sync::Arc; @@ -30,8 +31,6 @@ use super::super::convert; use super::super::{ARROW_MAGIC, CONTINUATION_MARKER}; use super::common::*; -type ArrayRef = Arc; - #[derive(Debug, Clone)] pub struct FileMetadata { /// The schema that is read from the file header @@ -45,10 +44,8 @@ pub struct FileMetadata { /// The total number of blocks, which may contain record batches and other types total_blocks: usize, - /// Optional dictionaries for each schema field. - /// - /// Dictionaries may be appended to in the streaming format. - dictionaries_by_field: Vec>, + /// Dictionaries associated to each dict_id + dictionaries: HashMap>, /// FileMetadata version version: ipc::Schema::MetadataVersion, @@ -122,8 +119,8 @@ pub fn read_file_metadata(reader: &mut R) -> Result(reader: &mut R) -> Result(reader: &mut R) -> Result( metadata.schema.clone(), projection, metadata.is_little_endian, - &metadata.dictionaries_by_field, + &metadata.dictionaries, metadata.version, reader, block.offset() as u64 + block.metaDataLength() as u64, diff --git a/src/io/ipc/read/stream.rs b/src/io/ipc/read/stream.rs index bd412e7b213..d915d9fa79a 100644 --- a/src/io/ipc/read/stream.rs +++ b/src/io/ipc/read/stream.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::io::Read; use std::sync::Arc; @@ -30,8 +31,6 @@ use super::super::convert; use super::super::CONTINUATION_MARKER; use super::common::*; -type ArrayRef = Arc; - #[derive(Debug)] pub struct StreamMetadata { /// The schema that is read from the stream's first message @@ -113,7 +112,7 @@ impl StreamState { pub fn read_next( reader: &mut R, metadata: &StreamMetadata, - dictionaries_by_field: &mut Vec>, + dictionaries: &mut HashMap>, ) -> Result> { // determine metadata length let mut meta_size: [u8; 4] = [0; 4]; @@ -172,7 +171,7 @@ pub fn read_next( metadata.schema.clone(), None, metadata.is_little_endian, - dictionaries_by_field, + dictionaries, metadata.version, &mut reader, 0, @@ -193,13 +192,13 @@ pub fn read_next( batch, &metadata.schema, metadata.is_little_endian, - dictionaries_by_field, + dictionaries, &mut dict_reader, 0, )?; // read the next message until we encounter a RecordBatch - read_next(reader, metadata, dictionaries_by_field) + read_next(reader, metadata, dictionaries) } ipc::Message::MessageHeader::NONE => Ok(Some(StreamState::Waiting)), t => Err(ArrowError::Ipc(format!( @@ -218,7 +217,7 @@ pub fn read_next( pub struct StreamReader { reader: R, metadata: StreamMetadata, - dictionaries_by_field: Vec>, + dictionaries: HashMap>, finished: bool, } @@ -229,11 +228,10 @@ impl StreamReader { /// encounter a schema. /// To check if the reader is done, use `is_finished(self)` pub fn new(reader: R, metadata: StreamMetadata) -> Self { - let fields = metadata.schema.fields().len(); Self { reader, metadata, - dictionaries_by_field: vec![None; fields], + dictionaries: Default::default(), finished: false, } } @@ -252,11 +250,7 @@ impl StreamReader { if self.finished { return Ok(None); } - let batch = read_next( - &mut self.reader, - &self.metadata, - &mut self.dictionaries_by_field, - )?; + let batch = read_next(&mut self.reader, &self.metadata, &mut self.dictionaries)?; if batch.is_none() { self.finished = true; } diff --git a/src/io/ipc/write/common.rs b/src/io/ipc/write/common.rs index b6d2f4bd8ec..f5c11ff54af 100644 --- a/src/io/ipc/write/common.rs +++ b/src/io/ipc/write/common.rs @@ -1,32 +1,14 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -//! Common utilities used to write to Arrow's IPC format. - use std::{collections::HashMap, sync::Arc}; use arrow_format::ipc; use arrow_format::ipc::flatbuffers::FlatBufferBuilder; use arrow_format::ipc::Message::CompressionType; -use crate::array::Array; +use crate::array::*; +use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::io::ipc::endianess::is_native_little_endian; use crate::record_batch::RecordBatch; -use crate::{array::DictionaryArray, datatypes::*}; use super::{write, write_dictionary}; @@ -47,34 +29,183 @@ pub struct WriteOptions { pub compression: Option, } -pub fn encoded_batch( - batch: &RecordBatch, - dictionary_tracker: &mut DictionaryTracker, +fn encode_dictionary( + field: &Field, + array: &Arc, options: &WriteOptions, -) -> Result<(Vec, EncodedData)> { - // TODO: handle nested dictionaries - let schema = batch.schema(); - let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len()); - - for (i, field) in schema.fields().iter().enumerate() { - let column = batch.column(i); - - if let DataType::Dictionary(_key_type, _value_type) = column.data_type() { + dictionary_tracker: &mut DictionaryTracker, + encoded_dictionaries: &mut Vec, +) -> Result<()> { + use PhysicalType::*; + match array.data_type().to_physical_type() { + Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null + | FixedSizeBinary => Ok(()), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { let dict_id = field .dict_id() .expect("All Dictionary types have `dict_id`"); - let emit = dictionary_tracker.insert(dict_id, column)?; + let values = array.as_any().downcast_ref::>().unwrap().values(); + // todo: this is won't work for Dict>; + let field = Field::new("item", values.data_type().clone(), true); + encode_dictionary(&field, + values, + options, + dictionary_tracker, + encoded_dictionaries + )?; + + let emit = dictionary_tracker.insert(dict_id, array)?; if emit { encoded_dictionaries.push(dictionary_batch_to_bytes( dict_id, - column.as_ref(), + array.as_ref(), options, is_native_little_endian(), )); - } + }; + Ok(()) + }), + Struct => { + let values = array + .as_any() + .downcast_ref::() + .unwrap() + .values(); + let fields = if let DataType::Struct(fields) = array.data_type() { + fields + } else { + unreachable!() + }; + fields + .iter() + .zip(values.iter()) + .try_for_each(|(field, values)| { + encode_dictionary( + field, + values, + options, + dictionary_tracker, + encoded_dictionaries, + ) + }) + } + List => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap() + .values(); + let field = if let DataType::List(field) = field.data_type() { + field.as_ref() + } else { + unreachable!() + }; + encode_dictionary( + field, + values, + options, + dictionary_tracker, + encoded_dictionaries, + ) + } + LargeList => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap() + .values(); + let field = if let DataType::LargeList(field) = field.data_type() { + field.as_ref() + } else { + unreachable!() + }; + encode_dictionary( + field, + values, + options, + dictionary_tracker, + encoded_dictionaries, + ) } + FixedSizeList => { + let values = array + .as_any() + .downcast_ref::() + .unwrap() + .values(); + let field = if let DataType::FixedSizeList(field, _) = field.data_type() { + field.as_ref() + } else { + unreachable!() + }; + encode_dictionary( + field, + values, + options, + dictionary_tracker, + encoded_dictionaries, + ) + } + Union => { + let values = array + .as_any() + .downcast_ref::() + .unwrap() + .fields(); + let fields = if let DataType::Union(fields, _, _) = field.data_type() { + fields + } else { + unreachable!() + }; + fields + .iter() + .zip(values.iter()) + .try_for_each(|(field, values)| { + encode_dictionary( + field, + values, + options, + dictionary_tracker, + encoded_dictionaries, + ) + }) + } + Map => { + let values = array.as_any().downcast_ref::().unwrap().field(); + let field = if let DataType::Map(field, _) = field.data_type() { + field.as_ref() + } else { + unreachable!() + }; + encode_dictionary( + field, + values, + options, + dictionary_tracker, + encoded_dictionaries, + ) + } + } +} + +pub fn encoded_batch( + batch: &RecordBatch, + dictionary_tracker: &mut DictionaryTracker, + options: &WriteOptions, +) -> Result<(Vec, EncodedData)> { + let schema = batch.schema(); + let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len()); + + for (field, column) in schema.fields().iter().zip(batch.columns()) { + encode_dictionary( + field, + column, + options, + dictionary_tracker, + &mut encoded_dictionaries, + )?; } let encoded_message = record_batch_to_bytes(batch, options); diff --git a/src/io/ipc/write/stream.rs b/src/io/ipc/write/stream.rs index 83338942cc2..c6d4f9bd950 100644 --- a/src/io/ipc/write/stream.rs +++ b/src/io/ipc/write/stream.rs @@ -73,8 +73,7 @@ impl StreamWriter { } let (encoded_dictionaries, encoded_message) = - encoded_batch(batch, &mut self.dictionary_tracker, &self.write_options) - .expect("StreamWriter is configured to not error on dictionary replacement"); + encoded_batch(batch, &mut self.dictionary_tracker, &self.write_options)?; for encoded_dictionary in encoded_dictionaries { write_message(&mut self.writer, encoded_dictionary)?; diff --git a/src/io/json_integration/mod.rs b/src/io/json_integration/mod.rs index 70ffc597f51..175c5068c78 100644 --- a/src/io/json_integration/mod.rs +++ b/src/io/json_integration/mod.rs @@ -83,12 +83,35 @@ impl From<&Field> for ArrowJsonField { _ => None, }; + let dictionary = if let DataType::Dictionary(key_type, _) = &field.data_type { + use crate::datatypes::IntegerType::*; + Some(ArrowJsonFieldDictionary { + id: field.dict_id, + index_type: IntegerType { + name: "".to_string(), + bit_width: match key_type { + Int8 | UInt8 => 8, + Int16 | UInt16 => 16, + Int32 | UInt32 => 32, + Int64 | UInt64 => 64, + }, + is_signed: match key_type { + Int8 | Int16 | Int32 | Int64 => true, + UInt8 | UInt16 | UInt32 | UInt64 => false, + }, + }, + is_ordered: field.dict_is_ordered, + }) + } else { + None + }; + Self { name: field.name().to_string(), field_type: field.data_type().to_json(), nullable: field.is_nullable(), children: vec![], - dictionary: None, // TODO: not enough info + dictionary, metadata: metadata_value, } } diff --git a/src/io/json_integration/read.rs b/src/io/json_integration/read.rs index 856fdb9808c..f6aabfff9dc 100644 --- a/src/io/json_integration/read.rs +++ b/src/io/json_integration/read.rs @@ -357,7 +357,14 @@ pub fn to_array( let values = fields .iter() .zip(json_col.children.as_ref().unwrap()) - .map(|(field, col)| to_array(field.data_type().clone(), None, col, dictionaries)) + .map(|(field, col)| { + to_array( + field.data_type().clone(), + field.dict_id(), + col, + dictionaries, + ) + }) .collect::>>()?; let array = StructArray::from_data(data_type, values, validity); diff --git a/src/scalar/dictionary.rs b/src/scalar/dictionary.rs new file mode 100644 index 00000000000..450930e401c --- /dev/null +++ b/src/scalar/dictionary.rs @@ -0,0 +1,55 @@ +use std::any::Any; +use std::sync::Arc; + +use crate::{array::*, datatypes::DataType}; + +use super::Scalar; + +/// The [`DictionaryArray`] equivalent of [`Array`] for [`Scalar`]. +#[derive(Debug, Clone)] +pub struct DictionaryScalar { + value: Option>, + phantom: std::marker::PhantomData, + data_type: DataType, +} + +impl PartialEq for DictionaryScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) && (self.value.as_ref() == other.value.as_ref()) + } +} + +impl DictionaryScalar { + /// returns a new [`DictionaryScalar`] + /// # Panics + /// iff + /// * the `data_type` is not `List` or `LargeList` (depending on this scalar's offset `O`) + /// * the child of the `data_type` is not equal to the `values` + #[inline] + pub fn new(data_type: DataType, value: Option>) -> Self { + Self { + value, + phantom: std::marker::PhantomData, + data_type, + } + } + + /// The values of the [`DictionaryScalar`] + pub fn value(&self) -> Option<&Arc> { + self.value.as_ref() + } +} + +impl Scalar for DictionaryScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.value.is_some() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/src/scalar/equal.rs b/src/scalar/equal.rs index 3ce92132305..63af3e0fcde 100644 --- a/src/scalar/equal.rs +++ b/src/scalar/equal.rs @@ -111,6 +111,16 @@ fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { let rhs = rhs.as_any().downcast_ref::>().unwrap(); lhs == rhs } - _ => unimplemented!(), + DataType::Dictionary(key_type, _) => match_integer_type!(key_type, |$T| { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs + }), + DataType::Struct(_) => { + let lhs = lhs.as_any().downcast_ref::().unwrap(); + let rhs = rhs.as_any().downcast_ref::().unwrap(); + lhs == rhs + } + other => unimplemented!("{}", other), } } diff --git a/src/scalar/mod.rs b/src/scalar/mod.rs index e0909438979..d5694e165dd 100644 --- a/src/scalar/mod.rs +++ b/src/scalar/mod.rs @@ -5,6 +5,8 @@ use std::any::Any; use crate::{array::*, datatypes::*}; +mod dictionary; +pub use dictionary::*; mod equal; mod primitive; pub use primitive::*; @@ -121,6 +123,20 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box { FixedSizeBinary => todo!(), FixedSizeList => todo!(), Union | Map => todo!(), - Dictionary(_) => todo!(), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index).into()) + } else { + None + }; + Box::new(DictionaryScalar::<$T>::new( + array.data_type().clone(), + value, + )) + }), } } diff --git a/tests/it/io/ipc/read/file.rs b/tests/it/io/ipc/read/file.rs index c869b385857..602c4151a9a 100644 --- a/tests/it/io/ipc/read/file.rs +++ b/tests/it/io/ipc/read/file.rs @@ -134,6 +134,12 @@ fn read_generated_100_non_canonical_map() -> Result<()> { test_file("1.0.0-bigendian", "generated_map_non_canonical") } +#[test] +fn read_generated_100_nested_dictionary() -> Result<()> { + test_file("1.0.0-littleendian", "generated_nested_dictionary")?; + test_file("1.0.0-bigendian", "generated_nested_dictionary") +} + #[test] fn read_generated_017_union() -> Result<()> { test_file("0.17.1", "generated_union") diff --git a/tests/it/io/ipc/write/file.rs b/tests/it/io/ipc/write/file.rs index 4b3929c2c96..e00936203c4 100644 --- a/tests/it/io/ipc/write/file.rs +++ b/tests/it/io/ipc/write/file.rs @@ -250,6 +250,12 @@ fn write_100_map_non_canonical() -> Result<()> { test_file("1.0.0-bigendian", "generated_map_non_canonical", false) } +#[test] +fn write_100_nested_dictionary() -> Result<()> { + test_file("1.0.0-littleendian", "generated_nested_dictionary", false)?; + test_file("1.0.0-bigendian", "generated_nested_dictionary", false) +} + #[test] fn write_generated_017_union() -> Result<()> { test_file("0.17.1", "generated_union", false)