Skip to content

Commit

Permalink
feat: RewriteCycle API for short-circuiting optimizer loops
Browse files Browse the repository at this point in the history
  • Loading branch information
erratic-pattern committed May 15, 2024
1 parent e859426 commit a0b8397
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 55 deletions.
21 changes: 14 additions & 7 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,27 +508,34 @@ fn test_simplify(input_expr: Expr, expected_expr: Expr) {
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
);
}
fn test_simplify_with_cycle_count(
fn test_simplify_with_cycle_info(
input_expr: Expr,
expected_expr: Expr,
expected_count: u32,
expected_cycle_count: usize,
expected_iteration_count: usize,
) {
let info: MyInfo = MyInfo {
schema: expr_test_schema(),
execution_props: ExecutionProps::new(),
};
let simplifier = ExprSimplifier::new(info);
let (simplified_expr, count) = simplifier
.simplify_with_cycle_count(input_expr.clone())
let (simplified_expr, info) = simplifier
.simplify_with_cycle_info(input_expr.clone())
.expect("successfully evaluated");

let total_iterations = info.total_iterations();
let completed_cycles = info.completed_cycles();
assert_eq!(
simplified_expr, expected_expr,
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
);
assert_eq!(
count, expected_count,
"Mismatch simplifier cycle count\n Expected: {expected_count}\n Got:{count}"
completed_cycles, expected_cycle_count,
"Mismatch simplifier cycle count\n Expected: {expected_cycle_count}\n Got:{completed_cycles}"
);
assert_eq!(
total_iterations, expected_iteration_count,
"Mismatch simplifier cycle count\n Expected: {expected_iteration_count}\n Got:{total_iterations}"
);
}

Expand Down Expand Up @@ -687,5 +694,5 @@ fn test_simplify_cycles() {
let expr = cast(now(), DataType::Int64)
.lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX));
let expected = lit(true);
test_simplify_with_cycle_count(expr, expected, 3);
test_simplify_with_cycle_info(expr, expected, 2, 7);
}
176 changes: 128 additions & 48 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use std::borrow::Cow;
use std::collections::HashSet;
use std::ops::Not;
use std::ops::{ControlFlow, Not};

