Skip to content

Commit

Permalink
feat: support ROUND, NOT LIKE and EXTRACT (#1498)
Browse files Browse the repository at this point in the history
  • Loading branch information
neverchanje authored Apr 1, 2022
1 parent cb5ec02 commit b128de3
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 29 deletions.
5 changes: 5 additions & 0 deletions e2e_test/v2/batch/decimal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ select round(v1, 2), round(v2, 1), round(v1, -1) from t

statement ok
drop table t

query T
values(round(42.4382));
----
42
4 changes: 4 additions & 0 deletions e2e_test/v2/batch/time.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
query T
values(extract(hour from timestamp '2001-02-16 20:38:40'));
----
20
18 changes: 18 additions & 0 deletions e2e_test/v2/tpch/tpch.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
include ../../tpch/create_tables.slt.part

include ../../tpch/insert_customer.slt.part
include ../../tpch/insert_lineitem.slt.part
include ../../tpch/insert_nation.slt.part
include ../../tpch/insert_orders.slt.part
include ../../tpch/insert_part.slt.part
include ../../tpch/insert_partsupp.slt.part
include ../../tpch/insert_supplier.slt.part
include ../../tpch/insert_region.slt.part

include ../../batch/tpch/q1.slt.part
include ../../batch/tpch/q5.slt.part
include ../../batch/tpch/q6.slt.part
include ../../batch/tpch/q7.slt.part
include ../../batch/tpch/q13.slt.part

include ../../tpch/drop_tables.slt.part
30 changes: 26 additions & 4 deletions rust/frontend/src/binder/expr/binary_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::error::{ErrorCode, Result, RwError};
use risingwave_sqlparser::ast::{BinaryOperator, Expr};

use crate::binder::Binder;
use crate::expr::{Expr as _, ExprType, FunctionCall};
use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall};

impl Binder {
pub(super) fn bind_binary_op(
Expand All @@ -42,15 +42,37 @@ impl Binder {
BinaryOperator::And => ExprType::And,
BinaryOperator::Or => ExprType::Or,
BinaryOperator::Like => ExprType::Like,
BinaryOperator::NotLike => return self.bind_not_like(bound_left, bound_right),
_ => return Err(ErrorCode::NotImplementedError(format!("{:?}", op)).into()),
};
FunctionCall::new_or_else(func_type, vec![bound_left, bound_right], |inputs| {
Self::err_unsupported_binary_op(op, inputs)
})
}

/// Apply a NOT on top of LIKE.
fn bind_not_like(&mut self, left: ExprImpl, right: ExprImpl) -> Result<FunctionCall> {
Ok(FunctionCall::new(
ExprType::Not,
vec![
FunctionCall::new_or_else(ExprType::Like, vec![left, right], |inputs| {
Self::err_unsupported_binary_op(BinaryOperator::NotLike, inputs)
})?
.into(),
],
)
.unwrap())
}

fn err_unsupported_binary_op(op: BinaryOperator, inputs: &[ExprImpl]) -> RwError {
let bound_left = inputs.get(0).unwrap();
let bound_right = inputs.get(1).unwrap();
let desc = format!(
"{:?} {:?} {:?}",
bound_left.return_type(),
op,
bound_right.return_type(),
);
FunctionCall::new(func_type, vec![bound_left, bound_right])
.ok_or_else(|| ErrorCode::NotImplementedError(desc).into())
ErrorCode::NotImplementedError(desc).into()
}
}
32 changes: 28 additions & 4 deletions rust/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
// limitations under the License.

use itertools::Itertools;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::error::{ErrorCode, Result, RwError};
use risingwave_common::expr::AggKind;
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{Function, FunctionArg, FunctionArgExpr};

use crate::binder::bind_context::Clause;
use crate::binder::Binder;
use crate::expr::{AggCall, ExprImpl, ExprType, FunctionCall};
use crate::expr::{AggCall, Expr, ExprImpl, ExprType, FunctionCall, Literal};

