Skip to content

Commit

Permalink
#88: Can't save, load then re-save a model
Browse files Browse the repository at this point in the history
  • Loading branch information
jlmelville committed Dec 9, 2021
1 parent ef41fbb commit bcb0df5
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# uwot 0.1.11.9000

## Bug fixes and minor improvements

* Models couldn't be re-saved after loading. Thank you to
[ilyakorsunsky](https://github.com/ilyakorsunsky) for reporting this
(<https://github.com/jlmelville/uwot/issues/88>).

# uwot 0.1.11

## New features
Expand Down
8 changes: 5 additions & 3 deletions R/uwot.R
Original file line number Diff line number Diff line change
Expand Up @@ -1762,7 +1762,7 @@ uwot <- function(X, n_neighbors = 15, n_components = 2, metric = "euclidean",
init <- match.arg(tolower(init), c(
"spectral", "random", "lvrandom", "normlaplacian",
"laplacian", "spca", "pca", "inormlaplacian", "ispectral",
"agspectral"
"agspectral", "normlaplaciantsvd"
))

if (init_is_spectral(init)) {
Expand Down Expand Up @@ -1816,6 +1816,7 @@ uwot <- function(X, n_neighbors = 15, n_components = 2, metric = "euclidean",
n_neg_nbrs = negative_sample_rate,
ndim = n_components, verbose = verbose
),
normlaplaciantsvd = normalized_laplacian_init_tsvd(V, ndim = n_components, verbose = verbose),
stop("Unknown initialization method: '", init, "'")
)
}
Expand Down Expand Up @@ -2213,11 +2214,12 @@ load_uwot <- function(file, verbose = FALSE) {
}
ann <- create_ann(annoy_metric, ndim = ndim)
ann$load(nn_fname)

if (n_metrics == 1) {
model$nn_index <- ann
model$nn_index <- list(ann = ann, type = "annoyv1", metric = annoy_metric)
}
else {
model$nn_index[[i]] <- ann
model$nn_index[[i]] <- list(ann = ann, type = "annoyv1", metric = annoy_metric)
}
}
model$mod_dir <- mod_dir
Expand Down
23 changes: 23 additions & 0 deletions tests/testthat/test_saveload.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,26 @@ test_that("unloading a model on save", {
unload_uwot(modelload, cleanup = TRUE)
expect_false(file.exists(modelload$mod_dir))
})

# #88
test_that("save-load-save", {
set.seed(1337)
X <- matrix(rnorm(100), 10, 10)

model <- uwot::umap(X, n_neighbors = 4, ret_model = TRUE)
model_file <- tempfile(tmpdir = tempdir())
model <- uwot::save_uwot(model, file = model_file)
model2 <- uwot::load_uwot(file = model_file)
new_file <- tempfile(tmpdir = tempdir())
uwot::save_uwot(model2, file = new_file)
expect_true(file.exists(new_file))

modelm <- uwot::umap(X, n_neighbors = 4, metric = list("euclidean" = 1:5, "euclidean" = 6:10), ret_model = TRUE)
modelm_file <- tempfile(tmpdir = tempdir())
modelm <- uwot::save_uwot(modelm, file = modelm_file)
modelm2 <- uwot::load_uwot(file = modelm_file)
new_filem <- tempfile(tmpdir = tempdir())
uwot::save_uwot(modelm2, file = new_filem)
expect_true(file.exists(new_filem))

})

0 comments on commit bcb0df5

Please sign in to comment.