Skip to content

Details (DRAFT)

Dillon Daudert edited this page May 25, 2020 · 10 revisions

Notes on API, supported functionality, polymorphism, etc.

1. General

"""
A struct encompassing the ways to parameterize the UMAP fit/transform algorithm(s).
"""
struct UMAPConfig{K, S, T, G, O}
    knn_params::K # n_neighbors, metric(s), precomputed dists, ...
    src_params::S # local_connectivity, set_operation_ratio, bandwidth
    tgt_params::T # initialization, metrics, fit_ab params (min_dist, spread), ...
    glb_params::G # global fuzzy set params (for merging views)
    opt_params::O # n_epochs, learning_rate, repulsion_strength, neg_sample_rate, ...
end

"""
The return result from `fit(data, ...)`. Passed to `transform(new_data, umap_result, ...)`
"""
struct UMAPResult{DS, DT, C, V}
    data::DS
    embedding::DT
    config::C
    # ... internal data, things we might want to cache
    views::V # intermediate results (e.g. knns / dists matrices) for each view
end 

2. Construct a global UMAP graph (i.e. fuzzy simplicial set)

Internally, UMAP can take as input data represented as a Tables.jl table. For the first half of the algorithm, we need to construct a weighted, undirected graph ("UMAP graph") G that approximates the distances between the original (potentially heterogeneous) points in a new metric space. Each column in the table is considered a different "view" of the same dataset, so we need to combine these views by

  1. Finding (or providing) the (approximate) knns of each point, with respect to a particular view (column) and metric
  2. Constructing fuzzy simplicial sets in the usual way for each (view, metric) combination
  3. Combine (e.g. intersect) the fuzzy set views into a total view for the final UMAP graph

2.1 KNN Search

If the input data is a Table, then the KNN search step needs to compute an (approx) KNN graph for each (metric, columns) combination. There are several particular special cases that need to be treated: there's one metric for all columns; there is one precomputed distance matrix (or knn graph); there's a collection of metrics and associated columns for each of them and a set of precomputed distance matrices.

  • NOTE: In the current API, a precomputed distance matrix replaces the input data. There is the potential to also pass the data in if, for the transform, we want to pass in the data instead of another precomputed distance matrix for that view... however, this seems like a niche case and a lot more complex. It's more straightforward for right now to just assume: If the user passes in the precomputed distance matrix for a particular view of the data, then they will be able to do the same for the transform.
# Top level to dispatch over multiple knn params
# NOTE: figure out where is best to index data[cols] 
function knn_search(data, knn_params::Sequence{KNNParams})
    return knn_search.((data,), knn_params)
end
# and for transform
function knn_search(data, queries, knn_params::Sequence{KNNParams})
    return knn_search.((data,), (queries,), knn_params)
end

# Basic KNN params where we have one metric
function knn_search(data, knn_params::KNNParams{<:PreMetric, ...}) end
function knn_search(data, queries, knn_params::KNNParams{<:PreMetric, ...}) end

# Basic KNN params where we have one precomputed dist matrix
function knn_search(data, knn_params::KNNParams{<:AbstractMatrix, ...}) end
function knn_search(data, queries, knn_params::KNNParams{<:AbstractMatrix, ...}) end

Types

From these, it's clear that we need KNNParams with at least the following parameterizations:

# nndescent case; cols can be nothing, an instance of col index or a vector of col inds
struct KNNParams{M<:PreMetric,V,P}
    n_neighbors::Int
    metric::M
    cols::V # which columns fall under this metric
    params::P
end

struct KNNParams{M<:AbstractMatrix,P}
    n_neighbors::Int
    dists::M
    params::P
end

2.2 Fuzzy Simplicial Sets of Views

For each knn graph constructed in step 2.1, we convert into a global fuzzy simplicial set of the respective view of the data. This function will always return a sparse matrix representing a weighted, undirected graph with the edge weights representing the membership probability of that edge in the fuzzy simplicial set.

  • NOTE: When initially fitting a dataset, the adjacency matrix of this graph will be square and symmetric. For transforming new data according to a fitted UMAP result, technically the result will only contain a submatrix of the complete adjacency matrix (assuming that we have a graph where the nodes are the union of the points in the original dataset and the points in the new data to be transformed). Since we transform new points by finding their nearest neighbors in the original dataset, all the edges in this submatrix have a source at some point in the new data and a destination at another point in the original data.
