diff --git a/src/array/dictionary/mod.rs b/src/array/dictionary/mod.rs index 8dc1bbd7c9a..46023209565 100644 --- a/src/array/dictionary/mod.rs +++ b/src/array/dictionary/mod.rs @@ -17,6 +17,9 @@ pub(super) mod fmt; mod iterator; mod mutable; use crate::array::specification::check_indexes_unchecked; +mod typed_iterator; + +use crate::array::dictionary::typed_iterator::{DictValue, DictionaryValuesIterTyped}; pub use iterator::*; pub use mutable::*; @@ -237,6 +240,38 @@ impl DictionaryArray { DictionaryValuesIter::new(self) } + /// Returns an iterator over the the values [`V::IterValue`]. + /// + /// # Panics + /// + /// Panics if the keys of this [`DictionaryArray`] have any null types. + /// If they do [`DictionaryArray::iter_typed`] should be called + pub fn values_iter_typed( + &self, + ) -> Result, Error> { + let keys = &self.keys; + assert_eq!(keys.null_count(), 0); + let values = self.values.as_ref(); + let values = V::downcast_values(values)?; + Ok(unsafe { DictionaryValuesIterTyped::new(keys, values) }) + } + + /// Returns an iterator over the the optional values of [`Option`]. + /// + /// # Panics + /// + /// This function panics if the `values` array + pub fn iter_typed( + &self, + ) -> Result, DictionaryValuesIterTyped, BitmapIter>, Error> + { + let keys = &self.keys; + let values = self.values.as_ref(); + let values = V::downcast_values(values)?; + let values_iter = unsafe { DictionaryValuesIterTyped::new(keys, values) }; + Ok(ZipValidity::new_with_validity(values_iter, self.validity())) + } + /// Returns the [`DataType`] of this [`DictionaryArray`] #[inline] pub fn data_type(&self) -> &DataType { diff --git a/src/array/dictionary/typed_iterator.rs b/src/array/dictionary/typed_iterator.rs new file mode 100644 index 00000000000..0e90a1cf4d8 --- /dev/null +++ b/src/array/dictionary/typed_iterator.rs @@ -0,0 +1,111 @@ +use crate::array::{Array, PrimitiveArray, Utf8Array}; +use crate::error::{Error, Result}; +use crate::trusted_len::TrustedLen; +use crate::types::Offset; + +use super::DictionaryKey; + +pub trait DictValue { + type IterValue<'this> + where + Self: 'this; + + /// # Safety + /// Will not do any bound checks but must check validity. + unsafe fn get_unchecked(&self, item: usize) -> Self::IterValue<'_>; + + /// Take a [`dyn Array`] an try to downcast it to the type of `DictValue`. + fn downcast_values(array: &dyn Array) -> Result<&Self> + where + Self: Sized; +} + +impl DictValue for Utf8Array { + type IterValue<'a> = &'a str; + + unsafe fn get_unchecked(&self, item: usize) -> Self::IterValue<'_> { + self.value_unchecked(item) + } + + fn downcast_values(array: &dyn Array) -> Result<&Self> + where + Self: Sized, + { + array + .as_any() + .downcast_ref::() + .ok_or(Error::InvalidArgumentError( + "could not convert array to dictionary value".into(), + )) + .map(|arr| { + assert_eq!( + arr.null_count(), + 0, + "null values in values not supported in iteration" + ); + arr + }) + } +} + +/// Iterator of values of an `ListArray`. +pub struct DictionaryValuesIterTyped<'a, K: DictionaryKey, V: DictValue> { + keys: &'a PrimitiveArray, + values: &'a V, + index: usize, + end: usize, +} + +impl<'a, K: DictionaryKey, V: DictValue> DictionaryValuesIterTyped<'a, K, V> { + pub(super) unsafe fn new(keys: &'a PrimitiveArray, values: &'a V) -> Self { + Self { + keys, + values, + index: 0, + end: keys.len(), + } + } +} + +impl<'a, K: DictionaryKey, V: DictValue> Iterator for DictionaryValuesIterTyped<'a, K, V> { + type Item = V::IterValue<'a>; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + unsafe { + let key = self.keys.value_unchecked(old); + let idx = key.as_usize(); + Some(self.values.get_unchecked(idx)) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl<'a, K: DictionaryKey, V: DictValue> TrustedLen for DictionaryValuesIterTyped<'a, K, V> {} + +impl<'a, K: DictionaryKey, V: DictValue> DoubleEndedIterator + for DictionaryValuesIterTyped<'a, K, V> +{ + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + unsafe { + let key = self.keys.value_unchecked(self.end); + let idx = key.as_usize(); + Some(self.values.get_unchecked(idx)) + } + } + } +} diff --git a/tests/it/array/dictionary/mod.rs b/tests/it/array/dictionary/mod.rs index b139c94c532..95f2ba78ca6 100644 --- a/tests/it/array/dictionary/mod.rs +++ b/tests/it/array/dictionary/mod.rs @@ -165,3 +165,48 @@ fn keys_values_iter() { assert_eq!(array.keys_values_iter().collect::>(), vec![1, 0]); } + +#[test] +fn iter_values_typed() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed()) + .unwrap(); + + let mut iter = array.values_iter_typed::>().unwrap(); + assert_eq!(iter.size_hint(), (3, Some(3))); + assert_eq!(iter.collect::>(), vec!["aa", "a", "a"]); + + let mut iter = array.iter_typed::>().unwrap(); + assert_eq!(iter.size_hint(), (3, Some(3))); + assert_eq!( + iter.collect::>(), + vec![Some("aa"), Some("a"), Some("a")] + ); +} + +#[test] +#[should_panic] +fn iter_values_typed_panic() { + let values = Utf8Array::::from_iter([Some("a"), Some("aa"), None]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed()) + .unwrap(); + + // should not be iterating values + let mut iter = array.values_iter_typed::>().unwrap(); + let _ = iter.collect::>(); +} + +#[test] +#[should_panic] +fn iter_values_typed_panic_2() { + let values = Utf8Array::::from_iter([Some("a"), Some("aa"), None]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed()) + .unwrap(); + + // should not be iterating values + let mut iter = array.iter_typed::>().unwrap(); + let _ = iter.collect::>(); +}