Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve TableScan with filters pushdown unparsing (joins) #13132

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 66 additions & 10 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ use super::{
},
utils::{
find_agg_node_within_select, find_window_nodes_within_select,
unproject_sort_expr, unproject_window_exprs,
try_transform_to_simple_table_scan_with_filters, unproject_sort_expr,
unproject_window_exprs,
},
Unparser,
};
Expand All @@ -38,8 +39,8 @@ use datafusion_common::{
Column, DataFusionError, Result, TableReference,
};
use datafusion_expr::{
expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan,
LogicalPlanBuilder, Projection, SortExpr, TableScan,
expr::Alias, BinaryExpr, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan,
LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan,
};
use sqlparser::ast::{self, Ident, SetExpr};
use std::sync::Arc;
Expand Down Expand Up @@ -459,22 +460,77 @@ impl Unparser<'_> {
self.select_to_sql_recursively(input, query, select, relation)
}
LogicalPlan::Join(join) => {
let join_constraint = self.join_constraint_to_sql(
join.join_constraint,
&join.on,
join.filter.as_ref(),
let mut table_scan_filters = vec![];

let left_plan =
match try_transform_to_simple_table_scan_with_filters(&join.left)? {
Some((plan, filters)) => {
table_scan_filters.extend(filters);
Arc::new(plan)
}
None => Arc::clone(&join.left),
};

self.select_to_sql_recursively(
left_plan.as_ref(),
query,
select,
relation,
)?;

let right_plan =
match try_transform_to_simple_table_scan_with_filters(&join.right)? {
Some((plan, filters)) => {
table_scan_filters.extend(filters);
Arc::new(plan)
}
None => Arc::clone(&join.right),
};

let mut right_relation = RelationBuilder::default();

self.select_to_sql_recursively(
join.left.as_ref(),
right_plan.as_ref(),
query,
select,
relation,
&mut right_relation,
)?;

let join_filters = if table_scan_filters.is_empty() {
join.filter.clone()
} else {
// Combine `table_scan_filters` into a single filter using `AND`
let Some(combined_filters) =
table_scan_filters.into_iter().reduce(|acc, filter| {
Expr::BinaryExpr(BinaryExpr {
left: Box::new(acc),
op: Operator::And,
right: Box::new(filter),
})
})
else {
return internal_err!("Failed to combine TableScan filters");
};

// Combine `join.filter` with `combined_filters` using `AND`
match &join.filter {
Some(filter) => Some(Expr::BinaryExpr(BinaryExpr {
left: Box::new(filter.clone()),
op: Operator::And,
right: Box::new(combined_filters),
})),
None => Some(combined_filters),
}
};

let join_constraint = self.join_constraint_to_sql(
join.join_constraint,
&join.on,
join_filters.as_ref(),
)?;

self.select_to_sql_recursively(
join.right.as_ref(),
right_plan.as_ref(),
query,
select,
&mut right_relation,
Expand Down
93 changes: 87 additions & 6 deletions datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@
// specific language governing permissions and limitations
// under the License.

use std::cmp::Ordering;
use std::{cmp::Ordering, sync::Arc, vec};

use datafusion_common::{
internal_err,
tree_node::{Transformed, TreeNode},
Column, Result, ScalarValue,
tree_node::{Transformed, TransformedResult, TreeNode},
Column, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::{
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection, SortExpr,
Window,
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, LogicalPlanBuilder,
Projection, SortExpr, Window,
};
use sqlparser::ast;

use super::{dialect::DateFieldExtractStyle, Unparser};
use super::{dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, Unparser};

/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
Expand Down Expand Up @@ -239,6 +239,87 @@ pub(crate) fn unproject_sort_expr(
Ok(sort_expr)
}

/// Iterates through the children of a [LogicalPlan] to find a TableScan node before encountering
/// a Projection or any unexpected node that indicates the presence of a Projection (SELECT) in the plan.
/// If a TableScan node is found, returns the TableScan node without filters, along with the collected filters separately.
/// If the plan contains a Projection, returns None.
///
/// Note: If a table alias is present, TableScan filters are rewritten to reference the alias.
///
/// LogicalPlan example:
/// Filter: ta.j1_id < 5
/// Alias: ta
/// TableScan: j1, j1_id > 10
///
/// Will return LogicalPlan below:
/// Alias: ta
/// TableScan: j1
/// And filters: [ta.j1_id < 5, ta.j1_id > 10]
pub(crate) fn try_transform_to_simple_table_scan_with_filters(
plan: &LogicalPlan,
) -> Result<Option<(LogicalPlan, Vec<Expr>)>> {
let mut filters: Vec<Expr> = vec![];
let mut plan_stack = vec![plan];
let mut table_alias = None;

while let Some(current_plan) = plan_stack.pop() {
match current_plan {
LogicalPlan::SubqueryAlias(alias) => {
table_alias = Some(alias.alias.clone());
plan_stack.push(alias.input.as_ref());
}
LogicalPlan::Filter(filter) => {
filters.push(filter.predicate.clone());
plan_stack.push(filter.input.as_ref());
}
LogicalPlan::TableScan(table_scan) => {
let table_schema = table_scan.source.schema();
// optional rewriter if table has an alias
let mut filter_alias_rewriter =
table_alias.as_ref().map(|alias_name| TableAliasRewriter {
table_schema: &table_schema,
alias_name: alias_name.clone(),
});

// rewrite filters to use table alias if present
let table_scan_filters = table_scan
sgrebnov marked this conversation as resolved.
Show resolved Hide resolved
.filters
.iter()
.cloned()
.map(|expr| {
if let Some(ref mut rewriter) = filter_alias_rewriter {
expr.rewrite(rewriter).data()
} else {
Ok(expr)
}
})
.collect::<Result<Vec<_>, DataFusionError>>()?;

filters.extend(table_scan_filters);

let mut builder = LogicalPlanBuilder::scan(
table_scan.table_name.clone(),
Arc::clone(&table_scan.source),
None,
)?;

if let Some(alias) = table_alias.take() {
builder = builder.alias(alias)?;
}

let plan = builder.build()?;

return Ok(Some((plan, filters)));
}
_ => {
return Ok(None);
}
}
}

Ok(None)
}

/// Converts a date_part function to SQL, tailoring it to the supported date field extraction style.
pub(crate) fn date_part_to_sql(
unparser: &Unparser,
Expand Down
62 changes: 62 additions & 0 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,68 @@ fn test_sort_with_push_down_fetch() -> Result<()> {
Ok(())
}

#[test]
fn test_join_with_table_scan_filters() -> Result<()> {
let schema_left = Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("name", DataType::Utf8, false),
]);

let schema_right = Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("age", DataType::Utf8, false),
]);

let left_plan = table_scan_with_filters(
Some("left_table"),
&schema_left,
None,
vec![col("name").like(lit("some_name"))],
)?
.alias("left")?
.build()?;

let right_plan = table_scan_with_filters(
Some("right_table"),
&schema_right,
None,
vec![col("age").gt(lit(10))],
)?
.build()?;

let join_plan_with_filter = LogicalPlanBuilder::from(left_plan.clone())
.join(
right_plan.clone(),
datafusion_expr::JoinType::Inner,
(vec!["left.id"], vec!["right_table.id"]),
Some(col("left.id").gt(lit(5))),
)?
.build()?;
goldmedal marked this conversation as resolved.
Show resolved Hide resolved

let sql = plan_to_sql(&join_plan_with_filter)?;

let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND ("left"."name" LIKE 'some_name' AND (age > 10)))"#;

assert_eq!(sql.to_string(), expected_sql);

let join_plan_no_filter = LogicalPlanBuilder::from(left_plan)
.join(
right_plan,
datafusion_expr::JoinType::Inner,
(vec!["left.id"], vec!["right_table.id"]),
None,
)?
.build()?;

let sql = plan_to_sql(&join_plan_no_filter)?;

let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND ("left"."name" LIKE 'some_name' AND (age > 10))"#;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed you put the pushdown condition in the join condition instead of the WHERE. In my opinion, the SQL plan will be different if we sometimes put the condition in a different place.

I did some tests for different join type and different place (join condition or filter) in DataFusion

    let join_type = vec![
        "inner join", "left join", "right join", "full join"
    ];

    for join in join_type {
        println!("-----------------{join}-------------------");
        println!("###### predicate in filter ######");
        let sql = format!("select o_orderkey from orders {join} customer on o_custkey = c_custkey where c_name = 'Customer#000000001'");
        println!("SQL: {}", sql);
        match ctx.sql(&sql).await?.into_optimized_plan() {
            Ok(plan) => {println!("{plan}")},
            Err(e) => eprintln!("Error: {}", e),
        }
        println!("###### predicate in join condition ######");
        let sql = format!("select o_orderkey from orders {join} customer on o_custkey = c_custkey and c_name = 'Customer#000000001'");
        println!("SQL: {}", sql);
        match ctx.sql(&sql).await?.into_optimized_plan() {
            Ok(plan) => {println!("{plan}")},
            Err(e) => eprintln!("Error: {}", e),
        }
    }

The result is

-----------------inner join-------------------
###### predicate in filter ######
SQL: select o_orderkey from orders inner join customer on o_custkey = c_custkey where c_name = 'Customer#000000001'
Projection: orders.o_orderkey
  Inner Join: orders.o_custkey = customer.c_custkey
    TableScan: orders projection=[o_orderkey, o_custkey]
    Projection: customer.c_custkey
      Filter: customer.c_name = Utf8("Customer#000000001")
        TableScan: customer projection=[c_custkey, c_name], partial_filters=[customer.c_name = Utf8("Customer#000000001")]
###### predicate in join condition ######
SQL: select o_orderkey from orders inner join customer on o_custkey = c_custkey and c_name = 'Customer#000000001'
Projection: orders.o_orderkey
  Inner Join: orders.o_custkey = customer.c_custkey
    TableScan: orders projection=[o_orderkey, o_custkey]
    Projection: customer.c_custkey
      Filter: customer.c_name = Utf8("Customer#000000001")
        TableScan: customer projection=[c_custkey, c_name], partial_filters=[customer.c_name = Utf8("Customer#000000001")]
-----------------left join-------------------
###### predicate in filter ######
SQL: select o_orderkey from orders left join customer on o_custkey = c_custkey where c_name = 'Customer#000000001'
Projection: orders.o_orderkey
  Inner Join: orders.o_custkey = customer.c_custkey
    TableScan: orders projection=[o_orderkey, o_custkey]
    Projection: customer.c_custkey
      Filter: customer.c_name = Utf8("Customer#000000001")
        TableScan: customer projection=[c_custkey, c_name], partial_filters=[customer.c_name = Utf8("Customer#000000001")]
###### predicate in join condition ######
SQL: select o_orderkey from orders left join customer on o_custkey = c_custkey and c_name = 'Customer#000000001'
Projection: orders.o_orderkey
  Left Join: orders.o_custkey = customer.c_custkey
    TableScan: orders projection=[o_orderkey, o_custkey]
    Projection: customer.c_custkey
      Filter: customer.c_name = Utf8("Customer#000000001")
        TableScan: customer projection=[c_custkey, c_name], partial_filters=[customer.c_name = Utf8("Customer#000000001")]
-----------------right join-------------------
###### predicate in filter ######
SQL: select o_orderkey from orders right join customer on o_custkey = c_custkey where c_name = 'Customer#000000001'
Projection: orders.o_orderkey
  Right Join: orders.o_custkey = customer.c_custkey
    TableScan: orders projection=[o_orderkey, o_custkey]
    Projection: customer.c_custkey
      Filter: customer.c_name = Utf8("Customer#000000001")
        TableScan: customer projection=[c_custkey, c_name], partial_filters=[customer.c_name = Utf8("Customer#000000001")]
###### predicate in join condition ######
SQL: select o_orderkey from orders right join customer on o_custkey = c_custkey and c_name = 'Customer#000000001'
Projection: orders.o_orderkey
  Right Join: orders.o_custkey = customer.c_custkey Filter: customer.c_name = Utf8("Customer#000000001")
    TableScan: orders projection=[o_orderkey, o_custkey]
    TableScan: customer projection=[c_custkey, c_name]
-----------------full join-------------------
###### predicate in filter ######
SQL: select o_orderkey from orders full join customer on o_custkey = c_custkey where c_name = 'Customer#000000001'
Projection: orders.o_orderkey
  Right Join: orders.o_custkey = customer.c_custkey
    TableScan: orders projection=[o_orderkey, o_custkey]
    Projection: customer.c_custkey
      Filter: customer.c_name = Utf8("Customer#000000001")
        TableScan: customer projection=[c_custkey, c_name], partial_filters=[customer.c_name = Utf8("Customer#000000001")]
###### predicate in join condition ######
SQL: select o_orderkey from orders full join customer on o_custkey = c_custkey and c_name = 'Customer#000000001'
Projection: orders.o_orderkey
  Full Join: orders.o_custkey = customer.c_custkey Filter: customer.c_name = Utf8("Customer#000000001")
    TableScan: orders projection=[o_orderkey, o_custkey]
    TableScan: customer projection=[c_custkey, c_name]

We can find the plan is the same in inner join and left join. The filter pushdown works fine. However, in right join and full join cases, if we put the predicate in the join condition, the filter pushdown doesn't work.
In the DataFusion case, the filter pushdown always works when putting the filter in WHERE.

I'm not pretty sure if it's a common rule (putting the predicate in WHERE is better) for the other database. However, in the DataFusino case, we're better to put them in WHERE.

By the way, this PR is ok for me now. I think it can be improved by a follow-up PR if we care about the performance of the generated SQL.

cc @alamb @phillipleblanc

Copy link
Member Author

@sgrebnov sgrebnov Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldmedal - thank you for deep review. I suspect that filters were not fully pushed down for the right join and full join cases by DF during optimization for samples above as two test queries are not exactly the same as how records are filtered:

`on o_custkey = c_custkey where c_name = 'Customer#000000001'  <-- filter is applied after join, fully filter out non matching records
`on o_custkey = c_custkey and c_name = 'Customer#000000001'` <-- filtering is done during join, will join/include NULL

It seems in all examples above the original WHERE was moved inside Join by optimizer (all cases), so optimized plan should be unparsed as below for right join, for example

select o_orderkey from orders right join (select c_custkey from customer where c_name = 'Customer#000000001') on o_custkey = c_custkey

and we will translate it to

select o_orderkey from orders right join customers on o_custkey = c_custkey and c_name = 'Customer#000000001'

@goldmedal - Is my understanding correct that tha main concern is that the first option is preferred as it could be executed more efficient by target engine?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldmedal - Is my understanding correct that tha main concern is that the first option is preferred as it could be executed more efficient by target engine?

Yes, if we can push down the predicate to the table scan, it usually means it will perform better.

I tried the subquery pattern:

-----------------inner join-------------------
###### predicate in filter ######
SQL: select o_orderkey from orders inner join (select c_custkey  from customer where c_name = 'Customer#000000001') on o_custkey = c_custkey
Projection: orders.o_orderkey
  Inner Join: orders.o_custkey = customer.c_custkey
    TableScan: orders projection=[o_orderkey, o_custkey]
    Projection: customer.c_custkey
      Filter: customer.c_name = Utf8("Customer#000000001")
        TableScan: customer projection=[c_custkey, c_name], partial_filters=[customer.c_name = Utf8("Customer#000000001")]
-----------------left join-------------------
###### predicate in filter ######
SQL: select o_orderkey from orders left join (select c_custkey  from customer where c_name = 'Customer#000000001') on o_custkey = c_custkey
Projection: orders.o_orderkey
  Left Join: orders.o_custkey = customer.c_custkey
    TableScan: orders projection=[o_orderkey, o_custkey]
    Projection: customer.c_custkey
      Filter: customer.c_name = Utf8("Customer#000000001")
        TableScan: customer projection=[c_custkey, c_name], partial_filters=[customer.c_name = Utf8("Customer#000000001")]
-----------------right join-------------------
###### predicate in filter ######
SQL: select o_orderkey from orders right join (select c_custkey  from customer where c_name = 'Customer#000000001') on o_custkey = c_custkey
Projection: orders.o_orderkey
  Right Join: orders.o_custkey = customer.c_custkey
    TableScan: orders projection=[o_orderkey, o_custkey]
    Projection: customer.c_custkey
      Filter: customer.c_name = Utf8("Customer#000000001")
        TableScan: customer projection=[c_custkey, c_name], partial_filters=[customer.c_name = Utf8("Customer#000000001")]
-----------------full join-------------------
###### predicate in filter ######
SQL: select o_orderkey from orders full join (select c_custkey  from customer where c_name = 'Customer#000000001') on o_custkey = c_custkey
Projection: orders.o_orderkey
  Full Join: orders.o_custkey = customer.c_custkey
    TableScan: orders projection=[o_orderkey, o_custkey]
    Projection: customer.c_custkey
      Filter: customer.c_name = Utf8("Customer#000000001")
        TableScan: customer projection=[c_custkey, c_name], partial_filters=[customer.c_name = Utf8("Customer#000000001")]

Every predicate is pushed down to the table scan. It's better 👍
I haven't checked the planner of other databases but I think they're similar.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgrebnov Do you want to improve it in this PR? or we can do it in the follow-up PR (maybe file an issue). WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldmedal - I would prefer the incremental approach with a follow-up PR. Thank you!


assert_eq!(sql.to_string(), expected_sql);

Ok(())
}

#[test]
fn test_interval_lhs_eq() {
sql_round_trip(
Expand Down