Skip to content

Commit

Permalink
Gaussian Mixture Model implementation (#369)
Browse files Browse the repository at this point in the history
* Starting work on GMM.

* More work on GMM training method.

* Modernizing KMeansTrainer.

* Moving logsumexp from ChainHelper to a method on DenseVector.

* Filling out GaussianMixtureModel class.

* Filling out GMMTrainer.train.

* Implementing MultivariateNormalDistribution.logProbability.

* Working on covariance calculation.

* Code compiles for GMM. Inference still isn't quite right though.

* Fix bugs in MultivariateNormalDistribution, Cholesky.determinant, LU.determinant, SparseVector.subtract.

* Fixing bugs in GMM.

* Small tidy ups to Math.

* Fixing diagonal and spherical coveriance estimation.

* Adding a mixture distribution and a distribution interface.

* Fixing parallel reduction by converting it into collect.

* PR comments.

* Fixing merge conflict.
  • Loading branch information
Craigacp authored Dec 19, 2024
1 parent 81605a2 commit ff2172a
Show file tree
Hide file tree
Showing 27 changed files with 4,150 additions and 122 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -178,36 +178,7 @@ public static ChainViterbiResults viterbi(ChainCliqueValues scores) {
* @return log sum exp input[i].
*/
public static double sumLogProbs(DenseVector input) {
double LOG_TOLERANCE = 30.0;

double maxValue = input.get(0);
int maxIdx = 0;
for (int i = 1; i < input.size(); i++) {
double value = input.get(i);
if (value > maxValue) {
maxValue = value;
maxIdx = i;
}
}
if (maxValue == Double.NEGATIVE_INFINITY) {
return maxValue;
}

boolean anyAdded = false;
double intermediate = 0.0;
double cutoff = maxValue - LOG_TOLERANCE;
for (int i = 0; i < input.size(); i++) {
double value = input.get(i);
if (value >= cutoff && i != maxIdx && !Double.isInfinite(value)) {
anyAdded = true;
intermediate += Math.exp(value - maxValue);
}
}
if (anyAdded) {
return maxValue + Math.log1p(intermediate);
} else {
return maxValue;
}
return input.logSumExp();
}

/**
Expand Down
98 changes: 98 additions & 0 deletions Clustering/GMM/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-clustering</artifactId>
<version>5.0.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<name>Clustering-GMM</name>
<artifactId>tribuo-clustering-gmm</artifactId>
<packaging>jar</packaging>
<properties>
<maven.compiler.release>17</maven.compiler.release>
</properties>

<dependencies>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tribuo-core</artifactId>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-data</artifactId>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tribuo-math</artifactId>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tribuo-clustering-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.oracle.labs.olcut</groupId>
<artifactId>olcut-core</artifactId>
</dependency>
<!-- test time dependencies -->
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tribuo-core</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-assembly</id> <!-- this is used for inheritance merges -->
<phase>package</phase> <!-- bind to the packaging phase -->
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.clustering.gmm;

import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import org.tribuo.Trainer;
import org.tribuo.clustering.gmm.GMMTrainer.Initialisation;
import org.tribuo.math.distributions.MultivariateNormalDistribution.CovarianceType;
import org.tribuo.math.distributions.MultivariateNormalDistribution;

import java.util.logging.Logger;

/**
* OLCUT {@link Options} for the GMM implementation.
*/
public class GMMOptions implements Options {
private static final Logger logger = Logger.getLogger(GMMOptions.class.getName());

/**
* Iterations of the GMM algorithm. Defaults to 10.
*/
@Option(longName = "gmm-interations", usage = "Iterations of the GMM algorithm. Defaults to 10.")
public int iterations = 10;
/**
* Number of centroids/Gaussians in GMM. Defaults to 10.
*/
@Option(longName = "gmm-num-centroids", usage = "Number of centroids in GMM. Defaults to 10.")
public int centroids = 10;
/**
* The covariance type of the Gaussians.
*/
@Option(charName = 'v', longName = "covariance-type", usage = "Set the covariance type.")
public CovarianceType covarianceType = MultivariateNormalDistribution.CovarianceType.DIAGONAL;
/**
* Initialisation function in GMM. Defaults to RANDOM.
*/
@Option(longName = "gmm-initialisation", usage = "Initialisation function in GMM. Defaults to RANDOM.")
public Initialisation initialisation = GMMTrainer.Initialisation.RANDOM;
/**
* Convergence tolerance to terminate EM early.
*/
@Option(longName = "gmm-tolerance", usage = "The convergence threshold.")
public double tolerance = 1e-3f;
/**
* Number of computation threads in GMM. Defaults to 4.
*/
@Option(longName = "gmm-num-threads", usage = "Number of computation threads in GMM. Defaults to 4.")
public int numThreads = 4;
/**
* The RNG seed.
*/
@Option(longName = "gmm-seed", usage = "Sets the random seed for GMM.")
public long seed = Trainer.DEFAULT_SEED;

/**
* Gets the configured GMMTrainer using the options in this object.
* @return A GMMTrainer.
*/
public GMMTrainer getTrainer() {
logger.info("Configuring GMM Trainer");
return new GMMTrainer(centroids, iterations, covarianceType, initialisation, tolerance, numThreads, seed);
}
}
Loading

0 comments on commit ff2172a

Please sign in to comment.