Skip to content

Commit

Permalink
Expose DDOF for variance and std dev
Browse files Browse the repository at this point in the history
  • Loading branch information
cigrainger committed Nov 23, 2023
1 parent 3619f77 commit 548979b
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 30 deletions.
22 changes: 18 additions & 4 deletions lib/explorer/backend/lazy_series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ defmodule Explorer.Backend.LazySeries do
median: 1,
mode: 1,
n_distinct: 1,
variance: 1,
standard_deviation: 1,
variance: 2,
standard_deviation: 2,
quantile: 2,
rank: 4,
product: 1,
Expand Down Expand Up @@ -152,8 +152,6 @@ defmodule Explorer.Backend.LazySeries do
:mean,
:median,
:mode,
:variance,
:standard_deviation,
:count,
:product,
:nil_count,
Expand Down Expand Up @@ -511,6 +509,22 @@ defmodule Explorer.Backend.LazySeries do
Backend.Series.new(data, :float)
end

@impl true
def variance(%Series{} = s, ddof \\ 1) do
args = [series_or_lazy_series!(s), ddof]
data = new(:variance, args, :float, true)

Backend.Series.new(data, :float)
end

@impl true
def standard_deviation(%Series{} = s, ddof \\ 1) do
args = [series_or_lazy_series!(s), ddof]
data = new(:standard_deviation, args, :float, true)

Backend.Series.new(data, :float)
end

@impl true
def coalesce(%Series{} = left, %Series{} = right) do
args = [series_or_lazy_series!(left), series_or_lazy_series!(right)]
Expand Down
5 changes: 3 additions & 2 deletions lib/explorer/backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ defmodule Explorer.Backend.Series do
@callback mean(s) :: float() | non_finite() | lazy_s() | nil
@callback median(s) :: float() | non_finite() | lazy_s() | nil
@callback mode(s) :: s | lazy_s()
@callback variance(s) :: float() | non_finite() | lazy_s() | nil
@callback standard_deviation(s) :: float() | non_finite() | lazy_s() | nil
@callback variance(s, ddof :: non_neg_integer()) :: float() | non_finite() | lazy_s() | nil
@callback standard_deviation(s, ddof :: non_neg_integer()) ::
float() | non_finite() | lazy_s() | nil
@callback quantile(s, float()) ::
number() | non_finite() | Date.t() | NaiveDateTime.t() | lazy_s() | nil
@callback nil_count(s) :: number() | lazy_s()
Expand Down
4 changes: 2 additions & 2 deletions lib/explorer/polars_backend/expression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ defmodule Explorer.PolarsBackend.Expression do
asin: 1,
acos: 1,
atan: 1,
standard_deviation: 1,
standard_deviation: 2,
subtract: 2,
sum: 1,
unordered_distinct: 1,
variance: 1,
variance: 2,
skew: 2,
covariance: 2
]
Expand Down
4 changes: 2 additions & 2 deletions lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ defmodule Explorer.PolarsBackend.Native do
def s_slice_by_indices(_s, _indices), do: err()
def s_slice_by_series(_s, _series), do: err()
def s_sort(_s, _descending?, _nils_last?), do: err()
def s_standard_deviation(_s), do: err()
def s_standard_deviation(_s, _ddof), do: err()
def s_strip(_s, _string), do: err()
def s_subtract(_s, _other), do: err()
def s_sum(_s), do: err()
Expand All @@ -371,7 +371,7 @@ defmodule Explorer.PolarsBackend.Native do
def s_qcut(_s, _quantiles, _labels, _break_point_label, _category_label),
do: err()

def s_variance(_s), do: err()
def s_variance(_s, _ddof), do: err()
def s_window_max(_s, _window_size, _weight, _ignore_null, _min_periods), do: err()
def s_window_mean(_s, _window_size, _weight, _ignore_null, _min_periods), do: err()
def s_window_median(_s, _window_size, _weight, _ignore_null, _min_periods), do: err()
Expand Down
5 changes: 3 additions & 2 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,11 @@ defmodule Explorer.PolarsBackend.Series do
def mode(series), do: Shared.apply_series(series, :s_mode)

@impl true
def variance(series), do: Shared.apply_series(series, :s_variance)
def variance(series, ddof), do: Shared.apply_series(series, :s_variance, [ddof])

@impl true
def standard_deviation(series), do: Shared.apply_series(series, :s_standard_deviation)
def standard_deviation(series, ddof),
do: Shared.apply_series(series, :s_standard_deviation, [ddof])

