diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 7c9054656b94..2c38a1d36c1e 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -27,8 +27,8 @@ use super::{ }, utils::{ find_agg_node_within_select, find_unnest_node_within_select, - find_window_nodes_within_select, unproject_sort_expr, unproject_unnest_expr, - unproject_window_exprs, + find_window_nodes_within_select, try_transform_to_simple_table_scan_with_filters, + unproject_sort_expr, unproject_unnest_expr, unproject_window_exprs, }, Unparser, }; @@ -39,8 +39,8 @@ use datafusion_common::{ Column, DataFusionError, Result, TableReference, }; use datafusion_expr::{ - expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, - LogicalPlanBuilder, Projection, SortExpr, TableScan, + expr::Alias, BinaryExpr, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, + LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan, }; use sqlparser::ast::{self, Ident, SetExpr}; use std::sync::Arc; @@ -468,22 +468,77 @@ impl Unparser<'_> { self.select_to_sql_recursively(input, query, select, relation) } LogicalPlan::Join(join) => { - let join_constraint = self.join_constraint_to_sql( - join.join_constraint, - &join.on, - join.filter.as_ref(), + let mut table_scan_filters = vec![]; + + let left_plan = + match try_transform_to_simple_table_scan_with_filters(&join.left)? { + Some((plan, filters)) => { + table_scan_filters.extend(filters); + Arc::new(plan) + } + None => Arc::clone(&join.left), + }; + + self.select_to_sql_recursively( + left_plan.as_ref(), + query, + select, + relation, )?; + let right_plan = + match try_transform_to_simple_table_scan_with_filters(&join.right)? { + Some((plan, filters)) => { + table_scan_filters.extend(filters); + Arc::new(plan) + } + None => Arc::clone(&join.right), + }; + let mut right_relation = RelationBuilder::default(); self.select_to_sql_recursively( - join.left.as_ref(), + right_plan.as_ref(), query, select, - relation, + &mut right_relation, )?; + + let join_filters = if table_scan_filters.is_empty() { + join.filter.clone() + } else { + // Combine `table_scan_filters` into a single filter using `AND` + let Some(combined_filters) = + table_scan_filters.into_iter().reduce(|acc, filter| { + Expr::BinaryExpr(BinaryExpr { + left: Box::new(acc), + op: Operator::And, + right: Box::new(filter), + }) + }) + else { + return internal_err!("Failed to combine TableScan filters"); + }; + + // Combine `join.filter` with `combined_filters` using `AND` + match &join.filter { + Some(filter) => Some(Expr::BinaryExpr(BinaryExpr { + left: Box::new(filter.clone()), + op: Operator::And, + right: Box::new(combined_filters), + })), + None => Some(combined_filters), + } + }; + + let join_constraint = self.join_constraint_to_sql( + join.join_constraint, + &join.on, + join_filters.as_ref(), + )?; + self.select_to_sql_recursively( - join.right.as_ref(), + right_plan.as_ref(), query, select, &mut right_relation, diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index d3d1bf351384..284956cef195 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -15,20 +15,20 @@ // specific language governing permissions and limitations // under the License. -use std::cmp::Ordering; +use std::{cmp::Ordering, sync::Arc, vec}; use datafusion_common::{ internal_err, - tree_node::{Transformed, TreeNode}, - Column, Result, ScalarValue, + tree_node::{Transformed, TransformedResult, TreeNode}, + Column, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection, - SortExpr, Unnest, Window, + expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, + LogicalPlanBuilder, Projection, SortExpr, Unnest, Window, }; use sqlparser::ast; -use super::{dialect::DateFieldExtractStyle, Unparser}; +use super::{dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, Unparser}; /// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). @@ -288,6 +288,87 @@ pub(crate) fn unproject_sort_expr( Ok(sort_expr) } +/// Iterates through the children of a [LogicalPlan] to find a TableScan node before encountering +/// a Projection or any unexpected node that indicates the presence of a Projection (SELECT) in the plan. +/// If a TableScan node is found, returns the TableScan node without filters, along with the collected filters separately. +/// If the plan contains a Projection, returns None. +/// +/// Note: If a table alias is present, TableScan filters are rewritten to reference the alias. +/// +/// LogicalPlan example: +/// Filter: ta.j1_id < 5 +/// Alias: ta +/// TableScan: j1, j1_id > 10 +/// +/// Will return LogicalPlan below: +/// Alias: ta +/// TableScan: j1 +/// And filters: [ta.j1_id < 5, ta.j1_id > 10] +pub(crate) fn try_transform_to_simple_table_scan_with_filters( + plan: &LogicalPlan, +) -> Result)>> { + let mut filters: Vec = vec![]; + let mut plan_stack = vec![plan]; + let mut table_alias = None; + + while let Some(current_plan) = plan_stack.pop() { + match current_plan { + LogicalPlan::SubqueryAlias(alias) => { + table_alias = Some(alias.alias.clone()); + plan_stack.push(alias.input.as_ref()); + } + LogicalPlan::Filter(filter) => { + filters.push(filter.predicate.clone()); + plan_stack.push(filter.input.as_ref()); + } + LogicalPlan::TableScan(table_scan) => { + let table_schema = table_scan.source.schema(); + // optional rewriter if table has an alias + let mut filter_alias_rewriter = + table_alias.as_ref().map(|alias_name| TableAliasRewriter { + table_schema: &table_schema, + alias_name: alias_name.clone(), + }); + + // rewrite filters to use table alias if present + let table_scan_filters = table_scan + .filters + .iter() + .cloned() + .map(|expr| { + if let Some(ref mut rewriter) = filter_alias_rewriter { + expr.rewrite(rewriter).data() + } else { + Ok(expr) + } + }) + .collect::, DataFusionError>>()?; + + filters.extend(table_scan_filters); + + let mut builder = LogicalPlanBuilder::scan( + table_scan.table_name.clone(), + Arc::clone(&table_scan.source), + None, + )?; + + if let Some(alias) = table_alias.take() { + builder = builder.alias(alias)?; + } + + let plan = builder.build()?; + + return Ok(Some((plan, filters))); + } + _ => { + return Ok(None); + } + } + } + + Ok(None) +} + /// Converts a date_part function to SQL, tailoring it to the supported date field extraction style. pub(crate) fn date_part_to_sql( unparser: &Unparser, diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 16941c5d9164..ea0ccb8e4b43 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -1008,6 +1008,93 @@ fn test_sort_with_push_down_fetch() -> Result<()> { Ok(()) } +#[test] +fn test_join_with_table_scan_filters() -> Result<()> { + let schema_left = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("name", DataType::Utf8, false), + ]); + + let schema_right = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); + + let left_plan = table_scan_with_filters( + Some("left_table"), + &schema_left, + None, + vec![col("name").like(lit("some_name"))], + )? + .alias("left")? + .build()?; + + let right_plan = table_scan_with_filters( + Some("right_table"), + &schema_right, + None, + vec![col("age").gt(lit(10))], + )? + .build()?; + + let join_plan_with_filter = LogicalPlanBuilder::from(left_plan.clone()) + .join( + right_plan.clone(), + datafusion_expr::JoinType::Inner, + (vec!["left.id"], vec!["right_table.id"]), + Some(col("left.id").gt(lit(5))), + )? + .build()?; + + let sql = plan_to_sql(&join_plan_with_filter)?; + + let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND ("left"."name" LIKE 'some_name' AND (age > 10)))"#; + + assert_eq!(sql.to_string(), expected_sql); + + let join_plan_no_filter = LogicalPlanBuilder::from(left_plan.clone()) + .join( + right_plan, + datafusion_expr::JoinType::Inner, + (vec!["left.id"], vec!["right_table.id"]), + None, + )? + .build()?; + + let sql = plan_to_sql(&join_plan_no_filter)?; + + let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND ("left"."name" LIKE 'some_name' AND (age > 10))"#; + + assert_eq!(sql.to_string(), expected_sql); + + let right_plan_with_filter = table_scan_with_filters( + Some("right_table"), + &schema_right, + None, + vec![col("age").gt(lit(10))], + )? + .filter(col("right_table.name").eq(lit("before_join_filter_val")))? + .build()?; + + let join_plan_multiple_filters = LogicalPlanBuilder::from(left_plan.clone()) + .join( + right_plan_with_filter, + datafusion_expr::JoinType::Inner, + (vec!["left.id"], vec!["right_table.id"]), + Some(col("left.id").gt(lit(5))), + )? + .filter(col("left.name").eq(lit("after_join_filter_val")))? + .build()?; + + let sql = plan_to_sql(&join_plan_multiple_filters)?; + + let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND (("left"."name" LIKE 'some_name' AND (right_table."name" = 'before_join_filter_val')) AND (age > 10))) WHERE ("left"."name" = 'after_join_filter_val')"#; + + assert_eq!(sql.to_string(), expected_sql); + + Ok(()) +} + #[test] fn test_interval_lhs_eq() { sql_round_trip(