From b728cc6b8651dfb272ee829b9e93f50d50ab4c45 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 7 Nov 2021 20:44:55 +0000 Subject: [PATCH] Added support to read dictionary from nested types. --- src/array/dictionary/mod.rs | 12 ++++---- src/array/display.rs | 8 ++++- src/array/equal/dictionary.rs | 11 +------ src/datatypes/field.rs | 22 +++++++++++++- src/datatypes/schema.rs | 8 ----- src/io/ipc/read/common.rs | 54 +++++++++++++++++++++++++++++++--- src/io/ipc/write/common.rs | 17 +---------- src/scalar/dictionary.rs | 55 +++++++++++++++++++++++++++++++++++ src/scalar/equal.rs | 12 +++++++- src/scalar/mod.rs | 18 +++++++++++- tests/it/io/ipc/read/file.rs | 9 +++++- 11 files changed, 178 insertions(+), 48 deletions(-) create mode 100644 src/scalar/dictionary.rs diff --git a/src/array/dictionary/mod.rs b/src/array/dictionary/mod.rs index 8f4de0f4bd4..de45eff8bee 100644 --- a/src/array/dictionary/mod.rs +++ b/src/array/dictionary/mod.rs @@ -13,7 +13,8 @@ mod mutable; pub use iterator::*; pub use mutable::*; -use super::{new_empty_array, primitive::PrimitiveArray, Array}; +use super::display::get_value_display; +use super::{display_fmt, new_empty_array, primitive::PrimitiveArray, Array}; use crate::scalar::NullScalar; /// Trait denoting [`NativeType`]s that can be used as keys of a dictionary. @@ -196,9 +197,10 @@ where PrimitiveArray: std::fmt::Display, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!(f, "{:?}{{", self.data_type())?; - writeln!(f, "keys: {},", self.keys())?; - writeln!(f, "values: {},", self.values())?; - write!(f, "}}") + let display = get_value_display(self); + let new_lines = false; + let head = &format!("{}", self.data_type()); + let iter = self.iter().enumerate().map(|(i, x)| x.map(|_| display(i))); + display_fmt(iter, head, f, new_lines) } } diff --git a/src/array/display.rs b/src/array/display.rs index a69b3342038..12aac44b9ea 100644 --- a/src/array/display.rs +++ b/src/array/display.rs @@ -165,7 +165,13 @@ pub fn get_value_display<'a>(array: &'a dyn Array) -> Box Strin .unwrap(); let keys = a.keys(); let display = get_display(a.values().as_ref()); - Box::new(move |row: usize| display(keys.value(row) as usize)) + Box::new(move |row: usize| { + if keys.is_null(row) { + "".to_string() + }else { + display(keys.value(row) as usize) + } + }) }), Map(_, _) => todo!(), Struct(_) => { diff --git a/src/array/equal/dictionary.rs b/src/array/equal/dictionary.rs index 8c879ff8370..d1e91fcc9e9 100644 --- a/src/array/equal/dictionary.rs +++ b/src/array/equal/dictionary.rs @@ -1,14 +1,5 @@ use crate::array::{Array, DictionaryArray, DictionaryKey}; pub(super) fn equal(lhs: &DictionaryArray, rhs: &DictionaryArray) -> bool { - if !(lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len()) { - return false; - }; - - // if x is not valid and y is but its child is not, the slots are equal. - lhs.iter().zip(rhs.iter()).all(|(x, y)| match (&x, &y) { - (None, Some(y)) => !y.is_valid(), - (Some(x), None) => !x.is_valid(), - _ => x == y, - }) + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) } diff --git a/src/datatypes/field.rs b/src/datatypes/field.rs index 8fc544f1ffd..f6ba2f84e76 100644 --- a/src/datatypes/field.rs +++ b/src/datatypes/field.rs @@ -23,7 +23,7 @@ use super::DataType; /// A logical [`DataType`] and its associated metadata per /// [Arrow specification](https://arrow.apache.org/docs/cpp/api/datatype.html) -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Eq)] pub struct Field { /// Its name pub name: String, @@ -39,6 +39,26 @@ pub struct Field { pub metadata: Option>, } +impl std::hash::Hash for Field { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.data_type.hash(state); + self.nullable.hash(state); + self.dict_is_ordered.hash(state); + self.metadata.hash(state); + } +} + +impl PartialEq for Field { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.data_type == other.data_type + && self.nullable == other.nullable + && self.dict_is_ordered == other.dict_is_ordered + && self.metadata == other.metadata + } +} + impl Field { /// Creates a new field pub fn new(name: &str, data_type: DataType, nullable: bool) -> Self { diff --git a/src/datatypes/schema.rs b/src/datatypes/schema.rs index afb656899b7..4ab3488a63d 100644 --- a/src/datatypes/schema.rs +++ b/src/datatypes/schema.rs @@ -164,14 +164,6 @@ impl Schema { Ok(&self.fields[self.index_of(name)?]) } - /// Returns all [`Field`]s with dictionary id `dict_id`. - pub fn fields_with_dict_id(&self, dict_id: i64) -> Vec<&Field> { - self.fields - .iter() - .filter(|f| f.dict_id() == Some(dict_id)) - .collect() - } - /// Find the index of the column with the given name. pub fn index_of(&self, name: &str) -> Result { for i in 0..self.fields.len() { diff --git a/src/io/ipc/read/common.rs b/src/io/ipc/read/common.rs index b8846636168..055b31b3705 100644 --- a/src/io/ipc/read/common.rs +++ b/src/io/ipc/read/common.rs @@ -162,6 +162,55 @@ pub fn read_record_batch( RecordBatch::try_new(schema, columns) } +fn find_first_dict_field_d(id: usize, data_type: &DataType) -> Option<&Field> { + use DataType::*; + match data_type { + Dictionary(_, inner) => find_first_dict_field_d(id, inner.as_ref()), + Map(field, _) => find_first_dict_field(id, field.as_ref()), + List(field) => find_first_dict_field(id, field.as_ref()), + LargeList(field) => find_first_dict_field(id, field.as_ref()), + FixedSizeList(field, _) => find_first_dict_field(id, field.as_ref()), + Union(fields, _, _) => { + for field in fields { + if let Some(f) = find_first_dict_field(id, field) { + return Some(f); + } + } + None + } + Struct(fields) => { + for field in fields { + if let Some(f) = find_first_dict_field(id, field) { + return Some(f); + } + } + None + } + _ => None, + } +} + +fn find_first_dict_field(id: usize, field: &Field) -> Option<&Field> { + if let DataType::Dictionary(_, _) = &field.data_type { + if field.dict_id as usize == id { + return Some(field); + } + } + find_first_dict_field_d(id, &field.data_type) +} + +fn first_dict_field(id: usize, fields: &[Field]) -> Result<&Field> { + for field in fields { + if let Some(field) = find_first_dict_field(id, field) { + return Ok(field); + } + } + Err(ArrowError::Schema(format!( + "dictionary id {} not found in schema", + id + ))) +} + /// Read the dictionary from the buffer and provided metadata, /// updating the `dictionaries` with the resulting dictionary pub fn read_dictionary( @@ -179,10 +228,7 @@ pub fn read_dictionary( } let id = batch.id(); - let fields_using_this_dictionary = schema.fields_with_dict_id(id); - let first_field = fields_using_this_dictionary.first().ok_or_else(|| { - ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string()) - })?; + let first_field = first_dict_field(id as usize, &schema.fields)?; // As the dictionary batch does not contain the type of the // values array, we need to retrieve this from the schema. diff --git a/src/io/ipc/write/common.rs b/src/io/ipc/write/common.rs index 150935f1f01..1b79d93b85a 100644 --- a/src/io/ipc/write/common.rs +++ b/src/io/ipc/write/common.rs @@ -29,18 +29,6 @@ pub struct WriteOptions { pub compression: Option, } -fn get_dict_values(array: &dyn Array) -> &Arc { - match array.data_type().to_physical_type() { - PhysicalType::Dictionary(key_type) => { - with_match_physical_dictionary_key_type!(key_type, |$T| { - let array = array.as_any().downcast_ref::>().unwrap(); - array.values() - }) - } - _ => unreachable!(), - } -} - fn encode_dictionary( field: &Field, array: &Arc, @@ -67,10 +55,7 @@ fn encode_dictionary( is_native_little_endian(), )); }; - - // todo: support nested dictionaries. Requires DataType::Dictionary to contain Field in values - let _ = get_dict_values(array.as_ref()); - todo!() + Ok(()) } Struct => { let values = array diff --git a/src/scalar/dictionary.rs b/src/scalar/dictionary.rs new file mode 100644 index 00000000000..450930e401c --- /dev/null +++ b/src/scalar/dictionary.rs @@ -0,0 +1,55 @@ +use std::any::Any; +use std::sync::Arc; + +use crate::{array::*, datatypes::DataType}; + +use super::Scalar; + +/// The [`DictionaryArray`] equivalent of [`Array`] for [`Scalar`]. +#[derive(Debug, Clone)] +pub struct DictionaryScalar { + value: Option>, + phantom: std::marker::PhantomData, + data_type: DataType, +} + +impl PartialEq for DictionaryScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) && (self.value.as_ref() == other.value.as_ref()) + } +} + +impl DictionaryScalar { + /// returns a new [`DictionaryScalar`] + /// # Panics + /// iff + /// * the `data_type` is not `List` or `LargeList` (depending on this scalar's offset `O`) + /// * the child of the `data_type` is not equal to the `values` + #[inline] + pub fn new(data_type: DataType, value: Option>) -> Self { + Self { + value, + phantom: std::marker::PhantomData, + data_type, + } + } + + /// The values of the [`DictionaryScalar`] + pub fn value(&self) -> Option<&Arc> { + self.value.as_ref() + } +} + +impl Scalar for DictionaryScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.value.is_some() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/src/scalar/equal.rs b/src/scalar/equal.rs index 3ce92132305..63af3e0fcde 100644 --- a/src/scalar/equal.rs +++ b/src/scalar/equal.rs @@ -111,6 +111,16 @@ fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { let rhs = rhs.as_any().downcast_ref::>().unwrap(); lhs == rhs } - _ => unimplemented!(), + DataType::Dictionary(key_type, _) => match_integer_type!(key_type, |$T| { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs + }), + DataType::Struct(_) => { + let lhs = lhs.as_any().downcast_ref::().unwrap(); + let rhs = rhs.as_any().downcast_ref::().unwrap(); + lhs == rhs + } + other => unimplemented!("{}", other), } } diff --git a/src/scalar/mod.rs b/src/scalar/mod.rs index e0909438979..d5694e165dd 100644 --- a/src/scalar/mod.rs +++ b/src/scalar/mod.rs @@ -5,6 +5,8 @@ use std::any::Any; use crate::{array::*, datatypes::*}; +mod dictionary; +pub use dictionary::*; mod equal; mod primitive; pub use primitive::*; @@ -121,6 +123,20 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box { FixedSizeBinary => todo!(), FixedSizeList => todo!(), Union | Map => todo!(), - Dictionary(_) => todo!(), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index).into()) + } else { + None + }; + Box::new(DictionaryScalar::<$T>::new( + array.data_type().clone(), + value, + )) + }), } } diff --git a/tests/it/io/ipc/read/file.rs b/tests/it/io/ipc/read/file.rs index 602c4151a9a..aaec1038f1f 100644 --- a/tests/it/io/ipc/read/file.rs +++ b/tests/it/io/ipc/read/file.rs @@ -13,7 +13,9 @@ fn test_file(version: &str, file_name: &str) -> Result<()> { ))?; // read expected JSON output + println!("reading json"); let (schema, batches) = read_gzip_json(version, file_name)?; + println!("reading metadata"); let metadata = read_file_metadata(&mut file)?; let reader = FileReader::new(file, metadata, None); @@ -21,7 +23,12 @@ fn test_file(version: &str, file_name: &str) -> Result<()> { assert_eq!(&schema, reader.schema().as_ref()); batches.iter().zip(reader).try_for_each(|(lhs, rhs)| { - assert_eq!(lhs, &rhs?); + for (c1, c2) in lhs.columns().iter().zip(rhs?.columns().iter()) { + println!("{}", c1); + println!("{}", c2); + assert_eq!(c1, c2); + } + //assert_eq!(lhs, &rhs?); Result::Ok(()) })?; Ok(())