-
Notifications
You must be signed in to change notification settings - Fork 17
Details (DRAFT)
Notes on API, supported functionality, polymorphism, etc.
"""
A struct encompassing the ways to parameterize the UMAP fit/transform algorithm(s).
"""
struct UMAPConfig{K, S, T, G, O}
knn_params::K # n_neighbors, metric(s), precomputed dists, ...
src_params::S # local_connectivity, set_operation_ratio, bandwidth
tgt_params::T # initialization, metrics, fit_ab params (min_dist, spread), ...
glb_params::G # global fuzzy set params (for merging views)
opt_params::O # n_epochs, learning_rate, repulsion_strength, neg_sample_rate, ...
end
"""
The return result from `fit(data, ...)`. Passed to `transform(new_data, umap_result, ...)`
"""
struct UMAPResult{DS, DT, C, V}
data::DS
embedding::DT
config::C
# ... internal data, things we might want to cache
views::V # intermediate results (e.g. knns / dists matrices) for each view
end
Internally, UMAP can take as input data
represented as a Tables.jl
table. For the first half of the algorithm, we need to construct a weighted, undirected graph ("UMAP graph") G
that approximates the distances between the original (potentially heterogeneous) points in a new metric space. Each column in the table is considered a different "view" of the same dataset, so we need to combine these views by
- Finding (or providing) the (approximate) knns of each point, with respect to a particular view (column) and metric
- Constructing fuzzy simplicial sets in the usual way for each (view, metric) combination
- Combine (e.g. intersect) the fuzzy set views into a total view for the final UMAP graph
If the input data is a Table, then the KNN search step needs to compute an (approx) KNN graph for each (metric, columns) combination. There are several particular special cases that need to be treated: there's one metric for all columns; there is one precomputed distance matrix (or knn graph); there's a collection of metrics and associated columns for each of them and a set of precomputed distance matrices.
- NOTE: In the current API, a precomputed distance matrix replaces the input data. There is the potential to also pass the data in if, for the transform, we want to pass in the data instead of another precomputed distance matrix for that view... however, this seems like a niche case and a lot more complex. It's more straightforward for right now to just assume: If the user passes in the precomputed distance matrix for a particular view of the data, then they will be able to do the same for the transform.
# Top level to dispatch over multiple knn params
# NOTE: figure out where is best to index data[cols]
function knn_search(data, knn_params::Sequence{KNNParams})
return knn_search.((data,), knn_params)
end
# and for transform
function knn_search(data, queries, knn_params::Sequence{KNNParams})
return knn_search.((data,), (queries,), knn_params)
end
# Basic KNN params where we have one metric
function knn_search(data, knn_params::KNNParams{<:PreMetric, ...}) end
function knn_search(data, queries, knn_params::KNNParams{<:PreMetric, ...}) end
# Basic KNN params where we have one precomputed dist matrix
function knn_search(data, knn_params::KNNParams{<:AbstractMatrix, ...}) end
function knn_search(data, queries, knn_params::KNNParams{<:AbstractMatrix, ...}) end
From these, it's clear that we need KNNParams with at least the following parameterizations:
# nndescent case; cols can be nothing, an instance of col index or a vector of col inds
struct KNNParams{M<:PreMetric,V,P}
n_neighbors::Int
metric::M
cols::V # which columns fall under this metric
params::P
end
struct KNNParams{M<:AbstractMatrix,P}
n_neighbors::Int
dists::M
params::P
end
For each knn graph constructed in step 2.1, we convert into a global fuzzy simplicial set of the respective view of the data. This function will always return a sparse matrix representing a weighted, undirected graph with the edge weights representing the membership probability of that edge in the fuzzy simplicial set.
- NOTE: When initially fitting a dataset, the adjacency matrix of this graph will be square and symmetric. For transforming new data according to a fitted UMAP result, technically the result will only contain a submatrix of the complete adjacency matrix (assuming that we have a graph where the nodes are the union of the points in the original dataset and the points in the new data to be transformed). Since we transform new points by finding their nearest neighbors in the original dataset, all the edges in this submatrix have a source at some point in the new data and a destination at another point in the original data.
# Handle multiple views of the data, this should handle both singular and sequences of src_params
function fuzzy_simplicial_set(knn_graphs::Sequence{KNNGraph}, knn_params::Sequence{KNNParams}, src_params, global_src_params)
view_fuzzy_sets = fuzzy_simplicial_set.(knn_graphs, knn_params, src_params; ...)
return merge_view_fuzzy_simplicial_sets(view_fuzzy_sets, global_src_params)
end
# handle a single view
function fuzzy_simplicial_set(knn_graph, knn_params, src_params; ...)
# smooth_knn_dists ...
# compute_membership_strengths ...
# combine_fuzzy_sets ...
end
Then we need the following parameterizations of the src_params:
struct SourceParams # SourceViewParams
local_connectivity
set_operation_ratio
bandwidth
end
The idea for the first SourceParams is that it's per-view, so it is broadcast across all the views (either one for all views, or one for each view).
However, there are some parameters for combining all view fuzzy sets into a single fuzzy set that combines all views. We might instead call the previous struct SourceViewParams
and in addition have the next struct for the global fuzzy simplicial set parameters:
struct SourceParams
sets_mix_weight
sets_operation_ratio # typically set to 100% intersection
# weights for each view specifically?
end
What we're left with after the view fuzzy sets are merged is a single, global UMAP graph. This encodes the topological information of the input data in the source space. Our next step is to initialize a new representation of the data in the target space and perform optimization.
We need a way to initialize the target embedding and then perform optimization.
We need a way to create a new point on the target manifold. In the simplest case, we choose this manifold to be R^d and can easily generate vectors (e.g. randomly, spectral embedding, weighted averaging of neighbors in the case of transforming new data). It should be possible to embed into many different kinds of manifolds provided we have a differentiable membership function for the target embedding that we can optimize via gradient descent.
The function initialize_embedding
returns a vector of points in the target space.
# fit
function initialize_embedding(umap_graph, init_params) end
# transform
function initialize_embedding(umap_graph, ref_embedding, init_params) end
The initialization depends on the target manifold we're in. Essentially, these are whatever data and functions we need to generate a point or points in the target space.
"""
This initializes points in Euclidean space of a specified dimension according to the given method.
"""
struct EuclideanInitialization{I}
method:I
n_components::Int # this might be extracted out into the encompasses TargetParams
end
There could be other ways to parameterize this that may not be provided by default (or maybe only one example).
As previously stated, we need a differentiable membership function for the simplices in the target space. This is a binary function that takes two points (more generically, it could take any l-simplex, but we only care about 1-simplices for the moment) in the target space and returns a value between 0 and 1.
By default, this membership function is approximated by a family of smooth exponential functions.
"""
Returns a function `membership_fn(distance, x, y)` that takes as input a distance function
and two points in the target space.
"""
function fit_membership_fn(tgt_params) -> membership_fn end
It should also be possible to provide a custom membership function as long as it meets the requirements and has the same signature.
In either case, we can use automatic differentiation (probably via Zygote.jl to get functions returning the gradient of the membership function (although it's likely a custom adjoint will be written for the default). For efficiency, the loss function is simplified: we drop any constants and decompose into positive and negative terms:
# we sample positive edges according to their weights in the UMAP graph
pos_grad(x, y) = gradient((_x, _y) -> log(membership_fn(distance, _x, _y)), x, y)
# we sample some number of negative edges for each positive edge sampled
neg_grad(x, y) = gradient((_x, _y) -> log(1 - membership_fn(distance, _x, _y)), x, y)
# fit a and b
struct TargetParams{I, M}
init::I # e.g. DefaultInitialization
metric::M
min_dist
spread
end
# provide a and b
struct TargetParams{M}
init::I
metric::M
a
b
end
The target embedding is optimized via stochastic gradient descent. This procedure is relatively straightforward at the moment, so there is not much polymorphism needed here.
function optimize_embedding!(embedding, ref_embedding, umap_graph, tgt_params, opt_params; move_ref)
# NOTE: try to compile gradient of target metric here?
# NOTE: add epochs_per_sample calculation here? this seems useful (in particular dividing the weights of the
# edges by the maximum weight, so we don't waste any epochs
lr = opt_params.learning_rate
for epoch in opt_params.n_epochs
_optimize_embedding!(embedding, ref_embedding, umap_graph, tgt_params, lr, opt_params...; move_ref=move_ref)
lr = ...
end
end
struct OptimizationParams
n_epochs
learning_rate
repulsion_strength
neg_sample_rate
end