Skip to content

Commit

Permalink
feat: impl batch operators to pb (#986)
Browse files Browse the repository at this point in the history
  • Loading branch information
neverchanje authored Mar 17, 2022
1 parent d9eb1fc commit 95b9d6f
Show file tree
Hide file tree
Showing 20 changed files with 325 additions and 80 deletions.
4 changes: 2 additions & 2 deletions Makefile.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ set -e
for out_file in ${PREFIX_LOG}/*.log
do
if grep "panic" "$out_file" -C 10; then
echo "\"panic\" found in $out_file, please check if there's any bugs in this PR."
if grep "panicked at" "$out_file" -C 10; then
echo "\"panicked at\" found in $out_file, please check if there's any bugs in this PR."
echo "You may find \"risedev-logs\" artifacts and download logs after all workflows finish."
exit 1
fi
Expand Down
13 changes: 13 additions & 0 deletions e2e_test/v2/basic.slt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,16 @@ create table t (v1 int not null);

statement ok
drop table t;

query I
values(1);
----
1

query I
values(1+2*3)
----
7

statement error
values(CAST('abc' AS BOOLEAN))
19 changes: 9 additions & 10 deletions rust/batch/src/task/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,13 @@ pub(in crate) enum TaskState {
Failed,
}

impl TryFrom<&ProstTaskId> for TaskId {
type Error = RwError;
fn try_from(prost: &ProstTaskId) -> Result<Self> {
Ok(TaskId {
impl From<&ProstTaskId> for TaskId {
fn from(prost: &ProstTaskId) -> Self {
TaskId {
task_id: prost.task_id,
stage_id: prost.stage_id,
query_id: prost.query_id.clone(),
})
}
}
}

Expand All @@ -69,7 +68,7 @@ impl TryFrom<&ProstOutputId> for TaskOutputId {
type Error = RwError;
fn try_from(prost: &ProstOutputId) -> Result<Self> {
Ok(TaskOutputId {
task_id: TaskId::try_from(prost.get_task_id()?)?,
task_id: TaskId::from(prost.get_task_id()?),
output_id: prost.get_output_id(),
})
}
Expand Down Expand Up @@ -171,7 +170,7 @@ impl BatchTaskExecution {
epoch: u64,
) -> Result<Self> {
Ok(BatchTaskExecution {
task_id: TaskId::try_from(prost_tid)?,
task_id: TaskId::from(prost_tid),
plan,
state: Mutex::new(TaskStatus::Pending),
receivers: Mutex::new(Vec::new()),
Expand Down Expand Up @@ -252,7 +251,7 @@ impl BatchTaskExecution {
}

pub fn get_task_output(&self, output_id: &ProstOutputId) -> Result<TaskOutput> {
let task_id = TaskId::try_from(output_id.get_task_id()?)?;
let task_id = TaskId::from(output_id.get_task_id()?);
let receiver = self.receivers.lock().unwrap()[output_id.get_output_id() as usize]
.take()
.ok_or_else(|| {
Expand All @@ -270,8 +269,8 @@ impl BatchTaskExecution {
Ok(task_output)
}

pub fn get_error(&self) -> Result<Option<RwError>> {
Ok(self.failure.lock().unwrap().clone())
pub fn get_error(&self) -> Option<RwError> {
self.failure.lock().unwrap().clone()
}

pub fn check_if_running(&self) -> Result<()> {
Expand Down
66 changes: 53 additions & 13 deletions rust/batch/src/task/task_manager.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashMap;
use std::collections::{hash_map, HashMap};
use std::sync::{Arc, Mutex};

use risingwave_common::error::ErrorCode::TaskNotFound;
use risingwave_common::error::ErrorCode::{self, TaskNotFound};
use risingwave_common::error::{Result, RwError};
use risingwave_pb::plan::{PlanFragment, TaskId as ProstTaskId, TaskOutputId as ProstOutputId};

Expand Down Expand Up @@ -31,18 +31,23 @@ impl BatchManager {
epoch: u64,
) -> Result<()> {
let task = BatchTaskExecution::new(tid, plan, env, epoch)?;
let task_id = task.get_task_id().clone();

task.async_execute()?;
self.tasks
.lock()
.unwrap()
.entry(task.get_task_id().clone())
.or_insert_with(|| Box::new(task));
Ok(())
if let hash_map::Entry::Vacant(e) = self.tasks.lock().unwrap().entry(task_id.clone()) {
e.insert(Box::new(task));
Ok(())
} else {
Err(ErrorCode::InternalError(format!(
"can not create duplicate task with the same id: {:?}",
task_id,
))
.into())
}
}

pub fn take_output(&self, output_id: &ProstOutputId) -> Result<TaskOutput> {
let task_id = TaskId::try_from(output_id.get_task_id()?)?;
let task_id = TaskId::from(output_id.get_task_id()?);
self.tasks
.lock()
.unwrap()
Expand All @@ -53,13 +58,14 @@ impl BatchManager {

#[cfg(test)]
pub fn remove_task(&self, sid: &ProstTaskId) -> Result<Option<Box<BatchTaskExecution>>> {
let task_id = TaskId::try_from(sid)?;
let task_id = TaskId::from(sid);
match self.tasks.lock().unwrap().remove(&task_id) {
Some(t) => Ok(Some(t)),
None => Err(TaskNotFound.into()),
}
}

/// Returns error if task is not running.
pub fn check_if_task_running(&self, task_id: &TaskId) -> Result<()> {
match self.tasks.lock().unwrap().get(task_id) {
Some(task) => task.check_if_running(),
Expand All @@ -68,13 +74,13 @@ impl BatchManager {
}

pub fn get_error(&self, task_id: &TaskId) -> Result<Option<RwError>> {
return self
Ok(self
.tasks
.lock()
.unwrap()
.get(task_id)
.ok_or(TaskNotFound)?
.get_error();
.get_error())
}
}

Expand All @@ -86,10 +92,12 @@ impl Default for BatchManager {

#[cfg(test)]
mod tests {
use risingwave_pb::plan::exchange_info::DistributionMode;
use risingwave_pb::plan::plan_node::NodeBody;
use risingwave_pb::plan::TaskOutputId as ProstTaskOutputId;
use tonic::Code;

use crate::task::{BatchManager, TaskId};
use crate::task::{BatchEnvironment, BatchManager, TaskId};

#[test]
fn test_task_not_found() {
Expand Down Expand Up @@ -122,4 +130,36 @@ mod tests {
Ok(_) => unreachable!(),
};
}

#[tokio::test]
async fn test_task_id_conflict() {
use risingwave_pb::plan::*;

let manager = BatchManager::new();
let plan = PlanFragment {
root: Some(PlanNode {
children: vec![],
identity: "".to_string(),
node_body: Some(NodeBody::Values(ValuesNode {
tuples: vec![],
fields: vec![],
})),
}),
exchange_info: Some(ExchangeInfo {
mode: DistributionMode::Single as i32,
distribution: None,
}),
};
let env = BatchEnvironment::for_test();
let task_id = TaskId {
..Default::default()
};
manager
.fire_task(env.clone(), &task_id, plan.clone(), 0)
.unwrap();
let err = manager.fire_task(env, &task_id, plan, 0).unwrap_err();
assert!(err
.to_string()
.contains("can not create duplicate task with the same id"));
}
}
4 changes: 3 additions & 1 deletion rust/common/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,10 @@ impl<T, E: ToErrorStr> ToRwResult<T, E> for std::result::Result<T, E> {
}

impl ToErrorStr for tonic::Status {
/// [`tonic::Status`] means no transportation error but only application-level failure.
/// In this case we focus on the message rather than other fields.
fn to_error_str(self) -> String {
format!("grpc tonic error: {}", self)
self.message().to_string()
}
}

Expand Down
4 changes: 2 additions & 2 deletions rust/frontend/src/expr/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ pub trait ExprRewriter {
}
}
fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
let (func_type, inputs) = func_call.decompose();
let (func_type, inputs, ret) = func_call.decompose();
let inputs = inputs
.into_iter()
.map(|expr| self.rewrite_expr(expr))
.collect();
FunctionCall::new(func_type, inputs).unwrap().into()
FunctionCall::new_with_return_type(func_type, inputs, ret).into()
}
fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
let (func_type, inputs) = agg_call.decompose();
Expand Down
4 changes: 2 additions & 2 deletions rust/frontend/src/expr/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ impl FunctionCall {
}
}

pub fn decompose(self) -> (ExprType, Vec<ExprImpl>) {
(self.func_type, self.inputs)
pub fn decompose(self) -> (ExprType, Vec<ExprImpl>, DataType) {
(self.func_type, self.inputs, self.return_type)
}
pub fn decompose_as_binary(self) -> (ExprType, ExprImpl, ExprImpl) {
assert_eq!(self.inputs.len(), 2);
Expand Down
1 change: 1 addition & 0 deletions rust/frontend/src/expr/input_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ impl InputRef {
pub fn new(index: usize, data_type: DataType) -> Self {
InputRef { index, data_type }
}

pub fn get_expr_type(&self) -> ExprType {
ExprType::InputRef
}
Expand Down
63 changes: 62 additions & 1 deletion rust/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use risingwave_common::types::{DataType, Scalar};
use risingwave_common::types::{DataType, Datum, Scalar, ScalarImpl};
mod input_ref;
pub use input_ref::*;
mod literal;
Expand All @@ -18,6 +18,9 @@ mod expr_visitor;
pub use expr_visitor::*;
pub type ExprType = risingwave_pb::expr::expr_node::Type;

use risingwave_pb::expr::expr_node::RexNode;
use risingwave_pb::expr::{ConstantValue, InputRefExpr};

/// the trait of bound exprssions
pub trait Expr: Into<ExprImpl> {
fn return_type(&self) -> DataType;
Expand Down Expand Up @@ -60,6 +63,64 @@ impl ExprImpl {
DataType::Boolean,
)))
}

/// Serialize to protobuf.
pub fn to_protobuf(&self) -> ExprNode {
use risingwave_pb::expr::FunctionCall as ProstFunctionCall;

match self {
ExprImpl::InputRef(e) => ExprNode {
expr_type: e.get_expr_type() as i32,
return_type: Some(e.return_type().to_protobuf()),
rex_node: Some(RexNode::InputRef(InputRefExpr {
column_idx: e.index() as i32,
})),
},
ExprImpl::Literal(e) => ExprNode {
expr_type: e.get_expr_type() as i32,
return_type: Some(e.return_type().to_protobuf()),
rex_node: literal_to_protobuf(e.get_data()),
},
ExprImpl::FunctionCall(e) => ExprNode {
expr_type: e.get_expr_type() as i32,
return_type: Some(e.return_type().to_protobuf()),
rex_node: Some(RexNode::FuncCall(ProstFunctionCall {
children: e.inputs().iter().map(|arg| arg.to_protobuf()).collect(),
})),
},
// This function is always called on the physical planning step, where
// `ExprImpl::AggCall` must have been rewritten to aggregate operators.
ExprImpl::AggCall(e) => {
panic!(
"AggCall {:?} has not been rewritten to physical aggregate operators",
e
)
}
}
}
}

/// Convert a literal value (datum) into protobuf.
fn literal_to_protobuf(d: &Datum) -> Option<RexNode> {
if d.is_none() {
return None;
}
let body = match d.as_ref().unwrap() {
ScalarImpl::Int16(v) => v.to_be_bytes().to_vec(),
ScalarImpl::Int32(v) => v.to_be_bytes().to_vec(),
ScalarImpl::Int64(v) => v.to_be_bytes().to_vec(),
ScalarImpl::Float32(v) => v.to_be_bytes().to_vec(),
ScalarImpl::Float64(v) => v.to_be_bytes().to_vec(),
ScalarImpl::Utf8(s) => s.as_bytes().to_vec(),
ScalarImpl::Bool(v) => (*v as i8).to_be_bytes().to_vec(),
ScalarImpl::Decimal(v) => v.to_string().as_bytes().to_vec(),
ScalarImpl::Interval(_) => todo!(),
ScalarImpl::NaiveDate(_) => todo!(),
ScalarImpl::NaiveDateTime(_) => todo!(),
ScalarImpl::NaiveTime(_) => todo!(),
ScalarImpl::Struct(_) => todo!(),
};
Some(RexNode::Constant(ConstantValue { body }))
}

impl Expr for ExprImpl {
Expand Down
2 changes: 1 addition & 1 deletion rust/frontend/src/expr/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::expr::ExprType;
fn to_conjunctions_inner(expr: ExprImpl, rets: &mut Vec<ExprImpl>) {
match expr {
ExprImpl::FunctionCall(func_call) if func_call.get_expr_type() == ExprType::And => {
let (_, exprs) = func_call.decompose();
let (_, exprs, _) = func_call.decompose();
for expr in exprs.into_iter() {
to_conjunctions_inner(expr, rets);
}
Expand Down
9 changes: 4 additions & 5 deletions rust/frontend/src/handler/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use risingwave_common::error::ErrorCode::InternalError;
use risingwave_common::error::{ErrorCode, Result, RwError, ToRwResult};
use risingwave_pb::hummock::{HummockSnapshot, PinSnapshotRequest, UnpinSnapshotRequest};
use risingwave_pb::plan::{TaskId, TaskOutputId};
use risingwave_rpc_client::{ComputeClient, ExchangeSource, GrpcExchangeSource};
use risingwave_rpc_client::{ComputeClient, ExchangeSource};
use risingwave_sqlparser::ast::Statement;
use uuid::Uuid;

use crate::binder::Binder;
use crate::handler::util::{get_pg_field_descs, to_pg_rows};
Expand Down Expand Up @@ -48,7 +49,7 @@ pub async fn handle_query(context: QueryContext, stmt: Statement) -> Result<PgRe

// Build task id and task sink id
let task_id = TaskId {
query_id: "".to_string(),
query_id: Uuid::new_v4().to_string(),
stage_id: 0,
task_id: 0,
};
Expand All @@ -74,9 +75,7 @@ pub async fn handle_query(context: QueryContext, stmt: Statement) -> Result<PgRe
compute_client
.create_task(task_id.clone(), plan, epoch)
.await?;
let mut source =
GrpcExchangeSource::create_with_client(compute_client.clone(), task_sink_id.clone())
.await?;
let mut source = compute_client.get_data(task_sink_id.clone()).await?;
while let Some(chunk) = source.take_data().await? {
rows.append(&mut to_pg_rows(chunk));
}
Expand Down
Loading

0 comments on commit 95b9d6f

Please sign in to comment.