Skip to content

Commit

Permalink
perf: improve performance of SortPreservingMergeExec operator (#722)
Browse files Browse the repository at this point in the history
* perf: re-use Array comparators

This commit stores built Arrow comparators for two arrays on each of the sort key cursors, resulting in a significant reduction in the cost associated with merging record batches using the `SortPreservingMerge` operator.

Benchmarks improved as follows:

```
⇒  critcmp master pr
group                               master                                 pr
-----                               ------                                 --
interleave_batches                  1.83   623.8±12.41µs        ? ?/sec    1.00    341.2±6.98µs        ? ?/sec
merge_batches_no_overlap_large      1.56    400.6±4.94µs        ? ?/sec    1.00    256.3±6.57µs        ? ?/sec
merge_batches_no_overlap_small      1.63   425.1±24.88µs        ? ?/sec    1.00    261.1±7.46µs        ? ?/sec
merge_batches_small_into_large      1.18    228.0±3.95µs        ? ?/sec    1.00    193.6±2.86µs        ? ?/sec
merge_batches_some_overlap_large    1.68   505.4±10.27µs        ? ?/sec    1.00    301.3±6.63µs        ? ?/sec
merge_batches_some_overlap_small    1.64    515.7±5.21µs        ? ?/sec    1.00   314.6±12.66µs        ? ?/sec
```

* test: test more than two partitions
  • Loading branch information
e-dard authored Jul 19, 2021
1 parent afe29bd commit bd3ee23
Showing 1 changed file with 145 additions and 37 deletions.
182 changes: 145 additions & 37 deletions datafusion/src/physical_plan/sort_preserving_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use arrow::array::DynComparator;
use arrow::{
array::{make_array as make_arrow_array, ArrayRef, MutableArrayData},
compute::SortOptions,
Expand All @@ -35,6 +36,7 @@ use async_trait::async_trait;
use futures::channel::mpsc;
use futures::stream::FusedStream;
use futures::{Stream, StreamExt};
use hashbrown::HashMap;

use crate::error::{DataFusionError, Result};
use crate::physical_plan::{
Expand Down Expand Up @@ -176,34 +178,60 @@ impl ExecutionPlan for SortPreservingMergeExec {
}
}

/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of `PhysicalExpr` that when
/// evaluated on the `RecordBatch` yield the sort keys.
/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of
/// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys.
///
/// Additionally it maintains a row cursor that can be advanced through the rows
/// of the provided `RecordBatch`
///
/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to by this
/// row cursor, with that of another `SortKeyCursor`
#[derive(Debug, Clone)]
/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to
/// by this row cursor, with that of another `SortKeyCursor`. A cursor stores
/// a row comparator for each other cursor that it is compared to.
struct SortKeyCursor {
columns: Vec<ArrayRef>,
batch: RecordBatch,
cur_row: usize,
num_rows: usize,

// An index uniquely identifying the record batch scanned by this cursor.
batch_idx: usize,
batch: RecordBatch,

// A collection of comparators that compare rows in this cursor's batch to
// the cursors in other batches. Other batches are uniquely identified by
// their batch_idx.
batch_comparators: HashMap<usize, Vec<DynComparator>>,
}

impl<'a> std::fmt::Debug for SortKeyCursor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SortKeyCursor")
.field("columns", &self.columns)
.field("cur_row", &self.cur_row)
.field("num_rows", &self.num_rows)
.field("batch_idx", &self.batch_idx)
.field("batch", &self.batch)
.field("batch_comparators", &"<FUNC>")
.finish()
}
}

impl SortKeyCursor {
fn new(batch: RecordBatch, sort_key: &[Arc<dyn PhysicalExpr>]) -> Result<Self> {
fn new(
batch_idx: usize,
batch: RecordBatch,
sort_key: &[Arc<dyn PhysicalExpr>],
) -> Result<Self> {
let columns = sort_key
.iter()
.map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())))
.collect::<Result<_>>()?;

