diff --git a/Project.toml b/Project.toml index 7ca3fd3..1277d11 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "UMAP" uuid = "c4f8c510-2410-5be4-91d7-4fbaeb39457e" authors = ["Dillon Daudert "] -version = "0.1.3" +version = "0.1.4" [deps] Arpack = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" diff --git a/src/UMAP.jl b/src/UMAP.jl index 7633e63..181f608 100644 --- a/src/UMAP.jl +++ b/src/UMAP.jl @@ -5,7 +5,7 @@ using NearestNeighborDescent: DescentGraph using LsqFit: curve_fit using SparseArrays: SparseMatrixCSC, sparse, dropzeros, nzrange, rowvals, nonzeros using LinearAlgebra: Symmetric, Diagonal, issymmetric, I -using Arpack: eigs +using Arpack: eigs, ARPACKException include("utils.jl") include("umap_.jl") diff --git a/src/umap_.jl b/src/umap_.jl index e3a6c24..7738936 100644 --- a/src/umap_.jl +++ b/src/umap_.jl @@ -6,7 +6,7 @@ struct UMAP_{S} graph::AbstractMatrix{S} embedding::AbstractMatrix{S} - function UMAP_(graph::AbstractMatrix{S}, embedding::AbstractMatrix{S}) where {S<:AbstractFloat} + function UMAP_(graph::AbstractMatrix{S}, embedding::AbstractMatrix{S}) where {S<:Real} issymmetric(graph) || isapprox(graph, graph') || error("UMAP_ constructor expected graph to be a symmetric matrix") new{S}(graph, embedding) end @@ -23,18 +23,18 @@ how many neighbors to consider as locally connected. # Keyword Arguments - `n_neighbors::Integer = 15`: the number of neighbors to consider as locally connected. Larger values capture more global structure in the data, while small values capture more local structure. -- `metric::SemiMetric = Euclidean()`: the metric to calculate distance in the input space. It is also possible to pass `metric = :precomputed` to treat `X` like a precomputed distance matrix. +- `metric::{SemiMetric, Symbol} = Euclidean()`: the metric to calculate distance in the input space. It is also possible to pass `metric = :precomputed` to treat `X` like a precomputed distance matrix. - `n_epochs::Integer = 300`: the number of training epochs for embedding optimization -- `learning_rate::AbstractFloat = 1.`: the initial learning rate during optimization +- `learning_rate::Real = 1`: the initial learning rate during optimization - `init::Symbol = :spectral`: how to initialize the output embedding; valid options are `:spectral` and `:random` -- `min_dist::AbstractFloat = 0.1`: the minimum spacing of points in the output embedding -- `spread::AbstractFloat = 1.0`: the effective scale of embedded points. Determines how clustered embedded points are in combination with `min_dist`. -- `set_operation_ratio::AbstractFloat = 1.0`: interpolates between fuzzy set union and fuzzy set intersection when constructing the UMAP graph (global fuzzy simplicial set). The value of this parameter should be between 1.0 and 0.0: 1.0 indicates pure fuzzy union, while 0.0 indicates pure fuzzy intersection. +- `min_dist::Real = 0.1`: the minimum spacing of points in the output embedding +- `spread::Real = 1`: the effective scale of embedded points. Determines how clustered embedded points are in combination with `min_dist`. +- `set_operation_ratio::Real = 1`: interpolates between fuzzy set union and fuzzy set intersection when constructing the UMAP graph (global fuzzy simplicial set). The value of this parameter should be between 1.0 and 0.0: 1.0 indicates pure fuzzy union, while 0.0 indicates pure fuzzy intersection. - `local_connectivity::Integer = 1`: the number of nearest neighbors that should be assumed to be locally connected. The higher this value, the more connected the manifold becomes. This should not be set higher than the intrinsic dimension of the manifold. -- `repulsion_strength::AbstractFloat = 1.0`: the weighting of negative samples during the optimization process. +- `repulsion_strength::Real = 1`: the weighting of negative samples during the optimization process. - `neg_sample_rate::Integer = 5`: the number of negative samples to select for each positive sample. Higher values will increase computational cost but result in slightly more accuracy. -- `a::AbstractFloat = nothing`: this controls the embedding. By default, this is determined automatically by `min_dist` and `spread`. -- `b::AbstractFloat = nothing`: this controls the embedding. By default, this is determined automatically by `min_dist` and `spread`. +- `a = nothing`: this controls the embedding. By default, this is determined automatically by `min_dist` and `spread`. +- `b = nothing`: this controls the embedding. By default, this is determined automatically by `min_dist` and `spread`. """ function umap(args...; kwargs...) # this is just a convenience function for now @@ -47,17 +47,17 @@ function UMAP_(X::AbstractMatrix{S}, n_neighbors::Integer = 15, metric::Union{SemiMetric, Symbol} = Euclidean(), n_epochs::Integer = 300, - learning_rate::AbstractFloat = 1., + learning_rate::Real = 1, init::Symbol = :spectral, - min_dist::AbstractFloat = 0.1, - spread::AbstractFloat = 1.0, - set_operation_ratio::AbstractFloat = 1.0, + min_dist::Real = 1//10, + spread::Real = 1, + set_operation_ratio::Real = 1, local_connectivity::Integer = 1, - repulsion_strength::AbstractFloat = 1.0, + repulsion_strength::Real = 1, neg_sample_rate::Integer = 5, - a::Union{AbstractFloat, Nothing} = nothing, - b::Union{AbstractFloat, Nothing} = nothing - ) where {S <: AbstractFloat} + a::Union{Real, Nothing} = nothing, + b::Union{Real, Nothing} = nothing + ) where {S<:Real} # argument checking size(X, 2) > n_neighbors > 0|| throw(ArgumentError("size(X, 2) must be greater than n_neighbors and n_neighbors must be greater than 0")) size(X, 1) > n_components > 1 || throw(ArgumentError("size(X, 1) must be greater than n_components and n_components must be greater than 1")) @@ -112,7 +112,7 @@ function smooth_knn_dists(knn_dists::AbstractMatrix{S}, k::Integer, local_connectivity::Real; niter::Integer=64, - bandwidth::AbstractFloat=1., + bandwidth::Real=1, ktol = 1e-5) where {S <: Real} nonzero_dists(dists) = @view dists[dists .> 0.] @@ -191,16 +191,22 @@ function compute_membership_strengths(knns::AbstractMatrix{S}, return rows, cols, vals end -function initialize_embedding(graph, n_components, ::Val{:spectral}) - embed = spectral_layout(graph, n_components) - # expand - expansion = 10. / maximum(embed) - @. embed = (embed*expansion) + randn(size(embed))*0.0001 +function initialize_embedding(graph::AbstractMatrix{T}, n_components, ::Val{:spectral}) where {T} + local embed + try + embed = spectral_layout(graph, n_components) + # expand + expansion = 10 / maximum(embed) + @. embed = (embed*expansion) + (1//10000)*randn(T, size(embed)) + catch e + print("Error encountered in spectral_layout; defaulting to random layout\n") + embed = initialize_embedding(graph, n_components, Val(:random)) + end return embed end -function initialize_embedding(graph, n_components, ::Val{:random}) - return 20. .* rand(n_components, size(graph, 1)) .- 10. +function initialize_embedding(graph::AbstractMatrix{T}, n_components, ::Val{:random}) where {T} + return 20 .* rand(T, n_components, size(graph, 1)) .- 10 end """ @@ -298,7 +304,7 @@ end Initialize the graph layout with spectral embedding. """ function spectral_layout(graph::SparseMatrixCSC{T}, - embed_dim::Integer) where {T<:AbstractFloat} + embed_dim::Integer) where {T<:Real} D_ = Diagonal(dropdims(sum(graph; dims=2); dims=2)) D = inv(sqrt(D_)) # normalized laplacian @@ -307,21 +313,13 @@ function spectral_layout(graph::SparseMatrixCSC{T}, k = embed_dim+1 num_lanczos_vectors = max(2k+1, round(Int, sqrt(size(L, 1)))) - local layout - try - # get the 2nd - embed_dim+1th smallest eigenvectors - eigenvals, eigenvecs = eigs(L; nev=k, - ncv=num_lanczos_vectors, - which=:SM, - tol=1e-4, - v0=ones(T, size(L, 1)), - maxiter=size(L, 1)*5) - layout = permutedims(eigenvecs[:, 2:k])::Array{T, 2} - catch e - print("\n", e, "\n") - print("Error occured in spectral_layout; - falling back to random layout.\n") - layout = 20 .* rand(T, embed_dim, size(L, 1)) .- 10 - end + # get the 2nd - embed_dim+1th smallest eigenvectors + eigenvals, eigenvecs = eigs(L; nev=k, + ncv=num_lanczos_vectors, + which=:SR, + tol=1e-4, + v0=ones(T, size(L, 1)), + maxiter=size(L, 1)*5) + layout = permutedims(eigenvecs[:, 2:k])::Array{T, 2} return layout end diff --git a/src/utils.jl b/src/utils.jl index 3dcd465..5b3266e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -29,7 +29,7 @@ end function knn_search(X::AbstractMatrix{S}, k, metric, - ::Val{:pairwise}) where {S <: AbstractFloat} + ::Val{:pairwise}) where {S <: Real} num_points = size(X, 2) dist_mat = Array{S}(undef, num_points, num_points) pairwise!(dist_mat, metric, X, dims=2) @@ -41,12 +41,12 @@ end function knn_search(X::AbstractMatrix{S}, k, metric, - ::Val{:approximate}) where {S <: AbstractFloat} + ::Val{:approximate}) where {S <: Real} knngraph = DescentGraph(X, k, metric) return knngraph.indices, knngraph.distances end -function _knn_from_dists(dist_mat::AbstractMatrix{S}, k) where {S <: AbstractFloat} +function _knn_from_dists(dist_mat::AbstractMatrix{S}, k) where {S <: Real} knns_ = [partialsortperm(view(dist_mat, :, i), 2:k+1) for i in 1:size(dist_mat, 1)] dists_ = [dist_mat[:, i][knns_[i]] for i in eachindex(knns_)] knns = hcat(knns_...)::Matrix{Int} @@ -57,7 +57,7 @@ end # combine local fuzzy simplicial sets @inline function combine_fuzzy_sets(fs_set::AbstractMatrix{T}, - set_op_ratio::T) where {T <: AbstractFloat} + set_op_ratio) where {T} return set_op_ratio .* fuzzy_set_union(fs_set) .+ (one(T) - set_op_ratio) .* fuzzy_set_intersection(fs_set) end @@ -68,4 +68,4 @@ end @inline function fuzzy_set_intersection(fs_set::AbstractMatrix) return fs_set .* fs_set' -end \ No newline at end of file +end diff --git a/test/umap_tests.jl b/test/umap_tests.jl index f6074b6..45d5b82 100644 --- a/test/umap_tests.jl +++ b/test/umap_tests.jl @@ -22,7 +22,7 @@ @test size(umap_.embedding) == (2, 100) data = rand(Float32, 5, 100) - @test_skip UMAP_(data) isa UMAP_{Float32} + @test UMAP_(data) isa UMAP_{Float32} end @testset "fuzzy_simpl_set" begin @@ -105,7 +105,8 @@ layout = spectral_layout(B, 5) @test layout isa Array{Float64, 2} @inferred spectral_layout(B, 5) - @test_skip spectral_layout(convert(SparseMatrixCSC{Float32}, B), 5) + layout32 = spectral_layout(convert(SparseMatrixCSC{Float32}, B), 5) + @test layout32 isa Array{Float32, 2} end end