diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index 7127a8fa94d6..d4774ebd766e 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -662,6 +662,54 @@ impl<'a> ExprRewriter for Simplifier<'a> { _ => unreachable!(), }, + // + // Rules for Case + // + + // CASE + // WHEN X THEN A + // WHEN Y THEN B + // ... + // ELSE Q + // END + // + // ---> (X AND A) OR (Y AND B AND NOT X) OR ... (NOT (X OR Y) AND Q) + // + // Note: the rationale for this rewrite is that the expr can then be further + // simplified using the existing rules for AND/OR + Case { + expr: None, + when_then_expr, + else_expr, + } if !when_then_expr.is_empty() + && when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number + && self.is_boolean_type(&when_then_expr[0].1) => + { + // The disjunction of all the when predicates encountered so far + let mut filter_expr = lit(false); + // The disjunction of all the cases + let mut out_expr = lit(false); + + for (when, then) in when_then_expr { + let case_expr = when + .as_ref() + .clone() + .and(filter_expr.clone().not()) + .and(*then); + + out_expr = out_expr.or(case_expr); + filter_expr = filter_expr.or(*when); + } + + if let Some(else_expr) = else_expr { + let case_expr = filter_expr.not().and(*else_expr); + out_expr = out_expr.or(case_expr); + } + + // Do a first pass at simplification + out_expr.rewrite(self)? + } + expr => { // no additional rewrites possible expr @@ -1175,6 +1223,8 @@ mod tests { .expect("expected to simplify") .rewrite(&mut const_evaluator) .expect("expected to const evaluate") + .rewrite(&mut rewriter) + .expect("expected to simplify") } fn expr_test_schema() -> DFSchemaRef { @@ -1291,6 +1341,11 @@ mod tests { #[test] fn simplify_expr_case_when_then_else() { + // CASE WHERE c2 != false THEN "ok" == "not_ok" ELSE c2 == true + // --> + // CASE WHERE c2 THEN false ELSE c2 + // --> + // false assert_eq!( simplify(Expr::Case { expr: None, @@ -1300,11 +1355,73 @@ mod tests { )], else_expr: Some(Box::new(col("c2").eq(lit(true)))), }), - Expr::Case { + col("c2").not().and(col("c2")) // #1716 + ); + + // CASE WHERE c2 != false THEN "ok" == "ok" ELSE c2 + // --> + // CASE WHERE c2 THEN true ELSE c2 + // --> + // c2 + assert_eq!( + simplify(Expr::Case { expr: None, - when_then_expr: vec![(Box::new(col("c2")), Box::new(lit(false)))], + when_then_expr: vec![( + Box::new(col("c2").not_eq(lit(false))), + Box::new(lit("ok").eq(lit("ok"))), + )], + else_expr: Some(Box::new(col("c2").eq(lit(true)))), + }), + col("c2").or(col("c2").not().and(col("c2"))) // #1716 + ); + + // CASE WHERE ISNULL(c2) THEN true ELSE c2 + // --> + // ISNULL(c2) OR c2 + assert_eq!( + simplify(Expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(col("c2").is_null()), + Box::new(lit(true)), + )], else_expr: Some(Box::new(col("c2"))), - } + }), + col("c2") + .is_null() + .or(col("c2").is_null().not().and(col("c2"))) + ); + + // CASE WHERE c1 then true WHERE c2 then false ELSE true + // --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE) + // --> c1 OR (NOT(c1 OR c2)) + // --> NOT(c1) AND c2 + assert_eq!( + simplify(Expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(col("c1")), Box::new(lit(true)),), + (Box::new(col("c2")), Box::new(lit(false)),) + ], + else_expr: Some(Box::new(lit(true))), + }), + col("c1").or(col("c1").or(col("c2")).not()) + ); + + // CASE WHERE c1 then true WHERE c2 then true ELSE false + // --> c1 OR (NOT(c1) AND c2 AND TRUE) OR (NOT(c1 OR c2) AND FALSE) + // --> c1 OR (NOT(c1) AND c2) + // --> c1 OR c2 + assert_eq!( + simplify(Expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(col("c1")), Box::new(lit(true)),), + (Box::new(col("c2")), Box::new(lit(false)),) + ], + else_expr: Some(Box::new(lit(true))), + }), + col("c1").or(col("c1").or(col("c2")).not()) ); }