diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 03e84c62..0eb741cc 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -67,6 +67,8 @@ jobs: ./example_mt_search ./example_mt_filter ./example_mt_replace_deleted + ./example_multivector_search + ./example_epsilon_search ./searchKnnCloserFirst_test ./searchKnnWithFilter_test ./multiThreadLoad_test @@ -74,4 +76,6 @@ jobs: ./test_updates ./test_updates update ./repair_test + ./multivector_search_test + ./epsilon_search_test shell: bash diff --git a/.gitignore b/.gitignore index 48f74604..d46c9890 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ var/ .vscode/ .vs/ **.DS_Store +*.pyc diff --git a/CMakeLists.txt b/CMakeLists.txt index 53f4f1ee..5e8aa08d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -57,6 +57,12 @@ if(HNSWLIB_EXAMPLES) add_executable(example_search examples/cpp/example_search.cpp) target_link_libraries(example_search hnswlib) + add_executable(example_epsilon_search examples/cpp/example_epsilon_search.cpp) + target_link_libraries(example_epsilon_search hnswlib) + + add_executable(example_multivector_search examples/cpp/example_multivector_search.cpp) + target_link_libraries(example_multivector_search hnswlib) + add_executable(example_filter examples/cpp/example_filter.cpp) target_link_libraries(example_filter hnswlib) @@ -73,6 +79,12 @@ if(HNSWLIB_EXAMPLES) target_link_libraries(example_mt_replace_deleted hnswlib) # tests + add_executable(multivector_search_test tests/cpp/multivector_search_test.cpp) + target_link_libraries(multivector_search_test hnswlib) + + add_executable(epsilon_search_test tests/cpp/epsilon_search_test.cpp) + target_link_libraries(epsilon_search_test hnswlib) + add_executable(test_updates tests/cpp/updates_test.cpp) target_link_libraries(test_updates hnswlib) diff --git a/README.md b/README.md index 3ed466a7..5b972e0b 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ For other spaces use the nmslib library https://github.com/nmslib/nmslib. * `set_num_threads(num_threads)` set the default number of cpu threads used during data insertion/querying. -* `get_items(ids)` - returns a numpy array (shape:`N*dim`) of vectors that have integer identifiers specified in `ids` numpy vector (shape:`N`). Note that for cosine similarity it currently returns **normalized** vectors. +* `get_items(ids, return_type = 'numpy')` - returns a numpy array (shape:`N*dim`) of vectors that have integer identifiers specified in `ids` numpy vector (shape:`N`) if `return_type` is `list` return list of lists. Note that for cosine similarity it currently returns **normalized** vectors. * `get_ids_list()` - returns a list of all elements' ids. @@ -229,6 +229,8 @@ print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(dat * filtering during the search with a boolean function * deleting the elements and reusing the memory of the deleted elements for newly added elements * multithreaded usage +* multivector search +* epsilon search ### Bindings installation diff --git a/examples/cpp/EXAMPLES.md b/examples/cpp/EXAMPLES.md index 3af603d4..5f9adc30 100644 --- a/examples/cpp/EXAMPLES.md +++ b/examples/cpp/EXAMPLES.md @@ -182,4 +182,8 @@ int main() { Multithreaded examples: * Creating index, inserting elements, searching [example_mt_search.cpp](example_mt_search.cpp) * Filtering during the search with a boolean function [example_mt_filter.cpp](example_mt_filter.cpp) -* Reusing the memory of the deleted elements when new elements are being added [example_mt_replace_deleted.cpp](example_mt_replace_deleted.cpp) \ No newline at end of file +* Reusing the memory of the deleted elements when new elements are being added [example_mt_replace_deleted.cpp](example_mt_replace_deleted.cpp) + +More examples: +* Multivector search [example_multivector_search.cpp](example_multivector_search.cpp) +* Epsilon search [example_epsilon_search.cpp](example_epsilon_search.cpp) \ No newline at end of file diff --git a/examples/cpp/example_epsilon_search.cpp b/examples/cpp/example_epsilon_search.cpp new file mode 100644 index 00000000..49eec408 --- /dev/null +++ b/examples/cpp/example_epsilon_search.cpp @@ -0,0 +1,66 @@ +#include "../../hnswlib/hnswlib.h" + +typedef unsigned int docidtype; +typedef float dist_t; + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + int min_num_candidates = 100; // Minimum number of candidates to search in the epsilon region + // this parameter is similar to ef + + int num_queries = 5; + float epsilon2 = 2.0; // Squared distance to query + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + + size_t data_point_size = space.get_data_size(); + char* data = new char[data_point_size * max_elements]; + for (int i = 0; i < max_elements; i++) { + char* point_data = data + i * data_point_size; + for (int j = 0; j < dim; j++) { + char* vec_data = point_data + j * sizeof(float); + float value = distrib_real(rng); + *(float*)vec_data = value; + } + } + + // Add data to index + for (int i = 0; i < max_elements; i++) { + hnswlib::labeltype label = i; + char* point_data = data + i * data_point_size; + alg_hnsw->addPoint(point_data, label); + } + + // Query random vectors + for (int i = 0; i < num_queries; i++) { + char* query_data = new char[data_point_size]; + for (int j = 0; j < dim; j++) { + size_t offset = j * sizeof(float); + char* vec_data = query_data + offset; + float value = distrib_real(rng); + *(float*)vec_data = value; + } + std::cout << "Query #" << i << "\n"; + hnswlib::EpsilonSearchStopCondition stop_condition(epsilon2, min_num_candidates, max_elements); + std::vector> result = + alg_hnsw->searchStopConditionClosest(query_data, stop_condition); + size_t num_vectors = result.size(); + std::cout << "Found " << num_vectors << " vectors\n"; + delete[] query_data; + } + + delete[] data; + delete alg_hnsw; + return 0; +} diff --git a/examples/cpp/example_multivector_search.cpp b/examples/cpp/example_multivector_search.cpp new file mode 100644 index 00000000..06aafe0b --- /dev/null +++ b/examples/cpp/example_multivector_search.cpp @@ -0,0 +1,83 @@ +#include "../../hnswlib/hnswlib.h" + +typedef unsigned int docidtype; +typedef float dist_t; + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + int num_queries = 5; + int num_docs = 5; // Number of documents to search + int ef_collection = 6; // Number of candidate documents during the search + // Controlls the recall: higher ef leads to better accuracy, but slower search + docidtype min_doc_id = 0; + docidtype max_doc_id = 9; + + // Initing index + hnswlib::MultiVectorL2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + std::uniform_int_distribution distrib_docid(min_doc_id, max_doc_id); + + size_t data_point_size = space.get_data_size(); + char* data = new char[data_point_size * max_elements]; + for (int i = 0; i < max_elements; i++) { + // set vector value + char* point_data = data + i * data_point_size; + for (int j = 0; j < dim; j++) { + char* vec_data = point_data + j * sizeof(float); + float value = distrib_real(rng); + *(float*)vec_data = value; + } + // set document id + docidtype doc_id = distrib_docid(rng); + space.set_doc_id(point_data, doc_id); + } + + // Add data to index + std::unordered_map label_docid_lookup; + for (int i = 0; i < max_elements; i++) { + hnswlib::labeltype label = i; + char* point_data = data + i * data_point_size; + alg_hnsw->addPoint(point_data, label); + label_docid_lookup[label] = space.get_doc_id(point_data); + } + + // Query random vectors + size_t query_size = dim * sizeof(float); + for (int i = 0; i < num_queries; i++) { + char* query_data = new char[query_size]; + for (int j = 0; j < dim; j++) { + size_t offset = j * sizeof(float); + char* vec_data = query_data + offset; + float value = distrib_real(rng); + *(float*)vec_data = value; + } + std::cout << "Query #" << i << "\n"; + hnswlib::MultiVectorSearchStopCondition stop_condition(space, num_docs, ef_collection); + std::vector> result = + alg_hnsw->searchStopConditionClosest(query_data, stop_condition); + size_t num_vectors = result.size(); + + std::unordered_map doc_counter; + for (auto pair: result) { + hnswlib::labeltype label = pair.second; + docidtype doc_id = label_docid_lookup[label]; + doc_counter[doc_id] += 1; + } + std::cout << "Found " << doc_counter.size() << " documents, " << num_vectors << " vectors\n"; + delete[] query_data; + } + + delete[] data; + delete alg_hnsw; + return 0; +} diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 73069fad..db4dba87 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -107,8 +107,8 @@ class HierarchicalNSW : public AlgorithmInterface { if ( M <= 10000 ) { M_ = M; } else { - std::cerr << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl; - std::cerr << " Cap to 10000 will be applied for the rest of the processing." << std::endl; + HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl; + HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl; M_ = 10000; } maxM_ = M_; @@ -301,9 +301,15 @@ class HierarchicalNSW : public AlgorithmInterface { } - template + // bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance + template std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const { + searchBaseLayerST( + tableint ep_id, + const void *data_point, + size_t ef, + BaseFilterFunctor* isIdAllowed = nullptr, + BaseSearchStopCondition* stop_condition = nullptr) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; @@ -312,10 +318,15 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + if (bare_bone_search || + (!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) { + char* ep_data = getDataByInternalId(ep_id); + dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); + if (!bare_bone_search && stop_condition) { + stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist); + } candidate_set.emplace(-dist, ep_id); } else { lowerBound = std::numeric_limits::max(); @@ -326,9 +337,19 @@ class HierarchicalNSW : public AlgorithmInterface { while (!candidate_set.empty()) { std::pair current_node_pair = candidate_set.top(); + dist_t candidate_dist = -current_node_pair.first; - if ((-current_node_pair.first) > lowerBound && - (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) { + bool flag_stop_search; + if (bare_bone_search) { + flag_stop_search = candidate_dist > lowerBound; + } else { + if (stop_condition) { + flag_stop_search = stop_condition->should_stop_search(candidate_dist, lowerBound); + } else { + flag_stop_search = candidate_dist > lowerBound && top_candidates.size() == ef; + } + } + if (flag_stop_search) { break; } candidate_set.pop(); @@ -363,7 +384,14 @@ class HierarchicalNSW : public AlgorithmInterface { char *currObj1 = (getDataByInternalId(candidate_id)); dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); - if (top_candidates.size() < ef || lowerBound > dist) { + bool flag_consider_candidate; + if (!bare_bone_search && stop_condition) { + flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound); + } else { + flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist; + } + + if (flag_consider_candidate) { candidate_set.emplace(-dist, candidate_id); #ifdef USE_SSE _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + @@ -371,11 +399,30 @@ class HierarchicalNSW : public AlgorithmInterface { _MM_HINT_T0); //////////////////////// #endif - if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id)))) + if (bare_bone_search || + (!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) { top_candidates.emplace(dist, candidate_id); + if (!bare_bone_search && stop_condition) { + stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist); + } + } - if (top_candidates.size() > ef) + bool flag_remove_extra = false; + if (!bare_bone_search && stop_condition) { + flag_remove_extra = stop_condition->should_remove_extra(); + } else { + flag_remove_extra = top_candidates.size() > ef; + } + while (flag_remove_extra) { + tableint id = top_candidates.top().second; top_candidates.pop(); + if (!bare_bone_search && stop_condition) { + stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist); + flag_remove_extra = stop_condition->should_remove_extra(); + } else { + flag_remove_extra = top_candidates.size() > ef; + } + } if (!top_candidates.empty()) lowerBound = top_candidates.top().first; @@ -390,8 +437,8 @@ class HierarchicalNSW : public AlgorithmInterface { void getNeighborsByHeuristic2( - std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - const size_t M) { + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + const size_t M) { if (top_candidates.size() < M) { return; } @@ -779,7 +826,7 @@ class HierarchicalNSW : public AlgorithmInterface { size_t dim = *((size_t *) dist_func_param_); std::vector data; data_t* data_ptr = (data_t*) data_ptrv; - for (int i = 0; i < dim; i++) { + for (size_t i = 0; i < dim; i++) { data.push_back(*data_ptr); data_ptr += 1; } @@ -1244,11 +1291,12 @@ class HierarchicalNSW : public AlgorithmInterface { } std::priority_queue, std::vector>, CompareByFirst> top_candidates; - if (num_deleted_) { - top_candidates = searchBaseLayerST( + bool bare_bone_search = !num_deleted_ && !isIdAllowed; + if (bare_bone_search) { + top_candidates = searchBaseLayerST( currObj, query_data, std::max(ef_, k), isIdAllowed); } else { - top_candidates = searchBaseLayerST( + top_candidates = searchBaseLayerST( currObj, query_data, std::max(ef_, k), isIdAllowed); } @@ -1264,6 +1312,60 @@ class HierarchicalNSW : public AlgorithmInterface { } + std::vector> + searchStopConditionClosest( + const void *query_data, + BaseSearchStopCondition& stop_condition, + BaseFilterFunctor* isIdAllowed = nullptr) const { + std::vector> result; + if (cur_element_count == 0) return result; + + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + metric_hops++; + metric_distance_computations+=size; + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + top_candidates = searchBaseLayerST(currObj, query_data, 0, isIdAllowed, &stop_condition); + + size_t sz = top_candidates.size(); + result.resize(sz); + while (!top_candidates.empty()) { + result[--sz] = top_candidates.top(); + top_candidates.pop(); + } + + stop_condition.filter_results(result); + + return result; + } + + void checkIntegrity() { int connections_checked = 0; std::vector inbound_connections_num(cur_element_count, 0); diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 402c6d8c..7ccfbba5 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -1,4 +1,13 @@ #pragma once + +// https://github.com/nmslib/hnswlib/pull/508 +// This allows others to provide their own error stream (e.g. RcppHNSW) +#ifndef HNSWLIB_ERR_OVERRIDE + #define HNSWERR std::cerr +#else + #define HNSWERR HNSWLIB_ERR_OVERRIDE +#endif + #ifndef NO_MANUAL_VECTORIZATION #if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64)) #define USE_SSE @@ -122,6 +131,24 @@ class BaseFilterFunctor { virtual ~BaseFilterFunctor() {}; }; +template +class BaseSearchStopCondition { + public: + virtual void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) = 0; + + virtual void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) = 0; + + virtual bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) = 0; + + virtual bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) = 0; + + virtual bool should_remove_extra() = 0; + + virtual void filter_results(std::vector> &candidates) = 0; + + virtual ~BaseSearchStopCondition() {} +}; + template class pairGreater { public: @@ -196,5 +223,6 @@ AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t #include "space_l2.h" #include "space_ip.h" +#include "stop_condition.h" #include "bruteforce.h" #include "hnswalg.h" diff --git a/hnswlib/stop_condition.h b/hnswlib/stop_condition.h new file mode 100644 index 00000000..acc80ebe --- /dev/null +++ b/hnswlib/stop_condition.h @@ -0,0 +1,276 @@ +#pragma once +#include "space_l2.h" +#include "space_ip.h" +#include +#include + +namespace hnswlib { + +template +class BaseMultiVectorSpace : public SpaceInterface { + public: + virtual DOCIDTYPE get_doc_id(const void *datapoint) = 0; + + virtual void set_doc_id(void *datapoint, DOCIDTYPE doc_id) = 0; +}; + + +template +class MultiVectorL2Space : public BaseMultiVectorSpace { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t vector_size_; + size_t dim_; + + public: + MultiVectorL2Space(size_t dim) { + fstdistfunc_ = L2Sqr; +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; + else if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #elif defined(USE_AVX) + if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #endif + + if (dim % 16 == 0) + fstdistfunc_ = L2SqrSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = L2SqrSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = L2SqrSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = L2SqrSIMD4ExtResiduals; +#endif + dim_ = dim; + vector_size_ = dim * sizeof(float); + data_size_ = vector_size_ + sizeof(DOCIDTYPE); + } + + size_t get_data_size() override { + return data_size_; + } + + DISTFUNC get_dist_func() override { + return fstdistfunc_; + } + + void *get_dist_func_param() override { + return &dim_; + } + + DOCIDTYPE get_doc_id(const void *datapoint) override { + return *(DOCIDTYPE *)((char *)datapoint + vector_size_); + } + + void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override { + *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id; + } + + ~MultiVectorL2Space() {} +}; + + +template +class MultiVectorInnerProductSpace : public BaseMultiVectorSpace { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t vector_size_; + size_t dim_; + + public: + MultiVectorInnerProductSpace(size_t dim) { + fstdistfunc_ = InnerProductDistance; +#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; + } else if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #elif defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #endif + #if defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; + InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; + } + #endif + + if (dim % 16 == 0) + fstdistfunc_ = InnerProductDistanceSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = InnerProductDistanceSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; +#endif + vector_size_ = dim * sizeof(float); + data_size_ = vector_size_ + sizeof(DOCIDTYPE); + } + + size_t get_data_size() override { + return data_size_; + } + + DISTFUNC get_dist_func() override { + return fstdistfunc_; + } + + void *get_dist_func_param() override { + return &dim_; + } + + DOCIDTYPE get_doc_id(const void *datapoint) override { + return *(DOCIDTYPE *)((char *)datapoint + vector_size_); + } + + void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override { + *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id; + } + + ~MultiVectorInnerProductSpace() {} +}; + + +template +class MultiVectorSearchStopCondition : public BaseSearchStopCondition { + size_t curr_num_docs_; + size_t num_docs_to_search_; + size_t ef_collection_; + std::unordered_map doc_counter_; + std::priority_queue> search_results_; + BaseMultiVectorSpace& space_; + + public: + MultiVectorSearchStopCondition( + BaseMultiVectorSpace& space, + size_t num_docs_to_search, + size_t ef_collection = 10) + : space_(space) { + curr_num_docs_ = 0; + num_docs_to_search_ = num_docs_to_search; + ef_collection_ = std::max(ef_collection, num_docs_to_search); + } + + void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override { + DOCIDTYPE doc_id = space_.get_doc_id(datapoint); + if (doc_counter_[doc_id] == 0) { + curr_num_docs_ += 1; + } + search_results_.emplace(dist, doc_id); + doc_counter_[doc_id] += 1; + } + + void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override { + DOCIDTYPE doc_id = space_.get_doc_id(datapoint); + doc_counter_[doc_id] -= 1; + if (doc_counter_[doc_id] == 0) { + curr_num_docs_ -= 1; + } + search_results_.pop(); + } + + bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override { + bool stop_search = candidate_dist > lowerBound && curr_num_docs_ == ef_collection_; + return stop_search; + } + + bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override { + bool flag_consider_candidate = curr_num_docs_ < ef_collection_ || lowerBound > candidate_dist; + return flag_consider_candidate; + } + + bool should_remove_extra() override { + bool flag_remove_extra = curr_num_docs_ > ef_collection_; + return flag_remove_extra; + } + + void filter_results(std::vector> &candidates) override { + while (curr_num_docs_ > num_docs_to_search_) { + dist_t dist_cand = candidates.back().first; + dist_t dist_res = search_results_.top().first; + assert(dist_cand == dist_res); + DOCIDTYPE doc_id = search_results_.top().second; + doc_counter_[doc_id] -= 1; + if (doc_counter_[doc_id] == 0) { + curr_num_docs_ -= 1; + } + search_results_.pop(); + candidates.pop_back(); + } + } + + ~MultiVectorSearchStopCondition() {} +}; + + +template +class EpsilonSearchStopCondition : public BaseSearchStopCondition { + float epsilon_; + size_t min_num_candidates_; + size_t max_num_candidates_; + size_t curr_num_items_; + + public: + EpsilonSearchStopCondition(float epsilon, size_t min_num_candidates, size_t max_num_candidates) { + assert(min_num_candidates <= max_num_candidates); + epsilon_ = epsilon; + min_num_candidates_ = min_num_candidates; + max_num_candidates_ = max_num_candidates; + curr_num_items_ = 0; + } + + void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override { + curr_num_items_ += 1; + } + + void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override { + curr_num_items_ -= 1; + } + + bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override { + if (candidate_dist > lowerBound && curr_num_items_ == max_num_candidates_) { + // new candidate can't improve found results + return true; + } + if (candidate_dist > epsilon_ && curr_num_items_ >= min_num_candidates_) { + // new candidate is out of epsilon region and + // minimum number of candidates is checked + return true; + } + return false; + } + + bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override { + bool flag_consider_candidate = curr_num_items_ < max_num_candidates_ || lowerBound > candidate_dist; + return flag_consider_candidate; + } + + bool should_remove_extra() { + bool flag_remove_extra = curr_num_items_ > max_num_candidates_; + return flag_remove_extra; + } + + void filter_results(std::vector> &candidates) override { + while (!candidates.empty() && candidates.back().first > epsilon_) { + candidates.pop_back(); + } + while (candidates.size() > max_num_candidates_) { + candidates.pop_back(); + } + } + + ~EpsilonSearchStopCondition() {} +}; +} // namespace hnswlib diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 9ed23655..dd09e80a 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -304,7 +304,11 @@ class Index { } - std::vector> getDataReturnList(py::object ids_ = py::none()) { + py::object getData(py::object ids_ = py::none(), std::string return_type = "numpy") { + std::vector return_types{"numpy", "list"}; + if (std::find(std::begin(return_types), std::end(return_types), return_type) == std::end(return_types)) { + throw std::invalid_argument("return_type should be \"numpy\" or \"list\""); + } std::vector ids; if (!ids_.is_none()) { py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_); @@ -325,7 +329,12 @@ class Index { for (auto id : ids) { data.push_back(appr_alg->template getDataByLabel(id)); } - return data; + if (return_type == "list") { + return py::cast(data); + } + if (return_type == "numpy") { + return py::array_t< data_t, py::array::c_style | py::array::forcecast >(py::cast(data)); + } } @@ -636,7 +645,7 @@ class Index { (void*)items.data(row), k, p_idFilter); if (result.size() != k) throw std::runtime_error( - "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); + "Cannot return the results in a contiguous 2D array. Probably ef or M is too small"); for (int i = k - 1; i >= 0; i--) { auto& result_tuple = result.top(); data_numpy_d[row * k + i] = result_tuple.first; @@ -656,7 +665,7 @@ class Index { (void*)(norm_array.data() + start_idx), k, p_idFilter); if (result.size() != k) throw std::runtime_error( - "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); + "Cannot return the results in a contiguous 2D array. Probably ef or M is too small"); for (int i = k - 1; i >= 0; i--) { auto& result_tuple = result.top(); data_numpy_d[row * k + i] = result_tuple.first; @@ -925,7 +934,7 @@ PYBIND11_PLUGIN(hnswlib) { py::arg("ids") = py::none(), py::arg("num_threads") = -1, py::arg("replace_deleted") = false) - .def("get_items", &Index::getDataReturnList, py::arg("ids") = py::none()) + .def("get_items", &Index::getData, py::arg("ids") = py::none(), py::arg("return_type") = "numpy") .def("get_ids_list", &Index::getIdsList) .def("set_ef", &Index::set_ef, py::arg("ef")) .def("set_num_threads", &Index::set_num_threads, py::arg("num_threads")) diff --git a/setup.py b/setup.py index c630e876..062b66ce 100644 --- a/setup.py +++ b/setup.py @@ -73,15 +73,19 @@ def cpp_flag(compiler): class BuildExt(build_ext): """A custom build extension for adding compiler-specific options.""" + compiler_flag_native = '-march=native' c_opts = { 'msvc': ['/EHsc', '/openmp', '/O2'], - 'unix': ['-O3', '-march=native'], # , '-w' + 'unix': ['-O3', compiler_flag_native], # , '-w' } link_opts = { 'unix': [], 'msvc': [], } + if os.environ.get("HNSWLIB_NO_NATIVE"): + c_opts['unix'].remove(compiler_flag_native) + if sys.platform == 'darwin': c_opts['unix'] += ['-stdlib=libc++', '-mmacosx-version-min=10.7'] link_opts['unix'] += ['-stdlib=libc++', '-mmacosx-version-min=10.7'] @@ -91,35 +95,35 @@ class BuildExt(build_ext): def build_extensions(self): ct = self.compiler.compiler_type - opts = self.c_opts.get(ct, []) + opts = BuildExt.c_opts.get(ct, []) if ct == 'unix': opts.append('-DVERSION_INFO="%s"' % self.distribution.get_version()) opts.append(cpp_flag(self.compiler)) if has_flag(self.compiler, '-fvisibility=hidden'): opts.append('-fvisibility=hidden') - # check that native flag is available - native_flag = '-march=native' - print('checking avalability of flag:', native_flag) - if not has_flag(self.compiler, native_flag): - print('removing unsupported compiler flag:', native_flag) - opts.remove(native_flag) - # for macos add apple-m1 flag if it's available - if sys.platform == 'darwin': - m1_flag = '-mcpu=apple-m1' - print('checking avalability of flag:', m1_flag) - if has_flag(self.compiler, m1_flag): - print('adding flag:', m1_flag) - opts.append(m1_flag) - else: - print(f'flag: {m1_flag} is not available') - else: - print(f'flag: {native_flag} is available') + if not os.environ.get("HNSWLIB_NO_NATIVE"): + # check that native flag is available + print('checking avalability of flag:', BuildExt.compiler_flag_native) + if not has_flag(self.compiler, BuildExt.compiler_flag_native): + print('removing unsupported compiler flag:', BuildExt.compiler_flag_native) + opts.remove(BuildExt.compiler_flag_native) + # for macos add apple-m1 flag if it's available + if sys.platform == 'darwin': + m1_flag = '-mcpu=apple-m1' + print('checking avalability of flag:', m1_flag) + if has_flag(self.compiler, m1_flag): + print('adding flag:', m1_flag) + opts.append(m1_flag) + else: + print(f'flag: {m1_flag} is not available') + else: + print(f'flag: {BuildExt.compiler_flag_native} is available') elif ct == 'msvc': opts.append('/DVERSION_INFO=\\"%s\\"' % self.distribution.get_version()) for ext in self.extensions: ext.extra_compile_args.extend(opts) - ext.extra_link_args.extend(self.link_opts.get(ct, [])) + ext.extra_link_args.extend(BuildExt.link_opts.get(ct, [])) build_ext.build_extensions(self) diff --git a/tests/cpp/epsilon_search_test.cpp b/tests/cpp/epsilon_search_test.cpp new file mode 100644 index 00000000..38df6246 --- /dev/null +++ b/tests/cpp/epsilon_search_test.cpp @@ -0,0 +1,114 @@ +#include "assert.h" +#include "../../hnswlib/hnswlib.h" + +typedef unsigned int docidtype; +typedef float dist_t; + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + int num_queries = 100; + float epsilon2 = 1.0; // Squared distance to query + int max_num_candidates = max_elements; // Upper bound on the number of returned elements in the epsilon region + int min_num_candidates = 2000; // Minimum number of candidates to search in the epsilon region + // this parameter is similar to ef + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::BruteforceSearch* alg_brute = new hnswlib::BruteforceSearch(&space, max_elements); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + std::cout << "Building index ...\n"; + for (int i = 0; i < max_elements; i++) { + hnswlib::labeltype label = i; + float* point_data = data + i * dim; + alg_hnsw->addPoint(point_data, label); + alg_brute->addPoint(point_data, label); + } + std::cout << "Index is ready\n"; + + // Query random vectors + for (int i = 0; i < num_queries; i++) { + float* query_data = new float[dim]; + for (int j = 0; j < dim; j++) { + query_data[j] = distrib_real(rng); + } + hnswlib::EpsilonSearchStopCondition stop_condition(epsilon2, min_num_candidates, max_num_candidates); + std::vector> result_hnsw = + alg_hnsw->searchStopConditionClosest(query_data, stop_condition); + + // check that returned results are in epsilon region + size_t num_vectors = result_hnsw.size(); + std::unordered_set hnsw_labels; + for (auto pair: result_hnsw) { + float dist = pair.first; + hnswlib::labeltype label = pair.second; + hnsw_labels.insert(label); + assert(dist >=0 && dist <= epsilon2); + } + std::priority_queue> result_brute = + alg_brute->searchKnn(query_data, max_elements); + + // check recall + std::unordered_set gt_labels; + while (!result_brute.empty()) { + float dist = result_brute.top().first; + hnswlib::labeltype label = result_brute.top().second; + if (dist < epsilon2) { + gt_labels.insert(label); + } + result_brute.pop(); + } + float correct = 0; + for (const auto& hnsw_label: hnsw_labels) { + if (gt_labels.find(hnsw_label) != gt_labels.end()) { + correct += 1; + } + } + if (gt_labels.size() == 0) { + assert(correct == 0); + continue; + } + float recall = correct / gt_labels.size(); + assert(recall > 0.95); + delete[] query_data; + } + std::cout << "Recall is OK\n"; + + // Query the elements for themselves and check that query can be found + float epsilon2_small = 0.0001f; + int min_candidates_small = 500; + for (size_t i = 0; i < max_elements; i++) { + hnswlib::EpsilonSearchStopCondition stop_condition(epsilon2_small, min_candidates_small, max_num_candidates); + std::vector> result = + alg_hnsw->searchStopConditionClosest(alg_hnsw->getDataByInternalId(i), stop_condition); + size_t num_vectors = result.size(); + // get closest distance + float dist = -1; + if (!result.empty()) { + dist = result[0].first; + } + assert(dist == 0); + } + std::cout << "Small epsilon search is OK\n"; + + delete[] data; + delete alg_brute; + delete alg_hnsw; + return 0; +} diff --git a/tests/cpp/multivector_search_test.cpp b/tests/cpp/multivector_search_test.cpp new file mode 100644 index 00000000..be783176 --- /dev/null +++ b/tests/cpp/multivector_search_test.cpp @@ -0,0 +1,126 @@ +#include +#include "../../hnswlib/hnswlib.h" + +typedef unsigned int docidtype; +typedef float dist_t; + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 1000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + int num_queries = 100; + int num_docs = 10; // Number of documents to search + int ef_collection = 15; // Number of candidate documents during the search + // Controlls the recall: higher ef leads to better accuracy, but slower search + docidtype min_doc_id = 0; + docidtype max_doc_id = 49; + + // Initing index + hnswlib::MultiVectorL2Space space(dim); + hnswlib::BruteforceSearch* alg_brute = new hnswlib::BruteforceSearch(&space, max_elements); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + std::uniform_int_distribution distrib_docid(min_doc_id, max_doc_id); + + size_t data_point_size = space.get_data_size(); + char* data = new char[data_point_size * max_elements]; + for (int i = 0; i < max_elements; i++) { + // set vector value + char* point_data = data + i * data_point_size; + for (int j = 0; j < dim; j++) { + char* vec_data = point_data + j * sizeof(float); + float value = distrib_real(rng); + *(float*)vec_data = value; + } + // set document id + docidtype doc_id = distrib_docid(rng); + space.set_doc_id(point_data, doc_id); + } + + // Add data to index + std::unordered_map label_docid_lookup; + for (int i = 0; i < max_elements; i++) { + hnswlib::labeltype label = i; + char* point_data = data + i * data_point_size; + alg_hnsw->addPoint(point_data, label); + alg_brute->addPoint(point_data, label); + label_docid_lookup[label] = space.get_doc_id(point_data); + } + + // Query random vectors and check overall recall + float correct = 0; + float total_num_elements = 0; + size_t query_size = dim * sizeof(float); + for (int i = 0; i < num_queries; i++) { + char* query_data = new char[query_size]; + for (int j = 0; j < dim; j++) { + size_t offset = j * sizeof(float); + char* vec_data = query_data + offset; + float value = distrib_real(rng); + *(float*)vec_data = value; + } + hnswlib::MultiVectorSearchStopCondition stop_condition(space, num_docs, ef_collection); + std::vector> hnsw_results = + alg_hnsw->searchStopConditionClosest(query_data, stop_condition); + + // check number of found documents + std::unordered_set hnsw_docs; + std::unordered_set hnsw_labels; + for (auto pair: hnsw_results) { + hnswlib::labeltype label = pair.second; + hnsw_labels.emplace(label); + docidtype doc_id = label_docid_lookup[label]; + hnsw_docs.emplace(doc_id); + } + assert(hnsw_docs.size() == num_docs); + + // Check overall recall + std::vector> gt_results = + alg_brute->searchKnnCloserFirst(query_data, max_elements); + std::unordered_set gt_docs; + for (int i = 0; i < gt_results.size(); i++) { + if (gt_docs.size() == num_docs) { + break; + } + hnswlib::labeltype gt_label = gt_results[i].second; + if (hnsw_labels.find(gt_label) != hnsw_labels.end()) { + correct += 1; + } + docidtype gt_doc_id = label_docid_lookup[gt_label]; + gt_docs.emplace(gt_doc_id); + total_num_elements += 1; + } + delete[] query_data; + } + float recall = correct / total_num_elements; + std::cout << "random elements search recall : " << recall << "\n"; + assert(recall > 0.95); + + // Query the elements for themselves and measure recall + correct = 0; + for (int i = 0; i < max_elements; i++) { + hnswlib::MultiVectorSearchStopCondition stop_condition(space, num_docs, ef_collection); + std::vector> result = + alg_hnsw->searchStopConditionClosest(data + i * data_point_size, stop_condition); + hnswlib::labeltype label = -1; + if (!result.empty()) { + label = result[0].second; + } + if (label == i) correct++; + } + recall = correct / max_elements; + std::cout << "same elements search recall : " << recall << "\n"; + assert(recall > 0.99); + + delete[] data; + delete alg_brute; + delete alg_hnsw; + return 0; +} diff --git a/tests/python/bindings_test_getdata.py b/tests/python/bindings_test_getdata.py index 515ecebd..3e16f9b9 100644 --- a/tests/python/bindings_test_getdata.py +++ b/tests/python/bindings_test_getdata.py @@ -45,5 +45,11 @@ def testGettingItems(self): self.assertRaises(ValueError, lambda: p.get_items(labels[0])) # After adding them, all labels should be retrievable - returned_items = p.get_items(labels) - self.assertSequenceEqual(data.tolist(), returned_items) + returned_items_np = p.get_items(labels) + self.assertTrue((data == returned_items_np).all()) + + # check returned type of get_items + self.assertTrue(isinstance(returned_items_np, np.ndarray)) + returned_items_list = p.get_items(labels, return_type="list") + self.assertTrue(isinstance(returned_items_list, list)) + self.assertTrue(isinstance(returned_items_list[0], list)) diff --git a/tests/python/bindings_test_replace.py b/tests/python/bindings_test_replace.py index 80003a3a..09c1299e 100644 --- a/tests/python/bindings_test_replace.py +++ b/tests/python/bindings_test_replace.py @@ -94,10 +94,10 @@ def testRandomSelf(self): remaining_data = comb_data[remaining_labels_list] returned_items = hnsw_index.get_items(remaining_labels_list) - self.assertSequenceEqual(remaining_data.tolist(), returned_items) + self.assertTrue((remaining_data == returned_items).all()) returned_items = hnsw_index.get_items(labels3_tr) - self.assertSequenceEqual(data3_tr.tolist(), returned_items) + self.assertTrue((data3_tr == returned_items).all()) # Check index serialization # Delete batch 3