diff --git a/src/io/ipc/read/mod.rs b/src/io/ipc/read/mod.rs index adc2b41e477..a374df53287 100644 --- a/src/io/ipc/read/mod.rs +++ b/src/io/ipc/read/mod.rs @@ -13,5 +13,5 @@ mod reader; mod stream; pub use common::{read_dictionary, read_record_batch}; -pub use reader::{read_file_metadata, FileMetadata, FileReader}; +pub use reader::{read_file_metadata, FileMetadata, FileReader, LimitRows}; pub use stream::{read_stream_metadata, StreamMetadata, StreamReader, StreamState}; diff --git a/src/io/ipc/read/reader.rs b/src/io/ipc/read/reader.rs index 8c864957530..9066542743e 100644 --- a/src/io/ipc/read/reader.rs +++ b/src/io/ipc/read/reader.rs @@ -169,7 +169,26 @@ pub fn read_file_metadata(reader: &mut R) -> Result( + message: &'a ipc::Message::Message, +) -> Result> { + match message.header_type() { + ipc::Message::MessageHeader::Schema => Err(ArrowError::Ipc( + "Not expecting a schema when messages are read".to_string(), + )), + ipc::Message::MessageHeader::RecordBatch => { + message.header_as_record_batch().ok_or_else(|| { + ArrowError::Ipc("Unable to read IPC message as record batch".to_string()) + }) + } + t => Err(ArrowError::Ipc(format!( + "Reading types other than record batches not yet supported, unable to read {:?}", + t + ))), + } +} + +/// Read a batch from the reader. pub fn read_batch( reader: &mut R, metadata: &FileMetadata, @@ -202,30 +221,18 @@ pub fn read_batch( )); } - match message.header_type() { - ipc::Message::MessageHeader::Schema => Err(ArrowError::Ipc( - "Not expecting a schema when messages are read".to_string(), - )), - ipc::Message::MessageHeader::RecordBatch => { - let batch = message.header_as_record_batch().ok_or_else(|| { - ArrowError::Ipc("Unable to read IPC message as record batch".to_string()) - })?; - read_record_batch( - batch, - metadata.schema.clone(), - projection, - metadata.is_little_endian, - &metadata.dictionaries, - metadata.version, - reader, - block.offset() as u64 + block.metaDataLength() as u64, - ) - } - t => Err(ArrowError::Ipc(format!( - "Reading types other than record batches not yet supported, unable to read {:?}", - t - ))), - } + let batch = get_serialized_batch(&message)?; + + read_record_batch( + batch, + metadata.schema.clone(), + projection, + metadata.is_little_endian, + &metadata.dictionaries, + metadata.version, + reader, + block.offset() as u64 + block.metaDataLength() as u64, + ) } impl FileReader { @@ -299,3 +306,64 @@ impl RecordBatchReader for FileReader { self.schema().as_ref() } } + +fn limit_batch(batch: RecordBatch, limit: usize) -> RecordBatch { + if batch.num_rows() < limit { + let RecordBatch { schema, columns } = batch; + let columns = columns + .into_iter() + .map(|x| x.slice(0, limit).into()) + .collect(); + RecordBatch { schema, columns } + } else { + batch + } +} + +/// Iterator adapter that limits the number of rows read by a +/// fallible [`Iterator`] of [`RecordBatch`]es. +/// # Implementation +/// Tracks the number of remaining rows and slices the last [`RecordBatch`] to fit exactly. +pub struct LimitRows>> { + iterator: I, + limit: Option, +} + +impl>> LimitRows { + /// Creates a new [`LimitRows`]. If `limit` is [`None`], it does not limit the iterator. + pub fn new(iterator: I, limit: Option) -> Self { + Self { iterator, limit } + } +} + +impl>> Iterator for LimitRows { + type Item = Result; + + fn next(&mut self) -> Option { + if let Some(limit) = self.limit { + // no more rows required => finish + if limit == 0 { + return None; + } + }; + + let batch = self.iterator.next(); + + if let Some(Ok(batch)) = batch { + Some(Ok(if let Some(ref mut limit) = self.limit { + // slice the last batch if it is too large + let batch = if batch.num_rows() > *limit { + limit_batch(batch, *limit) + } else { + batch + }; + *limit -= batch.num_rows(); + batch + } else { + batch + })) + } else { + batch + } + } +} diff --git a/src/record_batch.rs b/src/record_batch.rs index 02182af41df..6ead0ac213c 100644 --- a/src/record_batch.rs +++ b/src/record_batch.rs @@ -11,8 +11,8 @@ use crate::error::{ArrowError, Result}; /// Cloning is `O(C)` where `C` is the number of columns. #[derive(Clone, Debug, PartialEq)] pub struct RecordBatch { - schema: Arc, - columns: Vec>, + pub(crate) schema: Arc, + pub(crate) columns: Vec>, } impl RecordBatch {