Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed send and sync for mutex and rwlock #1705

Merged
merged 4 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions common/infallible/src/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@ use parking_lot::MutexGuard;
#[derive(Debug)]
pub struct Mutex<T>(ParkingMutex<T>);

/// Mutex is Send
unsafe impl<T> Send for Mutex<T> {}

/// Mutex is Sync
unsafe impl<T> Sync for Mutex<T> {}
unsafe impl<T> Send for Mutex<T> where ParkingMutex<T>: Send {}
unsafe impl<T> Sync for Mutex<T> where ParkingMutex<T>: Sync {}

impl<T> Mutex<T> {
/// creates mutex
Expand Down
4 changes: 2 additions & 2 deletions common/infallible/src/rwlock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use parking_lot::RwLockWriteGuard;
#[derive(Debug, Default)]
pub struct RwLock<T>(ParkingRwLock<T>);

unsafe impl<T> Send for RwLock<T> {}
unsafe impl<T> Sync for RwLock<T> {}
unsafe impl<T> Send for RwLock<T> where ParkingRwLock<T>: Send {}
unsafe impl<T> Sync for RwLock<T> where ParkingRwLock<T>: Sync {}

impl<T> RwLock<T> {
/// creates a read-write lock
Expand Down
14 changes: 6 additions & 8 deletions common/streams/src/sources/source_csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

use std::io;
use std::sync::Arc;

use common_arrow::arrow::io::csv::read::ByteRecord;
use common_arrow::arrow::io::csv::read::Reader;
Expand All @@ -23,25 +22,24 @@ use common_datavalues::DataSchemaRef;
use common_exception::ErrorCode;
use common_exception::Result;
use common_exception::ToErrorCode;
use common_infallible::RwLock;

use crate::Source;

pub struct CsvSource<R> {
reader: Arc<RwLock<Reader<R>>>,
reader: Reader<R>,
schema: DataSchemaRef,
block_size: usize,
rows: usize,
}

impl<R> CsvSource<R>
where R: io::Read
where R: io::Read + Sync + Send
{
pub fn new(reader: R, schema: DataSchemaRef, block_size: usize) -> Self {
let reader = ReaderBuilder::new().has_headers(false).from_reader(reader);

Self {
reader: Arc::new(RwLock::new(reader)),
reader,
block_size,
schema,
rows: 0,
Expand All @@ -50,10 +48,9 @@ where R: io::Read
}

impl<R> Source for CsvSource<R>
where R: io::Read
where R: io::Read + Sync + Send
{
fn read(&mut self) -> Result<Option<DataBlock>> {
let mut reader = self.reader.write();
let mut record = ByteRecord::new();
let mut desers = self
.schema
Expand All @@ -63,7 +60,8 @@ where R: io::Read
.collect::<Result<Vec<_>>>()?;

for row in 0..self.block_size {
let v = reader
let v = self
.reader
.read_byte_record(&mut record)
.map_err_to_code(ErrorCode::BadBytes, || {
format!("Parse csv error at line {}", self.rows)
Expand Down
12 changes: 5 additions & 7 deletions common/streams/src/sources/source_values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,28 @@

use std::io;
use std::io::BufReader;
use std::sync::Arc;

use common_datablocks::DataBlock;
use common_datavalues::DataSchemaRef;
use common_exception::Result;
use common_infallible::RwLock;
use common_io::prelude::*;

use crate::Source;

pub struct ValueSource<R> {
reader: Arc<RwLock<BufReader<R>>>,
reader: BufReader<R>,
schema: DataSchemaRef,
block_size: usize,
rows: usize,
}

impl<R> ValueSource<R>
where R: io::Read
where R: io::Read + Send + Sync
{
pub fn new(reader: R, schema: DataSchemaRef, block_size: usize) -> Self {
let reader = BufReader::new(reader);
Self {
reader: Arc::new(RwLock::new(reader)),
reader,
block_size,
schema,
rows: 0,
Expand All @@ -46,10 +44,10 @@ where R: io::Read
}

impl<R> Source for ValueSource<R>
where R: io::Read
where R: io::Read + Send + Sync
{
fn read(&mut self) -> Result<Option<DataBlock>> {
let mut reader = self.reader.write();
let reader = &mut self.reader;
let mut buf = Vec::new();
let mut temp = Vec::new();

Expand Down
46 changes: 25 additions & 21 deletions query/src/pipelines/transforms/group_by/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use common_datavalues::DataSchemaRefExt;
use common_exception::Result;
use common_functions::aggregates::StateAddr;
use common_functions::aggregates::StateAddrs;
use common_infallible::Mutex;
use common_io::prelude::BytesMut;
use common_streams::DataBlockStream;
use common_streams::SendableDataBlockStream;
Expand Down Expand Up @@ -53,33 +52,40 @@ impl<Method: HashMethod + PolymorphicKeysHelper<Method>> Aggregator<Method> {
&self,
group_cols: Vec<String>,
mut stream: SendableDataBlockStream,
) -> Result<Mutex<Method::State>> {
) -> Result<Method::State> {
// This may be confusing
// It will help us improve performance ~10% when we declare local references for them.
let hash_method = &self.method;
let aggregator_params = self.params.as_ref();

let aggregate_state = Mutex::new(hash_method.aggregate_state());
let mut state = hash_method.aggregate_state();

while let Some(block) = stream.next().await {
let block = block?;
let mut groups = aggregate_state.lock();
match aggregator_params.aggregate_functions.is_empty() {
true => {
while let Some(block) = stream.next().await {
let block = block?;

// 1.1 and 1.2.
let group_columns = Self::group_columns(&group_cols, &block)?;
let group_keys = hash_method.build_keys(&group_columns, block.num_rows())?;
// 1.1 and 1.2.
let group_columns = Self::group_columns(&group_cols, &block)?;
let group_keys = hash_method.build_keys(&group_columns, block.num_rows())?;
self.lookup_key(group_keys, &mut state);
}
}
false => {
while let Some(block) = stream.next().await {
let block = block?;

// TODO: This can be moved outside the while
// In fact, the rust compiler will help us do this(optimize the while match to match while),
// but we need to ensure that the match is simple enough(otherwise there will be performance degradation).
let places: StateAddrs = match aggregator_params.aggregate_functions.is_empty() {
true => self.lookup_key(group_keys, &mut groups),
false => self.lookup_state(group_keys, &mut groups),
};
// 1.1 and 1.2.
let group_columns = Self::group_columns(&group_cols, &block)?;
let group_keys = hash_method.build_keys(&group_columns, block.num_rows())?;

Self::execute(aggregator_params, &block, &places)?;
let places = self.lookup_state(group_keys, &mut state);
Self::execute(aggregator_params, &block, &places)?;
}
}
}
Ok(aggregate_state)

Ok(state)
}

#[inline(always)]
Expand All @@ -105,13 +111,11 @@ impl<Method: HashMethod + PolymorphicKeysHelper<Method>> Aggregator<Method> {
}

#[inline(always)]
fn lookup_key(&self, keys: Vec<Method::HashKey>, state: &mut Method::State) -> StateAddrs {
fn lookup_key(&self, keys: Vec<Method::HashKey>, state: &mut Method::State) {
let mut inserted = true;
for key in keys.iter() {
state.entity(key, &mut inserted);
}

vec![0_usize.into(); keys.len()]
}

/// Allocate aggregation function state for each key(the same key can always get the same state)
Expand Down
70 changes: 48 additions & 22 deletions query/src/pipelines/transforms/group_by/aggregator_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use crate::pipelines::transforms::group_by::AggregatorParams;
/// - Aggregate data(HashMap or MergeSort set in future)
/// - Aggregate function state data memory pool
/// - Group by key data memory pool (if necessary)
pub trait AggregatorState<Method: HashMethod> {
pub trait AggregatorState<Method: HashMethod>: Sync + Send {
type Key;
type Entity: StateEntity<Self::Key>;
type Iterator: Iterator<Item = *mut Self::Entity>;
Expand All @@ -62,6 +62,16 @@ pub struct ShortFixedKeysAggregatorState<T: ShortFixedKeyable> {
data: *mut ShortFixedKeysStateEntity<T>,
}

// TODO:(Winter) Hack:
// The *mut ShortFixedKeysStateEntity needs to be used externally, but we can ensure that *mut
// ShortFixedKeysStateEntity will not be used multiple async, so ShortFixedKeysAggregatorState is Send
unsafe impl<T: ShortFixedKeyable + Send> Send for ShortFixedKeysAggregatorState<T> {}

// TODO:(Winter) Hack:
// The *mut ShortFixedKeysStateEntity needs to be used externally, but we can ensure that &*mut
// ShortFixedKeysStateEntity will not be used multiple async, so ShortFixedKeysAggregatorState is Sync
unsafe impl<T: ShortFixedKeyable + Sync> Sync for ShortFixedKeysAggregatorState<T> {}

impl<T: ShortFixedKeyable> ShortFixedKeysAggregatorState<T> {
pub fn create(max_size: usize) -> Self {
unsafe {
Expand Down Expand Up @@ -150,6 +160,16 @@ pub struct LongerFixedKeysAggregatorState<T: HashTableKeyable> {
pub data: HashMap<T, usize>,
}

// TODO:(Winter) Hack:
// The *mut KeyValueEntity needs to be used externally, but we can ensure that *mut KeyValueEntity
// will not be used multiple async, so KeyValueEntity is Send
unsafe impl<T: HashTableKeyable + Send> Send for LongerFixedKeysAggregatorState<T> {}

// TODO:(Winter) Hack:
// The *mut KeyValueEntity needs to be used externally, but we can ensure that &*mut KeyValueEntity
// will not be used multiple async, so KeyValueEntity is Sync
unsafe impl<T: HashTableKeyable + Sync> Sync for LongerFixedKeysAggregatorState<T> {}

impl<T> AggregatorState<HashMethodFixedKeys<T>> for LongerFixedKeysAggregatorState<T>
where
T: DFPrimitiveType,
Expand All @@ -170,11 +190,6 @@ where
self.data.iter()
}

#[inline(always)]
fn entity(&mut self, key: &Self::Key, inserted: &mut bool) -> *mut Self::Entity {
self.data.insert_key(key, inserted)
}

#[inline(always)]
fn alloc_layout(&self, params: &AggregatorParams) -> StateAddr {
let place: StateAddr = self.area.alloc_layout(params.layout).into();
Expand All @@ -187,6 +202,11 @@ where

place
}

#[inline(always)]
fn entity(&mut self, key: &Self::Key, inserted: &mut bool) -> *mut Self::Entity {
self.data.insert_key(key, inserted)
}
}

pub struct SerializedKeysAggregatorState {
Expand All @@ -195,6 +215,16 @@ pub struct SerializedKeysAggregatorState {
pub data_state_map: HashMap<KeysRef, usize>,
}

// TODO:(Winter) Hack:
// The *mut KeyValueEntity needs to be used externally, but we can ensure that *mut KeyValueEntity
// will not be used multiple async, so KeyValueEntity is Send
unsafe impl Send for SerializedKeysAggregatorState {}

// TODO:(Winter) Hack:
// The *mut KeyValueEntity needs to be used externally, but we can ensure that &*mut KeyValueEntity
// will not be used multiple async, so KeyValueEntity is Sync
unsafe impl Sync for SerializedKeysAggregatorState {}

impl AggregatorState<HashMethodSerializer> for SerializedKeysAggregatorState {
type Key = KeysRef;
type Entity = KeyValueEntity<KeysRef, usize>;
Expand All @@ -208,9 +238,18 @@ impl AggregatorState<HashMethodSerializer> for SerializedKeysAggregatorState {
self.data_state_map.iter()
}

// fn alloc_layout(&self, memory_layout: &AggregatorLayout) -> NonNull<u8> {
// self.state_area.alloc_layout(memory_layout.layout)
// }
#[inline(always)]
fn alloc_layout(&self, params: &AggregatorParams) -> StateAddr {
let place: StateAddr = self.state_area.alloc_layout(params.layout).into();

for idx in 0..params.offsets_aggregate_states.len() {
let aggr_state = params.offsets_aggregate_states[idx];
let aggr_state_place = place.next(aggr_state);
params.aggregate_functions[idx].init_state(aggr_state_place);
}

place
}

fn entity(&mut self, keys: &Vec<u8>, inserted: &mut bool) -> *mut Self::Entity {
let mut keys_ref = KeysRef::create(keys.as_ptr() as usize, keys.len());
Expand All @@ -229,17 +268,4 @@ impl AggregatorState<HashMethodSerializer> for SerializedKeysAggregatorState {

state_entity
}

#[inline(always)]
fn alloc_layout(&self, params: &AggregatorParams) -> StateAddr {
let place: StateAddr = self.state_area.alloc_layout(params.layout).into();

for idx in 0..params.offsets_aggregate_states.len() {
let aggr_state = params.offsets_aggregate_states[idx];
let aggr_state_place = place.next(aggr_state);
params.aggregate_functions[idx].init_state(aggr_state_place);
}

place
}
}
5 changes: 2 additions & 3 deletions query/src/pipelines/transforms/transform_group_by_partial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,13 @@ impl GroupByPartialTransform {
let aggregator_params = AggregatorParams::try_create(schema, aggr_exprs)?;

let aggregator = Aggregator::create(method, aggregator_params);
let groups_locker = aggregator.aggregate(group_cols, stream).await?;
let state = aggregator.aggregate(group_cols, stream).await?;

let delta = start.elapsed();
tracing::debug!("Group by partial cost: {:?}", delta);

let groups = groups_locker.lock();
let finalized_schema = self.schema.clone();
aggregator.aggregate_finalized(&groups, finalized_schema)
aggregator.aggregate_finalized(&state, finalized_schema)
}
}

Expand Down