diff --git a/src/array/struct_/ffi.rs b/src/array/struct_/ffi.rs new file mode 100644 index 00000000000..0168520cdcb --- /dev/null +++ b/src/array/struct_/ffi.rs @@ -0,0 +1,45 @@ +use std::sync::Arc; + +use super::super::{ffi::ToFfi, Array, FromFfi}; +use super::StructArray; +use crate::{error::Result, ffi}; + +unsafe impl ToFfi for StructArray { + fn buffers(&self) -> Vec>> { + vec![self.validity.as_ref().map(|x| x.as_ptr())] + } + + fn children(&self) -> Vec> { + self.values.clone() + } + + fn offset(&self) -> Option { + Some( + self.validity + .as_ref() + .map(|bitmap| bitmap.offset()) + .unwrap_or_default(), + ) + } + + fn to_ffi_aligned(&self) -> Self { + self.clone() + } +} + +impl FromFfi for StructArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.field().data_type().clone(); + let fields = Self::get_fields(&data_type); + + let validity = unsafe { array.validity() }?; + let values = (0..fields.len()) + .map(|index| { + let child = array.child(index)?; + Ok(ffi::try_from(child)?.into()) + }) + .collect::>>>()?; + + Ok(Self::from_data(data_type, values, validity)) + } +} diff --git a/src/array/struct_/iterator.rs b/src/array/struct_/iterator.rs new file mode 100644 index 00000000000..57ee095be98 --- /dev/null +++ b/src/array/struct_/iterator.rs @@ -0,0 +1,102 @@ +use crate::{ + bitmap::utils::{zip_validity, ZipValidity}, + scalar::{new_scalar, Scalar}, + trusted_len::TrustedLen, +}; + +use super::StructArray; + +pub struct StructValueIter<'a> { + array: &'a StructArray, + index: usize, + end: usize, +} + +impl<'a> StructValueIter<'a> { + #[inline] + pub fn new(array: &'a StructArray) -> Self { + Self { + array, + index: 0, + end: array.len(), + } + } +} + +impl<'a> Iterator for StructValueIter<'a> { + type Item = Vec>; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + + // Safety: + // self.end is maximized by the length of the array + Some( + self.array + .values() + .iter() + .map(|v| new_scalar(v.as_ref(), old)) + .collect(), + ) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl<'a> TrustedLen for StructValueIter<'a> {} + +impl<'a> DoubleEndedIterator for StructValueIter<'a> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + + // Safety: + // self.end is maximized by the length of the array + Some( + self.array + .values() + .iter() + .map(|v| new_scalar(v.as_ref(), self.end)) + .collect(), + ) + } + } +} + +type ValuesIter<'a> = StructValueIter<'a>; +type ZipIter<'a> = ZipValidity<'a, Vec>, ValuesIter<'a>>; + +impl<'a> IntoIterator for &'a StructArray { + type Item = Option>>; + type IntoIter = ZipIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> StructArray { + /// Returns an iterator of `Option>` + pub fn iter(&'a self) -> ZipIter<'a> { + zip_validity( + StructValueIter::new(self), + self.validity.as_ref().map(|x| x.iter()), + ) + } + + /// Returns an iterator of `Box` + pub fn values_iter(&'a self) -> ValuesIter<'a> { + StructValueIter::new(self) + } +} diff --git a/src/array/struct_.rs b/src/array/struct_/mod.rs similarity index 84% rename from src/array/struct_.rs rename to src/array/struct_/mod.rs index 10b5060b598..c54be876809 100644 --- a/src/array/struct_.rs +++ b/src/array/struct_/mod.rs @@ -3,11 +3,12 @@ use std::sync::Arc; use crate::{ bitmap::Bitmap, datatypes::{DataType, Field}, - error::Result, - ffi, }; -use super::{ffi::ToFfi, new_empty_array, new_null_array, Array, FromFfi}; +use super::{new_empty_array, new_null_array, Array}; + +mod ffi; +mod iterator; /// A [`StructArray`] is a nested [`Array`] with an optional validity representing /// multiple [`Array`] with the same number of rows. @@ -222,43 +223,3 @@ impl std::fmt::Display for StructArray { write!(f, "}}") } } - -unsafe impl ToFfi for StructArray { - fn buffers(&self) -> Vec>> { - vec![self.validity.as_ref().map(|x| x.as_ptr())] - } - - fn children(&self) -> Vec> { - self.values.clone() - } - - fn offset(&self) -> Option { - Some( - self.validity - .as_ref() - .map(|bitmap| bitmap.offset()) - .unwrap_or_default(), - ) - } - - fn to_ffi_aligned(&self) -> Self { - self.clone() - } -} - -impl FromFfi for StructArray { - unsafe fn try_from_ffi(array: A) -> Result { - let data_type = array.field().data_type().clone(); - let fields = Self::get_fields(&data_type); - - let validity = unsafe { array.validity() }?; - let values = (0..fields.len()) - .map(|index| { - let child = array.child(index)?; - Ok(ffi::try_from(child)?.into()) - }) - .collect::>>>()?; - - Ok(Self::from_data(data_type, values, validity)) - } -} diff --git a/tests/it/array/mod.rs b/tests/it/array/mod.rs index d9430594942..bfba8624b1c 100644 --- a/tests/it/array/mod.rs +++ b/tests/it/array/mod.rs @@ -8,6 +8,7 @@ mod growable; mod list; mod ord; mod primitive; +mod struct_; mod union; mod utf8; diff --git a/tests/it/array/struct_/iterator.rs b/tests/it/array/struct_/iterator.rs new file mode 100644 index 00000000000..dac0c16b6ec --- /dev/null +++ b/tests/it/array/struct_/iterator.rs @@ -0,0 +1,29 @@ +use arrow2::array::*; +use arrow2::datatypes::*; +use arrow2::scalar::new_scalar; + +#[test] +fn test_simple_iter() { + use std::sync::Arc; + let boolean = Arc::new(BooleanArray::from_slice(&[false, false, true, true])) as Arc; + let int = Arc::new(Int32Array::from_slice(&[42, 28, 19, 31])) as Arc; + + let fields = vec![ + Field::new("b", DataType::Boolean, false), + Field::new("c", DataType::Int32, false), + ]; + + let array = StructArray::from_data( + DataType::Struct(fields), + vec![boolean.clone(), int.clone()], + None, + ); + + for (i, item) in array.iter().enumerate() { + let expected = Some(vec![ + new_scalar(boolean.as_ref(), i), + new_scalar(int.as_ref(), i), + ]); + assert_eq!(expected, item); + } +} diff --git a/tests/it/array/struct_/mod.rs b/tests/it/array/struct_/mod.rs new file mode 100644 index 00000000000..27325c6a68f --- /dev/null +++ b/tests/it/array/struct_/mod.rs @@ -0,0 +1 @@ +mod iterator;