-
Notifications
You must be signed in to change notification settings - Fork 17
Functionality
While the primary advertised use-case of UMAP is as a dimension reduction technique (projecting vectors R^n to R^d, with d << n), it is a much more generic algorithm. At a high level, UMAP projects points in one topological (source) space to another topological (target) space. Both the source and the target spaces must be differentiable manifolds with suitable metrics, with UMAP approximating this manifold for the input data. This ultimately allows for embedding high-dimensional, heterogeneous data in more useful spaces, for example R^2 (if we wish to visualize our data). To do this, UMAP performs the following:
Assumptions:
- The points in the source space are uniformly distributed on a Riemannian manifold
- The Riemannian metric is locally constant (or approximately)
- The manifold is locally connected
Procedure:
- Construct a fuzzy topological representation of the data in the source space:
- Approximate the manifold on which the data is assumed to lie
- Construct a fuzzy (simplicial) set representation of the approximated manifold
- Construct a fuzzy topological representation of the data in the target space:
- Initialize a representation of the data in the desired target space
- Construct a fuzzy set representation of the target manifold
- Optimize the projection by minimizing the cross entropy between the two fuzzy set representations
The first stage finds the distances from each point to its k-nearest neighbors. We use NearestNeighborDescent.jl to find the approximate neighbors of the input data, which has the benefit of working on a variety of input data types and distance functions.
For fitting, this looks like a function
# Returns a weighted, directed graph with outgoing edges from each point
# to its k nearest neighbor points (both in data)
knn_search(data, n_neighbors, metric) -> knn_graph
And for transforming, a function like
# Returns a weighted, directed *bipartite* graph with outgoing edges from
# each point in new_data to its k nearest neighbor points in data
knn_search(data, new_data, n_neighbors, metric) -> knn_graph
Alternatively, the user can calculate distances beforehand by some other method and pass them directly to skip this step.
Each point and its nearest neighbors are converted into a (local) fuzzy simplicial set, which are then combined.
Rather than using the knn distances directly, they are smoothed in order to "fix the cardinality of the fuzzy set to a specific value", in this case log2(k).
To achieve this, for each point x_i in the data we define rho_i and sigma_i such that:
-
rho_i = min{d(x_i, x_i_j) | 1 <= j <= k, d(x_i, x_i_j) > 0}, and
-
sigma_i := sum_{j=1:k} exp( (-max(0, d(x_i, x_i_j) - rho_i)) / sigma_i ) = log2(k).
Then the weights of the graph are defined as
- w((x_i, x_i_j)) = exp( (-max(0, d(x_i, x_i_j) - rho_i)) / sigma_i )
The effect of this is that the weight of each point to its nearest neighbor with nonzero distance is set to 1 (neighbors with distances of 0 also have a weight of 1), with further points having decreasing positive weights. That is, this normalizes all the weights in the graph to lie in [0, 1]. This creates a local fuzzy simplicial set for each point.
The local sets can be combined by taking the union of these. In effect, whereas the local fuzzy sets together constituted a directed, weighted KNN graph (exactly num_points * num_neighbors edges), the global fuzzy simplicial set will have an arbitrary number of edges. This is referred to as the (undirected, weighted) UMAP graph.
(NOTE - Heterogeneous data: It seems the general mechanism for handling heterogeneous data is to construct multiple fuzzy simplicial sets, one for each "feature" and associated metric (e.g. vectors, ordinal/categorical data, etc), then combine them. The combination might be parameterized by weighting union vs. intersection, and also by weighting each fuzzy simpl. set individually...
Each data type (or prediction variables in the supervised case) can be seen as an alternative view of the underlying structure, each with a different associated metric – for example categorical data may use Jaccard or Dice distance, while ordinal data might use Manhattan distance. Each view and metric can be used to independently generate fuzzy simplicial sets, which can then be intersected together to create a single fuzzy simplicial set for embedding)
Given the UMAP graph, which contains the topological information of the data in the source space, we can compute a new representation (embedding) of the data in the target space. The target space can be fairly generic, however since UMAP optimizes the embedding via stochastic gradient descent, we need a metric on the target space that is differentiable.
Since we know the target manifold and metric, we can compute the fuzzy simplicial set representation directly. (NOTE: Since we assume local connectedness, we provide a parameter min_dist
which defines the distance to each point's nearest neighbor in the target space.)
Each data point is initialized through some process (e.g. randomly or via spectral embedding of the UMAP graph). This is a function like
# The particulars can vary, but a collection of points in the target space must be returned
init_embedding(...) -> embedding
We optimize the embedding by minimizing the fuzzy set cross entropy. Recall that we constructed an undirected, weighted graph whose edges constitute a fuzzy simplicial set. We can represent this fuzzy set as a tuple (E, f)
where E
is the set of edges and f: E -> [0, 1]
is a membership function returning the weight of an edge in E.
We also have a fuzzy simplicial set for the embedding, (E, g)
. Here, the membership function g: E -> [0, 1]
is approximated by choosing a family of differentiable functions:
"""
To optimize this loss with gradient descent, g(e) must be differentiable. A
smooth approximation for the membership strength of a 1-simplex between two
points x, y, can be given by the following, with differentiable dissimilarity
function `σ` and constants `α`, `β`:
"""
ϕ(x, y, σ, α, β) = (1 + α*(σ(x, y))^β)^(-1)
"""
The approximation parameters `α`, `β` are chosen by non-linear least squares
fitting of the following function ψ:
"""
ψ(x, y, σ, min_dist) = 1 if σ(x, y) ≤ min_dist else exp(-(σ(x, y) - min_dist))
The loss function to minimize is therefore (dropping terms that are constant during optimization):
C(E, μ, ν) = - Σ_{e ∈ E} ( f(e) * log(g(e)) + (1 - f(e)) * log(1 - g(e)))