diff --git a/Cargo.toml b/Cargo.toml index d88a6c340f9..d5d9a832d66 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -134,7 +134,7 @@ io_csv_write = ["csv", "csv-core", "streaming-iterator", "lexical-core"] io_json = ["serde", "serde_json", "streaming-iterator", "fallible-streaming-iterator", "indexmap", "lexical-core"] io_ipc = ["arrow-format"] io_ipc_write_async = ["io_ipc", "futures"] -io_ipc_read_async = ["io_ipc", "futures"] +io_ipc_read_async = ["io_ipc", "futures", "async-stream"] io_ipc_compression = ["lz4", "zstd"] io_flight = ["io_ipc", "arrow-format/flight-data"] # base64 + io_ipc because arrow schemas are stored as base64-encoded ipc format. diff --git a/src/io/ipc/read/file_async.rs b/src/io/ipc/read/file_async.rs new file mode 100644 index 00000000000..a67170e51a0 --- /dev/null +++ b/src/io/ipc/read/file_async.rs @@ -0,0 +1,275 @@ +//! Async reader for Arrow IPC files +use std::io::SeekFrom; +use std::sync::Arc; + +use arrow_format::ipc::{ + planus::{ReadAsRoot, Vector}, + BlockRef, FooterRef, MessageHeaderRef, MessageRef, +}; +use futures::{ + stream::BoxStream, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, Stream, StreamExt, +}; + +use crate::array::*; +use crate::chunk::Chunk; +use crate::datatypes::{Field, Schema}; +use crate::error::{ArrowError, Result}; +use crate::io::ipc::{IpcSchema, ARROW_MAGIC, CONTINUATION_MARKER}; + +use super::common::{read_dictionary, read_record_batch}; +use super::reader::get_serialized_batch; +use super::schema::fb_to_schema; +use super::Dictionaries; +use super::FileMetadata; + +/// Async reader for Arrow IPC files +pub struct FileStream<'a> { + stream: BoxStream<'a, Result>>>, + metadata: FileMetadata, + schema: Schema, +} + +impl<'a> FileStream<'a> { + /// Create a new IPC file reader. + /// + /// # Examples + /// See [`FileSink`](crate::io::ipc::write::file_async::FileSink). + pub fn new(reader: R, metadata: FileMetadata, projection: Option>) -> Self + where + R: AsyncRead + AsyncSeek + Unpin + Send + 'a, + { + let schema = if let Some(projection) = projection.as_ref() { + projection.windows(2).for_each(|x| { + assert!( + x[0] < x[1], + "IPC projection must be ordered and non-overlapping", + ) + }); + let fields = projection + .iter() + .map(|&x| metadata.schema.fields[x].clone()) + .collect::>(); + Schema { + fields, + metadata: metadata.schema.metadata.clone(), + } + } else { + metadata.schema.clone() + }; + + let stream = Self::stream(reader, metadata.clone(), projection); + Self { + stream, + metadata, + schema, + } + } + + /// Get the metadata from the IPC file. + pub fn metadata(&self) -> &FileMetadata { + &self.metadata + } + + /// Get the projected schema from the IPC file. + pub fn schema(&self) -> &Schema { + &self.schema + } + + fn stream( + mut reader: R, + metadata: FileMetadata, + projection: Option>, + ) -> BoxStream<'a, Result>>> + where + R: AsyncRead + AsyncSeek + Unpin + Send + 'a, + { + async_stream::try_stream! { + let mut meta_buffer = vec![]; + let mut block_buffer = vec![]; + for block in 0..metadata.blocks.len() { + let chunk = read_batch( + &mut reader, + &metadata, + projection.as_deref(), + block, + &mut meta_buffer, + &mut block_buffer, + ).await?; + yield chunk; + } + } + .boxed() + } +} + +impl<'a> Stream for FileStream<'a> { + type Item = Result>>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().stream.poll_next_unpin(cx) + } +} + +/// Read the metadata from an IPC file. +pub async fn read_file_metadata_async(reader: &mut R) -> Result +where + R: AsyncRead + AsyncSeek + Unpin, +{ + // Check header + let mut magic = [0; 6]; + reader.read_exact(&mut magic).await?; + if magic != ARROW_MAGIC { + return Err(ArrowError::OutOfSpec( + "file does not contain correct Arrow header".to_string(), + )); + } + // Check footer + reader.seek(SeekFrom::End(-6)).await?; + reader.read_exact(&mut magic).await?; + if magic != ARROW_MAGIC { + return Err(ArrowError::OutOfSpec( + "file does not contain correct Arrow footer".to_string(), + )); + } + // Get footer size + let mut footer_size = [0; 4]; + reader.seek(SeekFrom::End(-10)).await?; + reader.read_exact(&mut footer_size).await?; + let footer_size = i32::from_le_bytes(footer_size); + // Read footer + let mut footer = vec![0; footer_size as usize]; + reader.seek(SeekFrom::End(-10 - footer_size as i64)).await?; + reader.read_exact(&mut footer).await?; + let footer = FooterRef::read_as_root(&footer[..]) + .map_err(|err| ArrowError::OutOfSpec(format!("unable to get root as footer: {:?}", err)))?; + + let blocks = footer.record_batches()?.ok_or_else(|| { + ArrowError::OutOfSpec("unable to get record batches from footer".to_string()) + })?; + let schema = footer + .schema()? + .ok_or_else(|| ArrowError::OutOfSpec("unable to get schema from footer".to_string()))?; + let (schema, ipc_schema) = fb_to_schema(schema)?; + let dictionary_blocks = footer.dictionaries()?; + let dictionaries = if let Some(blocks) = dictionary_blocks { + read_dictionaries(reader, &schema.fields[..], &ipc_schema, blocks).await? + } else { + Default::default() + }; + + Ok(FileMetadata { + schema, + ipc_schema, + blocks: blocks + .iter() + .map(|block| Ok(block.try_into()?)) + .collect::>>()?, + dictionaries, + }) +} + +async fn read_dictionaries( + reader: &mut R, + fields: &[Field], + ipc_schema: &IpcSchema, + blocks: Vector<'_, BlockRef<'_>>, +) -> Result +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let mut dictionaries = Default::default(); + let mut data = vec![]; + let mut buffer = vec![]; + + for block in blocks { + let offset = block.offset() as u64; + read_dictionary_message(reader, offset, &mut data).await?; + + let message = MessageRef::read_as_root(&data).map_err(|err| { + ArrowError::OutOfSpec(format!("unable to get root as message: {:?}", err)) + })?; + let header = message + .header()? + .ok_or_else(|| ArrowError::oos("message must have a header"))?; + match header { + MessageHeaderRef::DictionaryBatch(batch) => { + buffer.clear(); + buffer.resize(block.body_length() as usize, 0); + reader.read_exact(&mut buffer).await?; + let mut cursor = std::io::Cursor::new(&mut buffer); + read_dictionary(batch, fields, ipc_schema, &mut dictionaries, &mut cursor, 0)?; + } + other => { + return Err(ArrowError::OutOfSpec(format!( + "expected DictionaryBatch in dictionary blocks, found {:?}", + other, + ))) + } + } + } + Ok(dictionaries) +} + +async fn read_dictionary_message(reader: &mut R, offset: u64, data: &mut Vec) -> Result<()> +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let mut message_size = [0; 4]; + reader.seek(SeekFrom::Start(offset)).await?; + reader.read_exact(&mut message_size).await?; + if message_size == CONTINUATION_MARKER { + reader.read_exact(&mut message_size).await?; + } + let footer_size = i32::from_le_bytes(message_size); + data.clear(); + data.resize(footer_size as usize, 0); + reader.read_exact(data).await?; + + Ok(()) +} + +async fn read_batch( + reader: &mut R, + metadata: &FileMetadata, + projection: Option<&[usize]>, + block: usize, + meta_buffer: &mut Vec, + block_buffer: &mut Vec, +) -> Result>> +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let block = metadata.blocks[block]; + reader.seek(SeekFrom::Start(block.offset as u64)).await?; + let mut meta_buf = [0; 4]; + reader.read_exact(&mut meta_buf).await?; + if meta_buf == CONTINUATION_MARKER { + reader.read_exact(&mut meta_buf).await?; + } + let meta_len = i32::from_le_bytes(meta_buf) as usize; + meta_buffer.clear(); + meta_buffer.resize(meta_len, 0); + reader.read_exact(meta_buffer).await?; + + let message = MessageRef::read_as_root(&meta_buffer[..]) + .map_err(|err| ArrowError::oos(format!("unable to parse message: {:?}", err)))?; + let batch = get_serialized_batch(&message)?; + block_buffer.clear(); + block_buffer.resize(message.body_length()? as usize, 0); + reader.read_exact(block_buffer).await?; + let mut cursor = std::io::Cursor::new(block_buffer); + let chunk = read_record_batch( + batch, + &metadata.schema.fields, + &metadata.ipc_schema, + projection, + &metadata.dictionaries, + message.version()?, + &mut cursor, + 0, + )?; + Ok(chunk) +} diff --git a/src/io/ipc/read/mod.rs b/src/io/ipc/read/mod.rs index 3a45d4ecac6..207c33329fc 100644 --- a/src/io/ipc/read/mod.rs +++ b/src/io/ipc/read/mod.rs @@ -20,6 +20,10 @@ mod stream; #[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_read_async")))] pub mod stream_async; +#[cfg(feature = "io_ipc_read_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_read_async")))] +pub mod file_async; + pub use common::{read_dictionary, read_record_batch}; pub use reader::{read_file_metadata, FileMetadata, FileReader}; pub use schema::deserialize_schema; diff --git a/src/io/ipc/read/reader.rs b/src/io/ipc/read/reader.rs index 3254eb4a7ef..d31e975d666 100644 --- a/src/io/ipc/read/reader.rs +++ b/src/io/ipc/read/reader.rs @@ -26,10 +26,10 @@ pub struct FileMetadata { /// The blocks in the file /// /// A block indicates the regions in the file to read to get data - blocks: Vec, + pub(super) blocks: Vec, /// Dictionaries associated to each dict_id - dictionaries: Dictionaries, + pub(super) dictionaries: Dictionaries, } /// Arrow File reader @@ -166,7 +166,7 @@ pub fn read_file_metadata(reader: &mut R) -> Result( +pub(super) fn get_serialized_batch<'a>( message: &'a arrow_format::ipc::MessageRef, ) -> Result> { let header = message.header()?.ok_or_else(|| { diff --git a/src/io/ipc/write/file_async.rs b/src/io/ipc/write/file_async.rs new file mode 100644 index 00000000000..bc0f187ad79 --- /dev/null +++ b/src/io/ipc/write/file_async.rs @@ -0,0 +1,300 @@ +//! Async writer for IPC files. + +use super::common::{encode_chunk, DictionaryTracker, EncodedData, WriteOptions}; +use super::common_async::{write_continuation, write_message}; +use super::schema::serialize_schema; +use super::{default_ipc_fields, schema_to_bytes}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::*; +use crate::error::{ArrowError, Result}; +use crate::io::ipc::{IpcField, ARROW_MAGIC}; +use arrow_format::ipc::{planus::Builder, Block, Footer, MetadataVersion}; +use futures::{future::BoxFuture, AsyncWrite, AsyncWriteExt, FutureExt, Sink}; +use std::sync::Arc; +use std::task::Poll; + +type WriteOutput = (usize, Option, Vec, Option); + +/// Sink that writes array [`chunks`](Chunk) as an IPC file. +/// +/// The file header is automatically written before writing the first chunk, and the file footer is +/// automatically written when the sink is closed. +/// +/// The sink uses the same `ipc_fields` projection and `write_options` for each chunk. +/// +/// # Examples +/// +/// ``` +/// use std::sync::Arc; +/// use futures::{SinkExt, TryStreamExt, io::Cursor}; +/// use arrow2::array::{Array, Int32Array}; +/// use arrow2::datatypes::{DataType, Field, Schema}; +/// use arrow2::chunk::Chunk; +/// use arrow2::io::ipc::write::file_async::FileSink; +/// use arrow2::io::ipc::read::file_async::{read_file_metadata_async, FileStream}; +/// # futures::executor::block_on(async move { +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// +/// let mut buffer = Cursor::new(vec![]); +/// let mut sink = FileSink::new( +/// &mut buffer, +/// schema, +/// None, +/// Default::default(), +/// ); +/// +/// // Write chunks to file +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![Arc::new(values) as Arc]); +/// sink.feed(chunk).await?; +/// } +/// sink.close().await?; +/// drop(sink); +/// +/// // Read chunks from file +/// buffer.set_position(0); +/// let metadata = read_file_metadata_async(&mut buffer).await?; +/// let mut stream = FileStream::new(buffer, metadata, None); +/// let chunks = stream.try_collect::>().await?; +/// # arrow2::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct FileSink<'a, W: AsyncWrite + Unpin + Send + 'a> { + writer: Option, + task: Option>>>, + options: WriteOptions, + dictionary_tracker: DictionaryTracker, + offset: usize, + fields: Vec, + record_blocks: Vec, + dictionary_blocks: Vec, + schema: Schema, +} + +impl<'a, W> FileSink<'a, W> +where + W: AsyncWrite + Unpin + Send + 'a, +{ + /// Create a new file writer. + pub fn new( + writer: W, + schema: Schema, + ipc_fields: Option>, + options: WriteOptions, + ) -> Self { + let fields = ipc_fields.unwrap_or_else(|| default_ipc_fields(&schema.fields)); + let encoded = EncodedData { + ipc_message: schema_to_bytes(&schema, &fields), + arrow_data: vec![], + }; + let task = Some(Self::start(writer, encoded).boxed()); + Self { + writer: None, + task, + options, + fields, + offset: 0, + schema, + dictionary_tracker: DictionaryTracker::new(true), + record_blocks: vec![], + dictionary_blocks: vec![], + } + } + + async fn start(mut writer: W, encoded: EncodedData) -> Result> { + writer.write_all(&ARROW_MAGIC[..]).await?; + writer.write_all(&[0, 0]).await?; + let (meta, data) = write_message(&mut writer, encoded).await?; + + Ok((meta + data + 8, None, vec![], Some(writer))) + } + + async fn write( + mut writer: W, + mut offset: usize, + record: EncodedData, + dictionaries: Vec, + ) -> Result> { + let mut dict_blocks = vec![]; + for dict in dictionaries { + let (meta, data) = write_message(&mut writer, dict).await?; + let block = Block { + offset: offset as i64, + meta_data_length: meta as i32, + body_length: data as i64, + }; + dict_blocks.push(block); + offset += meta + data; + } + let (meta, data) = write_message(&mut writer, record).await?; + let block = Block { + offset: offset as i64, + meta_data_length: meta as i32, + body_length: data as i64, + }; + offset += meta + data; + Ok((offset, Some(block), dict_blocks, Some(writer))) + } + + async fn finish(mut writer: W, footer: Footer) -> Result> { + write_continuation(&mut writer, 0).await?; + let footer = { + let mut builder = Builder::new(); + builder.finish(&footer, None).to_owned() + }; + writer.write_all(&footer[..]).await?; + writer + .write_all(&(footer.len() as i32).to_le_bytes()) + .await?; + writer.write_all(&ARROW_MAGIC).await?; + writer.close().await?; + + Ok((0, None, vec![], None)) + } + + fn poll_write(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if let Some(task) = &mut self.task { + match futures::ready!(task.poll_unpin(cx)) { + Ok((offset, record, mut dictionaries, writer)) => { + self.task = None; + self.writer = writer; + self.offset = offset; + if let Some(block) = record { + self.record_blocks.push(block); + } + self.dictionary_blocks.append(&mut dictionaries); + Poll::Ready(Ok(())) + } + Err(error) => { + self.task = None; + Poll::Ready(Err(error)) + } + } + } else { + Poll::Ready(Ok(())) + } + } +} + +impl<'a, W> Sink>> for FileSink<'a, W> +where + W: AsyncWrite + Unpin + Send + 'a, +{ + type Error = ArrowError; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().poll_write(cx) + } + + fn start_send(self: std::pin::Pin<&mut Self>, item: Chunk>) -> Result<()> { + let this = self.get_mut(); + + if let Some(writer) = this.writer.take() { + let (dictionaries, record) = encode_chunk( + &item, + &this.fields[..], + &mut this.dictionary_tracker, + &this.options, + )?; + + this.task = Some(Self::write(writer, this.offset, record, dictionaries).boxed()); + Ok(()) + } else { + Err(ArrowError::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "writer is closed", + ))) + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().poll_write(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + match futures::ready!(this.poll_write(cx)) { + Ok(()) => { + if let Some(writer) = this.writer.take() { + let schema = serialize_schema(&this.schema, &this.fields); + let footer = Footer { + version: MetadataVersion::V5, + schema: Some(Box::new(schema)), + dictionaries: Some(std::mem::take(&mut this.dictionary_blocks)), + record_batches: Some(std::mem::take(&mut this.record_blocks)), + custom_metadata: None, + }; + this.task = Some(Self::finish(writer, footer).boxed()); + this.poll_write(cx) + } else { + Poll::Ready(Ok(())) + } + } + Err(error) => Poll::Ready(Err(error)), + } + } +} + +#[cfg(test)] +mod tests { + use super::FileSink; + use crate::{ + array::{Array, Float32Array, Int32Array}, + chunk::Chunk, + datatypes::{DataType, Field, Schema}, + io::ipc::read::file_async::{read_file_metadata_async, FileStream}, + }; + use futures::{io::Cursor, SinkExt, TryStreamExt}; + use std::sync::Arc; + + // Verify round trip data integrity when using async read + write. + #[test] + fn test_file_async_roundtrip() { + futures::executor::block_on(async move { + let mut data = vec![]; + for i in 0..5 { + let a1 = Int32Array::from(&[Some(i), None, Some(i + 1)]); + let a2 = Float32Array::from(&[None, Some(i as f32), None]); + let chunk = Chunk::new(vec![ + Arc::new(a1) as Arc, + Arc::new(a2) as Arc, + ]); + data.push(chunk); + } + let schema = Schema::from(vec![ + Field::new("a1", DataType::Int32, true), + Field::new("a2", DataType::Float32, true), + ]); + + let mut buffer = Cursor::new(Vec::new()); + let mut sink = FileSink::new(&mut buffer, schema.clone(), None, Default::default()); + for chunk in &data { + sink.feed(chunk.clone()).await.unwrap(); + } + sink.close().await.unwrap(); + drop(sink); + + buffer.set_position(0); + let metadata = read_file_metadata_async(&mut buffer).await.unwrap(); + assert_eq!(schema, metadata.schema); + let stream = FileStream::new(buffer, metadata, None); + let out = stream.try_collect::>().await.unwrap(); + for i in 0..5 { + assert_eq!(data[i], out[i]); + } + }) + } +} diff --git a/src/io/ipc/write/mod.rs b/src/io/ipc/write/mod.rs index 6331e9dc9a0..41abe67ebf2 100644 --- a/src/io/ipc/write/mod.rs +++ b/src/io/ipc/write/mod.rs @@ -19,6 +19,10 @@ mod common_async; #[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_write_async")))] pub mod stream_async; +#[cfg(feature = "io_ipc_write_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_write_async")))] +pub mod file_async; + use crate::datatypes::{DataType, Field}; use super::IpcField; diff --git a/src/io/ipc/write/stream_async.rs b/src/io/ipc/write/stream_async.rs index 7da157c1f4b..85feecf8ecc 100644 --- a/src/io/ipc/write/stream_async.rs +++ b/src/io/ipc/write/stream_async.rs @@ -1,14 +1,15 @@ //! `async` writing of arrow streams use std::sync::Arc; -use futures::AsyncWrite; - use super::super::IpcField; pub use super::common::WriteOptions; use super::common::{encode_chunk, DictionaryTracker, EncodedData}; use super::common_async::{write_continuation, write_message}; use super::{default_ipc_fields, schema_to_bytes}; +use futures::{future::BoxFuture, AsyncWrite, FutureExt, Sink}; +use std::{pin::Pin, task::Poll}; + use crate::array::Array; use crate::chunk::Chunk; use crate::datatypes::*; @@ -106,3 +107,206 @@ impl StreamWriter { self.writer } } + +/// A sink that writes array [`chunks`](Chunk) as an IPC stream. +/// +/// The stream header is automatically written before writing the first chunk. +/// +/// The sink uses the same `ipc_fields` projection and `write_options` for each chunk. +/// For more fine-grained control over those parameters, see [`StreamWriter`]. +/// +/// # Examples +/// +/// ``` +/// use std::sync::Arc; +/// use futures::SinkExt; +/// use arrow2::array::{Array, Int32Array}; +/// use arrow2::datatypes::{DataType, Field, Schema}; +/// use arrow2::chunk::Chunk; +/// # use arrow2::io::ipc::write::stream_async::StreamSink; +/// # futures::executor::block_on(async move { +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// +/// let mut buffer = vec![]; +/// let mut sink = StreamSink::new( +/// &mut buffer, +/// schema, +/// None, +/// Default::default(), +/// ); +/// +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![Arc::new(values) as Arc]); +/// sink.feed(chunk).await?; +/// } +/// sink.close().await?; +/// # arrow2::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct StreamSink<'a, W: AsyncWrite + Unpin + Send + 'a> { + sink: Option>, + task: Option>>>>, + schema: Arc, + ipc_fields: Arc>>, +} + +impl<'a, W> StreamSink<'a, W> +where + W: AsyncWrite + Unpin + Send + 'a, +{ + /// Create a new [`StreamSink`]. + pub fn new( + writer: W, + schema: Schema, + ipc_fields: Option<&[IpcField]>, + write_options: WriteOptions, + ) -> Self { + let mut sink = StreamWriter::new(writer, write_options); + let schema = Arc::new(schema); + let s = schema.clone(); + let ipc_fields = Arc::new(ipc_fields.map(|f| f.to_vec())); + let f = ipc_fields.clone(); + let task = Some( + async move { + sink.start(&s, f.as_deref()).await?; + Ok(Some(sink)) + } + .boxed(), + ); + Self { + sink: None, + task, + schema, + ipc_fields, + } + } + + fn poll_complete(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if let Some(task) = &mut self.task { + match futures::ready!(task.poll_unpin(cx)) { + Ok(sink) => { + self.sink = sink; + self.task = None; + Poll::Ready(Ok(())) + } + Err(error) => { + self.task = None; + Poll::Ready(Err(error)) + } + } + } else { + Poll::Ready(Ok(())) + } + } +} + +impl<'a, W> Sink>> for StreamSink<'a, W> +where + W: AsyncWrite + Unpin + Send, +{ + type Error = ArrowError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + self.get_mut().poll_complete(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Chunk>) -> Result<()> { + let this = self.get_mut(); + if let Some(mut sink) = this.sink.take() { + let schema = this.schema.clone(); + let fields = this.ipc_fields.clone(); + this.task = Some( + async move { + sink.write(&item, &schema, fields.as_deref()).await?; + Ok(Some(sink)) + } + .boxed(), + ); + Ok(()) + } else { + Err(ArrowError::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "writer closed".to_string(), + ))) + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + self.get_mut().poll_complete(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + match this.poll_complete(cx) { + Poll::Ready(Ok(())) => { + if let Some(mut sink) = this.sink.take() { + this.task = Some( + async move { + sink.finish().await?; + Ok(None) + } + .boxed(), + ); + this.poll_complete(cx) + } else { + Poll::Ready(Ok(())) + } + } + res => res, + } + } +} + +#[cfg(test)] +mod tests { + use super::StreamSink; + use crate::{ + array::{Array, Float32Array, Int32Array}, + chunk::Chunk, + datatypes::{DataType, Field, Schema}, + io::ipc::read::stream_async::{read_stream_metadata_async, AsyncStreamReader}, + }; + use futures::{SinkExt, TryStreamExt}; + use std::sync::Arc; + + // Verify round trip data integrity when using async read + write. + #[test] + fn test_stream_async_roundtrip() { + futures::executor::block_on(async move { + let mut data = vec![]; + for i in 0..5 { + let a1 = Int32Array::from(&[Some(i), None, Some(i + 1)]); + let a2 = Float32Array::from(&[None, Some(i as f32), None]); + let chunk = Chunk::new(vec![ + Arc::new(a1) as Arc, + Arc::new(a2) as Arc, + ]); + data.push(chunk); + } + let schema = Schema::from(vec![ + Field::new("a1", DataType::Int32, true), + Field::new("a2", DataType::Float32, true), + ]); + + let mut buffer = vec![]; + let mut sink = StreamSink::new(&mut buffer, schema.clone(), None, Default::default()); + for chunk in &data { + sink.feed(chunk.clone()).await.unwrap(); + } + sink.close().await.unwrap(); + drop(sink); + + let mut reader = &buffer[..]; + let metadata = read_stream_metadata_async(&mut reader).await.unwrap(); + assert_eq!(schema, metadata.schema); + let stream = AsyncStreamReader::new(reader, metadata); + let out = stream.try_collect::>().await.unwrap(); + for i in 0..5 { + assert_eq!(data[i], out[i]); + } + }) + } +} diff --git a/src/io/parquet/write/mod.rs b/src/io/parquet/write/mod.rs index 5fec264cfd3..599657e6bd4 100644 --- a/src/io/parquet/write/mod.rs +++ b/src/io/parquet/write/mod.rs @@ -8,6 +8,7 @@ mod levels; mod primitive; mod row_group; mod schema; +mod sink; mod stream; mod utf8; mod utils; @@ -39,6 +40,7 @@ pub use parquet2::{ pub use file::FileWriter; pub use row_group::{row_group_iter, RowGroupIterator}; pub use schema::to_parquet_type; +pub use sink::FileSink; pub use stream::FileStreamer; pub(self) fn decimal_length_from_precision(precision: usize) -> usize { diff --git a/src/io/parquet/write/sink.rs b/src/io/parquet/write/sink.rs new file mode 100644 index 00000000000..c5e251290ba --- /dev/null +++ b/src/io/parquet/write/sink.rs @@ -0,0 +1,283 @@ +use crate::{ + array::Array, + chunk::Chunk, + datatypes::Schema, + error::ArrowError, + io::parquet::write::{Encoding, FileStreamer, SchemaDescriptor, WriteOptions}, +}; +use futures::{future::BoxFuture, AsyncWrite, FutureExt, Sink, TryFutureExt}; +use parquet2::metadata::KeyValue; +use std::{collections::HashMap, pin::Pin, sync::Arc, task::Poll}; + +/// Sink that writes array [`chunks`](Chunk) as a Parquet file. +/// +/// # Examples +/// +/// ``` +/// use std::sync::Arc; +/// use futures::SinkExt; +/// use arrow2::array::{Array, Int32Array}; +/// use arrow2::datatypes::{DataType, Field, Schema}; +/// use arrow2::chunk::Chunk; +/// use arrow2::io::parquet::write::{Encoding, WriteOptions, Compression, Version}; +/// # use arrow2::io::parquet::write::FileSink; +/// # futures::executor::block_on(async move { +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// let encoding = vec![Encoding::Plain]; +/// let options = WriteOptions { +/// write_statistics: true, +/// compression: Compression::Uncompressed, +/// version: Version::V2, +/// }; +/// +/// let mut buffer = vec![]; +/// let mut sink = FileSink::try_new( +/// &mut buffer, +/// schema, +/// encoding, +/// options, +/// )?; +/// +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![Arc::new(values) as Arc]); +/// sink.feed(chunk).await?; +/// } +/// sink.metadata.insert(String::from("key"), Some(String::from("value"))); +/// sink.close().await?; +/// # arrow2::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct FileSink<'a, W: AsyncWrite + Send + Unpin> { + writer: Option>, + task: Option>, ArrowError>>>, + options: WriteOptions, + encoding: Vec, + schema: SchemaDescriptor, + /// Key-value metadata that will be written to the file on close. + pub metadata: HashMap>, +} + +impl<'a, W> FileSink<'a, W> +where + W: AsyncWrite + Send + Unpin + 'a, +{ + /// Create a new sink that writes arrays to the provided `writer`. + /// + /// # Error + /// If the Arrow schema can't be converted to a valid Parquet schema. + pub fn try_new( + writer: W, + schema: Schema, + encoding: Vec, + options: WriteOptions, + ) -> Result { + let mut writer = FileStreamer::try_new(writer, schema.clone(), options)?; + let schema = crate::io::parquet::write::to_parquet_schema(&schema)?; + let task = Some( + async move { + writer.start().await?; + Ok(Some(writer)) + } + .boxed(), + ); + Ok(Self { + writer: None, + task, + options, + schema, + encoding, + metadata: HashMap::default(), + }) + } + + fn poll_complete( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if let Some(task) = &mut self.task { + match futures::ready!(task.poll_unpin(cx)) { + Ok(writer) => { + self.task = None; + self.writer = writer; + Poll::Ready(Ok(())) + } + Err(error) => { + self.task = None; + Poll::Ready(Err(error)) + } + } + } else { + Poll::Ready(Ok(())) + } + } +} + +impl<'a, W> Sink>> for FileSink<'a, W> +where + W: AsyncWrite + Send + Unpin + 'a, +{ + type Error = ArrowError; + + fn start_send(self: Pin<&mut Self>, item: Chunk>) -> Result<(), Self::Error> { + let this = self.get_mut(); + if let Some(mut writer) = this.writer.take() { + let count = item.len(); + let rows = crate::io::parquet::write::row_group_iter( + item, + this.encoding.clone(), + this.schema.columns().to_vec(), + this.options, + ); + this.task = Some(Box::pin(async move { + writer.write(rows, count).await?; + Ok(Some(writer)) + })); + Ok(()) + } else { + Err(ArrowError::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "writer closed".to_string(), + ))) + } + } + + fn poll_ready( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().poll_complete(cx) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().poll_complete(cx) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + match futures::ready!(this.poll_complete(cx)) { + Ok(()) => { + let writer = this.writer.take(); + if let Some(writer) = writer { + let meta = std::mem::take(&mut this.metadata); + let metadata = if meta.is_empty() { + None + } else { + Some( + meta.into_iter() + .map(|(k, v)| KeyValue::new(k, v)) + .collect::>(), + ) + }; + + this.task = Some(writer.end(metadata).map_ok(|_| None).boxed()); + this.poll_complete(cx) + } else { + Poll::Ready(Ok(())) + } + } + Err(error) => Poll::Ready(Err(error)), + } + } +} + +#[cfg(test)] +mod tests { + use futures::{future::BoxFuture, io::Cursor, SinkExt}; + use parquet2::{ + compression::Compression, + write::{Version, WriteOptions}, + }; + use std::{collections::HashMap, sync::Arc}; + + use crate::{ + array::{Array, Float32Array, Int32Array}, + chunk::Chunk, + datatypes::{DataType, Field, Schema}, + error::Result, + io::parquet::{ + read::{ + infer_schema, read_columns_many_async, read_metadata_async, RowGroupDeserializer, + }, + write::Encoding, + }, + }; + + use super::FileSink; + + #[test] + fn test_parquet_async_roundtrip() { + futures::executor::block_on(async move { + let mut data = vec![]; + for i in 0..5 { + let a1 = Int32Array::from(&[Some(i), None, Some(i + 1)]); + let a2 = Float32Array::from(&[None, Some(i as f32), None]); + let chunk = Chunk::new(vec![ + Arc::new(a1) as Arc, + Arc::new(a2) as Arc, + ]); + data.push(chunk); + } + let schema = Schema::from(vec![ + Field::new("a1", DataType::Int32, true), + Field::new("a2", DataType::Float32, true), + ]); + let encoding = vec![Encoding::Plain, Encoding::Plain]; + let options = WriteOptions { + write_statistics: true, + compression: Compression::Uncompressed, + version: Version::V2, + }; + + let mut buffer = Cursor::new(Vec::new()); + let mut sink = + FileSink::try_new(&mut buffer, schema.clone(), encoding, options).unwrap(); + sink.metadata + .insert(String::from("key"), Some("value".to_string())); + for chunk in &data { + sink.feed(chunk.clone()).await.unwrap(); + } + sink.close().await.unwrap(); + drop(sink); + + buffer.set_position(0); + let metadata = read_metadata_async(&mut buffer).await.unwrap(); + let kv = HashMap::>::from_iter( + metadata + .key_value_metadata() + .to_owned() + .unwrap() + .into_iter() + .map(|kv| (kv.key, kv.value)), + ); + assert_eq!(kv.get("key").unwrap(), &Some("value".to_string())); + let read_schema = infer_schema(&metadata).unwrap(); + assert_eq!(read_schema, schema); + let factory = || Box::pin(futures::future::ready(Ok(buffer.clone()))) as BoxFuture<_>; + + let mut out = vec![]; + for group in &metadata.row_groups { + let column_chunks = + read_columns_many_async(factory, group, schema.fields.clone(), None) + .await + .unwrap(); + let chunks = + RowGroupDeserializer::new(column_chunks, group.num_rows() as usize, None); + let mut chunks = chunks.collect::>>().unwrap(); + out.append(&mut chunks); + } + + for i in 0..5 { + assert_eq!(data[i], out[i]); + } + }) + } +}