Skip to content

Commit

Permalink
Use f64::total_cmp instead of OrderedFloat (#4133)
Browse files Browse the repository at this point in the history
* Replace OrderedFloat with f64

* clippy

* Adding hasher

* fixed comments

* fixed comments

* fmt

* comments fixed

* removed ordered_flost from toml

* changed cargo.lock
  • Loading branch information
comphead authored Nov 10, 2022
1 parent 509c82c commit 5883e43
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 217 deletions.
14 changes: 1 addition & 13 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ arrow = { version = "26.0.0", default-features = false }
chrono = { version = "0.4", default-features = false }
cranelift-module = { version = "0.89.0", optional = true }
object_store = { version = "0.5.0", default-features = false, optional = true }
ordered-float = "3.0"
parquet = { version = "26.0.0", default-features = false, optional = true }
pyo3 = { version = "0.17.1", optional = true }
sqlparser = "0.26"
79 changes: 40 additions & 39 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ use std::ops::{Add, Sub};
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};

use crate::cast::as_struct_array;
use crate::delta::shift_months;
use crate::error::{DataFusionError, Result};
use arrow::{
array::*,
compute::kernels::cast::{cast, cast_with_options, CastOptions},
Expand All @@ -37,11 +40,6 @@ use arrow::{
},
};
use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime};
use ordered_float::OrderedFloat;

use crate::cast::as_struct_array;
use crate::delta::shift_months;
use crate::error::{DataFusionError, Result};

/// Represents a dynamically typed, nullable single value.
/// This is the single-valued counter-part of arrow's `Array`.
Expand Down Expand Up @@ -116,8 +114,7 @@ pub enum ScalarValue {
Dictionary(Box<DataType>, Box<ScalarValue>),
}

