Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(streaming): use table catalog in hash join #3707

Merged
merged 19 commits into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions proto/stream_plan.proto
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,11 @@ message HashJoinNode {
catalog.Table left_table = 6;
// Used for internal table states.
catalog.Table right_table = 7;
repeated uint32 dist_key_l = 8;
repeated uint32 dist_key_r = 9;
// It is true when the input is append-only
bool is_append_only = 10;
bool is_append_only = 8;
// Whether to optimize for append only stream.
yuhao-su marked this conversation as resolved.
Show resolved Hide resolved
// the output indices of current node
repeated uint32 output_indices = 11;
repeated uint32 output_indices = 9;
}

// Delta join with two indexes. This is a pseudo plan node generated on frontend. On meta
Expand Down
158 changes: 71 additions & 87 deletions src/frontend/src/optimizer/plan_node/stream_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;
use std::fmt;

use itertools::Itertools;
use risingwave_common::catalog::{ColumnDesc, DatabaseId, SchemaId, TableId};
use risingwave_common::catalog::{DatabaseId, Field, SchemaId};
use risingwave_common::types::DataType;
use risingwave_common::util::sort_util::OrderType;
use risingwave_pb::plan_common::JoinType;
use risingwave_pb::stream_plan::stream_node::NodeBody;
use risingwave_pb::stream_plan::HashJoinNode;

use super::utils::TableCatalogBuilder;
use super::{LogicalJoin, PlanBase, PlanRef, PlanTreeNodeBinary, StreamDeltaJoin, ToStreamProst};
use crate::catalog::column_catalog::ColumnCatalog;
use crate::catalog::table_catalog::TableCatalog;
use crate::expr::Expr;
use crate::optimizer::plan_node::EqJoinPredicate;
use crate::optimizer::property::{Direction, Distribution, FieldOrder};
use crate::optimizer::property::Distribution;
use crate::utils::ColIndexMapping;

/// [`StreamHashJoin`] implements [`super::LogicalJoin`] with hash table. It builds a hash table
Expand All @@ -46,8 +47,6 @@ pub struct StreamHashJoin {
/// only. Will remove after we have fully support shared state and index.
is_delta: bool,

dist_key_l: Distribution,
dist_key_r: Distribution,
/// Whether can optimize for append-only stream.
/// It is true if input of both side is append-only
is_append_only: bool,
Expand All @@ -70,9 +69,6 @@ impl StreamHashJoin {
.composite(&logical.i2o_col_mapping()),
);

let dist_l = logical.left().distribution().clone();
let dist_r = logical.right().distribution().clone();

let force_delta = ctx.inner().session_ctx.config().get_delta_join();

// TODO: derive from input
Expand All @@ -89,8 +85,6 @@ impl StreamHashJoin {
logical,
eq_join_predicate,
is_delta: force_delta,
dist_key_l: dist_l,
dist_key_r: dist_r,
is_append_only: append_only,
}
}
Expand Down Expand Up @@ -181,46 +175,48 @@ impl_plan_tree_node_for_binary! { StreamHashJoin }

