From 6ac08656a8c773b1e10cbb7abbfcfcb9622f3055 Mon Sep 17 00:00:00 2001 From: Dexter Duckworth Date: Thu, 3 Mar 2022 15:56:23 -0500 Subject: [PATCH] Updated IPC Record type to support reference types. --- src/io/ipc/write/common.rs | 49 ++++++++++++++++++++------- src/io/ipc/write/file_async.rs | 19 ++++++----- src/io/ipc/write/stream_async.rs | 26 ++++++++------ tests/it/io/ipc/write_file_async.rs | 3 +- tests/it/io/ipc/write_stream_async.rs | 3 +- 5 files changed, 63 insertions(+), 37 deletions(-) diff --git a/src/io/ipc/write/common.rs b/src/io/ipc/write/common.rs index 19034247eb7..14107ed9ac2 100644 --- a/src/io/ipc/write/common.rs +++ b/src/io/ipc/write/common.rs @@ -1,3 +1,4 @@ +use std::borrow::{Borrow, Cow}; use std::sync::Arc; use arrow_format::ipc::planus::Builder; @@ -382,30 +383,52 @@ pub(crate) fn pad_to_8(len: usize) -> usize { /// An array [`Chunk`] with optional accompanying IPC fields. #[derive(Debug, Clone, PartialEq)] -pub struct Record { - /// Chunk of Arrow columns to be written in IPC format. - pub columns: Chunk>, - /// Optional IPC field list used to map Arrow columns to IPC dictionaries. - pub fields: Option>, +pub struct Record<'a> { + columns: Cow<'a, Chunk>>, + fields: Option>, } -impl From>> for Record { +impl<'a> Record<'a> { + /// Get the IPC fields for this record. + pub fn fields(&self) -> Option<&[IpcField]> { + self.fields.as_deref() + } + + /// Get the Arrow columns in this record. + pub fn columns(&self) -> &Chunk> { + self.columns.borrow() + } +} + +impl From>> for Record<'static> { fn from(columns: Chunk>) -> Self { Self { - columns, + columns: Cow::Owned(columns), fields: None, } } } -impl From<(Chunk>, Option>)> for Record { - fn from((columns, fields): (Chunk>, Option>)) -> Self { - Self { columns, fields } +impl<'a, F> From<(Chunk>, Option)> for Record<'a> +where + F: Into>, +{ + fn from((columns, fields): (Chunk>, Option)) -> Self { + Self { + columns: Cow::Owned(columns), + fields: fields.map(|f| f.into()), + } } } -impl From for (Chunk>, Option>) { - fn from(record: Record) -> Self { - (record.columns, record.fields) +impl<'a, F> From<(&'a Chunk>, Option)> for Record<'a> +where + F: Into>, +{ + fn from((columns, fields): (&'a Chunk>, Option)) -> Self { + Self { + columns: Cow::Borrowed(columns), + fields: fields.map(|f| f.into()), + } } } diff --git a/src/io/ipc/write/file_async.rs b/src/io/ipc/write/file_async.rs index 12b66fb1602..930bc4d5c5b 100644 --- a/src/io/ipc/write/file_async.rs +++ b/src/io/ipc/write/file_async.rs @@ -1,5 +1,10 @@ //! Async writer for IPC files. +use std::task::Poll; + +use arrow_format::ipc::{planus::Builder, Block, Footer, MetadataVersion}; +use futures::{future::BoxFuture, AsyncWrite, AsyncWriteExt, FutureExt, Sink}; + use super::common::{encode_chunk, DictionaryTracker, EncodedData, WriteOptions}; use super::common_async::{write_continuation, write_message}; use super::schema::serialize_schema; @@ -7,9 +12,6 @@ use super::{default_ipc_fields, schema_to_bytes, Record}; use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::io::ipc::{IpcField, ARROW_MAGIC}; -use arrow_format::ipc::{planus::Builder, Block, Footer, MetadataVersion}; -use futures::{future::BoxFuture, AsyncWrite, AsyncWriteExt, FutureExt, Sink}; -use std::task::Poll; type WriteOutput = (usize, Option, Vec, Option); @@ -175,7 +177,7 @@ where } } -impl<'a, W> Sink for FileSink<'a, W> +impl<'a, W> Sink> for FileSink<'a, W> where W: AsyncWrite + Unpin + Send + 'a, { @@ -188,16 +190,15 @@ where self.get_mut().poll_write(cx) } - fn start_send(self: std::pin::Pin<&mut Self>, item: Record) -> Result<()> { + fn start_send(self: std::pin::Pin<&mut Self>, item: Record<'_>) -> Result<()> { let this = self.get_mut(); if let Some(writer) = this.writer.take() { - let Record { columns, fields } = item; - let fields = fields.unwrap_or_else(|| this.fields.clone()); + let fields = item.fields().unwrap_or_else(|| &this.fields[..]); let (dictionaries, record) = encode_chunk( - &columns, - &fields[..], + item.columns(), + fields, &mut this.dictionary_tracker, &this.options, )?; diff --git a/src/io/ipc/write/stream_async.rs b/src/io/ipc/write/stream_async.rs index ff48c915096..033fa06038b 100644 --- a/src/io/ipc/write/stream_async.rs +++ b/src/io/ipc/write/stream_async.rs @@ -1,14 +1,15 @@ //! `async` writing of arrow streams +use std::{pin::Pin, task::Poll}; + +use futures::{future::BoxFuture, AsyncWrite, FutureExt, Sink}; + use super::super::IpcField; pub use super::common::WriteOptions; use super::common::{encode_chunk, DictionaryTracker, EncodedData}; use super::common_async::{write_continuation, write_message}; use super::{default_ipc_fields, schema_to_bytes, Record}; -use futures::{future::BoxFuture, AsyncWrite, FutureExt, Sink}; -use std::{pin::Pin, task::Poll}; - use crate::datatypes::*; use crate::error::{ArrowError, Result}; @@ -93,11 +94,14 @@ where .boxed() } - fn write(&mut self, record: &Record) -> Result<()> { - let Record { columns, fields } = record; - let fields = fields.as_ref().unwrap_or(&self.fields); - let (dictionaries, message) = - encode_chunk(columns, fields, &mut self.dictionary_tracker, &self.options)?; + fn write(&mut self, record: Record<'_>) -> Result<()> { + let fields = record.fields().unwrap_or(&self.fields[..]); + let (dictionaries, message) = encode_chunk( + record.columns(), + fields, + &mut self.dictionary_tracker, + &self.options, + )?; if let Some(mut writer) = self.writer.take() { self.task = Some( @@ -138,7 +142,7 @@ where } } -impl<'a, W> Sink for StreamSink<'a, W> +impl<'a, W> Sink> for StreamSink<'a, W> where W: AsyncWrite + Unpin + Send, { @@ -148,8 +152,8 @@ where self.get_mut().poll_complete(cx) } - fn start_send(self: Pin<&mut Self>, item: Record) -> Result<()> { - self.get_mut().write(&item) + fn start_send(self: Pin<&mut Self>, item: Record<'_>) -> Result<()> { + self.get_mut().write(item) } fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { diff --git a/tests/it/io/ipc/write_file_async.rs b/tests/it/io/ipc/write_file_async.rs index d907f116175..ffa6530feed 100644 --- a/tests/it/io/ipc/write_file_async.rs +++ b/tests/it/io/ipc/write_file_async.rs @@ -25,8 +25,7 @@ async fn write_( let options = WriteOptions { compression: None }; let mut sink = FileSink::new(&mut result, schema, Some(ipc_fields.to_vec()), options); for batch in batches { - sink.feed((batch.clone(), Some(ipc_fields.to_vec())).into()) - .await?; + sink.feed((batch, Some(ipc_fields)).into()).await?; } sink.close().await?; drop(sink); diff --git a/tests/it/io/ipc/write_stream_async.rs b/tests/it/io/ipc/write_stream_async.rs index c5ebceb27ae..2840692a75f 100644 --- a/tests/it/io/ipc/write_stream_async.rs +++ b/tests/it/io/ipc/write_stream_async.rs @@ -25,8 +25,7 @@ async fn write_( let options = stream_async::WriteOptions { compression: None }; let mut sink = StreamSink::new(&mut result, schema, Some(ipc_fields.to_vec()), options); for batch in batches { - sink.feed((batch.clone(), Some(ipc_fields.to_vec())).into()) - .await?; + sink.feed((batch, Some(ipc_fields)).into()).await?; } sink.close().await?; drop(sink);