Skip to content

Commit

Permalink
[hotfix] Quantization patch; fix semantic_search_faiss/semantic_sea…
Browse files Browse the repository at this point in the history
…rch_usearch rescoring (#2558)

* Correctly update the scores after rescoring

* Sort scores from high to low in recommended example
  • Loading branch information
tomaarsen authored Mar 26, 2024
1 parent 85810ea commit a46251f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def search(query, top_k: int = 10, rescore_multiplier: int = 4):

# 6. Sort the scores and return the top_k
start_time = time.time()
indices = scores.argsort()[:top_k]
indices = (-scores).argsort()[:top_k]
top_k_indices = binary_ids[indices]
top_k_scores = scores[indices]
sort_time = time.time() - start_time
Expand Down
4 changes: 2 additions & 2 deletions sentence_transformers/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def semantic_search_faiss(
rescored_scores = np.einsum("ij,ikj->ik", rescore_embeddings, top_k_embeddings)
rescored_indices = np.argsort(-rescored_scores)[:, :top_k]
indices = indices[np.arange(len(query_embeddings))[:, None], rescored_indices]
scores = rescored_scores[:, :top_k]
scores = rescored_scores[np.arange(len(query_embeddings))[:, None], rescored_indices]

delta_t = time.time() - start_t

Expand Down Expand Up @@ -293,7 +293,7 @@ def semantic_search_usearch(
rescored_scores = np.einsum("ij,ikj->ik", rescore_embeddings, top_k_embeddings)
rescored_indices = np.argsort(-rescored_scores)[:, :top_k]
indices = indices[np.arange(len(query_embeddings))[:, None], rescored_indices]
scores = rescored_scores[:, :top_k]
scores = rescored_scores[np.arange(len(query_embeddings))[:, None], rescored_indices]

delta_t = time.time() - start_t

Expand Down

0 comments on commit a46251f

Please sign in to comment.