Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added support to write dictionaries in nested types.
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Nov 8, 2021
1 parent 693a7c1 commit e572d3c
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 34 deletions.
11 changes: 10 additions & 1 deletion src/array/equal/dictionary.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
use crate::array::{Array, DictionaryArray, DictionaryKey};

pub(super) fn equal<K: DictionaryKey>(lhs: &DictionaryArray<K>, rhs: &DictionaryArray<K>) -> bool {
lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter())
if !(lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len()) {
return false;
};

// if x is not valid and y is but its child is not, the slots are equal.
lhs.iter().zip(rhs.iter()).all(|(x, y)| match (&x, &y) {
(None, Some(y)) => !y.is_valid(),
(Some(x), None) => !x.is_valid(),
_ => x == y,
})
}
20 changes: 2 additions & 18 deletions src/io/ipc/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -663,24 +663,8 @@ pub(crate) fn get_fb_field_type<'a>(
}
}
Struct(fields) => {
// struct's fields are children
let mut children = vec![];
for field in fields {
let inner_types = get_fb_field_type(field.data_type(), field.is_nullable(), fbb);
let field_name = fbb.create_string(field.name());
children.push(ipc::Field::create(
fbb,
&ipc::FieldArgs {
name: Some(field_name),
nullable: field.is_nullable(),
type_type: inner_types.type_type,
type_: Some(inner_types.type_),
dictionary: None,
children: inner_types.children,
custom_metadata: None,
},
));
}
let children: Vec<_> = fields.iter().map(|field| build_field(fbb, field)).collect();

FbFieldType {
type_type,
type_: ipc::Struct_Builder::new(fbb).finish().as_union_value(),
Expand Down
15 changes: 11 additions & 4 deletions src/io/ipc/read/array/dictionary.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::{HashMap, VecDeque};
use std::collections::{HashMap, HashSet, VecDeque};
use std::convert::TryInto;
use std::io::{Read, Seek};
use std::sync::Arc;
Expand All @@ -7,7 +7,7 @@ use arrow_format::ipc;

use crate::array::{Array, DictionaryArray, DictionaryKey};
use crate::datatypes::Field;
use crate::error::Result;
use crate::error::{ArrowError, Result};

use super::super::deserialize::Node;
use super::{read_primitive, skip_primitive};
Expand All @@ -26,9 +26,16 @@ pub fn read_dictionary<T: DictionaryKey, R: Read + Seek>(
where
Vec<u8>: TryInto<T::Bytes>,
{
let id = field.dict_id().unwrap() as usize;
let values = dictionaries
.get(&(field.dict_id().unwrap() as usize))
.unwrap()
.get(&id)
.ok_or_else(|| {
let valid_ids = dictionaries.keys().collect::<HashSet<_>>();
ArrowError::Ipc(format!(
"Dictionary id {} not found. Valid ids: {:?}",
id, valid_ids
))
})?
.clone();

let keys = read_primitive(
Expand Down
16 changes: 13 additions & 3 deletions src/io/ipc/write/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,21 @@ fn encode_dictionary(
match array.data_type().to_physical_type() {
Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null
| FixedSizeBinary => Ok(()),
Dictionary(_) => {
Dictionary(key_type) => match_integer_type!(key_type, |$T| {
let dict_id = field
.dict_id()
.expect("All Dictionary types have `dict_id`");

let values = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap().values();
// todo: this is won't work for Dict<Dict<...>>;
let field = Field::new("item", values.data_type().clone(), true);
encode_dictionary(&field,
values,
options,
dictionary_tracker,
encoded_dictionaries
)?;

let emit = dictionary_tracker.insert(dict_id, array)?;

if emit {
Expand All @@ -56,14 +66,14 @@ fn encode_dictionary(
));
};
Ok(())
}
}),
Struct => {
let values = array
.as_any()
.downcast_ref::<StructArray>()
.unwrap()
.values();
let fields = if let DataType::Struct(fields) = field.data_type() {
let fields = if let DataType::Struct(fields) = array.data_type() {
fields
} else {
unreachable!()
Expand Down
9 changes: 1 addition & 8 deletions tests/it/io/ipc/read/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,15 @@ fn test_file(version: &str, file_name: &str) -> Result<()> {
))?;

// read expected JSON output
println!("reading json");
let (schema, batches) = read_gzip_json(version, file_name)?;
println!("reading metadata");

let metadata = read_file_metadata(&mut file)?;
let reader = FileReader::new(file, metadata, None);

assert_eq!(&schema, reader.schema().as_ref());

batches.iter().zip(reader).try_for_each(|(lhs, rhs)| {
for (c1, c2) in lhs.columns().iter().zip(rhs?.columns().iter()) {
println!("{}", c1);
println!("{}", c2);
assert_eq!(c1, c2);
}
//assert_eq!(lhs, &rhs?);
assert_eq!(lhs, &rhs?);
Result::Ok(())
})?;
Ok(())
Expand Down
6 changes: 6 additions & 0 deletions tests/it/io/ipc/write/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,12 @@ fn write_100_map_non_canonical() -> Result<()> {
test_file("1.0.0-bigendian", "generated_map_non_canonical", false)
}

#[test]
fn write_100_nested_dictionary() -> Result<()> {
test_file("1.0.0-littleendian", "generated_nested_dictionary", false)?;
test_file("1.0.0-bigendian", "generated_nested_dictionary", false)
}

#[test]
fn write_generated_017_union() -> Result<()> {
test_file("0.17.1", "generated_union", false)
Expand Down

0 comments on commit e572d3c

Please sign in to comment.