Skip to content

Commit

Permalink
Improve performance of DictionaryArray::try_new()  (#1435)
Browse files Browse the repository at this point in the history
* improve `DictionaryArray::try_new()` #1313

* *: fix typo

* *: add cheap validate and unit test

* *: polish the error

* Add safety note

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
jackwener and alamb authored Mar 22, 2022
1 parent e778c10 commit 3ee2e67
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 47 deletions.
30 changes: 23 additions & 7 deletions arrow/src/array/array_dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ use super::{
make_array, Array, ArrayData, ArrayRef, PrimitiveArray, PrimitiveBuilder,
StringArray, StringBuilder, StringDictionaryBuilder,
};
use crate::datatypes::ArrowNativeType;
use crate::datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType, DataType};
use crate::datatypes::{
ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType,
};
use crate::error::Result;

/// A dictionary array where each element is a single value indexed by an integer key.
Expand Down Expand Up @@ -96,8 +97,8 @@ impl<'a, K: ArrowPrimitiveType> DictionaryArray<K> {
Box::new(values.data_type().clone()),
);

// Note: This does more work than necessary by rebuilding /
// revalidating all the data
// Note: This use the ArrayDataBuilder::build_unchecked and afterwards
// call the new function which only validates that the keys are in bounds.
let mut data = ArrayData::builder(dict_data_type)
.len(keys.len())
.add_buffer(keys.data().buffers()[0].clone())
Expand All @@ -112,7 +113,14 @@ impl<'a, K: ArrowPrimitiveType> DictionaryArray<K> {
_ => data = data.null_count(0),
}

Ok(data.build()?.into())
// Safety: `validate` ensures key type is correct, and
// `validate_dictionary_offset` ensures all offsets are within range
let array = unsafe { data.build_unchecked() };

array.validate()?;
array.validate_dictionary_offset()?;

Ok(array.into())
}

/// Return an array view of the keys of this dictionary as a PrimitiveArray.
Expand Down Expand Up @@ -308,8 +316,8 @@ impl<T: ArrowPrimitiveType> fmt::Debug for DictionaryArray<T> {
mod tests {
use super::*;

use crate::array::Int8Array;
use crate::datatypes::Int16Type;
use crate::array::{Float32Array, Int8Array};
use crate::datatypes::{Float32Type, Int16Type};
use crate::{
array::Int16DictionaryArray, array::PrimitiveDictionaryBuilder,
datatypes::DataType,
Expand Down Expand Up @@ -574,4 +582,12 @@ mod tests {
let keys: Int32Array = [Some(-100)].into_iter().collect();
DictionaryArray::<Int32Type>::try_new(&keys, &values).unwrap();
}

#[test]
#[should_panic(expected = "Dictionary key type must be integer, but was Float32")]
fn test_try_wrong_dictionary_key_type() {
let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect();
let keys: Float32Array = [Some(0_f32), None, Some(3_f32)].into_iter().collect();
DictionaryArray::<Float32Type>::try_new(&keys, &values).unwrap();
}
}
80 changes: 40 additions & 40 deletions arrow/src/array/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ impl ArrayData {
// At the moment, constructing a DictionaryArray will also check this
if !DataType::is_dictionary_key_type(key_type) {
return Err(ArrowError::InvalidArgumentError(format!(
"Dictionary values must be integer, but was {}",
"Dictionary key type must be integer, but was {}",
key_type
)));
}
Expand Down Expand Up @@ -926,8 +926,8 @@ impl ArrayData {
///
/// 1. Null count is correct
/// 2. All offsets are valid
/// 3. All String data is valid UTF-8
/// 3. All dictionary offsets are valid
/// 3. All String data is valid UTF-8
/// 4. All dictionary offsets are valid
///
/// Does not (yet) check
/// 1. Union type_ids are valid see [#85](https://github.com/apache/arrow-rs/issues/85)
Expand All @@ -949,68 +949,68 @@ impl ArrayData {
)));
}

self.validate_dictionary_offset()?;

// validate all children recursively
self.child_data
.iter()
.enumerate()
.try_for_each(|(i, child_data)| {
child_data.validate_full().map_err(|e| {
ArrowError::InvalidArgumentError(format!(
"{} child #{} invalid: {}",
self.data_type, i, e
))
})
})?;

Ok(())
}

