diff --git a/common/datavalues/src/columns/primitive/mod.rs b/common/datavalues/src/columns/primitive/mod.rs index c987aa874b7ae..136398c3c6544 100644 --- a/common/datavalues/src/columns/primitive/mod.rs +++ b/common/datavalues/src/columns/primitive/mod.rs @@ -20,6 +20,8 @@ use std::sync::Arc; use common_arrow::arrow::array::Array; use common_arrow::arrow::array::PrimitiveArray; +use common_arrow::arrow::bitmap::utils::BitChunkIterExact; +use common_arrow::arrow::bitmap::utils::BitChunksExact; use common_arrow::arrow::bitmap::Bitmap; use common_arrow::arrow::buffer::Buffer; use common_arrow::arrow::compute::arity::unary; @@ -185,21 +187,78 @@ impl Column for PrimitiveColumn { } fn filter(&self, filter: &BooleanColumn) -> ColumnRef { - let length = filter.values().len() - filter.values().null_count(); - if length == self.len() { + assert_eq!(self.len(), filter.values().len()); + + let selected = filter.values().len() - filter.values().null_count(); + if selected == self.len() { return Arc::new(self.clone()); } - let iter = self - .values() + + let mut new = Vec::::with_capacity(selected); + let mut dst = new.as_mut_ptr(); + + let (mut slice, offset, mut length) = filter.values().as_slice(); + let mut values = self.values(); + if offset > 0 { + // Consume the offset + let n = 8 - offset; + values + .iter() + .zip(filter.values().iter()) + .take(n) + .for_each(|(value, is_selected)| { + if is_selected { + unsafe { + dst.write(*value); + dst = dst.add(1); + } + } + }); + slice = &slice[1..]; + length -= n; + values = &values[n..]; + } + + const CHUNK_SIZE: usize = 64; + let mut chunks = values.chunks_exact(CHUNK_SIZE); + let mut mask_chunks = BitChunksExact::::new(slice, length); + + chunks + .by_ref() + .zip(mask_chunks.by_ref()) + .for_each(|(chunk, mut mask)| { + if mask == u64::MAX { + unsafe { + std::ptr::copy(chunk.as_ptr(), dst, CHUNK_SIZE); + dst = dst.add(CHUNK_SIZE); + } + } else { + while mask != 0 { + let n = mask.trailing_zeros() as usize; + unsafe { + dst.write(chunk[n]); + dst = dst.add(1); + } + mask = mask & (mask - 1); + } + } + }); + + chunks + .remainder() .iter() - .zip(filter.values().iter()) - .filter(|(_, f)| *f) - .map(|(v, _)| *v); + .zip(mask_chunks.remainder_iter()) + .for_each(|(value, is_selected)| { + if is_selected { + unsafe { + dst.write(*value); + dst = dst.add(1); + } + } + }); - let values: Vec = iter.collect(); - let col = PrimitiveColumn { - values: values.into(), - }; + unsafe { new.set_len(selected) }; + let col = PrimitiveColumn { values: new.into() }; Arc::new(col) } diff --git a/common/datavalues/tests/it/columns/primitive.rs b/common/datavalues/tests/it/columns/primitive.rs index bd076e208965f..98d4399131fb3 100644 --- a/common/datavalues/tests/it/columns/primitive.rs +++ b/common/datavalues/tests/it/columns/primitive.rs @@ -54,3 +54,65 @@ fn test_const_column() { let c = ConstColumn::new(Series::from_data(vec![PI]), 24).arc(); println!("{:?}", c); } + +#[test] +fn test_filter_column() { + const N: usize = 1000; + let it = (0..N).map(|i| i as i32); + let data_column: PrimitiveColumn = Int32Column::from_iterator(it); + + struct Test { + filter: BooleanColumn, + expect: Vec, + } + + let mut tests: Vec = vec![ + Test { + filter: BooleanColumn::from_iterator((0..N).map(|_| true)), + expect: (0..N).map(|i| i as i32).collect(), + }, + Test { + filter: BooleanColumn::from_iterator((0..N).map(|_| false)), + expect: vec![], + }, + Test { + filter: BooleanColumn::from_iterator((0..N).map(|i| i % 10 == 0)), + expect: (0..N).map(|i| i as i32).filter(|i| i % 10 == 0).collect(), + }, + Test { + filter: BooleanColumn::from_iterator((0..N).map(|i| !(100..=800).contains(&i))), + expect: (0..N) + .map(|i| i as i32) + .filter(|&i| !(100..=800).contains(&i)) + .collect(), + }, + ]; + + let offset = 10; + let filter = BooleanColumn::from_iterator( + (0..N + offset).map(|i| !(100 + offset..=800 + offset).contains(&i)), + ) + .slice(offset, N) + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + tests.push(Test { + filter, + expect: (0..N) + .map(|i| i as i32) + .filter(|&i| !(100..=800).contains(&i)) + .collect(), + }); + + for test in tests { + let res = data_column.filter(&test.filter); + assert_eq!( + res.as_any() + .downcast_ref::>() + .unwrap() + .values(), + test.expect + ); + } +}