From c7a802aaf13880e5c19eb804b2e86d207216c5d4 Mon Sep 17 00:00:00 2001 From: Jorge Leitao Date: Sat, 6 Nov 2021 05:23:38 +0100 Subject: [PATCH] Added support to write Arrow IPC streams asynchronously (#577) * Added `async` writer of the Arrow stream. * Added test. --- Cargo.toml | 3 ++ README.md | 1 + src/io/flight/mod.rs | 2 +- src/io/ipc/write/common.rs | 59 ------------------------ src/io/ipc/write/common_async.rs | 74 ++++++++++++++++++++++++++++++ src/io/ipc/write/common_sync.rs | 64 ++++++++++++++++++++++++++ src/io/ipc/write/mod.rs | 10 ++++- src/io/ipc/write/stream.rs | 5 +-- src/io/ipc/write/stream_async.rs | 77 ++++++++++++++++++++++++++++++++ src/io/ipc/write/writer.rs | 6 +-- tests/it/io/ipc/mod.rs | 3 ++ tests/it/io/ipc/write_async.rs | 51 +++++++++++++++++++++ 12 files changed, 287 insertions(+), 68 deletions(-) create mode 100644 src/io/ipc/write/common_async.rs create mode 100644 src/io/ipc/write/common_sync.rs create mode 100644 src/io/ipc/write/stream_async.rs create mode 100644 tests/it/io/ipc/write_async.rs diff --git a/Cargo.toml b/Cargo.toml index d67581de15f..986c30a0f9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -101,6 +101,7 @@ full = [ "io_json", "io_ipc", "io_flight", + "io_ipc_write_async", "io_ipc_compression", "io_json_integration", "io_print", @@ -121,6 +122,7 @@ io_csv_read_async = ["csv-async", "lexical-core", "futures"] io_csv_write = ["csv", "streaming-iterator", "lexical-core"] io_json = ["serde", "serde_json", "indexmap"] io_ipc = ["arrow-format"] +io_ipc_write_async = ["io_ipc", "futures"] io_ipc_compression = ["lz4", "zstd"] io_flight = ["io_ipc", "arrow-format/flight-data"] io_parquet_compression = [ @@ -163,6 +165,7 @@ skip_feature_sets = [ ["io_json"], ["io_flight"], ["io_ipc"], + ["io_ipc_write_async"], ["io_parquet"], ["io_json_integration"], # this does not change the public API diff --git a/README.md b/README.md index 7d52d0e8ffa..565fe2313bc 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,7 @@ we also use the `0.x.y` versioning, since we are iterating over the API. * Read and write of delta-encoded utf8 to and from parquet * parquet roundtrip of all supported arrow types. +* Async writer of the Arrow stream format ## Features in pyarrow not in this crate diff --git a/src/io/flight/mod.rs b/src/io/flight/mod.rs index 1b772c0512b..ac4f0773f34 100644 --- a/src/io/flight/mod.rs +++ b/src/io/flight/mod.rs @@ -64,7 +64,7 @@ pub fn serialize_schema_to_info(schema: &Schema) -> Result> { let encoded_data = schema_as_encoded_data(schema); let mut schema = vec![]; - write::common::write_message(&mut schema, encoded_data)?; + write::common_sync::write_message(&mut schema, encoded_data)?; Ok(schema) } diff --git a/src/io/ipc/write/common.rs b/src/io/ipc/write/common.rs index 7f92c140b72..c557b5bf147 100644 --- a/src/io/ipc/write/common.rs +++ b/src/io/ipc/write/common.rs @@ -16,7 +16,6 @@ // under the License. //! Common utilities used to write to Arrow's IPC format. -use std::io::Write; use std::{collections::HashMap, sync::Arc}; use arrow_format::ipc; @@ -29,7 +28,6 @@ use crate::io::ipc::endianess::is_native_little_endian; use crate::record_batch::RecordBatch; use crate::{array::DictionaryArray, datatypes::*}; -use super::super::CONTINUATION_MARKER; use super::{write, write_dictionary}; /// Compression codec @@ -292,63 +290,6 @@ pub struct EncodedData { pub arrow_data: Vec, } -/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written -pub fn write_message(writer: &mut W, encoded: EncodedData) -> Result<(usize, usize)> { - let arrow_data_len = encoded.arrow_data.len(); - if arrow_data_len % 8 != 0 { - return Err(ArrowError::Ipc("Arrow data not aligned".to_string())); - } - - let a = 8 - 1; - let buffer = encoded.ipc_message; - let flatbuf_size = buffer.len(); - let prefix_size = 8; - let aligned_size = (flatbuf_size + prefix_size + a) & !a; - let padding_bytes = aligned_size - flatbuf_size - prefix_size; - - write_continuation(writer, (aligned_size - prefix_size) as i32)?; - - // write the flatbuf - if flatbuf_size > 0 { - writer.write_all(&buffer)?; - } - // write padding - writer.write_all(&vec![0; padding_bytes])?; - - // write arrow data - let body_len = if arrow_data_len > 0 { - write_body_buffers(writer, &encoded.arrow_data)? - } else { - 0 - }; - - Ok((aligned_size, body_len)) -} - -fn write_body_buffers(mut writer: W, data: &[u8]) -> Result { - let len = data.len(); - let pad_len = pad_to_8(data.len()); - let total_len = len + pad_len; - - // write body buffer - writer.write_all(data)?; - if pad_len > 0 { - writer.write_all(&vec![0u8; pad_len][..])?; - } - - writer.flush()?; - Ok(total_len) -} - -/// Write a record batch to the writer, writing the message size before the message -/// if the record batch is being written to a stream -pub fn write_continuation(writer: &mut W, total_len: i32) -> Result { - writer.write_all(&CONTINUATION_MARKER)?; - writer.write_all(&total_len.to_le_bytes()[..])?; - writer.flush()?; - Ok(8) -} - /// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes #[inline] pub(crate) fn pad_to_8(len: usize) -> usize { diff --git a/src/io/ipc/write/common_async.rs b/src/io/ipc/write/common_async.rs new file mode 100644 index 00000000000..5880fce32ea --- /dev/null +++ b/src/io/ipc/write/common_async.rs @@ -0,0 +1,74 @@ +use futures::AsyncWrite; +use futures::AsyncWriteExt; + +use crate::error::{ArrowError, Result}; + +use super::super::CONTINUATION_MARKER; +use super::common::pad_to_8; +use super::common::EncodedData; + +/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written +pub async fn write_message( + writer: &mut W, + encoded: EncodedData, +) -> Result<(usize, usize)> { + let arrow_data_len = encoded.arrow_data.len(); + if arrow_data_len % 8 != 0 { + return Err(ArrowError::Ipc("Arrow data not aligned".to_string())); + } + + let a = 8 - 1; + let buffer = encoded.ipc_message; + let flatbuf_size = buffer.len(); + let prefix_size = 8; + let aligned_size = (flatbuf_size + prefix_size + a) & !a; + let padding_bytes = aligned_size - flatbuf_size - prefix_size; + + write_continuation(writer, (aligned_size - prefix_size) as i32).await?; + + // write the flatbuf + if flatbuf_size > 0 { + writer.write_all(&buffer).await?; + } + // write padding + writer.write_all(&vec![0; padding_bytes]).await?; + + // write arrow data + let body_len = if arrow_data_len > 0 { + write_body_buffers(writer, &encoded.arrow_data).await? + } else { + 0 + }; + + Ok((aligned_size, body_len)) +} + +/// Write a record batch to the writer, writing the message size before the message +/// if the record batch is being written to a stream +pub async fn write_continuation( + writer: &mut W, + total_len: i32, +) -> Result { + writer.write_all(&CONTINUATION_MARKER).await?; + writer.write_all(&total_len.to_le_bytes()[..]).await?; + writer.flush().await?; + Ok(8) +} + +async fn write_body_buffers( + mut writer: W, + data: &[u8], +) -> Result { + let len = data.len(); + let pad_len = pad_to_8(data.len()); + let total_len = len + pad_len; + + // write body buffer + writer.write_all(data).await?; + if pad_len > 0 { + writer.write_all(&vec![0u8; pad_len][..]).await?; + } + + writer.flush().await?; + Ok(total_len) +} diff --git a/src/io/ipc/write/common_sync.rs b/src/io/ipc/write/common_sync.rs new file mode 100644 index 00000000000..9e80d446eb1 --- /dev/null +++ b/src/io/ipc/write/common_sync.rs @@ -0,0 +1,64 @@ +use std::io::Write; + +use crate::error::{ArrowError, Result}; + +use super::super::CONTINUATION_MARKER; +use super::common::pad_to_8; +use super::common::EncodedData; + +/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written +pub fn write_message(writer: &mut W, encoded: EncodedData) -> Result<(usize, usize)> { + let arrow_data_len = encoded.arrow_data.len(); + if arrow_data_len % 8 != 0 { + return Err(ArrowError::Ipc("Arrow data not aligned".to_string())); + } + + let a = 8 - 1; + let buffer = encoded.ipc_message; + let flatbuf_size = buffer.len(); + let prefix_size = 8; + let aligned_size = (flatbuf_size + prefix_size + a) & !a; + let padding_bytes = aligned_size - flatbuf_size - prefix_size; + + write_continuation(writer, (aligned_size - prefix_size) as i32)?; + + // write the flatbuf + if flatbuf_size > 0 { + writer.write_all(&buffer)?; + } + // write padding + writer.write_all(&vec![0; padding_bytes])?; + + // write arrow data + let body_len = if arrow_data_len > 0 { + write_body_buffers(writer, &encoded.arrow_data)? + } else { + 0 + }; + + Ok((aligned_size, body_len)) +} + +fn write_body_buffers(mut writer: W, data: &[u8]) -> Result { + let len = data.len(); + let pad_len = pad_to_8(data.len()); + let total_len = len + pad_len; + + // write body buffer + writer.write_all(data)?; + if pad_len > 0 { + writer.write_all(&vec![0u8; pad_len][..])?; + } + + writer.flush()?; + Ok(total_len) +} + +/// Write a record batch to the writer, writing the message size before the message +/// if the record batch is being written to a stream +pub fn write_continuation(writer: &mut W, total_len: i32) -> Result { + writer.write_all(&CONTINUATION_MARKER)?; + writer.write_all(&total_len.to_le_bytes()[..])?; + writer.flush()?; + Ok(8) +} diff --git a/src/io/ipc/write/mod.rs b/src/io/ipc/write/mod.rs index ba823de9100..5e75961884a 100644 --- a/src/io/ipc/write/mod.rs +++ b/src/io/ipc/write/mod.rs @@ -1,5 +1,5 @@ //! APIs to write to Arrow's IPC format. -pub mod common; +pub(crate) mod common; mod schema; mod serialize; mod stream; @@ -10,3 +10,11 @@ pub use schema::schema_to_bytes; pub use serialize::{write, write_dictionary}; pub use stream::StreamWriter; pub use writer::FileWriter; + +pub(crate) mod common_sync; + +#[cfg(feature = "io_ipc_write_async")] +mod common_async; +#[cfg(feature = "io_ipc_write_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_write_async")))] +pub mod stream_async; diff --git a/src/io/ipc/write/stream.rs b/src/io/ipc/write/stream.rs index cfd51ef312b..83338942cc2 100644 --- a/src/io/ipc/write/stream.rs +++ b/src/io/ipc/write/stream.rs @@ -22,9 +22,8 @@ use std::io::Write; -use super::common::{ - encoded_batch, write_continuation, write_message, DictionaryTracker, EncodedData, WriteOptions, -}; +use super::common::{encoded_batch, DictionaryTracker, EncodedData, WriteOptions}; +use super::common_sync::{write_continuation, write_message}; use super::schema_to_bytes; use crate::datatypes::*; diff --git a/src/io/ipc/write/stream_async.rs b/src/io/ipc/write/stream_async.rs new file mode 100644 index 00000000000..34395724e2b --- /dev/null +++ b/src/io/ipc/write/stream_async.rs @@ -0,0 +1,77 @@ +//! `async` writing of arrow streams +use futures::AsyncWrite; + +pub use super::common::WriteOptions; +use super::common::{encoded_batch, DictionaryTracker, EncodedData}; +use super::common_async::{write_continuation, write_message}; +use super::schema_to_bytes; + +use crate::datatypes::*; +use crate::error::{ArrowError, Result}; +use crate::record_batch::RecordBatch; + +/// An `async` writer to the Apache Arrow stream format. +pub struct StreamWriter { + /// The object to write to + writer: W, + /// IPC write options + write_options: WriteOptions, + /// Whether the writer footer has been written, and the writer is finished + finished: bool, + /// Keeps track of dictionaries that have been written + dictionary_tracker: DictionaryTracker, +} + +impl StreamWriter { + /// Creates a new [`StreamWriter`] + pub fn new(writer: W, write_options: WriteOptions) -> Self { + Self { + writer, + write_options, + finished: false, + dictionary_tracker: DictionaryTracker::new(false), + } + } + + /// Starts the stream + pub async fn start(&mut self, schema: &Schema) -> Result<()> { + let encoded_message = EncodedData { + ipc_message: schema_to_bytes(schema), + arrow_data: vec![], + }; + write_message(&mut self.writer, encoded_message).await?; + Ok(()) + } + + /// Writes a [`RecordBatch`] to the stream + pub async fn write(&mut self, batch: &RecordBatch) -> Result<()> { + if self.finished { + return Err(ArrowError::Ipc( + "Cannot write record batch to stream writer as it is closed".to_string(), + )); + } + + // todo: move this out of the `async` since this is blocking. + let (encoded_dictionaries, encoded_message) = + encoded_batch(batch, &mut self.dictionary_tracker, &self.write_options)?; + + for encoded_dictionary in encoded_dictionaries { + write_message(&mut self.writer, encoded_dictionary).await?; + } + + write_message(&mut self.writer, encoded_message).await?; + Ok(()) + } + + /// Finishes the stream + pub async fn finish(&mut self) -> Result<()> { + write_continuation(&mut self.writer, 0).await?; + self.finished = true; + Ok(()) + } + + /// Consumes itself, returning the inner writer. + pub fn into_inner(self) -> W { + self.writer + } +} diff --git a/src/io/ipc/write/writer.rs b/src/io/ipc/write/writer.rs index 88448df7509..10d58d063ca 100644 --- a/src/io/ipc/write/writer.rs +++ b/src/io/ipc/write/writer.rs @@ -28,10 +28,8 @@ use arrow_format::ipc::flatbuffers::FlatBufferBuilder; use super::super::ARROW_MAGIC; use super::{ super::convert, - common::{ - encoded_batch, write_continuation, write_message, DictionaryTracker, EncodedData, - WriteOptions, - }, + common::{encoded_batch, DictionaryTracker, EncodedData, WriteOptions}, + common_sync::{write_continuation, write_message}, schema_to_bytes, }; diff --git a/tests/it/io/ipc/mod.rs b/tests/it/io/ipc/mod.rs index fd285864a7a..6dafdaed47c 100644 --- a/tests/it/io/ipc/mod.rs +++ b/tests/it/io/ipc/mod.rs @@ -3,3 +3,6 @@ mod read; mod write; pub use common::read_gzip_json; + +#[cfg(feature = "io_ipc_write_async")] +mod write_async; diff --git a/tests/it/io/ipc/write_async.rs b/tests/it/io/ipc/write_async.rs new file mode 100644 index 00000000000..31f19d37d1e --- /dev/null +++ b/tests/it/io/ipc/write_async.rs @@ -0,0 +1,51 @@ +use std::io::Cursor; + +use arrow2::error::Result; +use arrow2::io::ipc::read::read_stream_metadata; +use arrow2::io::ipc::read::StreamReader; +use arrow2::io::ipc::write::stream_async::{StreamWriter, WriteOptions}; +use futures::io::Cursor as AsyncCursor; + +use crate::io::ipc::common::read_arrow_stream; +use crate::io::ipc::common::read_gzip_json; + +async fn test_file(version: &str, file_name: &str) -> Result<()> { + let (schema, batches) = read_arrow_stream(version, file_name); + + let mut result = AsyncCursor::new(Vec::::new()); + + // write IPC version 5 + { + let options = WriteOptions { compression: None }; + let mut writer = StreamWriter::new(&mut result, options); + writer.start(&schema).await?; + for batch in batches { + writer.write(&batch).await?; + } + writer.finish().await?; + } + let result = result.into_inner(); + + let mut reader = Cursor::new(result); + let metadata = read_stream_metadata(&mut reader)?; + let reader = StreamReader::new(reader, metadata); + + let schema = reader.schema().clone(); + + // read expected JSON output + let (expected_schema, expected_batches) = read_gzip_json(version, file_name).unwrap(); + + assert_eq!(schema.as_ref(), &expected_schema); + + let batches = reader + .map(|x| x.map(|x| x.unwrap())) + .collect::>>()?; + + assert_eq!(batches, expected_batches); + Ok(()) +} + +#[tokio::test] +async fn write_async() -> Result<()> { + test_file("1.0.0-littleendian", "generated_primitive").await +}