Skip to content

Commit

Permalink
added the 'full_covariance_matrices' parameter to the 'GMM' function,…
Browse files Browse the repository at this point in the history
… related to #48
  • Loading branch information
mlampros committed Apr 2, 2023
1 parent 4576540 commit 2ef5eb4
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 88 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: ClusterR
Type: Package
Title: Gaussian Mixture Models, K-Means, Mini-Batch-Kmeans, K-Medoids and Affinity Propagation Clustering
Version: 1.3.1
Date: 2023-02-01
Date: 2023-04-02
Authors@R: c( person(given = "Lampros", family = "Mouselimis", email = "[email protected]", role = c("aut", "cre"), comment = c(ORCID = "https://orcid.org/0000-0002-8024-1546")), person(given = "Conrad", family = "Sanderson", role = "cph", comment = "Author of the C++ Armadillo library"), person(given = "Ryan", family = "Curtin", role = "cph", comment = "Author of the C++ Armadillo library"), person(given = "Siddharth", family = "Agrawal", role = "cph", comment = "Author of the C code of the Mini-Batch-Kmeans algorithm (https://github.com/siddharth-agrawal/Mini-Batch-K-Means)"), person(given = "Brendan", family = "Frey", email = "[email protected]", role = "cph", comment = "Author of the matlab code of the Affinity propagation algorithm (for commercial use please contact the author of the matlab code)"), person(given = "Delbert", family = "Dueck", role = "cph", comment = "Author of the matlab code of the Affinity propagation algorithm"), person(given = "Vitalie", family = "Spinu", email = "[email protected]", role = "ctb", comment = c(Github = "Github Contributor")) )
Maintainer: Lampros Mouselimis <[email protected]>
BugReports: https://github.com/mlampros/ClusterR/issues
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
## Cluster 1.3.1

* I fixed a mistake related to a potential warning of the *'Optimal_Clusters_GMM()'* function (see issue: https://github.com/mlampros/ClusterR/issues/45)
* I modified the *'GMM()'* function by adding the *'full_covariance_matrices'* parameter (see issue: https://github.com/mlampros/ClusterR/issues/48)


## Cluster 1.3.0
Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ Predict_mini_batch_kmeans <- function(data, CENTROIDS, fuzzy = FALSE, eps = 1.0e
.Call(`_ClusterR_Predict_mini_batch_kmeans`, data, CENTROIDS, fuzzy, eps)
}

GMM_arma <- function(data, gaussian_comps, dist_mode, seed_mode, km_iter, em_iter, verbose, var_floor = 1e-10, seed = 1L) {
.Call(`_ClusterR_GMM_arma`, data, gaussian_comps, dist_mode, seed_mode, km_iter, em_iter, verbose, var_floor, seed)
GMM_arma <- function(data, gaussian_comps, dist_mode, seed_mode, km_iter, em_iter, verbose, var_floor = 1e-10, seed = 1L, full_covariance_matrices = FALSE) {
.Call(`_ClusterR_GMM_arma`, data, gaussian_comps, dist_mode, seed_mode, km_iter, em_iter, verbose, var_floor, seed, full_covariance_matrices)
}

predict_MGausDPDF <- function(data, CENTROIDS, COVARIANCE, WEIGHTS, eps = 1.0e-8) {
Expand Down
40 changes: 35 additions & 5 deletions R/clustering_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,27 @@ utils::globalVariables(c("x", "y")) # to avoid the following NOTE when
#'
#' @keywords internal

tryCatch_GMM <- function(data, gaussian_comps, dist_mode, seed_mode, km_iter, em_iter, verbose, var_floor, seed) {

Error = tryCatch(GMM_arma(data, gaussian_comps, dist_mode, seed_mode, km_iter, em_iter, verbose, var_floor, seed),
tryCatch_GMM <- function(data,
gaussian_comps,
dist_mode,
seed_mode,
km_iter,
em_iter,
verbose,
var_floor,
seed,
full_covariance_matrices) {

Error = tryCatch(GMM_arma(data,
gaussian_comps,
dist_mode,
seed_mode,
km_iter,
em_iter,
verbose,
var_floor,
seed,
full_covariance_matrices),

error = function(e) e)

Expand All @@ -34,6 +52,7 @@ tryCatch_GMM <- function(data, gaussian_comps, dist_mode, seed_mode, km_iter, em
#' @param verbose either TRUE or FALSE; enable or disable printing of progress during the k-means and EM algorithms
#' @param var_floor the variance floor (smallest allowed value) for the diagonal covariances
#' @param seed integer value for random number generator (RNG)
#' @param full_covariance_matrices a boolean. If FALSE "diagonal" covariance matrices (i.e. in each covariance matrix, all entries outside the main diagonal are assumed to be zero) otherwise "full" covariance matrices will be returned. Be aware in case of "full" covariance matrices a cube (3-dimensional) rather than a matrix for the output "covariance_matrices" value will be returned.
#' @return a list consisting of the centroids, covariance matrix ( where each row of the matrix represents a diagonal covariance matrix), weights and the log-likelihoods for each gaussian component. In case of Error it returns the error message and the possible causes.
#' @details
#' This function is an R implementation of the 'gmm_diag' class of the Armadillo library. The only exception is that user defined parameter settings are not supported, such as seed_mode = 'keep_existing'.
Expand Down Expand Up @@ -64,7 +83,8 @@ GMM = function(data,
em_iter = 5,
verbose = FALSE,
var_floor = 1e-10,
seed = 1) {
seed = 1,
full_covariance_matrices = FALSE) {

if ('data.frame' %in% class(data)) data = as.matrix(data)
if (!inherits(data, 'matrix')) stop('data should be either a matrix or a data frame')
Expand All @@ -75,12 +95,22 @@ GMM = function(data,
if (em_iter < 0 ) stop('the em_iter parameter can not be negative')
if (!is.logical(verbose)) stop('the verbose parameter should be either TRUE or FALSE')
if (var_floor < 0 ) stop('the var_floor parameter can not be negative')
if (!inherits(full_covariance_matrices, 'logical')) stop('The full_covariance_matrices parameter must be a boolean!')

flag_non_finite = check_NaN_Inf(data)

if (!flag_non_finite) stop("the data includes NaN's or +/- Inf values")

res = tryCatch_GMM(data, gaussian_comps, dist_mode, seed_mode, km_iter, em_iter, verbose, var_floor, seed)
res = tryCatch_GMM(data,
gaussian_comps,
dist_mode,
seed_mode,
km_iter,
em_iter,
verbose,
var_floor,
seed,
full_covariance_matrices)

if ('Error' %in% names(res)) {

Expand Down
Loading

0 comments on commit 2ef5eb4

Please sign in to comment.