// manual implementation of `PartialEq` that uses OrderedFloat to
// get defined behavior for floating point
// manual implementation of `PartialEq`
impl PartialEq for ScalarValue {
fn eq(&self, other: &Self) -> bool {
use ScalarValue::*;
Expand All @@ -131,17 +128,15 @@ impl PartialEq for ScalarValue {
(Decimal128(_, _, _), _) => false,
(Boolean(v1), Boolean(v2)) => v1.eq(v2),
(Boolean(_), _) => false,
(Float32(v1), Float32(v2)) => {
let v1 = v1.map(OrderedFloat);
let v2 = v2.map(OrderedFloat);
v1.eq(&v2)
}
(Float32(v1), Float32(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
},
(Float32(_), _) => false,
(Float64(v1), Float64(v2)) => {
let v1 = v1.map(OrderedFloat);
let v2 = v2.map(OrderedFloat);
v1.eq(&v2)
}
(Float64(v1), Float64(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
},
(Float64(_), _) => false,
(Int8(v1), Int8(v2)) => v1.eq(v2),
(Int8(_), _) => false,
Expand Down Expand Up @@ -201,8 +196,7 @@ impl PartialEq for ScalarValue {
}
}

// manual implementation of `PartialOrd` that uses OrderedFloat to
// get defined behavior for floating point
// manual implementation of `PartialOrd`
impl PartialOrd for ScalarValue {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
use ScalarValue::*;
Expand All @@ -221,17 +215,15 @@ impl PartialOrd for ScalarValue {
(Decimal128(_, _, _), _) => None,
(Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2),
(Boolean(_), _) => None,
(Float32(v1), Float32(v2)) => {
let v1 = v1.map(OrderedFloat);
let v2 = v2.map(OrderedFloat);
v1.partial_cmp(&v2)
}
(Float32(v1), Float32(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
},
(Float32(_), _) => None,
(Float64(v1), Float64(v2)) => {
let v1 = v1.map(OrderedFloat);
let v2 = v2.map(OrderedFloat);
v1.partial_cmp(&v2)
}
(Float64(v1), Float64(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
},
(Float64(_), _) => None,
(Int8(v1), Int8(v2)) => v1.partial_cmp(v2),
(Int8(_), _) => None,
Expand Down Expand Up @@ -625,8 +617,23 @@ where
intermediate.add(Duration::milliseconds(ms as i64))
}

// manual implementation of `Hash` that uses OrderedFloat to
// get defined behavior for floating point
//Float wrapper over f32/f64. Just because we cannot build std::hash::Hash for floats directly we have to do it through type wrapper
struct Fl<T>(T);

macro_rules! hash_float_value {
($(($t:ty, $i:ty)),+) => {
$(impl std::hash::Hash for Fl<$t> {
#[inline]
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
state.write(&<$i>::from_ne_bytes(self.0.to_ne_bytes()).to_ne_bytes())
}
})+
};
}

hash_float_value!((f64, u64), (f32, u32));

// manual implementation of `Hash`
impl std::hash::Hash for ScalarValue {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
use ScalarValue::*;
Expand All @@ -637,14 +644,8 @@ impl std::hash::Hash for ScalarValue {
s.hash(state)
}
Boolean(v) => v.hash(state),
Float32(v) => {
let v = v.map(OrderedFloat);
v.hash(state)
}
Float64(v) => {
let v = v.map(OrderedFloat);
v.hash(state)
}
Float32(v) => v.map(Fl).hash(state),
Float64(v) => v.map(Fl).hash(state),
Int8(v) => v.hash(state),
Int16(v) => v.hash(state),
Int32(v) => v.hash(state),
Expand Down
1 change: 0 additions & 1 deletion datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ log = "^0.4"
num-traits = { version = "0.2", optional = true }
num_cpus = "1.13.0"
object_store = "0.5.0"
ordered-float = "3.0"
parking_lot = "0.12"
parquet = { version = "26.0.0", features = ["arrow", "async"] }
paste = "^1.0"
Expand Down
1 change: 0 additions & 1 deletion datafusion/physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ itertools = { version = "0.10", features = ["use_std"] }
lazy_static = { version = "^1.4.0" }
md-5 = { version = "^0.10.0", optional = true }
num-traits = { version = "0.2", default-features = false }
ordered-float = "3.0"
paste = "^1.0"
rand = "0.8"
regex = { version = "^1.4.3", optional = true }
Expand Down
10 changes: 3 additions & 7 deletions datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::aggregate::tdigest::TryIntoOrderedF64;
use crate::aggregate::tdigest::TryIntoF64;
use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE};
use crate::expressions::{format_state_name, Literal};
use crate::{AggregateExpr, PhysicalExpr};
Expand All @@ -30,7 +30,6 @@ use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_expr::{Accumulator, AggregateState};
use ordered_float::OrderedFloat;
use std::{any::Any, iter, sync::Arc};

/// APPROX_PERCENTILE_CONT aggregate expression
Expand Down Expand Up @@ -267,9 +266,7 @@ impl ApproxPercentileAccumulator {
self.digest = TDigest::merge_digests(digests);
}

pub(crate) fn convert_to_ordered_float(
values: &ArrayRef,
) -> Result<Vec<OrderedFloat<f64>>> {
pub(crate) fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
match values.data_type() {
DataType::Float64 => {
let array = downcast_value!(values, Float64Array);
Expand Down Expand Up @@ -371,8 +368,7 @@ impl Accumulator for ApproxPercentileAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &values[0];
let sorted_values = &arrow::compute::sort(values, None)?;
let sorted_values =
ApproxPercentileAccumulator::convert_to_ordered_float(sorted_values)?;
let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
self.digest = self.digest.merge_sorted_f64(&sorted_values);
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator {
weights.len(),
"invalid number of values in means and weights"
);
let means_f64 = ApproxPercentileAccumulator::convert_to_ordered_float(means)?;
let weights_f64 = ApproxPercentileAccumulator::convert_to_ordered_float(weights)?;
let means_f64 = ApproxPercentileAccumulator::convert_to_float(means)?;
let weights_f64 = ApproxPercentileAccumulator::convert_to_float(weights)?;
let mut digests: Vec<TDigest> = vec![];
for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
digests.push(TDigest::new_with_centroid(
Expand Down
7 changes: 3 additions & 4 deletions datafusion/physical-expr/src/aggregate/count_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ mod tests {

macro_rules! test_count_distinct_update_batch_floating_point {
($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
use ordered_float::OrderedFloat;
let values: Vec<Option<$PRIM_TYPE>> = vec![
Some(<$PRIM_TYPE>::INFINITY),
Some(<$PRIM_TYPE>::NAN),
Expand All @@ -437,10 +436,10 @@ mod tests {

let mut state_vec =
state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap();

dbg!(&state_vec);
state_vec.sort_by(|a, b| match (a, b) {
(Some(lhs), Some(rhs)) => {
OrderedFloat::from(*lhs).cmp(&OrderedFloat::from(*rhs))
}
(Some(lhs), Some(rhs)) => lhs.total_cmp(rhs),
_ => a.partial_cmp(b).unwrap(),
});

Expand Down
Loading

0 comments on commit 5883e43

Please sign in to comment.