Skip to content

Commit

Permalink
add save/load test
Browse files Browse the repository at this point in the history
  • Loading branch information
jlmelville committed Mar 10, 2024
1 parent 144c878 commit 62862e1
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ iris_nn <- hnsw_search(irism[101:150, ], ann, k = 5)

## Class Example

As noted in the "Do not use named parameters" section below, you should avoid
using named parameters when using class methods. But I do use them in a few
places below to document the name of the parameters the positional arguments
refer to.

```R
library(RcppHNSW)
data <- as.matrix(iris[, -5])
Expand Down Expand Up @@ -153,7 +158,7 @@ res <- ann$getNNsList(data[1, ], k = 4, include_distances = TRUE)
ann2 <- new(HnswL2, dim, nitems, M, ef)
ann2$addItems(data)
# Retrieve the 4 nearest neighbors for every item in data
res2 <- ann2$getAllNNsList(data, k = 4, include_distances = TRUE)
res2 <- ann2$getAllNNsList(data, 4, TRUE)
# labels of the data are in res$item, distances in res$distance

# If you are able to store your data column-wise, then the overhead of copying
Expand All @@ -162,10 +167,18 @@ data_by_col <- t(data)
ann3 <- new(HnswL2, dim, nitems, M, ef)
ann3$addItemsCol(data_by_col)
# Retrieve the 4 nearest neighbors for every item in data_by_col
res3 <- ann3$getAllNNsListCol(data_by_col, k = 4, include_distances = TRUE)
res3 <- ann3$getAllNNsListCol(data_by_col, 4, TRUE)
# The returned neared neighbor data matrices are also returned column-wise
all(res2$item == t(res3$item) & res2$distance == t(res3$distance))

# Save the index
ann$save("iris.hnsw")

# load it back in: you do need to know the dimension of the original data
ann4 <- new(HnswL2, dim, "iris.hnsw")
# new index should behave like the original
all(ann$getNNs(data[1, ], 4) == ann4$getNNs(data[1, ], 4))

# other distance classes:
# Cosine: HnswCosine
# Inner Product: HnswIP
Expand Down
47 changes: 47 additions & 0 deletions tests/testthat/test_save_load.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
library(RcppHNSW)
context("Save/load index")

num_elements <- nrow(uirism)
dim <- ncol(uirism)

M <- 16
ef_construction <- 10
p <- new(HnswL2, dim, num_elements, M, ef_construction)

for (i in 1:num_elements) {
p$addItem(uirism[i, ])
}

nn4idx <- matrix(0L, nrow = num_elements, ncol = 4)
nn4dist <- matrix(0.0, nrow = num_elements, ncol = 4)

for (i in 1:num_elements) {
res <- p$getNNsList(uirism[i, ], k = 4, TRUE)
nn4idx[i, ] <- res$item
nn4dist[i, ] <- res$distance
}

temp_file <- tempfile()
on.exit(unlink(temp_file), add = TRUE)
p$save(temp_file)

nn4idx_aftersave <- matrix(0L, nrow = num_elements, ncol = 4)
nn4dist_aftersave <- matrix(0.0, nrow = num_elements, ncol = 4)
for (i in 1:num_elements) {
res_aftersave <- p$getNNsList(uirism[i, ], k = 4, TRUE)
nn4idx_aftersave[i, ] <- res_aftersave$item
nn4dist_aftersave[i, ] <- res_aftersave$distance
}
expect_equal(nn4idx, nn4idx_aftersave)
expect_equal(nn4dist, nn4dist_aftersave)

pload <- new(HnswL2, dim, temp_file)
nn4idx_afterload <- matrix(0L, nrow = num_elements, ncol = 4)
nn4dist_afterload <- matrix(0.0, nrow = num_elements, ncol = 4)
for (i in 1:num_elements) {
res_afterload <- pload$getNNsList(uirism[i, ], k = 4, TRUE)
nn4idx_afterload[i, ] <- res_afterload$item
nn4dist_afterload[i, ] <- res_afterload$distance
}
expect_equal(nn4idx, nn4idx_afterload)
expect_equal(nn4dist, nn4dist_afterload)

0 comments on commit 62862e1

Please sign in to comment.