# Handle multiple views of the data, this should handle both singular and sequences of src_params 
function fuzzy_simplicial_set(knn_graphs::Sequence{KNNGraph}, knn_params::Sequence{KNNParams}, src_params, global_src_params)
   view_fuzzy_sets = fuzzy_simplicial_set.(knn_graphs, knn_params, src_params; ...)
   return merge_view_fuzzy_simplicial_sets(view_fuzzy_sets, global_src_params)
end

# handle a single view
function fuzzy_simplicial_set(knn_graph, knn_params, src_params; ...)
    # smooth_knn_dists ...
    # compute_membership_strengths ...
    # combine_fuzzy_sets ...
end

Types

Then we need the following parameterizations of the src_params:

struct SourceParams # SourceViewParams
    local_connectivity
    set_operation_ratio
    bandwidth
end

The idea for the first SourceParams is that it's per-view, so it is broadcast across all the views (either one for all views, or one for each view). However, there are some parameters for combining all view fuzzy sets into a single fuzzy set that combines all views. We might instead call the previous struct SourceViewParams and in addition have the next struct for the global fuzzy simplicial set parameters:

struct SourceParams
    sets_mix_weight
    sets_operation_ratio # typically set to 100% intersection
    # weights for each view specifically?
end

What we're left with after the view fuzzy sets are merged is a single, global UMAP graph. This encodes the topological information of the input data in the source space. Our next step is to initialize a new representation of the data in the target space and perform optimization.

3. Optimize the target embedding

We need a way to initialize the target embedding and then perform optimization.

3.1 Initialize target embedding

We need a way to create a new point on the target manifold. In the simplest case, we choose this manifold to be R^d and can easily generate vectors (e.g. randomly, spectral embedding, weighted averaging of neighbors in the case of transforming new data). It should be possible to embed into many different kinds of manifolds provided we have a differentiable membership function for the target embedding that we can optimize via gradient descent.

The function initialize_embedding returns a vector of points in the target space.

# fit
function initialize_embedding(umap_graph, init_params) end
# transform
function initialize_embedding(umap_graph, ref_embedding, init_params) end

Types

The initialization depends on the target manifold we're in. Essentially, these are whatever data and functions we need to generate a point or points in the target space.

"""
This initializes points in Euclidean space of a specified dimension according to the given method.
"""
struct EuclideanInitialization{I}
    method:I
    n_components::Int # this might be extracted out into the encompasses TargetParams
end

There could be other ways to parameterize this that may not be provided by default (or maybe only one example).

3.2 Initialize the target membership function and derivative

As previously stated, we need a differentiable membership function for the simplices in the target space. This is a binary function that takes two points (more generically, it could take any l-simplex, but we only care about 1-simplices for the moment) in the target space and returns a value between 0 and 1.

By default, this membership function is approximated by a family of smooth exponential functions.

"""
Returns a function `membership_fn(distance, x, y)` that takes as input a distance function
and two points in the target space.
"""
function fit_membership_fn(tgt_params) -> membership_fn end

It should also be possible to provide a custom membership function as long as it meets the requirements and has the same signature.

In either case, we can use automatic differentiation (probably via Zygote.jl to get functions returning the gradient of the membership function (although it's likely a custom adjoint will be written for the default). For efficiency, the loss function is simplified: we drop any constants and decompose into positive and negative terms:

# we sample positive edges according to their weights in the UMAP graph
pos_grad(x, y) = gradient((_x, _y) -> log(membership_fn(distance, _x, _y)), x, y)
# we sample some number of negative edges for each positive edge sampled
neg_grad(x, y) = gradient((_x, _y) -> log(1 - membership_fn(distance, _x, _y)), x, y)

Types

# fit a and b
struct TargetParams{I, M}
    init::I # e.g. DefaultInitialization
    metric::M
    min_dist
    spread
end
# provide a and b
struct TargetParams{M}
    init::I
    metric::M
    a
    b
end

3.3 Optimize the embedding

The target embedding is optimized via stochastic gradient descent. This procedure is relatively straightforward at the moment, so there is not much polymorphism needed here.

function optimize_embedding!(embedding, ref_embedding, umap_graph, tgt_params, opt_params; move_ref)
    # NOTE: try to compile gradient of target metric here?
    # NOTE: add epochs_per_sample calculation here? this seems useful (in particular dividing the weights of the
    #       edges by the maximum weight, so we don't waste any epochs
    lr = opt_params.learning_rate
    for epoch in opt_params.n_epochs
        _optimize_embedding!(embedding, ref_embedding, umap_graph, tgt_params, lr, opt_params...; move_ref=move_ref)
        lr = ...
    end
end

Types

struct OptimizationParams
    n_epochs
    learning_rate
    repulsion_strength
    neg_sample_rate
end