Skip to content

Commit

Permalink
feat: support arbitrary expressions in LIMIT plan (#13028)
Browse files Browse the repository at this point in the history
* feat: support arbitrary expressions in `LIMIT` clause

* restore test

* Fix doc

* Update datafusion/optimizer/src/eliminate_limit.rs

Co-authored-by: Jax Liu <[email protected]>

* Update datafusion/expr/src/expr_rewriter/mod.rs

Co-authored-by: Jax Liu <[email protected]>

* Fix clippy

* Disallow non-integer types

---------

Co-authored-by: Jax Liu <[email protected]>
  • Loading branch information
jonahgao and goldmedal authored Oct 24, 2024
1 parent 18b2aaa commit 3f3a0cf
Show file tree
Hide file tree
Showing 19 changed files with 376 additions and 208 deletions.
25 changes: 18 additions & 7 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@ use crate::error::{DataFusionError, Result};
use crate::execution::context::{ExecutionProps, SessionState};
use crate::logical_expr::utils::generate_sort_key;
use crate::logical_expr::{
Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Window,
Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Values, Window,
};
use crate::logical_expr::{
Expr, LogicalPlan, Partitioning as LogicalPartitioning, PlanType, Repartition,
UserDefinedLogicalNode,
};
use crate::logical_expr::{Limit, Values};
use crate::physical_expr::{create_physical_expr, create_physical_exprs};
use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use crate::physical_plan::analyze::AnalyzeExec;
Expand Down Expand Up @@ -78,8 +77,8 @@ use datafusion_expr::expr::{
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::{
DescribeTable, DmlStatement, Extension, Filter, JoinType, RecursiveQuery, SortExpr,
StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
DescribeTable, DmlStatement, Extension, FetchType, Filter, JoinType, RecursiveQuery,
SkipType, SortExpr, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
};
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
use datafusion_physical_expr::expressions::Literal;
Expand Down Expand Up @@ -796,8 +795,20 @@ impl DefaultPhysicalPlanner {
}
LogicalPlan::Subquery(_) => todo!(),
LogicalPlan::SubqueryAlias(_) => children.one()?,
LogicalPlan::Limit(Limit { skip, fetch, .. }) => {
LogicalPlan::Limit(limit) => {
let input = children.one()?;
let SkipType::Literal(skip) = limit.get_skip_type()? else {
return not_impl_err!(
"Unsupported OFFSET expression: {:?}",
limit.skip
);
};
let FetchType::Literal(fetch) = limit.get_fetch_type()? else {
return not_impl_err!(
"Unsupported LIMIT expression: {:?}",
limit.fetch
);
};

// GlobalLimitExec requires a single partition for input
let input = if input.output_partitioning().partition_count() == 1 {
Expand All @@ -806,13 +817,13 @@ impl DefaultPhysicalPlanner {
// Apply a LocalLimitExec to each partition. The optimizer will also insert
// a CoalescePartitionsExec between the GlobalLimitExec and LocalLimitExec
if let Some(fetch) = fetch {
Arc::new(LocalLimitExec::new(input, *fetch + skip))
Arc::new(LocalLimitExec::new(input, fetch + skip))
} else {
input
}
};

Arc::new(GlobalLimitExec::new(input, *skip, *fetch))
Arc::new(GlobalLimitExec::new(input, skip, fetch))
}
LogicalPlan::Unnest(Unnest {
list_type_columns,
Expand Down
44 changes: 22 additions & 22 deletions datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ use datafusion::{
runtime_env::RuntimeEnv,
},
logical_expr::{
Expr, Extension, Limit, LogicalPlan, Sort, UserDefinedLogicalNode,
Expr, Extension, LogicalPlan, Sort, UserDefinedLogicalNode,
UserDefinedLogicalNodeCore,
},
optimizer::{OptimizerConfig, OptimizerRule},
Expand All @@ -98,7 +98,7 @@ use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::ScalarValue;
use datafusion_expr::tree_node::replace_sort_expression;
use datafusion_expr::{Projection, SortExpr};
use datafusion_expr::{FetchType, Projection, SortExpr};
use datafusion_optimizer::optimizer::ApplyOrder;
use datafusion_optimizer::AnalyzerRule;

Expand Down Expand Up @@ -361,28 +361,28 @@ impl OptimizerRule for TopKOptimizerRule {
// Note: this code simply looks for the pattern of a Limit followed by a
// Sort and replaces it by a TopK node. It does not handle many
// edge cases (e.g multiple sort columns, sort ASC / DESC), etc.
if let LogicalPlan::Limit(Limit {
fetch: Some(fetch),
input,
let LogicalPlan::Limit(ref limit) = plan else {
return Ok(Transformed::no(plan));
};
let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else {
return Ok(Transformed::no(plan));
};

if let LogicalPlan::Sort(Sort {
ref expr,
ref input,
..
}) = &plan
}) = limit.input.as_ref()
{
if let LogicalPlan::Sort(Sort {
ref expr,
ref input,
..
}) = **input
{
if expr.len() == 1 {
// we found a sort with a single sort expr, replace with a a TopK
return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
node: Arc::new(TopKPlanNode {
k: *fetch,
input: input.as_ref().clone(),
expr: expr[0].clone(),
}),
})));
}
if expr.len() == 1 {
// we found a sort with a single sort expr, replace with a a TopK
return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
node: Arc::new(TopKPlanNode {
k: fetch,
input: input.as_ref().clone(),
expr: expr[0].clone(),
}),
})));
}
}

