Skip to content

Commit

Permalink
chore: clippy and update pruning predicates in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
appletreeisyellow committed Feb 15, 2024
1 parent e9a7592 commit 23b03fa
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 35 deletions.
5 changes: 5 additions & 0 deletions datafusion-examples/examples/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ impl PruningStatistics for MyCatalog {
None
}

fn row_counts(&self, _column: &Column) -> Option<ArrayRef> {
// In this example, we know nothing about the number of rows in each file
None
}

fn contained(
&self,
_column: &Column,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1862,7 +1862,7 @@ mod tests {

assert_contains!(
&display,
"pruning_predicate=c1_min@0 != bar OR bar != c1_max@1"
"pruning_predicate=CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 != bar OR bar != c1_max@1 END"
);

assert_contains!(&display, r#"predicate=c1@0 != bar"#);
Expand Down
191 changes: 157 additions & 34 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1386,31 +1386,31 @@ fn build_statistics_expr(
///
/// For example:
///
/// `x_min <= 10 AND x_max >= 10`
/// `x_min <= 10 AND 10 <= x_max`
///
/// will become
///
/// ```
/// ```sql
/// CASE
/// WHEN y_null_count = y_row_count THEN false
/// ELSE x_min <= 10 AND x_max >= 10
/// WHEN x_null_count = x_row_count THEN false
/// ELSE x_min <= 10 AND 10 <= x_max
/// END
/// ```
/// ````
fn wrap_case_expr(
statistics_expr: Arc<dyn PhysicalExpr>,
expr_builder: &mut PruningExpressionBuilder,
) -> Result<Arc<dyn PhysicalExpr>> {
let nul_count_row_count_comparesion = Arc::new(phys_expr::BinaryExpr::new(
let when_null_count_eq_row_count = Arc::new(phys_expr::BinaryExpr::new(
expr_builder.null_count_column_expr()?,
Operator::Eq,
expr_builder.row_count_column_expr()?,
));
let then = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(false))));

// CASE WHEN x_null_count = x_row_count THEN false ELSE <statistics_expr> END
Ok(Arc::new(phys_expr::CaseExpr::try_new(
Some(nul_count_row_count_comparesion),
vec![(
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))),
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(false)))),
)],
None,
vec![(when_null_count_eq_row_count, then)],
Some(statistics_expr),
)?))
}
Expand Down Expand Up @@ -1999,7 +1999,7 @@ mod tests {
#[test]
fn row_group_predicate_eq() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1";
let expected_expr = "CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 END";

// test column on the left
let expr = col("c1").eq(lit(1));
Expand All @@ -2019,7 +2019,7 @@ mod tests {
#[test]
fn row_group_predicate_not_eq() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr = "c1_min@0 != 1 OR 1 != c1_max@1";
let expected_expr = "CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 != 1 OR 1 != c1_max@1 END";

// test column on the left
let expr = col("c1").not_eq(lit(1));
Expand All @@ -2039,7 +2039,8 @@ mod tests {
#[test]
fn row_group_predicate_gt() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr = "c1_max@0 > 1";
let expected_expr =
"CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_max@0 > 1 END";

// test column on the left
let expr = col("c1").gt(lit(1));
Expand All @@ -2059,7 +2060,7 @@ mod tests {
#[test]
fn row_group_predicate_gt_eq() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr = "c1_max@0 >= 1";
let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_max@0 >= 1 END";

// test column on the left
let expr = col("c1").gt_eq(lit(1));
Expand All @@ -2078,7 +2079,8 @@ mod tests {
#[test]
fn row_group_predicate_lt() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr = "c1_min@0 < 1";
let expected_expr =
"CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < 1 END";

// test column on the left
let expr = col("c1").lt(lit(1));
Expand All @@ -2098,7 +2100,7 @@ mod tests {
#[test]
fn row_group_predicate_lt_eq() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr = "c1_min@0 <= 1";
let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 <= 1 END";

// test column on the left
let expr = col("c1").lt_eq(lit(1));
Expand All @@ -2123,7 +2125,8 @@ mod tests {
]);
// test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression
let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3")));
let expected_expr = "c1_min@0 < 1";
let expected_expr =
"CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < 1 END";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);
Expand Down Expand Up @@ -2189,7 +2192,7 @@ mod tests {
#[test]
fn row_group_predicate_lt_bool() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]);
let expected_expr = "c1_min@0 < true";
let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < true END";

// DF doesn't support arithmetic on boolean columns so
// this predicate will error when evaluated
Expand All @@ -2212,7 +2215,21 @@ mod tests {
let expr = col("c1")
.lt(lit(1))
.and(col("c2").eq(lit(2)).or(col("c2").eq(lit(3))));
let expected_expr = "c1_min@0 < 1 AND (c2_min@1 <= 2 AND 2 <= c2_max@2 OR c2_min@1 <= 3 AND 3 <= c2_max@2)";
let expected_expr = "\
CASE \
WHEN c1_null_count@1 = c1_row_count@2 THEN false \
ELSE c1_min@0 < 1 \
END \
AND (\
CASE \
WHEN c2_null_count@5 = c2_row_count@6 THEN false \
ELSE c2_min@3 <= 2 AND 2 <= c2_max@4 \
END \
OR CASE \
WHEN c2_null_count@5 = c2_row_count@6 THEN false \
ELSE c2_min@3 <= 3 AND 3 <= c2_max@4 \
END\
)";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut required_columns);
assert_eq!(predicate_expr.to_string(), expected_expr);
Expand All @@ -2226,10 +2243,30 @@ mod tests {
c1_min_field.with_nullable(true) // could be nullable if stats are not present
)
);
// c1 < 1 should add c1_null_count
let c1_null_count_field = Field::new("c1_null_count", DataType::Int32, false);
assert_eq!(
required_columns.columns[1],
(
phys_expr::Column::new("c1", 0),
StatisticsType::NullCount,
c1_null_count_field.with_nullable(true) // could be nullable if stats are not present
)
);
// c1 < 1 should add c1_row_count
let c1_row_count_field = Field::new("c1_row_count", DataType::Int32, false);
assert_eq!(
required_columns.columns[2],
(
phys_expr::Column::new("c1", 0),
StatisticsType::RowCount,
c1_row_count_field.with_nullable(true) // could be nullable if stats are not present
)
);
// c2 = 2 should add c2_min and c2_max
let c2_min_field = Field::new("c2_min", DataType::Int32, false);
assert_eq!(
required_columns.columns[1],
required_columns.columns[3],
(
phys_expr::Column::new("c2", 1),
StatisticsType::Min,
Expand All @@ -2238,15 +2275,35 @@ mod tests {
);
let c2_max_field = Field::new("c2_max", DataType::Int32, false);
assert_eq!(
required_columns.columns[2],
required_columns.columns[4],
(
phys_expr::Column::new("c2", 1),
StatisticsType::Max,
c2_max_field.with_nullable(true) // could be nullable if stats are not present
)
);
// c2 = 2 should add c2_null_count
let c2_null_count_field = Field::new("c2_null_count", DataType::Int32, false);
assert_eq!(
required_columns.columns[5],
(
phys_expr::Column::new("c2", 1),
StatisticsType::NullCount,
c2_null_count_field.with_nullable(true) // could be nullable if stats are not present
)
);
// c2 = 2 should add c2_row_count
let c2_row_count_field = Field::new("c2_row_count", DataType::Int32, false);
assert_eq!(
required_columns.columns[6],
(
phys_expr::Column::new("c2", 1),
StatisticsType::RowCount,
c2_row_count_field.with_nullable(true) // could be nullable if stats are not present
)
);
// c2 = 3 shouldn't add any new statistics fields
assert_eq!(required_columns.columns.len(), 3);
assert_eq!(required_columns.columns.len(), 7);

Ok(())
}
Expand All @@ -2263,7 +2320,18 @@ mod tests {
vec![lit(1), lit(2), lit(3)],
false,
));
let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1 OR c1_min@0 <= 3 AND 3 <= c1_max@1";
let expected_expr = "CASE \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 \
END \
OR CASE \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE c1_min@0 <= 2 AND 2 <= c1_max@1 \
END \
OR CASE \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE c1_min@0 <= 3 AND 3 <= c1_max@1 \
END";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);
Expand Down Expand Up @@ -2299,9 +2367,19 @@ mod tests {
vec![lit(1), lit(2), lit(3)],
true,
));
let expected_expr = "(c1_min@0 != 1 OR 1 != c1_max@1) \
AND (c1_min@0 != 2 OR 2 != c1_max@1) \
AND (c1_min@0 != 3 OR 3 != c1_max@1)";
let expected_expr = "\
CASE \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE c1_min@0 != 1 OR 1 != c1_max@1 \
END \
AND CASE \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE c1_min@0 != 2 OR 2 != c1_max@1 \
END \
AND CASE \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE c1_min@0 != 3 OR 3 != c1_max@1 \
END";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);
Expand Down Expand Up @@ -2347,7 +2425,24 @@ mod tests {
// test c1 in(1, 2) and c2 BETWEEN 4 AND 5
let expr3 = expr1.and(expr2);

let expected_expr = "(c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1) AND c2_max@2 >= 4 AND c2_min@3 <= 5";
let expected_expr = "\
(\
CASE \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 \
END \
OR CASE \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE c1_min@0 <= 2 AND 2 <= c1_max@1 \
END\
) AND CASE \
WHEN c2_null_count@5 = c2_row_count@6 THEN false \
ELSE c2_max@4 >= 4 \
END \
AND CASE \
WHEN c2_null_count@5 = c2_row_count@6 THEN false \
ELSE c2_min@7 <= 5 \
END";
let predicate_expr =
test_build_predicate_expression(&expr3, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);
Expand Down Expand Up @@ -2375,7 +2470,10 @@ mod tests {
fn row_group_predicate_cast() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr =
"CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)";
"CASE \
WHEN CAST(c1_null_count@2 AS Int64) = CAST(c1_row_count@3 AS Int64) THEN false \
ELSE CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) \
END";

// test column on the left
let expr = cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1))));
Expand All @@ -2389,7 +2487,11 @@ mod tests {
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

let expected_expr = "TRY_CAST(c1_max@0 AS Int64) > 1";
let expected_expr =
"CASE \
WHEN TRY_CAST(c1_null_count@1 AS Int64) = TRY_CAST(c1_row_count@2 AS Int64) THEN false \
ELSE TRY_CAST(c1_max@0 AS Int64) > 1 \
END";

// test column on the left
let expr =
Expand Down Expand Up @@ -2421,7 +2523,19 @@ mod tests {
],
false,
));
let expected_expr = "CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64)";
let expected_expr =
"CASE \
WHEN CAST(c1_null_count@2 AS Int64) = CAST(c1_row_count@3 AS Int64) THEN false \
ELSE CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) \
END \
OR CASE \
WHEN CAST(c1_null_count@2 AS Int64) = CAST(c1_row_count@3 AS Int64) THEN false \
ELSE CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) \
END \
OR CASE \
WHEN CAST(c1_null_count@2 AS Int64) = CAST(c1_row_count@3 AS Int64) THEN false \
ELSE CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64) \
END";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);
Expand All @@ -2436,9 +2550,18 @@ mod tests {
true,
));
let expected_expr =
"(CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64)) \
AND (CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64)) \
AND (CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64))";
"CASE \
WHEN CAST(c1_null_count@2 AS Int64) = CAST(c1_row_count@3 AS Int64) THEN false \
ELSE CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64) \
END \
AND CASE \
WHEN CAST(c1_null_count@2 AS Int64) = CAST(c1_row_count@3 AS Int64) THEN false \
ELSE CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64) \
END \
AND CASE \
WHEN CAST(c1_null_count@2 AS Int64) = CAST(c1_row_count@3 AS Int64) THEN false \
ELSE CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64) \
END";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);
Expand Down

0 comments on commit 23b03fa

Please sign in to comment.