Skip to content

Commit

Permalink
split expr type and null info to be expr-schemable
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist committed Feb 8, 2022
1 parent 86dcb09 commit fb71054
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 155 deletions.
1 change: 1 addition & 0 deletions datafusion/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::datasource::{
MemTable, TableProvider,
};
use crate::error::{DataFusionError, Result};
use crate::logical_plan::expr_schema::ExprSchemable;
use crate::logical_plan::plan::{
Aggregate, Analyze, EmptyRelation, Explain, Filter, Join, Projection, Sort,
TableScan, ToStringifiedPlan, Union, Window,
Expand Down
152 changes: 2 additions & 150 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@
pub use super::Operator;
use crate::error::{DataFusionError, Result};
use crate::field_util::get_indexed_field;
use crate::logical_plan::ExprSchemable;
use crate::logical_plan::{window_frames, DFField, DFSchema};
use crate::physical_plan::functions::Volatility;
use crate::physical_plan::{
aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF,
window_functions,
};
use crate::physical_plan::{aggregates, functions, udf::ScalarUDF, window_functions};
use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue};
use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction};
use arrow::{compute::can_cast_types, datatypes::DataType};
Expand Down Expand Up @@ -251,151 +248,6 @@ impl PartialOrd for Expr {
}

impl Expr {
/// Returns the [arrow::datatypes::DataType] of the expression
/// based on [ExprSchema]
///
/// Note: [DFSchema] implements [ExprSchema].
///
/// # Errors
///
/// This function errors when it is not possible to compute its
/// [arrow::datatypes::DataType]. This happens when e.g. the
/// expression refers to a column that does not exist in the
/// schema, or when the expression is incorrectly typed
/// (e.g. `[utf8] + [bool]`).
pub fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> {
match self {
Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => {
expr.get_type(schema)
}
Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
Expr::ScalarVariable(_) => Ok(DataType::Utf8),
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema),
Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => {
Ok(data_type.clone())
}
Expr::ScalarUDF { fun, args } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
}
Expr::ScalarFunction { fun, args } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
functions::return_type(fun, &data_types)
}
Expr::WindowFunction { fun, args, .. } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
window_functions::return_type(fun, &data_types)
}
Expr::AggregateFunction { fun, args, .. } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
aggregates::return_type(fun, &data_types)
}
Expr::AggregateUDF { fun, args, .. } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
}
Expr::Not(_)
| Expr::IsNull(_)
| Expr::Between { .. }
| Expr::InList { .. }
| Expr::IsNotNull(_) => Ok(DataType::Boolean),
Expr::BinaryExpr {
ref left,
ref right,
ref op,
} => binary_operator_data_type(
&left.get_type(schema)?,
op,
&right.get_type(schema)?,
),
Expr::Wildcard => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
Expr::GetIndexedField { ref expr, key } => {
let data_type = expr.get_type(schema)?;

get_indexed_field(&data_type, key).map(|x| x.data_type().clone())
}
}
}

/// Returns the nullability of the expression based on [ExprSchema].
///
/// Note: [DFSchema] implements [ExprSchema].
///
/// # Errors
///
/// This function errors when it is not possible to compute its
/// nullability. This happens when the expression refers to a
/// column that does not exist in the schema.
pub fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool> {
match self {
Expr::Alias(expr, _)
| Expr::Not(expr)
| Expr::Negative(expr)
| Expr::Sort { expr, .. }
| Expr::Between { expr, .. }
| Expr::InList { expr, .. } => expr.nullable(input_schema),
Expr::Column(c) => input_schema.nullable(c),
Expr::Literal(value) => Ok(value.is_null()),
Expr::Case {
when_then_expr,
else_expr,
..
} => {
// this expression is nullable if any of the input expressions are nullable
let then_nullable = when_then_expr
.iter()
.map(|(_, t)| t.nullable(input_schema))
.collect::<Result<Vec<_>>>()?;
if then_nullable.contains(&true) {
Ok(true)
} else if let Some(e) = else_expr {
e.nullable(input_schema)
} else {
Ok(false)
}
}
Expr::Cast { expr, .. } => expr.nullable(input_schema),
Expr::ScalarVariable(_)
| Expr::TryCast { .. }
| Expr::ScalarFunction { .. }
| Expr::ScalarUDF { .. }
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::AggregateUDF { .. } => Ok(true),
Expr::IsNull(_) | Expr::IsNotNull(_) => Ok(false),
Expr::BinaryExpr {
ref left,
ref right,
..
} => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
Expr::Wildcard => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
Expr::GetIndexedField { ref expr, key } => {
let data_type = expr.get_type(input_schema)?;
get_indexed_field(&data_type, key).map(|x| x.is_nullable())
}
}
}

/// Returns the name of this expression based on [crate::logical_plan::DFSchema].
///
/// This represents how a column with this expression is named when no alias is chosen
Expand Down
180 changes: 180 additions & 0 deletions datafusion/src/logical_plan/expr_schema.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use super::Expr;
use crate::field_util::get_indexed_field;
use crate::physical_plan::{
aggregates, expressions::binary_operator_data_type, functions, window_functions,
};
use arrow::datatypes::DataType;
use datafusion_common::{DataFusionError, ExprSchema, Result};