use arrow::{
array::{new_null_array, AsArray},
Expand Down Expand Up @@ -92,11 +92,11 @@ pub struct ExprSimplifier<S> {
/// true
canonicalize: bool,
/// Maximum number of simplifier cycles
max_simplifier_cycles: u32,
max_simplifier_cycles: usize,
}

pub const THRESHOLD_INLINE_INLIST: usize = 3;
pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3;
pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: usize = 3;

impl<S: SimplifyInfo> ExprSimplifier<S> {
/// Create a new `ExprSimplifier` with the given `info` such as an
Expand Down Expand Up @@ -175,7 +175,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// assert_eq!(expr, b_lt_2);
/// ```
pub fn simplify(&self, expr: Expr) -> Result<Expr> {
Ok(self.simplify_with_cycle_count(expr)?.0)
Ok(self.simplify_with_cycle_info(expr)?.0)
}

/// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
Expand All @@ -185,36 +185,27 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
///
/// See [Self::simplify] for details and usage examples.
///
pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> {
pub fn simplify_with_cycle_info(
&self,
mut expr: Expr,
) -> Result<(Expr, RewriteCycle)> {
let mut simplifier = Simplifier::new(&self.info);
let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
// let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);

if self.canonicalize {
expr = expr.rewrite(&mut Canonicalizer::new()).data()?
}

// Evaluating constants can enable new simplifications and
// simplifications can enable new constant evaluation
// see `Self::with_max_cycles`
let mut num_cycles = 0;
loop {
let Transformed {
data, transformed, ..
} = expr
.rewrite(&mut const_evaluator)?
.transform_data(|expr| expr.rewrite(&mut simplifier))?
.transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
expr = data;
num_cycles += 1;
if !transformed || num_cycles >= self.max_simplifier_cycles {
break;
}
}
// shorten inlist should be started after other inlist rules are applied
expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;
Ok((expr, num_cycles))
let (mut expr, info) =
rewrite_cycle(expr, self.max_simplifier_cycles, |cycle, mut expr| {
expr = cycle.rewrite(expr, &mut const_evaluator)?;
expr = cycle.rewrite(expr, &mut simplifier)?;
expr = cycle.rewrite(expr, &mut guarantee_rewriter)?;
ControlFlow::Continue(expr)
})?;
expr = expr.rewrite(&mut ShortenInListSimplifier::new()).data()?;
Ok((expr, info))
}

/// Apply type coercion to an [`Expr`] so that it can be
Expand Down Expand Up @@ -378,21 +369,15 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// // Expression: a IS NOT NULL
/// let expr = col("a").is_not_null();
///
/// // When using default maximum cycles, 2 cycles will be performed.
/// let (simplified_expr, count) = simplifier.simplify_with_cycle_count(expr.clone()).unwrap();
/// assert_eq!(simplified_expr, lit(true));
/// // 2 cycles were executed, but only 1 was needed
/// assert_eq!(count, 2);
///
/// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1.
/// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap();
/// let (simplified_expr, info) = simplifier.with_max_cycles(1).simplify_with_cycle_info(expr.clone()).unwrap();
/// // Expression has been rewritten to: (c = a AND b = 1)
/// assert_eq!(simplified_expr, lit(true));
/// // Only 1 cycle was executed
/// assert_eq!(count, 1);
/// assert_eq!(info.completed_cycles(), 1);
///
/// ```
pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self {
pub fn with_max_cycles(mut self, max_simplifier_cycles: usize) -> Self {
self.max_simplifier_cycles = max_simplifier_cycles;
self
}
Expand Down Expand Up @@ -1755,6 +1740,96 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
Ok(Expr::InList(l1))
}

pub struct RewriteCycle {
consecutive_unchanged_count: usize,
total_iterations: usize,
num_rewriters: usize,
}
pub type RewriteCycleResult = ControlFlow<Result<Expr>, Expr>;

impl RewriteCycle {
fn new() -> Self {
RewriteCycle {
// use usize::MAX as default to avoid checking null in is_done() comparison
// real value is set later by record_num_rewriters
num_rewriters: usize::MAX,
consecutive_unchanged_count: 0,
total_iterations: 0,
}
}

pub fn completed_cycles(&self) -> usize {
// default value indicates we have not completed a cycle
if self.num_rewriters == usize::MAX {
0
} else {
self.total_iterations / self.num_rewriters
}
}

pub fn total_iterations(&self) -> usize {
self.total_iterations
}

fn record_num_rewriters(&mut self) {
self.num_rewriters = self.total_iterations;
}

fn is_done(&self) -> bool {
self.consecutive_unchanged_count >= self.num_rewriters
}

pub fn rewrite<R: TreeNodeRewriter<Node = Expr>>(
&mut self,
node: Expr,
rewriter: &mut R,
) -> RewriteCycleResult {
match node.rewrite(rewriter) {
Err(e) => ControlFlow::Break(Err(e)),
Ok(Transformed {
data: node,
transformed,
..
}) => {
self.total_iterations += 1;
if transformed {
self.consecutive_unchanged_count = 0;
} else {
self.consecutive_unchanged_count += 1;
}
if self.is_done() {
ControlFlow::Break(Ok(node))
} else {
ControlFlow::Continue(node)
}
}
}
}
}

pub fn rewrite_cycle<F: FnMut(&mut RewriteCycle, Expr) -> RewriteCycleResult>(
node: Expr,
max_cycles: usize,
mut f: F,
) -> Result<(Expr, RewriteCycle)> {
let mut cycle = RewriteCycle::new();
// run first cycle then record number of rewriters
let node = match f(&mut cycle, node) {
ControlFlow::Break(result) => return result.map(|n| (n, cycle)),
ControlFlow::Continue(node) => node,
};
cycle.record_num_rewriters();
if cycle.is_done() {
return Ok((node, cycle));
}
// run remaining cycles
let node = match (1..max_cycles).try_fold(node, |node, _| f(&mut cycle, node)) {
ControlFlow::Break(result) => result?,
ControlFlow::Continue(node) => node,
};
Ok((node, cycle))
}

#[cfg(test)]
mod tests {
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
Expand Down Expand Up @@ -2954,17 +3029,17 @@ mod tests {
try_simplify(expr).unwrap()
}

fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> {
fn try_simplify_with_cycle_info(expr: Expr) -> Result<(Expr, RewriteCycle)> {
let schema = expr_test_schema();
let execution_props = ExecutionProps::new();
let simplifier = ExprSimplifier::new(
SimplifyContext::new(&execution_props).with_schema(schema),
);
simplifier.simplify_with_cycle_count(expr)
simplifier.simplify_with_cycle_info(expr)
}

fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) {
try_simplify_with_cycle_count(expr).unwrap()
fn simplify_with_cycle_info(expr: Expr) -> (Expr, RewriteCycle) {
try_simplify_with_cycle_info(expr).unwrap()
}

fn simplify_with_guarantee(
Expand Down Expand Up @@ -3680,24 +3755,27 @@ mod tests {
// TRUE
let expr = lit(true);
let expected = lit(true);
let (expr, num_iter) = simplify_with_cycle_count(expr);
let (expr, info) = simplify_with_cycle_info(expr);
assert_eq!(expr, expected);
assert_eq!(num_iter, 1);
assert_eq!(info.completed_cycles(), 1);
assert_eq!(info.total_iterations(), 3);

// (true != NULL) OR (5 > 10)
let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10)));
let expected = lit_bool_null();
let (expr, num_iter) = simplify_with_cycle_count(expr);
let (expr, info) = simplify_with_cycle_info(expr);
assert_eq!(expr, expected);
assert_eq!(num_iter, 2);
assert_eq!(info.completed_cycles(), 1);
assert_eq!(info.total_iterations(), 4);

// NOTE: this currently does not simplify
// (((c4 - 10) + 10) *100) / 100
let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100);
let expected = expr.clone();
let (expr, num_iter) = simplify_with_cycle_count(expr);
let (expr, info) = simplify_with_cycle_info(expr);
assert_eq!(expr, expected);
assert_eq!(num_iter, 1);
assert_eq!(info.completed_cycles(), 1);
assert_eq!(info.total_iterations(), 3);

// ((c4<1 or c3<2) and c3_non_null<3) and false
let expr = col("c4")
Expand All @@ -3706,10 +3784,12 @@ mod tests {
.and(col("c3_non_null").lt(lit(3)))
.and(lit(false));
let expected = lit(false);
let (expr, num_iter) = simplify_with_cycle_count(expr);
let (expr, info) = simplify_with_cycle_info(expr);
assert_eq!(expr, expected);
assert_eq!(num_iter, 2);
assert_eq!(info.completed_cycles(), 1);
assert_eq!(info.total_iterations(), 5);
}

#[test]
fn test_simplify_udaf() {
let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());
Expand Down

0 comments on commit a0b8397

Please sign in to comment.