diff --git a/lib/explorer/backend/lazy_series.ex b/lib/explorer/backend/lazy_series.ex index d97ea6587..cb955afdb 100644 --- a/lib/explorer/backend/lazy_series.ex +++ b/lib/explorer/backend/lazy_series.ex @@ -99,8 +99,8 @@ defmodule Explorer.Backend.LazySeries do median: 1, mode: 1, n_distinct: 1, - variance: 1, - standard_deviation: 1, + variance: 2, + standard_deviation: 2, quantile: 2, rank: 4, product: 1, @@ -152,8 +152,6 @@ defmodule Explorer.Backend.LazySeries do :mean, :median, :mode, - :variance, - :standard_deviation, :count, :product, :nil_count, @@ -511,6 +509,22 @@ defmodule Explorer.Backend.LazySeries do Backend.Series.new(data, :float) end + @impl true + def variance(%Series{} = s, ddof \\ 1) do + args = [series_or_lazy_series!(s), ddof] + data = new(:variance, args, :float, true) + + Backend.Series.new(data, :float) + end + + @impl true + def standard_deviation(%Series{} = s, ddof \\ 1) do + args = [series_or_lazy_series!(s), ddof] + data = new(:standard_deviation, args, :float, true) + + Backend.Series.new(data, :float) + end + @impl true def coalesce(%Series{} = left, %Series{} = right) do args = [series_or_lazy_series!(left), series_or_lazy_series!(right)] diff --git a/lib/explorer/backend/series.ex b/lib/explorer/backend/series.ex index 6434194c6..86e2274a0 100644 --- a/lib/explorer/backend/series.ex +++ b/lib/explorer/backend/series.ex @@ -81,8 +81,9 @@ defmodule Explorer.Backend.Series do @callback mean(s) :: float() | non_finite() | lazy_s() | nil @callback median(s) :: float() | non_finite() | lazy_s() | nil @callback mode(s) :: s | lazy_s() - @callback variance(s) :: float() | non_finite() | lazy_s() | nil - @callback standard_deviation(s) :: float() | non_finite() | lazy_s() | nil + @callback variance(s, ddof :: non_neg_integer()) :: float() | non_finite() | lazy_s() | nil + @callback standard_deviation(s, ddof :: non_neg_integer()) :: + float() | non_finite() | lazy_s() | nil @callback quantile(s, float()) :: number() | non_finite() | Date.t() | NaiveDateTime.t() | lazy_s() | nil @callback nil_count(s) :: number() | lazy_s() diff --git a/lib/explorer/polars_backend/expression.ex b/lib/explorer/polars_backend/expression.ex index 27dfb9d70..6150b44ef 100644 --- a/lib/explorer/polars_backend/expression.ex +++ b/lib/explorer/polars_backend/expression.ex @@ -70,11 +70,11 @@ defmodule Explorer.PolarsBackend.Expression do asin: 1, acos: 1, atan: 1, - standard_deviation: 1, + standard_deviation: 2, subtract: 2, sum: 1, unordered_distinct: 1, - variance: 1, + variance: 2, skew: 2, covariance: 2 ] diff --git a/lib/explorer/polars_backend/native.ex b/lib/explorer/polars_backend/native.ex index 50c030a7a..f2b2a1227 100644 --- a/lib/explorer/polars_backend/native.ex +++ b/lib/explorer/polars_backend/native.ex @@ -350,7 +350,7 @@ defmodule Explorer.PolarsBackend.Native do def s_slice_by_indices(_s, _indices), do: err() def s_slice_by_series(_s, _series), do: err() def s_sort(_s, _descending?, _nils_last?), do: err() - def s_standard_deviation(_s), do: err() + def s_standard_deviation(_s, _ddof), do: err() def s_strip(_s, _string), do: err() def s_subtract(_s, _other), do: err() def s_sum(_s), do: err() @@ -371,7 +371,7 @@ defmodule Explorer.PolarsBackend.Native do def s_qcut(_s, _quantiles, _labels, _break_point_label, _category_label), do: err() - def s_variance(_s), do: err() + def s_variance(_s, _ddof), do: err() def s_window_max(_s, _window_size, _weight, _ignore_null, _min_periods), do: err() def s_window_mean(_s, _window_size, _weight, _ignore_null, _min_periods), do: err() def s_window_median(_s, _window_size, _weight, _ignore_null, _min_periods), do: err() diff --git a/lib/explorer/polars_backend/series.ex b/lib/explorer/polars_backend/series.ex index e1432b185..f0cab49db 100644 --- a/lib/explorer/polars_backend/series.ex +++ b/lib/explorer/polars_backend/series.ex @@ -202,10 +202,11 @@ defmodule Explorer.PolarsBackend.Series do def mode(series), do: Shared.apply_series(series, :s_mode) @impl true - def variance(series), do: Shared.apply_series(series, :s_variance) + def variance(series, ddof), do: Shared.apply_series(series, :s_variance, [ddof]) @impl true - def standard_deviation(series), do: Shared.apply_series(series, :s_standard_deviation) + def standard_deviation(series, ddof), + do: Shared.apply_series(series, :s_standard_deviation, [ddof]) @impl true def quantile(series, quantile), diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index a5d77b142..8247db60a 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -2225,6 +2225,10 @@ defmodule Explorer.Series do @doc """ Gets the variance of the series. + By default, this is the sample variance. This function also takes an optional + delta degrees of freedom (ddof). Setting this to zero corresponds to the population + variance. + ## Supported dtypes * `:integer` @@ -2245,15 +2249,21 @@ defmodule Explorer.Series do ** (ArgumentError) Explorer.Series.variance/1 not implemented for dtype {:datetime, :microsecond}. Valid dtypes are [:integer, :float] """ @doc type: :aggregation - @spec variance(series :: Series.t()) :: float() | non_finite() | nil - def variance(%Series{dtype: dtype} = series) when is_numeric_dtype(dtype), - do: apply_series(series, :variance) + @spec variance(series :: Series.t(), ddof :: non_neg_integer()) :: float() | non_finite() | nil + def variance(series, ddof \\ 1) + + def variance(%Series{dtype: dtype} = series, ddof) when is_numeric_dtype(dtype), + do: apply_series(series, :variance, [ddof]) - def variance(%Series{dtype: dtype}), do: dtype_error("variance/1", dtype, [:integer, :float]) + def variance(%Series{dtype: dtype}, _), do: dtype_error("variance/1", dtype, [:integer, :float]) @doc """ Gets the standard deviation of the series. + By default, this is the sample standard deviation. This function also takes an optional + delta degrees of freedom (ddof). Setting this to zero corresponds to the population + sample standard deviation. + ## Supported dtypes * `:integer` @@ -2274,11 +2284,14 @@ defmodule Explorer.Series do ** (ArgumentError) Explorer.Series.standard_deviation/1 not implemented for dtype :string. Valid dtypes are [:integer, :float] """ @doc type: :aggregation - @spec standard_deviation(series :: Series.t()) :: float() | non_finite() | nil - def standard_deviation(%Series{dtype: dtype} = series) when is_numeric_dtype(dtype), - do: apply_series(series, :standard_deviation) + @spec standard_deviation(series :: Series.t(), ddof :: non_neg_integer()) :: + float() | non_finite() | nil + def standard_deviation(series, ddof \\ 1) + + def standard_deviation(%Series{dtype: dtype} = series, ddof) when is_numeric_dtype(dtype), + do: apply_series(series, :standard_deviation, [ddof]) - def standard_deviation(%Series{dtype: dtype}), + def standard_deviation(%Series{dtype: dtype}, _), do: dtype_error("standard_deviation/1", dtype, [:integer, :float]) @doc """ diff --git a/native/explorer/src/expressions.rs b/native/explorer/src/expressions.rs index 989d2b473..a009dee22 100644 --- a/native/explorer/src/expressions.rs +++ b/native/explorer/src/expressions.rs @@ -490,17 +490,17 @@ pub fn expr_abs(expr: ExExpr) -> ExExpr { } #[rustler::nif] -pub fn expr_variance(expr: ExExpr) -> ExExpr { +pub fn expr_variance(expr: ExExpr, ddof: u8) -> ExExpr { let expr = expr.clone_inner(); - ExExpr::new(expr.var(1)) + ExExpr::new(expr.var(ddof)) } #[rustler::nif] -pub fn expr_standard_deviation(expr: ExExpr) -> ExExpr { +pub fn expr_standard_deviation(expr: ExExpr, ddof: u8) -> ExExpr { let expr = expr.clone_inner(); - ExExpr::new(expr.std(1)) + ExExpr::new(expr.std(ddof)) } #[rustler::nif] diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index eb0a36edc..20d2c05b8 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -955,19 +955,19 @@ pub fn s_product(s: ExSeries) -> Result { } #[rustler::nif(schedule = "DirtyCpu")] -pub fn s_variance(env: Env, s: ExSeries) -> Result { +pub fn s_variance(env: Env, s: ExSeries, ddof: u8) -> Result { match s.dtype() { - DataType::Int64 => Ok(s.i64()?.var(1).encode(env)), - DataType::Float64 => Ok(term_from_optional_float(s.f64()?.var(1), env)), + DataType::Int64 => Ok(s.i64()?.var(ddof).encode(env)), + DataType::Float64 => Ok(term_from_optional_float(s.f64()?.var(ddof), env)), dt => panic!("var/1 not implemented for {dt:?}"), } } #[rustler::nif(schedule = "DirtyCpu")] -pub fn s_standard_deviation(env: Env, s: ExSeries) -> Result { +pub fn s_standard_deviation(env: Env, s: ExSeries, ddof: u8) -> Result { match s.dtype() { - DataType::Int64 => Ok(s.i64()?.std(1).encode(env)), - DataType::Float64 => Ok(term_from_optional_float(s.f64()?.std(1), env)), + DataType::Int64 => Ok(s.i64()?.std(ddof).encode(env)), + DataType::Float64 => Ok(term_from_optional_float(s.f64()?.std(ddof), env)), dt => panic!("std/1 not implemented for {dt:?}"), } }