Skip to content

Commit

Permalink
Redo MessageSerializer with unions. Still has bugs
Browse files Browse the repository at this point in the history
Change-Id: Ib8beb014310219a7ab8263802ec94d2ea5af6805
  • Loading branch information
wesm committed Jan 20, 2017
1 parent 21854cc commit ba8db91
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 110 deletions.
11 changes: 6 additions & 5 deletions cpp/src/arrow/ipc/adapter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,12 @@ class RecordBatchWriter : public ArrayVisitor {
num_rows_, body_length, field_nodes_, buffer_meta_, &metadata_fb));

// Need to write 4 bytes (metadata size), the metadata, plus padding to
// fall on a 64-byte offset
int64_t padded_metadata_length =
BitUtil::RoundUpToMultipleOf64(metadata_fb->size() + 4);
// fall on an 8-byte offset
int64_t padded_metadata_length = BitUtil::CeilByte(metadata_fb->size() + 4);

// The returned metadata size includes the length prefix, the flatbuffer,
// plus padding
*metadata_length = padded_metadata_length;
*metadata_length = static_cast<int32_t>(padded_metadata_length);

// Write the flatbuffer size prefix
int32_t flatbuffer_size = metadata_fb->size();
Expand Down Expand Up @@ -604,7 +603,9 @@ Status ReadRecordBatchMetadata(int64_t offset, int32_t metadata_length,
return Status::Invalid(ss.str());
}

*metadata = std::make_shared<RecordBatchMetadata>(buffer, sizeof(int32_t));
std::shared_ptr<Message> message;
RETURN_NOT_OK(Message::Open(buffer, 4, &message));
*metadata = std::make_shared<RecordBatchMetadata>(message);
return Status::OK();
}

