Skip to content

Commit

Permalink
Merge pull request #4271 from platoneko/simd-select
Browse files Browse the repository at this point in the history
ISSUE-4264: simd selection
  • Loading branch information
sundy-li authored Mar 7, 2022
2 parents ca55802 + bd08a7c commit 3a66496
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 11 deletions.
81 changes: 70 additions & 11 deletions common/datavalues/src/columns/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use std::sync::Arc;

use common_arrow::arrow::array::Array;
use common_arrow::arrow::array::PrimitiveArray;
use common_arrow::arrow::bitmap::utils::BitChunkIterExact;
use common_arrow::arrow::bitmap::utils::BitChunksExact;
use common_arrow::arrow::bitmap::Bitmap;
use common_arrow::arrow::buffer::Buffer;
use common_arrow::arrow::compute::arity::unary;
Expand Down Expand Up @@ -185,21 +187,78 @@ impl<T: PrimitiveType> Column for PrimitiveColumn<T> {
}

fn filter(&self, filter: &BooleanColumn) -> ColumnRef {
let length = filter.values().len() - filter.values().null_count();
if length == self.len() {
assert_eq!(self.len(), filter.values().len());

let selected = filter.values().len() - filter.values().null_count();
if selected == self.len() {
return Arc::new(self.clone());
}
let iter = self
.values()

let mut new = Vec::<T>::with_capacity(selected);
let mut dst = new.as_mut_ptr();

let (mut slice, offset, mut length) = filter.values().as_slice();
let mut values = self.values();
if offset > 0 {
// Consume the offset
let n = 8 - offset;
values
.iter()
.zip(filter.values().iter())
.take(n)
.for_each(|(value, is_selected)| {
if is_selected {
unsafe {
dst.write(*value);
dst = dst.add(1);
}
}
});
slice = &slice[1..];
length -= n;
values = &values[n..];
}

const CHUNK_SIZE: usize = 64;
let mut chunks = values.chunks_exact(CHUNK_SIZE);
let mut mask_chunks = BitChunksExact::<u64>::new(slice, length);

chunks
.by_ref()
.zip(mask_chunks.by_ref())
.for_each(|(chunk, mut mask)| {
if mask == u64::MAX {
unsafe {
std::ptr::copy(chunk.as_ptr(), dst, CHUNK_SIZE);
dst = dst.add(CHUNK_SIZE);
}
} else {
while mask != 0 {
let n = mask.trailing_zeros() as usize;
unsafe {
dst.write(chunk[n]);
dst = dst.add(1);
}
mask = mask & (mask - 1);
}
}
});

chunks
.remainder()
.iter()
.zip(filter.values().iter())
.filter(|(_, f)| *f)
.map(|(v, _)| *v);
.zip(mask_chunks.remainder_iter())
.for_each(|(value, is_selected)| {
if is_selected {
unsafe {
dst.write(*value);
dst = dst.add(1);
}
}
});

let values: Vec<T> = iter.collect();
let col = PrimitiveColumn {
values: values.into(),
};
unsafe { new.set_len(selected) };
let col = PrimitiveColumn { values: new.into() };

Arc::new(col)
}
Expand Down
62 changes: 62 additions & 0 deletions common/datavalues/tests/it/columns/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,65 @@ fn test_const_column() {
let c = ConstColumn::new(Series::from_data(vec![PI]), 24).arc();
println!("{:?}", c);
}

#[test]
fn test_filter_column() {
const N: usize = 1000;
let it = (0..N).map(|i| i as i32);
let data_column: PrimitiveColumn<i32> = Int32Column::from_iterator(it);

struct Test {
filter: BooleanColumn,
expect: Vec<i32>,
}

let mut tests: Vec<Test> = vec![
Test {
filter: BooleanColumn::from_iterator((0..N).map(|_| true)),
expect: (0..N).map(|i| i as i32).collect(),
},
Test {
filter: BooleanColumn::from_iterator((0..N).map(|_| false)),
expect: vec![],
},
Test {
filter: BooleanColumn::from_iterator((0..N).map(|i| i % 10 == 0)),
expect: (0..N).map(|i| i as i32).filter(|i| i % 10 == 0).collect(),
},
Test {
filter: BooleanColumn::from_iterator((0..N).map(|i| !(100..=800).contains(&i))),
expect: (0..N)
.map(|i| i as i32)
.filter(|&i| !(100..=800).contains(&i))
.collect(),
},
];

let offset = 10;
let filter = BooleanColumn::from_iterator(
(0..N + offset).map(|i| !(100 + offset..=800 + offset).contains(&i)),
)
.slice(offset, N)
.as_any()
.downcast_ref::<BooleanColumn>()
.unwrap()
.clone();
tests.push(Test {
filter,
expect: (0..N)
.map(|i| i as i32)
.filter(|&i| !(100..=800).contains(&i))
.collect(),
});

for test in tests {
let res = data_column.filter(&test.filter);
assert_eq!(
res.as_any()
.downcast_ref::<PrimitiveColumn<i32>>()
.unwrap()
.values(),
test.expect
);
}
}

0 comments on commit 3a66496

Please sign in to comment.