Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Amortized intermediate allocations in IPC writer (#1362)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 16, 2023
1 parent d97526b commit 4ed6f26
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/io/flight/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pub fn serialize_schema_to_info(
};

let mut schema = vec![];
write::common_sync::write_message(&mut schema, encoded_data)?;
write::common_sync::write_message(&mut schema, &encoded_data)?;
Ok(schema)
}

Expand Down
1 change: 1 addition & 0 deletions src/io/ipc/append/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ impl<R: Read + Seek + Write> FileWriter<R> {
dictionaries,
cannot_replace: true,
},
encoded_message: Default::default(),
})
}
}
4 changes: 2 additions & 2 deletions src/io/ipc/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ pub fn compress_zstd(input_buf: &[u8], output_buf: &mut Vec<u8>) -> Result<()> {
}

#[cfg(not(feature = "io_ipc_compression"))]
pub fn compress_lz4(_input_buf: &[u8], _output_buf: &mut Vec<u8>) -> Result<()> {
pub fn compress_lz4(_input_buf: &[u8], _output_buf: &[u8]) -> Result<()> {
use crate::error::Error;
Err(Error::OutOfSpec("The crate was compiled without IPC compression. Use `io_ipc_compression` to write compressed IPC.".to_string()))
}

#[cfg(not(feature = "io_ipc_compression"))]
pub fn compress_zstd(_input_buf: &[u8], _output_buf: &mut Vec<u8>) -> Result<()> {
pub fn compress_zstd(_input_buf: &[u8], _output_buf: &[u8]) -> Result<()> {
use crate::error::Error;
Err(Error::OutOfSpec("The crate was compiled without IPC compression. Use `io_ipc_compression` to write compressed IPC.".to_string()))
}
Expand Down
42 changes: 32 additions & 10 deletions src/io/ipc/write/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,25 @@ pub fn encode_chunk(
dictionary_tracker: &mut DictionaryTracker,
options: &WriteOptions,
) -> Result<(Vec<EncodedData>, EncodedData)> {
let mut encoded_message = EncodedData::default();
let encoded_dictionaries = encode_chunk_amortized(
chunk,
fields,
dictionary_tracker,
options,
&mut encoded_message,
)?;
Ok((encoded_dictionaries, encoded_message))
}

// Amortizes `EncodedData` allocation.
pub fn encode_chunk_amortized(
chunk: &Chunk<Box<dyn Array>>,
fields: &[IpcField],
dictionary_tracker: &mut DictionaryTracker,
options: &WriteOptions,
encoded_message: &mut EncodedData,
) -> Result<Vec<EncodedData>> {
let mut encoded_dictionaries = vec![];

for (field, array) in fields.iter().zip(chunk.as_ref()) {
Expand All @@ -189,9 +208,9 @@ pub fn encode_chunk(
)?;
}

let encoded_message = chunk_to_bytes(chunk, options);
chunk_to_bytes_amortized(chunk, options, encoded_message);

Ok((encoded_dictionaries, encoded_message))
Ok(encoded_dictionaries)
}

fn serialize_compression(
Expand All @@ -213,10 +232,16 @@ fn serialize_compression(

/// Write [`Chunk`] into two sets of bytes, one for the header (ipc::Schema::Message) and the
/// other for the batch's data
fn chunk_to_bytes(chunk: &Chunk<Box<dyn Array>>, options: &WriteOptions) -> EncodedData {
fn chunk_to_bytes_amortized(
chunk: &Chunk<Box<dyn Array>>,
options: &WriteOptions,
encoded_message: &mut EncodedData,
) {
let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
let mut arrow_data: Vec<u8> = vec![];
let mut arrow_data = std::mem::take(&mut encoded_message.arrow_data);
arrow_data.clear();

let mut offset = 0;
for array in chunk.arrays() {
write(
Expand Down Expand Up @@ -248,11 +273,8 @@ fn chunk_to_bytes(chunk: &Chunk<Box<dyn Array>>, options: &WriteOptions) -> Enco

let mut builder = Builder::new();
let ipc_message = builder.finish(&message, None);

EncodedData {
ipc_message: ipc_message.to_vec(),
arrow_data,
}
encoded_message.ipc_message = ipc_message.to_vec();
encoded_message.arrow_data = arrow_data
}

/// Write dictionary values into two sets of bytes, one for the header (ipc::Schema::Message) and the
Expand Down Expand Up @@ -360,7 +382,7 @@ impl DictionaryTracker {
}

/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data
#[derive(Debug)]
#[derive(Debug, Default)]
pub struct EncodedData {
/// An encoded ipc::Schema::Message
pub ipc_message: Vec<u8>,
Expand Down
10 changes: 6 additions & 4 deletions src/io/ipc/write/common_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use super::common::pad_to_64;
use super::common::EncodedData;

/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written
pub fn write_message<W: Write>(writer: &mut W, encoded: EncodedData) -> Result<(usize, usize)> {
pub fn write_message<W: Write>(writer: &mut W, encoded: &EncodedData) -> Result<(usize, usize)> {
let arrow_data_len = encoded.arrow_data.len();

let a = 8 - 1;
let buffer = encoded.ipc_message;
let buffer = &encoded.ipc_message;
let flatbuf_size = buffer.len();
let prefix_size = 8;
let aligned_size = (flatbuf_size + prefix_size + a) & !a;
Expand All @@ -21,10 +21,12 @@ pub fn write_message<W: Write>(writer: &mut W, encoded: EncodedData) -> Result<(

// write the flatbuf
if flatbuf_size > 0 {
writer.write_all(&buffer)?;
writer.write_all(buffer)?;
}
// write padding
writer.write_all(&vec![0; padding_bytes])?;
// aligned to a 8 byte boundary, so maximum is [u8;8]
const PADDING_MAX: [u8; 8] = [0u8; 8];
writer.write_all(&PADDING_MAX[..padding_bytes])?;

// write arrow data
let body_len = if arrow_data_len > 0 {
Expand Down
6 changes: 3 additions & 3 deletions src/io/ipc/write/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl<W: Write> StreamWriter<W> {
ipc_message: schema_to_bytes(schema, self.ipc_fields.as_ref().unwrap()),
arrow_data: vec![],
};
write_message(&mut self.writer, encoded_message)?;
write_message(&mut self.writer, &encoded_message)?;
Ok(())
}

Expand Down Expand Up @@ -91,10 +91,10 @@ impl<W: Write> StreamWriter<W> {
)?;

for encoded_dictionary in encoded_dictionaries {
write_message(&mut self.writer, encoded_dictionary)?;
write_message(&mut self.writer, &encoded_dictionary)?;
}

write_message(&mut self.writer, encoded_message)?;
write_message(&mut self.writer, &encoded_message)?;
Ok(())
}

Expand Down
27 changes: 21 additions & 6 deletions src/io/ipc/write/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use arrow_format::ipc::planus::Builder;
use super::{
super::IpcField,
super::ARROW_MAGIC,
common::{encode_chunk, DictionaryTracker, EncodedData, WriteOptions},
common::{DictionaryTracker, EncodedData, WriteOptions},
common_sync::{write_continuation, write_message},
default_ipc_fields, schema, schema_to_bytes,
};
Expand All @@ -14,6 +14,7 @@ use crate::array::Array;
use crate::chunk::Chunk;
use crate::datatypes::*;
use crate::error::{Error, Result};
use crate::io::ipc::write::common::encode_chunk_amortized;

#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) enum State {
Expand Down Expand Up @@ -41,6 +42,8 @@ pub struct FileWriter<W: Write> {
pub(crate) state: State,
/// Keeps track of dictionaries that have been written
pub(crate) dictionary_tracker: DictionaryTracker,
/// Buffer/scratch that is reused between writes
pub(crate) encoded_message: EncodedData,
}

impl<W: Write> FileWriter<W> {
Expand Down Expand Up @@ -83,6 +86,7 @@ impl<W: Write> FileWriter<W> {
dictionaries: Default::default(),
cannot_replace: true,
},
encoded_message: Default::default(),
}
}

Expand All @@ -91,6 +95,17 @@ impl<W: Write> FileWriter<W> {
self.writer
}

/// Get the inner memory scratches so they can be reused in a new writer.
/// This can be utilized to save memory allocations for performance reasons.
pub fn get_scratches(&mut self) -> EncodedData {
std::mem::take(&mut self.encoded_message)
}
/// Set the inner memory scratches so they can be reused in a new writer.
/// This can be utilized to save memory allocations for performance reasons.
pub fn set_scratches(&mut self, scratches: EncodedData) {
self.encoded_message = scratches;
}

/// Writes the header and first (schema) message to the file.
/// # Errors
/// Errors if the file has been started or has finished.
Expand All @@ -109,7 +124,7 @@ impl<W: Write> FileWriter<W> {
arrow_data: vec![],
};

let (meta, data) = write_message(&mut self.writer, encoded_message)?;
let (meta, data) = write_message(&mut self.writer, &encoded_message)?;
self.block_offsets += meta + data + 8; // 8 <=> arrow magic + 2 bytes for alignment
self.state = State::Started;
Ok(())
Expand All @@ -132,17 +147,17 @@ impl<W: Write> FileWriter<W> {
} else {
self.ipc_fields.as_ref()
};

let (encoded_dictionaries, encoded_message) = encode_chunk(
let encoded_dictionaries = encode_chunk_amortized(
chunk,
ipc_fields,
&mut self.dictionary_tracker,
&self.options,
&mut self.encoded_message,
)?;

// add all dictionaries
for encoded_dictionary in encoded_dictionaries {
let (meta, data) = write_message(&mut self.writer, encoded_dictionary)?;
let (meta, data) = write_message(&mut self.writer, &encoded_dictionary)?;

let block = arrow_format::ipc::Block {
offset: self.block_offsets as i64,
Expand All @@ -153,7 +168,7 @@ impl<W: Write> FileWriter<W> {
self.block_offsets += meta + data;
}

let (meta, data) = write_message(&mut self.writer, encoded_message)?;
let (meta, data) = write_message(&mut self.writer, &self.encoded_message)?;
// add a record block for the footer
let block = arrow_format::ipc::Block {
offset: self.block_offsets as i64,
Expand Down
3 changes: 1 addition & 2 deletions src/temporal_conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,7 @@ fn chrono_tz_utf_to_timestamp_ns<O: Offset>(
timezone: String,
) -> Result<PrimitiveArray<i64>> {
Err(Error::InvalidArgumentError(format!(
"timezone \"{}\" cannot be parsed (feature chrono-tz is not active)",
timezone
"timezone \"{timezone}\" cannot be parsed (feature chrono-tz is not active)",
)))
}

Expand Down

0 comments on commit 4ed6f26

Please sign in to comment.