Skip to content

Commit

Permalink
feat(scheduler): exchange node rewrite in serialization (#525)
Browse files Browse the repository at this point in the history
* feat: exchange node rewrite in serialization

* resolve all comments
  • Loading branch information
BowenXiao1999 authored Mar 14, 2022
1 parent 2ac31c6 commit 4a80212
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 23 deletions.
12 changes: 11 additions & 1 deletion rust/common/src/catalog/schema.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::ops::Index;

use risingwave_pb::plan::Field as ProstField;
#[allow(unused_imports)]
use risingwave_pb::plan::{ColumnDesc, ExchangeInfo, Field as ProstField};

use crate::array::ArrayBuilderImpl;
use crate::error::Result;
Expand All @@ -19,6 +20,15 @@ impl std::fmt::Debug for Field {
}
}

impl Field {
pub fn to_prost(&self) -> Result<ProstField> {
Ok(ProstField {
data_type: Some(self.data_type.to_protobuf()?),
name: self.name.to_string(),
})
}
}

/// the schema of the executor's return data
#[derive(Clone, Debug, Default, PartialEq)]
pub struct Schema {
Expand Down
11 changes: 10 additions & 1 deletion rust/frontend/src/optimizer/plan_node/batch_exchange.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::fmt;

use risingwave_common::catalog::Schema;
use risingwave_pb::plan::plan_node::NodeBody;
use risingwave_pb::plan::ExchangeNode;

use super::{BatchBase, PlanRef, PlanTreeNodeUnary, ToBatchProst, ToDistributedBatch};
use crate::optimizer::property::{Distribution, Order, WithDistribution, WithOrder, WithSchema};
Expand Down Expand Up @@ -60,4 +62,11 @@ impl ToDistributedBatch for BatchExchange {
}
}

impl ToBatchProst for BatchExchange {}
/// The serialization of Batch Exchange is default cuz it will be rewritten in scheduler.
impl ToBatchProst for BatchExchange {
fn to_batch_prost_body(&self) -> NodeBody {
NodeBody::Exchange(ExchangeNode {
..Default::default()
})
}
}
30 changes: 29 additions & 1 deletion rust/frontend/src/optimizer/plan_node/batch_hash_join.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::fmt;

use risingwave_common::catalog::Schema;
use risingwave_pb::plan::plan_node::NodeBody;
use risingwave_pb::plan::HashJoinNode;

use super::{
BatchBase, EqJoinPredicate, LogicalJoin, PlanRef, PlanTreeNodeBinary, ToBatchProst,
Expand Down Expand Up @@ -91,4 +93,30 @@ impl ToDistributedBatch for BatchHashJoin {
}
}

impl ToBatchProst for BatchHashJoin {}
impl ToBatchProst for BatchHashJoin {
fn to_batch_prost_body(&self) -> NodeBody {
NodeBody::HashJoin(HashJoinNode {
join_type: self.logical.join_type() as i32,
left_key: self
.eq_join_predicate
.left_eq_indexes()
.into_iter()
.map(|a| a as i32)
.collect(),
right_key: self
.eq_join_predicate
.right_eq_indexes()
.into_iter()
.map(|a| a as i32)
.collect(),
left_output: (0..self.logical.left().schema().len())
.into_iter()
.map(|a| a as i32)
.collect(),
right_output: (0..self.logical.right().schema().len())
.into_iter()
.map(|a| a as i32)
.collect(),
})
}
}
15 changes: 14 additions & 1 deletion rust/frontend/src/optimizer/plan_node/batch_seq_scan.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::fmt;

use risingwave_common::catalog::Schema;
use risingwave_pb::plan::plan_node::NodeBody;
use risingwave_pb::plan::{CellBasedTableDesc, RowSeqScanNode};

use super::{BatchBase, PlanRef, ToBatchProst, ToDistributedBatch};
use crate::optimizer::plan_node::LogicalScan;
Expand Down Expand Up @@ -49,4 +51,15 @@ impl ToDistributedBatch for BatchSeqScan {
}
}

impl ToBatchProst for BatchSeqScan {}
impl ToBatchProst for BatchSeqScan {
fn to_batch_prost_body(&self) -> NodeBody {
// TODO(Bowen): Fix this serialization.
NodeBody::RowSeqScan(RowSeqScanNode {
table_desc: Some(CellBasedTableDesc {
table_id: self.logical.table_id(),
pk: vec![],
}),
..Default::default()
})
}
}
4 changes: 4 additions & 0 deletions rust/frontend/src/optimizer/plan_node/logical_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ impl LogicalScan {
f.field("table", &self.table_name)
.field("columns", &columns);
}

pub fn table_id(&self) -> u32 {
self.table_id.table_id
}
}

