Skip to content

Commit

Permalink
fix: Frontend doesn't handle left semi join correctly. (#2363)
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-chuang authored May 9, 2022
1 parent b4779d8 commit 27fc53b
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 44 deletions.
134 changes: 112 additions & 22 deletions src/frontend/src/optimizer/plan_node/logical_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,12 @@ impl LogicalJoin {
Self::new(left, right, join_type, Condition::with_expr(on_clause)).into()
}

// FIXME: please note that the modification here is just a temporary fix for bug of LogicalJoin.
// Related issue is #1849.
pub fn out_column_num(left_len: usize, right_len: usize, join_type: JoinType) -> usize {
match join_type {
JoinType::Inner
| JoinType::LeftOuter
| JoinType::RightOuter
| JoinType::FullOuter
| JoinType::LeftSemi => left_len + right_len,
JoinType::LeftAnti => left_len,
JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => {
left_len + right_len
}
JoinType::LeftSemi | JoinType::LeftAnti => left_len,
JoinType::RightSemi | JoinType::RightAnti => right_len,
}
}
Expand All @@ -104,14 +100,11 @@ impl LogicalJoin {
join_type: JoinType,
) -> ColIndexMapping {
match join_type {
JoinType::LeftSemi
| JoinType::Inner
| JoinType::LeftOuter
| JoinType::RightOuter
| JoinType::FullOuter => {
JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => {
ColIndexMapping::identity_or_none(left_len + right_len, left_len)
}
JoinType::LeftAnti => ColIndexMapping::identity(left_len),

JoinType::LeftSemi | JoinType::LeftAnti => ColIndexMapping::identity(left_len),
JoinType::RightSemi | JoinType::RightAnti => ColIndexMapping::empty(right_len),
}
}
Expand All @@ -123,14 +116,10 @@ impl LogicalJoin {
join_type: JoinType,
) -> ColIndexMapping {
match join_type {
JoinType::LeftSemi
| JoinType::Inner
| JoinType::LeftOuter
| JoinType::RightOuter
| JoinType::FullOuter => {
JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => {
ColIndexMapping::with_shift_offset(left_len + right_len, -(left_len as isize))
}
JoinType::LeftAnti => ColIndexMapping::empty(left_len),
JoinType::LeftSemi | JoinType::LeftAnti => ColIndexMapping::empty(left_len),
JoinType::RightSemi | JoinType::RightAnti => ColIndexMapping::identity(right_len),
}
}
Expand Down Expand Up @@ -237,6 +226,14 @@ impl LogicalJoin {
pub fn clone_with_cond(&self, cond: Condition) -> Self {
Self::new(self.left.clone(), self.right.clone(), self.join_type, cond)
}

pub fn is_left_join(&self) -> bool {
matches!(self.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
}

pub fn is_right_join(&self) -> bool {
matches!(self.join_type(), JoinType::RightSemi | JoinType::RightAnti)
}
}

impl PlanTreeNodeBinary for LogicalJoin {
Expand Down Expand Up @@ -293,7 +290,18 @@ impl ColPrunable for LogicalJoin {

let left_len = self.left.schema().fields.len();

let mut visitor = CollectInputRef::new(required_cols.clone());
let total_len = self.left().schema().len() + self.right().schema().len();
let mut resized_required_cols = FixedBitSet::with_capacity(total_len);

required_cols.ones().for_each(|i| {
if self.is_right_join() {
resized_required_cols.insert(left_len + i);
} else {
resized_required_cols.insert(i);
}
});

let mut visitor = CollectInputRef::new(resized_required_cols);
self.on.visit_expr(&mut visitor);
let left_right_required_cols = visitor.collect();

Expand All @@ -319,9 +327,17 @@ impl ColPrunable for LogicalJoin {
on,
);

if required_cols == &left_right_required_cols {
let required_inputs_in_output = if self.is_left_join() {
left_required_cols
} else if self.is_right_join() {
right_required_cols
} else {
left_right_required_cols
};
if required_cols == &required_inputs_in_output {
join.into()
} else {
let mapping = ColIndexMapping::with_remaining_columns(&required_inputs_in_output);
let mut remaining_columns = FixedBitSet::with_capacity(join.schema().fields().len());
remaining_columns.extend(required_cols.ones().map(|i| mapping.map(i)));
LogicalProject::with_mapping(
Expand Down Expand Up @@ -508,6 +524,80 @@ mod tests {
assert_eq!(right.schema().fields(), &fields[3..4]);
}

/// Semi join panicked previously at `prune_col`. Add test to prevent regression.
#[tokio::test]
async fn test_prune_semi_join() {
let ty = DataType::Int32;
let ctx = OptimizerContext::mock().await;
let fields: Vec<Field> = (1..7)
.map(|i| Field::with_name(ty.clone(), format!("v{}", i)))
.collect();
let left = LogicalValues::new(
vec![],
Schema {
fields: fields[0..3].to_vec(),
},
ctx.clone(),
);
let right = LogicalValues::new(
vec![],
Schema {
fields: fields[3..6].to_vec(),
},
ctx,
);
let on: ExprImpl = ExprImpl::FunctionCall(Box::new(
FunctionCall::new(
Type::Equal,
vec![
ExprImpl::InputRef(Box::new(InputRef::new(1, ty.clone()))),
ExprImpl::InputRef(Box::new(InputRef::new(4, ty))),
],
)
.unwrap(),
));
for join_type in [
JoinType::LeftSemi,
JoinType::RightSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
] {
let join = LogicalJoin::new(
left.clone().into(),
right.clone().into(),
join_type,
Condition::with_expr(on.clone()),
);

let offset = if join.is_right_join() { 3 } else { 0 };

// Perform the prune
let mut required_cols = FixedBitSet::with_capacity(3);
// key 0 is never used in the join (always key 1)
required_cols.extend(vec![0]);
// should not panic here
let plan = join.prune_col(&required_cols);
// Check that the join has been wrapped in a projection
let as_plan = plan.as_logical_project().unwrap();
// Check the result
assert_eq!(as_plan.schema().fields().len(), 1);
assert_eq!(as_plan.schema().fields()[0], fields[offset]);

// Perform the prune
let mut required_cols = FixedBitSet::with_capacity(3);
required_cols.extend(vec![0, 1, 2]);
// should not panic here
let plan = join.prune_col(&required_cols);
// Check that the join has not been wrapped in a projection
let as_plan = plan.as_logical_join().unwrap();
// Check the result
assert_eq!(as_plan.schema().fields().len(), 3);
assert_eq!(as_plan.schema().fields()[0], fields[offset]);
assert_eq!(as_plan.schema().fields()[1], fields[offset + 1]);
assert_eq!(as_plan.schema().fields()[2], fields[offset + 2]);
}
}

/// Pruning
/// ```text
/// Join(on: input_ref(1)=input_ref(3))
Expand Down
1 change: 0 additions & 1 deletion src/frontend/test_runner/tests/testdata/basic_query.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,3 @@
create table t (v1 int);
select coalesce(1,'a') from t;
binder_error: 'Bind error: Coalesce function cannot match types Int32 and Varchar'

Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,10 @@
LogicalFilter { predicate: ($2 = 100:Int32) AND (CorrelatedInputRef { index: 1, depth: 1 } = $1) AND ($1 = 1000:Int32) AND (CorrelatedInputRef { index: 2, depth: 1 } = $2) }
LogicalScan { table: t2, columns: [_row_id#0, x, y] }
optimized_logical_plan: |
LogicalProject { exprs: [$0, $1], expr_alias: [x, y] }
LogicalJoin { type: LeftSemi, on: ($0 = $2) AND ($1 = $3) }
LogicalScan { table: t1, columns: [x, y] }
LogicalFilter { predicate: ($1 = 100:Int32) AND ($0 = 1000:Int32) }
LogicalScan { table: t2, columns: [x, y] }
LogicalJoin { type: LeftSemi, on: ($0 = $2) AND ($1 = $3) }
LogicalScan { table: t1, columns: [x, y] }
LogicalFilter { predicate: ($1 = 100:Int32) AND ($0 = 1000:Int32) }
LogicalScan { table: t2, columns: [x, y] }
- sql: |
create table t1(x int, y int);
create table t2(x int, y int);
Expand Down
32 changes: 16 additions & 16 deletions src/frontend/test_runner/tests/testdata/tpch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@
StreamMaterialize { columns: [o_orderpriority, order_count], pk_columns: [o_orderpriority] }
StreamProject { exprs: [$0, $2], expr_alias: [o_orderpriority, order_count] }
StreamHashAgg { group_keys: [$0], aggs: [count, count] }
StreamProject { exprs: [$1, $2, $4], expr_alias: [ , , ] }
StreamProject { exprs: [$1, $2], expr_alias: [ , ] }
StreamExchange { dist: HashShard([1]) }
StreamHashJoin { type: LeftSemi, predicate: $0 = $3 }
StreamProject { exprs: [$0, $2, $3], expr_alias: [ , , ] }
Expand Down Expand Up @@ -1466,7 +1466,7 @@
StreamExchange { dist: Single }
StreamProject { exprs: [$0, $1, $2, $3, $4, $6], expr_alias: [c_name, c_custkey, o_orderkey, o_orderdate, o_totalprice, quantity] }
StreamHashAgg { group_keys: [$0, $1, $2, $3, $4], aggs: [count, sum($5)] }
StreamProject { exprs: [$1, $0, $2, $4, $3, $5, $6, $7, $8, $9], expr_alias: [ , , , , , , , , , ] }
StreamProject { exprs: [$1, $0, $2, $4, $3, $5, $6, $7, $8], expr_alias: [ , , , , , , , , ] }
StreamExchange { dist: HashShard([1, 0, 2, 4, 3]) }
StreamHashJoin { type: LeftSemi, predicate: $2 = $9 }
StreamProject { exprs: [$0, $1, $2, $3, $4, $8, $5, $6, $9], expr_alias: [ , , , , , , , , ] }
Expand Down Expand Up @@ -1611,11 +1611,11 @@
BatchProject { exprs: [$0], expr_alias: [ps_suppkey] }
BatchExchange { order: [], dist: HashShard([0]) }
BatchFilter { predicate: ($1 > (0.5:Decimal * $2)) }
BatchProject { exprs: [$2, $3, $7], expr_alias: [ , , ] }
BatchHashAgg { group_keys: [$0, $1, $2, $3, $4, $5, $6], aggs: [sum($7)] }
BatchProject { exprs: [$0, $1, $2, $3, $4, $5, $6, $7], expr_alias: [ , , , , , , , ] }
BatchExchange { order: [], dist: HashShard([0, 1, 2, 3, 4, 5, 6]) }
BatchHashJoin { type: LeftOuter, predicate: $1 = $8AND $2 = $9 }
BatchProject { exprs: [$2, $3, $6], expr_alias: [ , , ] }
BatchHashAgg { group_keys: [$0, $1, $2, $3, $4, $5], aggs: [sum($6)] }
BatchProject { exprs: [$0, $1, $2, $3, $4, $5, $6], expr_alias: [ , , , , , , ] }
BatchExchange { order: [], dist: HashShard([0, 1, 2, 3, 4, 5]) }
BatchHashJoin { type: LeftOuter, predicate: $1 = $7AND $2 = $8 }
BatchExchange { order: [], dist: HashShard([1, 2]) }
BatchHashJoin { type: LeftSemi, predicate: $1 = $6 }
BatchExchange { order: [], dist: HashShard([1]) }
Expand All @@ -1629,9 +1629,9 @@
BatchFilter { predicate: ($3 >= '1994-01-01':Varchar::Date) AND ($3 < ('1994-01-01':Varchar::Date + '1 year 00:00:00':Interval)) }
BatchScan { table: lineitem, columns: [l_partkey, l_suppkey, l_quantity, l_shipdate] }
stream_plan: |
StreamMaterialize { columns: [s_name, s_address, _row_id#0(hidden), _row_id#1(hidden), _row_id#2(hidden), ps_partkey(hidden), ps_suppkey(hidden), ps_availqty(hidden), ps_supplycost(hidden), ps_comment(hidden), p_partkey(hidden)], pk_columns: [_row_id#0, _row_id#1, _row_id#2, ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment, p_partkey], order_descs: [s_name, _row_id#0, _row_id#1, _row_id#2, ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment, p_partkey] }
StreamExchange { dist: HashShard([2, 3, 4, 5, 6, 7, 8, 9, 10]) }
StreamProject { exprs: [$1, $2, $3, $4, $6, $7, $5, $8, $9, $10, $11], expr_alias: [s_name, s_address, , , , , , , , , ] }
StreamMaterialize { columns: [s_name, s_address, _row_id#0(hidden), _row_id#1(hidden)], pk_columns: [_row_id#0, _row_id#1], order_descs: [s_name, _row_id#0, _row_id#1] }
StreamExchange { dist: HashShard([2, 3]) }
StreamProject { exprs: [$1, $2, $3, $4], expr_alias: [s_name, s_address, , ] }
StreamHashJoin { type: LeftSemi, predicate: $0 = $5 }
StreamProject { exprs: [$0, $1, $2, $4, $6], expr_alias: [ , , , , ] }
StreamExchange { dist: HashShard([0]) }
Expand All @@ -1642,13 +1642,13 @@
StreamExchange { dist: HashShard([0]) }
StreamFilter { predicate: ($1 = 'KENYA':Varchar) }
StreamTableScan { table: nation, columns: [n_nationkey, n_name, _row_id#0], pk_indices: [2] }
StreamProject { exprs: [$0, $3, $4, $1, $5, $6, $7], expr_alias: [ps_suppkey, , , , , , ] }
StreamProject { exprs: [$0, $3, $4, $1, $5, $6], expr_alias: [ps_suppkey, , , , , ] }
StreamFilter { predicate: ($1 > (0.5:Decimal * $2)) }
StreamProject { exprs: [$2, $3, $8, $0, $1, $4, $5, $6], expr_alias: [ , , , , , , , ] }
StreamHashAgg { group_keys: [$0, $1, $2, $3, $4, $5, $6], aggs: [count, sum($7)] }
StreamProject { exprs: [$0, $1, $2, $3, $4, $5, $6, $8, $7, $11], expr_alias: [ , , , , , , , , , ] }
StreamExchange { dist: HashShard([0, 1, 2, 3, 4, 5, 6]) }
StreamHashJoin { type: LeftOuter, predicate: $1 = $9AND $2 = $10 }
StreamProject { exprs: [$2, $3, $7, $0, $1, $4, $5], expr_alias: [ , , , , , , ] }
StreamHashAgg { group_keys: [$0, $1, $2, $3, $4, $5], aggs: [count, sum($6)] }
StreamProject { exprs: [$0, $1, $2, $3, $4, $5, $6, $9], expr_alias: [ , , , , , , , ] }
StreamExchange { dist: HashShard([0, 1, 2, 3, 4, 5]) }
StreamHashJoin { type: LeftOuter, predicate: $1 = $7AND $2 = $8 }
StreamExchange { dist: HashShard([1, 2]) }
StreamHashJoin { type: LeftSemi, predicate: $1 = $6 }
StreamExchange { dist: HashShard([1]) }
Expand Down

0 comments on commit 27fc53b

Please sign in to comment.