Skip to content

Commit

Permalink
Add another method to collect referenced columns from an expression (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ygf11 authored Nov 10, 2022
1 parent 36890b8 commit 509c82c
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 35 deletions.
7 changes: 2 additions & 5 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ use arrow::{
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_expr::expr::{BinaryExpr, Cast};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::utils::expr_to_columns;
use datafusion_expr::{binary_expr, cast, try_cast, ExprSchemable};
use datafusion_physical_expr::create_physical_expr;
use log::trace;
Expand Down Expand Up @@ -445,10 +444,8 @@ impl<'a> PruningExpressionBuilder<'a> {
required_columns: &'a mut RequiredStatColumns,
) -> Result<Self> {
// find column name; input could be a more complicated expression
let mut left_columns = HashSet::<Column>::new();
expr_to_columns(left, &mut left_columns)?;
let mut right_columns = HashSet::<Column>::new();
expr_to_columns(right, &mut right_columns)?;
let left_columns = left.to_columns()?;
let right_columns = right.to_columns()?;
let (column_expr, scalar_expr, columns, correct_operator) =
match (left_columns.len(), right_columns.len()) {
(1, 0) => (left, right, left_columns, op),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
use arrow::array::{BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array};
use arrow::{array::ArrayRef, datatypes::SchemaRef, error::ArrowError};
use datafusion_common::{Column, DataFusionError, Result};
use datafusion_expr::utils::expr_to_columns;
use datafusion_optimizer::utils::split_conjunction;
use log::{debug, error, trace};
use parquet::{
Expand All @@ -32,7 +31,7 @@ use parquet::{
},
format::PageLocation,
};
use std::collections::{HashSet, VecDeque};
use std::collections::VecDeque;
use std::sync::Arc;

use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics};
Expand Down Expand Up @@ -286,8 +285,7 @@ fn extract_page_index_push_down_predicates(
predicates
.into_iter()
.try_for_each::<_, Result<()>>(|predicate| {
let mut columns: HashSet<Column> = HashSet::new();
expr_to_columns(predicate, &mut columns)?;
let columns = predicate.to_columns()?;
if columns.len() == 1 {
one_col_expr.push(predicate);
}
Expand Down
7 changes: 3 additions & 4 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ use datafusion_expr::expr::{
Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like,
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::utils::{expand_wildcard, expr_to_columns};
use datafusion_expr::utils::expand_wildcard;
use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use datafusion_optimizer::utils::unalias;
use datafusion_physical_expr::expressions::Literal;
Expand All @@ -72,7 +72,7 @@ use futures::future::BoxFuture;
use futures::{FutureExt, StreamExt, TryStreamExt};
use itertools::Itertools;
use log::{debug, trace};
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::fmt::Write;
use std::sync::Arc;

Expand Down Expand Up @@ -875,8 +875,7 @@ impl DefaultPhysicalPlanner {
let join_filter = match filter {
Some(expr) => {
// Extract columns from filter expression
let mut cols = HashSet::new();
expr_to_columns(expr, &mut cols)?;
let cols = expr.to_columns()?;

// Collect left & right field indices
let left_field_indices = cols.iter()
Expand Down
33 changes: 33 additions & 0 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::aggregate_function;
use crate::built_in_function;
use crate::expr_fn::binary_expr;
use crate::logical_plan::Subquery;
use crate::utils::expr_to_columns;
use crate::window_frame;
use crate::window_function;
use crate::AggregateUDF;
Expand All @@ -30,6 +31,7 @@ use arrow::datatypes::DataType;
use datafusion_common::Result;
use datafusion_common::{plan_err, Column};
use datafusion_common::{DataFusionError, ScalarValue};
use std::collections::HashSet;
use std::fmt;
use std::fmt::{Display, Formatter, Write};
use std::hash::{BuildHasher, Hash, Hasher};
Expand Down Expand Up @@ -685,6 +687,14 @@ impl Expr {
_ => plan_err!(format!("Could not coerce '{}' into Column!", self)),
}
}

/// Return all referenced columns of this expression.
pub fn to_columns(&self) -> Result<HashSet<Column>> {
let mut using_columns = HashSet::new();
expr_to_columns(self, &mut using_columns)?;

Ok(using_columns)
}
}

impl Not for Expr {
Expand Down Expand Up @@ -1277,6 +1287,7 @@ mod test {
use crate::expr_fn::col;
use crate::{case, lit, Expr};
use arrow::datatypes::DataType;
use datafusion_common::Column;
use datafusion_common::{Result, ScalarValue};

#[test]
Expand Down Expand Up @@ -1327,4 +1338,26 @@ mod test {
assert!(exp2 > exp3);
assert!(exp3 < exp2);
}

#[test]
fn test_collect_expr() -> Result<()> {
// single column
{
let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64));
let columns = expr.to_columns()?;
assert_eq!(1, columns.len());
assert!(columns.contains(&Column::from_name("a")));
}

// multiple columns
{
let expr = col("a") + col("b") + lit(1);
let columns = expr.to_columns()?;
assert_eq!(2, columns.len());
assert!(columns.contains(&Column::from_name("a")));
assert!(columns.contains(&Column::from_name("b")));
}

Ok(())
}
}
10 changes: 3 additions & 7 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::{
Window,
},
utils::{
can_hash, expand_qualified_wildcard, expand_wildcard, expr_to_columns,
can_hash, expand_qualified_wildcard, expand_wildcard,
group_window_expr_by_sort_keys,
},
Expr, ExprSchemable, TableSource,
Expand All @@ -43,10 +43,7 @@ use datafusion_common::{
};
use std::any::Any;
use std::convert::TryFrom;
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use std::{collections::HashMap, sync::Arc};

/// Default table name for unnamed table
pub const UNNAMED_TABLE: &str = "?table?";
Expand Down Expand Up @@ -378,8 +375,7 @@ impl LogicalPlanBuilder {
.clone()
.into_iter()
.try_for_each::<_, Result<()>>(|expr| {
let mut columns: HashSet<Column> = HashSet::new();
expr_to_columns(&expr, &mut columns)?;
let columns = expr.to_columns()?;

columns.into_iter().for_each(|c| {
if schema.field_from_column(&c).is_err() {
Expand Down
15 changes: 4 additions & 11 deletions datafusion/optimizer/src/filter_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,7 @@ fn extract_or_clauses_for_join(
// If nothing can be extracted from any sub clauses, do nothing for this OR clause.
if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
let predicate = or(left_expr, right_expr);
let mut columns: HashSet<Column> = HashSet::new();
expr_to_columns(&predicate, &mut columns).ok().unwrap();
let columns = predicate.to_columns().ok().unwrap();

exprs.push(predicate);
expr_columns.push(columns);
Expand Down Expand Up @@ -388,8 +387,7 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Ex
}
}
_ => {
let mut columns: HashSet<Column> = HashSet::new();
expr_to_columns(expr, &mut columns).ok().unwrap();
let columns = expr.to_columns().ok().unwrap();

if schema_columns
.intersection(&columns)
Expand Down Expand Up @@ -541,8 +539,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
utils::split_conjunction_owned(predicate)
.into_iter()
.try_for_each::<_, Result<()>>(|predicate| {
let mut columns: HashSet<Column> = HashSet::new();
expr_to_columns(&predicate, &mut columns)?;
let columns = predicate.to_columns()?;
state.filters.push((predicate, columns));
Ok(())
})?;
Expand Down Expand Up @@ -664,11 +661,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {

predicates
.into_iter()
.map(|e| {
let mut accum = HashSet::new();
expr_to_columns(e, &mut accum)?;
Ok((e.clone(), accum))
})
.map(|e| Ok((e.clone(), e.to_columns()?)))
.collect::<Result<Vec<_>>>()
})
.unwrap_or_else(|| Ok(vec![]))?;
Expand Down
6 changes: 2 additions & 4 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ use datafusion_expr::logical_plan::{
};
use datafusion_expr::utils::{
can_hash, expand_qualified_wildcard, expand_wildcard, expr_as_column_expr,
expr_to_columns, find_aggregate_exprs, find_column_exprs, find_window_exprs,
COUNT_STAR_EXPANSION,
find_aggregate_exprs, find_column_exprs, find_window_exprs, COUNT_STAR_EXPANSION,
};
use datafusion_expr::{
and, col, lit, AggregateFunction, AggregateUDF, Expr, ExprSchemable, GetIndexedField,
Expand Down Expand Up @@ -690,8 +689,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let join_filter = filter
.into_iter()
.map(|expr| {
let mut using_columns = HashSet::new();
expr_to_columns(&expr, &mut using_columns)?;
let using_columns = expr.to_columns()?;

normalize_col_with_schemas(
expr,
Expand Down

0 comments on commit 509c82c

Please sign in to comment.