impl_plan_tree_node_for_leaf! {LogicalScan}
Expand Down
29 changes: 29 additions & 0 deletions rust/frontend/src/optimizer/property/distribution.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use paste::paste;
use risingwave_pb::plan::exchange_info::hash_info::HashMethod;
use risingwave_pb::plan::exchange_info::{
BroadcastInfo, Distribution as DistributionProst, DistributionMode, HashInfo,
};
use risingwave_pb::plan::ExchangeInfo;

use super::super::plan_node::*;
use crate::optimizer::property::{Convention, Order};
Expand All @@ -18,6 +23,30 @@ static ANY_DISTRIBUTION: Distribution = Distribution::Any;

#[allow(dead_code)]
impl Distribution {
pub fn to_prost(&self, output_count: u32) -> ExchangeInfo {
ExchangeInfo {
mode: match self {
Distribution::Single => DistributionMode::Single,
Distribution::Broadcast => DistributionMode::Broadcast,
Distribution::HashShard(_keys) => DistributionMode::Hash,
// TODO: Should panic if AnyShard or Any
_ => DistributionMode::Hash,
} as i32,
distribution: match self {
Distribution::Single => None,
Distribution::Broadcast => Some(DistributionProst::BroadcastInfo(BroadcastInfo {
count: output_count,
})),
Distribution::HashShard(keys) => Some(DistributionProst::HashInfo(HashInfo {
output_count,
keys: keys.iter().map(|num| *num as u32).collect(),
hash_method: HashMethod::Crc32 as i32,
})),
// TODO: Should panic if AnyShard or Any
Distribution::AnyShard | Distribution::Any => None,
},
}
}
pub fn enforce_if_not_satisfies(&self, plan: PlanRef, required_order: &Order) -> PlanRef {
if !plan.distribution().satisfies(self) {
self.enforce(plan, required_order)
Expand Down
99 changes: 83 additions & 16 deletions rust/frontend/src/scheduler/plan_fragmenter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl BatchPlanFragmenter {
/// Contains the connection info of each stage.
pub(crate) struct Query {
/// Query id should always be unique.
query_id: Uuid,
pub(crate) query_id: Uuid,
pub(crate) stage_graph: StageGraph,
}

Expand Down Expand Up @@ -71,7 +71,7 @@ pub(crate) struct StageGraph {
parent_edges: HashMap<StageId, HashSet<StageId>>,
/// Indicates which stage the exchange executor is running on.
/// Look up child stage for exchange source so that parent stage knows where to pull data.
exchange_id_to_stage: HashMap<u64, StageId>,
pub(crate) exchange_id_to_stage: HashMap<i32, StageId>,
}

impl StageGraph {
Expand All @@ -88,7 +88,7 @@ struct StageGraphBuilder {
stages: HashMap<StageId, QueryStageRef>,
child_edges: HashMap<StageId, HashSet<StageId>>,
parent_edges: HashMap<StageId, HashSet<StageId>>,
exchange_id_to_stage: HashMap<u64, StageId>,
exchange_id_to_stage: HashMap<i32, StageId>,
}

impl StageGraphBuilder {
Expand Down Expand Up @@ -127,7 +127,7 @@ impl StageGraphBuilder {
/// # Arguments
///
/// * `exchange_id` - The operator id of exchange executor.
pub fn link_to_child(&mut self, parent_id: StageId, exchange_id: u64, child_id: StageId) {
pub fn link_to_child(&mut self, parent_id: StageId, exchange_id: i32, child_id: StageId) {
let child_ids = self.child_edges.get_mut(&parent_id);
// If the parent id does not exist, create a new set containing the child ids. Otherwise
// just insert.
Expand Down Expand Up @@ -204,9 +204,11 @@ impl BatchPlanFragmenter {
// link with current stage.
let child_query_stage =
self.new_query_stage(child_node.clone(), child_node.distribution().clone());
// TODO(Bowen): replace mock exchange id 0 to real operator id (#67).
self.stage_graph_builder
.link_to_child(cur_stage.id, 0, child_query_stage.id);
self.stage_graph_builder.link_to_child(
cur_stage.id,
node.id().0,
child_query_stage.id,
);
self.build_stage(&child_query_stage, child_node);
}
} else {
Expand All @@ -224,8 +226,12 @@ mod tests {
use std::rc::Rc;
use std::sync::Arc;

use risingwave_common::catalog::{Schema, TableId};
use risingwave_pb::common::{ParallelUnit, ParallelUnitType, WorkerNode, WorkerType};
use risingwave_common::catalog::{Field, Schema, TableId};
use risingwave_common::types::DataType;
use risingwave_pb::common::{
HostAddress, ParallelUnit, ParallelUnitType, WorkerNode, WorkerType,
};
use risingwave_pb::plan::exchange_info::DistributionMode;
use risingwave_pb::plan::JoinType;

use crate::optimizer::plan_node::{
Expand All @@ -249,24 +255,30 @@ mod tests {
// Scan Scan
//
let ctx = Rc::new(RefCell::new(QueryContext::mock().await));
let fields = vec![
Field::unnamed(DataType::Int32),
Field::unnamed(DataType::Float64),
];
let batch_plan_node: PlanRef = BatchSeqScan::new(LogicalScan::new(
"".to_string(),
TableId::default(),
vec![],
Schema::default(),
Schema {
fields: fields.clone(),
},
ctx,
))
.into();
let batch_exchange_node1: PlanRef = BatchExchange::new(
batch_plan_node.clone(),
Order::default(),
Distribution::AnyShard,
Distribution::HashShard(vec![0, 1, 2]),
)
.into();
let batch_exchange_node2: PlanRef = BatchExchange::new(
batch_plan_node.clone(),
Order::default(),
Distribution::AnyShard,
Distribution::HashShard(vec![0, 1, 2]),
)
.into();
let hash_join_node: PlanRef = BatchHashJoin::new(
Expand All @@ -288,7 +300,7 @@ mod tests {

// Break the plan node into fragments.
let fragmenter = BatchPlanFragmenter::new();
let query = fragmenter.split(batch_exchange_node3).unwrap();
let query = fragmenter.split(batch_exchange_node3.clone()).unwrap();

assert_eq!(query.stage_graph.id, 0);
assert_eq!(query.stage_graph.stages.len(), 4);
Expand Down Expand Up @@ -319,21 +331,30 @@ mod tests {
let worker1 = WorkerNode {
id: 0,
r#type: WorkerType::ComputeNode as i32,
host: None,
host: Some(HostAddress {
host: "127.0.0.1".to_string(),
port: 5687,
}),
state: risingwave_pb::common::worker_node::State::Running as i32,
parallel_units: generate_parallel_units(0),
};
let worker2 = WorkerNode {
id: 1,
r#type: WorkerType::ComputeNode as i32,
host: None,
host: Some(HostAddress {
host: "127.0.0.1".to_string(),
port: 5688,
}),
state: risingwave_pb::common::worker_node::State::Running as i32,
parallel_units: generate_parallel_units(8),
};
let worker3 = WorkerNode {
id: 2,
r#type: WorkerType::ComputeNode as i32,
host: None,
host: Some(HostAddress {
host: "127.0.0.1".to_string(),
port: 5689,
}),
state: risingwave_pb::common::worker_node::State::Running as i32,
parallel_units: generate_parallel_units(16),
};
Expand Down Expand Up @@ -370,6 +391,52 @@ mod tests {
assert_eq!(scan_node_2.assignments.get(&0).unwrap(), &worker1);
assert_eq!(scan_node_2.assignments.get(&1).unwrap(), &worker2);
assert_eq!(scan_node_2.assignments.get(&2).unwrap(), &worker3);

// Check that the serialized exchange source node has been filled with correct info.
let prost_node_root = root.augmented_stage.to_prost(0, &query).unwrap();
assert_eq!(
prost_node_root.exchange_info.unwrap().mode,
DistributionMode::Single as i32
);
assert_eq!(prost_node_root.root.clone().unwrap().children.len(), 0);
if let risingwave_pb::plan::plan_node::NodeBody::Exchange(exchange) =
prost_node_root.root.unwrap().node_body.unwrap()
{
assert_eq!(exchange.source_stage_id.unwrap().stage_id, 1);
assert_eq!(exchange.sources.len(), 3);
assert_eq!(exchange.input_schema.len(), 4);
} else {
panic!("The root node should be exchange single");
}

let prost_join_node = join_node.augmented_stage.to_prost(0, &query).unwrap();
assert_eq!(prost_join_node.root.as_ref().unwrap().children.len(), 2);
assert_eq!(
prost_join_node.exchange_info.unwrap().mode,
DistributionMode::Hash as i32
);
if let risingwave_pb::plan::plan_node::NodeBody::HashJoin(_) = prost_join_node
.root
.as_ref()
.unwrap()
.node_body
.as_ref()
.unwrap()
{
} else {
panic!("The node should be hash join node");
}

let exchange_1 = prost_join_node.root.as_ref().unwrap().children[0].clone();
if let risingwave_pb::plan::plan_node::NodeBody::Exchange(exchange) =
exchange_1.node_body.unwrap()
{
assert_eq!(exchange.source_stage_id.unwrap().stage_id, 2);
assert_eq!(exchange.sources.len(), 3);
assert_eq!(exchange.input_schema.len(), 2);
} else {
panic!("The node should be exchange node");
}
}

fn generate_parallel_units(start_id: u32) -> Vec<ParallelUnit> {
Expand Down
Loading

0 comments on commit 4a80212

Please sign in to comment.