diff --git a/Sources/Accord.Statistics/Measures/Tools.cs b/Sources/Accord.Statistics/Measures/Tools.cs index 252355d87..9c3b9668c 100644 --- a/Sources/Accord.Statistics/Measures/Tools.cs +++ b/Sources/Accord.Statistics/Measures/Tools.cs @@ -28,6 +28,8 @@ namespace Accord.Statistics using Accord.Math.Decompositions; using Accord.Statistics.Kernels; using AForge; + using Accord.Statistics.Distributions; + using Accord.Statistics.Distributions.Fitting; /// /// Set of statistics functions. @@ -441,6 +443,121 @@ public static double[][] Whitening(double[][] value, out double[][] transformMat return Matrix.Dot(value, transformMatrix); } + /// + /// Creates a new distribution that has been fit to a given set of observations. + /// + /// + /// The array of observations to fit the model against. The array + /// elements can be either of type double (for univariate data) or + /// type double[] (for multivariate data). + /// The weight vector containing the weight for each of the samples. + /// + public static TDistribution Fit(this double[] observations, double[] weights = null) + where TDistribution : IFittable, new() + { + var dist = new TDistribution(); + dist.Fit(observations, weights); + return dist; + } + + /// + /// Creates a new distribution that has been fit to a given set of observations. + /// + /// + /// The array of observations to fit the model against. The array + /// elements can be either of type double (for univariate data) or + /// type double[] (for multivariate data). + /// The weight vector containing the weight for each of the samples. + /// + public static TDistribution Fit(this double[][] observations, double[] weights = null) + where TDistribution : IFittable, new() + { + var dist = new TDistribution(); + dist.Fit(observations, weights); + return dist; + } + + /// + /// Creates a new distribution that has been fit to a given set of observations. + /// + /// + /// The array of observations to fit the model against. The array + /// elements can be either of type double (for univariate data) or + /// type double[] (for multivariate data). + /// The weight vector containing the weight for each of the samples. + /// Optional arguments which may be used during fitting, such + /// as regularization constants and additional parameters. + /// + public static TDistribution Fit(this double[] observations, TOptions options, double[] weights = null) + where TDistribution : IFittable, new() + where TOptions : class, IFittingOptions + { + var dist = new TDistribution(); + dist.Fit(observations, weights, options); + return dist; + } + + /// + /// Creates a new distribution that has been fit to a given set of observations. + /// + /// + /// The array of observations to fit the model against. The array + /// elements can be either of type double (for univariate data) or + /// type double[] (for multivariate data). + /// The weight vector containing the weight for each of the samples. + /// Optional arguments which may be used during fitting, such + /// as regularization constants and additional parameters. + /// + public static TDistribution Fit(this double[][] observations, TOptions options, double[] weights = null) + where TDistribution : IFittable, new() + where TOptions : class, IFittingOptions + { + var dist = new TDistribution(); + dist.Fit(observations, weights, options); + return dist; + } + + /// + /// Creates a new distribution that has been fit to a given set of observations. + /// + /// + /// The distribution whose parameters should be fitted to the samples. + /// The array of observations to fit the model against. The array + /// elements can be either of type double (for univariate data) or + /// type double[] (for multivariate data). + /// The weight vector containing the weight for each of the samples. + /// + public static TDistribution FitNew( + this TDistribution distribution, TObservations[] observations, double[] weights = null) + where TDistribution : IFittable, ICloneable + { + var clone = (TDistribution)distribution.Clone(); + clone.Fit(observations, weights); + return clone; + } + + /// + /// Creates a new distribution that has been fit to a given set of observations. + /// + /// + /// The distribution whose parameters should be fitted to the samples. + /// The array of observations to fit the model against. The array + /// elements can be either of type double (for univariate data) or + /// type double[] (for multivariate data). + /// The weight vector containing the weight for each of the samples. + /// Optional arguments which may be used during fitting, such + /// as regularization constants and additional parameters. + /// + public static TDistribution FitNew( + this TDistribution distribution, TObservations[] observations, TOptions options, double[] weights = null) + where TDistribution : IFittable, ICloneable + where TOptions : class, IFittingOptions + { + var clone = (TDistribution)distribution.Clone(); + clone.Fit(observations, weights, options); + return clone; + } + } } diff --git a/Unit Tests/Accord.Tests.Statistics/Distributions/Univariate/Continuous/NormalDistributionTest.cs b/Unit Tests/Accord.Tests.Statistics/Distributions/Univariate/Continuous/NormalDistributionTest.cs index 2f675075f..27e71ad16 100644 --- a/Unit Tests/Accord.Tests.Statistics/Distributions/Univariate/Continuous/NormalDistributionTest.cs +++ b/Unit Tests/Accord.Tests.Statistics/Distributions/Univariate/Continuous/NormalDistributionTest.cs @@ -27,8 +27,9 @@ namespace Accord.Tests.Statistics using Accord.Statistics; using Accord.Statistics.Distributions.Multivariate; using Accord.Statistics.Distributions.Univariate; - using NUnit.Framework; - + using NUnit.Framework; + using Accord.Statistics.Distributions.Fitting; + [TestFixture] public class NormalDistributionTest { @@ -107,8 +108,65 @@ public void FitTest() target.Fit(observations2, weights2); Assert.AreEqual(expectedMean, target.Mean); - } - + } + + + [Test] + public void FitExtensionTest_options() + { + NormalDistribution target = new NormalDistribution(); + double[] observations = { 0.10, 0.40, 2.00, 2.00 }; + double[] weights = { 0.25, 0.25, 0.25, 0.25 }; + target.Fit(observations, weights); + NormalDistribution same = observations.Fit(new NormalOptions() + { + Regularization = 10 + }, weights); + Assert.AreNotSame(same, target); + Assert.AreEqual(same.ToString(), target.ToString()); + + NormalDistribution copy = target.FitNew(observations, new NormalOptions() + { + Regularization = 10 + }, weights); + Assert.AreNotSame(copy, target); + Assert.AreEqual(copy.ToString(), target.ToString()); + } + + + [Test] + public void FitExtensionTest_weights() + { + NormalDistribution target = new NormalDistribution(); + double[] observations = { 0.10, 0.40, 2.00, 2.00 }; + double[] weights = { 0.25, 0.25, 0.25, 0.25 }; + target.Fit(observations, weights); + NormalDistribution same = observations.Fit(weights); + Assert.AreNotSame(same, target); + Assert.AreEqual(same.ToString(), target.ToString()); + + NormalDistribution copy = target.FitNew(observations, weights); + Assert.AreNotSame(copy, target); + Assert.AreEqual(copy.ToString(), target.ToString()); + } + + [Test] + public void FitExtensionTest() + { + NormalDistribution target = new NormalDistribution(); + double[] observations = { 0.10, 0.40, 2.00, 2.00 }; + target.Fit(observations); + NormalDistribution same = observations.Fit(); + Assert.AreNotSame(same, target); + Assert.AreEqual(same.ToString(), target.ToString()); + + NormalDistribution copy = target.FitNew(observations); + Assert.AreNotSame(copy, target); + Assert.AreEqual(copy.ToString(), target.ToString()); + } + + + [Test] public void FitTest2() {