diff --git a/cpp/src/arrow/python/CMakeLists.txt b/cpp/src/arrow/python/CMakeLists.txt index b972b0d860618..960155703e1ed 100644 --- a/cpp/src/arrow/python/CMakeLists.txt +++ b/cpp/src/arrow/python/CMakeLists.txt @@ -38,6 +38,7 @@ set(ARROW_PYTHON_SRCS inference.cc init.cc io.cc + ipc.cc numpy_convert.cc numpy_to_arrow.cc python_to_arrow.cc diff --git a/cpp/src/arrow/python/extension_type.h b/cpp/src/arrow/python/extension_type.h index 0041c8af6a45e..f5b12e6d8430a 100644 --- a/cpp/src/arrow/python/extension_type.h +++ b/cpp/src/arrow/python/extension_type.h @@ -44,6 +44,7 @@ class ARROW_PYTHON_EXPORT PyExtensionType : public ExtensionType { std::string Serialize() const override; // For use from Cython + // Assumes that `typ` is borrowed static Status FromClass(const std::shared_ptr storage_type, const std::string extension_name, PyObject* typ, std::shared_ptr* out); diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 2049c52a72bf5..5d5800eec58f7 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1965,6 +1965,14 @@ cdef extern from 'arrow/python/inference.h' namespace 'arrow::py': c_bool IsPyFloat(object o) +cdef extern from 'arrow/python/ipc.h' namespace 'arrow::py': + cdef cppclass CPyRecordBatchReader" arrow::py::PyRecordBatchReader" \ + (CRecordBatchReader): + @staticmethod + CResult[shared_ptr[CRecordBatchReader]] Make(shared_ptr[CSchema], + object) + + cdef extern from 'arrow/extension_type.h' namespace 'arrow': cdef cppclass CExtensionTypeRegistry" arrow::ExtensionTypeRegistry": @staticmethod diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index ae991b0940f99..74a81c60ecff7 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -530,6 +530,35 @@ cdef class RecordBatchReader(_Weakrefable): self.reader = c_reader return self + @staticmethod + def from_batches(schema, batches): + """ + Create RecordBatchReader from an iterable of batches. + + Parameters + ---------- + schema : Schema + The shared schema of the record batches + batches : Iterable[RecordBatch] + The batches that this reader will return. + + Returns + ------- + reader : RecordBatchReader + """ + cdef: + shared_ptr[CSchema] c_schema + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader self + + c_schema = pyarrow_unwrap_schema(schema) + c_reader = GetResultValue(CPyRecordBatchReader.Make( + c_schema, batches)) + + self = RecordBatchReader.__new__(RecordBatchReader) + self.reader = c_reader + return self + cdef class _RecordBatchStreamReader(RecordBatchReader): cdef: diff --git a/python/pyarrow/tests/test_cffi.py b/python/pyarrow/tests/test_cffi.py index f47fd5d5b316e..5505e57164534 100644 --- a/python/pyarrow/tests/test_cffi.py +++ b/python/pyarrow/tests/test_cffi.py @@ -235,42 +235,58 @@ def test_export_import_batch(): pa.RecordBatch._import_from_c(ptr_array, ptr_schema) -def _export_import_batch_reader(ptr_stream): +def _export_import_batch_reader(ptr_stream, reader_factory): # Prepare input batches = make_batches() schema = batches[0].schema - reader = pa.ipc.open_stream(make_serialized(schema, batches)) + reader = reader_factory(schema, batches) reader._export_to_c(ptr_stream) # Delete and recreate C++ object from exported pointer - del reader + del reader, batches + reader_new = pa.ipc.RecordBatchReader._import_from_c(ptr_stream) assert reader_new.schema == schema got_batches = list(reader_new) del reader_new - assert batches == got_batches + assert got_batches == make_batches() # Test read_pandas() if pd is not None: - reader = pa.ipc.open_stream(make_serialized(schema, batches)) + batches = make_batches() + schema = batches[0].schema + expected_df = pa.Table.from_batches(batches).to_pandas() + + reader = reader_factory(schema, batches) reader._export_to_c(ptr_stream) - del reader + del reader, batches + reader_new = pa.ipc.RecordBatchReader._import_from_c(ptr_stream) - expected_df = pa.Table.from_batches(batches).to_pandas() got_df = reader_new.read_pandas() del reader_new tm.assert_frame_equal(expected_df, got_df) +def make_ipc_stream_reader(schema, batches): + return pa.ipc.open_stream(make_serialized(schema, batches)) + + +def make_py_record_batch_reader(schema, batches): + return pa.ipc.RecordBatchReader.from_batches(schema, batches) + + @needs_cffi -def test_export_import_batch_reader(): +@pytest.mark.parametrize('reader_factory', + [make_ipc_stream_reader, + make_py_record_batch_reader]) +def test_export_import_batch_reader(reader_factory): c_stream = ffi.new("struct ArrowArrayStream*") ptr_stream = int(ffi.cast("uintptr_t", c_stream)) gc.collect() # Make sure no Arrow data dangles in a ref cycle old_allocated = pa.total_allocated_bytes() - _export_import_batch_reader(ptr_stream) + _export_import_batch_reader(ptr_stream, reader_factory) assert pa.total_allocated_bytes() == old_allocated diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index 44f8499e8346e..3d3e72e616533 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -15,11 +15,13 @@ # specific language governing permissions and limitations # under the License. +from collections import UserList import io import pytest import socket import sys import threading +import weakref import numpy as np @@ -858,3 +860,36 @@ def test_write_empty_ipc_file(): table = reader.read_all() assert len(table) == 0 assert table.schema.equals(schema) + + +def test_py_record_batch_reader(): + def make_schema(): + return pa.schema([('field', pa.int64())]) + + def make_batches(): + schema = make_schema() + batch1 = pa.record_batch([[1, 2, 3]], schema=schema) + batch2 = pa.record_batch([[4, 5]], schema=schema) + return [batch1, batch2] + + # With iterable + batches = UserList(make_batches()) # weakrefable + wr = weakref.ref(batches) + + with pa.ipc.RecordBatchReader.from_batches(make_schema(), + batches) as reader: + batches = None + assert wr() is not None + assert list(reader) == make_batches() + assert wr() is None + + # With iterator + batches = iter(UserList(make_batches())) # weakrefable + wr = weakref.ref(batches) + + with pa.ipc.RecordBatchReader.from_batches(make_schema(), + batches) as reader: + batches = None + assert wr() is not None + assert list(reader) == make_batches() + assert wr() is None