-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add transform data functionality #26
add transform data functionality #26
Conversation
…ct to a reference embedding - add new initialize_embedding method - add new optimize_embedding method and refactor old method to use new one - add corresponding tests
- exposes initialization and optimization w.r.t. reference embedding - add corresponding unit test
I also considered a different approach to the high level API, which was letting the |
Codecov Report
@@ Coverage Diff @@
## master #26 +/- ##
==========================================
+ Coverage 89.36% 90.12% +0.76%
==========================================
Files 4 4
Lines 141 162 +21
==========================================
+ Hits 126 146 +20
- Misses 15 16 +1
Continue to review full report at Codecov.
|
Thanks for the PR, this looks great! I will give it a review in the next couple days. |
fixed docstring typos |
I took an initial look through and this seems mostly on the right track. I need to spend more time with this, so I'll have a more thorough review forthcoming. There is one potential issue here when optimizing a new embedding with respect to the reference. Since the knn graph is constructed from the union of the reference and query data, it leaves open the possibility that some query point has no neighbors (either incoming or outgoing) in the reference data. During optimization, this means that its embedding will never be updated (the condition at line 155 in embeddings.jl will always be false.) I wonder if you've thought about this and how the Python implementation handles it, if it does. |
Ah you're right, I didn't consider this. Looks like the Python implementation may actually build the graph by only searching for nearest neighbors in the reference data. (Searching for neighbors here, which uses this, which computes distances between original (reference) data and query data here) I'll look into this more and update the PR within a few days. |
As I am looking through the Python implementation and the codebase here and in NearestNeighborDescent.jl, I'm having a bit of a change of heart about API. I'm now thinking that what may make the most sense is the use case # easy option: quick data --> embedding (the current api)
embedding = umap(X, n_components, kwargs...)
# and more exposed option, allowing transforms of new data into preexisting embedded space:
model = UMAP_(X, n_components, kwargs...)
embedding = model.embedding
test_embedding = umap_transform(Xtest, model, fewer_kwargs...) The upsides are:
The only drawback of this from the user POV is that you can't transform new data w.r.t a dataset and embedding without its neighbor graph, eg. if they were obtained externally, but I feel like this won't be a use case for most people, and even if so, recomputing the neighbor graph of the data won't be prohibitive in most cases. The main changes then are just exposing the neighbor graph made by the I will update the PR with this version soon but let me know if you have any objections to this approach! :) |
I agree that we'll need to store the state from the initial fit procedure in order to use it for the transform. This will require some changes to the We'll also need add the initial data used as a field so we can calculate the distances during the transform. Then there will need to be a new method of So I'm alright with a bit of refactoring as needed, especially in the |
- add umap_transform function - refactor fuzzy_simplicial_set - to search for nearest neighbors for query data - add knn_search method to reconstruct ApproximateKNNGraph
I've gone through and implemented the things you suggested, and changed the transform function (now called One thing I discovered, that may need to factor into this PR or maybe a next PR: It seems like the python implementation has been optimized a lot recently. I ran the notebook in the repo, and according to the README, python umap was a bit slower than this package. Now, this package has similar timing to the README (~150sec) while the python umap runs in only ~50 seconds on my laptop. It seems like one of the major areas of optimization was their improvement of the nearest neighbor search, which can be looked at later. The side effect though is that I was comparing the python functionality to transforming new data to this PR, and got the following results on mnist test data, after fitting to training data: So using the hyperparameters, etc. in this PR, by the eye check it seems like the Python transform is more accurate. It also runs in about 10 seconds while the julia function takes ~25 seconds -- so I was not sure how to compare results while keeping a similar level of nearest neighbor approximation complexity, since the python has a faster baseline as of now. Increasing the number of neighbors used in the ApproximateKNNGraph made the julia transform plot more closely resemble the python plot, while taking a nonnegligible amount of time longer. |
Hi - thanks for the update, I'll take another look over the weekend! You're right about pynndescent being faster now than when I created the readme, my impression was that it was mostly due to parallelization. I looked into adding the same to NNDescent.jl but I remember I wasn't satisfied with the approach they took. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At a high level, this looks pretty good. I've made comments for things I think should be changed, and I'm willing to discuss any of them in more detail.
With regards to the type instability I pointed out for transform, if you've got an idea for how to fix it, go ahead and implement it (I recommend adding a test using Test.@inferred
). Otherwise we can leave it for when I do a more thorough refactoring of UMAP_
and the type annotations more generally.
I would like one thing: could you do some benchmarking comparing this PR to master for just the umap
procedure? Just to make sure there are no performance regressions.
src/umap_.jl
Outdated
function UMAP_{S, M, N}(graph, embedding) where {S<:Real, | ||
M<:AbstractMatrix{S}, | ||
N<:AbstractMatrix{S}} | ||
issymmetric(graph) || isapprox(graph, graph') || error("UMAP_ constructor expected graph to be a symmetric matrix") | ||
new(graph, embedding) | ||
UMAP_{S, M, N}(graph, embedding, Matrix(undef, 0, 0), Matrix{Int}(undef, 0, 0), Matrix{S}(undef, 0, 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it makes sense to allow constructing a UMAP_
struct without all of the required fields, so we don't need these constructors. In other words, I think there should be two general cases: inner constructor(s) that require all the fields needed to represent a "complete" struct, and the outer constructor(s) which build the struct by running the algorithm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing I was trying to do was leave old constructors since UMAP_
had been exported before, so these changes wouldn't break old code. But since the struct didn't have much of a use before this may not be a big concern.
The other thing I was considering was I have had use cases where I have data and its corresponding embedding, although the embedding was not computed from the same package. I then wanted to use some transform
sort of functionality, and recomputing the graph of the old data would be redundant. Constructing a partially defined struct would let a user wrap data and its embedding manually, and use these for transforming new data.
(As is, this wouldn't save a whole lot of computation since the current knn_search would recompute the ApproximateKNNGraph for large matrices anyways, but this would be an internal detail that could change later). But let me know your thoughts on this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you could give me an example of the workflow when you'd need a partial struct? I'm not sure I understand the use case entirely.
I wouldn't worry too much about "breaking" API changes - since this package is pre-1.0, any minor version bump technically permits breaking changes. There's still a ways to go before a 1.0 release is ready and the API could be made stable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I won't worry about the API changes 👍
The use case was that a third party published data and its embedding, and I had some related data of interest that I wanted to embed along with the published data/embedding. Reconsidering this (and since we are changing the struct API anyways), I think a separate transform
method would be much better than the partial struct, eg one with some signature like transform(X, embedding, Q; <kwargs>)
.
But this use case is much more unusual than the typical workflow so definitely not an important addition if you don't think it's important than I'll leave it out
src/utils.jl
Outdated
metric, | ||
::Val{:pairwise}) where {S <: Real} | ||
num_points = size(X, 2) | ||
dist_mat = Array{S}(undef, num_points, num_points) | ||
if S <: AbstractFloat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What use-case is this additional logic needed for?
src/embeddings.jl
Outdated
Initialize an embedding of points corresponding to the columns of the `graph`, by taking weighted means of | ||
the columns of `ref_embedding`, where weights are values from the rows of the `graph`. | ||
|
||
The resulting embedding will have shape `(size(ref_embedding, 1), length(query_inds))`, where `size(ref_embedding, 1)` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should replace length(query_inds)
with size(graph, 2)
, since query_inds isn't an argument that's passed to this. Also note that the returned value of this function isn't a matrix, but a vector of vectors (intentionally).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah mb, forgot to update this documentation
src/embeddings.jl
Outdated
samples in the resulting embedding. Its elements will have type T. | ||
""" | ||
function initialize_embedding(graph::AbstractMatrix{<:Real}, ref_embedding::AbstractMatrix{T}) where {T<:AbstractFloat} | ||
embed = [zeros(T, size(ref_embedding, 1)) for _ in 1:size(graph, 2)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You never use these zero vectors, so this is a bit unnecessary. To create an array with the correct type, you could do something like embed = Vector{Vector{T}}[]
src/embeddings.jl
Outdated
""" | ||
function initialize_embedding(graph::AbstractMatrix{<:Real}, ref_embedding::AbstractMatrix{T}) where {T<:AbstractFloat} | ||
embed = [zeros(T, size(ref_embedding, 1)) for _ in 1:size(graph, 2)] | ||
for (i, col) in enumerate(eachcol(graph)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can simplify this expression, since this is just matrix multiplication:
embed = collect(eachcol((ref_embedding * graph) ./ (sum(graph, dims=1) .+ eps(T))))
Double check it, but it should be equivalent. Also, you'll want eps(T)
over eps(zero(T))
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch
src/embeddings.jl
Outdated
- `initial_alpha`: the initial learning rate | ||
- `gamma`: the repulsive strength of negative samples | ||
- `neg_sample_rate::Integer`: the number of negative samples per positive sample | ||
""" | ||
function optimize_embedding(graph, | ||
embedding, | ||
n_epochs, | ||
embedding::AbstractVector{<:AbstractVector{<:AbstractFloat}}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we need these type constraints. I actually think this package is overly type constrained (that's mostly my doing), when it could be much more generic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So one issue was having the optimize_embedding
methods with both one embedding argument and with two (for query and ref), the call would not be properly dispatched because of the default arguments to _a
and _b
, and I was resolving this with type constraints. But I can remove the default arguments or remove the overloaded methods, if you have a preference towards either of those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have no preference, I will just remove the original optimize_embedding
method so all calls will have to pass two embeddings (or the same one twice), I'm starting to think the overloading is more unnecessary/messy than good.
- add parameterization to UMAP_ struct - remove partial UMAP_ struct constructors - clean knn_search code - add type safety to initialize_embedding - switch model and Q in transform API
I believe I addressed all of your comments and made the changes. The major ones were:
The Here are the relative benchmarks: (the data is the typical mnist data, (28*28, 60k) for Please let me know if I missed anything 😃 |
Thanks for the quick work! I'll aim to get a final review by the weekend, and hopefully we can get this merged soon after that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the exception of the small change to the README, this looks good to me!
It'll be good to create a new release with this functionality... There are some improvements and QoL changes this package really needs that I'd like to implement beforehand, though.
cool, if you'd like to open some issues I can help out! |
Co-authored-by: Sanjay Mohan <>
This adds functionality to transform data with UMAP with respect to previously fit data.
See the README additions for usage overview.
First, a skeleton is build for the manifold containing both the old fit data and the new data to transform. This yields a graph whose first R rows/columns correspond to the old fit data (for which an embedding has already been obtained), and the remaining rows/columns correspond to the new data to be transformed. This uses the preexisting
fuzzy_simplicial_set
function as is.Next, an embedding for the data to transform is initialized as a weighted mean of points from the reference embedding. This is the new method overload for the
initialize_embedding
function. For comparison, here is the python implementation (the weights are obtained from the graph entries, which you can see here ).Finally, the
optimize_embedding
function operates as before, but only computes attractions / repulsions between the new embedding and the reference embedding. For the python version, see here They usehead
to refer to the new embedding, which I designated withquery_inds
, andtail
to refer to the reference embedding, which I designated withref_inds
.When using this via a high level
umap
call, the reference (fit) data is assumed to be the first R samples of the input data. However, theoptimize_embedding
function can refer to arbitrary indices of the input graph / input embedding as the reference data, and arbitrary indices as the query data (data to be optimized/transformed). These indices don't necessarily have to cover all indices of the graph/embedding (the remaining entries will be ignored entirely). They can also overlap (which means that data may be optimized by computing attractions/repulsions to other data that is also intended to be optimized).I also added a few basic tests to cover these new features.
Let me know what you think about this addition, or if there are other parts I can clarify (either in this PR or documentation)!