Skip to content
This repository has been archived by the owner on Nov 19, 2020. It is now read-only.

Commit

Permalink
GH-390 : MachineLearning.KMeans: Balanced clustering.
Browse files Browse the repository at this point in the history
  • Loading branch information
cesarsouza committed Jan 23, 2017
1 parent 1bf65bb commit 87205d0
Show file tree
Hide file tree
Showing 13 changed files with 1,900 additions and 67 deletions.
28 changes: 28 additions & 0 deletions Copyright.txt
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,31 @@ NLopt Numerical Optimization Library - Copyright (c) 2008-2014 Steven G. Johnson
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


Munkres algorithm - The Hungarian method for solving the assignment problem

The MIT License (MIT)

Copyright (c) 2000 Robert A. Pilgrim
Murray State University
Dept. of Computer Science & Information Systems
Murray,Kentucky

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
58 changes: 30 additions & 28 deletions Sources/Accord.MachineLearning/Accord.MachineLearning.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,36 @@
<RootNamespace>Accord.MachineLearning</RootNamespace>
<AssemblyName>Accord.MachineLearning</AssemblyName>
</PropertyGroup>
<Import Project="$(SolutionDir)Accord.NET.targets" />
<Import Project="$(SolutionDir)Accord.NET.targets" />
<PropertyGroup Condition=" '$(Configuration)' == 'Debug' ">
<DebugType>Full</DebugType>
<Optimize>False</Optimize>
<CheckForOverflowUnderflow>True</CheckForOverflowUnderflow>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<OutputPath>$(SolutionDir)..\Debug\</OutputPath>
<DocumentationFile></DocumentationFile>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'net35|AnyCPU' ">
<DefineConstants>TRACE;NET35</DefineConstants>
<TargetFrameworkVersion>v3.5</TargetFrameworkVersion>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'net40|AnyCPU' ">
<DefineConstants>TRACE;NET40</DefineConstants>
<TargetFrameworkVersion>v4.0</TargetFrameworkVersion>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'net45|AnyCPU'">
<DefineConstants>TRACE;NET45</DefineConstants>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'net46|AnyCPU'">
<DefineConstants>TRACE;NET46</DefineConstants>
<TargetFrameworkVersion>v4.6</TargetFrameworkVersion>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'mono|AnyCPU'">
<DefineConstants>TRACE;MONO</DefineConstants>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
</PropertyGroup>
<DebugType>Full</DebugType>
<Optimize>False</Optimize>
<CheckForOverflowUnderflow>True</CheckForOverflowUnderflow>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<OutputPath>$(SolutionDir)..\Debug\</OutputPath>
<DocumentationFile>
</DocumentationFile>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'net35|AnyCPU' ">
<DefineConstants>TRACE;NET35</DefineConstants>
<TargetFrameworkVersion>v3.5</TargetFrameworkVersion>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'net40|AnyCPU' ">
<DefineConstants>TRACE;NET40</DefineConstants>
<TargetFrameworkVersion>v4.0</TargetFrameworkVersion>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'net45|AnyCPU'">
<DefineConstants>TRACE;NET45</DefineConstants>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'net46|AnyCPU'">
<DefineConstants>TRACE;NET46</DefineConstants>
<TargetFrameworkVersion>v4.6</TargetFrameworkVersion>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'mono|AnyCPU'">
<DefineConstants>TRACE;MONO</DefineConstants>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
</PropertyGroup>
<ItemGroup>
<Reference Include="System" />
<Reference Include="System.Core">
Expand All @@ -47,6 +48,7 @@
<Compile Include="..\Accord.Core\Properties\VersionInfo.cs">
<Link>Properties\VersionInfo.cs</Link>
</Compile>
<Compile Include="Clustering\KMeans\BalancedKMeans.cs" />
<Compile Include="DecisionTrees\DecisionTreeHelper.cs" />
<Compile Include="Representations\BagOfWords.cs" />
<Compile Include="Rules\AssociationRule.cs" />
Expand Down
244 changes: 244 additions & 0 deletions Sources/Accord.MachineLearning/Clustering/KMeans/BalancedKMeans.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
// Accord Machine Learning Library
// The Accord.NET Framework
// http://accord-framework.net
//
// Copyright © César Souza, 2009-2017
// cesarsouza at gmail.com
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
//

