From 03d6167f039d24677f4478e61af07988e1159fc3 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 27 Aug 2024 14:57:02 +0200 Subject: [PATCH] Remove Box from Sort `expr::Sort` had `Box` because Sort was also an expression (via `expr::Expr::Sort`). This has been removed, obsoleting need to use a `Box`. --- datafusion/core/src/datasource/mod.rs | 2 +- datafusion/core/tests/dataframe/mod.rs | 20 +++++++++---------- .../tests/user_defined/user_defined_plan.rs | 2 +- datafusion/expr/src/expr.rs | 6 +++--- datafusion/expr/src/expr_rewriter/mod.rs | 4 ++-- datafusion/expr/src/expr_rewriter/order_by.rs | 2 +- datafusion/expr/src/logical_plan/builder.rs | 8 ++++---- datafusion/expr/src/logical_plan/plan.rs | 2 +- datafusion/expr/src/logical_plan/tree_node.rs | 2 +- datafusion/expr/src/tree_node.rs | 8 ++++---- datafusion/expr/src/utils.rs | 16 +++++++-------- .../src/analyzer/count_wildcard_rule.rs | 2 +- .../optimizer/src/common_subexpr_eliminate.rs | 3 +-- .../proto/src/logical_plan/from_proto.rs | 7 +------ datafusion/proto/src/logical_plan/to_proto.rs | 2 +- datafusion/sql/src/expr/function.rs | 2 +- datafusion/sql/src/expr/order_by.rs | 2 +- datafusion/sql/src/unparser/expr.rs | 2 +- datafusion/sql/src/unparser/rewrite.rs | 2 +- .../substrait/src/logical_plan/consumer.rs | 2 +- .../substrait/src/logical_plan/producer.rs | 9 +-------- 21 files changed, 46 insertions(+), 59 deletions(-) diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 55e88e572be1..529bb799e23d 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -63,7 +63,7 @@ fn create_ordering( // Construct PhysicalSortExpr objects from Expr objects: let mut sort_exprs = vec![]; for sort in exprs { - match sort.expr.as_ref() { + match &sort.expr { Expr::Column(col) => match expressions::col(&col.name, schema) { Ok(expr) => { sort_exprs.push(PhysicalSortExpr { diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index c5b9db7588e9..19ce9294cfad 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -184,7 +184,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) - .order_by(vec![Sort::new(Box::new(col("a")), false, true)]) + .order_by(vec![Sort::new(col("a"), false, true)]) .window_frame(WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), @@ -352,7 +352,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { .unwrap() .select(vec![col("a")]) .unwrap() - .sort(vec![Sort::new(Box::new(col("b")), false, true)]) + .sort(vec![Sort::new(col("b"), false, true)]) .unwrap(); let results = df.collect().await.unwrap(); @@ -396,7 +396,7 @@ async fn sort_on_distinct_columns() -> Result<()> { .unwrap() .distinct() .unwrap() - .sort(vec![Sort::new(Box::new(col("a")), false, true)]) + .sort(vec![Sort::new(col("a"), false, true)]) .unwrap(); let results = df.collect().await.unwrap(); @@ -435,7 +435,7 @@ async fn sort_on_distinct_unprojected_columns() -> Result<()> { .await? .select(vec![col("a")])? .distinct()? - .sort(vec![Sort::new(Box::new(col("b")), false, true)]) + .sort(vec![Sort::new(col("b"), false, true)]) .unwrap_err(); assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions b must appear in select list"); Ok(()) @@ -599,8 +599,8 @@ async fn test_grouping_sets() -> Result<()> { .await? .aggregate(vec![grouping_set_expr], vec![count(col("a"))])? .sort(vec![ - Sort::new(Box::new(col("a")), false, true), - Sort::new(Box::new(col("b")), false, true), + Sort::new(col("a"), false, true), + Sort::new(col("b"), false, true), ])?; let results = df.collect().await?; @@ -640,8 +640,8 @@ async fn test_grouping_sets_count() -> Result<()> { .await? .aggregate(vec![grouping_set_expr], vec![count(lit(1))])? .sort(vec![ - Sort::new(Box::new(col("c1")), false, true), - Sort::new(Box::new(col("c2")), false, true), + Sort::new(col("c1"), false, true), + Sort::new(col("c2"), false, true), ])?; let results = df.collect().await?; @@ -687,8 +687,8 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { ], )? .sort(vec![ - Sort::new(Box::new(col("c1")), false, true), - Sort::new(Box::new(col("c2")), false, true), + Sort::new(col("c1"), false, true), + Sort::new(col("c2"), false, true), ])?; let results = df.collect().await?; diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index da27cf8869d1..56edeab443c7 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -419,7 +419,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { } fn expressions(&self) -> Vec { - vec![self.expr.expr.as_ref().clone()] + vec![self.expr.expr.clone()] } /// For example: `TopK: k=10` diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index b81c02ccd0b7..8914214d084f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -602,7 +602,7 @@ impl TryCast { #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Sort { /// The expression to sort on - pub expr: Box, + pub expr: Expr, /// The direction of the sort pub asc: bool, /// Whether to put Nulls before all other data values @@ -611,7 +611,7 @@ pub struct Sort { impl Sort { /// Create a new Sort expression - pub fn new(expr: Box, asc: bool, nulls_first: bool) -> Self { + pub fn new(expr: Expr, asc: bool, nulls_first: bool) -> Self { Self { expr, asc, @@ -1368,7 +1368,7 @@ impl Expr { /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST /// ``` pub fn sort(self, asc: bool, nulls_first: bool) -> Sort { - Sort::new(Box::new(self), asc, nulls_first) + Sort::new(self, asc, nulls_first) } /// Return `IsTrue(Box(self))` diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 5e7fedb4cbd8..61b0f6d9bb2b 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -125,8 +125,8 @@ pub fn normalize_sorts( .into_iter() .map(|e| { let sort = e.into(); - normalize_col(*sort.expr, plan) - .map(|expr| Sort::new(Box::new(expr), sort.asc, sort.nulls_first)) + normalize_col(sort.expr, plan) + .map(|expr| Sort::new(expr, sort.asc, sort.nulls_first)) }) .collect() } diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index af5b8c4f9177..48d380cd59e2 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -35,7 +35,7 @@ pub fn rewrite_sort_cols_by_aggs( .map(|e| { let sort = e.into(); Ok(Sort::new( - Box::new(rewrite_sort_col_by_aggs(*sort.expr, plan)?), + rewrite_sort_col_by_aggs(sort.expr, plan)?, sort.asc, sort.nulls_first, )) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index f5770167861b..fc961b83f7b5 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1720,8 +1720,8 @@ mod tests { let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? .sort(vec![ - expr::Sort::new(Box::new(col("state")), true, true), - expr::Sort::new(Box::new(col("salary")), false, false), + expr::Sort::new(col("state"), true, true), + expr::Sort::new(col("salary"), false, false), ])? .build()?; @@ -2147,8 +2147,8 @@ mod tests { let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? .sort(vec![ - expr::Sort::new(Box::new(col("state")), true, true), - expr::Sort::new(Box::new(col("salary")), false, false), + expr::Sort::new(col("state"), true, true), + expr::Sort::new(col("salary"), false, false), ])? .build()?; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 8e6ec762f549..5bd6ab10331a 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2616,7 +2616,7 @@ impl DistinctOn { // Check that the left-most sort expressions are the same as the `ON` expressions. let mut matched = true; for (on, sort) in self.on_expr.iter().zip(sort_expr.iter()) { - if on != &*sort.expr { + if on != &sort.expr { matched = false; break; } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 29a99a8e8886..0964fb601879 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -509,7 +509,7 @@ impl LogicalPlan { })) => on_expr .iter() .chain(select_expr.iter()) - .chain(sort_expr.iter().flatten().map(|sort| &*sort.expr)) + .chain(sort_expr.iter().flatten().map(|sort| &sort.expr)) .apply_until_stop(f), // plans without expressions LogicalPlan::EmptyRelation(_) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 90d61bf63763..c7c498dd3f01 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -97,7 +97,7 @@ impl TreeNode for Expr { expr_vec.push(f.as_ref()); } if let Some(order_by) = order_by { - expr_vec.extend(order_by.iter().map(|sort| sort.expr.as_ref())); + expr_vec.extend(order_by.iter().map(|sort| &sort.expr)); } expr_vec } @@ -109,7 +109,7 @@ impl TreeNode for Expr { }) => { let mut expr_vec = args.iter().collect::>(); expr_vec.extend(partition_by); - expr_vec.extend(order_by.iter().map(|sort| sort.expr.as_ref())); + expr_vec.extend(order_by.iter().map(|sort| &sort.expr)); expr_vec } Expr::InList(InList { expr, list, .. }) => { @@ -395,7 +395,7 @@ pub fn transform_sort_vec Result>>( ) -> Result>> { Ok(sorts .iter() - .map(|sort| (*sort.expr).clone()) + .map(|sort| sort.expr.clone()) .map_until_stop_and_collect(&mut f)? .update_data(|transformed_exprs| { replace_sort_expressions(sorts, transformed_exprs) @@ -413,7 +413,7 @@ pub fn replace_sort_expressions(sorts: Vec, new_expr: Vec) -> Vec Sort { Sort { - expr: Box::new(new_expr), + expr: new_expr, ..sort } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index b6b1b5660a81..c4c6b076e5ba 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1401,9 +1401,9 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys() -> Result<()> { - let age_asc = expr::Sort::new(Box::new(col("age")), true, true); - let name_desc = expr::Sort::new(Box::new(col("name")), false, true); - let created_at_desc = expr::Sort::new(Box::new(col("created_at")), false, true); + let age_asc = expr::Sort::new(col("age"), true, true); + let name_desc = expr::Sort::new(col("name"), false, true); + let created_at_desc = expr::Sort::new(col("created_at"), false, true); let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], @@ -1463,12 +1463,12 @@ mod tests { for nulls_first_ in nulls_first_or_last { let order_by = &[ Sort { - expr: Box::new(col("age")), + expr: col("age"), asc: asc_, nulls_first: nulls_first_, }, Sort { - expr: Box::new(col("name")), + expr: col("name"), asc: asc_, nulls_first: nulls_first_, }, @@ -1477,7 +1477,7 @@ mod tests { let expected = vec![ ( Sort { - expr: Box::new(col("age")), + expr: col("age"), asc: asc_, nulls_first: nulls_first_, }, @@ -1485,7 +1485,7 @@ mod tests { ), ( Sort { - expr: Box::new(col("name")), + expr: col("name"), asc: asc_, nulls_first: nulls_first_, }, @@ -1493,7 +1493,7 @@ mod tests { ), ( Sort { - expr: Box::new(col("created_at")), + expr: col("created_at"), asc: true, nulls_first: false, }, diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 35d4f91e3b6f..0036f6df43f6 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -229,7 +229,7 @@ mod tests { WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) - .order_by(vec![Sort::new(Box::new(col("a")), false, true)]) + .order_by(vec![Sort::new(col("a"), false, true)]) .window_frame(WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 25bef7e2d0e4..22e9d220d324 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -328,8 +328,7 @@ impl CommonSubexprEliminate { ) -> Result> { let Sort { expr, input, fetch } = sort; let input = Arc::unwrap_or_clone(input); - let sort_expressions = - expr.iter().map(|sort| sort.expr.as_ref().clone()).collect(); + let sort_expressions = expr.iter().map(|sort| sort.expr.clone()).collect(); let new_sort = self .try_unary_plan(sort_expressions, input, config)? .update_data(|(new_expr, new_input)| { diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 3ba1cb945e9c..893255ccc8ce 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -645,12 +645,7 @@ pub fn parse_sort( codec: &dyn LogicalExtensionCodec, ) -> Result { Ok(Sort::new( - Box::new(parse_required_expr( - sort.expr.as_ref(), - registry, - "expr", - codec, - )?), + parse_required_expr(sort.expr.as_ref(), registry, "expr", codec)?, sort.asc, sort.nulls_first, )) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b937c03f79d9..63d1a007c1e5 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -637,7 +637,7 @@ where nulls_first, } = sort; Ok(protobuf::SortExprNode { - expr: Some(serialize_expr(expr.as_ref(), codec)?), + expr: Some(serialize_expr(expr, codec)?), asc: *asc, nulls_first: *nulls_first, }) diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 9c768eb73c2e..190a7e918928 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -282,7 +282,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let func_deps = schema.functional_dependencies(); // Find whether ties are possible in the given ordering let is_ordering_strict = order_by.iter().find_map(|orderby_expr| { - if let Expr::Column(col) = orderby_expr.expr.as_ref() { + if let Expr::Column(col) = &orderby_expr.expr { let idx = schema.index_of_column(col).ok()?; return if func_deps.iter().any(|dep| { dep.source_indices == vec![idx] && dep.mode == Dependency::Single diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index cdaa787cedd0..6a3a4d6ccbb7 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -100,7 +100,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; let asc = asc.unwrap_or(true); expr_vec.push(Sort::new( - Box::new(expr), + expr, asc, // when asc is true, by default nulls last to be consistent with postgres // postgres rule: https://www.postgresql.org/docs/current/queries-order.html diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 9a3f139fdee8..549635a31aef 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1761,7 +1761,7 @@ mod tests { fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), args: vec![wildcard()], partition_by: vec![], - order_by: vec![Sort::new(Box::new(col("a")), false, true)], + order_by: vec![Sort::new(col("a"), false, true)], window_frame: WindowFrame::new_bounds( datafusion_expr::WindowFrameUnits::Range, datafusion_expr::WindowFrameBound::Preceding( diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 522a08af8546..2529385849e0 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -158,7 +158,7 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( let mut collects = p.expr.clone(); for sort in &sort.expr { - collects.push(sort.expr.as_ref().clone()); + collects.push(sort.expr.clone()); } // Compare outer collects Expr::to_string with inner collected transformed values diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 05903bb56cfe..21bef3c2c98e 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -936,7 +936,7 @@ pub async fn from_substrait_sorts( }; let (asc, nulls_first) = asc_nullfirst.unwrap(); sorts.push(Sort { - expr: Box::new(expr), + expr, asc, nulls_first, }); diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 592390a285ba..e71cf04cd341 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -16,7 +16,6 @@ // under the License. use itertools::Itertools; -use std::ops::Deref; use std::sync::Arc; use arrow_buffer::ToByteSlice; @@ -819,13 +818,7 @@ fn to_substrait_sort_field( (false, false) => SortDirection::DescNullsLast, }; Ok(SortField { - expr: Some(to_substrait_rex( - ctx, - sort.expr.deref(), - schema, - 0, - extensions, - )?), + expr: Some(to_substrait_rex(ctx, &sort.expr, schema, 0, extensions)?), sort_kind: Some(SortKind::Direction(sort_kind.into())), }) }