Skip to content
This repository has been archived by the owner on Nov 19, 2020. It is now read-only.

Commit

Permalink
Adding extension methods to simplify how distributions can be estimat…
Browse files Browse the repository at this point in the history
…ed from the data (without requiring the distribution to be created first).
  • Loading branch information
cesarsouza committed Jun 18, 2017
1 parent 588e100 commit 998c75d
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 4 deletions.
117 changes: 117 additions & 0 deletions Sources/Accord.Statistics/Measures/Tools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace Accord.Statistics
using Accord.Math.Decompositions;
using Accord.Statistics.Kernels;
using AForge;
using Accord.Statistics.Distributions;
using Accord.Statistics.Distributions.Fitting;

/// <summary>
/// Set of statistics functions.
Expand Down Expand Up @@ -441,6 +443,121 @@ public static double[][] Whitening(double[][] value, out double[][] transformMat
return Matrix.Dot(value, transformMatrix);
}

/// <summary>
/// Creates a new distribution that has been fit to a given set of observations.
/// </summary>
///
/// <param name="observations">The array of observations to fit the model against. The array
/// elements can be either of type double (for univariate data) or
/// type double[] (for multivariate data).</param>
/// <param name="weights">The weight vector containing the weight for each of the samples.</param>
///
public static TDistribution Fit<TDistribution>(this double[] observations, double[] weights = null)
where TDistribution : IFittable<double>, new()
{
var dist = new TDistribution();
dist.Fit(observations, weights);
return dist;
}

/// <summary>
/// Creates a new distribution that has been fit to a given set of observations.
/// </summary>
///
/// <param name="observations">The array of observations to fit the model against. The array
/// elements can be either of type double (for univariate data) or
/// type double[] (for multivariate data).</param>
/// <param name="weights">The weight vector containing the weight for each of the samples.</param>
///
public static TDistribution Fit<TDistribution>(this double[][] observations, double[] weights = null)
where TDistribution : IFittable<double[]>, new()
{
var dist = new TDistribution();
dist.Fit(observations, weights);
return dist;
}

/// <summary>
/// Creates a new distribution that has been fit to a given set of observations.
/// </summary>
///
/// <param name="observations">The array of observations to fit the model against. The array
/// elements can be either of type double (for univariate data) or
/// type double[] (for multivariate data).</param>
/// <param name="weights">The weight vector containing the weight for each of the samples.</param>
/// <param name="options">Optional arguments which may be used during fitting, such
/// as regularization constants and additional parameters.</param>
///
public static TDistribution Fit<TDistribution, TOptions>(this double[] observations, TOptions options, double[] weights = null)
where TDistribution : IFittable<double, TOptions>, new()
where TOptions : class, IFittingOptions
{
var dist = new TDistribution();
dist.Fit(observations, weights, options);
return dist;
}

/// <summary>
/// Creates a new distribution that has been fit to a given set of observations.
/// </summary>
///
/// <param name="observations">The array of observations to fit the model against. The array
/// elements can be either of type double (for univariate data) or
/// type double[] (for multivariate data).</param>
/// <param name="weights">The weight vector containing the weight for each of the samples.</param>
/// <param name="options">Optional arguments which may be used during fitting, such
/// as regularization constants and additional parameters.</param>
///
public static TDistribution Fit<TDistribution, TOptions>(this double[][] observations, TOptions options, double[] weights = null)
where TDistribution : IFittable<double[], TOptions>, new()
where TOptions : class, IFittingOptions
{
var dist = new TDistribution();
dist.Fit(observations, weights, options);
return dist;
}

