From 6d38e6a3c037aa5be000ad4ae599071bc3786075 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Sat, 24 Aug 2024 11:39:47 +0200 Subject: [PATCH] WIP --- datafusion/core/src/dataframe/mod.rs | 6 +- .../src/datasource/file_format/options.rs | 14 +- .../core/src/datasource/listing/helpers.rs | 3 +- .../core/src/datasource/listing/table.rs | 8 +- datafusion/core/src/datasource/memory.rs | 3 +- datafusion/core/src/datasource/mod.rs | 41 ++-- .../physical_plan/file_scan_config.rs | 2 +- datafusion/core/src/datasource/stream.rs | 6 +- datafusion/core/src/physical_planner.rs | 32 ++- datafusion/core/src/test_util/mod.rs | 4 +- datafusion/expr/src/expr.rs | 113 +++++------ datafusion/expr/src/expr_fn.rs | 28 +-- datafusion/expr/src/expr_rewriter/mod.rs | 16 +- datafusion/expr/src/expr_rewriter/order_by.rs | 38 ++-- datafusion/expr/src/expr_schema.rs | 15 +- datafusion/expr/src/logical_plan/builder.rs | 32 ++- datafusion/expr/src/logical_plan/ddl.rs | 7 +- datafusion/expr/src/logical_plan/plan.rs | 46 ++--- datafusion/expr/src/logical_plan/tree_node.rs | 18 +- datafusion/expr/src/tree_node.rs | 57 +++++- datafusion/expr/src/utils.rs | 189 +++++++----------- datafusion/expr/src/window_frame.rs | 4 +- .../functions-aggregate/src/first_last.rs | 4 +- .../src/analyzer/count_wildcard_rule.rs | 2 +- .../optimizer/src/analyzer/type_coercion.rs | 7 +- .../optimizer/src/common_subexpr_eliminate.rs | 14 +- .../src/eliminate_duplicated_expr.rs | 51 ++--- datafusion/optimizer/src/eliminate_limit.rs | 6 +- datafusion/optimizer/src/push_down_filter.rs | 3 +- datafusion/optimizer/src/push_down_limit.rs | 4 +- .../simplify_expressions/expr_simplifier.rs | 1 - .../src/single_distinct_to_groupby.rs | 2 +- datafusion/proto/src/generated/prost.rs | 4 +- datafusion/sql/src/expr/function.rs | 25 +-- datafusion/sql/src/expr/order_by.rs | 8 +- datafusion/sql/src/query.rs | 3 +- datafusion/sql/src/select.rs | 4 +- datafusion/sql/src/statement.rs | 13 +- datafusion/sql/src/unparser/expr.rs | 114 +++-------- datafusion/sql/src/unparser/mod.rs | 2 - datafusion/sql/src/unparser/plan.rs | 17 +- datafusion/sql/src/unparser/rewrite.rs | 64 ++++-- 42 files changed, 452 insertions(+), 578 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index a38e7f45a6f17..b71d8e383dc32 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -51,7 +51,7 @@ use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions, }; -use datafusion_expr::{case, is_null, lit}; +use datafusion_expr::{case, is_null, lit, SortExpr}; use datafusion_expr::{ utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; @@ -576,7 +576,7 @@ impl DataFrame { self, on_expr: Vec, select_expr: Vec, - sort_expr: Option>, + sort_expr: Option>, ) -> Result { let plan = LogicalPlanBuilder::from(self.plan) .distinct_on(on_expr, select_expr, sort_expr)? @@ -796,7 +796,7 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn sort(self, expr: Vec) -> Result { + pub fn sort(self, expr: Vec) -> Result { let plan = LogicalPlanBuilder::from(self.plan).sort(expr)?.build()?; Ok(DataFrame { session_state: self.session_state, diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 552977baba17b..db90262edbf8a 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -31,7 +31,6 @@ use crate::datasource::{ }; use crate::error::Result; use crate::execution::context::{SessionConfig, SessionState}; -use crate::logical_expr::Expr; use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::config::TableOptions; @@ -41,6 +40,7 @@ use datafusion_common::{ }; use async_trait::async_trait; +use datafusion_expr::SortExpr; /// Options that control the reading of CSV files. /// @@ -84,7 +84,7 @@ pub struct CsvReadOptions<'a> { /// File compression type pub file_compression_type: FileCompressionType, /// Indicates how the file is sorted - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl<'a> Default for CsvReadOptions<'a> { @@ -199,7 +199,7 @@ impl<'a> CsvReadOptions<'a> { } /// Configure if file has known sort order - pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } @@ -231,7 +231,7 @@ pub struct ParquetReadOptions<'a> { /// based on data in file. pub schema: Option<&'a Schema>, /// Indicates how the file is sorted - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl<'a> Default for ParquetReadOptions<'a> { @@ -278,7 +278,7 @@ impl<'a> ParquetReadOptions<'a> { } /// Configure if file has known sort order - pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } @@ -397,7 +397,7 @@ pub struct NdJsonReadOptions<'a> { /// Flag indicating whether this file may be unbounded (as in a FIFO file). pub infinite: bool, /// Indicates how the file is sorted - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl<'a> Default for NdJsonReadOptions<'a> { @@ -452,7 +452,7 @@ impl<'a> NdJsonReadOptions<'a> { } /// Configure if file has known sort order - pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index b5dd2dd12e10a..0d9ded22a9a98 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -102,11 +102,10 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { } // TODO other expressions are not handled yet: - // - AGGREGATE, WINDOW and SORT should not end up in filter conditions, except maybe in some edge cases + // - AGGREGATE and WINDOW should not end up in filter conditions, except maybe in some edge cases // - Can `Wildcard` be considered as a `Literal`? // - ScalarVariable could be `applicable`, but that would require access to the context Expr::AggregateFunction { .. } - | Expr::Sort { .. } | Expr::WindowFunction { .. } | Expr::Wildcard { .. } | Expr::Unnest { .. } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 89066d8234acc..6925d5b3717d9 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -33,8 +33,8 @@ use crate::datasource::{ use crate::execution::context::SessionState; use datafusion_catalog::TableProvider; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::TableType; use datafusion_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}; +use datafusion_expr::{SortExpr, TableType}; use datafusion_physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}; use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; @@ -222,7 +222,7 @@ pub struct ListingOptions { /// ordering (encapsulated by a `Vec`). If there aren't /// multiple equivalent orderings, the outer `Vec` will have a /// single element. - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl ListingOptions { @@ -385,7 +385,7 @@ impl ListingOptions { /// /// assert_eq!(listing_options.file_sort_order, file_sort_order); /// ``` - pub fn with_file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn with_file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } @@ -1162,7 +1162,7 @@ mod tests { (vec![], Ok(vec![])), // not a sort expr ( - vec![vec![col("string_col")]], + vec![vec![col("string_col").sort(true, false)]], Err("Expected Expr::Sort in output_ordering, but got string_col"), ), // sort expr, but non column diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 44e01e71648a0..bb3af3e70138d 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -48,6 +48,7 @@ use log::debug; use parking_lot::Mutex; use tokio::sync::RwLock; use tokio::task::JoinSet; +use datafusion_expr::SortExpr; /// Type alias for partition data pub type PartitionData = Arc>>; @@ -64,7 +65,7 @@ pub struct MemTable { column_defaults: HashMap, /// Optional pre-known sort order(s). Must be `SortExpr`s. /// inserting data into this table removes the order - pub sort_order: Arc>>>, + pub sort_order: Arc>>>, } impl MemTable { diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 1c9924735735d..68d2d9133a625 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -50,38 +50,37 @@ pub use statistics::get_statistics_with_limit; use arrow_schema::{Schema, SortOptions}; use datafusion_common::{plan_err, Result}; -use datafusion_expr::Expr; +use datafusion_expr::{Expr, SortExpr}; use datafusion_physical_expr::{expressions, LexOrdering, PhysicalSortExpr}; fn create_ordering( schema: &Schema, - sort_order: &[Vec], + sort_order: &[Vec], ) -> Result> { let mut all_sort_orders = vec![]; for exprs in sort_order { // Construct PhysicalSortExpr objects from Expr objects: let mut sort_exprs = vec![]; - for expr in exprs { - match expr { - Expr::Sort(sort) => match sort.expr.as_ref() { - Expr::Column(col) => match expressions::col(&col.name, schema) { - Ok(expr) => { - sort_exprs.push(PhysicalSortExpr { - expr, - options: SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - }); - } - // Cannot find expression in the projected_schema, stop iterating - // since rest of the orderings are violated - Err(_) => break, + for sort in exprs { + match sort.expr.as_ref() { + Expr::Column(col) => match expressions::col(&col.name, schema) { + Ok(expr) => { + sort_exprs.push(PhysicalSortExpr { + expr, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); } - expr => return plan_err!("Expected single column references in output_ordering, got {expr}"), - } - expr => return plan_err!("Expected Expr::Sort in output_ordering, but got {expr}"), + // Cannot find expression in the projected_schema, stop iterating + // since rest of the orderings are violated + Err(_) => break, + }, + expr => return plan_err!( + "Expected single column references in output_ordering, got {expr}" + ), } } if !sort_exprs.is_empty() { diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 34fb6226c1a26..763dd59ab523d 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -979,7 +979,7 @@ mod tests { name: &'static str, file_schema: Schema, files: Vec, - sort: Vec, + sort: Vec, expected_result: Result>, &'static str>, } diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index b53fe86631785..ef6d195cdaff9 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -33,7 +33,7 @@ use arrow_schema::SchemaRef; use datafusion_common::{config_err, plan_err, Constraints, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; @@ -248,7 +248,7 @@ impl StreamProvider for FileStreamProvider { #[derive(Debug)] pub struct StreamConfig { source: Arc, - order: Vec>, + order: Vec>, constraints: Constraints, } @@ -263,7 +263,7 @@ impl StreamConfig { } /// Specify a sort order for the stream - pub fn with_order(mut self, order: Vec>) -> Self { + pub fn with_order(mut self, order: Vec>) -> Self { self.order = order; self } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 8d6c5089fa34d..4a2e501543b66 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -73,13 +73,13 @@ use datafusion_common::{ }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ - self, physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, + physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, StringifiedPlan, - WindowFrame, WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, SortExpr, + StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::LexOrdering; @@ -1641,31 +1641,27 @@ pub fn create_aggregate_expr_and_maybe_filter( /// Create a physical sort expression from a logical expression pub fn create_physical_sort_expr( - e: &Expr, + e: &SortExpr, input_dfschema: &DFSchema, execution_props: &ExecutionProps, ) -> Result { - if let Expr::Sort(expr::Sort { + let SortExpr { expr, asc, nulls_first, - }) = e - { - Ok(PhysicalSortExpr { - expr: create_physical_expr(expr, input_dfschema, execution_props)?, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, - }) - } else { - internal_err!("Expects a sort expression") - } + } = e; + Ok(PhysicalSortExpr { + expr: create_physical_expr(expr, input_dfschema, execution_props)?, + options: SortOptions { + descending: !asc, + nulls_first: *nulls_first, + }, + }) } /// Create vector of physical sort expression from a vector of logical expression pub fn create_physical_sort_exprs( - exprs: &[Expr], + exprs: &[SortExpr], input_dfschema: &DFSchema, execution_props: &ExecutionProps, ) -> Result { diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index ca8376fdec0a8..4f93418d0c044 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -46,7 +46,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::TableReference; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::{ expressions, AggregateExpr, EquivalenceProperties, PhysicalExpr, @@ -362,7 +362,7 @@ pub fn register_unbounded_file_with_ordering( schema: SchemaRef, file_path: &Path, table_name: &str, - file_sort_order: Vec>, + file_sort_order: Vec>, ) -> Result<()> { let source = FileStreamProvider::new_file(schema, file_path.into()); let config = StreamConfig::new(Arc::new(source)).with_order(file_sort_order); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 85ba80396c8e8..b81c02ccd0b75 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -289,10 +289,6 @@ pub enum Expr { /// Casts the expression to a given type and will return a null value if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. TryCast(TryCast), - /// A sort expression, that can be used to sort values. - /// - /// See [Expr::sort] for more details - Sort(Sort), /// Represents the call of a scalar function with a set of arguments. ScalarFunction(ScalarFunction), /// Calls an aggregate function with arguments, and optional @@ -633,6 +629,23 @@ impl Sort { } } +impl Display for Sort { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.expr)?; + if self.asc { + write!(f, " ASC")?; + } else { + write!(f, " DESC")?; + } + if self.nulls_first { + write!(f, " NULLS FIRST")?; + } else { + write!(f, " NULLS LAST")?; + } + Ok(()) + } +} + /// Aggregate function /// /// See also [`ExprFunctionExt`] to set these fields on `Expr` @@ -649,7 +662,7 @@ pub struct AggregateFunction { /// Optional filter pub filter: Option>, /// Optional ordering - pub order_by: Option>, + pub order_by: Option>, pub null_treatment: Option, } @@ -660,7 +673,7 @@ impl AggregateFunction { args: Vec, distinct: bool, filter: Option>, - order_by: Option>, + order_by: Option>, null_treatment: Option, ) -> Self { Self { @@ -785,7 +798,7 @@ pub struct WindowFunction { /// List of partition by expressions pub partition_by: Vec, /// List of order by expressions - pub order_by: Vec, + pub order_by: Vec, /// Window frame pub window_frame: window_frame::WindowFrame, /// Specifies how NULL value is treated: ignore or respect @@ -1141,7 +1154,6 @@ impl Expr { Expr::ScalarFunction(..) => "ScalarFunction", Expr::ScalarSubquery { .. } => "ScalarSubquery", Expr::ScalarVariable(..) => "ScalarVariable", - Expr::Sort { .. } => "Sort", Expr::TryCast { .. } => "TryCast", Expr::WindowFunction { .. } => "WindowFunction", Expr::Wildcard { .. } => "Wildcard", @@ -1227,14 +1239,9 @@ impl Expr { Expr::Like(Like::new(true, Box::new(self), Box::new(other), None, true)) } - /// Return the name to use for the specific Expr, recursing into - /// `Expr::Sort` as appropriate + /// Return the name to use for the specific Expr pub fn name_for_alias(&self) -> Result { - match self { - // call Expr::display_name() on a Expr::Sort will throw an error - Expr::Sort(Sort { expr, .. }) => expr.name_for_alias(), - expr => Ok(expr.schema_name().to_string()), - } + Ok(self.schema_name().to_string()) } /// Ensure `expr` has the name as `original_name` by adding an @@ -1250,14 +1257,7 @@ impl Expr { /// Return `self AS name` alias expression pub fn alias(self, name: impl Into) -> Expr { - match self { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => Expr::Sort(Sort::new(Box::new(expr.alias(name)), asc, nulls_first)), - _ => Expr::Alias(Alias::new(self, None::<&str>, name.into())), - } + Expr::Alias(Alias::new(self, None::<&str>, name.into())) } /// Return `self AS name` alias expression with a specific qualifier @@ -1266,18 +1266,7 @@ impl Expr { relation: Option>, name: impl Into, ) -> Expr { - match self { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => Expr::Sort(Sort::new( - Box::new(expr.alias_qualified(relation, name)), - asc, - nulls_first, - )), - _ => Expr::Alias(Alias::new(self, relation, name.into())), - } + Expr::Alias(Alias::new(self, relation, name.into())) } /// Remove an alias from an expression if one exists. @@ -1372,14 +1361,14 @@ impl Expr { Expr::IsNotNull(Box::new(self)) } - /// Create a sort expression from an existing expression. + /// Create a sort configuration from an existing expression. /// /// ``` /// # use datafusion_expr::col; /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST /// ``` - pub fn sort(self, asc: bool, nulls_first: bool) -> Expr { - Expr::Sort(Sort::new(Box::new(self), asc, nulls_first)) + pub fn sort(self, asc: bool, nulls_first: bool) -> Sort { + Sort::new(Box::new(self), asc, nulls_first) } /// Return `IsTrue(Box(self))` @@ -1655,7 +1644,6 @@ impl Expr { | Expr::Wildcard { .. } | Expr::WindowFunction(..) | Expr::Literal(..) - | Expr::Sort(..) | Expr::Placeholder(..) => false, } } @@ -1752,14 +1740,6 @@ impl Expr { }) => { data_type.hash(hasher); } - Expr::Sort(Sort { - expr: _expr, - asc, - nulls_first, - }) => { - asc.hash(hasher); - nulls_first.hash(hasher); - } Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { func.hash(hasher); } @@ -1871,7 +1851,6 @@ impl<'a> Display for SchemaDisplay<'a> { Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(..) - | Expr::Sort(_) | Expr::OuterReferenceColumn(..) | Expr::Placeholder(_) | Expr::Wildcard { .. } => write!(f, "{}", self.0), @@ -1901,7 +1880,7 @@ impl<'a> Display for SchemaDisplay<'a> { }; if let Some(order_by) = order_by { - write!(f, " ORDER BY [{}]", schema_name_from_exprs(order_by)?)?; + write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; }; Ok(()) @@ -2107,7 +2086,7 @@ impl<'a> Display for SchemaDisplay<'a> { } if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", schema_name_from_exprs(order_by)?)?; + write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; }; write!(f, " {window_frame}") @@ -2144,6 +2123,24 @@ fn schema_name_from_exprs_inner(exprs: &[Expr], sep: &str) -> Result Result { + let mut s = String::new(); + for (i, e) in sorts.iter().enumerate() { + if i > 0 { + write!(&mut s, ", ")?; + } + let ordering = if e.asc { "ASC" } else { "DESC" }; + let nulls_ordering = if e.nulls_first { + "NULLS FIRST" + } else { + "NULLS LAST" + }; + write!(&mut s, "{} {} {}", e.expr, ordering, nulls_ordering)?; + } + + Ok(s) +} + /// Format expressions for display as part of a logical plan. In many cases, this will produce /// similar output to `Expr.name()` except that column names will be prefixed with '#'. impl fmt::Display for Expr { @@ -2203,22 +2200,6 @@ impl fmt::Display for Expr { }) => write!(f, "{expr} IN ({subquery:?})"), Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"), Expr::BinaryExpr(expr) => write!(f, "{expr}"), - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - if *asc { - write!(f, "{expr} ASC")?; - } else { - write!(f, "{expr} DESC")?; - } - if *nulls_first { - write!(f, " NULLS FIRST") - } else { - write!(f, " NULLS LAST") - } - } Expr::ScalarFunction(fun) => { fmt_function(f, fun.name(), false, &fun.args, true) } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4e6022399653b..825874bccb842 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -26,9 +26,9 @@ use crate::function::{ StateFieldsArgs, }; use crate::{ - conditional_expressions::CaseBuilder, logical_plan::Subquery, AggregateUDF, Expr, - LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature, - Volatility, + conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery, + AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, + Signature, Volatility, }; use crate::{ AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, @@ -723,9 +723,7 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { /// ``` pub trait ExprFunctionExt { /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - fn order_by(self, order_by: Vec) -> ExprFuncBuilder; + fn order_by(self, order_by: Vec) -> ExprFuncBuilder; /// Add `FILTER ` fn filter(self, filter: Expr) -> ExprFuncBuilder; /// Add `DISTINCT` @@ -753,7 +751,7 @@ pub enum ExprFuncKind { #[derive(Debug, Clone)] pub struct ExprFuncBuilder { fun: Option, - order_by: Option>, + order_by: Option>, filter: Option, distinct: bool, null_treatment: Option, @@ -798,16 +796,6 @@ impl ExprFuncBuilder { ); }; - if let Some(order_by) = &order_by { - for expr in order_by.iter() { - if !matches!(expr, Expr::Sort(_)) { - return plan_err!( - "ORDER BY expressions must be Expr::Sort, found {expr:?}" - ); - } - } - } - let fun_expr = match fun { ExprFuncKind::Aggregate(mut udaf) => { udaf.order_by = order_by; @@ -833,9 +821,7 @@ impl ExprFuncBuilder { impl ExprFunctionExt for ExprFuncBuilder { /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { + fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { self.order_by = Some(order_by); self } @@ -873,7 +859,7 @@ impl ExprFunctionExt for ExprFuncBuilder { } impl ExprFunctionExt for Expr { - fn order_by(self, order_by: Vec) -> ExprFuncBuilder { + fn order_by(self, order_by: Vec) -> ExprFuncBuilder { let mut builder = match self { Expr::AggregateFunction(udaf) => { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 768c4aabc840d..375ef1edf49a1 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -21,7 +21,7 @@ use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; -use crate::expr::{Alias, Unnest}; +use crate::expr::{Alias, Sort, Unnest}; use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; @@ -117,6 +117,20 @@ pub fn normalize_cols( .collect() } +pub fn normalize_sorts( + sorts: impl IntoIterator>, + plan: &LogicalPlan, +) -> Result> { + sorts + .into_iter() + .map(|e| { + let sort = e.into(); + normalize_col(*sort.expr, plan) + .map(|expr| Sort::new(Box::new(expr), sort.asc, sort.nulls_first)) + }) + .collect() +} + /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index bbb855801c3ea..af5b8c4f9177a 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -17,9 +17,9 @@ //! Rewrite for order by expressions -use crate::expr::{Alias, Sort}; +use crate::expr::Alias; use crate::expr_rewriter::normalize_col; -use crate::{Cast, Expr, ExprSchemable, LogicalPlan, TryCast}; +use crate::{expr::Sort, Cast, Expr, ExprSchemable, LogicalPlan, TryCast}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Column, Result}; @@ -27,28 +27,18 @@ use datafusion_common::{Column, Result}; /// Rewrite sort on aggregate expressions to sort on the column of aggregate output /// For example, `max(x)` is written to `col("max(x)")` pub fn rewrite_sort_cols_by_aggs( - exprs: impl IntoIterator>, + sorts: impl IntoIterator>, plan: &LogicalPlan, -) -> Result> { - exprs +) -> Result> { + sorts .into_iter() .map(|e| { - let expr = e.into(); - match expr { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let sort = Expr::Sort(Sort::new( - Box::new(rewrite_sort_col_by_aggs(*expr, plan)?), - asc, - nulls_first, - )); - Ok(sort) - } - expr => Ok(expr), - } + let sort = e.into(); + Ok(Sort::new( + Box::new(rewrite_sort_col_by_aggs(*sort.expr, plan)?), + sort.asc, + sort.nulls_first, + )) }) .collect() } @@ -289,8 +279,8 @@ mod test { struct TestCase { desc: &'static str, - input: Expr, - expected: Expr, + input: Sort, + expected: Sort, } impl TestCase { @@ -332,7 +322,7 @@ mod test { .unwrap() } - fn sort(expr: Expr) -> Expr { + fn sort(expr: Expr) -> Sort { let asc = true; let nulls_first = true; expr.sort(asc, nulls_first) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 10ec10e61239f..cb364f5309299 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -18,7 +18,7 @@ use super::{Between, Expr, Like}; use crate::expr::{ AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder, - ScalarFunction, Sort, TryCast, Unnest, WindowFunction, + ScalarFunction, TryCast, Unnest, WindowFunction, }; use crate::type_coercion::binary::get_result_type; use crate::type_coercion::functions::{ @@ -107,7 +107,7 @@ impl ExprSchemable for Expr { }, _ => expr.get_type(schema), }, - Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => expr.get_type(schema), + Expr::Negative(expr) => expr.get_type(schema), Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), @@ -280,10 +280,9 @@ impl ExprSchemable for Expr { /// column that does not exist in the schema. fn nullable(&self, input_schema: &dyn ExprSchema) -> Result { match self { - Expr::Alias(Alias { expr, .. }) - | Expr::Not(expr) - | Expr::Negative(expr) - | Expr::Sort(Sort { expr, .. }) => expr.nullable(input_schema), + Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => { + expr.nullable(input_schema) + } Expr::InList(InList { expr, list, .. }) => { // Avoid inspecting too many expressions. @@ -422,9 +421,7 @@ impl ExprSchemable for Expr { }, _ => expr.data_type_and_nullable(schema), }, - Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => { - expr.data_type_and_nullable(schema) - } + Expr::Negative(expr) => expr.data_type_and_nullable(schema), Expr::Column(c) => schema .data_type_and_nullable(c) .map(|(d, n)| (d.clone(), n)), diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 559908bcfdfa4..5e06041a6f0a9 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -23,10 +23,10 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use crate::dml::CopyTo; -use crate::expr::Alias; +use crate::expr::{Alias, Sort as SortExpr}; use crate::expr_rewriter::{ coerce_plan_expr_for_schema, normalize_col, - normalize_col_with_schemas_and_ambiguity_check, normalize_cols, + normalize_col_with_schemas_and_ambiguity_check, normalize_cols, normalize_sorts, rewrite_sort_cols_by_aggs, }; use crate::logical_plan::{ @@ -542,22 +542,34 @@ impl LogicalPlanBuilder { plan_err!("For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list") } + /// Apply a sort by provided expressions with default direction + pub fn sort_by( + self, + expr: impl IntoIterator> + Clone, + ) -> Result { + self.sort( + expr.into_iter() + .map(|e| e.into().sort(true, false)) + .collect::>(), + ) + } + /// Apply a sort pub fn sort( self, - exprs: impl IntoIterator> + Clone, + sorts: impl IntoIterator> + Clone, ) -> Result { - let exprs = rewrite_sort_cols_by_aggs(exprs, &self.plan)?; + let sorts = rewrite_sort_cols_by_aggs(sorts, &self.plan)?; let schema = self.plan.schema(); // Collect sort columns that are missing in the input plan's schema let mut missing_cols: Vec = vec![]; - exprs + sorts .clone() .into_iter() - .try_for_each::<_, Result<()>>(|expr| { - let columns = expr.column_refs(); + .try_for_each::<_, Result<()>>(|sort| { + let columns = sort.expr.column_refs(); columns.into_iter().for_each(|c| { if !schema.has_column(c) { @@ -570,7 +582,7 @@ impl LogicalPlanBuilder { if missing_cols.is_empty() { return Ok(Self::new(LogicalPlan::Sort(Sort { - expr: normalize_cols(exprs, &self.plan)?, + expr: normalize_sorts(sorts, &self.plan)?, input: self.plan, fetch: None, }))); @@ -583,7 +595,7 @@ impl LogicalPlanBuilder { let plan = Self::add_missing_columns(unwrap_arc(self.plan), &missing_cols, is_distinct)?; let sort_plan = LogicalPlan::Sort(Sort { - expr: normalize_cols(exprs, &plan)?, + expr: normalize_sorts(sorts, &plan)?, input: Arc::new(plan), fetch: None, }); @@ -619,7 +631,7 @@ impl LogicalPlanBuilder { self, on_expr: Vec, select_expr: Vec, - sort_expr: Option>, + sort_expr: Option>, ) -> Result { Ok(Self::new(LogicalPlan::Distinct(Distinct::On( DistinctOn::try_new(on_expr, select_expr, sort_expr, self.plan)?, diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index ad0fcd2d47712..3fc43200efe64 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -22,8 +22,9 @@ use std::{ hash::{Hash, Hasher}, }; -use crate::{Expr, LogicalPlan, Volatility}; +use crate::{Expr, LogicalPlan, SortExpr, Volatility}; +use crate::expr::Sort; use arrow::datatypes::DataType; use datafusion_common::{Constraints, DFSchemaRef, SchemaReference, TableReference}; use sqlparser::ast::Ident; @@ -204,7 +205,7 @@ pub struct CreateExternalTable { /// SQL used to create the table, if available pub definition: Option, /// Order expressions supplied by user - pub order_exprs: Vec>, + pub order_exprs: Vec>, /// Whether the table is an infinite streams pub unbounded: bool, /// Table(provider) specific options @@ -365,7 +366,7 @@ pub struct CreateIndex { pub name: Option, pub table: TableReference, pub using: Option, - pub columns: Vec, + pub columns: Vec, pub unique: bool, pub if_not_exists: bool, pub schema: DFSchemaRef, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 3ede7f25b753b..74f860fa91b4c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -26,7 +26,9 @@ use super::dml::CopyTo; use super::DdlStatement; use crate::builder::{change_redundant_column, unnest_with_options}; use crate::expr::{Placeholder, Sort as SortExpr, WindowFunction}; -use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols, NamePreserver}; +use crate::expr_rewriter::{ + create_col_from_scalar_expr, normalize_cols, normalize_sorts, NamePreserver, +}; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; @@ -885,11 +887,15 @@ impl LogicalPlan { Aggregate::try_new(Arc::new(inputs.swap_remove(0)), expr, agg_expr) .map(LogicalPlan::Aggregate) } - LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort { - expr, - input: Arc::new(inputs.swap_remove(0)), - fetch: *fetch, - })), + LogicalPlan::Sort(Sort { /*fetch,*/ .. }) => { + // TODO FIXME + // LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort { + // expr, + // input: Arc::new(inputs.swap_remove(0)), + // fetch: *fetch, + // })), + internal_err!("with_ew_exprs not implemented for Sort") + }, LogicalPlan::Join(Join { join_type, join_constraint, @@ -1015,14 +1021,11 @@ impl LogicalPlan { }) => { let sort_expr = expr.split_off(on_expr.len() + select_expr.len()); let select_expr = expr.split_off(on_expr.len()); + assert!(sort_expr.is_empty(), "with_new_exprs for Distinct does not support sort expressions"); Distinct::On(DistinctOn::try_new( expr, select_expr, - if !sort_expr.is_empty() { - Some(sort_expr) - } else { - None - }, + None, // no sort expressions accepted Arc::new(inputs.swap_remove(0)), )?) } @@ -2563,7 +2566,7 @@ pub struct DistinctOn { /// The `ORDER BY` clause, whose initial expressions must match those of the `ON` clause when /// present. Note that those matching expressions actually wrap the `ON` expressions with /// additional info pertaining to the sorting procedure (i.e. ASC/DESC, and NULLS FIRST/LAST). - pub sort_expr: Option>, + pub sort_expr: Option>, /// The logical plan that is being DISTINCT'd pub input: Arc, /// The schema description of the DISTINCT ON output @@ -2575,7 +2578,7 @@ impl DistinctOn { pub fn try_new( on_expr: Vec, select_expr: Vec, - sort_expr: Option>, + sort_expr: Option>, input: Arc, ) -> Result { if on_expr.is_empty() { @@ -2610,20 +2613,15 @@ impl DistinctOn { /// Try to update `self` with a new sort expressions. /// /// Validates that the sort expressions are a super-set of the `ON` expressions. - pub fn with_sort_expr(mut self, sort_expr: Vec) -> Result { - let sort_expr = normalize_cols(sort_expr, self.input.as_ref())?; + pub fn with_sort_expr(mut self, sort_expr: Vec) -> Result { + let sort_expr = normalize_sorts(sort_expr, self.input.as_ref())?; // Check that the left-most sort expressions are the same as the `ON` expressions. let mut matched = true; for (on, sort) in self.on_expr.iter().zip(sort_expr.iter()) { - match sort { - Expr::Sort(SortExpr { expr, .. }) => { - if on != &**expr { - matched = false; - break; - } - } - _ => return plan_err!("Not a sort expression: {sort}"), + if on != &*sort.expr { + matched = false; + break; } } @@ -2837,7 +2835,7 @@ fn calc_func_dependencies_for_project( #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Sort { /// The sort expressions - pub expr: Vec, + pub expr: Vec, /// The incoming logical plan pub input: Arc, /// Optional fetch limit diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 539cb1cf5fb22..297e8e4a01c02 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -46,7 +46,7 @@ use crate::{ use std::sync::Arc; use crate::expr::{Exists, InSubquery}; -use crate::tree_node::transform_option_vec; +use crate::tree_node::{transform_sort_option_vec, transform_sort_vec}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, @@ -490,7 +490,9 @@ impl LogicalPlan { .apply_until_stop(|e| f(&e))? .visit_sibling(|| filter.iter().apply_until_stop(f)) } - LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().apply_until_stop(f), + LogicalPlan::Sort(Sort { expr, .. }) => { + expr.iter().apply_until_stop(|sort| f(&*sort.expr)) + } LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs @@ -516,7 +518,7 @@ impl LogicalPlan { })) => on_expr .iter() .chain(select_expr.iter()) - .chain(sort_expr.iter().flatten()) + .chain(sort_expr.iter().flatten().map(|sort| &*sort.expr)) .apply_until_stop(f), // plans without expressions LogicalPlan::EmptyRelation(_) @@ -667,10 +669,10 @@ impl LogicalPlan { null_equals_null, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), + LogicalPlan::Sort(Sort { expr, input, fetch }) => { + transform_sort_vec(expr, &mut f)? + .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })) + } LogicalPlan::Extension(Extension { node }) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs @@ -718,7 +720,7 @@ impl LogicalPlan { select_expr, select_expr.into_iter().map_until_stop_and_collect(&mut f), sort_expr, - transform_option_vec(sort_expr, &mut f) + transform_sort_option_vec(sort_expr, &mut f) )? .update_data(|(on_expr, select_expr, sort_expr)| { LogicalPlan::Distinct(Distinct::On(DistinctOn { diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 450ebb6c22752..472e3089cecd5 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -48,7 +48,6 @@ impl TreeNode for Expr { | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) - | Expr::Sort(Sort { expr, .. }) | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref()], Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().collect(), @@ -98,7 +97,7 @@ impl TreeNode for Expr { expr_vec.push(f.as_ref()); } if let Some(order_by) = order_by { - expr_vec.extend(order_by); + expr_vec.extend(order_by.into_iter().map(|sort| sort.expr.as_ref())); } expr_vec } @@ -110,7 +109,7 @@ impl TreeNode for Expr { }) => { let mut expr_vec = args.iter().collect::>(); expr_vec.extend(partition_by); - expr_vec.extend(order_by); + expr_vec.extend(order_by.into_iter().map(|sort| sort.expr.as_ref())); expr_vec } Expr::InList(InList { expr, list, .. }) => { @@ -265,12 +264,6 @@ impl TreeNode for Expr { .update_data(|be| Expr::Cast(Cast::new(be, data_type))), Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)? .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => transform_box(expr, &mut f)? - .update_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))), Expr::ScalarFunction(ScalarFunction { func, args }) => { transform_vec(args, &mut f)?.map_data(|new_args| { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( @@ -290,7 +283,7 @@ impl TreeNode for Expr { partition_by, transform_vec(partition_by, &mut f), order_by, - transform_vec(order_by, &mut f) + transform_sort_vec(order_by, &mut f) )? .update_data(|(new_args, new_partition_by, new_order_by)| { Expr::WindowFunction(WindowFunction::new(fun, new_args)) @@ -313,7 +306,7 @@ impl TreeNode for Expr { filter, transform_option_box(filter, &mut f), order_by, - transform_option_vec(order_by, &mut f) + transform_sort_option_vec(order_by, &mut f) )? .map_data(|(new_args, new_filter, new_order_by)| { Ok(Expr::AggregateFunction(AggregateFunction::new_udf( @@ -386,3 +379,45 @@ fn transform_vec Result>>( ) -> Result>> { ve.into_iter().map_until_stop_and_collect(f) } + +pub fn transform_sort_option_vec Result>>( + sorts_option: Option>, + f: &mut F, +) -> Result>>> { + sorts_option.map_or(Ok(Transformed::no(None)), |sorts| { + Ok(transform_sort_vec(sorts, f)?.update_data(Some)) + }) +} + +pub fn transform_sort_vec Result>>( + sorts: Vec, + mut f: &mut F, +) -> Result>> { + Ok(sorts + .iter() + .map(|sort| (*sort.expr).clone()) + .map_until_stop_and_collect(&mut f)? + .update_data(|transformed_exprs| { + replace_sort_expressions(sorts, transformed_exprs) + })) +} + +pub fn replace_sort_expressions(sorts: Vec, new_expr: Vec) -> Vec { + if sorts.len() != new_expr.len() { + panic!( + "Incorrect number of new_expr, expected {}, got {}", + sorts.len(), + new_expr.len() + ); + } + + let mut new_sorts = Vec::with_capacity(sorts.len()); + for (i, expr) in new_expr.into_iter().enumerate() { + new_sorts.push(Sort { + expr: Box::new(expr), + asc: sorts[i].asc, + nulls_first: sorts[i].nulls_first, + }); + } + new_sorts +} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 5a1e476841636..7818ce587588a 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -296,7 +296,6 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Case { .. } | Expr::Cast { .. } | Expr::TryCast { .. } - | Expr::Sort { .. } | Expr::ScalarFunction(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } @@ -461,22 +460,20 @@ pub fn expand_qualified_wildcard( /// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)") /// if bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column -type WindowSortKey = Vec<(Expr, bool)>; +type WindowSortKey = Vec<(Sort, bool)>; /// Generate a sort key for a given window expr's partition_by and order_bu expr pub fn generate_sort_key( partition_by: &[Expr], - order_by: &[Expr], + order_by: &[Sort], ) -> Result { let normalized_order_by_keys = order_by .iter() - .map(|e| match e { - Expr::Sort(Sort { expr, .. }) => { - Ok(Expr::Sort(Sort::new(expr.clone(), true, false))) - } - _ => plan_err!("Order by only accepts sort expressions"), + .map(|e| { + let Sort { expr, .. } = e; + Sort::new(expr.clone(), true, false) }) - .collect::>>()?; + .collect::>(); let mut final_sort_keys = vec![]; let mut is_partition_flag = vec![]; @@ -512,65 +509,61 @@ pub fn generate_sort_key( /// Compare the sort expr as PostgreSQL's common_prefix_cmp(): /// pub fn compare_sort_expr( - sort_expr_a: &Expr, - sort_expr_b: &Expr, + sort_expr_a: &Sort, + sort_expr_b: &Sort, schema: &DFSchemaRef, ) -> Ordering { - match (sort_expr_a, sort_expr_b) { - ( - Expr::Sort(Sort { - expr: expr_a, - asc: asc_a, - nulls_first: nulls_first_a, - }), - Expr::Sort(Sort { - expr: expr_b, - asc: asc_b, - nulls_first: nulls_first_b, - }), - ) => { - let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema); - let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema); - for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) { - match idx_a.cmp(idx_b) { - Ordering::Less => { - return Ordering::Less; - } - Ordering::Greater => { - return Ordering::Greater; - } - Ordering::Equal => {} - } + let Sort { + expr: expr_a, + asc: asc_a, + nulls_first: nulls_first_a, + } = sort_expr_a; + + let Sort { + expr: expr_b, + asc: asc_b, + nulls_first: nulls_first_b, + } = sort_expr_b; + + let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema); + let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema); + for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) { + match idx_a.cmp(idx_b) { + Ordering::Less => { + return Ordering::Less; } - match ref_indexes_a.len().cmp(&ref_indexes_b.len()) { - Ordering::Less => return Ordering::Greater, - Ordering::Greater => { - return Ordering::Less; - } - Ordering::Equal => {} + Ordering::Greater => { + return Ordering::Greater; } - match (asc_a, asc_b) { - (true, false) => { - return Ordering::Greater; - } - (false, true) => { - return Ordering::Less; - } - _ => {} - } - match (nulls_first_a, nulls_first_b) { - (true, false) => { - return Ordering::Less; - } - (false, true) => { - return Ordering::Greater; - } - _ => {} - } - Ordering::Equal + Ordering::Equal => {} } - _ => panic!("Sort expressions must be of type Sort"), } + match ref_indexes_a.len().cmp(&ref_indexes_b.len()) { + Ordering::Less => return Ordering::Greater, + Ordering::Greater => { + return Ordering::Less; + } + Ordering::Equal => {} + } + match (asc_a, asc_b) { + (true, false) => { + return Ordering::Greater; + } + (false, true) => { + return Ordering::Less; + } + _ => {} + } + match (nulls_first_a, nulls_first_b) { + (true, false) => { + return Ordering::Less; + } + (false, true) => { + return Ordering::Greater; + } + _ => {} + } + Ordering::Equal } /// group a slice of window expression expr by their order by expressions @@ -606,14 +599,6 @@ pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { }) } -/// Collect all deeply nested `Expr::Sort`. They are returned in order of occurrence -/// (depth first), with duplicates omitted. -pub fn find_sort_exprs(exprs: &[Expr]) -> Vec { - find_exprs_in_exprs(exprs, &|nested_expr| { - matches!(nested_expr, Expr::Sort { .. }) - }) -} - /// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence /// (depth first), with duplicates omitted. pub fn find_window_exprs(exprs: &[Expr]) -> Vec { @@ -1417,10 +1402,9 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys() -> Result<()> { - let age_asc = Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)); - let name_desc = Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)); - let created_at_desc = - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); + let age_asc = expr::Sort::new(Box::new(col("age")), true, true); + let name_desc = expr::Sort::new(Box::new(col("name")), false, true); + let created_at_desc = expr::Sort::new(Box::new(col("created_at")), false, true); let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], @@ -1471,43 +1455,6 @@ mod tests { Ok(()) } - #[test] - fn test_find_sort_exprs() -> Result<()> { - let exprs = &[ - Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateUDF(max_udaf()), - vec![col("name")], - )) - .order_by(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - ]) - .window_frame(WindowFrame::new(Some(false))) - .build() - .unwrap(), - Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")], - )) - .order_by(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ]) - .window_frame(WindowFrame::new(Some(false))) - .build() - .unwrap(), - ]; - let expected = vec![ - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ]; - let result = find_sort_exprs(exprs); - assert_eq!(expected, result); - Ok(()) - } - #[test] fn avoid_generate_duplicate_sort_keys() -> Result<()> { let asc_or_desc = [true, false]; @@ -1516,41 +1463,41 @@ mod tests { for asc_ in asc_or_desc { for nulls_first_ in nulls_first_or_last { let order_by = &[ - Expr::Sort(Sort { + Sort { expr: Box::new(col("age")), asc: asc_, nulls_first: nulls_first_, - }), - Expr::Sort(Sort { + }, + Sort { expr: Box::new(col("name")), asc: asc_, nulls_first: nulls_first_, - }), + }, ]; let expected = vec![ ( - Expr::Sort(Sort { + Sort { expr: Box::new(col("age")), asc: asc_, nulls_first: nulls_first_, - }), + }, true, ), ( - Expr::Sort(Sort { + Sort { expr: Box::new(col("name")), asc: asc_, nulls_first: nulls_first_, - }), + }, true, ), ( - Expr::Sort(Sort { + Sort { expr: Box::new(col("created_at")), asc: true, nulls_first: false, - }), + }, true, ), ]; diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 0e1d917419f8d..6c935cdcd1217 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -26,7 +26,7 @@ use std::fmt::{self, Formatter}; use std::hash::Hash; -use crate::{lit, Expr}; +use crate::{expr::Sort, lit}; use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; @@ -247,7 +247,7 @@ impl WindowFrame { } /// Regularizes the ORDER BY clause of the window frame. - pub fn regularize_order_bys(&self, order_by: &mut Vec) -> Result<()> { + pub fn regularize_order_bys(&self, order_by: &mut Vec) -> Result<()> { match self.units { // Normally, RANGE frames require an ORDER BY clause with exactly // one column. However, an ORDER BY clause may be absent or have diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 2162442f054ed..30f5d5b07561b 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -32,7 +32,7 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, ExprFunctionExt, - Signature, TypeSignature, Volatility, + Signature, SortExpr, TypeSignature, Volatility, }; use datafusion_functions_aggregate_common::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -40,7 +40,7 @@ use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; create_func!(FirstValue, first_value_udaf); /// Returns the first value in a group of values. -pub fn first_value(expression: Expr, order_by: Option>) -> Expr { +pub fn first_value(expression: Expr, order_by: Option>) -> Expr { if let Some(order_by) = order_by { first_value_udaf() .call(vec![expression]) diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index e114efb99960e..35d4f91e3b6f0 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -229,7 +229,7 @@ mod tests { WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) - .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .order_by(vec![Sort::new(Box::new(col("a")), false, true)]) .window_frame(WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 68ab2e13005f3..344c9ca1e1052 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -33,7 +33,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, Alias, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, - ScalarFunction, WindowFunction, + ScalarFunction, Sort, WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; @@ -491,7 +491,6 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { | Expr::Negative(_) | Expr::Cast(_) | Expr::TryCast(_) - | Expr::Sort(_) | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) @@ -578,12 +577,12 @@ fn coerce_frame_bound( fn coerce_window_frame( window_frame: WindowFrame, schema: &DFSchema, - expressions: &[Expr], + expressions: &[Sort], ) -> Result { let mut window_frame = window_frame; let current_types = expressions .iter() - .map(|e| e.get_type(schema)) + .map(|s| s.expr.get_type(schema)) .collect::>>()?; let target_type = match window_frame.units { WindowFrameUnits::Range => { diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index feccf5679efbc..7d9f1eedc13a7 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -37,6 +37,7 @@ use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; +use datafusion_expr::tree_node::replace_sort_expressions; use datafusion_expr::{col, BinaryExpr, Case, Expr, ExprSchemable, Operator}; use indexmap::IndexMap; @@ -328,15 +329,17 @@ impl CommonSubexprEliminate { ) -> Result> { let Sort { expr, input, fetch } = sort; let input = unwrap_arc(input); - let new_sort = self.try_unary_plan(expr, input, config)?.update_data( - |(new_expr, new_input)| { + let sort_expressions = + expr.iter().map(|sort| sort.expr.as_ref().clone()).collect(); + let new_sort = self + .try_unary_plan(sort_expressions, input, config)? + .update_data(|(new_expr, new_input)| { LogicalPlan::Sort(Sort { - expr: new_expr, + expr: replace_sort_expressions(expr, new_expr), input: Arc::new(new_input), fetch, }) - }, - ); + }); Ok(new_sort) } @@ -900,7 +903,6 @@ impl ExprMask { | Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Alias(..) - | Expr::Sort { .. } | Expr::Wildcard { .. } ); diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index e9d091d52b00b..201aea6bf2b8f 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -22,9 +22,8 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::{Aggregate, Expr, Sort}; -use indexmap::IndexSet; -use std::hash::{Hash, Hasher}; +use datafusion_expr::{Aggregate, Expr, Sort, SortExpr}; +use indexmap::{IndexMap, IndexSet}; /// Optimization rule that eliminate duplicated expr. #[derive(Default)] pub struct EliminateDuplicatedExpr; @@ -35,33 +34,6 @@ impl EliminateDuplicatedExpr { Self {} } } -// use this structure to avoid initial clone -#[derive(Eq, Clone, Debug)] -struct SortExprWrapper { - expr: Expr, -} -impl PartialEq for SortExprWrapper { - fn eq(&self, other: &Self) -> bool { - match (&self.expr, &other.expr) { - (Expr::Sort(own_sort), Expr::Sort(other_sort)) => { - own_sort.expr == other_sort.expr - } - _ => self.expr == other.expr, - } - } -} -impl Hash for SortExprWrapper { - fn hash(&self, state: &mut H) { - match &self.expr { - Expr::Sort(sort) => { - sort.expr.hash(state); - } - _ => { - self.expr.hash(state); - } - } - } -} impl OptimizerRule for EliminateDuplicatedExpr { fn apply_order(&self) -> Option { Some(ApplyOrder::TopDown) @@ -79,14 +51,15 @@ impl OptimizerRule for EliminateDuplicatedExpr { match plan { LogicalPlan::Sort(sort) => { let len = sort.expr.len(); - let unique_exprs: Vec<_> = sort - .expr - .into_iter() - .map(|e| SortExprWrapper { expr: e }) - .collect::>() - .into_iter() - .map(|wrapper| wrapper.expr) - .collect(); + let mut first_sort_by_expr: IndexMap = + IndexMap::default(); + for s in &sort.expr { + first_sort_by_expr + .entry(s.expr.as_ref().clone()) + .or_insert(s.clone()); + } + let unique_exprs: Vec = + first_sort_by_expr.into_values().collect(); let transformed = if len != unique_exprs.len() { Transformed::yes @@ -146,7 +119,7 @@ mod tests { fn eliminate_sort_expr() -> Result<()> { let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) - .sort(vec![col("a"), col("a"), col("b"), col("c")])? + .sort_by(vec![col("a"), col("a"), col("b"), col("c")])? .limit(5, Some(10))? .build()?; let expected = "Limit: skip=5, fetch=10\ diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index a42fe6a6f95b7..516a7afb738b7 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -180,7 +180,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? .limit(0, Some(2))? - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(2, Some(1))? .build()?; @@ -200,7 +200,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? .limit(0, Some(2))? - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(0, Some(1))? .build()?; @@ -218,7 +218,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? .limit(2, Some(1))? - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(3, Some(1))? .build()?; diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 8455919c35a83..58e67990dabe9 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -285,8 +285,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::TryCast(_) | Expr::InList { .. } | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue), - Expr::Sort(_) - | Expr::AggregateFunction(_) + Expr::AggregateFunction(_) | Expr::WindowFunction(_) | Expr::Wildcard { .. } | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 290b893577b82..fb68a79c70c2e 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -348,7 +348,7 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(0, Some(10))? .build()?; @@ -365,7 +365,7 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(5, Some(10))? .build()?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c45df74a564da..8997a711aace7 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -591,7 +591,6 @@ impl<'a> ConstEvaluator<'a> { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::WindowFunction { .. } - | Expr::Sort { .. } | Expr::GroupingSet(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) => false, diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 30cae17eaf9f8..0ffa075016763 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -624,7 +624,7 @@ mod tests { vec![col("a")], false, None, - Some(vec![col("a")]), + Some(vec![col("a").sort(true, false)]), None, )); let plan = LogicalPlanBuilder::from(table_scan) diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 875d2af75dd78..aff374d1b87b7 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -521,8 +521,6 @@ pub mod logical_expr_node { Case(::prost::alloc::boxed::Box), #[prost(message, tag = "11")] Cast(::prost::alloc::boxed::Box), - #[prost(message, tag = "12")] - Sort(::prost::alloc::boxed::Box), #[prost(message, tag = "13")] Negative(::prost::alloc::boxed::Box), #[prost(message, tag = "14")] @@ -762,7 +760,7 @@ pub struct WindowExprNode { #[prost(message, repeated, tag = "5")] pub partition_by: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "6")] - pub order_by: ::prost::alloc::vec::Vec, + pub order_by: ::prost::alloc::vec::Vec, /// repeated LogicalExprNode filter = 7; #[prost(message, optional, tag = "8")] pub window_frame: ::core::option::Option, diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 72e08e4b8fb52..cfc10fbc39d1a 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -283,22 +283,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let func_deps = schema.functional_dependencies(); // Find whether ties are possible in the given ordering let is_ordering_strict = order_by.iter().find_map(|orderby_expr| { - if let Expr::Sort(sort_expr) = orderby_expr { - if let Expr::Column(col) = sort_expr.expr.as_ref() { - let idx = schema.index_of_column(col).ok()?; - return if func_deps.iter().any(|dep| { - dep.source_indices == vec![idx] - && dep.mode == Dependency::Single - }) { - Some(true) - } else { - Some(false) - }; - } - Some(false) - } else { - panic!("order_by expression must be of type Sort"); + if let Expr::Column(col) = orderby_expr.expr.as_ref() { + let idx = schema.index_of_column(col).ok()?; + return if func_deps.iter().any(|dep| { + dep.source_indices == vec![idx] && dep.mode == Dependency::Single + }) { + Some(true) + } else { + Some(false) + }; } + Some(false) }); let window_frame = window diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index 7fb32f714cfa6..cdaa787cedd0d 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -20,7 +20,7 @@ use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, Result, }; use datafusion_expr::expr::Sort; -use datafusion_expr::Expr; +use datafusion_expr::{Expr, SortExpr}; use sqlparser::ast::{Expr as SQLExpr, OrderByExpr, Value}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -44,7 +44,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, literal_to_column: bool, additional_schema: Option<&DFSchema>, - ) -> Result> { + ) -> Result> { if exprs.is_empty() { return Ok(vec![]); } @@ -99,13 +99,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; let asc = asc.unwrap_or(true); - expr_vec.push(Expr::Sort(Sort::new( + expr_vec.push(Sort::new( Box::new(expr), asc, // when asc is true, by default nulls last to be consistent with postgres // postgres rule: https://www.postgresql.org/docs/current/queries-order.html nulls_first.unwrap_or(!asc), - ))) + )) } Ok(expr_vec) } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index ba2b41bb6ecff..71328cfd018c3 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{not_impl_err, plan_err, Constraints, Result, ScalarValue}; +use datafusion_expr::expr::Sort; use datafusion_expr::{ CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, Operator, @@ -119,7 +120,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn order_by( &self, plan: LogicalPlan, - order_by: Vec, + order_by: Vec, ) -> Result { if order_by.is_empty() { return Ok(plan); diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 45fda094557b0..6f7c65ffa7f87 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -31,7 +31,7 @@ use datafusion_common::UnnestOptions; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ - normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_cols, + normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_sorts, }; use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::utils::{ @@ -108,7 +108,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { true, Some(base_plan.schema().as_ref()), )?; - let order_by_rex = normalize_cols(order_by_rex, &projected_plan)?; + let order_by_rex = normalize_sorts(order_by_rex, &projected_plan)?; // this alias map is resolved and looked up in both having exprs and group by exprs let alias_map = extract_aliases(&select_exprs); diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index e75a96e78d483..3dfc379b039a8 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -48,9 +48,10 @@ use datafusion_expr::{ CreateIndex as PlanCreateIndex, CreateMemoryTable, CreateView, DescribeTable, DmlStatement, DropCatalogSchema, DropFunction, DropTable, DropView, EmptyRelation, Explain, Expr, ExprSchemable, Filter, LogicalPlan, LogicalPlanBuilder, - OperateFunctionArg, PlanType, Prepare, SetVariable, Statement as PlanStatement, - ToStringifiedPlan, TransactionAccessMode, TransactionConclusion, TransactionEnd, - TransactionIsolationLevel, TransactionStart, Volatility, WriteOp, + OperateFunctionArg, PlanType, Prepare, SetVariable, SortExpr, + Statement as PlanStatement, ToStringifiedPlan, TransactionAccessMode, + TransactionConclusion, TransactionEnd, TransactionIsolationLevel, TransactionStart, + Volatility, WriteOp, }; use sqlparser::ast; use sqlparser::ast::{ @@ -952,7 +953,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { order_exprs: Vec, schema: &DFSchemaRef, planner_context: &mut PlannerContext, - ) -> Result>> { + ) -> Result>> { // Ask user to provide a schema if schema is empty. if !order_exprs.is_empty() && schema.fields().is_empty() { return plan_err!( @@ -966,8 +967,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let expr_vec = self.order_by_to_sort_expr(expr, schema, planner_context, true, None)?; // Verify that columns of all SortExprs exist in the schema: - for expr in expr_vec.iter() { - for column in expr.column_refs().iter() { + for sort in expr_vec.iter() { + for column in sort.expr.column_refs().iter() { if !schema.has_column(column) { // Return an error if any column is not in the schema: return plan_err!("Column {column} is not in schema"); diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 9ce627aecc760..cde301227f537 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use core::fmt; - use datafusion_expr::ScalarUDF; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ @@ -24,7 +22,7 @@ use sqlparser::ast::{ ObjectName, TimezoneInfo, UnaryOperator, }; use std::sync::Arc; -use std::{fmt::Display, vec}; +use std::vec; use super::dialect::{DateFieldExtractStyle, IntervalStyle}; use super::Unparser; @@ -46,33 +44,6 @@ use datafusion_expr::{ Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, }; -/// DataFusion's Exprs can represent either an `Expr` or an `OrderByExpr` -pub enum Unparsed { - // SQL Expression - Expr(ast::Expr), - // SQL ORDER BY expression (e.g. `col ASC NULLS FIRST`) - OrderByExpr(ast::OrderByExpr), -} - -impl Unparsed { - pub fn into_order_by_expr(self) -> Result { - if let Unparsed::OrderByExpr(order_by_expr) = self { - Ok(order_by_expr) - } else { - internal_err!("Expected Sort expression to be converted an OrderByExpr") - } - } -} - -impl Display for Unparsed { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Unparsed::Expr(expr) => write!(f, "{}", expr), - Unparsed::OrderByExpr(order_by_expr) => write!(f, "{}", order_by_expr), - } - } -} - /// Convert a DataFusion [`Expr`] to [`ast::Expr`] /// /// This function is the opposite of [`SqlToRel::sql_to_expr`] and can be used @@ -106,13 +77,9 @@ pub fn expr_to_sql(expr: &Expr) -> Result { unparser.expr_to_sql(expr) } -/// Convert a DataFusion [`Expr`] to [`Unparsed`] -/// -/// This function is similar to expr_to_sql, but it supports converting more [`Expr`] types like -/// `Sort` expressions to `OrderByExpr` expressions. -pub fn expr_to_unparsed(expr: &Expr) -> Result { +pub fn sort_to_sql(sort: &Sort) -> Result { let unparser = Unparser::default(); - unparser.expr_to_unparsed(expr) + unparser.sort_to_sql(sort) } const LOWEST: &BinaryOperator = &BinaryOperator::Or; @@ -286,7 +253,7 @@ impl Unparser<'_> { }; let order_by: Vec = order_by .iter() - .map(|expr| expr_to_unparsed(expr)?.into_order_by_expr()) + .map(|expr| sort_to_sql(expr)) .collect::>>()?; let start_bound = self.convert_bound(&window_frame.start_bound)?; @@ -413,11 +380,6 @@ impl Unparser<'_> { negated: *negated, }) } - Expr::Sort(Sort { - expr: _, - asc: _, - nulls_first: _, - }) => plan_err!("Sort expression should be handled by expr_to_unparsed"), Expr::IsNull(expr) => { Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql_inner(expr)?))) } @@ -534,36 +496,26 @@ impl Unparser<'_> { } } - /// This function can convert more [`Expr`] types than `expr_to_sql`, - /// returning an [`Unparsed`] like `Sort` expressions to `OrderByExpr` - /// expressions. - pub fn expr_to_unparsed(&self, expr: &Expr) -> Result { - match expr { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let sql_parser_expr = self.expr_to_sql(expr)?; + pub fn sort_to_sql(&self, sort: &Sort) -> Result { + let Sort { + expr, + asc, + nulls_first, + } = sort; + let sql_parser_expr = self.expr_to_sql(expr)?; - let nulls_first = if self.dialect.supports_nulls_first_in_sort() { - Some(*nulls_first) - } else { - None - }; + let nulls_first = if self.dialect.supports_nulls_first_in_sort() { + Some(*nulls_first) + } else { + None + }; - Ok(Unparsed::OrderByExpr(ast::OrderByExpr { - expr: sql_parser_expr, - asc: Some(*asc), - nulls_first, - with_fill: None, - })) - } - _ => { - let sql_parser_expr = self.expr_to_sql(expr)?; - Ok(Unparsed::Expr(sql_parser_expr)) - } - } + Ok(ast::OrderByExpr { + expr: sql_parser_expr, + asc: Some(*asc), + nulls_first, + with_fill: None, + }) } fn scalar_function_to_sql_overrides( @@ -1941,24 +1893,6 @@ mod tests { Ok(()) } - #[test] - fn expr_to_unparsed_ok() -> Result<()> { - let tests: Vec<(Expr, &str)> = vec![ - ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#), - (col("a").sort(true, true), r#"a ASC NULLS FIRST"#), - ]; - - for (expr, expected) in tests { - let ast = expr_to_unparsed(&expr)?; - - let actual = format!("{}", ast); - - assert_eq!(actual, expected); - } - - Ok(()) - } - #[test] fn custom_dialect_with_identifier_quote_style() -> Result<()> { let dialect = CustomDialectBuilder::new() @@ -2047,7 +1981,7 @@ mod tests { #[test] fn customer_dialect_support_nulls_first_in_ort() -> Result<()> { - let tests: Vec<(Expr, &str, bool)> = vec![ + let tests: Vec<(Sort, &str, bool)> = vec![ (col("a").sort(true, true), r#"a ASC NULLS FIRST"#, true), (col("a").sort(true, true), r#"a ASC"#, false), ]; @@ -2057,7 +1991,7 @@ mod tests { .with_supports_nulls_first_in_sort(supports_nulls_first_in_sort) .build(); let unparser = Unparser::new(&dialect); - let ast = unparser.expr_to_unparsed(&expr)?; + let ast = unparser.sort_to_sql(&expr)?; let actual = format!("{}", ast); diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index b2fd32566aa84..83ae64ba238b0 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -29,8 +29,6 @@ pub use plan::plan_to_sql; use self::dialect::{DefaultDialect, Dialect}; pub mod dialect; -pub use expr::Unparsed; - /// Convert a DataFusion [`Expr`] to [`sqlparser::ast::Expr`] /// /// See [`expr_to_sql`] for background. `Unparser` allows greater control of diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 106705c322fc1..509c5dd52cd4a 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -15,11 +15,10 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{ - internal_err, not_impl_err, plan_err, Column, DataFusionError, Result, -}; +use datafusion_common::{internal_err, not_impl_err, Column, DataFusionError, Result}; use datafusion_expr::{ expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, Projection, + SortExpr, }; use sqlparser::ast::{self, Ident, SetExpr}; @@ -318,7 +317,7 @@ impl Unparser<'_> { return self.derive(plan, relation); } if let Some(query_ref) = query { - query_ref.order_by(self.sort_to_sql(sort.expr.clone())?); + query_ref.order_by(self.sorts_to_sql(sort.expr.clone())?); } else { return internal_err!( "Sort operator only valid in a statement context." @@ -361,7 +360,7 @@ impl Unparser<'_> { .collect::>>()?; if let Some(sort_expr) = &on.sort_expr { if let Some(query_ref) = query { - query_ref.order_by(self.sort_to_sql(sort_expr.clone())?); + query_ref.order_by(self.sorts_to_sql(sort_expr.clone())?); } else { return internal_err!( "Sort operator only valid in a statement context." @@ -525,14 +524,10 @@ impl Unparser<'_> { } } - fn sort_to_sql(&self, sort_exprs: Vec) -> Result> { + fn sorts_to_sql(&self, sort_exprs: Vec) -> Result> { sort_exprs .iter() - .map(|expr: &Expr| { - self.expr_to_unparsed(expr)? - .into_order_by_expr() - .or(plan_err!("Expecting Sort expr")) - }) + .map(|sort_expr| self.sort_to_sql(sort_expr)) .collect::>>() } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index ffa1a60fe62ae..715d3aca68044 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -21,10 +21,10 @@ use std::{ }; use datafusion_common::{ - tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeIterator}, + tree_node::{Transformed, TransformedResult, TreeNode}, Result, }; -use datafusion_expr::{Expr, LogicalPlan, Projection, Sort}; +use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; use sqlparser::ast::Ident; /// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions. @@ -86,22 +86,48 @@ pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result } /// Rewrite sort expressions that have a UNION plan as their input to remove the table reference. -fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { - let sort_exprs: Vec = exprs - .into_iter() - .map_until_stop_and_collect(|expr| { - expr.transform_up(|expr| { - if let Expr::Column(mut col) = expr { - col.relation = None; - Ok(Transformed::yes(Expr::Column(col))) - } else { - Ok(Transformed::no(expr)) - } - }) +fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { + // TODO FIXME + // this impl does't compile + /* + let sort_exprs = transform_sort_vec(exprs, |expr| { + expr.transform_up(|expr| { + if let Expr::Column(mut col) = expr { + col.relation = None; + Ok(Transformed::yes(Expr::Column(col))) + } else { + Ok(Transformed::no(expr)) + } }) - .data()?; + }) + .data()?; Ok(sort_exprs) + */ + + /* + original code: + + fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { + let sort_exprs: Vec = exprs + .into_iter() + .map_until_stop_and_collect(|expr| { + expr.transform_up(|expr| { + if let Expr::Column(mut col) = expr { + col.relation = None; + Ok(Transformed::yes(Expr::Column(col))) + } else { + Ok(Transformed::no(expr)) + } + }) + }) + .data()?; + + Ok(sort_exprs) + } + */ + + panic!("rewrite_sort_expr_for_union to be implemented"); } // Rewrite logic plan for query that order by columns are not in projections @@ -161,12 +187,8 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( .collect::>(); let mut collects = p.expr.clone(); - for expr in &sort.expr { - if let Expr::Sort(s) = expr { - collects.push(s.expr.as_ref().clone()); - } else { - panic!("sort expression must be of type Sort"); - } + for sort in &sort.expr { + collects.push(sort.expr.as_ref().clone()); } // Compare outer collects Expr::to_string with inner collected transformed values