Skip to content

Commit

Permalink
Add Python wrapper and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Aug 31, 2020
1 parent 593d1aa commit 28cc39c
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 62 deletions.
15 changes: 11 additions & 4 deletions cpp/src/arrow/c/bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1618,9 +1618,7 @@ namespace {

class ArrayStreamBatchReader : public RecordBatchReader {
public:
explicit ArrayStreamBatchReader(struct ArrowArrayStream* stream) : stream_(stream) {
DCHECK(!ArrowArrayStreamIsReleased(stream_));
}
explicit ArrayStreamBatchReader(struct ArrowArrayStream* stream) : stream_(stream) {}

~ArrayStreamBatchReader() { ArrowArrayStreamRelease(stream_); }

Expand All @@ -1629,7 +1627,13 @@ class ArrayStreamBatchReader : public RecordBatchReader {
Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
struct ArrowArray c_array;
RETURN_NOT_OK(StatusFromCError(stream_->get_next(stream_, &c_array)));
return ImportRecordBatch(&c_array, CacheSchema()).Value(batch);
if (ArrowArrayIsReleased(&c_array)) {
// End of stream
batch->reset();
return Status::OK();
} else {
return ImportRecordBatch(&c_array, CacheSchema()).Value(batch);
}
}

private:
Expand Down Expand Up @@ -1667,6 +1671,9 @@ class ArrayStreamBatchReader : public RecordBatchReader {

Result<std::shared_ptr<RecordBatchReader>> ImportRecordBatchReader(
struct ArrowArrayStream* stream) {
if (ArrowArrayStreamIsReleased(stream)) {
return Status::Invalid("Cannot import released ArrowArrayStream");
}
// XXX should we call get_schema() here to avoid crashing on error?
return std::make_shared<ArrayStreamBatchReader>(stream);
}
Expand Down
4 changes: 2 additions & 2 deletions python/pyarrow/_csv.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ from collections.abc import Mapping
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.lib cimport (check_status, Field, MemoryPool, Schema,
_CRecordBatchReader, ensure_type,
RecordBatchReader, ensure_type,
maybe_unbox_memory_pool, get_input_stream,
native_transcoding_input_stream,
pyarrow_wrap_schema, pyarrow_wrap_table,
Expand Down Expand Up @@ -633,7 +633,7 @@ cdef _get_convert_options(ConvertOptions convert_options,
out[0] = convert_options.options


cdef class CSVStreamingReader(_CRecordBatchReader):
cdef class CSVStreamingReader(RecordBatchReader):
"""An object that reads record batches incrementally from a CSV file.
Should not be instantiated directly by user code.
Expand Down
15 changes: 7 additions & 8 deletions python/pyarrow/_flight.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ from pyarrow.lib cimport *
from pyarrow.lib import ArrowException, ArrowInvalid
from pyarrow.lib import as_buffer, frombytes, tobytes
from pyarrow.includes.libarrow_flight cimport *
from pyarrow.ipc import _ReadPandasOption, _get_legacy_format_default
from pyarrow.ipc import _get_legacy_format_default, _ReadPandasMixin
import pyarrow.lib as lib


Expand Down Expand Up @@ -812,7 +812,7 @@ cdef class FlightStreamChunk(_Weakrefable):
self.chunk.data != NULL, self.chunk.app_metadata != NULL)


cdef class _MetadataRecordBatchReader(_Weakrefable):
cdef class _MetadataRecordBatchReader(_Weakrefable, _ReadPandasMixin):
"""A reader for Flight streams."""

# Needs to be separate class so the "real" class can subclass the
Expand Down Expand Up @@ -869,8 +869,7 @@ cdef class _MetadataRecordBatchReader(_Weakrefable):
return chunk


cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader,
_ReadPandasOption):
cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader):
"""The virtual base class for readers for Flight streams."""


Expand Down Expand Up @@ -1365,7 +1364,7 @@ cdef class RecordBatchStream(FlightDataStream):
data_source : RecordBatchReader or Table
options : pyarrow.ipc.IpcWriteOptions, optional
"""
if (not isinstance(data_source, _CRecordBatchReader) and
if (not isinstance(data_source, RecordBatchReader) and
not isinstance(data_source, lib.Table)):
raise TypeError("Expected RecordBatchReader or Table, "
"but got: {}".format(type(data_source)))
Expand All @@ -1375,8 +1374,8 @@ cdef class RecordBatchStream(FlightDataStream):
cdef CFlightDataStream* to_stream(self) except *:
cdef:
shared_ptr[CRecordBatchReader] reader
if isinstance(self.data_source, _CRecordBatchReader):
reader = (<_CRecordBatchReader> self.data_source).reader
if isinstance(self.data_source, RecordBatchReader):
reader = (<RecordBatchReader> self.data_source).reader
elif isinstance(self.data_source, lib.Table):
table = (<Table> self.data_source).table
reader.reset(new TableBatchReader(deref(table)))
Expand Down Expand Up @@ -1616,7 +1615,7 @@ cdef CStatus _data_stream_next(void* self, CFlightPayload* payload) except *:
else:
result, metadata = result, None

if isinstance(result, (Table, _CRecordBatchReader)):
if isinstance(result, (Table, RecordBatchReader)):
if metadata:
raise ValueError("Can only return metadata alongside a "
"RecordBatch.")
Expand Down
12 changes: 12 additions & 0 deletions python/pyarrow/cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@
// Opaque producer-specific data
void* private_data;
};
struct ArrowArrayStream {
int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out);
int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out);
const char* (*get_last_error)(struct ArrowArrayStream*);
// Release callback
void (*release)(struct ArrowArrayStream*);
// Opaque producer-specific data
void* private_data;
};
"""

# TODO use out-of-line mode for faster import and avoid C parsing
Expand Down
10 changes: 9 additions & 1 deletion python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1390,7 +1390,7 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
" arrow::ipc::RecordBatchStreamReader"(CRecordBatchReader):
@staticmethod
CResult[shared_ptr[CRecordBatchReader]] Open(
const CInputStream* stream, const CIpcReadOptions& options)
const shared_ptr[CInputStream], const CIpcReadOptions&)

@staticmethod
CResult[shared_ptr[CRecordBatchReader]] Open2" Open"(
Expand Down Expand Up @@ -2049,6 +2049,9 @@ cdef extern from 'arrow/c/abi.h':
cdef struct ArrowArray:
pass

cdef struct ArrowArrayStream:
pass

cdef extern from 'arrow/c/bridge.h' namespace 'arrow' nogil:
CStatus ExportType(CDataType&, ArrowSchema* out)
CResult[shared_ptr[CDataType]] ImportType(ArrowSchema*)
Expand All @@ -2069,3 +2072,8 @@ cdef extern from 'arrow/c/bridge.h' namespace 'arrow' nogil:
shared_ptr[CSchema])
CResult[shared_ptr[CRecordBatch]] ImportRecordBatch(ArrowArray*,
ArrowSchema*)

CStatus ExportRecordBatchReader(shared_ptr[CRecordBatchReader],
ArrowArrayStream*)
CResult[shared_ptr[CRecordBatchReader]] ImportRecordBatchReader(
ArrowArrayStream*)
92 changes: 83 additions & 9 deletions python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,29 @@ cdef _get_input_stream(object source, shared_ptr[CInputStream]* out):
get_input_stream(source, True, out)


cdef class _CRecordBatchReader(_Weakrefable):
"""The base RecordBatchReader wrapper.
class _ReadPandasMixin:

def read_pandas(self, **options):
"""
Read contents of stream to a pandas.DataFrame.
Read all record batches as a pyarrow.Table then convert it to a
pandas.DataFrame using Table.to_pandas.
Parameters
----------
**options : arguments to forward to Table.to_pandas
Returns
-------
df : pandas.DataFrame
"""
table = self.read_all()
return table.to_pandas(**options)


cdef class RecordBatchReader(_Weakrefable):
"""Base class for reading stream of record batches.
Provides common implementations of convenience methods. Should not
be instantiated directly by user code.
Expand All @@ -411,6 +432,18 @@ cdef class _CRecordBatchReader(_Weakrefable):
while True:
yield self.read_next_batch()

@property
def schema(self):
"""
Shared schema of the record batches in the stream.
"""
cdef shared_ptr[CSchema] c_schema

with nogil:
c_schema = self.reader.get().schema()

return pyarrow_wrap_schema(c_schema)

def get_next_batch(self):
import warnings
warnings.warn('Please use read_next_batch instead of '
Expand Down Expand Up @@ -445,31 +478,70 @@ cdef class _CRecordBatchReader(_Weakrefable):
check_status(self.reader.get().ReadAll(&table))
return pyarrow_wrap_table(table)

read_pandas = _ReadPandasMixin.read_pandas

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
pass

def _export_to_c(self, uintptr_t out_ptr):
"""
Export to a C ArrowArrayStream struct, given its pointer.
Parameters
----------
out_ptr: int
The raw pointer to a C ArrowArrayStream struct.
cdef class _RecordBatchStreamReader(_CRecordBatchReader):
Be careful: if you don't pass the ArrowArrayStream struct to a
consumer, array memory will leak. This is a low-level function
intended for expert users.
"""
with nogil:
check_status(ExportRecordBatchReader(
self.reader, <ArrowArrayStream*> out_ptr))

@staticmethod
def _import_from_c(uintptr_t in_ptr):
"""
Import RecordBatchReader from a C ArrowArrayStream struct,
given its pointer.
Parameters
----------
in_ptr: int
The raw pointer to a C ArrowArrayStream struct.
This is a low-level function intended for expert users.
"""
cdef:
shared_ptr[CRecordBatchReader] c_reader
RecordBatchReader self

with nogil:
c_reader = GetResultValue(ImportRecordBatchReader(
<ArrowArrayStream*> in_ptr))

self = RecordBatchReader.__new__(RecordBatchReader)
self.reader = c_reader
return self


cdef class _RecordBatchStreamReader(RecordBatchReader):
cdef:
shared_ptr[CInputStream] in_stream
CIpcReadOptions options

cdef readonly:
Schema schema

def __cinit__(self):
pass

def _open(self, source):
_get_input_stream(source, &self.in_stream)
with nogil:
self.reader = GetResultValue(CRecordBatchStreamReader.Open(
self.in_stream.get(), self.options))

self.schema = pyarrow_wrap_schema(self.reader.get().schema())
self.in_stream, self.options))


cdef class _RecordBatchFileWriter(_RecordBatchStreamWriter):
Expand Down Expand Up @@ -560,6 +632,8 @@ cdef class _RecordBatchFileReader(_Weakrefable):

return pyarrow_wrap_table(table)

read_pandas = _ReadPandasMixin.read_pandas

def __enter__(self):
return self

Expand Down
26 changes: 3 additions & 23 deletions python/pyarrow/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,15 @@
import pyarrow as pa

from pyarrow.lib import (IpcWriteOptions, Message, MessageReader, # noqa
RecordBatchReader, _ReadPandasMixin,
MetadataVersion,
read_message, read_record_batch, read_schema,
read_tensor, write_tensor,
get_record_batch_size, get_tensor_size)
import pyarrow.lib as lib


class _ReadPandasOption:

def read_pandas(self, **options):
"""
Read contents of stream to a pandas.DataFrame.
Read all record batches as a pyarrow.Table then convert it to a
pandas.DataFrame using Table.to_pandas.
Parameters
----------
**options : arguments to forward to Table.to_pandas
Returns
-------
df : pandas.DataFrame
"""
table = self.read_all()
return table.to_pandas(**options)


class RecordBatchStreamReader(lib._RecordBatchStreamReader, _ReadPandasOption):
class RecordBatchStreamReader(lib._RecordBatchStreamReader):
"""
Reader for the Arrow streaming binary format.
Expand Down Expand Up @@ -97,7 +77,7 @@ def __init__(self, sink, schema, *, use_legacy_format=None, options=None):
self._open(sink, schema, options=options)


class RecordBatchFileReader(lib._RecordBatchFileReader, _ReadPandasOption):
class RecordBatchFileReader(lib._RecordBatchFileReader):
"""
Class for reading Arrow record batch data from the Arrow binary file format
Expand Down
2 changes: 1 addition & 1 deletion python/pyarrow/lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ cdef class _CRecordBatchWriter(_Weakrefable):
shared_ptr[CRecordBatchWriter] writer


cdef class _CRecordBatchReader(_Weakrefable):
cdef class RecordBatchReader(_Weakrefable):
cdef:
shared_ptr[CRecordBatchReader] reader

Expand Down
Loading

0 comments on commit 28cc39c

Please sign in to comment.