Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Improved iter for DictionaryArray #1288

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 12 additions & 42 deletions src/array/dictionary/iterator.rs
Original file line number Diff line number Diff line change
@@ -1,59 +1,29 @@
use crate::array::{ArrayAccessor, ArrayValuesIter};
use crate::bitmap::utils::{BitmapIter, ZipValidity};
use crate::scalar::Scalar;
use crate::trusted_len::TrustedLen;

use super::{DictionaryArray, DictionaryKey};

/// Iterator of values of an `ListArray`.
pub struct DictionaryValuesIter<'a, K: DictionaryKey> {
array: &'a DictionaryArray<K>,
index: usize,
end: usize,
}

impl<'a, K: DictionaryKey> DictionaryValuesIter<'a, K> {
#[inline]
pub fn new(array: &'a DictionaryArray<K>) -> Self {
Self {
array,
index: 0,
end: array.len(),
}
}
}

impl<'a, K: DictionaryKey> Iterator for DictionaryValuesIter<'a, K> {
unsafe impl<'a, K> ArrayAccessor<'a> for DictionaryArray<K>
where
K: DictionaryKey,
{
type Item = Box<dyn Scalar>;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.index == self.end {
return None;
}
let old = self.index;
self.index += 1;
Some(self.array.value(old))
unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item {
// safety: invariant of the trait
self.value_unchecked(index)
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.end - self.index, Some(self.end - self.index))
fn len(&self) -> usize {
self.keys.len()
}
}

unsafe impl<'a, K: DictionaryKey> TrustedLen for DictionaryValuesIter<'a, K> {}

impl<'a, K: DictionaryKey> DoubleEndedIterator for DictionaryValuesIter<'a, K> {
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
if self.index == self.end {
None
} else {
self.end -= 1;
Some(self.array.value(self.end))
}
}
}
/// Iterator of values of a [`DictionaryArray`].
pub type DictionaryValuesIter<'a, K> = ArrayValuesIter<'a, DictionaryArray<K>>;

type ValuesIter<'a, K> = DictionaryValuesIter<'a, K>;
type ZipIter<'a, K> = ZipValidity<Box<dyn Scalar>, ValuesIter<'a, K>, BitmapIter<'a>>;
Expand Down
18 changes: 16 additions & 2 deletions src/array/dictionary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,23 @@ impl<K: DictionaryKey> DictionaryArray<K> {
/// # Panic
/// This function panics iff `index >= self.len()`
#[inline]
pub fn value(&self, index: usize) -> Box<dyn Scalar> {
pub fn value(&self, i: usize) -> Box<dyn Scalar> {
assert!(i < self.len());
unsafe { self.value_unchecked(i) }
}

/// Returns the value of the [`DictionaryArray`] at position `i`.
/// # Implementation
/// This function will allocate a new [`Scalar`] and is usually not performant.
/// Consider calling `keys` and `values`, downcasting `values`, and iterating over that.
/// # Safety
/// This function is safe iff `index < self.len()`
#[inline]
pub unsafe fn value_unchecked(&self, index: usize) -> Box<dyn Scalar> {
// safety - invariant of this function
let index = unsafe { self.keys.value_unchecked(index) };
// safety - invariant of this struct
let index = unsafe { self.keys.value(index).as_usize() };
let index = unsafe { index.as_usize() };
new_scalar(self.values.as_ref(), index)
}

Expand Down
35 changes: 33 additions & 2 deletions src/array/dictionary/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{collections::hash_map::DefaultHasher, sync::Arc};

use hash_hasher::HashedMap;

use crate::array::ArrayAccessor;
use crate::{
array::{primitive::MutablePrimitiveArray, Array, MutableArray, TryExtend, TryPush},
bitmap::MutableBitmap,
Expand Down Expand Up @@ -109,8 +110,14 @@ impl<K: DictionaryKey, M: MutableArray> MutableDictionaryArray<K, M> {
}

/// pushes a null value
#[inline]
pub fn push_null(&mut self) {
self.keys.push(None)
if self.values.is_empty() {
// keys's default value is 0. If self.values is empty, the 0th index
// would be out of bound.
self.values.push_null()
}
self.keys.push(None);
}

/// returns a mutable reference to the inner values.
Expand Down Expand Up @@ -198,8 +205,9 @@ impl<K: DictionaryKey, M: 'static + MutableArray> MutableArray for MutableDictio
self
}

#[inline]
fn push_null(&mut self) {
self.keys.push(None)
self.push_null()
}

fn reserve(&mut self, additional: usize) {
Expand Down Expand Up @@ -249,3 +257,26 @@ where
}
}
}

