diff --git a/DESCRIPTION b/DESCRIPTION index e695d61..2756590 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,23 +1,30 @@ Package: RcppHNSW -Title: 'Rcpp' Bindings for 'hnswlib', a Library for Approximate Nearest Neighbors -Version: 0.5.9000 -Authors@R: c(person("James", "Melville", email = "jlmelville@gmail.com", - role = c("aut", "cre")), - person("Aaron", "Lun", role = "ctb"), - person("Samuel", "Granjeaud", role = "ctb"), - person("Dmitriy", "Selivanov", role = "ctb"), - person("Yuxing", "Liao", role = "ctb")) -Description: 'Hnswlib' is a C++ library for Approximate Nearest Neighbors. This - package provides a minimal R interface by relying on the 'Rcpp' package. See - for more on 'hnswlib'. 'hnswlib' is - released under Version 2.0 of the Apache License. +Title: 'Rcpp' Bindings for 'hnswlib', a Library for Approximate Nearest + Neighbors +Version: 0.6.0 +Authors@R: c( + person("James", "Melville", , "jlmelville@gmail.com", role = c("aut", "cre", "cph")), + person("Aaron", "Lun", role = "ctb"), + person("Samuel", "Granjeaud", role = "ctb"), + person("Dmitriy", "Selivanov", role = "ctb"), + person("Yuxing", "Liao", role = "ctb") + ) +Description: 'Hnswlib' is a C++ library for Approximate Nearest Neighbors. + This package provides a minimal R interface by relying on the 'Rcpp' + package. See for more on + 'hnswlib'. 'hnswlib' is released under Version 2.0 of the Apache + License. License: GPL (>= 3) URL: https://github.com/jlmelville/rcpphnsw BugReports: https://github.com/jlmelville/rcpphnsw/issues +Imports: + methods, + Rcpp (>= 0.11.3) +Suggests: + covr, + testthat +LinkingTo: + Rcpp Encoding: UTF-8 -Imports: methods, Rcpp (>= 0.11.3) -LinkingTo: Rcpp -RoxygenNote: 7.2.3 -Suggests: testthat, - covr Roxygen: list(markdown = TRUE) +RoxygenNote: 7.3.1 diff --git a/NEWS.md b/NEWS.md index 4d161e1..12ecefc 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,4 +1,8 @@ -# RcppHNSW 0.5.9000 +# RcppHNSW 0.6.0 + +## New features + +* Updated hnswlib to [version 0.8.0](https://github.com/nmslib/hnswlib/releases/tag/v0.8.0). # RcppHNSW 0.5.0 diff --git a/cran-comments.md b/cran-comments.md index 8504823..c521f9a 100644 --- a/cran-comments.md +++ b/cran-comments.md @@ -1,20 +1,22 @@ ## Release Summary -This is a patch release to fix various CRAN check errors. +This is a patch release for a new version of the underlying hnswlib library. ## Test environments -* ubuntu 22.04 (on github actions), R 4.2.3, R 4.3.1, devel -* local ubuntu 23.04 R 4.2.2 +* ubuntu 22.04 (on github actions), R 4.2.3, R 4.3.2, devel +* local ubuntu 23.10 R 4.3.1 * Debian Linux, R-devel, GCC ASAN/UBSAN (via rhub) * Debian Linux, R-release, GCC (via rhub) +* Debian Linux, R-release, GCC valgrind (via rhub) * Ubuntu Linux 20.04.1 LTS, R-release, GCC (via rhub) * Fedora Linux, R-devel, clang, gfortran (via rhub) -* Windows Server 2022 (on github actions), R 4.2.3, R 4.3.1 +* Windows Server 2022 (on github actions), R 4.2.3, R 4.3.2 * Windows Server 2022, R-devel, 64 bit (via rhub) -* local Windows 11 build, R 4.3.1 +* local Windows 11 build, R 4.3.2 * win-builder (devel) -* mac OS X Monterey (on github actions) R 4.3.1 +* local mac OS X Sonoma R 4.3.2 +* mac OS X Monterey (on github actions) R 4.3.2 ## R CMD check results @@ -22,43 +24,23 @@ There were no ERRORs or WARNINGs. There was one NOTE: -N checking installed package size ... - installed size is 6.6Mb - sub-directories of 1Mb or more: - libs 6.3Mb +* checking installed package size ... NOTE + installed size is 6.7Mb + sub-directories of 1Mb or more: + libs 6.4Mb This is expected due to the use of C++ templates in hnswlib. -This is spelled correctly. - ## CRAN checks There are no ERRORs or WARNINGs. -There is a NOTE: - -Check: C++ specification -Result: NOTE - Specified C++11: please drop specification unless essential - -This submission fixes this. - -There is a NOTE: - -Check: Rd metadata -Result: NOTE - Invalid package aliases in Rd file ‘RcppHnsw-package.Rd’: - ‘RcppHnsw-package’ - -This submissions fixes this. - -There are four flavors with NOTEs about installed package size (r-release-macos-arm64, -r-release-macos-x86_64, r-oldrel-macos-arm64, r-oldrel-macos-x86_64). This is expected and won't be -fixed. +There are three flavors with NOTEs about installed package size (r-release-macos-arm64, +r-release-macos-x86_64, r-oldrel-macos-arm64). This is expected and won't be fixed. ## Downstream dependencies -We checked 2 reverse dependencies (0 from CRAN + 2 from Bioconductor), comparing R CMD check +We checked 3 reverse dependencies (1 from CRAN + 2 from Bioconductor), comparing R CMD check results across CRAN and dev versions of this package. * We saw 0 new problems diff --git a/inst/include/bruteforce.h b/inst/include/bruteforce.h index 30b33ae..8727cc8 100644 --- a/inst/include/bruteforce.h +++ b/inst/include/bruteforce.h @@ -84,10 +84,16 @@ class BruteforceSearch : public AlgorithmInterface { void removePoint(labeltype cur_external) { - size_t cur_c = dict_external_to_internal[cur_external]; + std::unique_lock lock(index_lock); - dict_external_to_internal.erase(cur_external); + auto found = dict_external_to_internal.find(cur_external); + if (found == dict_external_to_internal.end()) { + return; + } + + dict_external_to_internal.erase(found); + size_t cur_c = found->second; labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); dict_external_to_internal[label] = cur_c; memcpy(data_ + size_per_element_ * cur_c, @@ -106,7 +112,7 @@ class BruteforceSearch : public AlgorithmInterface { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); if ((!isIdAllowed) || (*isIdAllowed)(label)) { - topResults.push(std::pair(dist, label)); + topResults.emplace(dist, label); } } dist_t lastdist = topResults.empty() ? std::numeric_limits::max() : topResults.top().first; @@ -115,7 +121,7 @@ class BruteforceSearch : public AlgorithmInterface { if (dist <= lastdist) { labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); if ((!isIdAllowed) || (*isIdAllowed)(label)) { - topResults.push(std::pair(dist, label)); + topResults.emplace(dist, label); } if (topResults.size() > k) topResults.pop(); diff --git a/inst/include/hnswalg.h b/inst/include/hnswalg.h index e498e15..e269ae6 100644 --- a/inst/include/hnswalg.h +++ b/inst/include/hnswalg.h @@ -8,6 +8,7 @@ #include #include #include +#include namespace hnswlib { typedef unsigned int tableint; @@ -33,7 +34,7 @@ class HierarchicalNSW : public AlgorithmInterface { double mult_{0.0}, revSize_{0.0}; int maxlevel_{0}; - VisitedListPool *visited_list_pool_{nullptr}; + std::unique_ptr visited_list_pool_{nullptr}; // Locks operations with element by label value mutable std::vector label_op_locks_; @@ -101,7 +102,13 @@ class HierarchicalNSW : public AlgorithmInterface { data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); dist_func_param_ = s->get_dist_func_param(); - M_ = M; + if ( M <= 10000 ) { + M_ = M; + } else { + HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl; + HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl; + M_ = 10000; + } maxM_ = M_; maxM0_ = M_ * 2; ef_construction_ = std::max(ef_construction, M_); @@ -122,7 +129,7 @@ class HierarchicalNSW : public AlgorithmInterface { cur_element_count = 0; - visited_list_pool_ = new VisitedListPool(1, max_elements); + visited_list_pool_ = std::unique_ptr(new VisitedListPool(1, max_elements)); // initializations for special treatment of the first node enterpoint_node_ = -1; @@ -138,13 +145,20 @@ class HierarchicalNSW : public AlgorithmInterface { ~HierarchicalNSW() { + clear(); + } + + void clear() { free(data_level0_memory_); + data_level0_memory_ = nullptr; for (tableint i = 0; i < cur_element_count; i++) { if (element_levels_[i] > 0) free(linkLists_[i]); } free(linkLists_); - delete visited_list_pool_; + linkLists_ = nullptr; + cur_element_count = 0; + visited_list_pool_.reset(nullptr); } @@ -291,9 +305,15 @@ class HierarchicalNSW : public AlgorithmInterface { } - template + // bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance + template std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const { + searchBaseLayerST( + tableint ep_id, + const void *data_point, + size_t ef, + BaseFilterFunctor* isIdAllowed = nullptr, + BaseSearchStopCondition* stop_condition = nullptr) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; @@ -302,10 +322,15 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + if (bare_bone_search || + (!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) { + char* ep_data = getDataByInternalId(ep_id); + dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); + if (!bare_bone_search && stop_condition) { + stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist); + } candidate_set.emplace(-dist, ep_id); } else { lowerBound = std::numeric_limits::max(); @@ -316,9 +341,19 @@ class HierarchicalNSW : public AlgorithmInterface { while (!candidate_set.empty()) { std::pair current_node_pair = candidate_set.top(); + dist_t candidate_dist = -current_node_pair.first; - if ((-current_node_pair.first) > lowerBound && - (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) { + bool flag_stop_search; + if (bare_bone_search) { + flag_stop_search = candidate_dist > lowerBound; + } else { + if (stop_condition) { + flag_stop_search = stop_condition->should_stop_search(candidate_dist, lowerBound); + } else { + flag_stop_search = candidate_dist > lowerBound && top_candidates.size() == ef; + } + } + if (flag_stop_search) { break; } candidate_set.pop(); @@ -353,7 +388,14 @@ class HierarchicalNSW : public AlgorithmInterface { char *currObj1 = (getDataByInternalId(candidate_id)); dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); - if (top_candidates.size() < ef || lowerBound > dist) { + bool flag_consider_candidate; + if (!bare_bone_search && stop_condition) { + flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound); + } else { + flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist; + } + + if (flag_consider_candidate) { candidate_set.emplace(-dist, candidate_id); #ifdef USE_SSE _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + @@ -361,11 +403,30 @@ class HierarchicalNSW : public AlgorithmInterface { _MM_HINT_T0); //////////////////////// #endif - if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id)))) + if (bare_bone_search || + (!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) { top_candidates.emplace(dist, candidate_id); + if (!bare_bone_search && stop_condition) { + stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist); + } + } - if (top_candidates.size() > ef) + bool flag_remove_extra = false; + if (!bare_bone_search && stop_condition) { + flag_remove_extra = stop_condition->should_remove_extra(); + } else { + flag_remove_extra = top_candidates.size() > ef; + } + while (flag_remove_extra) { + tableint id = top_candidates.top().second; top_candidates.pop(); + if (!bare_bone_search && stop_condition) { + stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist); + flag_remove_extra = stop_condition->should_remove_extra(); + } else { + flag_remove_extra = top_candidates.size() > ef; + } + } if (!top_candidates.empty()) lowerBound = top_candidates.top().first; @@ -380,8 +441,8 @@ class HierarchicalNSW : public AlgorithmInterface { void getNeighborsByHeuristic2( - std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - const size_t M) { + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + const size_t M) { if (top_candidates.size() < M) { return; } @@ -573,8 +634,7 @@ class HierarchicalNSW : public AlgorithmInterface { if (new_max_elements < cur_element_count) throw std::runtime_error("Cannot resize, max element is less than the current number of elements"); - delete visited_list_pool_; - visited_list_pool_ = new VisitedListPool(1, new_max_elements); + visited_list_pool_.reset(new VisitedListPool(1, new_max_elements)); element_levels_.resize(new_max_elements); @@ -595,6 +655,32 @@ class HierarchicalNSW : public AlgorithmInterface { max_elements_ = new_max_elements; } + size_t indexFileSize() const { + size_t size = 0; + size += sizeof(offsetLevel0_); + size += sizeof(max_elements_); + size += sizeof(cur_element_count); + size += sizeof(size_data_per_element_); + size += sizeof(label_offset_); + size += sizeof(offsetData_); + size += sizeof(maxlevel_); + size += sizeof(enterpoint_node_); + size += sizeof(maxM_); + + size += sizeof(maxM0_); + size += sizeof(M_); + size += sizeof(mult_); + size += sizeof(ef_construction_); + + size += cur_element_count * size_data_per_element_; + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + size += sizeof(linkListSize); + size += linkListSize; + } + return size; + } void saveIndex(const std::string &location) { std::ofstream output(location, std::ios::binary); @@ -633,6 +719,7 @@ class HierarchicalNSW : public AlgorithmInterface { if (!input.is_open()) throw std::runtime_error("Cannot open file"); + clear(); // get file size: input.seekg(0, input.end); std::streampos total_filesize = input.tellg(); @@ -698,7 +785,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::vector(max_elements).swap(link_list_locks_); std::vector(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_); - visited_list_pool_ = new VisitedListPool(1, max_elements); + visited_list_pool_.reset(new VisitedListPool(1, max_elements)); linkLists_ = (char **) malloc(sizeof(void *) * max_elements); if (linkLists_ == nullptr) @@ -1216,11 +1303,12 @@ class HierarchicalNSW : public AlgorithmInterface { } std::priority_queue, std::vector>, CompareByFirst> top_candidates; - if (num_deleted_) { - top_candidates = searchBaseLayerST( + bool bare_bone_search = !num_deleted_ && !isIdAllowed; + if (bare_bone_search) { + top_candidates = searchBaseLayerST( currObj, query_data, std::max(ef_, k), isIdAllowed); } else { - top_candidates = searchBaseLayerST( + top_candidates = searchBaseLayerST( currObj, query_data, std::max(ef_, k), isIdAllowed); } @@ -1236,6 +1324,60 @@ class HierarchicalNSW : public AlgorithmInterface { } + std::vector> + searchStopConditionClosest( + const void *query_data, + BaseSearchStopCondition& stop_condition, + BaseFilterFunctor* isIdAllowed = nullptr) const { + std::vector> result; + if (cur_element_count == 0) return result; + + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + metric_hops++; + metric_distance_computations+=size; + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + top_candidates = searchBaseLayerST(currObj, query_data, 0, isIdAllowed, &stop_condition); + + size_t sz = top_candidates.size(); + result.resize(sz); + while (!top_candidates.empty()) { + result[--sz] = top_candidates.top(); + top_candidates.pop(); + } + + stop_condition.filter_results(result); + + return result; + } + + void checkIntegrity() { int connections_checked = 0; std::vector inbound_connections_num(cur_element_count, 0); @@ -1246,7 +1388,6 @@ class HierarchicalNSW : public AlgorithmInterface { tableint *data = (tableint *) (ll_cur + 1); std::unordered_set s; for (int j = 0; j < size; j++) { - assert(data[j] > 0); assert(data[j] < cur_element_count); assert(data[j] != i); inbound_connections_num[data[j]]++; diff --git a/inst/include/hnswlib.h b/inst/include/hnswlib.h index fb7118f..7ccfbba 100644 --- a/inst/include/hnswlib.h +++ b/inst/include/hnswlib.h @@ -1,4 +1,13 @@ #pragma once + +// https://github.com/nmslib/hnswlib/pull/508 +// This allows others to provide their own error stream (e.g. RcppHNSW) +#ifndef HNSWLIB_ERR_OVERRIDE + #define HNSWERR std::cerr +#else + #define HNSWERR HNSWLIB_ERR_OVERRIDE +#endif + #ifndef NO_MANUAL_VECTORIZATION #if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64)) #define USE_SSE @@ -15,7 +24,7 @@ #ifdef _MSC_VER #include #include -void cpuid(int32_t out[4], int32_t eax, int32_t ecx) { +static void cpuid(int32_t out[4], int32_t eax, int32_t ecx) { __cpuidex(out, eax, ecx); } static __int64 xgetbv(unsigned int x) { @@ -119,6 +128,25 @@ typedef size_t labeltype; class BaseFilterFunctor { public: virtual bool operator()(hnswlib::labeltype id) { return true; } + virtual ~BaseFilterFunctor() {}; +}; + +template +class BaseSearchStopCondition { + public: + virtual void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) = 0; + + virtual void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) = 0; + + virtual bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) = 0; + + virtual bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) = 0; + + virtual bool should_remove_extra() = 0; + + virtual void filter_results(std::vector> &candidates) = 0; + + virtual ~BaseSearchStopCondition() {} }; template @@ -195,5 +223,6 @@ AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t #include "space_l2.h" #include "space_ip.h" +#include "stop_condition.h" #include "bruteforce.h" #include "hnswalg.h" diff --git a/inst/include/space_ip.h b/inst/include/space_ip.h index 2b1c359..0e6834c 100644 --- a/inst/include/space_ip.h +++ b/inst/include/space_ip.h @@ -157,19 +157,44 @@ InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void __m512 sum512 = _mm512_set1_ps(0); - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - + size_t loop = qty16 / 4; + + while (loop--) { __m512 v1 = _mm512_loadu_ps(pVect1); - pVect1 += 16; __m512 v2 = _mm512_loadu_ps(pVect2); + pVect1 += 16; + pVect2 += 16; + + __m512 v3 = _mm512_loadu_ps(pVect1); + __m512 v4 = _mm512_loadu_ps(pVect2); + pVect1 += 16; + pVect2 += 16; + + __m512 v5 = _mm512_loadu_ps(pVect1); + __m512 v6 = _mm512_loadu_ps(pVect2); + pVect1 += 16; pVect2 += 16; - sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2)); + + __m512 v7 = _mm512_loadu_ps(pVect1); + __m512 v8 = _mm512_loadu_ps(pVect2); + pVect1 += 16; + pVect2 += 16; + + sum512 = _mm512_fmadd_ps(v1, v2, sum512); + sum512 = _mm512_fmadd_ps(v3, v4, sum512); + sum512 = _mm512_fmadd_ps(v5, v6, sum512); + sum512 = _mm512_fmadd_ps(v7, v8, sum512); } - _mm512_store_ps(TmpRes, sum512); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15]; + while (pVect1 < pEnd1) { + __m512 v1 = _mm512_loadu_ps(pVect1); + __m512 v2 = _mm512_loadu_ps(pVect2); + pVect1 += 16; + pVect2 += 16; + sum512 = _mm512_fmadd_ps(v1, v2, sum512); + } + float sum = _mm512_reduce_add_ps(sum512); return sum; } diff --git a/inst/include/stop_condition.h b/inst/include/stop_condition.h new file mode 100644 index 0000000..acc80eb --- /dev/null +++ b/inst/include/stop_condition.h @@ -0,0 +1,276 @@ +#pragma once +#include "space_l2.h" +#include "space_ip.h" +#include +#include + +namespace hnswlib { + +template +class BaseMultiVectorSpace : public SpaceInterface { + public: + virtual DOCIDTYPE get_doc_id(const void *datapoint) = 0; + + virtual void set_doc_id(void *datapoint, DOCIDTYPE doc_id) = 0; +}; + + +template +class MultiVectorL2Space : public BaseMultiVectorSpace { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t vector_size_; + size_t dim_; + + public: + MultiVectorL2Space(size_t dim) { + fstdistfunc_ = L2Sqr; +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; + else if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #elif defined(USE_AVX) + if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #endif + + if (dim % 16 == 0) + fstdistfunc_ = L2SqrSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = L2SqrSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = L2SqrSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = L2SqrSIMD4ExtResiduals; +#endif + dim_ = dim; + vector_size_ = dim * sizeof(float); + data_size_ = vector_size_ + sizeof(DOCIDTYPE); + } + + size_t get_data_size() override { + return data_size_; + } + + DISTFUNC get_dist_func() override { + return fstdistfunc_; + } + + void *get_dist_func_param() override { + return &dim_; + } + + DOCIDTYPE get_doc_id(const void *datapoint) override { + return *(DOCIDTYPE *)((char *)datapoint + vector_size_); + } + + void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override { + *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id; + } + + ~MultiVectorL2Space() {} +}; + + +template +class MultiVectorInnerProductSpace : public BaseMultiVectorSpace { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t vector_size_; + size_t dim_; + + public: + MultiVectorInnerProductSpace(size_t dim) { + fstdistfunc_ = InnerProductDistance; +#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; + } else if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #elif defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #endif + #if defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; + InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; + } + #endif + + if (dim % 16 == 0) + fstdistfunc_ = InnerProductDistanceSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = InnerProductDistanceSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; +#endif + vector_size_ = dim * sizeof(float); + data_size_ = vector_size_ + sizeof(DOCIDTYPE); + } + + size_t get_data_size() override { + return data_size_; + } + + DISTFUNC get_dist_func() override { + return fstdistfunc_; + } + + void *get_dist_func_param() override { + return &dim_; + } + + DOCIDTYPE get_doc_id(const void *datapoint) override { + return *(DOCIDTYPE *)((char *)datapoint + vector_size_); + } + + void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override { + *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id; + } + + ~MultiVectorInnerProductSpace() {} +}; + + +template +class MultiVectorSearchStopCondition : public BaseSearchStopCondition { + size_t curr_num_docs_; + size_t num_docs_to_search_; + size_t ef_collection_; + std::unordered_map doc_counter_; + std::priority_queue> search_results_; + BaseMultiVectorSpace& space_; + + public: + MultiVectorSearchStopCondition( + BaseMultiVectorSpace& space, + size_t num_docs_to_search, + size_t ef_collection = 10) + : space_(space) { + curr_num_docs_ = 0; + num_docs_to_search_ = num_docs_to_search; + ef_collection_ = std::max(ef_collection, num_docs_to_search); + } + + void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override { + DOCIDTYPE doc_id = space_.get_doc_id(datapoint); + if (doc_counter_[doc_id] == 0) { + curr_num_docs_ += 1; + } + search_results_.emplace(dist, doc_id); + doc_counter_[doc_id] += 1; + } + + void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override { + DOCIDTYPE doc_id = space_.get_doc_id(datapoint); + doc_counter_[doc_id] -= 1; + if (doc_counter_[doc_id] == 0) { + curr_num_docs_ -= 1; + } + search_results_.pop(); + } + + bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override { + bool stop_search = candidate_dist > lowerBound && curr_num_docs_ == ef_collection_; + return stop_search; + } + + bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override { + bool flag_consider_candidate = curr_num_docs_ < ef_collection_ || lowerBound > candidate_dist; + return flag_consider_candidate; + } + + bool should_remove_extra() override { + bool flag_remove_extra = curr_num_docs_ > ef_collection_; + return flag_remove_extra; + } + + void filter_results(std::vector> &candidates) override { + while (curr_num_docs_ > num_docs_to_search_) { + dist_t dist_cand = candidates.back().first; + dist_t dist_res = search_results_.top().first; + assert(dist_cand == dist_res); + DOCIDTYPE doc_id = search_results_.top().second; + doc_counter_[doc_id] -= 1; + if (doc_counter_[doc_id] == 0) { + curr_num_docs_ -= 1; + } + search_results_.pop(); + candidates.pop_back(); + } + } + + ~MultiVectorSearchStopCondition() {} +}; + + +template +class EpsilonSearchStopCondition : public BaseSearchStopCondition { + float epsilon_; + size_t min_num_candidates_; + size_t max_num_candidates_; + size_t curr_num_items_; + + public: + EpsilonSearchStopCondition(float epsilon, size_t min_num_candidates, size_t max_num_candidates) { + assert(min_num_candidates <= max_num_candidates); + epsilon_ = epsilon; + min_num_candidates_ = min_num_candidates; + max_num_candidates_ = max_num_candidates; + curr_num_items_ = 0; + } + + void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override { + curr_num_items_ += 1; + } + + void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override { + curr_num_items_ -= 1; + } + + bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override { + if (candidate_dist > lowerBound && curr_num_items_ == max_num_candidates_) { + // new candidate can't improve found results + return true; + } + if (candidate_dist > epsilon_ && curr_num_items_ >= min_num_candidates_) { + // new candidate is out of epsilon region and + // minimum number of candidates is checked + return true; + } + return false; + } + + bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override { + bool flag_consider_candidate = curr_num_items_ < max_num_candidates_ || lowerBound > candidate_dist; + return flag_consider_candidate; + } + + bool should_remove_extra() { + bool flag_remove_extra = curr_num_items_ > max_num_candidates_; + return flag_remove_extra; + } + + void filter_results(std::vector> &candidates) override { + while (!candidates.empty() && candidates.back().first > epsilon_) { + candidates.pop_back(); + } + while (candidates.size() > max_num_candidates_) { + candidates.pop_back(); + } + } + + ~EpsilonSearchStopCondition() {} +}; +} // namespace hnswlib diff --git a/man/RcppHnsw-package.Rd b/man/RcppHnsw-package.Rd index b05762c..daab253 100644 --- a/man/RcppHnsw-package.Rd +++ b/man/RcppHnsw-package.Rd @@ -2,8 +2,8 @@ % Please edit documentation in R/rcpphnsw-package.R \docType{package} \name{RcppHnsw-package} +\alias{RcppHNSW} \alias{RcppHnsw-package} -\alias{_PACKAGE} \alias{HnswL2} \alias{Rcpp_HnswL2-class} \alias{HnswCosine} diff --git a/src/hnsw.cpp b/src/hnsw.cpp index 65433f9..019beb3 100644 --- a/src/hnsw.cpp +++ b/src/hnsw.cpp @@ -26,7 +26,7 @@ #include -#include "hnswlib.h" +#include "rcpphnsw.h" #include "RcppPerpendicular/RcppPerpendicular.h" diff --git a/src/rcpphnsw.h b/src/rcpphnsw.h new file mode 100644 index 0000000..10bece4 --- /dev/null +++ b/src/rcpphnsw.h @@ -0,0 +1,30 @@ +// RcppHNSW -- Rcpp bindings to hnswlib library for Approximate Nearest +// Neighbors +// +// Copyright (C) 2023 James Melville +// +// This file is part of RcppHNSW +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +#ifndef RCPP_RCPPHNSW_H +#define RCPP_RCPPHNSW_H + +#include + +#define HNSWLIB_ERR_OVERRIDE Rcpp::Rcerr + +#include "hnswlib.h" + +#endif // RCPP_RCPPHNSW_H