Ok(Self {
cur_row: 0,
num_rows: batch.num_rows(),
columns,
batch,
batch_idx,
batch_comparators: HashMap::new(),
})
}

Expand All @@ -220,7 +248,7 @@ impl SortKeyCursor {

/// Compares the sort key pointed to by this instance's row cursor with that of another
fn compare(
&self,
&mut self,
other: &SortKeyCursor,
options: &[SortOptions],
) -> Result<Ordering> {
Expand All @@ -246,7 +274,19 @@ impl SortKeyCursor {
.zip(other.columns.iter())
.zip(options.iter());

for ((l, r), sort_options) in zipped {
// Recall or initialise a collection of comparators for comparing
// columnar arrays of this cursor and "other".
let cmp = self
.batch_comparators
.entry(other.batch_idx)
.or_insert_with(|| Vec::with_capacity(other.columns.len()));

for (i, ((l, r), sort_options)) in zipped.enumerate() {
if i >= cmp.len() {
// initialise comparators as potentially needed
cmp.push(arrow::array::build_compare(l.as_ref(), r.as_ref())?);
}

match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) {
(false, true) if sort_options.nulls_first => return Ok(Ordering::Less),
(false, true) => return Ok(Ordering::Greater),
Expand All @@ -255,15 +295,11 @@ impl SortKeyCursor {
}
(true, false) => return Ok(Ordering::Less),
(false, false) => {}
(true, true) => {
// TODO: Building the predicate each time is sub-optimal
let c = arrow::array::build_compare(l.as_ref(), r.as_ref())?;
match c(self.cur_row, other.cur_row) {
Ordering::Equal => {}
o if sort_options.descending => return Ok(o.reverse()),
o => return Ok(o),
}
}
(true, true) => match cmp[i](self.cur_row, other.cur_row) {
Ordering::Equal => {}
o if sort_options.descending => return Ok(o.reverse()),
o => return Ok(o),
},
}
}

Expand Down Expand Up @@ -304,6 +340,9 @@ struct SortPreservingMergeStream {
target_batch_size: usize,
/// If the stream has encountered an error
aborted: bool,

/// An index to uniquely identify the input stream batch
next_batch_index: usize,
}

impl SortPreservingMergeStream {
Expand All @@ -313,15 +352,21 @@ impl SortPreservingMergeStream {
expressions: &[PhysicalSortExpr],
target_batch_size: usize,
) -> Self {
let cursors = (0..streams.len())
.into_iter()
.map(|_| VecDeque::new())
.collect();

Self {
schema,
cursors: vec![Default::default(); streams.len()],
cursors,
streams,
column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(),
sort_options: expressions.iter().map(|x| x.options).collect(),
target_batch_size,
aborted: false,
in_progress: vec![],
next_batch_index: 0,
}
}

Expand Down Expand Up @@ -352,12 +397,17 @@ impl SortPreservingMergeStream {
return Poll::Ready(Err(e));
}
Some(Ok(batch)) => {
let cursor = match SortKeyCursor::new(batch, &self.column_expressions) {
let cursor = match SortKeyCursor::new(
self.next_batch_index, // assign this batch an ID
batch,
&self.column_expressions,
) {
Ok(cursor) => cursor,
Err(e) => {
return Poll::Ready(Err(ArrowError::ExternalError(Box::new(e))));
}
};
self.next_batch_index += 1;
self.cursors[idx].push_back(cursor)
}
}
Expand All @@ -367,17 +417,17 @@ impl SortPreservingMergeStream {

/// Returns the index of the next stream to pull a row from, or None
/// if all cursors for all streams are exhausted
fn next_stream_idx(&self) -> Result<Option<usize>> {
let mut min_cursor: Option<(usize, &SortKeyCursor)> = None;
for (idx, candidate) in self.cursors.iter().enumerate() {
if let Some(candidate) = candidate.back() {
fn next_stream_idx(&mut self) -> Result<Option<usize>> {
let mut min_cursor: Option<(usize, &mut SortKeyCursor)> = None;
for (idx, candidate) in self.cursors.iter_mut().enumerate() {
if let Some(candidate) = candidate.back_mut() {
if candidate.is_finished() {
continue;
}

match min_cursor {
None => min_cursor = Some((idx, candidate)),
Some((_, min)) => {
Some((_, ref mut min)) => {
if min.compare(candidate, &self.sort_options)?
== Ordering::Greater
{
Expand Down Expand Up @@ -599,8 +649,7 @@ mod tests {
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();

_test_merge(
b1,
b2,
&[vec![b1], vec![b2]],
&[
"+----+---+-------------------------------+",
"| a | b | c |",
Expand Down Expand Up @@ -646,8 +695,7 @@ mod tests {
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();

_test_merge(
b1,
b2,
&[vec![b1], vec![b2]],
&[
"+-----+---+-------------------------------+",
"| a | b | c |",
Expand Down Expand Up @@ -693,8 +741,7 @@ mod tests {
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();

_test_merge(
b1,
b2,
&[vec![b1], vec![b2]],
&[
"+----+---+-------------------------------+",
"| a | b | c |",
Expand All @@ -715,8 +762,71 @@ mod tests {
.await;
}

async fn _test_merge(b1: RecordBatch, b2: RecordBatch, exp: &[&str]) {
let schema = b1.schema();
#[tokio::test]
async fn test_merge_three_partitions() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
Some("a"),
Some("b"),
Some("c"),
Some("d"),
Some("f"),
]));
let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();

let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
Some("e"),
Some("g"),
Some("h"),
Some("i"),
Some("j"),
]));
let c: ArrayRef =
Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60]));
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();

let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
Some("f"),
Some("g"),
Some("h"),
Some("i"),
Some("j"),
]));
let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();

_test_merge(
&[vec![b1], vec![b2], vec![b3]],
&[
"+-----+---+-------------------------------+",
"| a | b | c |",
"+-----+---+-------------------------------+",
"| 1 | a | 1970-01-01 00:00:00.000000008 |",
"| 2 | b | 1970-01-01 00:00:00.000000007 |",
"| 7 | c | 1970-01-01 00:00:00.000000006 |",
"| 9 | d | 1970-01-01 00:00:00.000000005 |",
"| 10 | e | 1970-01-01 00:00:00.000000040 |",
"| 100 | f | 1970-01-01 00:00:00.000000004 |",
"| 3 | f | 1970-01-01 00:00:00.000000008 |",
"| 200 | g | 1970-01-01 00:00:00.000000006 |",
"| 20 | g | 1970-01-01 00:00:00.000000060 |",
"| 700 | h | 1970-01-01 00:00:00.000000002 |",
"| 70 | h | 1970-01-01 00:00:00.000000020 |",
"| 900 | i | 1970-01-01 00:00:00.000000002 |",
"| 90 | i | 1970-01-01 00:00:00.000000020 |",
"| 300 | j | 1970-01-01 00:00:00.000000006 |",
"| 30 | j | 1970-01-01 00:00:00.000000060 |",
"+-----+---+-------------------------------+",
],
)
.await;
}

async fn _test_merge(partitions: &[Vec<RecordBatch>], exp: &[&str]) {
let schema = partitions[0][0].schema();
let sort = vec![
PhysicalSortExpr {
expr: col("b", &schema).unwrap(),
Expand All @@ -727,12 +837,10 @@ mod tests {
options: Default::default(),
},
];
let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap();
let exec = MemoryExec::try_new(partitions, schema, None).unwrap();
let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024));

let collected = collect(merge).await.unwrap();
assert_eq!(collected.len(), 1);

assert_batches_eq!(exp, collected.as_slice());
}

Expand Down

0 comments on commit bd3ee23

Please sign in to comment.