/// trait to allow expr to typable with respect to a schema
pub trait ExprSchemable {
/// given a schema, return the type of the expr
fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType>;

/// given a schema, return the nullability of the expr
fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool>;
}

impl ExprSchemable for Expr {
/// Returns the [arrow::datatypes::DataType] of the expression
/// based on [ExprSchema]
///
/// Note: [DFSchema] implements [ExprSchema].
///
/// # Errors
///
/// This function errors when it is not possible to compute its
/// [arrow::datatypes::DataType]. This happens when e.g. the
/// expression refers to a column that does not exist in the
/// schema, or when the expression is incorrectly typed
/// (e.g. `[utf8] + [bool]`).
fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> {
match self {
Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => {
expr.get_type(schema)
}
Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
Expr::ScalarVariable(_) => Ok(DataType::Utf8),
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema),
Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => {
Ok(data_type.clone())
}
Expr::ScalarUDF { fun, args } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
}
Expr::ScalarFunction { fun, args } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
functions::return_type(fun, &data_types)
}
Expr::WindowFunction { fun, args, .. } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
window_functions::return_type(fun, &data_types)
}
Expr::AggregateFunction { fun, args, .. } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
aggregates::return_type(fun, &data_types)
}
Expr::AggregateUDF { fun, args, .. } => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
}
Expr::Not(_)
| Expr::IsNull(_)
| Expr::Between { .. }
| Expr::InList { .. }
| Expr::IsNotNull(_) => Ok(DataType::Boolean),
Expr::BinaryExpr {
ref left,
ref right,
ref op,
} => binary_operator_data_type(
&left.get_type(schema)?,
op,
&right.get_type(schema)?,
),
Expr::Wildcard => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
Expr::GetIndexedField { ref expr, key } => {
let data_type = expr.get_type(schema)?;

get_indexed_field(&data_type, key).map(|x| x.data_type().clone())
}
}
}

/// Returns the nullability of the expression based on [ExprSchema].
///
/// Note: [DFSchema] implements [ExprSchema].
///
/// # Errors
///
/// This function errors when it is not possible to compute its
/// nullability. This happens when the expression refers to a
/// column that does not exist in the schema.
fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool> {
match self {
Expr::Alias(expr, _)
| Expr::Not(expr)
| Expr::Negative(expr)
| Expr::Sort { expr, .. }
| Expr::Between { expr, .. }
| Expr::InList { expr, .. } => expr.nullable(input_schema),
Expr::Column(c) => input_schema.nullable(c),
Expr::Literal(value) => Ok(value.is_null()),
Expr::Case {
when_then_expr,
else_expr,
..
} => {
// this expression is nullable if any of the input expressions are nullable
let then_nullable = when_then_expr
.iter()
.map(|(_, t)| t.nullable(input_schema))
.collect::<Result<Vec<_>>>()?;
if then_nullable.contains(&true) {
Ok(true)
} else if let Some(e) = else_expr {
e.nullable(input_schema)
} else {
Ok(false)
}
}
Expr::Cast { expr, .. } => expr.nullable(input_schema),
Expr::ScalarVariable(_)
| Expr::TryCast { .. }
| Expr::ScalarFunction { .. }
| Expr::ScalarUDF { .. }
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::AggregateUDF { .. } => Ok(true),
Expr::IsNull(_) | Expr::IsNotNull(_) => Ok(false),
Expr::BinaryExpr {
ref left,
ref right,
..
} => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
Expr::Wildcard => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
Expr::GetIndexedField { ref expr, key } => {
let data_type = expr.get_type(input_schema)?;
get_indexed_field(&data_type, key).map(|x| x.is_nullable())
}
}
}
}
2 changes: 2 additions & 0 deletions datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ mod dfschema;
mod display;
mod expr;
mod expr_rewriter;
mod expr_schema;
mod expr_simplier;
mod expr_visitor;
mod extension;
Expand Down Expand Up @@ -54,6 +55,7 @@ pub use expr_rewriter::{
normalize_col, normalize_cols, replace_col, rewrite_sort_cols_by_aggs,
unnormalize_col, unnormalize_cols, ExprRewritable, ExprRewriter, RewriteRecursion,
};
pub use expr_schema::ExprSchemable;
pub use expr_simplier::{ExprSimplifiable, SimplifyInfo};
pub use expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
pub use extension::UserDefinedLogicalNode;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/optimizer/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::logical_plan::plan::{Filter, Projection, Window};
use crate::logical_plan::{
col,
plan::{Aggregate, Sort},
DFField, DFSchema, Expr, ExprRewritable, ExprRewriter, ExprVisitable,
DFField, DFSchema, Expr, ExprRewritable, ExprRewriter, ExprSchemable, ExprVisitable,
ExpressionVisitor, LogicalPlan, Recursion, RewriteRecursion,
};
use crate::optimizer::optimizer::OptimizerRule;
Expand Down
Loading

0 comments on commit fb71054

Please sign in to comment.