Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
ArrowArrayStreamReader::try_new(): Safeguard against released streams (
Browse files Browse the repository at this point in the history
  • Loading branch information
Qqwy authored Jul 27, 2023
1 parent de20f2d commit d5c78e7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
9 changes: 8 additions & 1 deletion src/ffi/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,21 @@ pub struct ArrowArrayStreamReader<Iter: DerefMut<Target = ArrowArrayStream>> {
impl<Iter: DerefMut<Target = ArrowArrayStream>> ArrowArrayStreamReader<Iter> {
/// Returns a new [`ArrowArrayStreamReader`]
/// # Error
/// Errors iff the [`ArrowArrayStream`] is out of specification
/// Errors iff the [`ArrowArrayStream`] is out of specification,
/// or was already released prior to calling this function.
/// # Safety
/// This method is intrinsically `unsafe` since it assumes that the `ArrowArrayStream`
/// contains a valid Arrow C stream interface.
/// In particular:
/// * The `ArrowArrayStream` fulfills the invariants of the C stream interface
/// * The schema `get_schema` produces fulfills the C data interface
pub unsafe fn try_new(mut iter: Iter) -> Result<Self, Error> {
if iter.release.is_none() {
return Err(Error::InvalidArgumentError(
"The C stream was already released".to_string(),
));
};

if iter.get_next.is_none() {
return Err(Error::OutOfSpec(
"The C stream MUST contain a non-null get_next".to_string(),
Expand Down
13 changes: 12 additions & 1 deletion tests/it/ffi/stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use arrow2::array::*;
use arrow2::datatypes::Field;
use arrow2::{error::Result, ffi};
use arrow2::{error::Error, error::Result, ffi};

fn _test_round_trip(arrays: Vec<Box<dyn Array>>) -> Result<()> {
let field = Field::new("a", arrays[0].data_type().clone(), true);
Expand Down Expand Up @@ -30,3 +30,14 @@ fn round_trip() -> Result<()> {

_test_round_trip(vec![array.clone(), array.clone(), array])
}

#[test]
fn stream_reader_try_new_invalid_argument_error_on_released_stream() {
let released_stream = Box::new(ffi::ArrowArrayStream::empty());
let reader = unsafe { ffi::ArrowArrayStreamReader::try_new(released_stream) };
// poor man's assert_matches:
match reader {
Err(Error::InvalidArgumentError(_)) => {}
_ => panic!("ArrowArrayStreamReader::try_new did not return an InvalidArgumentError"),
}
}

0 comments on commit d5c78e7

Please sign in to comment.