From dda052f5d622cce5aef50e90f476e3d9ceb1f49b Mon Sep 17 00:00:00 2001 From: Dexter Duckworth Date: Sat, 5 Mar 2022 09:52:31 -0500 Subject: [PATCH] IPC sink types and IPC file stream (#878) --- Cargo.toml | 2 +- src/io/ipc/read/file_async.rs | 275 ++++++++++++++++++ src/io/ipc/read/mod.rs | 4 + src/io/ipc/read/reader.rs | 6 +- src/io/ipc/write/common.rs | 53 ++++ src/io/ipc/write/common_async.rs | 6 +- src/io/ipc/write/file_async.rs | 248 ++++++++++++++++ src/io/ipc/write/mod.rs | 6 +- src/io/ipc/write/stream_async.rs | 223 +++++++++----- tests/it/io/ipc/mod.rs | 8 +- tests/it/io/ipc/read_file_async.rs | 45 +++ tests/it/io/ipc/write_file_async.rs | 63 ++++ .../{write_async.rs => write_stream_async.rs} | 10 +- 13 files changed, 862 insertions(+), 87 deletions(-) create mode 100644 src/io/ipc/read/file_async.rs create mode 100644 src/io/ipc/write/file_async.rs create mode 100644 tests/it/io/ipc/read_file_async.rs create mode 100644 tests/it/io/ipc/write_file_async.rs rename tests/it/io/ipc/{write_async.rs => write_stream_async.rs} (86%) diff --git a/Cargo.toml b/Cargo.toml index d88a6c340f9..d5d9a832d66 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -134,7 +134,7 @@ io_csv_write = ["csv", "csv-core", "streaming-iterator", "lexical-core"] io_json = ["serde", "serde_json", "streaming-iterator", "fallible-streaming-iterator", "indexmap", "lexical-core"] io_ipc = ["arrow-format"] io_ipc_write_async = ["io_ipc", "futures"] -io_ipc_read_async = ["io_ipc", "futures"] +io_ipc_read_async = ["io_ipc", "futures", "async-stream"] io_ipc_compression = ["lz4", "zstd"] io_flight = ["io_ipc", "arrow-format/flight-data"] # base64 + io_ipc because arrow schemas are stored as base64-encoded ipc format. diff --git a/src/io/ipc/read/file_async.rs b/src/io/ipc/read/file_async.rs new file mode 100644 index 00000000000..e7bc04eb845 --- /dev/null +++ b/src/io/ipc/read/file_async.rs @@ -0,0 +1,275 @@ +//! Async reader for Arrow IPC files +use std::io::SeekFrom; +use std::sync::Arc; + +use arrow_format::ipc::{ + planus::{ReadAsRoot, Vector}, + BlockRef, FooterRef, MessageHeaderRef, MessageRef, +}; +use futures::{ + stream::BoxStream, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, Stream, StreamExt, +}; + +use crate::array::*; +use crate::chunk::Chunk; +use crate::datatypes::{Field, Schema}; +use crate::error::{ArrowError, Result}; +use crate::io::ipc::{IpcSchema, ARROW_MAGIC, CONTINUATION_MARKER}; + +use super::common::{read_dictionary, read_record_batch}; +use super::reader::get_serialized_batch; +use super::schema::fb_to_schema; +use super::Dictionaries; +use super::FileMetadata; + +/// Async reader for Arrow IPC files +pub struct FileStream<'a> { + stream: BoxStream<'a, Result>>>, + metadata: FileMetadata, + schema: Schema, +} + +impl<'a> FileStream<'a> { + /// Create a new IPC file reader. + /// + /// # Examples + /// See [`FileSink`](crate::io::ipc::write::file_async::FileSink). + pub fn new(reader: R, metadata: FileMetadata, projection: Option>) -> Self + where + R: AsyncRead + AsyncSeek + Unpin + Send + 'a, + { + let schema = if let Some(projection) = projection.as_ref() { + projection.windows(2).for_each(|x| { + assert!( + x[0] < x[1], + "IPC projection must be ordered and non-overlapping", + ) + }); + let fields = projection + .iter() + .map(|&x| metadata.schema.fields[x].clone()) + .collect::>(); + Schema { + fields, + metadata: metadata.schema.metadata.clone(), + } + } else { + metadata.schema.clone() + }; + + let stream = Self::stream(reader, metadata.clone(), projection); + Self { + stream, + metadata, + schema, + } + } + + /// Get the metadata from the IPC file. + pub fn metadata(&self) -> &FileMetadata { + &self.metadata + } + + /// Get the projected schema from the IPC file. + pub fn schema(&self) -> &Schema { + &self.schema + } + + fn stream( + mut reader: R, + metadata: FileMetadata, + projection: Option>, + ) -> BoxStream<'a, Result>>> + where + R: AsyncRead + AsyncSeek + Unpin + Send + 'a, + { + async_stream::try_stream! { + let mut meta_buffer = vec![]; + let mut block_buffer = vec![]; + for block in 0..metadata.blocks.len() { + let chunk = read_batch( + &mut reader, + &metadata, + projection.as_deref(), + block, + &mut meta_buffer, + &mut block_buffer, + ).await?; + yield chunk; + } + } + .boxed() + } +} + +impl<'a> Stream for FileStream<'a> { + type Item = Result>>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().stream.poll_next_unpin(cx) + } +} + +/// Read the metadata from an IPC file. +pub async fn read_file_metadata_async(mut reader: R) -> Result +where + R: AsyncRead + AsyncSeek + Unpin, +{ + // Check header + let mut magic = [0; 6]; + reader.read_exact(&mut magic).await?; + if magic != ARROW_MAGIC { + return Err(ArrowError::OutOfSpec( + "file does not contain correct Arrow header".to_string(), + )); + } + // Check footer + reader.seek(SeekFrom::End(-6)).await?; + reader.read_exact(&mut magic).await?; + if magic != ARROW_MAGIC { + return Err(ArrowError::OutOfSpec( + "file does not contain correct Arrow footer".to_string(), + )); + } + // Get footer size + let mut footer_size = [0; 4]; + reader.seek(SeekFrom::End(-10)).await?; + reader.read_exact(&mut footer_size).await?; + let footer_size = i32::from_le_bytes(footer_size); + // Read footer + let mut footer = vec![0; footer_size as usize]; + reader.seek(SeekFrom::End(-10 - footer_size as i64)).await?; + reader.read_exact(&mut footer).await?; + let footer = FooterRef::read_as_root(&footer[..]) + .map_err(|err| ArrowError::OutOfSpec(format!("unable to get root as footer: {:?}", err)))?; + + let blocks = footer.record_batches()?.ok_or_else(|| { + ArrowError::OutOfSpec("unable to get record batches from footer".to_string()) + })?; + let schema = footer + .schema()? + .ok_or_else(|| ArrowError::OutOfSpec("unable to get schema from footer".to_string()))?; + let (schema, ipc_schema) = fb_to_schema(schema)?; + let dictionary_blocks = footer.dictionaries()?; + let dictionaries = if let Some(blocks) = dictionary_blocks { + read_dictionaries(reader, &schema.fields[..], &ipc_schema, blocks).await? + } else { + Default::default() + }; + + Ok(FileMetadata { + schema, + ipc_schema, + blocks: blocks + .iter() + .map(|block| Ok(block.try_into()?)) + .collect::>>()?, + dictionaries, + }) +} + +async fn read_dictionaries( + mut reader: R, + fields: &[Field], + ipc_schema: &IpcSchema, + blocks: Vector<'_, BlockRef<'_>>, +) -> Result +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let mut dictionaries = Default::default(); + let mut data = vec![]; + let mut buffer = vec![]; + + for block in blocks { + let offset = block.offset() as u64; + read_dictionary_message(&mut reader, offset, &mut data).await?; + + let message = MessageRef::read_as_root(&data).map_err(|err| { + ArrowError::OutOfSpec(format!("unable to get root as message: {:?}", err)) + })?; + let header = message + .header()? + .ok_or_else(|| ArrowError::oos("message must have a header"))?; + match header { + MessageHeaderRef::DictionaryBatch(batch) => { + buffer.clear(); + buffer.resize(block.body_length() as usize, 0); + reader.read_exact(&mut buffer).await?; + let mut cursor = std::io::Cursor::new(&mut buffer); + read_dictionary(batch, fields, ipc_schema, &mut dictionaries, &mut cursor, 0)?; + } + other => { + return Err(ArrowError::OutOfSpec(format!( + "expected DictionaryBatch in dictionary blocks, found {:?}", + other, + ))) + } + } + } + Ok(dictionaries) +} + +async fn read_dictionary_message(mut reader: R, offset: u64, data: &mut Vec) -> Result<()> +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let mut message_size = [0; 4]; + reader.seek(SeekFrom::Start(offset)).await?; + reader.read_exact(&mut message_size).await?; + if message_size == CONTINUATION_MARKER { + reader.read_exact(&mut message_size).await?; + } + let footer_size = i32::from_le_bytes(message_size); + data.clear(); + data.resize(footer_size as usize, 0); + reader.read_exact(data).await?; + + Ok(()) +} + +async fn read_batch( + mut reader: R, + metadata: &FileMetadata, + projection: Option<&[usize]>, + block: usize, + meta_buffer: &mut Vec, + block_buffer: &mut Vec, +) -> Result>> +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let block = metadata.blocks[block]; + reader.seek(SeekFrom::Start(block.offset as u64)).await?; + let mut meta_buf = [0; 4]; + reader.read_exact(&mut meta_buf).await?; + if meta_buf == CONTINUATION_MARKER { + reader.read_exact(&mut meta_buf).await?; + } + let meta_len = i32::from_le_bytes(meta_buf) as usize; + meta_buffer.clear(); + meta_buffer.resize(meta_len, 0); + reader.read_exact(meta_buffer).await?; + + let message = MessageRef::read_as_root(&meta_buffer[..]) + .map_err(|err| ArrowError::oos(format!("unable to parse message: {:?}", err)))?; + let batch = get_serialized_batch(&message)?; + block_buffer.clear(); + block_buffer.resize(message.body_length()? as usize, 0); + reader.read_exact(block_buffer).await?; + let mut cursor = std::io::Cursor::new(block_buffer); + let chunk = read_record_batch( + batch, + &metadata.schema.fields, + &metadata.ipc_schema, + projection, + &metadata.dictionaries, + message.version()?, + &mut cursor, + 0, + )?; + Ok(chunk) +} diff --git a/src/io/ipc/read/mod.rs b/src/io/ipc/read/mod.rs index 3a45d4ecac6..207c33329fc 100644 --- a/src/io/ipc/read/mod.rs +++ b/src/io/ipc/read/mod.rs @@ -20,6 +20,10 @@ mod stream; #[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_read_async")))] pub mod stream_async; +#[cfg(feature = "io_ipc_read_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_read_async")))] +pub mod file_async; + pub use common::{read_dictionary, read_record_batch}; pub use reader::{read_file_metadata, FileMetadata, FileReader}; pub use schema::deserialize_schema; diff --git a/src/io/ipc/read/reader.rs b/src/io/ipc/read/reader.rs index 3254eb4a7ef..d31e975d666 100644 --- a/src/io/ipc/read/reader.rs +++ b/src/io/ipc/read/reader.rs @@ -26,10 +26,10 @@ pub struct FileMetadata { /// The blocks in the file /// /// A block indicates the regions in the file to read to get data - blocks: Vec, + pub(super) blocks: Vec, /// Dictionaries associated to each dict_id - dictionaries: Dictionaries, + pub(super) dictionaries: Dictionaries, } /// Arrow File reader @@ -166,7 +166,7 @@ pub fn read_file_metadata(reader: &mut R) -> Result( +pub(super) fn get_serialized_batch<'a>( message: &'a arrow_format::ipc::MessageRef, ) -> Result> { let header = message.header()?.ok_or_else(|| { diff --git a/src/io/ipc/write/common.rs b/src/io/ipc/write/common.rs index 2db05ee4afc..14107ed9ac2 100644 --- a/src/io/ipc/write/common.rs +++ b/src/io/ipc/write/common.rs @@ -1,3 +1,4 @@ +use std::borrow::{Borrow, Cow}; use std::sync::Arc; use arrow_format::ipc::planus::Builder; @@ -379,3 +380,55 @@ pub struct EncodedData { pub(crate) fn pad_to_8(len: usize) -> usize { (((len + 7) & !7) - len) as usize } + +/// An array [`Chunk`] with optional accompanying IPC fields. +#[derive(Debug, Clone, PartialEq)] +pub struct Record<'a> { + columns: Cow<'a, Chunk>>, + fields: Option>, +} + +impl<'a> Record<'a> { + /// Get the IPC fields for this record. + pub fn fields(&self) -> Option<&[IpcField]> { + self.fields.as_deref() + } + + /// Get the Arrow columns in this record. + pub fn columns(&self) -> &Chunk> { + self.columns.borrow() + } +} + +impl From>> for Record<'static> { + fn from(columns: Chunk>) -> Self { + Self { + columns: Cow::Owned(columns), + fields: None, + } + } +} + +impl<'a, F> From<(Chunk>, Option)> for Record<'a> +where + F: Into>, +{ + fn from((columns, fields): (Chunk>, Option)) -> Self { + Self { + columns: Cow::Owned(columns), + fields: fields.map(|f| f.into()), + } + } +} + +impl<'a, F> From<(&'a Chunk>, Option)> for Record<'a> +where + F: Into>, +{ + fn from((columns, fields): (&'a Chunk>, Option)) -> Self { + Self { + columns: Cow::Borrowed(columns), + fields: fields.map(|f| f.into()), + } + } +} diff --git a/src/io/ipc/write/common_async.rs b/src/io/ipc/write/common_async.rs index cade8331dc1..8792a9bd0bf 100644 --- a/src/io/ipc/write/common_async.rs +++ b/src/io/ipc/write/common_async.rs @@ -9,7 +9,7 @@ use super::common::EncodedData; /// Write a message's IPC data and buffers, returning metadata and buffer data lengths written pub async fn write_message( - writer: &mut W, + mut writer: W, encoded: EncodedData, ) -> Result<(usize, usize)> { let arrow_data_len = encoded.arrow_data.len(); @@ -21,7 +21,7 @@ pub async fn write_message( let aligned_size = (flatbuf_size + prefix_size + a) & !a; let padding_bytes = aligned_size - flatbuf_size - prefix_size; - write_continuation(writer, (aligned_size - prefix_size) as i32).await?; + write_continuation(&mut writer, (aligned_size - prefix_size) as i32).await?; // write the flatbuf if flatbuf_size > 0 { @@ -43,7 +43,7 @@ pub async fn write_message( /// Write a record batch to the writer, writing the message size before the message /// if the record batch is being written to a stream pub async fn write_continuation( - writer: &mut W, + mut writer: W, total_len: i32, ) -> Result { writer.write_all(&CONTINUATION_MARKER).await?; diff --git a/src/io/ipc/write/file_async.rs b/src/io/ipc/write/file_async.rs new file mode 100644 index 00000000000..930bc4d5c5b --- /dev/null +++ b/src/io/ipc/write/file_async.rs @@ -0,0 +1,248 @@ +//! Async writer for IPC files. + +use std::task::Poll; + +use arrow_format::ipc::{planus::Builder, Block, Footer, MetadataVersion}; +use futures::{future::BoxFuture, AsyncWrite, AsyncWriteExt, FutureExt, Sink}; + +use super::common::{encode_chunk, DictionaryTracker, EncodedData, WriteOptions}; +use super::common_async::{write_continuation, write_message}; +use super::schema::serialize_schema; +use super::{default_ipc_fields, schema_to_bytes, Record}; +use crate::datatypes::*; +use crate::error::{ArrowError, Result}; +use crate::io::ipc::{IpcField, ARROW_MAGIC}; + +type WriteOutput = (usize, Option, Vec, Option); + +/// Sink that writes array [`chunks`](Chunk) as an IPC file. +/// +/// The file header is automatically written before writing the first chunk, and the file footer is +/// automatically written when the sink is closed. +/// +/// # Examples +/// +/// ``` +/// use std::sync::Arc; +/// use futures::{SinkExt, TryStreamExt, io::Cursor}; +/// use arrow2::array::{Array, Int32Array}; +/// use arrow2::datatypes::{DataType, Field, Schema}; +/// use arrow2::chunk::Chunk; +/// use arrow2::io::ipc::write::file_async::FileSink; +/// use arrow2::io::ipc::read::file_async::{read_file_metadata_async, FileStream}; +/// # futures::executor::block_on(async move { +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// +/// let mut buffer = Cursor::new(vec![]); +/// let mut sink = FileSink::new( +/// &mut buffer, +/// &schema, +/// None, +/// Default::default(), +/// ); +/// +/// // Write chunks to file +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![Arc::new(values) as Arc]); +/// sink.feed(chunk.into()).await?; +/// } +/// sink.close().await?; +/// drop(sink); +/// +/// // Read chunks from file +/// buffer.set_position(0); +/// let metadata = read_file_metadata_async(&mut buffer).await?; +/// let mut stream = FileStream::new(buffer, metadata, None); +/// let chunks = stream.try_collect::>().await?; +/// # arrow2::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct FileSink<'a, W: AsyncWrite + Unpin + Send + 'a> { + writer: Option, + task: Option>>>, + options: WriteOptions, + dictionary_tracker: DictionaryTracker, + offset: usize, + fields: Vec, + record_blocks: Vec, + dictionary_blocks: Vec, + schema: Schema, +} + +impl<'a, W> FileSink<'a, W> +where + W: AsyncWrite + Unpin + Send + 'a, +{ + /// Create a new file writer. + pub fn new( + writer: W, + schema: &Schema, + ipc_fields: Option>, + options: WriteOptions, + ) -> Self { + let fields = ipc_fields.unwrap_or_else(|| default_ipc_fields(&schema.fields)); + let encoded = EncodedData { + ipc_message: schema_to_bytes(schema, &fields), + arrow_data: vec![], + }; + let task = Some(Self::start(writer, encoded).boxed()); + Self { + writer: None, + task, + options, + fields, + offset: 0, + schema: schema.clone(), + dictionary_tracker: DictionaryTracker::new(true), + record_blocks: vec![], + dictionary_blocks: vec![], + } + } + + async fn start(mut writer: W, encoded: EncodedData) -> Result> { + writer.write_all(&ARROW_MAGIC[..]).await?; + writer.write_all(&[0, 0]).await?; + let (meta, data) = write_message(&mut writer, encoded).await?; + + Ok((meta + data + 8, None, vec![], Some(writer))) + } + + async fn write( + mut writer: W, + mut offset: usize, + record: EncodedData, + dictionaries: Vec, + ) -> Result> { + let mut dict_blocks = vec![]; + for dict in dictionaries { + let (meta, data) = write_message(&mut writer, dict).await?; + let block = Block { + offset: offset as i64, + meta_data_length: meta as i32, + body_length: data as i64, + }; + dict_blocks.push(block); + offset += meta + data; + } + let (meta, data) = write_message(&mut writer, record).await?; + let block = Block { + offset: offset as i64, + meta_data_length: meta as i32, + body_length: data as i64, + }; + offset += meta + data; + Ok((offset, Some(block), dict_blocks, Some(writer))) + } + + async fn finish(mut writer: W, footer: Footer) -> Result> { + write_continuation(&mut writer, 0).await?; + let footer = { + let mut builder = Builder::new(); + builder.finish(&footer, None).to_owned() + }; + writer.write_all(&footer[..]).await?; + writer + .write_all(&(footer.len() as i32).to_le_bytes()) + .await?; + writer.write_all(&ARROW_MAGIC).await?; + writer.close().await?; + + Ok((0, None, vec![], None)) + } + + fn poll_write(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if let Some(task) = &mut self.task { + match futures::ready!(task.poll_unpin(cx)) { + Ok((offset, record, mut dictionaries, writer)) => { + self.task = None; + self.writer = writer; + self.offset = offset; + if let Some(block) = record { + self.record_blocks.push(block); + } + self.dictionary_blocks.append(&mut dictionaries); + Poll::Ready(Ok(())) + } + Err(error) => { + self.task = None; + Poll::Ready(Err(error)) + } + } + } else { + Poll::Ready(Ok(())) + } + } +} + +impl<'a, W> Sink> for FileSink<'a, W> +where + W: AsyncWrite + Unpin + Send + 'a, +{ + type Error = ArrowError; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().poll_write(cx) + } + + fn start_send(self: std::pin::Pin<&mut Self>, item: Record<'_>) -> Result<()> { + let this = self.get_mut(); + + if let Some(writer) = this.writer.take() { + let fields = item.fields().unwrap_or_else(|| &this.fields[..]); + + let (dictionaries, record) = encode_chunk( + item.columns(), + fields, + &mut this.dictionary_tracker, + &this.options, + )?; + + this.task = Some(Self::write(writer, this.offset, record, dictionaries).boxed()); + Ok(()) + } else { + Err(ArrowError::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "writer is closed", + ))) + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().poll_write(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + match futures::ready!(this.poll_write(cx)) { + Ok(()) => { + if let Some(writer) = this.writer.take() { + let schema = serialize_schema(&this.schema, &this.fields); + let footer = Footer { + version: MetadataVersion::V5, + schema: Some(Box::new(schema)), + dictionaries: Some(std::mem::take(&mut this.dictionary_blocks)), + record_batches: Some(std::mem::take(&mut this.record_blocks)), + custom_metadata: None, + }; + this.task = Some(Self::finish(writer, footer).boxed()); + this.poll_write(cx) + } else { + Poll::Ready(Ok(())) + } + } + Err(error) => Poll::Ready(Err(error)), + } + } +} diff --git a/src/io/ipc/write/mod.rs b/src/io/ipc/write/mod.rs index 6331e9dc9a0..f4afb0e451f 100644 --- a/src/io/ipc/write/mod.rs +++ b/src/io/ipc/write/mod.rs @@ -5,7 +5,7 @@ mod serialize; mod stream; mod writer; -pub use common::{Compression, WriteOptions}; +pub use common::{Compression, Record, WriteOptions}; pub use schema::schema_to_bytes; pub use serialize::{write, write_dictionary}; pub use stream::StreamWriter; @@ -19,6 +19,10 @@ mod common_async; #[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_write_async")))] pub mod stream_async; +#[cfg(feature = "io_ipc_write_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_write_async")))] +pub mod file_async; + use crate::datatypes::{DataType, Field}; use super::IpcField; diff --git a/src/io/ipc/write/stream_async.rs b/src/io/ipc/write/stream_async.rs index 7da157c1f4b..033fa06038b 100644 --- a/src/io/ipc/write/stream_async.rs +++ b/src/io/ipc/write/stream_async.rs @@ -1,108 +1,183 @@ //! `async` writing of arrow streams -use std::sync::Arc; -use futures::AsyncWrite; +use std::{pin::Pin, task::Poll}; + +use futures::{future::BoxFuture, AsyncWrite, FutureExt, Sink}; use super::super::IpcField; pub use super::common::WriteOptions; use super::common::{encode_chunk, DictionaryTracker, EncodedData}; use super::common_async::{write_continuation, write_message}; -use super::{default_ipc_fields, schema_to_bytes}; +use super::{default_ipc_fields, schema_to_bytes, Record}; -use crate::array::Array; -use crate::chunk::Chunk; use crate::datatypes::*; use crate::error::{ArrowError, Result}; -/// An `async` writer to the Apache Arrow stream format. -pub struct StreamWriter { - /// The object to write to - writer: W, - /// IPC write options - write_options: WriteOptions, - /// Whether the stream has been finished - finished: bool, - /// Keeps track of dictionaries that have been written +/// A sink that writes array [`chunks`](Chunk) as an IPC stream. +/// +/// The stream header is automatically written before writing the first chunk. +/// +/// # Examples +/// +/// ``` +/// use std::sync::Arc; +/// use futures::SinkExt; +/// use arrow2::array::{Array, Int32Array}; +/// use arrow2::datatypes::{DataType, Field, Schema}; +/// use arrow2::chunk::Chunk; +/// # use arrow2::io::ipc::write::stream_async::StreamSink; +/// # futures::executor::block_on(async move { +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// +/// let mut buffer = vec![]; +/// let mut sink = StreamSink::new( +/// &mut buffer, +/// &schema, +/// None, +/// Default::default(), +/// ); +/// +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![Arc::new(values) as Arc]); +/// sink.feed(chunk.into()).await?; +/// } +/// sink.close().await?; +/// # arrow2::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct StreamSink<'a, W: AsyncWrite + Unpin + Send + 'a> { + writer: Option, + task: Option>>>, + options: WriteOptions, dictionary_tracker: DictionaryTracker, + fields: Vec, } -impl StreamWriter { - /// Creates a new [`StreamWriter`] - pub fn new(writer: W, write_options: WriteOptions) -> Self { +impl<'a, W> StreamSink<'a, W> +where + W: AsyncWrite + Unpin + Send + 'a, +{ + /// Create a new [`StreamSink`]. + pub fn new( + writer: W, + schema: &Schema, + ipc_fields: Option>, + write_options: WriteOptions, + ) -> Self { + let fields = ipc_fields.unwrap_or_else(|| default_ipc_fields(&schema.fields)); + let task = Some(Self::start(writer, schema, &fields[..])); Self { - writer, - write_options, - finished: false, + writer: None, + task, + fields, dictionary_tracker: DictionaryTracker::new(false), + options: write_options, } } - /// Starts the stream - pub async fn start(&mut self, schema: &Schema, ipc_fields: Option<&[IpcField]>) -> Result<()> { - let encoded_message = if let Some(ipc_fields) = ipc_fields { - EncodedData { - ipc_message: schema_to_bytes(schema, ipc_fields), - arrow_data: vec![], - } - } else { - let ipc_fields = default_ipc_fields(&schema.fields); - EncodedData { - ipc_message: schema_to_bytes(schema, &ipc_fields), - arrow_data: vec![], - } + fn start( + mut writer: W, + schema: &Schema, + ipc_fields: &[IpcField], + ) -> BoxFuture<'a, Result>> { + let message = EncodedData { + ipc_message: schema_to_bytes(schema, ipc_fields), + arrow_data: vec![], }; - write_message(&mut self.writer, encoded_message).await?; - Ok(()) + async move { + write_message(&mut writer, message).await?; + Ok(Some(writer)) + } + .boxed() } - /// Writes [`Chunk`] to the stream - pub async fn write( - &mut self, - columns: &Chunk>, - schema: &Schema, - ipc_fields: Option<&[IpcField]>, - ) -> Result<()> { - if self.finished { - return Err(ArrowError::Io(std::io::Error::new( + fn write(&mut self, record: Record<'_>) -> Result<()> { + let fields = record.fields().unwrap_or(&self.fields[..]); + let (dictionaries, message) = encode_chunk( + record.columns(), + fields, + &mut self.dictionary_tracker, + &self.options, + )?; + + if let Some(mut writer) = self.writer.take() { + self.task = Some( + async move { + for d in dictionaries { + write_message(&mut writer, d).await?; + } + write_message(&mut writer, message).await?; + Ok(Some(writer)) + } + .boxed(), + ); + Ok(()) + } else { + Err(ArrowError::Io(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, - "Cannot write to a finished stream".to_string(), - ))); + "writer closed".to_string(), + ))) } + } - let (encoded_dictionaries, encoded_message) = if let Some(ipc_fields) = ipc_fields { - encode_chunk( - columns, - ipc_fields, - &mut self.dictionary_tracker, - &self.write_options, - )? + fn poll_complete(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if let Some(task) = &mut self.task { + match futures::ready!(task.poll_unpin(cx)) { + Ok(writer) => { + self.writer = writer; + self.task = None; + Poll::Ready(Ok(())) + } + Err(error) => { + self.task = None; + Poll::Ready(Err(error)) + } + } } else { - let ipc_fields = default_ipc_fields(&schema.fields); - encode_chunk( - columns, - &ipc_fields, - &mut self.dictionary_tracker, - &self.write_options, - )? - }; - - for encoded_dictionary in encoded_dictionaries { - write_message(&mut self.writer, encoded_dictionary).await?; + Poll::Ready(Ok(())) } + } +} - write_message(&mut self.writer, encoded_message).await?; - Ok(()) +impl<'a, W> Sink> for StreamSink<'a, W> +where + W: AsyncWrite + Unpin + Send, +{ + type Error = ArrowError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + self.get_mut().poll_complete(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Record<'_>) -> Result<()> { + self.get_mut().write(item) } - /// Finishes the stream - pub async fn finish(&mut self) -> Result<()> { - write_continuation(&mut self.writer, 0).await?; - self.finished = true; - Ok(()) + fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + self.get_mut().poll_complete(cx) } - /// Consumes itself, returning the inner writer. - pub fn into_inner(self) -> W { - self.writer + fn poll_close(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + match this.poll_complete(cx) { + Poll::Ready(Ok(())) => { + if let Some(mut writer) = this.writer.take() { + this.task = Some( + async move { + write_continuation(&mut writer, 0).await?; + Ok(None) + } + .boxed(), + ); + this.poll_complete(cx) + } else { + Poll::Ready(Ok(())) + } + } + res => res, + } } } diff --git a/tests/it/io/ipc/mod.rs b/tests/it/io/ipc/mod.rs index e1a5f79b4f0..63cce5b2e8f 100644 --- a/tests/it/io/ipc/mod.rs +++ b/tests/it/io/ipc/mod.rs @@ -5,7 +5,13 @@ mod write; pub use common::read_gzip_json; #[cfg(feature = "io_ipc_write_async")] -mod write_async; +mod write_stream_async; + +#[cfg(feature = "io_ipc_write_async")] +mod write_file_async; #[cfg(feature = "io_ipc_read_async")] mod read_stream_async; + +#[cfg(feature = "io_ipc_read_async")] +mod read_file_async; diff --git a/tests/it/io/ipc/read_file_async.rs b/tests/it/io/ipc/read_file_async.rs new file mode 100644 index 00000000000..af619604608 --- /dev/null +++ b/tests/it/io/ipc/read_file_async.rs @@ -0,0 +1,45 @@ +use futures::StreamExt; +use tokio::fs::File; +use tokio_util::compat::*; + +use arrow2::error::Result; +use arrow2::io::ipc::read::file_async::*; + +use crate::io::ipc::common::read_gzip_json; + +async fn test_file(version: &str, file_name: &str) -> Result<()> { + let testdata = crate::test_util::arrow_test_data(); + let mut file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.arrow_file", + testdata, version, file_name + )) + .await? + .compat(); + + let metadata = read_file_metadata_async(&mut file).await?; + let mut reader = FileStream::new(file, metadata, None); + + // read expected JSON output + let (schema, ipc_fields, batches) = read_gzip_json(version, file_name)?; + + assert_eq!(&schema, &reader.metadata().schema); + assert_eq!(&ipc_fields, &reader.metadata().ipc_schema.fields); + + let mut items = vec![]; + while let Some(item) = reader.next().await { + items.push(item?) + } + + batches + .iter() + .zip(items.into_iter()) + .for_each(|(lhs, rhs)| { + assert_eq!(lhs, &rhs); + }); + Ok(()) +} + +#[tokio::test] +async fn write_async() -> Result<()> { + test_file("1.0.0-littleendian", "generated_primitive").await +} diff --git a/tests/it/io/ipc/write_file_async.rs b/tests/it/io/ipc/write_file_async.rs new file mode 100644 index 00000000000..ffa6530feed --- /dev/null +++ b/tests/it/io/ipc/write_file_async.rs @@ -0,0 +1,63 @@ +use std::io::Cursor; +use std::sync::Arc; + +use arrow2::array::Array; +use arrow2::chunk::Chunk; +use arrow2::datatypes::Schema; +use arrow2::error::Result; +use arrow2::io::ipc::read; +use arrow2::io::ipc::write::file_async::FileSink; +use arrow2::io::ipc::write::WriteOptions; +use arrow2::io::ipc::IpcField; +use futures::io::Cursor as AsyncCursor; +use futures::SinkExt; + +use crate::io::ipc::common::read_arrow_stream; +use crate::io::ipc::common::read_gzip_json; + +async fn write_( + schema: &Schema, + ipc_fields: &[IpcField], + batches: &[Chunk>], +) -> Result> { + let mut result = AsyncCursor::new(vec![]); + + let options = WriteOptions { compression: None }; + let mut sink = FileSink::new(&mut result, schema, Some(ipc_fields.to_vec()), options); + for batch in batches { + sink.feed((batch, Some(ipc_fields)).into()).await?; + } + sink.close().await?; + drop(sink); + Ok(result.into_inner()) +} + +async fn test_file(version: &str, file_name: &str) -> Result<()> { + let (schema, ipc_fields, batches) = read_arrow_stream(version, file_name); + + let result = write_(&schema, &ipc_fields, &batches).await?; + + let mut reader = Cursor::new(result); + let metadata = read::read_file_metadata(&mut reader)?; + let reader = read::FileReader::new(reader, metadata, None); + + let schema = &reader.metadata().schema; + let ipc_fields = reader.metadata().ipc_schema.fields.clone(); + + // read expected JSON output + let (expected_schema, expected_ipc_fields, expected_batches) = + read_gzip_json(version, file_name).unwrap(); + + assert_eq!(schema, &expected_schema); + assert_eq!(ipc_fields, expected_ipc_fields); + + let batches = reader.collect::>>()?; + + assert_eq!(batches, expected_batches); + Ok(()) +} + +#[tokio::test] +async fn write_async() -> Result<()> { + test_file("1.0.0-littleendian", "generated_primitive").await +} diff --git a/tests/it/io/ipc/write_async.rs b/tests/it/io/ipc/write_stream_async.rs similarity index 86% rename from tests/it/io/ipc/write_async.rs rename to tests/it/io/ipc/write_stream_async.rs index 9f27eea8808..2840692a75f 100644 --- a/tests/it/io/ipc/write_async.rs +++ b/tests/it/io/ipc/write_stream_async.rs @@ -7,8 +7,10 @@ use arrow2::datatypes::Schema; use arrow2::error::Result; use arrow2::io::ipc::read; use arrow2::io::ipc::write::stream_async; +use arrow2::io::ipc::write::stream_async::StreamSink; use arrow2::io::ipc::IpcField; use futures::io::Cursor as AsyncCursor; +use futures::SinkExt; use crate::io::ipc::common::read_arrow_stream; use crate::io::ipc::common::read_gzip_json; @@ -21,12 +23,12 @@ async fn write_( let mut result = AsyncCursor::new(vec![]); let options = stream_async::WriteOptions { compression: None }; - let mut writer = stream_async::StreamWriter::new(&mut result, options); - writer.start(schema, Some(ipc_fields)).await?; + let mut sink = StreamSink::new(&mut result, schema, Some(ipc_fields.to_vec()), options); for batch in batches { - writer.write(batch, schema, Some(ipc_fields)).await?; + sink.feed((batch, Some(ipc_fields)).into()).await?; } - writer.finish().await?; + sink.close().await?; + drop(sink); Ok(result.into_inner()) }