diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 7dec039545253..2a22d2d6fac35 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -1618,9 +1618,7 @@ namespace { class ArrayStreamBatchReader : public RecordBatchReader { public: - explicit ArrayStreamBatchReader(struct ArrowArrayStream* stream) : stream_(stream) { - DCHECK(!ArrowArrayStreamIsReleased(stream_)); - } + explicit ArrayStreamBatchReader(struct ArrowArrayStream* stream) : stream_(stream) {} ~ArrayStreamBatchReader() { ArrowArrayStreamRelease(stream_); } @@ -1629,7 +1627,13 @@ class ArrayStreamBatchReader : public RecordBatchReader { Status ReadNext(std::shared_ptr* batch) override { struct ArrowArray c_array; RETURN_NOT_OK(StatusFromCError(stream_->get_next(stream_, &c_array))); - return ImportRecordBatch(&c_array, CacheSchema()).Value(batch); + if (ArrowArrayIsReleased(&c_array)) { + // End of stream + batch->reset(); + return Status::OK(); + } else { + return ImportRecordBatch(&c_array, CacheSchema()).Value(batch); + } } private: @@ -1667,6 +1671,9 @@ class ArrayStreamBatchReader : public RecordBatchReader { Result> ImportRecordBatchReader( struct ArrowArrayStream* stream) { + if (ArrowArrayStreamIsReleased(stream)) { + return Status::Invalid("Cannot import released ArrowArrayStream"); + } // XXX should we call get_schema() here to avoid crashing on error? return std::make_shared(stream); } diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx index 028ddc6f43c7a..34c6693c51e82 100644 --- a/python/pyarrow/_csv.pyx +++ b/python/pyarrow/_csv.pyx @@ -28,7 +28,7 @@ from collections.abc import Mapping from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * from pyarrow.lib cimport (check_status, Field, MemoryPool, Schema, - _CRecordBatchReader, ensure_type, + RecordBatchReader, ensure_type, maybe_unbox_memory_pool, get_input_stream, native_transcoding_input_stream, pyarrow_wrap_schema, pyarrow_wrap_table, @@ -633,7 +633,7 @@ cdef _get_convert_options(ConvertOptions convert_options, out[0] = convert_options.options -cdef class CSVStreamingReader(_CRecordBatchReader): +cdef class CSVStreamingReader(RecordBatchReader): """An object that reads record batches incrementally from a CSV file. Should not be instantiated directly by user code. diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index b7459d21da155..53771043374bf 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -35,7 +35,7 @@ from pyarrow.lib cimport * from pyarrow.lib import ArrowException, ArrowInvalid from pyarrow.lib import as_buffer, frombytes, tobytes from pyarrow.includes.libarrow_flight cimport * -from pyarrow.ipc import _ReadPandasOption, _get_legacy_format_default +from pyarrow.ipc import _get_legacy_format_default, _ReadPandasMixin import pyarrow.lib as lib @@ -812,7 +812,7 @@ cdef class FlightStreamChunk(_Weakrefable): self.chunk.data != NULL, self.chunk.app_metadata != NULL) -cdef class _MetadataRecordBatchReader(_Weakrefable): +cdef class _MetadataRecordBatchReader(_Weakrefable, _ReadPandasMixin): """A reader for Flight streams.""" # Needs to be separate class so the "real" class can subclass the @@ -869,8 +869,7 @@ cdef class _MetadataRecordBatchReader(_Weakrefable): return chunk -cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader, - _ReadPandasOption): +cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader): """The virtual base class for readers for Flight streams.""" @@ -1365,7 +1364,7 @@ cdef class RecordBatchStream(FlightDataStream): data_source : RecordBatchReader or Table options : pyarrow.ipc.IpcWriteOptions, optional """ - if (not isinstance(data_source, _CRecordBatchReader) and + if (not isinstance(data_source, RecordBatchReader) and not isinstance(data_source, lib.Table)): raise TypeError("Expected RecordBatchReader or Table, " "but got: {}".format(type(data_source))) @@ -1375,8 +1374,8 @@ cdef class RecordBatchStream(FlightDataStream): cdef CFlightDataStream* to_stream(self) except *: cdef: shared_ptr[CRecordBatchReader] reader - if isinstance(self.data_source, _CRecordBatchReader): - reader = (<_CRecordBatchReader> self.data_source).reader + if isinstance(self.data_source, RecordBatchReader): + reader = ( self.data_source).reader elif isinstance(self.data_source, lib.Table): table = ( self.data_source).table reader.reset(new TableBatchReader(deref(table))) @@ -1616,7 +1615,7 @@ cdef CStatus _data_stream_next(void* self, CFlightPayload* payload) except *: else: result, metadata = result, None - if isinstance(result, (Table, _CRecordBatchReader)): + if isinstance(result, (Table, RecordBatchReader)): if metadata: raise ValueError("Can only return metadata alongside a " "RecordBatch.") diff --git a/python/pyarrow/cffi.py b/python/pyarrow/cffi.py index 8880c25a0357d..961b61dee59fd 100644 --- a/python/pyarrow/cffi.py +++ b/python/pyarrow/cffi.py @@ -52,6 +52,18 @@ // Opaque producer-specific data void* private_data; }; + + struct ArrowArrayStream { + int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out); + int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out); + + const char* (*get_last_error)(struct ArrowArrayStream*); + + // Release callback + void (*release)(struct ArrowArrayStream*); + // Opaque producer-specific data + void* private_data; + }; """ # TODO use out-of-line mode for faster import and avoid C parsing diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 837bb05aadea7..57f38ce81c1c6 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1390,7 +1390,7 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: " arrow::ipc::RecordBatchStreamReader"(CRecordBatchReader): @staticmethod CResult[shared_ptr[CRecordBatchReader]] Open( - const CInputStream* stream, const CIpcReadOptions& options) + const shared_ptr[CInputStream], const CIpcReadOptions&) @staticmethod CResult[shared_ptr[CRecordBatchReader]] Open2" Open"( @@ -2049,6 +2049,9 @@ cdef extern from 'arrow/c/abi.h': cdef struct ArrowArray: pass + cdef struct ArrowArrayStream: + pass + cdef extern from 'arrow/c/bridge.h' namespace 'arrow' nogil: CStatus ExportType(CDataType&, ArrowSchema* out) CResult[shared_ptr[CDataType]] ImportType(ArrowSchema*) @@ -2069,3 +2072,8 @@ cdef extern from 'arrow/c/bridge.h' namespace 'arrow' nogil: shared_ptr[CSchema]) CResult[shared_ptr[CRecordBatch]] ImportRecordBatch(ArrowArray*, ArrowSchema*) + + CStatus ExportRecordBatchReader(shared_ptr[CRecordBatchReader], + ArrowArrayStream*) + CResult[shared_ptr[CRecordBatchReader]] ImportRecordBatchReader( + ArrowArrayStream*) diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index 4572d05adb860..f9dc34f55c9e8 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -398,8 +398,29 @@ cdef _get_input_stream(object source, shared_ptr[CInputStream]* out): get_input_stream(source, True, out) -cdef class _CRecordBatchReader(_Weakrefable): - """The base RecordBatchReader wrapper. +class _ReadPandasMixin: + + def read_pandas(self, **options): + """ + Read contents of stream to a pandas.DataFrame. + + Read all record batches as a pyarrow.Table then convert it to a + pandas.DataFrame using Table.to_pandas. + + Parameters + ---------- + **options : arguments to forward to Table.to_pandas + + Returns + ------- + df : pandas.DataFrame + """ + table = self.read_all() + return table.to_pandas(**options) + + +cdef class RecordBatchReader(_Weakrefable): + """Base class for reading stream of record batches. Provides common implementations of convenience methods. Should not be instantiated directly by user code. @@ -411,6 +432,18 @@ cdef class _CRecordBatchReader(_Weakrefable): while True: yield self.read_next_batch() + @property + def schema(self): + """ + Shared schema of the record batches in the stream. + """ + cdef shared_ptr[CSchema] c_schema + + with nogil: + c_schema = self.reader.get().schema() + + return pyarrow_wrap_schema(c_schema) + def get_next_batch(self): import warnings warnings.warn('Please use read_next_batch instead of ' @@ -445,21 +478,62 @@ cdef class _CRecordBatchReader(_Weakrefable): check_status(self.reader.get().ReadAll(&table)) return pyarrow_wrap_table(table) + read_pandas = _ReadPandasMixin.read_pandas + def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): pass + def _export_to_c(self, uintptr_t out_ptr): + """ + Export to a C ArrowArrayStream struct, given its pointer. + + Parameters + ---------- + out_ptr: int + The raw pointer to a C ArrowArrayStream struct. -cdef class _RecordBatchStreamReader(_CRecordBatchReader): + Be careful: if you don't pass the ArrowArrayStream struct to a + consumer, array memory will leak. This is a low-level function + intended for expert users. + """ + with nogil: + check_status(ExportRecordBatchReader( + self.reader, out_ptr)) + + @staticmethod + def _import_from_c(uintptr_t in_ptr): + """ + Import RecordBatchReader from a C ArrowArrayStream struct, + given its pointer. + + Parameters + ---------- + in_ptr: int + The raw pointer to a C ArrowArrayStream struct. + + This is a low-level function intended for expert users. + """ + cdef: + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader self + + with nogil: + c_reader = GetResultValue(ImportRecordBatchReader( + in_ptr)) + + self = RecordBatchReader.__new__(RecordBatchReader) + self.reader = c_reader + return self + + +cdef class _RecordBatchStreamReader(RecordBatchReader): cdef: shared_ptr[CInputStream] in_stream CIpcReadOptions options - cdef readonly: - Schema schema - def __cinit__(self): pass @@ -467,9 +541,7 @@ cdef class _RecordBatchStreamReader(_CRecordBatchReader): _get_input_stream(source, &self.in_stream) with nogil: self.reader = GetResultValue(CRecordBatchStreamReader.Open( - self.in_stream.get(), self.options)) - - self.schema = pyarrow_wrap_schema(self.reader.get().schema()) + self.in_stream, self.options)) cdef class _RecordBatchFileWriter(_RecordBatchStreamWriter): @@ -560,6 +632,8 @@ cdef class _RecordBatchFileReader(_Weakrefable): return pyarrow_wrap_table(table) + read_pandas = _ReadPandasMixin.read_pandas + def __enter__(self): return self diff --git a/python/pyarrow/ipc.py b/python/pyarrow/ipc.py index 19e80baa8dfa3..65325c483c4a8 100644 --- a/python/pyarrow/ipc.py +++ b/python/pyarrow/ipc.py @@ -22,6 +22,7 @@ import pyarrow as pa from pyarrow.lib import (IpcWriteOptions, Message, MessageReader, # noqa + RecordBatchReader, _ReadPandasMixin, MetadataVersion, read_message, read_record_batch, read_schema, read_tensor, write_tensor, @@ -29,28 +30,7 @@ import pyarrow.lib as lib -class _ReadPandasOption: - - def read_pandas(self, **options): - """ - Read contents of stream to a pandas.DataFrame. - - Read all record batches as a pyarrow.Table then convert it to a - pandas.DataFrame using Table.to_pandas. - - Parameters - ---------- - **options : arguments to forward to Table.to_pandas - - Returns - ------- - df : pandas.DataFrame - """ - table = self.read_all() - return table.to_pandas(**options) - - -class RecordBatchStreamReader(lib._RecordBatchStreamReader, _ReadPandasOption): +class RecordBatchStreamReader(lib._RecordBatchStreamReader): """ Reader for the Arrow streaming binary format. @@ -97,7 +77,7 @@ def __init__(self, sink, schema, *, use_legacy_format=None, options=None): self._open(sink, schema, options=options) -class RecordBatchFileReader(lib._RecordBatchFileReader, _ReadPandasOption): +class RecordBatchFileReader(lib._RecordBatchFileReader): """ Class for reading Arrow record batch data from the Arrow binary file format diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 8e06dcd1d9bcb..5b2958a06472d 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -483,7 +483,7 @@ cdef class _CRecordBatchWriter(_Weakrefable): shared_ptr[CRecordBatchWriter] writer -cdef class _CRecordBatchReader(_Weakrefable): +cdef class RecordBatchReader(_Weakrefable): cdef: shared_ptr[CRecordBatchReader] reader diff --git a/python/pyarrow/tests/test_cffi.py b/python/pyarrow/tests/test_cffi.py index bcf3c723950fc..f47fd5d5b316e 100644 --- a/python/pyarrow/tests/test_cffi.py +++ b/python/pyarrow/tests/test_cffi.py @@ -26,6 +26,13 @@ import pytest +try: + import pandas as pd + import pandas.testing as tm +except ImportError: + pd = tm = None + + needs_cffi = pytest.mark.skipif(ffi is None, reason="test needs cffi package installed") @@ -36,6 +43,34 @@ assert_array_released = pytest.raises( ValueError, match="Cannot import released ArrowArray") +assert_stream_released = pytest.raises( + ValueError, match="Cannot import released ArrowArrayStream") + + +def make_schema(): + return pa.schema([('ints', pa.list_(pa.int32()))], + metadata={b'key1': b'value1'}) + + +def make_batch(): + return pa.record_batch([[[1], [2, 42]]], make_schema()) + + +def make_batches(): + schema = make_schema() + return [ + pa.record_batch([[[1], [2, 42]]], schema), + pa.record_batch([[None, [], [5, 6]]], schema), + ] + + +def make_serialized(schema, batches): + with pa.BufferOutputStream() as sink: + with pa.ipc.new_stream(sink, schema) as out: + for batch in batches: + out.write(batch) + return sink.getvalue() + @needs_cffi def test_export_import_type(): @@ -120,10 +155,6 @@ def test_export_import_schema(): c_schema = ffi.new("struct ArrowSchema*") ptr_schema = int(ffi.cast("uintptr_t", c_schema)) - def make_schema(): - return pa.schema([('ints', pa.list_(pa.int32()))], - metadata={b'key1': b'value1'}) - gc.collect() # Make sure no Arrow data dangles in a ref cycle old_allocated = pa.total_allocated_bytes() @@ -156,13 +187,6 @@ def test_export_import_batch(): c_array = ffi.new("struct ArrowArray*") ptr_array = int(ffi.cast("uintptr_t", c_array)) - def make_schema(): - return pa.schema([('ints', pa.list_(pa.int32()))], - metadata={b'key1': b'value1'}) - - def make_batch(): - return pa.record_batch([[[1], [2, 42]]], make_schema()) - gc.collect() # Make sure no Arrow data dangles in a ref cycle old_allocated = pa.total_allocated_bytes() @@ -172,7 +196,7 @@ def make_batch(): py_value = batch.to_pydict() batch._export_to_c(ptr_array) assert pa.total_allocated_bytes() > old_allocated - # Delete recreate C++ object from exported pointer + # Delete and recreate C++ object from exported pointer del batch batch_new = pa.RecordBatch._import_from_c(ptr_array, schema) assert batch_new.to_pydict() == py_value @@ -192,8 +216,6 @@ def make_batch(): del batch batch_new = pa.RecordBatch._import_from_c(ptr_array, ptr_schema) assert batch_new.to_pydict() == py_value - print(batch_new.schema) - print(make_schema()) assert batch_new.schema == make_schema() assert pa.total_allocated_bytes() > old_allocated del batch_new @@ -211,3 +233,47 @@ def make_batch(): # Now released with assert_schema_released: pa.RecordBatch._import_from_c(ptr_array, ptr_schema) + + +def _export_import_batch_reader(ptr_stream): + # Prepare input + batches = make_batches() + schema = batches[0].schema + + reader = pa.ipc.open_stream(make_serialized(schema, batches)) + reader._export_to_c(ptr_stream) + # Delete and recreate C++ object from exported pointer + del reader + reader_new = pa.ipc.RecordBatchReader._import_from_c(ptr_stream) + assert reader_new.schema == schema + got_batches = list(reader_new) + del reader_new + assert batches == got_batches + + # Test read_pandas() + if pd is not None: + reader = pa.ipc.open_stream(make_serialized(schema, batches)) + reader._export_to_c(ptr_stream) + del reader + reader_new = pa.ipc.RecordBatchReader._import_from_c(ptr_stream) + expected_df = pa.Table.from_batches(batches).to_pandas() + got_df = reader_new.read_pandas() + del reader_new + tm.assert_frame_equal(expected_df, got_df) + + +@needs_cffi +def test_export_import_batch_reader(): + c_stream = ffi.new("struct ArrowArrayStream*") + ptr_stream = int(ffi.cast("uintptr_t", c_stream)) + + gc.collect() # Make sure no Arrow data dangles in a ref cycle + old_allocated = pa.total_allocated_bytes() + + _export_import_batch_reader(ptr_stream) + + assert pa.total_allocated_bytes() == old_allocated + + # Now released + with assert_stream_released: + pa.ipc.RecordBatchReader._import_from_c(ptr_stream)