Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support arbitrary expressions in LIMIT plan #13028

Merged
merged 10 commits into from
Oct 24, 2024
25 changes: 18 additions & 7 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@ use crate::error::{DataFusionError, Result};
use crate::execution::context::{ExecutionProps, SessionState};
use crate::logical_expr::utils::generate_sort_key;
use crate::logical_expr::{
Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Window,
Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Values, Window,
};
use crate::logical_expr::{
Expr, LogicalPlan, Partitioning as LogicalPartitioning, PlanType, Repartition,
UserDefinedLogicalNode,
};
use crate::logical_expr::{Limit, Values};
use crate::physical_expr::{create_physical_expr, create_physical_exprs};
use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use crate::physical_plan::analyze::AnalyzeExec;
Expand Down Expand Up @@ -78,8 +77,8 @@ use datafusion_expr::expr::{
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::{
DescribeTable, DmlStatement, Extension, Filter, JoinType, RecursiveQuery, SortExpr,
StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
DescribeTable, DmlStatement, Extension, FetchType, Filter, JoinType, RecursiveQuery,
SkipType, SortExpr, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
};
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
use datafusion_physical_expr::expressions::Literal;
Expand Down Expand Up @@ -796,8 +795,20 @@ impl DefaultPhysicalPlanner {
}
LogicalPlan::Subquery(_) => todo!(),
LogicalPlan::SubqueryAlias(_) => children.one()?,
LogicalPlan::Limit(Limit { skip, fetch, .. }) => {
LogicalPlan::Limit(limit) => {
let input = children.one()?;
let SkipType::Literal(skip) = limit.get_skip_type()? else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

return not_impl_err!(
"Unsupported OFFSET expression: {:?}",
limit.skip
);
};
let FetchType::Literal(fetch) = limit.get_fetch_type()? else {
return not_impl_err!(
"Unsupported LIMIT expression: {:?}",
limit.fetch
);
};

// GlobalLimitExec requires a single partition for input
let input = if input.output_partitioning().partition_count() == 1 {
Expand All @@ -806,13 +817,13 @@ impl DefaultPhysicalPlanner {
// Apply a LocalLimitExec to each partition. The optimizer will also insert
// a CoalescePartitionsExec between the GlobalLimitExec and LocalLimitExec
if let Some(fetch) = fetch {
Arc::new(LocalLimitExec::new(input, *fetch + skip))
Arc::new(LocalLimitExec::new(input, fetch + skip))
} else {
input
}
};

Arc::new(GlobalLimitExec::new(input, *skip, *fetch))
Arc::new(GlobalLimitExec::new(input, skip, fetch))
}
LogicalPlan::Unnest(Unnest {
list_type_columns,
Expand Down
44 changes: 22 additions & 22 deletions datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ use datafusion::{
runtime_env::RuntimeEnv,
},
logical_expr::{
Expr, Extension, Limit, LogicalPlan, Sort, UserDefinedLogicalNode,
Expr, Extension, LogicalPlan, Sort, UserDefinedLogicalNode,
UserDefinedLogicalNodeCore,
},
optimizer::{OptimizerConfig, OptimizerRule},
Expand All @@ -98,7 +98,7 @@ use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::ScalarValue;
use datafusion_expr::tree_node::replace_sort_expression;
use datafusion_expr::{Projection, SortExpr};
use datafusion_expr::{FetchType, Projection, SortExpr};
use datafusion_optimizer::optimizer::ApplyOrder;
use datafusion_optimizer::AnalyzerRule;

Expand Down Expand Up @@ -361,28 +361,28 @@ impl OptimizerRule for TopKOptimizerRule {
// Note: this code simply looks for the pattern of a Limit followed by a
// Sort and replaces it by a TopK node. It does not handle many
// edge cases (e.g multiple sort columns, sort ASC / DESC), etc.
if let LogicalPlan::Limit(Limit {
fetch: Some(fetch),
input,
let LogicalPlan::Limit(ref limit) = plan else {
return Ok(Transformed::no(plan));
};
let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else {
return Ok(Transformed::no(plan));
};

if let LogicalPlan::Sort(Sort {
ref expr,
ref input,
..
}) = &plan
}) = limit.input.as_ref()
{
if let LogicalPlan::Sort(Sort {
ref expr,
ref input,
..
}) = **input
{
if expr.len() == 1 {
// we found a sort with a single sort expr, replace with a a TopK
return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
node: Arc::new(TopKPlanNode {
k: *fetch,
input: input.as_ref().clone(),
expr: expr[0].clone(),
}),
})));
}
if expr.len() == 1 {
// we found a sort with a single sort expr, replace with a a TopK
return Ok(Transformed::yes(LogicalPlan::Extension(Extension {
node: Arc::new(TopKPlanNode {
k: fetch,
input: input.as_ref().clone(),
expr: expr[0].clone(),
}),
})));
}
}

Expand Down
9 changes: 6 additions & 3 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,14 @@ impl NamePreserver {
/// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
pub fn new(plan: &LogicalPlan) -> Self {
Self {
// The schema of Filter, Join and TableScan nodes comes from their inputs rather than
// their expressions, so there is no need to use aliases to preserve expression names.
// The expressions of these plans do not contribute to their output schema,
// so there is no need to preserve expression names to prevent a schema change.
use_alias: !matches!(
plan,
LogicalPlan::Filter(_) | LogicalPlan::Join(_) | LogicalPlan::TableScan(_)
LogicalPlan::Filter(_)
| LogicalPlan::Join(_)
| LogicalPlan::TableScan(_)
| LogicalPlan::Limit(_)
),
}
}
Expand Down
19 changes: 16 additions & 3 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use crate::utils::{
find_valid_equijoin_key_pair, group_window_expr_by_sort_keys,
};
use crate::{
and, binary_expr, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery,
and, binary_expr, lit, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery,
TableProviderFilterPushDown, TableSource, WriteOp,
};

Expand Down Expand Up @@ -512,9 +512,22 @@ impl LogicalPlanBuilder {
/// `fetch` - Maximum number of rows to fetch, after skipping `skip` rows,
/// if specified.
pub fn limit(self, skip: usize, fetch: Option<usize>) -> Result<Self> {
let skip_expr = if skip == 0 {
None
} else {
Some(lit(skip as i64))
};
let fetch_expr = fetch.map(|f| lit(f as i64));
self.limit_by_expr(skip_expr, fetch_expr)
}

/// Limit the number of rows returned
///
/// Similar to `limit` but uses expressions for `skip` and `fetch`
pub fn limit_by_expr(self, skip: Option<Expr>, fetch: Option<Expr>) -> Result<Self> {
Ok(Self::new(LogicalPlan::Limit(Limit {
skip,
fetch,
skip: skip.map(Box::new),
fetch: fetch.map(Box::new),
input: self.plan,
})))
}
Expand Down
6 changes: 4 additions & 2 deletions datafusion/expr/src/logical_plan/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -549,11 +549,13 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> {
let mut object = serde_json::json!(
{
"Node Type": "Limit",
"Skip": skip,
}
);
if let Some(s) = skip {
object["Skip"] = s.to_string().into()
};
if let Some(f) = fetch {
object["Fetch"] = serde_json::Value::Number((*f).into());
object["Fetch"] = f.to_string().into()
};
object
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ pub use ddl::{
pub use dml::{DmlStatement, WriteOp};
pub use plan::{
projection_schema, Aggregate, Analyze, ColumnUnnestList, CrossJoin, DescribeTable,
Distinct, DistinctOn, EmptyRelation, Explain, Extension, Filter, Join,
Distinct, DistinctOn, EmptyRelation, Explain, Extension, FetchType, Filter, Join,
JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare,
Projection, RecursiveQuery, Repartition, Sort, StringifiedPlan, Subquery,
Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery,
SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window,
};
pub use statement::{
Expand Down
105 changes: 90 additions & 15 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion_common::{
aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints,
DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence,
FunctionalDependencies, ParamValues, Result, TableReference, UnnestOptions,
FunctionalDependencies, ParamValues, Result, ScalarValue, TableReference,
UnnestOptions,
};
use indexmap::IndexSet;

Expand Down Expand Up @@ -960,11 +961,21 @@ impl LogicalPlan {
.map(LogicalPlan::SubqueryAlias)
}
LogicalPlan::Limit(Limit { skip, fetch, .. }) => {
self.assert_no_expressions(expr)?;
let old_expr_len = skip.iter().chain(fetch.iter()).count();
if old_expr_len != expr.len() {
return internal_err!(
"Invalid number of new Limit expressions: expected {}, got {}",
old_expr_len,
expr.len()
);
}
// Pop order is same as the order returned by `LogicalPlan::expressions()`
let new_skip = skip.as_ref().and(expr.pop());
let new_fetch = fetch.as_ref().and(expr.pop());
let input = self.only_input(inputs)?;
Ok(LogicalPlan::Limit(Limit {
skip: *skip,
fetch: *fetch,
skip: new_skip.map(Box::new),
fetch: new_fetch.map(Box::new),
input: Arc::new(input),
}))
}
Expand Down Expand Up @@ -1339,7 +1350,10 @@ impl LogicalPlan {
LogicalPlan::RecursiveQuery(_) => None,
LogicalPlan::Subquery(_) => None,
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(),
LogicalPlan::Limit(Limit { fetch, .. }) => *fetch,
LogicalPlan::Limit(limit) => match limit.get_fetch_type() {
Ok(FetchType::Literal(s)) => s,
_ => None,
},
LogicalPlan::Distinct(
Distinct::All(input) | Distinct::On(DistinctOn { input, .. }),
) => input.max_rows(),
Expand Down Expand Up @@ -1909,16 +1923,20 @@ impl LogicalPlan {
)
}
},
LogicalPlan::Limit(Limit {
ref skip,
ref fetch,
..
}) => {
LogicalPlan::Limit(limit) => {
// Attempt to display `skip` and `fetch` as literals if possible, otherwise as expressions.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Display literals as before to avoid breaking too many tests. Maybe we could display them in the expr-style through a follow-up PR. For example, 1 -> Int64(1).

let skip_str = match limit.get_skip_type() {
Ok(SkipType::Literal(n)) => n.to_string(),
_ => limit.skip.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string()),
};
let fetch_str = match limit.get_fetch_type() {
Ok(FetchType::Literal(Some(n))) => n.to_string(),
Ok(FetchType::Literal(None)) => "None".to_string(),
_ => limit.fetch.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string())
};
write!(
f,
"Limit: skip={}, fetch={}",
skip,
fetch.map_or_else(|| "None".to_string(), |x| x.to_string())
"Limit: skip={}, fetch={}", skip_str,fetch_str,
)
}
LogicalPlan::Subquery(Subquery { .. }) => {
Expand Down Expand Up @@ -2778,14 +2796,71 @@ impl PartialOrd for Extension {
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct Limit {
/// Number of rows to skip before fetch
pub skip: usize,
pub skip: Option<Box<Expr>>,
Copy link
Member Author

@jonahgao jonahgao Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Box here to prevent increasing the size of LogicalPlan, that isstd::mem::size_of::<LogicalPlan>(); otherwise it will cause a stack overflow in one of the array_ndims test

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the constraints on the expression that can be used here?
For example, can it have any column references?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to do the constant folding when building Limit node, so that logical plan structure remains intact? See also #12723

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the constraints on the expression that can be used here? For example, can it have any column references?

I think it can be any integer expression, and can also contain column references. Both PostgreSQL and DuckDB support select 1 limit (select max(col0) from t).

v1.1.1-dev319 af39bd0dcf
D create table t as values(1);
D select 1 limit (select max(col0) from t);
┌───────┐
│   1   │
│ int32 │
├───────┤
│     1 │
└───────┘

/// Maximum number of rows to fetch,
/// None means fetching all rows
pub fetch: Option<usize>,
pub fetch: Option<Box<Expr>>,
/// The logical plan
pub input: Arc<LogicalPlan>,
}

/// Different types of skip expression in Limit plan.
pub enum SkipType {
/// The skip expression is a literal value.
Literal(usize),
/// Currently only supports expressions that can be folded into constants.
UnsupportedExpr,
Comment on lines +2811 to +2812
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Limit is a relational operator, so this will always need to be constant-foldable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but some expressions like limit (select count(*) from t) can't be folded at the planning stage , so we need to keep them for conversion into physical expressions later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonahgao good point. I also thought about this, but ignored, assuming we don't plan to support this.

We could have Limit node as-is for now (and do trivial constant folding when building the plan), and introduce expressions in the plan when we add support for limit (<uncorrelated subquery>). WDYT?

Copy link
Member Author

@jonahgao jonahgao Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constant folding requires the TypeCoercion and SimplifyExpressions rule. We can't directly invoke them during building for now and need to defer them to optimizer.

Another reason the Limit node needs to contain expressions is to support Prepare statements. Issue #12294 requires the Limit node to use Expr::Placeholder.

}

/// Different types of fetch expression in Limit plan.
pub enum FetchType {
/// The fetch expression is a literal value.
/// `Literal(None)` means the fetch expression is not provided.
Literal(Option<usize>),
/// Currently only supports expressions that can be folded into constants.
UnsupportedExpr,
}

impl Limit {
/// Get the skip type from the limit plan.
pub fn get_skip_type(&self) -> Result<SkipType> {
match self.skip.as_deref() {
Some(expr) => match *expr {
Expr::Literal(ScalarValue::Int64(s)) => {
// `skip = NULL` is equivalent to `skip = 0`
let s = s.unwrap_or(0);
if s >= 0 {
Ok(SkipType::Literal(s as usize))
} else {
plan_err!("OFFSET must be >=0, '{}' was provided", s)
}
}
_ => Ok(SkipType::UnsupportedExpr),
},
// `skip = None` is equivalent to `skip = 0`
None => Ok(SkipType::Literal(0)),
}
}

/// Get the fetch type from the limit plan.
pub fn get_fetch_type(&self) -> Result<FetchType> {
match self.fetch.as_deref() {
Some(expr) => match *expr {
Expr::Literal(ScalarValue::Int64(Some(s))) => {
if s >= 0 {
Ok(FetchType::Literal(Some(s as usize)))
} else {
plan_err!("LIMIT must be >= 0, '{}' was provided", s)
}
}
Expr::Literal(ScalarValue::Int64(None)) => Ok(FetchType::Literal(None)),
_ => Ok(FetchType::UnsupportedExpr),
},
None => Ok(FetchType::Literal(None)),
}
}
}

/// Removes duplicate rows from the input
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum Distinct {
Expand Down
Loading