From 8a0c76eb12c381cf3a037156c40861fe45e765b8 Mon Sep 17 00:00:00 2001 From: sundyli <543950155@qq.com> Date: Fri, 30 Jul 2021 19:26:54 +0800 Subject: [PATCH] Add support to merge sort with a limit (#222) --- src/compute/merge_sort/mod.rs | 61 ++++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/src/compute/merge_sort/mod.rs b/src/compute/merge_sort/mod.rs index 5c4fddfaaca..3dad0afdcb2 100644 --- a/src/compute/merge_sort/mod.rs +++ b/src/compute/merge_sort/mod.rs @@ -38,7 +38,7 @@ //! let slices2 = merge_sort_slices(a2, a3); //! let slices = merge_sort_slices(slices1, slices2); //! -//! let array = take_arrays(&[a0, a1, a2, a3], slices); +//! let array = take_arrays(&[a0, a1, a2, a3], slices, None); //! ``` //! //! A common operation in query engines is to merge multiple fields based on the @@ -49,8 +49,8 @@ //! ```rust,ignore //! // `slices` computed before-hand //! // in parallel -//! let array1 = take_arrays(&[a0, a1, a2, a3], slices); -//! let array2 = take_arrays(&[b0, b1, b2, b3], slices); +//! let array1 = take_arrays(&[a0, a1, a2, a3], slices, None); +//! let array2 = take_arrays(&[b0, b1, b2, b3], slices, None); //! ``` //! //! To serialize slices, e.g. for checkpointing or transfer via Arrow's IPC, you can store @@ -89,14 +89,32 @@ type MergeSlice = (usize, usize, usize); pub fn take_arrays>( arrays: &[&dyn Array], slices: I, + limit: Option, ) -> Box { let slices = slices.into_iter(); let len = arrays.iter().map(|array| array.len()).sum(); - let mut growable = make_growable(arrays, false, len); - for (index, start, len) in slices { - growable.extend(index, start, len) + let limit = limit.unwrap_or(len); + let limit = limit.min(len); + let mut growable = make_growable(arrays, false, limit); + + if limit != len { + let mut current_len = 0; + for (index, start, len) in slices { + if len + current_len >= limit { + growable.extend(index, start, limit - current_len); + break; + } else { + growable.extend(index, start, len); + current_len += len; + } + } + } else { + for (index, start, len) in slices { + growable.extend(index, start, len); + } } + growable.as_box() } @@ -114,7 +132,7 @@ pub fn take_arrays>( /// # fn main() -> Result<()> { /// let a = Int32Array::from_slice(&[2, 4, 6]); /// let b = Int32Array::from_slice(&[0, 1, 3]); -/// let sorted = merge_sort(&a, &b, &SortOptions::default())?; +/// let sorted = merge_sort(&a, &b, &SortOptions::default(), None)?; /// let expected = Int32Array::from_slice(&[0, 1, 2, 3, 4, 6]); /// assert_eq!(expected, sorted.as_ref()); /// # Ok(()) @@ -124,6 +142,7 @@ pub fn merge_sort( lhs: &dyn Array, rhs: &dyn Array, options: &SortOptions, + limit: Option, ) -> Result> { let arrays = &[lhs, rhs]; @@ -133,7 +152,7 @@ pub fn merge_sort( let lhs = (0, 0, lhs.len()); let rhs = (1, 0, rhs.len()); let slices = merge_sort_slices(once(&lhs), once(&rhs), &comparator); - Ok(take_arrays(arrays, slices)) + Ok(take_arrays(arrays, slices, limit)) } /// Returns a vector of slices from different sorted arrays that can be used to create sorted arrays. @@ -519,6 +538,26 @@ mod tests { Ok(()) } + #[test] + fn test_merge_with_limit() -> Result<()> { + let a0: &dyn Array = &Int32Array::from_slice(&[0, 2, 4, 6, 8]); + let a1: &dyn Array = &Int32Array::from_slice(&[1, 3, 5, 7, 9]); + + let options = SortOptions::default(); + let arrays = vec![a0, a1]; + let pairs = vec![(arrays.as_ref(), &options)]; + let comparator = build_comparator(&pairs)?; + + let slices = merge_sort_slices(once(&(0, 0, 5)), once(&(1, 0, 5)), &comparator); + // thus, they can be used to take from the arrays + let array = take_arrays(&arrays, slices, Some(5)); + + let expected = Int32Array::from_slice(&[0, 1, 2, 3, 4]); + // values are right + assert_eq!(expected, array.as_ref()); + Ok(()) + } + #[test] fn test_merge_4_i32() -> Result<()> { let a0: &dyn Array = &Int32Array::from_slice(&[0, 1]); @@ -546,7 +585,7 @@ mod tests { ); // thus, they can be used to take from the arrays - let array = take_arrays(&arrays, slices); + let array = take_arrays(&arrays, slices, None); let expected = Int32Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7]); @@ -616,7 +655,7 @@ mod tests { let pairs = vec![(arrays0.as_ref(), &options), (arrays1.as_ref(), &options)]; let slices = slices(&pairs)?; - let array = take_arrays(&[array0, array1], slices); + let array = take_arrays(&[array0, array1], slices, None); assert_eq!(expected, array.as_ref()); Ok(()) @@ -641,7 +680,7 @@ mod tests { let a1 = sort(a1, &options, None)?; // merge then. If multiple arrays, this can be applied in parallel. - let result = merge_sort(a0.as_ref(), a1.as_ref(), &options)?; + let result = merge_sort(a0.as_ref(), a1.as_ref(), &options, None)?; assert_eq!(expected, result.as_ref()); Ok(())