impl ToStreamProst for StreamHashJoin {
fn to_stream_prost_body(&self) -> NodeBody {
let left_key_indices_prost = self
.eq_join_predicate
.left_eq_indexes()
.iter()
.map(|v| *v as i32)
.collect_vec();
let right_key_indices_prost = self
.eq_join_predicate
.right_eq_indexes()
.iter()
.map(|v| *v as i32)
.collect_vec();
let left_key_indices = left_key_indices_prost
yuhao-su marked this conversation as resolved.
Show resolved Hide resolved
.iter()
.map(|idx| *idx as usize)
.collect_vec();
let right_key_indices = right_key_indices_prost
.iter()
.map(|idx| *idx as usize)
.collect_vec();
NodeBody::HashJoin(HashJoinNode {
join_type: self.logical.join_type() as i32,
left_key: self
.eq_join_predicate
.left_eq_indexes()
.iter()
.map(|v| *v as i32)
.collect(),
right_key: self
.eq_join_predicate
.right_eq_indexes()
.iter()
.map(|v| *v as i32)
.collect(),
left_key: left_key_indices_prost,
right_key: right_key_indices_prost,
condition: self
.eq_join_predicate
.other_cond()
.as_expr_unless_true()
.map(|x| x.to_expr_proto()),
dist_key_l: self
.dist_key_l
.dist_column_indices()
.iter()
.map(|idx| *idx as u32)
.collect_vec(),
dist_key_r: self
.dist_key_r
.dist_column_indices()
.iter()
.map(|idx| *idx as u32)
.collect_vec(),
is_delta_join: self.is_delta,
left_table: Some(infer_internal_table_catalog(self.left()).to_prost(
SchemaId::placeholder() as u32,
DatabaseId::placeholder() as u32,
)),
right_table: Some(infer_internal_table_catalog(self.right()).to_prost(
SchemaId::placeholder() as u32,
DatabaseId::placeholder() as u32,
)),
left_table: Some(
infer_internal_table_catalog(self.left(), left_key_indices).to_prost(
SchemaId::placeholder() as u32,
DatabaseId::placeholder() as u32,
),
),
Comment on lines +195 to +200
Copy link
Contributor

@BowenXiao1999 BowenXiao1999 Jul 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about impl infer_internal_table_catalog as a method in LogicalJoin, just like Agg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should remain in StreamHashJoin since it is only used in streaming.

right_table: Some(
infer_internal_table_catalog(self.right(), right_key_indices).to_prost(
SchemaId::placeholder() as u32,
DatabaseId::placeholder() as u32,
),
),
output_indices: self
.logical
.output_indices()
Expand All @@ -232,51 +228,39 @@ impl ToStreamProst for StreamHashJoin {
}
}

fn infer_internal_table_catalog(input: PlanRef) -> TableCatalog {
fn infer_internal_table_catalog(input: PlanRef, join_key_indices: Vec<usize>) -> TableCatalog {
let base = input.plan_base();
let schema = &base.schema;
let pk_indices = &base.pk_indices;
let mut col_names = HashMap::new();
// FIXME: temp fix, use TableCatalogBuilder to avoid_duplicate_col_name in the future (https://github.com/singularity-data/risingwave/issues/3657)
let columns = schema
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
let mut c = ColumnCatalog {
column_desc: ColumnDesc::from_field_with_column_id(field, i as i32),
is_hidden: false,
};
c.column_desc.name = match col_names.try_insert(field.name.clone(), 0) {
Ok(_) => field.name.clone(),
Err(mut err) => {
let cnt = err.entry.get_mut();
*cnt += 1;
field.name.clone() + "#" + &cnt.to_string()
}
};
c
})
.collect_vec();
let mut order_desc = vec![];
for &index in pk_indices {
order_desc.push(FieldOrder {
index,
direct: Direction::Asc,
});
}
TableCatalog {
id: TableId::placeholder(),
associated_source_id: None,
name: String::new(),
columns,
order_key: order_desc,
pk: pk_indices.clone(),
distribution_key: base.dist.dist_column_indices().to_vec(),
is_index_on: None,
appendonly: input.append_only(),
owner: risingwave_common::catalog::DEFAULT_SUPPER_USER.to_string(),
vnode_mapping: None,
properties: HashMap::default(),

let append_only = input.append_only();
let dist_keys = base.dist.dist_column_indices().to_vec();

// The pk of hash join internal table shoule be join_key + input_pk.
yuhao-su marked this conversation as resolved.
Show resolved Hide resolved
let mut pk_indices = join_key_indices;
// TODO(yuhao): dedupe the dist key and pk.
yuhao-su marked this conversation as resolved.
Show resolved Hide resolved
yuhao-su marked this conversation as resolved.
Show resolved Hide resolved
pk_indices.extend(&base.pk_indices);

let mut columns_fields = schema.fields().to_vec();

// The join degree at the end of internal table.
let degree_column_field = Field {
yuhao-su marked this conversation as resolved.
Show resolved Hide resolved
data_type: DataType::Int64,
name: "_degree".to_string(),
sub_fields: vec![],
type_name: "".to_string(),
};
columns_fields.push(degree_column_field);

let mut internal_table_catalog_builder = TableCatalogBuilder::new();

for (idx, field) in columns_fields.iter().enumerate() {
let order_type = if pk_indices.contains(&idx) {
Some(OrderType::Ascending)
yuhao-su marked this conversation as resolved.
Show resolved Hide resolved
} else {
None
};
internal_table_catalog_builder.add_column_desc_from_field(order_type, field)
}

internal_table_catalog_builder.build(dist_keys, append_only)
}
52 changes: 52 additions & 0 deletions src/storage/src/table/state_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ use std::borrow::Cow;
use std::cmp::Ordering;
use std::marker::PhantomData;
use std::ops::Index;
use std::sync::Arc;

use futures::{pin_mut, Stream, StreamExt};
use futures_async_stream::try_stream;
use risingwave_common::array::Row;
use risingwave_common::buffer::Bitmap;
use risingwave_common::catalog::{ColumnDesc, TableId};
use risingwave_common::util::ordered::{serialize_pk, OrderedRowSerializer};
use risingwave_common::util::sort_util::OrderType;
use risingwave_hummock_sdk::key::range_of_prefix;
use risingwave_pb::catalog::Table;

use super::mem_table::{MemTable, RowOp};
use super::storage_table::{StorageTableBase, READ_WRITE};
Expand Down Expand Up @@ -221,6 +224,55 @@ impl<S: StateStore> StateTable<S> {

Ok(StateTableRowIter::new(mem_table_iter, storage_table_iter).into_stream())
}

/// Create state table from table catalog and store.
pub fn from_table_catalog(
table_catalog: &Table,
store: S,
vnodes: Option<Arc<Bitmap>>,
) -> Self {
let table_columns = table_catalog
.columns
.iter()
.map(|col| col.column_desc.as_ref().unwrap().into())
.collect();
let order_types = table_catalog
.order_key
.iter()
.map(|col_order| {
OrderType::from_prost(
&risingwave_pb::plan_common::OrderType::from_i32(col_order.order_type).unwrap(),
)
})
.collect();
let dist_key_indices = table_catalog
.distribution_key
.iter()
.map(|dist_index| *dist_index as usize)
.collect();
let pk_indices = table_catalog
.order_key
.iter()
.map(|col_order| col_order.index as usize)
.collect();
let distribution = match vnodes {
// Hash Agg
Some(vnodes) => Distribution {
dist_key_indices,
vnodes,
},
// Simple Agg
yuhao-su marked this conversation as resolved.
Show resolved Hide resolved
None => Distribution::fallback(),
};
StateTable::new_with_distribution(
store,
TableId::new(table_catalog.id),
table_columns,
order_types,
pk_indices,
distribution,
)
}
}

pub type RowStream<'a, S: StateStore> = impl Stream<Item = StorageResult<Cow<'a, Row>>>;
Expand Down
49 changes: 2 additions & 47 deletions src/stream/src/executor/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@ use risingwave_common::buffer::Bitmap;
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::hash::HashCode;
use risingwave_common::types::{DataType, Datum};
use risingwave_common::util::sort_util::OrderType;
use risingwave_expr::expr::AggKind;
use risingwave_expr::*;
use risingwave_storage::table::state_table::StateTable;
use risingwave_storage::table::Distribution;
use risingwave_storage::StateStore;
pub use row_count::*;
use static_assertions::const_assert_eq;
Expand Down Expand Up @@ -406,51 +404,8 @@ pub fn generate_state_tables_from_proto<S: StateStore>(

for table_catalog in internal_tables {
// Parse info from proto and create state table.
let state_table = {
let columns = table_catalog
.columns
.iter()
.map(|col| col.column_desc.as_ref().unwrap().into())
.collect();
let order_types = table_catalog
.order_key
.iter()
.map(|order_key| {
OrderType::from_prost(
&risingwave_pb::plan_common::OrderType::from_i32(order_key.order_type)
.unwrap(),
)
})
.collect();
let dist_key_indices = table_catalog
.distribution_key
.iter()
.map(|dist_index| *dist_index as usize)
.collect();
let pk_indices = table_catalog
.pk
.iter()
.map(|pk_index| *pk_index as usize)
.collect();
let distribution = match vnodes.clone() {
// Hash Agg
Some(vnodes) => Distribution {
dist_key_indices,
vnodes,
},
// Simple Agg
None => Distribution::fallback(),
};
StateTable::new_with_distribution(
store.clone(),
risingwave_common::catalog::TableId::new(table_catalog.id),
columns,
order_types,
pk_indices,
distribution,
)
};

let state_table =
StateTable::from_table_catalog(table_catalog, store.clone(), vnodes.clone());
state_tables.push(state_table)
}
state_tables
Expand Down
Loading