Expand Down
21 changes: 4 additions & 17 deletions cpp/src/arrow/ipc/metadata-internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,23 +320,10 @@ Status MessageBuilder::SetRecordBatch(int32_t length, int64_t body_length,
Status WriteRecordBatchMetadata(int32_t length, int64_t body_length,
const std::vector<flatbuf::FieldNode>& nodes,
const std::vector<flatbuf::Buffer>& buffers, std::shared_ptr<Buffer>* out) {
flatbuffers::FlatBufferBuilder fbb;

auto batch = flatbuf::CreateRecordBatch(
fbb, length, fbb.CreateVectorOfStructs(nodes), fbb.CreateVectorOfStructs(buffers));

fbb.Finish(batch);

int32_t size = fbb.GetSize();

auto result = std::make_shared<PoolBuffer>();
RETURN_NOT_OK(result->Resize(size));

uint8_t* dst = result->mutable_data();
memcpy(dst, fbb.GetBufferPointer(), size);

*out = result;
return Status::OK();
MessageBuilder builder;
RETURN_NOT_OK(builder.SetRecordBatch(length, body_length, nodes, buffers));
RETURN_NOT_OK(builder.Finish());
return builder.GetBuffer(out);
}

Status MessageBuilder::Finish() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,33 +70,24 @@ public static int bytesToInt(byte[] bytes) {
*/
public static long serialize(WriteChannel out, Schema schema) throws IOException {
FlatBufferBuilder builder = new FlatBufferBuilder();
builder.finish(schema.getSchema(builder));
ByteBuffer serializedBody = builder.dataBuffer();
ByteBuffer serializedHeader =
serializeHeader(MessageHeader.Schema, serializedBody.remaining());

long size = out.writeIntLittleEndian(serializedHeader.remaining());
size += out.write(serializedHeader);
size += out.write(serializedBody);
int schemaOffset = schema.getSchema(builder);
ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.Schema, schemaOffset, 0);
long size = out.writeIntLittleEndian(serializedMessage.remaining());
size += out.write(serializedMessage);
return size;
}

/**
* Deserializes a schema object. Format is from serialize().
*/
public static Schema deserializeSchema(ReadChannel in) throws IOException {
Message header = deserializeHeader(in, MessageHeader.Schema);
if (header == null) {
Message message = deserializeMessage(in, MessageHeader.Schema);
if (message == null) {
throw new IOException("Unexpected end of input. Missing schema.");
}

// Now read the schema.
ByteBuffer buffer = ByteBuffer.allocate((int)header.bodyLength());
if (in.readFully(buffer) != header.bodyLength()) {
throw new IOException("Unexpected end of input trying to read schema.");
}
buffer.rewind();
return Schema.deserialize(buffer);
return Schema.convertSchema((org.apache.arrow.flatbuf.Schema)
message.header(new org.apache.arrow.flatbuf.Schema()));
}

/**
Expand All @@ -106,37 +97,22 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch)
throws IOException {
long start = out.getCurrentPosition();
int bodyLength = batch.computeBodyLength();
ByteBuffer metadata = WriteChannel.serialize(batch);

int messageLength = 4 + metadata.remaining() + bodyLength;
ByteBuffer serializedHeader =
serializeHeader(MessageHeader.RecordBatch, messageLength);

// Compute the required alignment. This is not a great way to do it. The issue is
// that we need to know the message size to serialize the message header but the
// size depends on the alignment, which depends on the message header.
// This will serialize the header again with the updated size alignment adjusted.
// TODO: We really just want sizeof(MessageHeader) from the serializeHeader() above.
// Is there a way to do this?
long bufferOffset = start + 4 + serializedHeader.remaining() + 4 + metadata.remaining();
if (bufferOffset % 8 != 0) {
messageLength += 8 - bufferOffset % 8;
serializedHeader = serializeHeader(MessageHeader.RecordBatch, messageLength);
}

// Write message header.
out.writeIntLittleEndian(serializedHeader.remaining());
out.write(serializedHeader);
FlatBufferBuilder builder = new FlatBufferBuilder();
int batchOffset = batch.writeTo(builder);

ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.RecordBatch,
batchOffset, bodyLength);

// Write batch header. with the 4 byte little endian prefix
out.writeIntLittleEndian(metadata.remaining());
int metadataSize = metadata.remaining();
long batchStart = out.getCurrentPosition();
out.write(metadata);
long metadataStart = out.getCurrentPosition();
out.writeIntLittleEndian(serializedMessage.remaining());
out.write(serializedMessage);

// Align the output to 8 byte boundary.
out.align();

long metadataSize = out.getCurrentPosition() - metadataStart;

long bufferStart = out.getCurrentPosition();
List<ArrowBuf> buffers = batch.getBuffers();
List<ArrowBuffer> buffersLayout = batch.getBuffersLayout();
Expand All @@ -154,31 +130,31 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch)
" != " + startPosition + layout.getSize());
}
}
return new ArrowBlock(batchStart, metadataSize, out.getCurrentPosition() - bufferStart);
return new ArrowBlock(start, (int) metadataSize, out.getCurrentPosition() - bufferStart);
}

/**
* Deserializes a RecordBatch
*/
public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in,
BufferAllocator alloc) throws IOException {
Message header = deserializeHeader(in, MessageHeader.RecordBatch);
if (header == null) return null;
Message message = deserializeMessage(in, MessageHeader.RecordBatch);
if (message == null) return null;

int messageLen = (int)header.bodyLength();
// Now read the buffer. This has the metadata followed by the data.
ArrowBuf buffer = alloc.buffer(messageLen);
long readPosition = in.getCurrentPositiion();
if (in.readFully(buffer, messageLen) != messageLen) {
throw new IOException("Unexpected end of input trying to read batch.");
if (message.bodyLength() > Integer.MAX_VALUE) {
throw new IOException("Cannot currently deserialize record batches over 2GB");
}

// Read the length of the metadata.
int metadataLen = buffer.readInt();
buffer = buffer.slice(4, messageLen - 4);
readPosition += 4;
messageLen -= 4;
return deserializeRecordBatch(buffer, readPosition, metadataLen, messageLen);
RecordBatch recordBatchFB = (RecordBatch) message.header(new RecordBatch());

int bodyLength = (int) message.bodyLength();

// Now read the record batch body
ArrowBuf buffer = alloc.buffer(bodyLength);
if (in.readFully(buffer, bodyLength) != bodyLength) {
throw new IOException("Unexpected end of input trying to read batch.");
}
return deserializeRecordBatch(recordBatchFB, buffer);
}

/**
Expand All @@ -188,41 +164,41 @@ public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in,
public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, ArrowBlock block,
BufferAllocator alloc) throws IOException {
long readPosition = in.getCurrentPositiion();

// Metadata length contains byte padding
long totalLen = block.getMetadataLength() + block.getBodyLength();
if ((readPosition + block.getMetadataLength()) % 8 != 0) {
// Compute padded size.
totalLen += (8 - (readPosition + block.getMetadataLength()) % 8);
}

if (totalLen > Integer.MAX_VALUE) {
throw new IOException("Cannot currently deserialize record batches over 2GB");
}


ArrowBuf buffer = alloc.buffer((int) totalLen);
if (in.readFully(buffer, (int) totalLen) != totalLen) {
throw new IOException("Unexpected end of input trying to read batch.");
}

return deserializeRecordBatch(buffer, readPosition, block.getMetadataLength(), (int) totalLen);
return deserializeRecordBatch(buffer, block.getMetadataLength(), (int) totalLen);
}

// Deserializes a record batch. Buffer should start at the RecordBatch and include
// all the bytes for the metadata and then data buffers.
private static ArrowRecordBatch deserializeRecordBatch(
ArrowBuf buffer, long readPosition, int metadataLen, int bufferLen) {
private static ArrowRecordBatch deserializeRecordBatch(ArrowBuf buffer, int metadataLen,
int bufferLen) {
// Read the metadata.
RecordBatch recordBatchFB =
RecordBatch.getRootAsRecordBatch(buffer.nioBuffer().asReadOnlyBuffer());

int bufferOffset = metadataLen;
readPosition += bufferOffset;
if (readPosition % 8 != 0) {
bufferOffset += (int)(8 - readPosition % 8);
}

// Now read the body
final ArrowBuf body = buffer.slice(bufferOffset, bufferLen - bufferOffset);
return deserializeRecordBatch(recordBatchFB, body);
}

// Deserializes a record batch given the Flatbuffer metadata and in-memory body
private static ArrowRecordBatch deserializeRecordBatch(RecordBatch recordBatchFB,
ArrowBuf body) {
// Now read the body
int nodesLength = recordBatchFB.nodesLength();
List<ArrowFieldNode> nodes = new ArrayList<>();
for (int i = 0; i < nodesLength; ++i) {
Expand All @@ -237,43 +213,44 @@ private static ArrowRecordBatch deserializeRecordBatch(
}
ArrowRecordBatch arrowRecordBatch =
new ArrowRecordBatch(recordBatchFB.length(), nodes, buffers);
buffer.release();
body.release();
return arrowRecordBatch;
}

/**
* Serializes a message header.
*/
private static ByteBuffer serializeHeader(byte headerType, int bodyLength) {
FlatBufferBuilder headerBuilder = new FlatBufferBuilder();
Message.startMessage(headerBuilder);
Message.addHeaderType(headerBuilder, headerType);
Message.addVersion(headerBuilder, MetadataVersion.V1);
Message.addBodyLength(headerBuilder, bodyLength);
headerBuilder.finish(Message.endMessage(headerBuilder));
return headerBuilder.dataBuffer();
private static ByteBuffer serializeMessage(FlatBufferBuilder builder, byte headerType,
int headerOffset, int bodyLength) {
Message.startMessage(builder);
Message.addHeaderType(builder, headerType);
Message.addHeader(builder, headerOffset);
Message.addVersion(builder, MetadataVersion.V1);
Message.addBodyLength(builder, bodyLength);
builder.finish(Message.endMessage(builder));
return builder.dataBuffer();
}

private static Message deserializeHeader(ReadChannel in, byte headerType) throws IOException {
// Read the header size. There is an i32 little endian prefix.
private static Message deserializeMessage(ReadChannel in, byte headerType) throws IOException {
// Read the message size. There is an i32 little endian prefix.
ByteBuffer buffer = ByteBuffer.allocate(4);
if (in.readFully(buffer) != 4) {
return null;
}

int headerLength = bytesToInt(buffer.array());
buffer = ByteBuffer.allocate(headerLength);
if (in.readFully(buffer) != headerLength) {
int messageLength = bytesToInt(buffer.array());
buffer = ByteBuffer.allocate(messageLength);
if (in.readFully(buffer) != messageLength) {
throw new IOException(
"Unexpected end of stream trying to read header.");
"Unexpected end of stream trying to read message.");
}
buffer.rewind();

Message header = Message.getRootAsMessage(buffer);
if (header.headerType() != headerType) {
Message message = Message.getRootAsMessage(buffer);
if (message.headerType() != headerType) {
throw new IOException("Invalid message: expecting " + headerType +
". Message contained: " + header.headerType());
". Message contained: " + message.headerType());
}
return header;
return message;
}
}

0 comments on commit ba8db91

Please sign in to comment.