namespace Accord.MachineLearning
{
using System;
using System.Collections.Generic;
using Accord.Math;
using Accord.Math.Distances;
using Math.Optimization;
using Statistics;
using System.Threading.Tasks;

/// <summary>
/// Balanced K-Means algorithm.
/// </summary>
///
/// <remarks>
/// The Balanced k-Means algorithm attempts to find a clustering where each cluster
/// has approximately the same number of data points. The Balanced k-Means implementation
/// used in the framework uses the <see cref="Munkres"/> algorithm to solve the assignment
/// problem thus enforcing balance between the clusters.
/// </remarks>
///
/// <example>
/// How to perform clustering with Balanced K-Means.
///
/// <code source="Unit Tests\Accord.Tests.MachineLearning\Clustering\BalancedKMeansTest.cs" region="doc_learn" />
/// </example>
///
/// <seealso cref="KMeans"/>
/// <seealso cref="BinarySplit"/>
/// <seealso cref="GaussianMixtureModel"/>
///
[Serializable]
public class BalancedKMeans : KMeans
{

/// <summary>
/// Gets the labels assigned for each data point in the last
/// call to <see cref="Learn(double[][], double[])"/>.
/// </summary>
///
/// <value>The labels.</value>
///
public int[] Labels { get; private set; }

/// <summary>
/// Initializes a new instance of the Balanced K-Means algorithm.
/// </summary>
///
/// <param name="k">The number of clusters to divide the input data into.</param>
/// <param name="distance">The distance function to use. Default is to
/// use the <see cref="Accord.Math.Distance.SquareEuclidean(double[], double[])"/> distance.</param>
///
public BalancedKMeans(int k, IDistance<double[]> distance)
: base(k, distance)
{
}

/// <summary>
/// Initializes a new instance of the Balanced K-Means algorithm.
/// </summary>
///
/// <param name="k">The number of clusters to divide the input data into.</param>
///
public BalancedKMeans(int k)
: base(k) { }


/// <summary>
/// Learns a model that can map the given inputs to the desired outputs.
/// </summary>
/// <param name="x">The model inputs.</param>
/// <param name="weights">The weight of importance for each input sample.</param>
/// <returns>A model that has learned how to produce suitable outputs
/// given the input data <paramref name="x" />.</returns>
public override KMeansClusterCollection Learn(double[][] x, double[] weights = null)
{
// Initial argument checking
if (x == null)
throw new ArgumentNullException("x");

if (x.Length < K)
throw new ArgumentException("Not enough points. There should be more points than the number K of clusters.");

if (weights == null)
{
weights = Vector.Ones(x.Length);
}
else
{
if (x.Length != weights.Length)
throw new ArgumentException("Data weights vector must be the same length as data samples.");
}

double weightSum = weights.Sum();
if (weightSum <= 0)
throw new ArgumentException("Not enough points. There should be more points than the number K of clusters.");

if (!x.IsRectangular())
throw new DimensionMismatchException("data", "The points matrix should be rectangular. The vector at position {} has a different length than previous ones.");

int k = this.K;
int rows = x.Length;
int cols = x[0].Length;

// Perform a random initialization of the clusters
// if the algorithm has not been initialized before.
//
if (this.Clusters.Centroids[0] == null)
{
Randomize(x);
}

// Initial variables
int[] labels = new int[rows];
double[] count = new double[k];
double[][] centroids = Clusters.Centroids;
double[][] newCentroids = new double[k][];
for (int i = 0; i < newCentroids.Length; i++)
newCentroids[i] = new double[cols];

Object[] syncObjects = new Object[K];
for (int i = 0; i < syncObjects.Length; i++)
syncObjects[i] = new Object();

Iterations = 0;

bool shouldStop = false;

var m = new Munkres(x.Length, x.Length);
double[][] costMatrix = m.CostMatrix;

while (!shouldStop) // Main loop
{
Array.Clear(count, 0, count.Length);
for (int i = 0; i < newCentroids.Length; i++)
Array.Clear(newCentroids[i], 0, newCentroids[i].Length);
for (int i = 0; i < labels.Length; i++)
labels[i] = -1;

// Set the cost matrix for Munkres algorithm
for (int i = 0; i < costMatrix.Length; i++)
for (int j = 0; j < costMatrix[i].Length; j++)
costMatrix[i][j] = Distance.Distance(x[j], centroids[i % k]);

//string str = costMatrix.ToCSharp();

m.Minimize(); // solve the assignment problem

for (int i = 0; i < x.Length; i++)
{
if (m.Solution[i] >= 0)
labels[(int)m.Solution[i]] = i % k;
}

// For each point in the data set,
Parallel.For(0, x.Length, ParallelOptions, i =>
{
// Get the point
double[] point = x[i];
double weight = weights[i];
// Get the nearest cluster centroid
int c = labels[i];
if (c >= 0)
{
// Get the closest cluster centroid
double[] centroid = newCentroids[c];
lock (syncObjects[c])
{
// Increase the cluster's sample counter
count[c] += weight;
// Accumulate in the cluster centroid
for (int j = 0; j < point.Length; j++)
centroid[j] += point[j] * weight;
}
}
});

// Next we will compute each cluster's new centroid
// by dividing the accumulated sums by the number of
// samples in each cluster, thus averaging its members.
Parallel.For(0, newCentroids.Length, ParallelOptions, i =>
{
double sum = count[i];
if (sum > 0)
{
for (int j = 0; j < newCentroids[i].Length; j++)
newCentroids[i][j] /= sum;
}
});

// The algorithm stops when there is no further change in the
// centroids (relative difference is less than the threshold).
shouldStop = converged(centroids, newCentroids);

// go to next generation
Parallel.For(0, centroids.Length, ParallelOptions, i =>
{
for (int j = 0; j < centroids[i].Length; j++)
centroids[i][j] = newCentroids[i][j];
});
}

for (int i = 0; i < Clusters.Centroids.Length; i++)
{
// Compute the proportion of samples in the cluster
Clusters.Proportions[i] = count[i] / weightSum;
}

this.Labels = labels;

ComputeInformation(x, labels);

return Clusters;
}

}
}
Loading

0 comments on commit 87205d0

Please sign in to comment.