Skip to content

Commit

Permalink
API to get Expr's type and nullability without a DFSchema (#1726)
Browse files Browse the repository at this point in the history
* API to get Expr type and nullability without a `DFSchema`

* Add test

* publically export

* Improve docs
  • Loading branch information
alamb authored Feb 3, 2022
1 parent 78c30b6 commit b2eaee3
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 14 deletions.
123 changes: 111 additions & 12 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,20 +392,59 @@ impl PartialOrd for Expr {
}
}

/// Provides schema information needed by [Expr] methods such as
/// [Expr::nullable] and [Expr::data_type].
///
/// Note that this trait is implemented for &[DFSchema] which is
/// widely used in the DataFusion codebase.
pub trait ExprSchema {
/// Is this column reference nullable?
fn nullable(&self, col: &Column) -> Result<bool>;

/// What is the datatype of this column?
fn data_type(&self, col: &Column) -> Result<&DataType>;
}

// Implement `ExprSchema` for `Arc<DFSchema>`
impl<P: AsRef<DFSchema>> ExprSchema for P {
fn nullable(&self, col: &Column) -> Result<bool> {
self.as_ref().nullable(col)
}

fn data_type(&self, col: &Column) -> Result<&DataType> {
self.as_ref().data_type(col)
}
}

impl ExprSchema for DFSchema {
fn nullable(&self, col: &Column) -> Result<bool> {
Ok(self.field_from_column(col)?.is_nullable())
}

fn data_type(&self, col: &Column) -> Result<&DataType> {
Ok(self.field_from_column(col)?.data_type())
}
}

impl Expr {
/// Returns the [arrow::datatypes::DataType] of the expression based on [arrow::datatypes::Schema].
/// Returns the [arrow::datatypes::DataType] of the expression
/// based on [ExprSchema]
///
/// Note: [DFSchema] implements [ExprSchema].
///
/// # Errors
///
/// This function errors when it is not possible to compute its [arrow::datatypes::DataType].
/// This happens when e.g. the expression refers to a column that does not exist in the schema, or when
/// the expression is incorrectly typed (e.g. `[utf8] + [bool]`).
pub fn get_type(&self, schema: &DFSchema) -> Result<DataType> {
/// This function errors when it is not possible to compute its
/// [arrow::datatypes::DataType]. This happens when e.g. the
/// expression refers to a column that does not exist in the
/// schema, or when the expression is incorrectly typed
/// (e.g. `[utf8] + [bool]`).
pub fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> {
match self {
Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => {
expr.get_type(schema)
}
Expr::Column(c) => Ok(schema.field_from_column(c)?.data_type().clone()),
Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
Expr::ScalarVariable(_) => Ok(DataType::Utf8),
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema),
Expand Down Expand Up @@ -472,21 +511,24 @@ impl Expr {
}
}

/// Returns the nullability of the expression based on [arrow::datatypes::Schema].
/// Returns the nullability of the expression based on [ExprSchema].
///
/// Note: [DFSchema] implements [ExprSchema].
///
/// # Errors
///
/// This function errors when it is not possible to compute its nullability.
/// This happens when the expression refers to a column that does not exist in the schema.
pub fn nullable(&self, input_schema: &DFSchema) -> Result<bool> {
/// This function errors when it is not possible to compute its
/// nullability. This happens when the expression refers to a
/// column that does not exist in the schema.
pub fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool> {
match self {
Expr::Alias(expr, _)
| Expr::Not(expr)
| Expr::Negative(expr)
| Expr::Sort { expr, .. }
| Expr::Between { expr, .. }
| Expr::InList { expr, .. } => expr.nullable(input_schema),
Expr::Column(c) => Ok(input_schema.field_from_column(c)?.is_nullable()),
Expr::Column(c) => input_schema.nullable(c),
Expr::Literal(value) => Ok(value.is_null()),
Expr::Case {
when_then_expr,
Expand Down Expand Up @@ -561,7 +603,11 @@ impl Expr {
///
/// This function errors when it is impossible to cast the
/// expression to the target [arrow::datatypes::DataType].
pub fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result<Expr> {
pub fn cast_to<S: ExprSchema>(
self,
cast_to_type: &DataType,
schema: &S,
) -> Result<Expr> {
// TODO(kszucs): most of the operations do not validate the type correctness
// like all of the binary expressions below. Perhaps Expr should track the
// type of the expression?
Expand Down Expand Up @@ -2557,4 +2603,57 @@ mod tests {
combine_filters(&[filter1.clone(), filter2.clone(), filter3.clone()]);
assert_eq!(result, Some(and(and(filter1, filter2), filter3)));
}

#[test]
fn expr_schema_nullability() {
let expr = col("foo").eq(lit(1));
assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
assert!(expr
.nullable(&MockExprSchema::new().with_nullable(true))
.unwrap());
}

#[test]
fn expr_schema_data_type() {
let expr = col("foo");
assert_eq!(
DataType::Utf8,
expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
.unwrap()
);
}

struct MockExprSchema {
nullable: bool,
data_type: DataType,
}

impl MockExprSchema {
fn new() -> Self {
Self {
nullable: false,
data_type: DataType::Null,
}
}

fn with_nullable(mut self, nullable: bool) -> Self {
self.nullable = nullable;
self
}

fn with_data_type(mut self, data_type: DataType) -> Self {
self.data_type = data_type;
self
}
}

impl ExprSchema for MockExprSchema {
fn nullable(&self, _col: &Column) -> Result<bool> {
Ok(self.nullable)
}

fn data_type(&self, _col: &Column) -> Result<&DataType> {
Ok(&self.data_type)
}
}
}
4 changes: 2 additions & 2 deletions datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ pub use expr::{
rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512,
signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex,
translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when,
Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, RewriteRecursion,
SimplifyInfo,
Column, Expr, ExprRewriter, ExprSchema, ExpressionVisitor, Literal, Recursion,
RewriteRecursion, SimplifyInfo,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
Expand Down

0 comments on commit b2eaee3

Please sign in to comment.