pub fn validate_dictionary_offset(&self) -> Result<()> {
match &self.data_type {
DataType::Utf8 => {
self.validate_utf8::<i32>()?;
}
DataType::LargeUtf8 => {
self.validate_utf8::<i64>()?;
}
DataType::Binary => {
self.validate_offsets_full::<i32>(self.buffers[1].len())?;
}
DataType::Utf8 => self.validate_utf8::<i32>(),
DataType::LargeUtf8 => self.validate_utf8::<i64>(),
DataType::Binary => self.validate_offsets_full::<i32>(self.buffers[1].len()),
DataType::LargeBinary => {
self.validate_offsets_full::<i64>(self.buffers[1].len())?;
self.validate_offsets_full::<i64>(self.buffers[1].len())
}
DataType::List(_) | DataType::Map(_, _) => {
let child = &self.child_data[0];
self.validate_offsets_full::<i32>(child.len + child.offset)?;
self.validate_offsets_full::<i32>(child.len + child.offset)
}
DataType::LargeList(_) => {
let child = &self.child_data[0];
self.validate_offsets_full::<i64>(child.len + child.offset)?;
self.validate_offsets_full::<i64>(child.len + child.offset)
}
DataType::Union(_, _) => {
// Validate Union Array as part of implementing new Union semantics
// See comments in `ArrayData::validate()`
// https://github.com/apache/arrow-rs/issues/85
//
// TODO file follow on ticket for full union validation
Ok(())
}
DataType::Dictionary(key_type, _value_type) => {
let dictionary_length: i64 = self.child_data[0].len.try_into().unwrap();
let max_value = dictionary_length - 1;
match key_type.as_ref() {
DataType::UInt8 => self.check_bounds::<u8>(max_value)?,
DataType::UInt16 => self.check_bounds::<u16>(max_value)?,
DataType::UInt32 => self.check_bounds::<u32>(max_value)?,
DataType::UInt64 => self.check_bounds::<u64>(max_value)?,
DataType::Int8 => self.check_bounds::<i8>(max_value)?,
DataType::Int16 => self.check_bounds::<i16>(max_value)?,
DataType::Int32 => self.check_bounds::<i32>(max_value)?,
DataType::Int64 => self.check_bounds::<i64>(max_value)?,
DataType::UInt8 => self.check_bounds::<u8>(max_value),
DataType::UInt16 => self.check_bounds::<u16>(max_value),
DataType::UInt32 => self.check_bounds::<u32>(max_value),
DataType::UInt64 => self.check_bounds::<u64>(max_value),
DataType::Int8 => self.check_bounds::<i8>(max_value),
DataType::Int16 => self.check_bounds::<i16>(max_value),
DataType::Int32 => self.check_bounds::<i32>(max_value),
DataType::Int64 => self.check_bounds::<i64>(max_value),
_ => unreachable!(),
}
}
_ => {
// No extra validation check required for other types
Ok(())
}
};

// validate all children recursively
self.child_data
.iter()
.enumerate()
.try_for_each(|(i, child_data)| {
child_data.validate_full().map_err(|e| {
ArrowError::InvalidArgumentError(format!(
"{} child #{} invalid: {}",
self.data_type, i, e
))
})
})?;

Ok(())
}
}

/// Calls the `validate(item_index, range)` function for each of
Expand Down Expand Up @@ -1736,7 +1736,7 @@ mod tests {

// Test creating a dictionary with a non integer type
#[test]
#[should_panic(expected = "Dictionary values must be integer, but was Utf8")]
#[should_panic(expected = "Dictionary key type must be integer, but was Utf8")]
fn test_non_int_dictionary() {
let i32_buffer = Buffer::from_slice_ref(&[0i32, 2i32]);
let data_type =
Expand Down

0 comments on commit 3ee2e67

Please sign in to comment.