Skip to content

Commit

Permalink
fix(expr, executor): Remove RowExpression: use eval_row and `unwr…
Browse files Browse the repository at this point in the history
…ap_or(false)` on `Datum`, not panic on Null/None (#3587)

* eval_row and unwrap_or(false), not panic

* fix

* fmt

* fmt

* fmt

* remove `RowExpression`

* fmt

* fix
  • Loading branch information
jon-chuang authored Jul 6, 2022
1 parent 8757d90 commit 3850d5e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 46 deletions.
22 changes: 0 additions & 22 deletions src/expr/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ pub mod expr_unary;
mod template;

use std::convert::TryFrom;
use std::slice;
use std::sync::Arc;

pub use agg::AggKind;
Expand Down Expand Up @@ -125,26 +124,5 @@ pub fn build_from_prost(prost: &ExprNode) -> Result<BoxedExpression> {
}
}

/// Simply wrap a row level expression as an array level expression
#[derive(Debug)]
pub struct RowExpression {
expr: BoxedExpression,
}

impl RowExpression {
pub fn new(expr: BoxedExpression) -> Self {
Self { expr }
}

pub fn eval(&mut self, row: &Row, data_types: &[DataType]) -> Result<ArrayRef> {
let input = DataChunk::from_rows(slice::from_ref(row), data_types)?;
self.expr.eval_checked(&input)
}

pub fn return_type(&self) -> DataType {
self.expr.return_type()
}
}

mod test_utils;
pub use test_utils::*;
34 changes: 13 additions & 21 deletions src/stream/src/executor/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ use futures::StreamExt;
use futures_async_stream::try_stream;
use itertools::Itertools;
use madsim::collections::HashSet;
use risingwave_common::array::{Array, ArrayRef, Op, Row, RowRef, StreamChunk};
use risingwave_common::array::{Op, Row, RowRef, StreamChunk};
use risingwave_common::bail;
use risingwave_common::catalog::{Schema, TableId};
use risingwave_common::hash::HashKey;
use risingwave_common::types::{DataType, ToOwnedDatum};
use risingwave_expr::expr::RowExpression;
use risingwave_expr::expr::BoxedExpression;
use risingwave_storage::StateStore;

use super::barrier_align::*;
Expand Down Expand Up @@ -182,7 +182,7 @@ pub struct HashJoinExecutor<K: HashKey, S: StateStore, const T: JoinTypePrimitiv
/// The parameters of the right join executor
side_r: JoinSide<K, S>,
/// Optional non-equi join conditions
cond: Option<RowExpression>,
cond: Option<BoxedExpression>,
/// Identity string
identity: String,
/// Epoch
Expand Down Expand Up @@ -380,7 +380,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
output_indices: Vec<usize>,
actor_id: u64,
executor_id: u64,
cond: Option<RowExpression>,
cond: Option<BoxedExpression>,
op_info: String,
store_l: S,
table_id_l: TableId,
Expand Down Expand Up @@ -603,22 +603,12 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
Row(new_row)
}

fn bool_from_array_ref(array_ref: ArrayRef) -> bool {
let bool_array = array_ref.as_ref().as_bool();
bool_array.value_at(0).unwrap_or_else(|| {
panic!(
"Some thing wrong with the expression result. Bool array: {:?}",
bool_array
)
})
}

#[try_stream(ok = Message, error = StreamExecutorError)]
async fn eq_join_oneside<'a, const SIDE: SideTypePrimitive>(
mut side_l: &'a mut JoinSide<K, S>,
mut side_r: &'a mut JoinSide<K, S>,
output_data_types: &'a [DataType],
cond: &'a mut Option<RowExpression>,
cond: &'a mut Option<BoxedExpression>,
chunk: StreamChunk,
append_only_optimize: bool,
) {
Expand Down Expand Up @@ -658,7 +648,10 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
let new_row =
Self::row_concat(row_update, update_start_pos, row_matched, matched_start_pos);

cond_match = Self::bool_from_array_ref(cond.eval(&new_row, output_data_types)?);
cond_match = cond
.eval_row(&new_row)?
.map(|s| *s.as_bool())
.unwrap_or(false);
}
Ok(cond_match)
};
Expand Down Expand Up @@ -784,24 +777,23 @@ mod tests {
use risingwave_common::catalog::{Field, Schema, TableId};
use risingwave_common::hash::{Key128, Key64};
use risingwave_expr::expr::expr_binary_nonnull::new_binary_expr;
use risingwave_expr::expr::{InputRefExpression, RowExpression};
use risingwave_expr::expr::InputRefExpression;
use risingwave_pb::expr::expr_node::Type;
use risingwave_storage::memory::MemoryStateStore;

use super::*;
use crate::executor::test_utils::{MessageSender, MockSource};
use crate::executor::{Barrier, Epoch, Message};

fn create_cond() -> RowExpression {
fn create_cond() -> BoxedExpression {
let left_expr = InputRefExpression::new(DataType::Int64, 1);
let right_expr = InputRefExpression::new(DataType::Int64, 3);
let cond = new_binary_expr(
new_binary_expr(
Type::LessThan,
DataType::Boolean,
Box::new(left_expr),
Box::new(right_expr),
);
RowExpression::new(cond)
)
}

fn create_executor<const T: JoinTypePrimitive>(
Expand Down
6 changes: 3 additions & 3 deletions src/stream/src/from_proto/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::sync::Arc;

use risingwave_common::catalog::TableId;
use risingwave_common::hash::{calc_hash_key_kind, HashKey, HashKeyDispatcher, HashKeyKind};
use risingwave_expr::expr::{build_from_prost, RowExpression};
use risingwave_expr::expr::{build_from_prost, BoxedExpression};
use risingwave_pb::plan_common::JoinType as JoinTypeProto;

use super::*;
Expand Down Expand Up @@ -66,7 +66,7 @@ impl ExecutorBuilder for HashJoinExecutorBuilder {
.collect_vec();

let condition = match node.get_condition() {
Ok(cond_prost) => Some(RowExpression::new(build_from_prost(cond_prost)?)),
Ok(cond_prost) => Some(build_from_prost(cond_prost)?),
Err(_) => None,
};
trace!("Join non-equi condition: {:?}", condition);
Expand Down Expand Up @@ -146,7 +146,7 @@ struct HashJoinExecutorDispatcherArgs<S: StateStore> {
pk_indices: PkIndices,
output_indices: Vec<usize>,
executor_id: u64,
cond: Option<RowExpression>,
cond: Option<BoxedExpression>,
op_info: String,
store_l: S,
left_table_id: TableId,
Expand Down

0 comments on commit 3850d5e

Please sign in to comment.