Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-37429: [C++] Add arrow::ipc::StreamDecoder::Reset() #37970

Merged
merged 4 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions cpp/src/arrow/ipc/read_write_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2161,6 +2161,43 @@ TEST(TestRecordBatchStreamReader, MalformedInput) {
ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader));
}

namespace {
class EndlessCollectListener : public CollectListener {
public:
EndlessCollectListener() : CollectListener(), decoder_(nullptr) {}

void SetDecoder(StreamDecoder* decoder) { decoder_ = decoder; }

arrow::Status OnEOS() override { return decoder_->Reset(); }

private:
StreamDecoder* decoder_;
};
}; // namespace

TEST(TestStreamDecoder, Reset) {
auto listener = std::make_shared<EndlessCollectListener>();
StreamDecoder decoder(listener);
listener->SetDecoder(&decoder);

std::shared_ptr<RecordBatch> batch;
ASSERT_OK(MakeIntRecordBatch(&batch));
StreamWriterHelper writer_helper;
ASSERT_OK(writer_helper.Init(batch->schema(), IpcWriteOptions::Defaults()));
ASSERT_OK(writer_helper.WriteBatch(batch));
ASSERT_OK(writer_helper.Finish());

ASSERT_OK_AND_ASSIGN(auto all_buffer, ConcatenateBuffers({writer_helper.buffer_,
writer_helper.buffer_}));
// Consume by Buffer
ASSERT_OK(decoder.Consume(all_buffer));
ASSERT_EQ(2, listener->num_record_batches());

// Consume by raw data
ASSERT_OK(decoder.Consume(all_buffer->data(), all_buffer->size()));
ASSERT_EQ(4, listener->num_record_batches());
}

TEST(TestStreamDecoder, NextRequiredSize) {
auto listener = std::make_shared<CollectListener>();
StreamDecoder decoder(listener);
Expand Down
79 changes: 75 additions & 4 deletions cpp/src/arrow/ipc/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -909,14 +909,18 @@ class StreamDecoderInternal : public MessageDecoderListener {
return listener_->OnEOS();
}

std::shared_ptr<Listener> listener() const { return listener_; }

Listener* raw_listener() const { return listener_.get(); }

IpcReadOptions options() const { return options_; }

State state() const { return state_; }

std::shared_ptr<Schema> schema() const { return filtered_schema_; }

ReadStats stats() const { return stats_; }

State state() const { return state_; }

int num_required_initial_dictionaries() const {
return num_required_initial_dictionaries_;
}
Expand Down Expand Up @@ -2016,6 +2020,8 @@ class StreamDecoder::StreamDecoderImpl : public StreamDecoderInternal {

int64_t next_required_size() const { return message_decoder_.next_required_size(); }

const MessageDecoder* message_decoder() const { return &message_decoder_; }

private:
MessageDecoder message_decoder_;
};
Expand All @@ -2027,10 +2033,75 @@ StreamDecoder::StreamDecoder(std::shared_ptr<Listener> listener, IpcReadOptions
StreamDecoder::~StreamDecoder() {}

Status StreamDecoder::Consume(const uint8_t* data, int64_t size) {
return impl_->Consume(data, size);
while (size > 0) {
const auto next_required_size = impl_->next_required_size();
if (next_required_size == 0) {
break;
}
if (size < next_required_size) {
break;
}
ARROW_RETURN_NOT_OK(impl_->Consume(data, next_required_size));
data += next_required_size;
size -= next_required_size;
}
if (size > 0) {
return impl_->Consume(data, size);
} else {
return arrow::Status::OK();
}
}

Status StreamDecoder::Consume(std::shared_ptr<Buffer> buffer) {
return impl_->Consume(std::move(buffer));
if (buffer->size() == 0) {
return arrow::Status::OK();
}
if (impl_->next_required_size() == 0 || buffer->size() <= impl_->next_required_size()) {
return impl_->Consume(std::move(buffer));
} else {
int64_t offset = 0;
while (true) {
const auto next_required_size = impl_->next_required_size();
if (next_required_size == 0) {
break;
}
if (buffer->size() - offset <= next_required_size) {
break;
}
if (buffer->is_cpu()) {
switch (impl_->message_decoder()->state()) {
case MessageDecoder::State::INITIAL:
case MessageDecoder::State::METADATA_LENGTH:
// We don't need to pass a sliced buffer because
// MessageDecoder doesn't keep reference of the given
// buffer on these states.
ARROW_RETURN_NOT_OK(
impl_->Consume(buffer->data() + offset, next_required_size));
break;
default:
ARROW_RETURN_NOT_OK(
impl_->Consume(SliceBuffer(buffer, offset, next_required_size)));
break;
}
} else {
ARROW_RETURN_NOT_OK(
impl_->Consume(SliceBuffer(buffer, offset, next_required_size)));
}
offset += next_required_size;
}
if (buffer->size() - offset == 0) {
return arrow::Status::OK();
} else if (offset == 0) {
return impl_->Consume(std::move(buffer));
} else {
return impl_->Consume(SliceBuffer(std::move(buffer), offset));
}
}
}

Status StreamDecoder::Reset() {
impl_ = std::make_unique<StreamDecoderImpl>(impl_->listener(), impl_->options());
return Status::OK();
}

std::shared_ptr<Schema> StreamDecoder::schema() const { return impl_->schema(); }
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/ipc/reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,14 @@ class ARROW_EXPORT StreamDecoder {
/// \return Status
Status Consume(std::shared_ptr<Buffer> buffer);

/// \brief Reset the internal status.
///
/// You can reuse this decoder for new stream after calling
/// this.
///
/// \return Status
Status Reset();

/// \return the shared schema of the record batches in the stream
std::shared_ptr<Schema> schema() const;

Expand Down