Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Improved performance of check_indexes #1313

Merged
merged 2 commits into from
Dec 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions src/array/dictionary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@ mod ffi;
pub(super) mod fmt;
mod iterator;
mod mutable;
use crate::array::specification::check_indexes_unchecked;
pub use iterator::*;
pub use mutable::*;

use super::{new_empty_array, primitive::PrimitiveArray, Array};
use super::{new_null_array, specification::check_indexes};

/// Trait denoting [`NativeType`]s that can be used as keys of a dictionary.
pub trait DictionaryKey: NativeType + TryInto<usize> + TryFrom<usize> {
/// # Safety
///
/// Any implementation of this trait must ensure that `always_fits_usize` only
/// returns `true` if all values succeeds on `value::try_into::<usize>().unwrap()`.
pub unsafe trait DictionaryKey: NativeType + TryInto<usize> + TryFrom<usize> {
/// The corresponding [`IntegerType`] of this key
const KEY_TYPE: IntegerType;

Expand All @@ -37,31 +42,53 @@ pub trait DictionaryKey: NativeType + TryInto<usize> + TryFrom<usize> {
Err(_) => unreachable_unchecked(),
}
}

/// If the key type always can be converted to `usize`.
fn always_fits_usize() -> bool {
false
}
}

impl DictionaryKey for i8 {
unsafe impl DictionaryKey for i8 {
const KEY_TYPE: IntegerType = IntegerType::Int8;
}
impl DictionaryKey for i16 {
unsafe impl DictionaryKey for i16 {
const KEY_TYPE: IntegerType = IntegerType::Int16;
}
impl DictionaryKey for i32 {
unsafe impl DictionaryKey for i32 {
const KEY_TYPE: IntegerType = IntegerType::Int32;
}
impl DictionaryKey for i64 {
unsafe impl DictionaryKey for i64 {
const KEY_TYPE: IntegerType = IntegerType::Int64;
}
impl DictionaryKey for u8 {
unsafe impl DictionaryKey for u8 {
const KEY_TYPE: IntegerType = IntegerType::UInt8;

fn always_fits_usize() -> bool {
true
}
}
impl DictionaryKey for u16 {
unsafe impl DictionaryKey for u16 {
const KEY_TYPE: IntegerType = IntegerType::UInt16;

fn always_fits_usize() -> bool {
true
}
}
impl DictionaryKey for u32 {
unsafe impl DictionaryKey for u32 {
const KEY_TYPE: IntegerType = IntegerType::UInt32;

fn always_fits_usize() -> bool {
true
}
}
impl DictionaryKey for u64 {
unsafe impl DictionaryKey for u64 {
const KEY_TYPE: IntegerType = IntegerType::UInt64;

#[cfg(target_pointer_width = "64")]
fn always_fits_usize() -> bool {
true
}
}

/// An [`Array`] whose values are stored as indices. This [`Array`] is useful when the cardinality of
Expand Down Expand Up @@ -120,7 +147,13 @@ impl<K: DictionaryKey> DictionaryArray<K> {
check_data_type(K::KEY_TYPE, &data_type, values.data_type())?;

if keys.null_count() != keys.len() {
check_indexes(keys.values(), values.len())?;
if K::always_fits_usize() {
// safety: we just checked that conversion to `usize` always
// succeeds
unsafe { check_indexes_unchecked(keys.values(), values.len()) }?;
} else {
check_indexes(keys.values(), values.len())?;
}
}

Ok(Self {
Expand Down
25 changes: 25 additions & 0 deletions src/array/specification.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::array::DictionaryKey;
use crate::error::{Error, Result};
use crate::offset::{Offset, Offsets, OffsetsBuffer};

Expand Down Expand Up @@ -107,6 +108,30 @@ pub(crate) fn try_check_utf8<O: Offset, C: OffsetsContainer<O>>(
}
}

/// Check dictionary indexes without checking usize conversion.
/// # Safety
/// The caller must ensure that `K::as_usize` always succeeds.
pub(crate) unsafe fn check_indexes_unchecked<K: DictionaryKey>(
keys: &[K],
len: usize,
) -> Result<()> {
let mut invalid = false;

// this loop is auto-vectorized
keys.iter().for_each(|k| {
if k.as_usize() > len {
invalid = true;
}
});

if invalid {
let key = keys.iter().map(|k| k.as_usize()).max().unwrap();
Err(Error::oos(format!("One of the dictionary keys is {} but it must be < than the length of the dictionary values, which is {}", key, len)))
} else {
Ok(())
}
}

pub fn check_indexes<K>(keys: &[K], len: usize) -> Result<()>
where
K: std::fmt::Debug + Copy + TryInto<usize>,
Expand Down