This repository has been archived by the owner on Feb 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 224
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement iterator for StructArray (#613)
- Loading branch information
1 parent
2363b27
commit b3ed162
Showing
6 changed files
with
182 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
use std::sync::Arc; | ||
|
||
use super::super::{ffi::ToFfi, Array, FromFfi}; | ||
use super::StructArray; | ||
use crate::{error::Result, ffi}; | ||
|
||
unsafe impl ToFfi for StructArray { | ||
fn buffers(&self) -> Vec<Option<std::ptr::NonNull<u8>>> { | ||
vec![self.validity.as_ref().map(|x| x.as_ptr())] | ||
} | ||
|
||
fn children(&self) -> Vec<Arc<dyn Array>> { | ||
self.values.clone() | ||
} | ||
|
||
fn offset(&self) -> Option<usize> { | ||
Some( | ||
self.validity | ||
.as_ref() | ||
.map(|bitmap| bitmap.offset()) | ||
.unwrap_or_default(), | ||
) | ||
} | ||
|
||
fn to_ffi_aligned(&self) -> Self { | ||
self.clone() | ||
} | ||
} | ||
|
||
impl<A: ffi::ArrowArrayRef> FromFfi<A> for StructArray { | ||
unsafe fn try_from_ffi(array: A) -> Result<Self> { | ||
let data_type = array.field().data_type().clone(); | ||
let fields = Self::get_fields(&data_type); | ||
|
||
let validity = unsafe { array.validity() }?; | ||
let values = (0..fields.len()) | ||
.map(|index| { | ||
let child = array.child(index)?; | ||
Ok(ffi::try_from(child)?.into()) | ||
}) | ||
.collect::<Result<Vec<Arc<dyn Array>>>>()?; | ||
|
||
Ok(Self::from_data(data_type, values, validity)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
use crate::{ | ||
bitmap::utils::{zip_validity, ZipValidity}, | ||
scalar::{new_scalar, Scalar}, | ||
trusted_len::TrustedLen, | ||
}; | ||
|
||
use super::StructArray; | ||
|
||
pub struct StructValueIter<'a> { | ||
array: &'a StructArray, | ||
index: usize, | ||
end: usize, | ||
} | ||
|
||
impl<'a> StructValueIter<'a> { | ||
#[inline] | ||
pub fn new(array: &'a StructArray) -> Self { | ||
Self { | ||
array, | ||
index: 0, | ||
end: array.len(), | ||
} | ||
} | ||
} | ||
|
||
impl<'a> Iterator for StructValueIter<'a> { | ||
type Item = Vec<Box<dyn Scalar>>; | ||
|
||
#[inline] | ||
fn next(&mut self) -> Option<Self::Item> { | ||
if self.index == self.end { | ||
return None; | ||
} | ||
let old = self.index; | ||
self.index += 1; | ||
|
||
// Safety: | ||
// self.end is maximized by the length of the array | ||
Some( | ||
self.array | ||
.values() | ||
.iter() | ||
.map(|v| new_scalar(v.as_ref(), old)) | ||
.collect(), | ||
) | ||
} | ||
|
||
#[inline] | ||
fn size_hint(&self) -> (usize, Option<usize>) { | ||
(self.end - self.index, Some(self.end - self.index)) | ||
} | ||
} | ||
|
||
unsafe impl<'a> TrustedLen for StructValueIter<'a> {} | ||
|
||
impl<'a> DoubleEndedIterator for StructValueIter<'a> { | ||
#[inline] | ||
fn next_back(&mut self) -> Option<Self::Item> { | ||
if self.index == self.end { | ||
None | ||
} else { | ||
self.end -= 1; | ||
|
||
// Safety: | ||
// self.end is maximized by the length of the array | ||
Some( | ||
self.array | ||
.values() | ||
.iter() | ||
.map(|v| new_scalar(v.as_ref(), self.end)) | ||
.collect(), | ||
) | ||
} | ||
} | ||
} | ||
|
||
type ValuesIter<'a> = StructValueIter<'a>; | ||
type ZipIter<'a> = ZipValidity<'a, Vec<Box<dyn Scalar>>, ValuesIter<'a>>; | ||
|
||
impl<'a> IntoIterator for &'a StructArray { | ||
type Item = Option<Vec<Box<dyn Scalar>>>; | ||
type IntoIter = ZipIter<'a>; | ||
|
||
fn into_iter(self) -> Self::IntoIter { | ||
self.iter() | ||
} | ||
} | ||
|
||
impl<'a> StructArray { | ||
/// Returns an iterator of `Option<Box<dyn Array>>` | ||
pub fn iter(&'a self) -> ZipIter<'a> { | ||
zip_validity( | ||
StructValueIter::new(self), | ||
self.validity.as_ref().map(|x| x.iter()), | ||
) | ||
} | ||
|
||
/// Returns an iterator of `Box<dyn Array>` | ||
pub fn values_iter(&'a self) -> ValuesIter<'a> { | ||
StructValueIter::new(self) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ mod growable; | |
mod list; | ||
mod ord; | ||
mod primitive; | ||
mod struct_; | ||
mod union; | ||
mod utf8; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
use arrow2::array::*; | ||
use arrow2::datatypes::*; | ||
use arrow2::scalar::new_scalar; | ||
|
||
#[test] | ||
fn test_simple_iter() { | ||
use std::sync::Arc; | ||
let boolean = Arc::new(BooleanArray::from_slice(&[false, false, true, true])) as Arc<dyn Array>; | ||
let int = Arc::new(Int32Array::from_slice(&[42, 28, 19, 31])) as Arc<dyn Array>; | ||
|
||
let fields = vec![ | ||
Field::new("b", DataType::Boolean, false), | ||
Field::new("c", DataType::Int32, false), | ||
]; | ||
|
||
let array = StructArray::from_data( | ||
DataType::Struct(fields), | ||
vec![boolean.clone(), int.clone()], | ||
None, | ||
); | ||
|
||
for (i, item) in array.iter().enumerate() { | ||
let expected = Some(vec![ | ||
new_scalar(boolean.as_ref(), i), | ||
new_scalar(int.as_ref(), i), | ||
]); | ||
assert_eq!(expected, item); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
mod iterator; |