unsafe impl<'a, K, M, T: 'a> ArrayAccessor<'a> for MutableDictionaryArray<K, M>
where
K: DictionaryKey,
M: MutableArray + ArrayAccessor<'a, Item = T>,
{
type Item = T;

#[inline]
unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item {
// safety: invariant of the trait
let index = self.keys.value_unchecked(index);
// safety: invariant of the struct
let index = index.as_usize();
// safety: invariant of the struct
self.values.value_unchecked(index)
}

#[inline]
fn len(&self) -> usize {
self.keys.len()
}
}
2 changes: 1 addition & 1 deletion src/array/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod private {
impl<'a, T: super::ArrayAccessor<'a>> Sealed for T {}
}

/// Sealed trait representing assess to a value of an array.
/// Sealed trait representing random access to a value of an array.
/// # Safety
/// Implementers of this trait guarantee that
/// `value_unchecked` is safe when called up to `len`
Expand Down
10 changes: 2 additions & 8 deletions src/array/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,8 @@ impl<O: Offset> ListArray<O> {
/// Returns the element at index `i`
#[inline]
pub fn value(&self, i: usize) -> Box<dyn Array> {
let offset = self.offsets[i];
let offset_1 = self.offsets[i + 1];
let length = (offset_1 - offset).to_usize();

// Safety:
// One of the invariants of the struct
// is that offsets are in bounds
unsafe { self.values.slice_unchecked(offset.to_usize(), length) }
assert!(i < self.len());
unsafe { self.value_unchecked(i) }
}

/// Returns the element at index `i` as &str
Expand Down
21 changes: 20 additions & 1 deletion src/array/primitive/mutable.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{iter::FromIterator, sync::Arc};

use crate::array::physical_binary::extend_validity;
use crate::array::TryExtendFromSelf;
use crate::array::{ArrayAccessor, TryExtendFromSelf};
use crate::bitmap::Bitmap;
use crate::{
array::{Array, MutableArray, TryExtend, TryPush},
Expand Down Expand Up @@ -288,6 +288,11 @@ impl<T: NativeType> MutablePrimitiveArray<T> {
pub fn capacity(&self) -> usize {
self.values.capacity()
}

/// Returns the capacity of this [`MutablePrimitiveArray`].
pub fn len(&self) -> usize {
self.values.len()
}
}

/// Accessors
Expand Down Expand Up @@ -667,3 +672,17 @@ impl<T: NativeType> TryExtendFromSelf for MutablePrimitiveArray<T> {
Ok(())
}
}

unsafe impl<'a, T: NativeType> ArrayAccessor<'a> for MutablePrimitiveArray<T> {
type Item = T;

#[inline]
unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item {
*self.values.get_unchecked(index)
}

#[inline]
fn len(&self) -> usize {
self.len()
}
}
10 changes: 10 additions & 0 deletions tests/it/array/dictionary/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,13 @@ fn push_utf8() {
.collect::<HashedMap<_, _>>();
assert_eq!(*new.map(), expected_map);
}

#[test]
fn iter() {
let values = Utf8Array::<i32>::from_slice(&["a", "aa"]);
let array =
DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed())
.unwrap();

assert!(array.iter().eq(array.iter()))
}