Skip to content

Commit

Permalink
Add BaseKnnVectorsFormatTestCase.testRecall() and fix old codecs (apa…
Browse files Browse the repository at this point in the history
…che#13910)

* Add BaseKnnVectorsFormatTestCase.testRecall() and fix map ord to doc in Lucene90HnswVectorsReader
  • Loading branch information
msokolov authored Oct 17, 2024
1 parent 1faf33a commit 3983fa2
Show file tree
Hide file tree
Showing 8 changed files with 886 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
int node = results.topNode();
float minSimilarity = results.topScore();
results.pop();
knnCollector.collect(node, minSimilarity);
knnCollector.collect(vectorValues.ordToDoc(node), minSimilarity);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ private void popToScratch(HnswGraphBuilder.GraphBuilderKnnCollector candidates)
// extract all the Neighbors from the queue into an array; these will now be
// sorted from worst to best
for (int i = 0; i < candidateCount; i++) {
float similarity = candidates.minCompetitiveSimilarity();
float similarity = candidates.minimumScore();
scratch.add(candidates.popNode(), similarity);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ public int[] popUntilNearestKNodes() {
return queue.nodes();
}

float minimumScore() {
public float minimumScore() {
return queue.topScore();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ public void testSearch() throws Exception {
}
}

@Override
public void testRecall() {
// ignore this test since this class always returns no results from search
}

public void testQuantizedVectorsWriteAndRead() throws Exception {
// create lucene directory with codec
int numVectors = 1 + random().nextInt(50);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,6 @@ private void add(
String idString = Integer.toString(id);
doc.add(new StringField("id", idString, Field.Store.YES));
doc.add(new SortedDocValuesField("id", new BytesRef(idString)));
// XSSystem.out.println("add " + idString + " " + Arrays.toString(vector));
iw.updateDocument(new Term("id", idString), doc);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@
*/
package org.apache.lucene.tests.index;

import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -70,6 +76,10 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
Expand Down Expand Up @@ -1906,4 +1916,168 @@ public void testMismatchedFields() throws Exception {

IOUtils.close(reader, w2, dir1, dir2);
}

/**
* Test that the query is a viable approximation to exact search. This test is designed to uncover
* gross failures only, not to represent the true expected recall.
*/
public void testRecall() throws IOException {
VectorSimilarityFunction[] functions = {
VectorSimilarityFunction.EUCLIDEAN,
VectorSimilarityFunction.COSINE,
VectorSimilarityFunction.DOT_PRODUCT,
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT
};
for (VectorSimilarityFunction similarity : functions) {
assertRecall(similarity, 0.5, 1.0);
}
}

protected void assertRecall(VectorSimilarityFunction similarity, double min, double max)
throws IOException {
int dim = 16;
int recalled = 0;
try (Directory indexStore = getKnownIndexStore("field", dim, similarity);
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
float[] queryEmbedding = new float[dim];
// indexed 421 lines from LICENSE.txt
// indexed 157 lines from NOTICE.txt
int topK = 10;
int numQueries = 578;
String[] testQueries = {
"Apache Lucene",
"Apache License",
"TERMS AND CONDITIONS",
"Copyright 2001",
"Permission is hereby",
"Copyright © 2003",
"The dictionary comes from Morfologik project",
"The levenshtein automata tables"
};
for (String queryString : testQueries) {
computeLineEmbedding(queryString, queryEmbedding);

// pass match-all "filter" to force full traversal, bypassing graph
KnnFloatVectorQuery exactQuery =
new KnnFloatVectorQuery("field", queryEmbedding, 1000, new MatchAllDocsQuery());
assertEquals(numQueries, searcher.count(exactQuery)); // Same for exact search

KnnFloatVectorQuery query = new KnnFloatVectorQuery("field", queryEmbedding, topK);
assertEquals(10, searcher.count(query)); // Expect some results without timeout
TopDocs results = searcher.search(query, topK);
Set<Integer> resultDocs = new HashSet<>();
int i = 0;
for (ScoreDoc scoreDoc : results.scoreDocs) {
if (VERBOSE) {
System.out.println(
"result "
+ i++
+ ": "
+ reader.storedFields().document(scoreDoc.doc)
+ " "
+ scoreDoc);
}
resultDocs.add(scoreDoc.doc);
}
TopDocs expected = searcher.search(exactQuery, topK);
i = 0;
for (ScoreDoc scoreDoc : expected.scoreDocs) {
if (VERBOSE) {
System.out.println(
"expected "
+ i++
+ ": "
+ reader.storedFields().document(scoreDoc.doc)
+ " "
+ scoreDoc);
}
if (resultDocs.contains(scoreDoc.doc)) {
++recalled;
}
}
}
int totalResults = testQueries.length * topK;
assertTrue(
"Average recall for "
+ similarity
+ " should be at least "
+ (totalResults * min)
+ " / "
+ totalResults
+ ", got "
+ recalled,
recalled >= (int) (totalResults * min));
assertTrue(
"Average recall for "
+ similarity
+ " should be no more than "
+ (totalResults * max)
+ " / "
+ totalResults
+ ", got "
+ recalled,
recalled <= (int) (totalResults * max));
}
}

/** Creates a new directory and adds documents with the given vectors as kNN vector fields */
Directory getKnownIndexStore(
String field, int dimension, VectorSimilarityFunction vectorSimilarityFunction)
throws IOException {
Directory indexStore = newDirectory(random());
IndexWriter writer = new IndexWriter(indexStore, newIndexWriterConfig());
float[] scratch = new float[dimension];
for (String file : List.of("LICENSE.txt", "NOTICE.txt")) {
try (InputStream in = BaseKnnVectorsFormatTestCase.class.getResourceAsStream(file);
BufferedReader reader = new BufferedReader(new InputStreamReader(in, UTF_8))) {
String line;
int lineNo = -1;
while ((line = reader.readLine()) != null) {
line = line.strip();
if (line.isEmpty()) {
continue;
}
++lineNo;
Document doc = new Document();
doc.add(
new KnnFloatVectorField(
field, computeLineEmbedding(line, scratch), vectorSimilarityFunction));
doc.add(new StoredField("text", line));
doc.add(new StringField("id", file + "." + lineNo, Field.Store.YES));
writer.addDocument(doc);
if (random().nextBoolean()) {
// Add some documents without a vector
addDocuments(writer, "id" + lineNo + ".", randomIntBetween(1, 5));
}
}
// System.out.println("indexed " + (lineNo + 1) + " lines from " + file);
}
}
// Add some documents without a vector nor an id
addDocuments(writer, null, 5);
writer.close();
return indexStore;
}

private float[] computeLineEmbedding(String line, float[] vector) {
Arrays.fill(vector, 0);
for (int i = 0; i < line.length(); i++) {
char c = line.charAt(i);
vector[i % vector.length] += c / ((float) (i + 1) / vector.length);
}
VectorUtil.l2normalize(vector, false);
return vector;
}

private void addDocuments(IndexWriter writer, String idBase, int count) throws IOException {
for (int i = 0; i < count; i++) {
Document doc = new Document();
doc.add(new StringField("other", "value", Field.Store.NO));
if (idBase != null) {
doc.add(new StringField("id", idBase + i, Field.Store.YES));
}
writer.addDocument(doc);
}
}
}
Loading

0 comments on commit 3983fa2

Please sign in to comment.