Skip to content

Commit

Permalink
Implement interleave_record_batch (#6731)
Browse files Browse the repository at this point in the history
Signed-off-by: Ruihang Xia <[email protected]>
  • Loading branch information
waynexia authored Nov 16, 2024
1 parent eab3326 commit f955193
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions arrow-select/src/interleave.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,67 @@ fn interleave_fallback(
Ok(make_array(array_data.freeze()))
}

/// Interleave rows by index from multiple [`RecordBatch`] instances and return a new [`RecordBatch`].
///
/// This function will call [`interleave`] on each array of the [`RecordBatch`] instances and assemble a new [`RecordBatch`].
///
/// # Example
/// ```
/// # use std::sync::Arc;
/// # use arrow_array::{StringArray, Int32Array, RecordBatch, UInt32Array};
/// # use arrow_schema::{DataType, Field, Schema};
/// # use arrow_select::interleave::interleave_record_batch;
///
/// let schema = Arc::new(Schema::new(vec![
/// Field::new("a", DataType::Int32, true),
/// Field::new("b", DataType::Utf8, true),
/// ]));
///
/// let batch1 = RecordBatch::try_new(
/// schema.clone(),
/// vec![
/// Arc::new(Int32Array::from(vec![0, 1, 2])),
/// Arc::new(StringArray::from(vec!["a", "b", "c"])),
/// ],
/// ).unwrap();
///
/// let batch2 = RecordBatch::try_new(
/// schema.clone(),
/// vec![
/// Arc::new(Int32Array::from(vec![3, 4, 5])),
/// Arc::new(StringArray::from(vec!["d", "e", "f"])),
/// ],
/// ).unwrap();
///
/// let indices = vec![(0, 1), (1, 2), (0, 0), (1, 1)];
/// let interleaved = interleave_record_batch(&[&batch1, &batch2], &indices).unwrap();
///
/// let expected = RecordBatch::try_new(
/// schema,
/// vec![
/// Arc::new(Int32Array::from(vec![1, 5, 0, 4])),
/// Arc::new(StringArray::from(vec!["b", "f", "a", "e"])),
/// ],
/// ).unwrap();
/// assert_eq!(interleaved, expected);
/// ```
pub fn interleave_record_batch(
record_batches: &[&RecordBatch],
indices: &[(usize, usize)],
) -> Result<RecordBatch, ArrowError> {
let schema = record_batches[0].schema();
let columns = (0..schema.fields().len())
.map(|i| {
let column_values: Vec<&dyn Array> = record_batches
.iter()
.map(|batch| batch.column(i).as_ref())
.collect();
interleave(&column_values, indices)
})
.collect::<Result<Vec<_>, _>>()?;
RecordBatch::try_new(schema, columns)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit f955193

Please sign in to comment.