From 070d1b3ea322b8eac045a91cfdc4e896f4e2eab5 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Mon, 21 Feb 2022 19:49:01 +0000 Subject: [PATCH] Added support to write to Arrow Stream --- .../pyproject.toml | 2 +- .../src/c_stream.rs | 23 ++++ arrow-pyarrow-integration-testing/src/lib.rs | 6 ++ .../tests/test_c_stream.py | 11 +- src/ffi/mod.rs | 2 +- src/ffi/stream.rs | 102 +++++++++++++++++- tests/it/{ffi.rs => ffi/data.rs} | 0 tests/it/ffi/mod.rs | 2 + tests/it/ffi/stream.rs | 35 ++++++ 9 files changed, 174 insertions(+), 9 deletions(-) rename tests/it/{ffi.rs => ffi/data.rs} (100%) create mode 100644 tests/it/ffi/mod.rs create mode 100644 tests/it/ffi/stream.rs diff --git a/arrow-pyarrow-integration-testing/pyproject.toml b/arrow-pyarrow-integration-testing/pyproject.toml index 27480690e06..5c143694e18 100644 --- a/arrow-pyarrow-integration-testing/pyproject.toml +++ b/arrow-pyarrow-integration-testing/pyproject.toml @@ -16,5 +16,5 @@ # under the License. [build-system] -requires = ["maturin"] +requires = ["maturin>=0.12,<0.13"] build-backend = "maturin" diff --git a/arrow-pyarrow-integration-testing/src/c_stream.rs b/arrow-pyarrow-integration-testing/src/c_stream.rs index 9a4e4fabe8a..968ef95df22 100644 --- a/arrow-pyarrow-integration-testing/src/c_stream.rs +++ b/arrow-pyarrow-integration-testing/src/c_stream.rs @@ -3,6 +3,7 @@ use pyo3::ffi::Py_uintptr_t; use pyo3::prelude::*; +use arrow2::array::Int32Array; use arrow2::ffi; use super::*; @@ -26,3 +27,25 @@ pub fn to_rust_iterator(ob: PyObject, py: Python) -> PyResult> { } Ok(arrays) } + +pub fn from_rust_iterator(py: Python) -> PyResult { + let array = Int32Array::from(&[Some(2), None, Some(1), None]); + let field = Field::new("a", array.data_type().clone(), true); + + let array: Arc = Arc::new(array.clone()); + //let arrays = vec![array.clone(), array.clone(), array]; + let arrays: Vec> = vec![]; + + let iter = Box::new(arrays.clone().into_iter().map(Ok)) as _; + + let mut stream = Box::new(ffi::ArrowArrayStream::empty()); + unsafe { ffi::export_iterator(iter, field, &mut *stream) }; + + let pa = py.import("pyarrow.ipc")?; + let stream = pa.getattr("RecordBatchReader")?.call_method1( + "_import_from_c", + ((&*stream as *const ffi::ArrowArrayStream) as Py_uintptr_t,), + )?; + + Ok(stream.to_object(py)) +} diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 9c4e852d30a..9b18dab716c 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -159,10 +159,16 @@ pub fn to_rust_iterator(ob: PyObject, py: Python) -> PyResult> { c_stream::to_rust_iterator(ob, py) } +#[pyfunction] +pub fn from_rust_iterator(py: Python) -> PyResult { + c_stream::from_rust_iterator(py) +} + #[pymodule] fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(round_trip_array, m)?)?; m.add_function(wrap_pyfunction!(round_trip_field, m)?)?; m.add_function(wrap_pyfunction!(to_rust_iterator, m)?)?; + m.add_function(wrap_pyfunction!(from_rust_iterator, m)?)?; Ok(()) } diff --git a/arrow-pyarrow-integration-testing/tests/test_c_stream.py b/arrow-pyarrow-integration-testing/tests/test_c_stream.py index 168f72eadc1..9073ae2c213 100644 --- a/arrow-pyarrow-integration-testing/tests/test_c_stream.py +++ b/arrow-pyarrow-integration-testing/tests/test_c_stream.py @@ -7,9 +7,6 @@ class TestCase(unittest.TestCase): def test_rust_reads(self): - """ - Python -> Rust -> Python - """ schema = pyarrow.schema([pyarrow.field("aa", pyarrow.int32())]) a = pyarrow.array([1, None, 2], type=pyarrow.int32()) @@ -20,3 +17,11 @@ def test_rust_reads(self): array = arrays[0].field(0) assert array == a + + # see https://issues.apache.org/jira/browse/ARROW-15747 + def _test_pyarrow_reads(self): + stream = arrow_pyarrow_integration_testing.from_rust_iterator() + + arrays = [a for a in stream] + + assert False diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs index ca561e2ee86..4678b567dd5 100644 --- a/src/ffi/mod.rs +++ b/src/ffi/mod.rs @@ -18,7 +18,7 @@ use crate::error::Result; use self::schema::to_field; pub use generated::{ArrowArray, ArrowArrayStream, ArrowSchema}; -pub use stream::ArrowArrayStreamReader; +pub use stream::{export_iterator, ArrowArrayStreamReader}; /// Exports an [`Arc`] to the C data interface. /// # Safety diff --git a/src/ffi/stream.rs b/src/ffi/stream.rs index d020db0b5ff..6219d33f755 100644 --- a/src/ffi/stream.rs +++ b/src/ffi/stream.rs @@ -1,8 +1,9 @@ -use std::ffi::CStr; +use std::ffi::{CStr, CString}; +use std::sync::Arc; use crate::{array::Array, datatypes::Field, error::ArrowError}; -use super::{import_array_from_c, import_field_from_c}; +use super::{export_array_to_c, export_field_to_c, import_array_from_c, import_field_from_c}; use super::{ArrowArray, ArrowArrayStream, ArrowSchema}; impl Drop for ArrowArrayStream { @@ -48,8 +49,7 @@ unsafe fn handle_error(iter: &mut ArrowArrayStream) -> ArrowError { ) } -/// Interface for the Arrow C stream interface. Implements an iterator of [`Array`]. -/// +/// Implements an iterator of [`Array`] consumed from the [C stream interface](https://arrow.apache.org/docs/format/CStreamInterface.html). pub struct ArrowArrayStreamReader { iter: Box, field: Field, @@ -127,3 +127,97 @@ impl ArrowArrayStreamReader { .transpose() } } + +struct PrivateData { + iter: Box, ArrowError>>>, + field: Field, + error: Option, +} + +unsafe extern "C" fn get_next(iter: *mut ArrowArrayStream, array: *mut ArrowArray) -> i32 { + if iter.is_null() { + return 2001; + } + let mut private = &mut *((*iter).private_data as *mut PrivateData); + + match private.iter.next() { + Some(Ok(item)) => { + // check that the array has the same data_type as field + let item_dt = item.data_type(); + let expected_dt = private.field.data_type(); + if item_dt != expected_dt { + private.error = Some(CString::new(format!("The iterator produced an item of data type {item_dt:?} but the producer expects data type {expected_dt:?}").as_bytes().to_vec()).unwrap()); + return 2001; // custom application specific error (since this is never a result of this interface) + } + + export_array_to_c(item, array); + private.error = None; + 0 + } + Some(Err(err)) => { + private.error = Some(CString::new(err.to_string().as_bytes().to_vec()).unwrap()); + 2001 // custom application specific error (since this is never a result of this interface) + } + None => { + *array = ArrowArray::empty(); + private.error = None; + 0 + } + } +} + +unsafe extern "C" fn get_schema(iter: *mut ArrowArrayStream, schema: *mut ArrowSchema) -> i32 { + println!("get_schema"); + if iter.is_null() { + return 2001; + } + let private = &mut *((*iter).private_data as *mut PrivateData); + + export_field_to_c(&private.field, schema); + 0 +} + +unsafe extern "C" fn get_last_error(iter: *mut ArrowArrayStream) -> *const ::std::os::raw::c_char { + if iter.is_null() { + return std::ptr::null(); + } + let private = &mut *((*iter).private_data as *mut PrivateData); + + private + .error + .as_ref() + .map(|x| x.as_ptr()) + .unwrap_or(std::ptr::null()) +} + +unsafe extern "C" fn release(iter: *mut ArrowArrayStream) { + if iter.is_null() { + return; + } + let _ = Box::from_raw((*iter).private_data as *mut PrivateData); + (*iter).release = None; + // private drops automatically +} + +/// Exports an iterator to the [C stream interface](https://arrow.apache.org/docs/format/CStreamInterface.html) +/// # Safety +/// The pointer `consumer` must be allocated +pub unsafe fn export_iterator( + iter: Box, ArrowError>>>, + field: Field, + consumer: *mut ArrowArrayStream, +) { + let private_data = Box::new(PrivateData { + iter, + field, + error: None, + }); + + *consumer = ArrowArrayStream { + get_schema: Some(get_schema), + get_next: Some(get_next), + get_last_error: Some(get_last_error), + release: Some(release), + private_data: Box::into_raw(private_data) as *mut ::std::os::raw::c_void, + } +} diff --git a/tests/it/ffi.rs b/tests/it/ffi/data.rs similarity index 100% rename from tests/it/ffi.rs rename to tests/it/ffi/data.rs diff --git a/tests/it/ffi/mod.rs b/tests/it/ffi/mod.rs new file mode 100644 index 00000000000..af49381f138 --- /dev/null +++ b/tests/it/ffi/mod.rs @@ -0,0 +1,2 @@ +mod data; +mod stream; diff --git a/tests/it/ffi/stream.rs b/tests/it/ffi/stream.rs new file mode 100644 index 00000000000..97075083143 --- /dev/null +++ b/tests/it/ffi/stream.rs @@ -0,0 +1,35 @@ +use std::collections::BTreeMap; +use std::sync::Arc; + +use arrow2::array::*; +use arrow2::bitmap::Bitmap; +use arrow2::datatypes::{DataType, Field, TimeUnit}; +use arrow2::{error::Result, ffi}; + +fn _test_round_trip(arrays: Vec>) -> Result<()> { + let field = Field::new("a", arrays[0].data_type().clone(), true); + let iter = Box::new(arrays.clone().into_iter().map(Ok)) as _; + + let mut stream = Box::new(ffi::ArrowArrayStream::empty()); + + unsafe { ffi::export_iterator(iter, field.clone(), &mut *stream) } + + let mut stream = unsafe { ffi::ArrowArrayStreamReader::try_new(stream)? }; + + let mut produced_arrays: Vec> = vec![]; + while let Some(array) = unsafe { stream.next() } { + produced_arrays.push(array?.into()); + } + + assert_eq!(produced_arrays, arrays); + assert_eq!(stream.field(), &field); + Ok(()) +} + +#[test] +fn round_trip() -> Result<()> { + let array = Int32Array::from(&[Some(2), None, Some(1), None]); + let array: Arc = Arc::new(array.clone()); + + _test_round_trip(vec![array.clone(), array.clone(), array]) +}