Skip to content

Commit

Permalink
Implement PrimitiveArray conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
emilk committed Jan 9, 2025
1 parent 3933885 commit e27870c
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,66 @@ impl<T: NativeType> Default for PrimitiveArray<T> {
PrimitiveArray::new(T::PRIMITIVE.into(), Default::default(), None)
}
}

/// arrow2 -> arrow1
#[cfg(feature = "arrow")]
impl<P, N> From<PrimitiveArray<N>> for arrow_array::PrimitiveArray<P>
where
P: arrow_array::ArrowPrimitiveType<Native = N>,
N: NativeType + arrow_buffer::ArrowNativeType,
{
fn from(array: PrimitiveArray<N>) -> Self {
let scalar_buffer: arrow_buffer::ScalarBuffer<N> = array.values().clone().into();
Self::new(scalar_buffer, array.validity().cloned().map(|v| v.into()))
}
}

// Conflicts with `impl<T: NativeType, Ptr: std::borrow::Borrow<Option<T>>> FromIterator<Ptr> for PrimitiveArray<T> {`
// /// arrow1 -> arrow2
// impl<P, N> From<arrow_array::PrimitiveArray<P>> for PrimitiveArray<N>
// where
// P: arrow_array::ArrowPrimitiveType<Native = N>,
// N: NativeType + arrow_buffer::ArrowNativeType,
// {
// fn from(array: arrow_array::PrimitiveArray<P>) -> Self {
// Self::new(
// P::DATA_TYPE.into(),
// array.values().clone().into(),
// array.validity().cloned().map(|v| v.into()),
// )
// }
// }

/// arrow1 -> arrow2
#[cfg(feature = "arrow")]
impl<N> PrimitiveArray<N>
where
N: NativeType + arrow_buffer::ArrowNativeType,
{
/// Convert from `arrow-rs` `PrimitiveArray`
pub fn from_arrow<P>(array: arrow_array::PrimitiveArray<P>) -> Self
where
P: arrow_array::ArrowPrimitiveType<Native = N>,
{
let (data_type, values, nulls) = array.into_parts();
Self::new(
data_type.into(),
values.into(),
nulls.map(Bitmap::from_arrow),
)
}
}

#[cfg(feature = "arrow")]
#[test]
fn test_primitive_array_arrow_conversion() {
let original = Float64Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]).sliced(1, 2);

let arrow: arrow_array::Float64Array = original.clone().into();
assert_eq!(arrow.len(), 2);
assert_eq!(arrow.value(0), 2.0);
assert_eq!(arrow.value(1), 3.0);

let roundtripped = Float64Array::from_arrow(arrow);
assert_eq!(roundtripped, original);
}

0 comments on commit e27870c

Please sign in to comment.