@impl true
def quantile(series, quantile),
Expand Down
29 changes: 21 additions & 8 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2225,6 +2225,10 @@ defmodule Explorer.Series do
@doc """
Gets the variance of the series.
By default, this is the sample variance. This function also takes an optional
delta degrees of freedom (ddof). Setting this to zero corresponds to the population
variance.
## Supported dtypes
* `:integer`
Expand All @@ -2245,15 +2249,21 @@ defmodule Explorer.Series do
** (ArgumentError) Explorer.Series.variance/1 not implemented for dtype {:datetime, :microsecond}. Valid dtypes are [:integer, :float]
"""
@doc type: :aggregation
@spec variance(series :: Series.t()) :: float() | non_finite() | nil
def variance(%Series{dtype: dtype} = series) when is_numeric_dtype(dtype),
do: apply_series(series, :variance)
@spec variance(series :: Series.t(), ddof :: non_neg_integer()) :: float() | non_finite() | nil
def variance(series, ddof \\ 1)

def variance(%Series{dtype: dtype} = series, ddof) when is_numeric_dtype(dtype),
do: apply_series(series, :variance, [ddof])

def variance(%Series{dtype: dtype}), do: dtype_error("variance/1", dtype, [:integer, :float])
def variance(%Series{dtype: dtype}, _), do: dtype_error("variance/1", dtype, [:integer, :float])

@doc """
Gets the standard deviation of the series.
By default, this is the sample standard deviation. This function also takes an optional
delta degrees of freedom (ddof). Setting this to zero corresponds to the population
sample standard deviation.
## Supported dtypes
* `:integer`
Expand All @@ -2274,11 +2284,14 @@ defmodule Explorer.Series do
** (ArgumentError) Explorer.Series.standard_deviation/1 not implemented for dtype :string. Valid dtypes are [:integer, :float]
"""
@doc type: :aggregation
@spec standard_deviation(series :: Series.t()) :: float() | non_finite() | nil
def standard_deviation(%Series{dtype: dtype} = series) when is_numeric_dtype(dtype),
do: apply_series(series, :standard_deviation)
@spec standard_deviation(series :: Series.t(), ddof :: non_neg_integer()) ::
float() | non_finite() | nil
def standard_deviation(series, ddof \\ 1)

def standard_deviation(%Series{dtype: dtype} = series, ddof) when is_numeric_dtype(dtype),
do: apply_series(series, :standard_deviation, [ddof])

def standard_deviation(%Series{dtype: dtype}),
def standard_deviation(%Series{dtype: dtype}, _),
do: dtype_error("standard_deviation/1", dtype, [:integer, :float])

@doc """
Expand Down
8 changes: 4 additions & 4 deletions native/explorer/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,17 +490,17 @@ pub fn expr_abs(expr: ExExpr) -> ExExpr {
}

#[rustler::nif]
pub fn expr_variance(expr: ExExpr) -> ExExpr {
pub fn expr_variance(expr: ExExpr, ddof: u8) -> ExExpr {
let expr = expr.clone_inner();

ExExpr::new(expr.var(1))
ExExpr::new(expr.var(ddof))
}

#[rustler::nif]
pub fn expr_standard_deviation(expr: ExExpr) -> ExExpr {
pub fn expr_standard_deviation(expr: ExExpr, ddof: u8) -> ExExpr {
let expr = expr.clone_inner();

ExExpr::new(expr.std(1))
ExExpr::new(expr.std(ddof))
}

#[rustler::nif]
Expand Down
12 changes: 6 additions & 6 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,19 +955,19 @@ pub fn s_product(s: ExSeries) -> Result<ExSeries, ExplorerError> {
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_variance(env: Env, s: ExSeries) -> Result<Term, ExplorerError> {
pub fn s_variance(env: Env, s: ExSeries, ddof: u8) -> Result<Term, ExplorerError> {
match s.dtype() {
DataType::Int64 => Ok(s.i64()?.var(1).encode(env)),
DataType::Float64 => Ok(term_from_optional_float(s.f64()?.var(1), env)),
DataType::Int64 => Ok(s.i64()?.var(ddof).encode(env)),
DataType::Float64 => Ok(term_from_optional_float(s.f64()?.var(ddof), env)),
dt => panic!("var/1 not implemented for {dt:?}"),
}
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_standard_deviation(env: Env, s: ExSeries) -> Result<Term, ExplorerError> {
pub fn s_standard_deviation(env: Env, s: ExSeries, ddof: u8) -> Result<Term, ExplorerError> {
match s.dtype() {
DataType::Int64 => Ok(s.i64()?.std(1).encode(env)),
DataType::Float64 => Ok(term_from_optional_float(s.f64()?.std(1), env)),
DataType::Int64 => Ok(s.i64()?.std(ddof).encode(env)),
DataType::Float64 => Ok(term_from_optional_float(s.f64()?.std(ddof), env)),
dt => panic!("std/1 not implemented for {dt:?}"),
}
}
Expand Down

0 comments on commit 548979b

Please sign in to comment.