Skip to content

Commit

Permalink
Allow creating and exporting a RecordBatchReader from a Python iterable
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Sep 30, 2020
1 parent 3badb80 commit 9ade9a9
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 9 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ set(ARROW_PYTHON_SRCS
inference.cc
init.cc
io.cc
ipc.cc
numpy_convert.cc
numpy_to_arrow.cc
python_to_arrow.cc
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/python/extension_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class ARROW_PYTHON_EXPORT PyExtensionType : public ExtensionType {
std::string Serialize() const override;

// For use from Cython
// Assumes that `typ` is borrowed
static Status FromClass(const std::shared_ptr<DataType> storage_type,
const std::string extension_name, PyObject* typ,
std::shared_ptr<ExtensionType>* out);
Expand Down
8 changes: 8 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,14 @@ cdef extern from 'arrow/python/inference.h' namespace 'arrow::py':
c_bool IsPyFloat(object o)


cdef extern from 'arrow/python/ipc.h' namespace 'arrow::py':
cdef cppclass CPyRecordBatchReader" arrow::py::PyRecordBatchReader" \
(CRecordBatchReader):
@staticmethod
CResult[shared_ptr[CRecordBatchReader]] Make(shared_ptr[CSchema],
object)


cdef extern from 'arrow/extension_type.h' namespace 'arrow':
cdef cppclass CExtensionTypeRegistry" arrow::ExtensionTypeRegistry":
@staticmethod
Expand Down
29 changes: 29 additions & 0 deletions python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,35 @@ cdef class RecordBatchReader(_Weakrefable):
self.reader = c_reader
return self

@staticmethod
def from_batches(schema, batches):
"""
Create RecordBatchReader from an iterable of batches.
Parameters
----------
schema : Schema
The shared schema of the record batches
batches : Iterable[RecordBatch]
The batches that this reader will return.
Returns
-------
reader : RecordBatchReader
"""
cdef:
shared_ptr[CSchema] c_schema
shared_ptr[CRecordBatchReader] c_reader
RecordBatchReader self

c_schema = pyarrow_unwrap_schema(schema)
c_reader = GetResultValue(CPyRecordBatchReader.Make(
c_schema, batches))

self = RecordBatchReader.__new__(RecordBatchReader)
self.reader = c_reader
return self


cdef class _RecordBatchStreamReader(RecordBatchReader):
cdef:
Expand Down
34 changes: 25 additions & 9 deletions python/pyarrow/tests/test_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,42 +235,58 @@ def test_export_import_batch():
pa.RecordBatch._import_from_c(ptr_array, ptr_schema)


def _export_import_batch_reader(ptr_stream):
def _export_import_batch_reader(ptr_stream, reader_factory):
# Prepare input
batches = make_batches()
schema = batches[0].schema

reader = pa.ipc.open_stream(make_serialized(schema, batches))
reader = reader_factory(schema, batches)
reader._export_to_c(ptr_stream)
# Delete and recreate C++ object from exported pointer
del reader
del reader, batches

reader_new = pa.ipc.RecordBatchReader._import_from_c(ptr_stream)
assert reader_new.schema == schema
got_batches = list(reader_new)
del reader_new
assert batches == got_batches
assert got_batches == make_batches()

# Test read_pandas()
if pd is not None:
reader = pa.ipc.open_stream(make_serialized(schema, batches))
batches = make_batches()
schema = batches[0].schema
expected_df = pa.Table.from_batches(batches).to_pandas()

reader = reader_factory(schema, batches)
reader._export_to_c(ptr_stream)
del reader
del reader, batches

reader_new = pa.ipc.RecordBatchReader._import_from_c(ptr_stream)
expected_df = pa.Table.from_batches(batches).to_pandas()
got_df = reader_new.read_pandas()
del reader_new
tm.assert_frame_equal(expected_df, got_df)


def make_ipc_stream_reader(schema, batches):
return pa.ipc.open_stream(make_serialized(schema, batches))


def make_py_record_batch_reader(schema, batches):
return pa.ipc.RecordBatchReader.from_batches(schema, batches)


@needs_cffi
def test_export_import_batch_reader():
@pytest.mark.parametrize('reader_factory',
[make_ipc_stream_reader,
make_py_record_batch_reader])
def test_export_import_batch_reader(reader_factory):
c_stream = ffi.new("struct ArrowArrayStream*")
ptr_stream = int(ffi.cast("uintptr_t", c_stream))

gc.collect() # Make sure no Arrow data dangles in a ref cycle
old_allocated = pa.total_allocated_bytes()

_export_import_batch_reader(ptr_stream)
_export_import_batch_reader(ptr_stream, reader_factory)

assert pa.total_allocated_bytes() == old_allocated

Expand Down
35 changes: 35 additions & 0 deletions python/pyarrow/tests/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
# specific language governing permissions and limitations
# under the License.

from collections import UserList
import io
import pytest
import socket
import sys
import threading
import weakref

import numpy as np

Expand Down Expand Up @@ -858,3 +860,36 @@ def test_write_empty_ipc_file():
table = reader.read_all()
assert len(table) == 0
assert table.schema.equals(schema)


def test_py_record_batch_reader():
def make_schema():
return pa.schema([('field', pa.int64())])

def make_batches():
schema = make_schema()
batch1 = pa.record_batch([[1, 2, 3]], schema=schema)
batch2 = pa.record_batch([[4, 5]], schema=schema)
return [batch1, batch2]

# With iterable
batches = UserList(make_batches()) # weakrefable
wr = weakref.ref(batches)

with pa.ipc.RecordBatchReader.from_batches(make_schema(),
batches) as reader:
batches = None
assert wr() is not None
assert list(reader) == make_batches()
assert wr() is None

# With iterator
batches = iter(UserList(make_batches())) # weakrefable
wr = weakref.ref(batches)

with pa.ipc.RecordBatchReader.from_batches(make_schema(),
batches) as reader:
batches = None
assert wr() is not None
assert list(reader) == make_batches()
assert wr() is None

0 comments on commit 9ade9a9

Please sign in to comment.