diff --git a/src/io/ipc/write/stream_async.rs b/src/io/ipc/write/stream_async.rs index 8a2485b4dfa..85feecf8ecc 100644 --- a/src/io/ipc/write/stream_async.rs +++ b/src/io/ipc/write/stream_async.rs @@ -292,7 +292,7 @@ mod tests { ]); let mut buffer = vec![]; - let mut sink = StreamSink::new(&mut buffer, schema, None, Default::default()); + let mut sink = StreamSink::new(&mut buffer, schema.clone(), None, Default::default()); for chunk in &data { sink.feed(chunk.clone()).await.unwrap(); } @@ -301,6 +301,7 @@ mod tests { let mut reader = &buffer[..]; let metadata = read_stream_metadata_async(&mut reader).await.unwrap(); + assert_eq!(schema, metadata.schema); let stream = AsyncStreamReader::new(reader, metadata); let out = stream.try_collect::>().await.unwrap(); for i in 0..5 { diff --git a/src/io/parquet/write/mod.rs b/src/io/parquet/write/mod.rs index 863777bee32..58f7931d821 100644 --- a/src/io/parquet/write/mod.rs +++ b/src/io/parquet/write/mod.rs @@ -41,7 +41,7 @@ pub use file::FileWriter; pub use row_group::{row_group_iter, RowGroupIterator}; pub use schema::to_parquet_type; pub use stream::FileStreamer; -pub use sink::ParquetSink; +pub use sink::FileSink; pub(self) fn decimal_length_from_precision(precision: usize) -> usize { // digits = floor(log_10(2^(8*n - 1) - 1)) diff --git a/src/io/parquet/write/sink.rs b/src/io/parquet/write/sink.rs index 77c96cf5071..c5e251290ba 100644 --- a/src/io/parquet/write/sink.rs +++ b/src/io/parquet/write/sink.rs @@ -6,22 +6,66 @@ use crate::{ io::parquet::write::{Encoding, FileStreamer, SchemaDescriptor, WriteOptions}, }; use futures::{future::BoxFuture, AsyncWrite, FutureExt, Sink, TryFutureExt}; -use std::{pin::Pin, sync::Arc, task::Poll}; +use parquet2::metadata::KeyValue; +use std::{collections::HashMap, pin::Pin, sync::Arc, task::Poll}; -/// Sink that writes array [`chunks`](Chunk) to an async writer. -pub struct ParquetSink<'a, W: AsyncWrite + Send + Unpin> { +/// Sink that writes array [`chunks`](Chunk) as a Parquet file. +/// +/// # Examples +/// +/// ``` +/// use std::sync::Arc; +/// use futures::SinkExt; +/// use arrow2::array::{Array, Int32Array}; +/// use arrow2::datatypes::{DataType, Field, Schema}; +/// use arrow2::chunk::Chunk; +/// use arrow2::io::parquet::write::{Encoding, WriteOptions, Compression, Version}; +/// # use arrow2::io::parquet::write::FileSink; +/// # futures::executor::block_on(async move { +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// let encoding = vec![Encoding::Plain]; +/// let options = WriteOptions { +/// write_statistics: true, +/// compression: Compression::Uncompressed, +/// version: Version::V2, +/// }; +/// +/// let mut buffer = vec![]; +/// let mut sink = FileSink::try_new( +/// &mut buffer, +/// schema, +/// encoding, +/// options, +/// )?; +/// +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![Arc::new(values) as Arc]); +/// sink.feed(chunk).await?; +/// } +/// sink.metadata.insert(String::from("key"), Some(String::from("value"))); +/// sink.close().await?; +/// # arrow2::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct FileSink<'a, W: AsyncWrite + Send + Unpin> { writer: Option>, task: Option>, ArrowError>>>, options: WriteOptions, encoding: Vec, schema: SchemaDescriptor, + /// Key-value metadata that will be written to the file on close. + pub metadata: HashMap>, } -impl<'a, W> ParquetSink<'a, W> +impl<'a, W> FileSink<'a, W> where W: AsyncWrite + Send + Unpin + 'a, { /// Create a new sink that writes arrays to the provided `writer`. + /// /// # Error /// If the Arrow schema can't be converted to a valid Parquet schema. pub fn try_new( @@ -45,6 +89,7 @@ where options, schema, encoding, + metadata: HashMap::default(), }) } @@ -70,7 +115,7 @@ where } } -impl<'a, W> Sink>> for ParquetSink<'a, W> +impl<'a, W> Sink>> for FileSink<'a, W> where W: AsyncWrite + Send + Unpin + 'a, { @@ -122,7 +167,18 @@ where Ok(()) => { let writer = this.writer.take(); if let Some(writer) = writer { - this.task = Some(writer.end(None).map_ok(|_| None).boxed()); + let meta = std::mem::take(&mut this.metadata); + let metadata = if meta.is_empty() { + None + } else { + Some( + meta.into_iter() + .map(|(k, v)| KeyValue::new(k, v)) + .collect::>(), + ) + }; + + this.task = Some(writer.end(metadata).map_ok(|_| None).boxed()); this.poll_complete(cx) } else { Poll::Ready(Ok(())) @@ -132,3 +188,96 @@ where } } } + +#[cfg(test)] +mod tests { + use futures::{future::BoxFuture, io::Cursor, SinkExt}; + use parquet2::{ + compression::Compression, + write::{Version, WriteOptions}, + }; + use std::{collections::HashMap, sync::Arc}; + + use crate::{ + array::{Array, Float32Array, Int32Array}, + chunk::Chunk, + datatypes::{DataType, Field, Schema}, + error::Result, + io::parquet::{ + read::{ + infer_schema, read_columns_many_async, read_metadata_async, RowGroupDeserializer, + }, + write::Encoding, + }, + }; + + use super::FileSink; + + #[test] + fn test_parquet_async_roundtrip() { + futures::executor::block_on(async move { + let mut data = vec![]; + for i in 0..5 { + let a1 = Int32Array::from(&[Some(i), None, Some(i + 1)]); + let a2 = Float32Array::from(&[None, Some(i as f32), None]); + let chunk = Chunk::new(vec![ + Arc::new(a1) as Arc, + Arc::new(a2) as Arc, + ]); + data.push(chunk); + } + let schema = Schema::from(vec![ + Field::new("a1", DataType::Int32, true), + Field::new("a2", DataType::Float32, true), + ]); + let encoding = vec![Encoding::Plain, Encoding::Plain]; + let options = WriteOptions { + write_statistics: true, + compression: Compression::Uncompressed, + version: Version::V2, + }; + + let mut buffer = Cursor::new(Vec::new()); + let mut sink = + FileSink::try_new(&mut buffer, schema.clone(), encoding, options).unwrap(); + sink.metadata + .insert(String::from("key"), Some("value".to_string())); + for chunk in &data { + sink.feed(chunk.clone()).await.unwrap(); + } + sink.close().await.unwrap(); + drop(sink); + + buffer.set_position(0); + let metadata = read_metadata_async(&mut buffer).await.unwrap(); + let kv = HashMap::>::from_iter( + metadata + .key_value_metadata() + .to_owned() + .unwrap() + .into_iter() + .map(|kv| (kv.key, kv.value)), + ); + assert_eq!(kv.get("key").unwrap(), &Some("value".to_string())); + let read_schema = infer_schema(&metadata).unwrap(); + assert_eq!(read_schema, schema); + let factory = || Box::pin(futures::future::ready(Ok(buffer.clone()))) as BoxFuture<_>; + + let mut out = vec![]; + for group in &metadata.row_groups { + let column_chunks = + read_columns_many_async(factory, group, schema.fields.clone(), None) + .await + .unwrap(); + let chunks = + RowGroupDeserializer::new(column_chunks, group.num_rows() as usize, None); + let mut chunks = chunks.collect::>>().unwrap(); + out.append(&mut chunks); + } + + for i in 0..5 { + assert_eq!(data[i], out[i]); + } + }) + } +}