/// <summary>
/// Creates a new distribution that has been fit to a given set of observations.
/// </summary>
///
/// <param name="distribution">The distribution whose parameters should be fitted to the samples.</param>
/// <param name="observations">The array of observations to fit the model against. The array
/// elements can be either of type double (for univariate data) or
/// type double[] (for multivariate data).</param>
/// <param name="weights">The weight vector containing the weight for each of the samples.</param>
///
public static TDistribution FitNew<TDistribution, TObservations>(
this TDistribution distribution, TObservations[] observations, double[] weights = null)
where TDistribution : IFittable<TObservations>, ICloneable
{
var clone = (TDistribution)distribution.Clone();
clone.Fit(observations, weights);
return clone;
}

/// <summary>
/// Creates a new distribution that has been fit to a given set of observations.
/// </summary>
///
/// <param name="distribution">The distribution whose parameters should be fitted to the samples.</param>
/// <param name="observations">The array of observations to fit the model against. The array
/// elements can be either of type double (for univariate data) or
/// type double[] (for multivariate data).</param>
/// <param name="weights">The weight vector containing the weight for each of the samples.</param>
/// <param name="options">Optional arguments which may be used during fitting, such
/// as regularization constants and additional parameters.</param>
///
public static TDistribution FitNew<TDistribution, TObservations, TOptions>(
this TDistribution distribution, TObservations[] observations, TOptions options, double[] weights = null)
where TDistribution : IFittable<TObservations, TOptions>, ICloneable
where TOptions : class, IFittingOptions
{
var clone = (TDistribution)distribution.Clone();
clone.Fit(observations, weights, options);
return clone;
}

}
}

Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ namespace Accord.Tests.Statistics
using Accord.Statistics;
using Accord.Statistics.Distributions.Multivariate;
using Accord.Statistics.Distributions.Univariate;
using NUnit.Framework;

using NUnit.Framework;
using Accord.Statistics.Distributions.Fitting;

[TestFixture]
public class NormalDistributionTest
{
Expand Down Expand Up @@ -107,8 +108,65 @@ public void FitTest()
target.Fit(observations2, weights2);

Assert.AreEqual(expectedMean, target.Mean);
}

}


[Test]
public void FitExtensionTest_options()
{
NormalDistribution target = new NormalDistribution();
double[] observations = { 0.10, 0.40, 2.00, 2.00 };
double[] weights = { 0.25, 0.25, 0.25, 0.25 };
target.Fit(observations, weights);
NormalDistribution same = observations.Fit<NormalDistribution, NormalOptions>(new NormalOptions()
{
Regularization = 10
}, weights);
Assert.AreNotSame(same, target);
Assert.AreEqual(same.ToString(), target.ToString());

NormalDistribution copy = target.FitNew(observations, new NormalOptions()
{
Regularization = 10
}, weights);
Assert.AreNotSame(copy, target);
Assert.AreEqual(copy.ToString(), target.ToString());
}


[Test]
public void FitExtensionTest_weights()
{
NormalDistribution target = new NormalDistribution();
double[] observations = { 0.10, 0.40, 2.00, 2.00 };
double[] weights = { 0.25, 0.25, 0.25, 0.25 };
target.Fit(observations, weights);
NormalDistribution same = observations.Fit<NormalDistribution>(weights);
Assert.AreNotSame(same, target);
Assert.AreEqual(same.ToString(), target.ToString());

NormalDistribution copy = target.FitNew(observations, weights);
Assert.AreNotSame(copy, target);
Assert.AreEqual(copy.ToString(), target.ToString());
}

[Test]
public void FitExtensionTest()
{
NormalDistribution target = new NormalDistribution();
double[] observations = { 0.10, 0.40, 2.00, 2.00 };
target.Fit(observations);
NormalDistribution same = observations.Fit<NormalDistribution>();
Assert.AreNotSame(same, target);
Assert.AreEqual(same.ToString(), target.ToString());

NormalDistribution copy = target.FitNew(observations);
Assert.AreNotSame(copy, target);
Assert.AreEqual(copy.ToString(), target.ToString());
}



[Test]
public void FitTest2()
{
Expand Down

0 comments on commit 998c75d

Please sign in to comment.