Expand Down
9 changes: 6 additions & 3 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,14 @@ impl NamePreserver {
/// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
pub fn new(plan: &LogicalPlan) -> Self {
Self {
// The schema of Filter, Join and TableScan nodes comes from their inputs rather than
// their expressions, so there is no need to use aliases to preserve expression names.
// The expressions of these plans do not contribute to their output schema,
// so there is no need to preserve expression names to prevent a schema change.
use_alias: !matches!(
plan,
LogicalPlan::Filter(_) | LogicalPlan::Join(_) | LogicalPlan::TableScan(_)
LogicalPlan::Filter(_)
| LogicalPlan::Join(_)
| LogicalPlan::TableScan(_)
| LogicalPlan::Limit(_)
),
}
}
Expand Down
19 changes: 16 additions & 3 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use crate::utils::{
find_valid_equijoin_key_pair, group_window_expr_by_sort_keys,
};
use crate::{
and, binary_expr, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery,
and, binary_expr, lit, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery,
TableProviderFilterPushDown, TableSource, WriteOp,
};

Expand Down Expand Up @@ -512,9 +512,22 @@ impl LogicalPlanBuilder {
/// `fetch` - Maximum number of rows to fetch, after skipping `skip` rows,
/// if specified.
pub fn limit(self, skip: usize, fetch: Option<usize>) -> Result<Self> {
let skip_expr = if skip == 0 {
None
} else {
Some(lit(skip as i64))
};
let fetch_expr = fetch.map(|f| lit(f as i64));
self.limit_by_expr(skip_expr, fetch_expr)
}

/// Limit the number of rows returned
///
/// Similar to `limit` but uses expressions for `skip` and `fetch`
pub fn limit_by_expr(self, skip: Option<Expr>, fetch: Option<Expr>) -> Result<Self> {
Ok(Self::new(LogicalPlan::Limit(Limit {
skip,
fetch,
skip: skip.map(Box::new),
fetch: fetch.map(Box::new),
input: self.plan,
})))
}
Expand Down
6 changes: 4 additions & 2 deletions datafusion/expr/src/logical_plan/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -549,11 +549,13 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> {
let mut object = serde_json::json!(
{
"Node Type": "Limit",
"Skip": skip,
}
);
if let Some(s) = skip {
object["Skip"] = s.to_string().into()
};
if let Some(f) = fetch {
object["Fetch"] = serde_json::Value::Number((*f).into());
object["Fetch"] = f.to_string().into()
};
object
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ pub use ddl::{
pub use dml::{DmlStatement, WriteOp};
pub use plan::{
projection_schema, Aggregate, Analyze, ColumnUnnestList, CrossJoin, DescribeTable,
Distinct, DistinctOn, EmptyRelation, Explain, Extension, Filter, Join,
Distinct, DistinctOn, EmptyRelation, Explain, Extension, FetchType, Filter, Join,
JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare,
Projection, RecursiveQuery, Repartition, Sort, StringifiedPlan, Subquery,
Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery,
SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window,
};
pub use statement::{
Expand Down
105 changes: 90 additions & 15 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion_common::{
aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints,
DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence,
FunctionalDependencies, ParamValues, Result, TableReference, UnnestOptions,
FunctionalDependencies, ParamValues, Result, ScalarValue, TableReference,
UnnestOptions,
};
use indexmap::IndexSet;

Expand Down Expand Up @@ -960,11 +961,21 @@ impl LogicalPlan {
.map(LogicalPlan::SubqueryAlias)
}
LogicalPlan::Limit(Limit { skip, fetch, .. }) => {
self.assert_no_expressions(expr)?;
let old_expr_len = skip.iter().chain(fetch.iter()).count();
if old_expr_len != expr.len() {
return internal_err!(
"Invalid number of new Limit expressions: expected {}, got {}",
old_expr_len,
expr.len()
);
}
// Pop order is same as the order returned by `LogicalPlan::expressions()`
let new_skip = skip.as_ref().and(expr.pop());
let new_fetch = fetch.as_ref().and(expr.pop());
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Limit(Limit {
skip: *skip,
fetch: *fetch,
skip: new_skip.map(Box::new),
fetch: new_fetch.map(Box::new),
input: Arc::new(input),
}))
}
Expand Down Expand Up @@ -1339,7 +1350,10 @@ impl LogicalPlan {
LogicalPlan::RecursiveQuery(_) => None,
LogicalPlan::Subquery(_) => None,
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(),
LogicalPlan::Limit(Limit { fetch, .. }) => *fetch,
LogicalPlan::Limit(limit) => match limit.get_fetch_type() {
Ok(FetchType::Literal(s)) => s,
_ => None,
},
LogicalPlan::Distinct(
Distinct::All(input) | Distinct::On(DistinctOn { input, .. }),
) => input.max_rows(),
Expand Down Expand Up @@ -1909,16 +1923,20 @@ impl LogicalPlan {
)
}
},
LogicalPlan::Limit(Limit {
ref skip,
ref fetch,
..
}) => {
LogicalPlan::Limit(limit) => {
// Attempt to display `skip` and `fetch` as literals if possible, otherwise as expressions.
let skip_str = match limit.get_skip_type() {
Ok(SkipType::Literal(n)) => n.to_string(),
_ => limit.skip.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string()),
};
let fetch_str = match limit.get_fetch_type() {
Ok(FetchType::Literal(Some(n))) => n.to_string(),
Ok(FetchType::Literal(None)) => "None".to_string(),
_ => limit.fetch.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string())
};
write!(
f,
"Limit: skip={}, fetch={}",
skip,
fetch.map_or_else(|| "None".to_string(), |x| x.to_string())
"Limit: skip={}, fetch={}", skip_str,fetch_str,
)
}
LogicalPlan::Subquery(Subquery { .. }) => {
Expand Down Expand Up @@ -2778,14 +2796,71 @@ impl PartialOrd for Extension {
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct Limit {
/// Number of rows to skip before fetch
pub skip: usize,
pub skip: Option<Box<Expr>>,
/// Maximum number of rows to fetch,
/// None means fetching all rows
pub fetch: Option<usize>,
pub fetch: Option<Box<Expr>>,
/// The logical plan
pub input: Arc<LogicalPlan>,
}

/// Different types of skip expression in Limit plan.
pub enum SkipType {
/// The skip expression is a literal value.
Literal(usize),
/// Currently only supports expressions that can be folded into constants.
UnsupportedExpr,
}

/// Different types of fetch expression in Limit plan.
pub enum FetchType {
/// The fetch expression is a literal value.
/// `Literal(None)` means the fetch expression is not provided.
Literal(Option<usize>),
/// Currently only supports expressions that can be folded into constants.
UnsupportedExpr,
}

impl Limit {
/// Get the skip type from the limit plan.
pub fn get_skip_type(&self) -> Result<SkipType> {
match self.skip.as_deref() {
Some(expr) => match *expr {
Expr::Literal(ScalarValue::Int64(s)) => {
// `skip = NULL` is equivalent to `skip = 0`
let s = s.unwrap_or(0);
if s >= 0 {
Ok(SkipType::Literal(s as usize))
} else {
plan_err!("OFFSET must be >=0, '{}' was provided", s)
}
}
_ => Ok(SkipType::UnsupportedExpr),
},
// `skip = None` is equivalent to `skip = 0`
None => Ok(SkipType::Literal(0)),
}
}

/// Get the fetch type from the limit plan.
pub fn get_fetch_type(&self) -> Result<FetchType> {
match self.fetch.as_deref() {
Some(expr) => match *expr {
Expr::Literal(ScalarValue::Int64(Some(s))) => {
if s >= 0 {
Ok(FetchType::Literal(Some(s as usize)))
} else {
plan_err!("LIMIT must be >= 0, '{}' was provided", s)
}
}
Expr::Literal(ScalarValue::Int64(None)) => Ok(FetchType::Literal(None)),
_ => Ok(FetchType::UnsupportedExpr),
},
None => Ok(FetchType::Literal(None)),
}
}
}

/// Removes duplicate rows from the input
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum Distinct {
Expand Down
Loading

0 comments on commit 3f3a0cf

Please sign in to comment.