Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added LimitRows iterator adapter.
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Nov 16, 2021
1 parent 110d889 commit 70db412
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/io/ipc/read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ mod reader;
mod stream;

pub use common::{read_dictionary, read_record_batch};
pub use reader::{read_file_metadata, FileMetadata, FileReader};
pub use reader::{read_file_metadata, FileMetadata, FileReader, LimitRows};
pub use stream::{read_stream_metadata, StreamMetadata, StreamReader, StreamState};
118 changes: 93 additions & 25 deletions src/io/ipc/read/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,26 @@ pub fn read_file_metadata<R: Read + Seek>(reader: &mut R) -> Result<FileMetadata
})
}

/// Read the IPC file's metadata
fn get_serialized_batch<'a>(
message: &'a ipc::Message::Message,
) -> Result<ipc::Message::RecordBatch<'a>> {
match message.header_type() {
ipc::Message::MessageHeader::Schema => Err(ArrowError::Ipc(
"Not expecting a schema when messages are read".to_string(),
)),
ipc::Message::MessageHeader::RecordBatch => {
message.header_as_record_batch().ok_or_else(|| {
ArrowError::Ipc("Unable to read IPC message as record batch".to_string())
})
}
t => Err(ArrowError::Ipc(format!(
"Reading types other than record batches not yet supported, unable to read {:?}",
t
))),
}
}

/// Read a batch from the reader.
pub fn read_batch<R: Read + Seek>(
reader: &mut R,
metadata: &FileMetadata,
Expand Down Expand Up @@ -202,30 +221,18 @@ pub fn read_batch<R: Read + Seek>(
));
}

match message.header_type() {
ipc::Message::MessageHeader::Schema => Err(ArrowError::Ipc(
"Not expecting a schema when messages are read".to_string(),
)),
ipc::Message::MessageHeader::RecordBatch => {
let batch = message.header_as_record_batch().ok_or_else(|| {
ArrowError::Ipc("Unable to read IPC message as record batch".to_string())
})?;
read_record_batch(
batch,
metadata.schema.clone(),
projection,
metadata.is_little_endian,
&metadata.dictionaries,
metadata.version,
reader,
block.offset() as u64 + block.metaDataLength() as u64,
)
}
t => Err(ArrowError::Ipc(format!(
"Reading types other than record batches not yet supported, unable to read {:?}",
t
))),
}
let batch = get_serialized_batch(&message)?;

read_record_batch(
batch,
metadata.schema.clone(),
projection,
metadata.is_little_endian,
&metadata.dictionaries,
metadata.version,
reader,
block.offset() as u64 + block.metaDataLength() as u64,
)
}

impl<R: Read + Seek> FileReader<R> {
Expand Down Expand Up @@ -299,3 +306,64 @@ impl<R: Read + Seek> RecordBatchReader for FileReader<R> {
self.schema().as_ref()
}
}

fn limit_batch(batch: RecordBatch, limit: usize) -> RecordBatch {
if batch.num_rows() < limit {
let RecordBatch { schema, columns } = batch;
let columns = columns
.into_iter()
.map(|x| x.slice(0, limit).into())
.collect();
RecordBatch { schema, columns }
} else {
batch
}
}

/// Iterator adapter that limits the number of rows read by a
/// fallible [`Iterator`] of [`RecordBatch`]es.
/// # Implementation
/// Tracks the number of remaining rows and slices the last [`RecordBatch`] to fit exactly.
pub struct LimitRows<I: Iterator<Item = Result<RecordBatch>>> {
iterator: I,
limit: Option<usize>,
}

impl<I: Iterator<Item = Result<RecordBatch>>> LimitRows<I> {
/// Creates a new [`LimitRows`]. If `limit` is [`None`], it does not limit the iterator.
pub fn new(iterator: I, limit: Option<usize>) -> Self {
Self { iterator, limit }
}
}

impl<I: Iterator<Item = Result<RecordBatch>>> Iterator for LimitRows<I> {
type Item = Result<RecordBatch>;

fn next(&mut self) -> Option<Self::Item> {
if let Some(limit) = self.limit {
// no more rows required => finish
if limit == 0 {
return None;
}
};

let batch = self.iterator.next();

if let Some(Ok(batch)) = batch {
Some(Ok(if let Some(ref mut limit) = self.limit {
// slice the last batch if it is too large
let batch = if batch.num_rows() > *limit {
limit_batch(batch, *limit)
} else {
batch
};
*limit -= batch.num_rows();
batch
} else {
batch
}))
} else {
batch
}
}
}
4 changes: 2 additions & 2 deletions src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use crate::error::{ArrowError, Result};
/// Cloning is `O(C)` where `C` is the number of columns.
#[derive(Clone, Debug, PartialEq)]
pub struct RecordBatch {
schema: Arc<Schema>,
columns: Vec<Arc<dyn Array>>,
pub(crate) schema: Arc<Schema>,
pub(crate) columns: Vec<Arc<dyn Array>>,
}

impl RecordBatch {
Expand Down

0 comments on commit 70db412

Please sign in to comment.