Skip to content

Commit

Permalink
Improve TableScan with filters pushdown unparsing (joins) (#13132)
Browse files Browse the repository at this point in the history
* Improve TableScan with filters pushdown unparsing (joins)

* Fix formatting

* Add test with filters before and after join
  • Loading branch information
sgrebnov authored Oct 29, 2024
1 parent 1fd6116 commit 0b45b9a
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 17 deletions.
77 changes: 66 additions & 11 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ use super::{
},
utils::{
find_agg_node_within_select, find_unnest_node_within_select,
find_window_nodes_within_select, unproject_sort_expr, unproject_unnest_expr,
unproject_window_exprs,
find_window_nodes_within_select, try_transform_to_simple_table_scan_with_filters,
unproject_sort_expr, unproject_unnest_expr, unproject_window_exprs,
},
Unparser,
};
Expand All @@ -39,8 +39,8 @@ use datafusion_common::{
Column, DataFusionError, Result, TableReference,
};
use datafusion_expr::{
expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan,
LogicalPlanBuilder, Projection, SortExpr, TableScan,
expr::Alias, BinaryExpr, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan,
LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan,
};
use sqlparser::ast::{self, Ident, SetExpr};
use std::sync::Arc;
Expand Down Expand Up @@ -468,22 +468,77 @@ impl Unparser<'_> {
self.select_to_sql_recursively(input, query, select, relation)
}
LogicalPlan::Join(join) => {
let join_constraint = self.join_constraint_to_sql(
join.join_constraint,
&join.on,
join.filter.as_ref(),
let mut table_scan_filters = vec![];

let left_plan =
match try_transform_to_simple_table_scan_with_filters(&join.left)? {
Some((plan, filters)) => {
table_scan_filters.extend(filters);
Arc::new(plan)
}
None => Arc::clone(&join.left),
};

self.select_to_sql_recursively(
left_plan.as_ref(),
query,
select,
relation,
)?;

let right_plan =
match try_transform_to_simple_table_scan_with_filters(&join.right)? {
Some((plan, filters)) => {
table_scan_filters.extend(filters);
Arc::new(plan)
}
None => Arc::clone(&join.right),
};

let mut right_relation = RelationBuilder::default();

self.select_to_sql_recursively(
join.left.as_ref(),
right_plan.as_ref(),
query,
select,
relation,
&mut right_relation,
)?;

let join_filters = if table_scan_filters.is_empty() {
join.filter.clone()
} else {
// Combine `table_scan_filters` into a single filter using `AND`
let Some(combined_filters) =
table_scan_filters.into_iter().reduce(|acc, filter| {
Expr::BinaryExpr(BinaryExpr {
left: Box::new(acc),
op: Operator::And,
right: Box::new(filter),
})
})
else {
return internal_err!("Failed to combine TableScan filters");
};

// Combine `join.filter` with `combined_filters` using `AND`
match &join.filter {
Some(filter) => Some(Expr::BinaryExpr(BinaryExpr {
left: Box::new(filter.clone()),
op: Operator::And,
right: Box::new(combined_filters),
})),
None => Some(combined_filters),
}
};

let join_constraint = self.join_constraint_to_sql(
join.join_constraint,
&join.on,
join_filters.as_ref(),
)?;

self.select_to_sql_recursively(
join.right.as_ref(),
right_plan.as_ref(),
query,
select,
&mut right_relation,
Expand Down
93 changes: 87 additions & 6 deletions datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@
// specific language governing permissions and limitations
// under the License.

use std::cmp::Ordering;
use std::{cmp::Ordering, sync::Arc, vec};

use datafusion_common::{
internal_err,
tree_node::{Transformed, TreeNode},
Column, Result, ScalarValue,
tree_node::{Transformed, TransformedResult, TreeNode},
Column, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::{
expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection,
SortExpr, Unnest, Window,
expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan,
LogicalPlanBuilder, Projection, SortExpr, Unnest, Window,
};
use sqlparser::ast;

use super::{dialect::DateFieldExtractStyle, Unparser};
use super::{dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, Unparser};

/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
Expand Down Expand Up @@ -288,6 +288,87 @@ pub(crate) fn unproject_sort_expr(
Ok(sort_expr)
}

/// Iterates through the children of a [LogicalPlan] to find a TableScan node before encountering
/// a Projection or any unexpected node that indicates the presence of a Projection (SELECT) in the plan.
/// If a TableScan node is found, returns the TableScan node without filters, along with the collected filters separately.
/// If the plan contains a Projection, returns None.
///
/// Note: If a table alias is present, TableScan filters are rewritten to reference the alias.
///
/// LogicalPlan example:
/// Filter: ta.j1_id < 5
/// Alias: ta
/// TableScan: j1, j1_id > 10
///
/// Will return LogicalPlan below:
/// Alias: ta
/// TableScan: j1
/// And filters: [ta.j1_id < 5, ta.j1_id > 10]
pub(crate) fn try_transform_to_simple_table_scan_with_filters(
plan: &LogicalPlan,
) -> Result<Option<(LogicalPlan, Vec<Expr>)>> {
let mut filters: Vec<Expr> = vec![];
let mut plan_stack = vec![plan];
let mut table_alias = None;

while let Some(current_plan) = plan_stack.pop() {
match current_plan {
LogicalPlan::SubqueryAlias(alias) => {
table_alias = Some(alias.alias.clone());
plan_stack.push(alias.input.as_ref());
}
LogicalPlan::Filter(filter) => {
filters.push(filter.predicate.clone());
plan_stack.push(filter.input.as_ref());
}
LogicalPlan::TableScan(table_scan) => {
let table_schema = table_scan.source.schema();
// optional rewriter if table has an alias
let mut filter_alias_rewriter =
table_alias.as_ref().map(|alias_name| TableAliasRewriter {
table_schema: &table_schema,
alias_name: alias_name.clone(),
});

// rewrite filters to use table alias if present
let table_scan_filters = table_scan
.filters
.iter()
.cloned()
.map(|expr| {
if let Some(ref mut rewriter) = filter_alias_rewriter {
expr.rewrite(rewriter).data()
} else {
Ok(expr)
}
})
.collect::<Result<Vec<_>, DataFusionError>>()?;

filters.extend(table_scan_filters);

let mut builder = LogicalPlanBuilder::scan(
table_scan.table_name.clone(),
Arc::clone(&table_scan.source),
None,
)?;

if let Some(alias) = table_alias.take() {
builder = builder.alias(alias)?;
}

let plan = builder.build()?;

return Ok(Some((plan, filters)));
}
_ => {
return Ok(None);
}
}
}

Ok(None)
}

/// Converts a date_part function to SQL, tailoring it to the supported date field extraction style.
pub(crate) fn date_part_to_sql(
unparser: &Unparser,
Expand Down
87 changes: 87 additions & 0 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,93 @@ fn test_sort_with_push_down_fetch() -> Result<()> {
Ok(())
}

#[test]
fn test_join_with_table_scan_filters() -> Result<()> {
let schema_left = Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("name", DataType::Utf8, false),
]);

let schema_right = Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("age", DataType::Utf8, false),
]);

let left_plan = table_scan_with_filters(
Some("left_table"),
&schema_left,
None,
vec![col("name").like(lit("some_name"))],
)?
.alias("left")?
.build()?;

let right_plan = table_scan_with_filters(
Some("right_table"),
&schema_right,
None,
vec![col("age").gt(lit(10))],
)?
.build()?;

let join_plan_with_filter = LogicalPlanBuilder::from(left_plan.clone())
.join(
right_plan.clone(),
datafusion_expr::JoinType::Inner,
(vec!["left.id"], vec!["right_table.id"]),
Some(col("left.id").gt(lit(5))),
)?
.build()?;

let sql = plan_to_sql(&join_plan_with_filter)?;

let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND ("left"."name" LIKE 'some_name' AND (age > 10)))"#;

assert_eq!(sql.to_string(), expected_sql);

let join_plan_no_filter = LogicalPlanBuilder::from(left_plan.clone())
.join(
right_plan,
datafusion_expr::JoinType::Inner,
(vec!["left.id"], vec!["right_table.id"]),
None,
)?
.build()?;

let sql = plan_to_sql(&join_plan_no_filter)?;

let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND ("left"."name" LIKE 'some_name' AND (age > 10))"#;

assert_eq!(sql.to_string(), expected_sql);

let right_plan_with_filter = table_scan_with_filters(
Some("right_table"),
&schema_right,
None,
vec![col("age").gt(lit(10))],
)?
.filter(col("right_table.name").eq(lit("before_join_filter_val")))?
.build()?;

let join_plan_multiple_filters = LogicalPlanBuilder::from(left_plan.clone())
.join(
right_plan_with_filter,
datafusion_expr::JoinType::Inner,
(vec!["left.id"], vec!["right_table.id"]),
Some(col("left.id").gt(lit(5))),
)?
.filter(col("left.name").eq(lit("after_join_filter_val")))?
.build()?;

let sql = plan_to_sql(&join_plan_multiple_filters)?;

let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND (("left"."name" LIKE 'some_name' AND (right_table."name" = 'before_join_filter_val')) AND (age > 10))) WHERE ("left"."name" = 'after_join_filter_val')"#;

assert_eq!(sql.to_string(), expected_sql);

Ok(())
}

#[test]
fn test_interval_lhs_eq() {
sql_round_trip(
Expand Down

0 comments on commit 0b45b9a

Please sign in to comment.