diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 493fb97b82b1..aa7b6a9f900f 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -169,6 +169,10 @@ enum AggregateFunction { COUNT = 4; APPROX_DISTINCT = 5; ARRAY_AGG = 6; + VARIANCE=7; + VARIANCE_POP=8; + STDDEV=9; + STDDEV_POP=10; } message AggregateExprNode { diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index c8ec304fbcde..01428d9ba7a7 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1026,6 +1026,14 @@ impl TryInto for &Expr { AggregateFunction::Sum => protobuf::AggregateFunction::Sum, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, AggregateFunction::Count => protobuf::AggregateFunction::Count, + AggregateFunction::Variance => protobuf::AggregateFunction::Variance, + AggregateFunction::VariancePop => { + protobuf::AggregateFunction::VariancePop + } + AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, + AggregateFunction::StddevPop => { + protobuf::AggregateFunction::StddevPop + } }; let arg = &args[0]; @@ -1256,6 +1264,10 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Count => Self::Count, AggregateFunction::ApproxDistinct => Self::ApproxDistinct, AggregateFunction::ArrayAgg => Self::ArrayAgg, + AggregateFunction::Variance => Self::Variance, + AggregateFunction::VariancePop => Self::VariancePop, + AggregateFunction::Stddev => Self::Stddev, + AggregateFunction::StddevPop => Self::StddevPop, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index f5442c40e660..fd3b57b3deda 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -119,6 +119,10 @@ impl From for AggregateFunction { AggregateFunction::ApproxDistinct } protobuf::AggregateFunction::ArrayAgg => AggregateFunction::ArrayAgg, + protobuf::AggregateFunction::Variance => AggregateFunction::Variance, + protobuf::AggregateFunction::VariancePop => AggregateFunction::VariancePop, + protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev, + protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop, } } } diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index 7040d345aece..7445c9067981 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -359,7 +359,7 @@ impl ConstEvaluator { } /// Internal helper to evaluates an Expr - fn evaluate_to_scalar(&self, expr: Expr) -> Result { + pub(crate) fn evaluate_to_scalar(&self, expr: Expr) -> Result { if let Expr::Literal(s) = expr { return Ok(s); } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index e9f9696a56e8..07b0ff8b33b2 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -35,7 +35,9 @@ use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_t use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use expressions::{avg_return_type, sum_return_type}; +use expressions::{ + avg_return_type, stddev_return_type, sum_return_type, variance_return_type, +}; use std::{fmt, str::FromStr, sync::Arc}; /// the implementation of an aggregate function @@ -64,6 +66,14 @@ pub enum AggregateFunction { ApproxDistinct, /// array_agg ArrayAgg, + /// Variance (Sample) + Variance, + /// Variance (Population) + VariancePop, + /// Standard Deviation (Sample) + Stddev, + /// Standard Deviation (Population) + StddevPop, } impl fmt::Display for AggregateFunction { @@ -84,6 +94,12 @@ impl FromStr for AggregateFunction { "sum" => AggregateFunction::Sum, "approx_distinct" => AggregateFunction::ApproxDistinct, "array_agg" => AggregateFunction::ArrayAgg, + "var" => AggregateFunction::Variance, + "var_samp" => AggregateFunction::Variance, + "var_pop" => AggregateFunction::VariancePop, + "stddev" => AggregateFunction::Stddev, + "stddev_samp" => AggregateFunction::Stddev, + "stddev_pop" => AggregateFunction::StddevPop, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -116,6 +132,10 @@ pub fn return_type( Ok(coerced_data_types[0].clone()) } AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), + AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]), + AggregateFunction::VariancePop => variance_return_type(&coerced_data_types[0]), + AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]), + AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]), AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new( "item", @@ -212,6 +232,48 @@ pub fn create_aggregate_expr( "AVG(DISTINCT) aggregations are not available".to_string(), )); } + (AggregateFunction::Variance, false) => Arc::new(expressions::Variance::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )), + (AggregateFunction::Variance, true) => { + return Err(DataFusionError::NotImplemented( + "VAR(DISTINCT) aggregations are not available".to_string(), + )); + } + (AggregateFunction::VariancePop, false) => { + Arc::new(expressions::VariancePop::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )) + } + (AggregateFunction::VariancePop, true) => { + return Err(DataFusionError::NotImplemented( + "VAR_POP(DISTINCT) aggregations are not available".to_string(), + )); + } + (AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )), + (AggregateFunction::Stddev, true) => { + return Err(DataFusionError::NotImplemented( + "STDDEV(DISTINCT) aggregations are not available".to_string(), + )); + } + (AggregateFunction::StddevPop, false) => Arc::new(expressions::StddevPop::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )), + (AggregateFunction::StddevPop, true) => { + return Err(DataFusionError::NotImplemented( + "STDDEV_POP(DISTINCT) aggregations are not available".to_string(), + )); + } }) } @@ -256,7 +318,12 @@ pub fn signature(fun: &AggregateFunction) -> Signature { .collect::>(); Signature::uniform(1, valid, Volatility::Immutable) } - AggregateFunction::Avg | AggregateFunction::Sum => { + AggregateFunction::Avg + | AggregateFunction::Sum + | AggregateFunction::Variance + | AggregateFunction::VariancePop + | AggregateFunction::Stddev + | AggregateFunction::StddevPop => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } } @@ -267,7 +334,7 @@ mod tests { use super::*; use crate::error::Result; use crate::physical_plan::expressions::{ - ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Sum, + ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Stddev, Sum, Variance, }; #[test] @@ -450,6 +517,158 @@ mod tests { Ok(()) } + #[test] + fn test_variance_expr() -> Result<()> { + let funcs = vec![AggregateFunction::Variance]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + if fun == AggregateFunction::Variance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } + } + } + Ok(()) + } + + #[test] + fn test_var_pop_expr() -> Result<()> { + let funcs = vec![AggregateFunction::VariancePop]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + if fun == AggregateFunction::Variance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } + } + } + Ok(()) + } + + #[test] + fn test_stddev_expr() -> Result<()> { + let funcs = vec![AggregateFunction::Stddev]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + if fun == AggregateFunction::Variance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } + } + } + Ok(()) + } + + #[test] + fn test_stddev_pop_expr() -> Result<()> { + let funcs = vec![AggregateFunction::StddevPop]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + if fun == AggregateFunction::Variance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } + } + } + Ok(()) + } + #[test] fn test_min_max() -> Result<()> { let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?; @@ -544,4 +763,56 @@ mod tests { let observed = return_type(&AggregateFunction::Avg, &[DataType::Utf8]); assert!(observed.is_err()); } + + #[test] + fn test_variance_return_type() -> Result<()> { + let observed = return_type(&AggregateFunction::Variance, &[DataType::Float32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Variance, &[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Variance, &[DataType::Int32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Variance, &[DataType::UInt32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Variance, &[DataType::Int64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_variance_no_utf8() { + let observed = return_type(&AggregateFunction::Variance, &[DataType::Utf8]); + assert!(observed.is_err()); + } + + #[test] + fn test_stddev_return_type() -> Result<()> { + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Float32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Stddev, &[DataType::UInt32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_stddev_no_utf8() { + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Utf8]); + assert!(observed.is_err()); + } } diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index e76e4a6b023e..d74b4e465c89 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -21,7 +21,8 @@ use crate::arrow::datatypes::Schema; use crate::error::{DataFusionError, Result}; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::{ - is_avg_support_arg_type, is_sum_support_arg_type, try_cast, + is_avg_support_arg_type, is_stddev_support_arg_type, is_sum_support_arg_type, + is_variance_support_arg_type, try_cast, }; use crate::physical_plan::functions::{Signature, TypeSignature}; use crate::physical_plan::PhysicalExpr; @@ -86,6 +87,42 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::Variance => { + if !is_variance_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::VariancePop => { + if !is_variance_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::Stddev => { + if !is_stddev_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::StddevPop => { + if !is_stddev_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } } } diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 134c6d89ac4f..a85d86708557 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -50,8 +50,11 @@ mod nth_value; mod nullif; mod rank; mod row_number; +mod stats; +mod stddev; mod sum; mod try_cast; +mod variance; /// Module with some convenient methods used in expression building pub mod helpers { @@ -84,9 +87,16 @@ pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; +pub use stats::StatsType; +pub(crate) use stddev::{ + is_stddev_support_arg_type, stddev_return_type, Stddev, StddevPop, +}; pub(crate) use sum::is_sum_support_arg_type; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; +pub(crate) use variance::{ + is_variance_support_arg_type, variance_return_type, Variance, VariancePop, +}; /// returns the name of the state pub fn format_state_name(name: &str, state_name: &str) -> String { diff --git a/datafusion/src/physical_plan/expressions/stats.rs b/datafusion/src/physical_plan/expressions/stats.rs new file mode 100644 index 000000000000..3f2d266622de --- /dev/null +++ b/datafusion/src/physical_plan/expressions/stats.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Enum used for differenciating population and sample for statistical functions +#[derive(Debug, Clone, Copy)] +pub enum StatsType { + /// Population + Population, + /// Sample + Sample, +} diff --git a/datafusion/src/physical_plan/expressions/stddev.rs b/datafusion/src/physical_plan/expressions/stddev.rs new file mode 100644 index 000000000000..d6e28f18d355 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/stddev.rs @@ -0,0 +1,421 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{ + expressions::variance::VarianceAccumulator, Accumulator, AggregateExpr, PhysicalExpr, +}; +use crate::scalar::ScalarValue; +use arrow::datatypes::DataType; +use arrow::datatypes::Field; + +use super::{format_state_name, StatsType}; + +/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression +#[derive(Debug)] +pub struct Stddev { + name: String, + expr: Arc, +} + +/// STDDEV_POP population aggregate expression +#[derive(Debug)] +pub struct StddevPop { + name: String, + expr: Arc, +} + +/// function return type of standard deviation +pub(crate) fn stddev_return_type(arg_type: &DataType) -> Result { + match arg_type { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(DataFusionError::Plan(format!( + "STDDEV does not support {:?}", + other + ))), + } +} + +pub(crate) fn is_stddev_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +impl Stddev { + /// Create a new STDDEV aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of stddev just support FLOAT64 and Decimal data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + } + } +} + +impl AggregateExpr for Stddev { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl StddevPop { + /// Create a new STDDEV aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of stddev just support FLOAT64 and Decimal data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + } + } +} + +impl AggregateExpr for StddevPop { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} +/// An accumulator to compute the average +#[derive(Debug)] +pub struct StddevAccumulator { + variance: VarianceAccumulator, +} + +impl StddevAccumulator { + /// Creates a new `StddevAccumulator` + pub fn try_new(s_type: StatsType) -> Result { + Ok(Self { + variance: VarianceAccumulator::try_new(s_type)?, + }) + } +} + +impl Accumulator for StddevAccumulator { + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.variance.get_count()), + self.variance.get_mean(), + self.variance.get_m2(), + ]) + } + + fn update(&mut self, values: &[ScalarValue]) -> Result<()> { + self.variance.update(values) + } + + fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { + self.variance.merge(states) + } + + fn evaluate(&self) -> Result { + let variance = self.variance.evaluate()?; + match variance { + ScalarValue::Float64(e) => { + if e == None { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(e.map(|f| f.sqrt()))) + } + } + _ => Err(DataFusionError::Internal( + "Variance should be f64".to_string(), + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::physical_plan::expressions::col; + use crate::{error::Result, generic_test_op}; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + + #[test] + fn stddev_f64_1() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64])); + generic_test_op!( + a, + DataType::Float64, + StddevPop, + ScalarValue::from(0.5_f64), + DataType::Float64 + ) + } + + #[test] + fn stddev_f64_2() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); + generic_test_op!( + a, + DataType::Float64, + StddevPop, + ScalarValue::from(0.7760297817881877), + DataType::Float64 + ) + } + + #[test] + fn stddev_f64_3() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); + generic_test_op!( + a, + DataType::Float64, + StddevPop, + ScalarValue::from(std::f64::consts::SQRT_2), + DataType::Float64 + ) + } + + #[test] + fn stddev_f64_4() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); + generic_test_op!( + a, + DataType::Float64, + Stddev, + ScalarValue::from(0.9504384952922168), + DataType::Float64 + ) + } + + #[test] + fn stddev_i32() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); + generic_test_op!( + a, + DataType::Int32, + StddevPop, + ScalarValue::from(std::f64::consts::SQRT_2), + DataType::Float64 + ) + } + + #[test] + fn stddev_u32() -> Result<()> { + let a: ArrayRef = + Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); + generic_test_op!( + a, + DataType::UInt32, + StddevPop, + ScalarValue::from(std::f64::consts::SQRT_2), + DataType::Float64 + ) + } + + #[test] + fn stddev_f32() -> Result<()> { + let a: ArrayRef = + Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); + generic_test_op!( + a, + DataType::Float32, + StddevPop, + ScalarValue::from(std::f64::consts::SQRT_2), + DataType::Float64 + ) + } + + #[test] + fn test_stddev_return_data_type() -> Result<()> { + let data_type = DataType::Float64; + let result_type = stddev_return_type(&data_type)?; + assert_eq!(DataType::Float64, result_type); + + let data_type = DataType::Decimal(36, 10); + assert!(stddev_return_type(&data_type).is_err()); + Ok(()) + } + + #[test] + fn test_stddev_1_input() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + + let agg = Arc::new(Stddev::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) + } + + #[test] + fn stddev_i32_with_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + Some(4), + Some(5), + ])); + generic_test_op!( + a, + DataType::Int32, + StddevPop, + ScalarValue::from(1.479019945774904), + DataType::Float64 + ) + } + + #[test] + fn stddev_i32_all_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + + let agg = Arc::new(Stddev::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) + } + + fn aggregate( + batch: &RecordBatch, + agg: Arc, + ) -> Result { + let mut accum = agg.create_accumulator()?; + let expr = agg.expressions(); + let values = expr + .iter() + .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + accum.update_batch(&values)?; + accum.evaluate() + } +} diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs new file mode 100644 index 000000000000..3f592b00fd4e --- /dev/null +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -0,0 +1,530 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use crate::scalar::ScalarValue; +use arrow::datatypes::DataType; +use arrow::datatypes::Field; + +use super::{format_state_name, StatsType}; + +/// VAR and VAR_SAMP aggregate expression +#[derive(Debug)] +pub struct Variance { + name: String, + expr: Arc, +} + +/// VAR_POP aggregate expression +#[derive(Debug)] +pub struct VariancePop { + name: String, + expr: Arc, +} + +/// function return type of variance +pub(crate) fn variance_return_type(arg_type: &DataType) -> Result { + match arg_type { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(DataFusionError::Plan(format!( + "VARIANCE does not support {:?}", + other + ))), + } +} + +pub(crate) fn is_variance_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +impl Variance { + /// Create a new VARIANCE aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of variance just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + } + } +} + +impl AggregateExpr for Variance { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl VariancePop { + /// Create a new VAR_POP aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of variance just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + } + } +} + +impl AggregateExpr for VariancePop { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new( + StatsType::Population, + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +/// An accumulator to compute variance +/// The algrithm used is an online implementation and numerically stable. It is based on this paper: +/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". +/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. +/// +/// The algorithm has been analyzed here: +/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". +/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. + +#[derive(Debug)] +pub struct VarianceAccumulator { + m2: ScalarValue, + mean: ScalarValue, + count: u64, + stats_type: StatsType, +} + +impl VarianceAccumulator { + /// Creates a new `VarianceAccumulator` + pub fn try_new(s_type: StatsType) -> Result { + Ok(Self { + m2: ScalarValue::from(0 as f64), + mean: ScalarValue::from(0 as f64), + count: 0, + stats_type: s_type, + }) + } + + pub fn get_count(&self) -> u64 { + self.count + } + + pub fn get_mean(&self) -> ScalarValue { + self.mean.clone() + } + + pub fn get_m2(&self) -> ScalarValue { + self.m2.clone() + } +} + +impl Accumulator for VarianceAccumulator { + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + self.mean.clone(), + self.m2.clone(), + ]) + } + + fn update(&mut self, values: &[ScalarValue]) -> Result<()> { + let values = &values[0]; + let is_empty = values.is_null(); + + if !is_empty { + let new_count = self.count + 1; + let delta1 = ScalarValue::add(values, &self.mean.arithmetic_negate())?; + let new_mean = ScalarValue::add( + &ScalarValue::div(&delta1, &ScalarValue::from(new_count as f64))?, + &self.mean, + )?; + let delta2 = ScalarValue::add(values, &new_mean.arithmetic_negate())?; + let tmp = ScalarValue::mul(&delta1, &delta2)?; + + let new_m2 = ScalarValue::add(&self.m2, &tmp)?; + self.count += 1; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { + let count = &states[0]; + let mean = &states[1]; + let m2 = &states[2]; + let mut new_count: u64 = self.count; + + // counts are summed + if let ScalarValue::UInt64(Some(c)) = count { + if *c == 0_u64 { + return Ok(()); + } + + if self.count == 0 { + self.count = *c; + self.mean = mean.clone(); + self.m2 = m2.clone(); + return Ok(()); + } + new_count += c + } else { + unreachable!() + }; + + let new_mean = ScalarValue::div( + &ScalarValue::add(&self.mean, mean)?, + &ScalarValue::from(2_f64), + )?; + let delta = ScalarValue::add(&mean.arithmetic_negate(), &self.mean)?; + let delta_sqrt = ScalarValue::mul(&delta, &delta)?; + let new_m2 = ScalarValue::add( + &ScalarValue::add( + &ScalarValue::mul( + &delta_sqrt, + &ScalarValue::div( + &ScalarValue::mul(&ScalarValue::from(self.count), count)?, + &ScalarValue::from(new_count as f64), + )?, + )?, + &self.m2, + )?, + m2, + )?; + + self.count = new_count; + self.mean = new_mean; + self.m2 = new_m2; + + Ok(()) + } + + fn evaluate(&self) -> Result { + let count = match self.stats_type { + StatsType::Population => self.count, + StatsType::Sample => { + if self.count > 0 { + self.count - 1 + } else { + self.count + } + } + }; + + if count <= 1 { + return Err(DataFusionError::Internal( + "At least two values are needed to calculate variance".to_string(), + )); + } + + match self.m2 { + ScalarValue::Float64(e) => { + if self.count == 0 { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(e.map(|f| f / count as f64))) + } + } + _ => Err(DataFusionError::Internal( + "M2 should be f64 for variance".to_string(), + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::physical_plan::expressions::col; + use crate::{error::Result, generic_test_op}; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + + #[test] + fn variance_f64_1() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64])); + generic_test_op!( + a, + DataType::Float64, + VariancePop, + ScalarValue::from(0.25_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_f64_2() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); + generic_test_op!( + a, + DataType::Float64, + VariancePop, + ScalarValue::from(2_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_f64_3() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); + generic_test_op!( + a, + DataType::Float64, + Variance, + ScalarValue::from(2.5_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_f64_4() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); + generic_test_op!( + a, + DataType::Float64, + Variance, + ScalarValue::from(0.9033333333333333_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_i32() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); + generic_test_op!( + a, + DataType::Int32, + VariancePop, + ScalarValue::from(2_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_u32() -> Result<()> { + let a: ArrayRef = + Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); + generic_test_op!( + a, + DataType::UInt32, + VariancePop, + ScalarValue::from(2.0f64), + DataType::Float64 + ) + } + + #[test] + fn variance_f32() -> Result<()> { + let a: ArrayRef = + Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); + generic_test_op!( + a, + DataType::Float32, + VariancePop, + ScalarValue::from(2_f64), + DataType::Float64 + ) + } + + #[test] + fn test_variance_return_data_type() -> Result<()> { + let data_type = DataType::Float64; + let result_type = variance_return_type(&data_type)?; + assert_eq!(DataType::Float64, result_type); + + let data_type = DataType::Decimal(36, 10); + assert!(variance_return_type(&data_type).is_err()); + Ok(()) + } + + #[test] + fn test_variance_1_input() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + + let agg = Arc::new(Variance::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) + } + + #[test] + fn variance_i32_with_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + Some(4), + Some(5), + ])); + generic_test_op!( + a, + DataType::Int32, + VariancePop, + ScalarValue::from(2.1875f64), + DataType::Float64 + ) + } + + #[test] + fn variance_i32_all_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + + let agg = Arc::new(Variance::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) + } + + fn aggregate( + batch: &RecordBatch, + agg: Arc, + ) -> Result { + let mut accum = agg.create_accumulator()?; + let expr = agg.expressions(); + let values = expr + .iter() + .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + accum.update_batch(&values)?; + accum.evaluate() + } +} diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index cdcf11eccea2..cf6e8a1ac1c2 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -526,6 +526,301 @@ macro_rules! eq_array_primitive { } impl ScalarValue { + /// Return true if the value is numeric + pub fn is_numeric(&self) -> bool { + matches!( + self, + ScalarValue::Float32(_) + | ScalarValue::Float64(_) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + ) + } + + /// Add two numeric ScalarValues + pub fn add(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(DataFusionError::Internal(format!( + "Addition only supports numeric types, \ + here has {:?} and {:?}", + lhs.get_datatype(), + rhs.get_datatype() + ))); + } + + if lhs.is_null() || rhs.is_null() { + return Err(DataFusionError::Internal( + "Addition does not support empty values".to_string(), + )); + } + + // TODO: Finding a good way to support operation between different types without + // writing a hige match block. + // TODO: Add support for decimal types + match (lhs, rhs) { + (ScalarValue::Decimal128(_, _, _), _) | + (_, ScalarValue::Decimal128(_, _, _)) => { + Err(DataFusionError::Internal( + "Addition with Decimals are not supported for now".to_string() + )) + }, + // f64 / _ + (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() + f2.unwrap()))) + }, + // f32 / _ + (ScalarValue::Float32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Float32(f1), ScalarValue::Float32(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap() as f64))) + }, + // i64 / _ + (ScalarValue::Int64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Int64(f1), ScalarValue::Int64(f2)) => { + Ok(ScalarValue::Int64(Some(f1.unwrap() + f2.unwrap()))) + }, + // i32 / _ + (ScalarValue::Int32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Int32(f1), ScalarValue::Int32(f2)) => { + Ok(ScalarValue::Int64(Some(f1.unwrap() as i64 + f2.unwrap() as i64))) + }, + // i16 / _ + (ScalarValue::Int16(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Int16(f1), ScalarValue::Int16(f2)) => { + Ok(ScalarValue::Int32(Some(f1.unwrap() as i32 + f2.unwrap() as i32))) + }, + // i8 / _ + (ScalarValue::Int8(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Int8(f1), ScalarValue::Int8(f2)) => { + Ok(ScalarValue::Int16(Some(f1.unwrap() as i16 + f2.unwrap() as i16))) + }, + // u64 / _ + (ScalarValue::UInt64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::UInt64(f1), ScalarValue::UInt64(f2)) => { + Ok(ScalarValue::UInt64(Some(f1.unwrap() as u64 + f2.unwrap() as u64))) + }, + // u32 / _ + (ScalarValue::UInt32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::UInt32(f1), ScalarValue::UInt32(f2)) => { + Ok(ScalarValue::UInt64(Some(f1.unwrap() as u64 + f2.unwrap() as u64))) + }, + // u16 / _ + (ScalarValue::UInt16(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::UInt16(f1), ScalarValue::UInt16(f2)) => { + Ok(ScalarValue::UInt32(Some(f1.unwrap() as u32 + f2.unwrap() as u32))) + }, + // u8 / _ + (ScalarValue::UInt8(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => { + Ok(ScalarValue::UInt16(Some(f1.unwrap() as u16 + f2.unwrap() as u16))) + }, + _ => Err(DataFusionError::Internal( + format!( + "Addition only support calculation with the same type or f64 as one of the numbers for now, here has {:?} and {:?}", + lhs.get_datatype(), rhs.get_datatype() + ))), + } + } + + /// Multiply two numeric ScalarValues + pub fn mul(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(DataFusionError::Internal(format!( + "Multiplication is only supported on numeric types, \ + here has {:?} and {:?}", + lhs.get_datatype(), + rhs.get_datatype() + ))); + } + + if lhs.is_null() || rhs.is_null() { + return Err(DataFusionError::Internal( + "Multiplication does not support empty values".to_string(), + )); + } + + // TODO: Finding a good way to support operation between different types without + // writing a hige match block. + // TODO: Add support for decimal type + match (lhs, rhs) { + (ScalarValue::Decimal128(_, _, _), _) + | (_, ScalarValue::Decimal128(_, _, _)) => Err(DataFusionError::Internal( + "Multiplication with Decimals are not supported for now".to_string(), + )), + // f64 / _ + (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() * f2.unwrap()))) + } + // f32 / _ + (ScalarValue::Float32(f1), ScalarValue::Float32(f2)) => Ok( + ScalarValue::Float64(Some(f1.unwrap() as f64 * f2.unwrap() as f64)), + ), + // i64 / _ + (ScalarValue::Int64(f1), ScalarValue::Int64(f2)) => { + Ok(ScalarValue::Int64(Some(f1.unwrap() * f2.unwrap()))) + } + // i32 / _ + (ScalarValue::Int32(f1), ScalarValue::Int32(f2)) => Ok(ScalarValue::Int64( + Some(f1.unwrap() as i64 * f2.unwrap() as i64), + )), + // i16 / _ + (ScalarValue::Int16(f1), ScalarValue::Int16(f2)) => Ok(ScalarValue::Int32( + Some(f1.unwrap() as i32 * f2.unwrap() as i32), + )), + // i8 / _ + (ScalarValue::Int8(f1), ScalarValue::Int8(f2)) => Ok(ScalarValue::Int16( + Some(f1.unwrap() as i16 * f2.unwrap() as i16), + )), + // u64 / _ + (ScalarValue::UInt64(f1), ScalarValue::UInt64(f2)) => Ok( + ScalarValue::UInt64(Some(f1.unwrap() as u64 * f2.unwrap() as u64)), + ), + // u32 / _ + (ScalarValue::UInt32(f1), ScalarValue::UInt32(f2)) => Ok( + ScalarValue::UInt64(Some(f1.unwrap() as u64 * f2.unwrap() as u64)), + ), + // u16 / _ + (ScalarValue::UInt16(f1), ScalarValue::UInt16(f2)) => Ok( + ScalarValue::UInt32(Some(f1.unwrap() as u32 * f2.unwrap() as u32)), + ), + // u8 / _ + (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => Ok(ScalarValue::UInt16( + Some(f1.unwrap() as u16 * f2.unwrap() as u16), + )), + _ => Err(DataFusionError::Internal(format!( + "Multiplication only support f64 for now, here has {:?} and {:?}", + lhs.get_datatype(), + rhs.get_datatype() + ))), + } + } + + /// Division between two numeric ScalarValues + pub fn div(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(DataFusionError::Internal(format!( + "Division is only supported on numeric types, \ + here has {:?} and {:?}", + lhs.get_datatype(), + rhs.get_datatype() + ))); + } + + if lhs.is_null() || rhs.is_null() { + return Err(DataFusionError::Internal( + "Division does not support empty values".to_string(), + )); + } + + // TODO: Finding a good way to support operation between different types without + // writing a hige match block. + // TODO: Add support for decimal types + match (lhs, rhs) { + (ScalarValue::Decimal128(_, _, _), _) | + (_, ScalarValue::Decimal128(_, _, _)) => { + Err(DataFusionError::Internal( + "Division with Decimals are not supported for now".to_string() + )) + }, + // f64 / _ + (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() / f2.unwrap()))) + }, + // f32 / _ + (ScalarValue::Float32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64/ f2.unwrap()))) + }, + (ScalarValue::Float32(f1), ScalarValue::Float32(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64/ f2.unwrap() as f64))) + }, + // i64 / _ + (ScalarValue::Int64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::Int64(f1), ScalarValue::Int64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // i32 / _ + (ScalarValue::Int32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::Int32(f1), ScalarValue::Int32(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // i16 / _ + (ScalarValue::Int16(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::Int16(f1), ScalarValue::Int16(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // i8 / _ + (ScalarValue::Int8(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::Int8(f1), ScalarValue::Int8(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // u64 / _ + (ScalarValue::UInt64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::UInt64(f1), ScalarValue::UInt64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // u32 / _ + (ScalarValue::UInt32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::UInt32(f1), ScalarValue::UInt32(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // u16 / _ + (ScalarValue::UInt16(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::UInt16(f1), ScalarValue::UInt16(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // u8 / _ + (ScalarValue::UInt8(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + _ => Err(DataFusionError::Internal( + format!( + "Division only support calculation with the same type or f64 as denominator for now, here has {:?} and {:?}", + lhs.get_datatype(), rhs.get_datatype() + ))), + } + } + /// Create a decimal Scalar from value/precision and scale. pub fn try_new_decimal128( value: i128, @@ -3081,4 +3376,245 @@ mod tests { DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned())) ); } + + macro_rules! test_scalar_op { + ($OP:ident, $LHS:expr, $LHS_TYPE:ident, $RHS:expr, $RHS_TYPE:ident, $RESULT:expr, $RESULT_TYPE:ident) => {{ + let v1 = &ScalarValue::from($LHS as $LHS_TYPE); + let v2 = &ScalarValue::from($RHS as $RHS_TYPE); + assert_eq!( + ScalarValue::$OP(v1, v2).unwrap(), + ScalarValue::from($RESULT as $RESULT_TYPE) + ); + }}; + } + + macro_rules! test_scalar_op_err { + ($OP:ident, $LHS:expr, $LHS_TYPE:ident, $RHS:expr, $RHS_TYPE:ident) => {{ + let v1 = &ScalarValue::from($LHS as $LHS_TYPE); + let v2 = &ScalarValue::from($RHS as $RHS_TYPE); + let actual = ScalarValue::$OP(v1, v2).is_err(); + assert!(actual); + }}; + } + + #[test] + fn scalar_addition() { + test_scalar_op!(add, 1, f64, 2, f64, 3, f64); + test_scalar_op!(add, 1, f32, 2, f32, 3, f64); + test_scalar_op!(add, 1, i64, 2, i64, 3, i64); + test_scalar_op!(add, 100, i64, -32, i64, 68, i64); + test_scalar_op!(add, -102, i64, 32, i64, -70, i64); + test_scalar_op!(add, 1, i32, 2, i32, 3, i64); + test_scalar_op!( + add, + std::i32::MAX, + i32, + std::i32::MAX, + i32, + std::i32::MAX as i64 * 2, + i64 + ); + test_scalar_op!(add, 1, i16, 2, i16, 3, i32); + test_scalar_op!( + add, + std::i16::MAX, + i16, + std::i16::MAX, + i16, + std::i16::MAX as i32 * 2, + i32 + ); + test_scalar_op!(add, 1, i8, 2, i8, 3, i16); + test_scalar_op!( + add, + std::i8::MAX, + i8, + std::i8::MAX, + i8, + std::i8::MAX as i16 * 2, + i16 + ); + test_scalar_op!(add, 1, u64, 2, u64, 3, u64); + test_scalar_op!(add, 1, u32, 2, u32, 3, u64); + test_scalar_op!( + add, + std::u32::MAX, + u32, + std::u32::MAX, + u32, + std::u32::MAX as u64 * 2, + u64 + ); + test_scalar_op!(add, 1, u16, 2, u16, 3, u32); + test_scalar_op!( + add, + std::u16::MAX, + u16, + std::u16::MAX, + u16, + std::u16::MAX as u32 * 2, + u32 + ); + test_scalar_op!(add, 1, u8, 2, u8, 3, u16); + test_scalar_op!( + add, + std::u8::MAX, + u8, + std::u8::MAX, + u8, + std::u8::MAX as u16 * 2, + u16 + ); + test_scalar_op_err!(add, 1, i32, 2, u16); + test_scalar_op_err!(add, 1, i32, 2, u16); + + let v1 = &ScalarValue::from(1); + let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); + assert!(ScalarValue::add(v1, v2).is_err()); + + let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::add(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::add(v1, v2).is_err()); + + let v2 = &ScalarValue::Float32(None); + let v1 = &ScalarValue::from(2); + assert!(ScalarValue::add(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::Float32(None); + assert!(ScalarValue::add(v1, v2).is_err()); + } + + #[test] + fn scalar_multiplication() { + test_scalar_op!(mul, 1, f64, 2, f64, 2, f64); + test_scalar_op!(mul, 1, f32, 2, f32, 2, f64); + test_scalar_op!(mul, 15, i64, 2, i64, 30, i64); + test_scalar_op!(mul, 100, i64, -32, i64, -3200, i64); + test_scalar_op!(mul, -1.1, f64, 2, f64, -2.2, f64); + test_scalar_op!(mul, 1, i32, 2, i32, 2, i64); + test_scalar_op!( + mul, + std::i32::MAX, + i32, + std::i32::MAX, + i32, + std::i32::MAX as i64 * std::i32::MAX as i64, + i64 + ); + test_scalar_op!(mul, 1, i16, 2, i16, 2, i32); + test_scalar_op!( + mul, + std::i16::MAX, + i16, + std::i16::MAX, + i16, + std::i16::MAX as i32 * std::i16::MAX as i32, + i32 + ); + test_scalar_op!(mul, 1, i8, 2, i8, 2, i16); + test_scalar_op!( + mul, + std::i8::MAX, + i8, + std::i8::MAX, + i8, + std::i8::MAX as i16 * std::i8::MAX as i16, + i16 + ); + test_scalar_op!(mul, 1, u64, 2, u64, 2, u64); + test_scalar_op!(mul, 1, u32, 2, u32, 2, u64); + test_scalar_op!( + mul, + std::u32::MAX, + u32, + std::u32::MAX, + u32, + std::u32::MAX as u64 * std::u32::MAX as u64, + u64 + ); + test_scalar_op!(mul, 1, u16, 2, u16, 2, u32); + test_scalar_op!( + mul, + std::u16::MAX, + u16, + std::u16::MAX, + u16, + std::u16::MAX as u32 * std::u16::MAX as u32, + u32 + ); + test_scalar_op!(mul, 1, u8, 2, u8, 2, u16); + test_scalar_op!( + mul, + std::u8::MAX, + u8, + std::u8::MAX, + u8, + std::u8::MAX as u16 * std::u8::MAX as u16, + u16 + ); + test_scalar_op_err!(mul, 1, i32, 2, u16); + test_scalar_op_err!(mul, 1, i32, 2, u16); + + let v1 = &ScalarValue::from(1); + let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); + assert!(ScalarValue::mul(v1, v2).is_err()); + + let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::mul(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::mul(v1, v2).is_err()); + + let v2 = &ScalarValue::Float32(None); + let v1 = &ScalarValue::from(2); + assert!(ScalarValue::mul(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::Float32(None); + assert!(ScalarValue::mul(v1, v2).is_err()); + } + + #[test] + fn scalar_division() { + test_scalar_op!(div, 1, f64, 2, f64, 0.5, f64); + test_scalar_op!(div, 1, f32, 2, f32, 0.5, f64); + test_scalar_op!(div, 15, i64, 2, i64, 7.5, f64); + test_scalar_op!(div, 100, i64, -2, i64, -50, f64); + test_scalar_op!(div, 1, i32, 2, i32, 0.5, f64); + test_scalar_op!(div, 1, i16, 2, i16, 0.5, f64); + test_scalar_op!(div, 1, i8, 2, i8, 0.5, f64); + test_scalar_op!(div, 1, u64, 2, u64, 0.5, f64); + test_scalar_op!(div, 1, u32, 2, u32, 0.5, f64); + test_scalar_op!(div, 1, u16, 2, u16, 0.5, f64); + test_scalar_op!(div, 1, u8, 2, u8, 0.5, f64); + test_scalar_op_err!(div, 1, i32, 2, u16); + test_scalar_op_err!(div, 1, i32, 2, u16); + + let v1 = &ScalarValue::from(1); + let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); + assert!(ScalarValue::div(v1, v2).is_err()); + + let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::div(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::div(v1, v2).is_err()); + + let v2 = &ScalarValue::Float32(None); + let v1 = &ScalarValue::from(2); + assert!(ScalarValue::div(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::Float32(None); + assert!(ScalarValue::div(v1, v2).is_err()); + } } diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 8073862c8d6e..edf530be8b7d 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -49,6 +49,138 @@ async fn csv_query_avg() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_variance_1() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT var_pop(c2) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["1.8675"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_variance_2() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT var_pop(c6) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["26156334342021890000000000000000000000"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_variance_3() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT var_pop(c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.09234223721582163"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_variance_4() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT var(c2) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["1.8863636363636365"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_variance_5() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT var_samp(c2) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["1.8863636363636365"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_1() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev_pop(c2) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["1.3665650368716449"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_2() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev_pop(c6) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["5114326382039172000"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_3() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev_pop(c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.30387865541334363"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_4() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev(c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.3054095399405338"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_5() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev_samp(c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.3054095399405338"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_6() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "select stddev(sq.column1) from (values (1.1), (2.0), (3.0)) as sq"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.9504384952922168"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_external_table_count() { let mut ctx = ExecutionContext::new();