impl Binder {
pub(super) fn bind_function(&mut self, f: Function) -> Result<ExprImpl> {
Expand Down Expand Up @@ -75,7 +75,10 @@ impl Binder {
.into())
}
};
Ok(FunctionCall::try_new(function_name, function_type, inputs)?.into())
Ok(FunctionCall::new_or_else(function_type, inputs, |args| {
Self::err_unsupported_func(function_name, args)
})?
.into())
} else {
Err(
ErrorCode::NotImplementedError(format!("unsupported function: {:?}", f.name))
Expand All @@ -84,10 +87,31 @@ impl Binder {
}
}

fn err_unsupported_func(function_name: &str, inputs: &[ExprImpl]) -> RwError {
let args = inputs
.iter()
.map(|i| format!("{:?}", i.return_type()))
.join(",");
ErrorCode::NotImplementedError(format!(
"function {}({}) doesn't exist",
function_name, args
))
.into()
}

/// Rewrite the arguments to be consistent with the `round` signature:
/// - round(Decimal, Int32) -> Decimal
/// - round(Decimal) -> Decimal
fn rewrite_round_args(mut inputs: Vec<ExprImpl>) -> Vec<ExprImpl> {
if inputs.len() == 2 {
if inputs.len() == 1 {
// Rewrite round(Decimal) to round(Decimal, 0).
let input = inputs.pop().unwrap();
if input.return_type() == DataType::Decimal {
vec![input, Literal::new(Some(0.into()), DataType::Int32).into()]
} else {
vec![input]
}
} else if inputs.len() == 2 {
let digits = inputs.pop().unwrap();
let input = inputs.pop().unwrap();
vec![
Expand Down
22 changes: 21 additions & 1 deletion rust/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use itertools::zip_eq;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{
BinaryOperator, DataType as AstDataType, Expr, TrimWhereField, UnaryOperator,
BinaryOperator, DataType as AstDataType, DateTimeField, Expr, TrimWhereField, UnaryOperator,
};

use crate::binder::Binder;
Expand Down Expand Up @@ -92,12 +92,32 @@ impl Binder {
} => Ok(ExprImpl::FunctionCall(Box::new(
self.bind_between(*expr, negated, *low, *high)?,
))),
Expr::Extract { field, expr } => self.bind_extract(field, *expr),
_ => Err(
ErrorCode::NotImplementedError(format!("unsupported expression {:?}", expr)).into(),
),
}
}

pub(super) fn bind_extract(&mut self, field: DateTimeField, expr: Expr) -> Result<ExprImpl> {
Ok(FunctionCall::new_or_else(
ExprType::Extract,
vec![
self.bind_string(field.to_string())?.into(),
self.bind_expr(expr)?,
],
|inputs| {
ErrorCode::NotImplementedError(format!(
"function extract({} from {:?}) doesn't exist",
field,
inputs[1].return_type()
))
.into()
},
)?
.into())
}

pub(super) fn bind_unary_expr(&mut self, op: UnaryOperator, expr: Expr) -> Result<ExprImpl> {
let func_type = match op {
UnaryOperator::Not => ExprType::Not,
Expand Down
24 changes: 6 additions & 18 deletions rust/frontend/src/expr/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use itertools::Itertools;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::error::{Result, RwError};
use risingwave_common::types::DataType;

use super::{infer_type, Expr, ExprImpl};
Expand Down Expand Up @@ -87,26 +86,15 @@ impl std::fmt::Debug for FunctionCall {

impl FunctionCall {
/// Returns error if the function call is not valid.
pub fn try_new(
function_name: &str,
func_type: ExprType,
inputs: Vec<ExprImpl>,
) -> Result<Self> {
pub fn new_or_else<F>(func_type: ExprType, inputs: Vec<ExprImpl>, err_f: F) -> Result<Self>
where
F: FnOnce(&Vec<ExprImpl>) -> RwError,
{
infer_type(
func_type,
inputs.iter().map(|expr| expr.return_type()).collect(),
)
.ok_or_else(|| {
let args = inputs
.iter()
.map(|i| format!("{:?}", i.return_type()))
.join(",");
ErrorCode::NotImplementedError(format!(
"function {}({}) doesn't exist",
function_name, args
))
.into()
})
.ok_or_else(|| err_f(&inputs))
.map(|return_type| Self::new_with_return_type(func_type, inputs, return_type))
}

Expand Down
8 changes: 7 additions & 1 deletion rust/frontend/src/expr/type_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,13 @@ fn build_type_derive_map() -> HashMap<FuncSign, DataTypeName> {
&[T::Int32],
T::Decimal,
);

build_binary_funcs(
&mut map,
&[E::Extract],
&[T::Varchar], // Time field, "YEAR", "DAY", etc
&[T::Timestamp, T::Time, T::Date],
T::Decimal,
);
map
}
lazy_static::lazy_static! {
Expand Down
3 changes: 2 additions & 1 deletion rust/frontend/src/handler/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ lazy_static::lazy_static! {
/// If `RW_IMPLICIT_FLUSH` is on, then every INSERT/UPDATE/DELETE statement will block
/// until the entire dataflow is refreshed. In other words, every related table & MV will
/// be able to see the write.
/// TODO: Use session config to set this.
static ref IMPLICIT_FLUSH: bool =
std::env::var("RW_IMPLICIT_FLUSH").unwrap_or_else(|_| { "true".to_string() }).parse().unwrap();
std::env::var("RW_IMPLICIT_FLUSH").unwrap_or_else(|_| { "1".to_string() }) == "1";
}

pub async fn handle_query(context: OptimizerContext, stmt: Statement) -> Result<PgResponse> {
Expand Down
16 changes: 16 additions & 0 deletions rust/frontend/test_runner/tests/testdata/basic_query_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,25 @@
values(round(42.4382, 2));
batch_plan: |
BatchValues { rows: [[RoundDigit(42.4382:Decimal, 2:Int32)]] }
- sql: |
values(round(42.4382));
batch_plan: |
BatchValues { rows: [[RoundDigit(42.4382:Decimal, 0:Int32)]] }
- sql: |
values(round('abc'));
binder_error: 'Feature is not yet implemented: function round(Varchar) doesn''t exist'
- sql: |
values('Postgres' not like 'Post%');
batch_plan: |
BatchValues { rows: [[Not(Like('Postgres':Varchar, 'Post%':Varchar))]] }
- sql: |
values(1 not like 1.23);
binder_error: |
Feature is not yet implemented: Int32 NotLike Decimal
- sql: |
values(extract(hour from timestamp '2001-02-16 20:38:40'));
batch_plan: |
BatchValues { rows: [[Extract('HOUR':Varchar, '2001-02-16 20:38:40':Varchar::Timestamp)]] }
- sql: |
create table t (v1 int);
select (case when v1=1 then 1 when v1=2 then 2 else 0.0 end) from t;
Expand Down

0 comments on commit b128de3

Please sign in to comment.