Skip to content

Commit

Permalink
LUCENE-9614: add KnnVectorQuery implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
msokolov authored Aug 13, 2021
1 parent a9fb5a9 commit 624560a
Show file tree
Hide file tree
Showing 4 changed files with 645 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ public TopDocs search(String field, float[] target, int k) throws IOException {
float score = results.topScore();
results.pop();
if (reversed) {
score = (float) Math.exp(-score / target.length);
score = 1 / (1 + score);
}
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
}
Expand Down
307 changes: 307 additions & 0 deletions lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;

/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
public class KnnVectorQuery extends Query {

private static final TopDocs NO_RESULTS =
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);

private final String field;
private final float[] target;
private final int k;

/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
* given field. <code>target</code> vector.
*
* @param field a field that has been indexed as a {@link KnnVectorField}.
* @param target the target of the search
* @param k the number of documents to find
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnVectorQuery(String field, float[] target, int k) {
this.field = field;
this.target = target;
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
}
}

@Override
public Query rewrite(IndexReader reader) throws IOException {
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
for (LeafReaderContext ctx : reader.leaves()) {
perLeafResults[ctx.ord] = searchLeaf(ctx, Math.min(k, reader.numDocs()));
}
// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
if (topK.scoreDocs.length == 0) {
return new MatchNoDocsQuery();
}
return createRewrittenQuery(reader, topK);
}

private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
if (results == null) {
return NO_RESULTS;
}
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
return results;
}

private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
int len = topK.scoreDocs.length;
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
int[] docs = new int[len];
float[] scores = new float[len];
for (int i = 0; i < len; i++) {
docs[i] = topK.scoreDocs[i].doc;
scores[i] = topK.scoreDocs[i].score;
}
int[] segmentStarts = findSegmentStarts(reader, docs);
return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.hashCode());
}

private int[] findSegmentStarts(IndexReader reader, int[] docs) {
int[] starts = new int[reader.leaves().size() + 1];
starts[starts.length - 1] = docs.length;
if (starts.length == 2) {
return starts;
}
int resultIndex = 0;
for (int i = 1; i < starts.length - 1; i++) {
int upper = reader.leaves().get(i).docBase;
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
if (resultIndex < 0) {
resultIndex = -1 - resultIndex;
}
starts[i] = resultIndex;
}
return starts;
}

@Override
public String toString(String field) {
return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
}

@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
visitor.visitLeaf(this);
}
}

@Override
public boolean equals(Object obj) {
return obj instanceof KnnVectorQuery
&& ((KnnVectorQuery) obj).k == k
&& ((KnnVectorQuery) obj).field.equals(field)
&& Arrays.equals(((KnnVectorQuery) obj).target, target);
}

@Override
public int hashCode() {
return Objects.hash(field, k, Arrays.hashCode(target));
}

/** Caches the results of a KnnVector search: a list of docs and their scores */
static class DocAndScoreQuery extends Query {

private final int k;
private final int[] docs;
private final float[] scores;
private final int[] segmentStarts;
private final int readerHash;

/**
* Constructor
*
* @param k the number of documents requested
* @param docs the global docids of documents that match, in ascending order
* @param scores the scores of the matching documents
* @param segmentStarts the indexes in docs and scores corresponding to the first matching
* document in each segment. If a segment has no matching documents, it should be assigned
* the index of the next segment that does. There should be a final entry that is always
* docs.length-1.
* @param readerHash a hash code identifying the IndexReader used to create this query
*/
DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, int readerHash) {
this.k = k;
this.docs = docs;
this.scores = scores;
this.segmentStarts = segmentStarts;
this.readerHash = readerHash;
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
if (searcher.getIndexReader().hashCode() != readerHash) {
throw new IllegalStateException("This DocAndScore query was created by a different reader");
}
return new Weight(this) {
@Override
public Explanation explain(LeafReaderContext context, int doc) {
int found = Arrays.binarySearch(docs, doc);
if (found < 0) {
return Explanation.noMatch("not in top " + k);
}
return Explanation.match(scores[found], "within top " + k);
}

@Override
public Scorer scorer(LeafReaderContext context) {

return new Scorer(this) {
final int lower = segmentStarts[context.ord];
final int upper = segmentStarts[context.ord + 1];
int upTo = -1;

@Override
public DocIdSetIterator iterator() {
return new DocIdSetIterator() {
@Override
public int docID() {
return docIdNoShadow();
}

@Override
public int nextDoc() {
if (upTo == -1) {
upTo = lower;
} else {
++upTo;
}
return docIdNoShadow();
}

@Override
public int advance(int target) throws IOException {
return slowAdvance(target);
}

@Override
public long cost() {
return upper - lower;
}
};
}

@Override
public float getMaxScore(int docid) {
docid += context.docBase;
float maxScore = 0;
for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
maxScore = Math.max(maxScore, scores[idx]);
}
return maxScore;
}

@Override
public float score() {
return scores[upTo];
}

@Override
public int advanceShallow(int docid) {
int start = Math.max(upTo, lower);
int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
if (docidIndex < 0) {
docidIndex = -1 - docidIndex;
}
if (docidIndex >= upper) {
return NO_MORE_DOCS;
}
return docs[docidIndex];
}

/**
* move the implementation of docID() into a differently-named method so we can call it
* from DocIDSetIterator.docID() even though this class is anonymous
*
* @return the current docid
*/
private int docIdNoShadow() {
if (upTo == -1) {
return -1;
}
if (upTo >= upper) {
return NO_MORE_DOCS;
}
return docs[upTo] - context.docBase;
}

@Override
public int docID() {
return docIdNoShadow();
}
};
}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}

@Override
public String toString(String field) {
return "DocAndScore[" + k + "]";
}

@Override
public void visit(QueryVisitor visitor) {
visitor.visitLeaf(this);
}

@Override
public boolean equals(Object obj) {
if (obj instanceof DocAndScoreQuery == false) {
return false;
}
return Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
&& Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
}

@Override
public int hashCode() {
return Objects.hash(
DocAndScoreQuery.class.hashCode(), Arrays.hashCode(docs), Arrays.hashCode(scores));
}
}
}
Loading

0 comments on commit 624560a

Please sign in to comment.