Skip to content

Commit

Permalink
A few small tweaks to VectorUtil#findNextGEQ: (apache#13972)
Browse files Browse the repository at this point in the history
1. Rearrange/rename the parameters to be more idiomatic (e.g., follow conventions of Arrays#... methods)
2. Add assert to ensure expected sortedness we may rely on in the future (so we're not trappy)
3. Migrate PostingsReader to call VectorUtil instead of VectorUtilSupport (so it benefits from the common assert)
  • Loading branch information
gsmiller authored Nov 5, 2024
1 parent 1328527 commit a3d56ea
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ public int advance(int target) throws IOException {
}
}

int next = findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
int next = findNextGEQ(docBuffer, target, docBufferUpto, docBufferSize);
this.doc = (int) docBuffer[next];
docBufferUpto = next + 1;
return doc;
Expand Down Expand Up @@ -937,7 +937,7 @@ public int advance(int target) throws IOException {
refillDocs();
}

int next = findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
int next = findNextGEQ(docBuffer, target, docBufferUpto, docBufferSize);
posPendingCount += sumOverRange(freqBuffer, docBufferUpto, next + 1);
this.freq = (int) freqBuffer[next];
this.docBufferUpto = next + 1;
Expand Down Expand Up @@ -1423,7 +1423,7 @@ public int advance(int target) throws IOException {
needsRefilling = false;
}

int next = findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
int next = findNextGEQ(docBuffer, target, docBufferUpto, docBufferSize);
this.doc = (int) docBuffer[next];
docBufferUpto = next + 1;
return doc;
Expand Down Expand Up @@ -1654,7 +1654,7 @@ public int advance(int target) throws IOException {
needsRefilling = false;
}

int next = findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
int next = findNextGEQ(docBuffer, target, docBufferUpto, docBufferSize);
posPendingCount += sumOverRange(freqBuffer, docBufferUpto, next + 1);
freq = (int) freqBuffer[next];
docBufferUpto = next + 1;
Expand Down Expand Up @@ -1755,13 +1755,13 @@ static long readVLong15(DataInput in) throws IOException {
}
}

private static int findNextGEQ(long[] buffer, int length, long target, int from) {
for (int i = from; i < length; ++i) {
private static int findNextGEQ(long[] buffer, long target, int from, int to) {
for (int i = from; i < to; ++i) {
if (buffer[i] >= target) {
return i;
}
}
return length;
return to;
}

private static void prefetchPostings(IndexInput docIn, IntBlockTermState state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ private static int linearSearch(int[] values, long target, int startIndex) {
@Benchmark
public void vectorUtilSearch() {
for (int i = 0; i < startIndexes.length; ++i) {
VectorUtil.findNextGEQ(values, 128, targets[i], startIndexes[i]);
VectorUtil.findNextGEQ(values, targets[i], startIndexes[i], 128);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private static int vectorUtilSearch(int[] values, int target, int startIndex) {
return VectorUtil.findNextGEQ(values, 128, target, startIndex);
return VectorUtil.findNextGEQ(values, target, startIndex, 128);
}

private static void assertEquals(int expected, int actual) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SlowImpactsEnum;
import org.apache.lucene.internal.vectorization.PostingDecodingUtil;
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;
import org.apache.lucene.store.ByteArrayDataInput;
import org.apache.lucene.store.ChecksumIndexInput;
Expand All @@ -57,6 +56,7 @@
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.VectorUtil;

/**
* Concrete class that reads docId(maybe frq,pos,offset,payloads) list with postings format.
Expand All @@ -66,8 +66,6 @@
public final class Lucene101PostingsReader extends PostingsReaderBase {

static final VectorizationProvider VECTORIZATION_PROVIDER = VectorizationProvider.getInstance();
private static final VectorUtilSupport VECTOR_SUPPORT =
VECTORIZATION_PROVIDER.getVectorUtilSupport();
// Dummy impacts, composed of the maximum possible term frequency and the lowest possible
// (unsigned) norm value. This is typically used on tail blocks, which don't actually record
// impacts as the storage overhead would not be worth any query evaluation speedup, since there's
Expand Down Expand Up @@ -601,7 +599,7 @@ public int advance(int target) throws IOException {
}
}

int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
int next = VectorUtil.findNextGEQ(docBuffer, target, docBufferUpto, docBufferSize);
this.doc = docBuffer[next];
docBufferUpto = next + 1;
return doc;
Expand Down Expand Up @@ -950,7 +948,7 @@ public int advance(int target) throws IOException {
refillDocs();
}

int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
int next = VectorUtil.findNextGEQ(docBuffer, target, docBufferUpto, docBufferSize);
posPendingCount += sumOverRange(freqBuffer, docBufferUpto, next + 1);
this.freq = freqBuffer[next];
this.docBufferUpto = next + 1;
Expand Down Expand Up @@ -1437,7 +1435,7 @@ public int advance(int target) throws IOException {
needsRefilling = false;
}

int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
int next = VectorUtil.findNextGEQ(docBuffer, target, docBufferUpto, docBufferSize);
this.doc = docBuffer[next];
docBufferUpto = next + 1;
return doc;
Expand Down Expand Up @@ -1670,7 +1668,7 @@ public int advance(int target) throws IOException {
needsRefilling = false;
}

int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
int next = VectorUtil.findNextGEQ(docBuffer, target, docBufferUpto, docBufferSize);
posPendingCount += sumOverRange(freqBuffer, docBufferUpto, next + 1);
freq = freqBuffer[next];
docBufferUpto = next + 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,12 @@ public int squareDistance(byte[] a, byte[] b) {
}

@Override
public int findNextGEQ(int[] buffer, int length, int target, int from) {
for (int i = from; i < length; ++i) {
public int findNextGEQ(int[] buffer, int target, int from, int to) {
for (int i = from; i < to; ++i) {
if (buffer[i] >= target) {
return i;
}
}
return length;
return to;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ public interface VectorUtilSupport {
int squareDistance(byte[] a, byte[] b);

/**
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code
* length} exclusive, find the first array index whose value is greater than or equal to {@code
* target}. This index is guaranteed to be at least {@code from}. If there is no such array index,
* {@code length} is returned.
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code to}
* exclusive, find the first array index whose value is greater than or equal to {@code target}.
* This index is guaranteed to be at least {@code from}. If there is no such array index, {@code
* to} is returned.
*/
int findNextGEQ(int[] buffer, int length, int target, int from);
int findNextGEQ(int[] buffer, int target, int from, int to);
}
14 changes: 8 additions & 6 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.lucene.util;

import java.util.stream.IntStream;
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;

Expand Down Expand Up @@ -309,12 +310,13 @@ public static float[] checkFinite(float[] v) {
}

/**
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code
* length} exclusive, find the first array index whose value is greater than or equal to {@code
* target}. This index is guaranteed to be at least {@code from}. If there is no such array index,
* {@code length} is returned.
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code to}
* exclusive, find the first array index whose value is greater than or equal to {@code target}.
* This index is guaranteed to be at least {@code from}. If there is no such array index, {@code
* to} is returned.
*/
public static int findNextGEQ(int[] buffer, int length, int target, int from) {
return IMPL.findNextGEQ(buffer, length, target, from);
public static int findNextGEQ(int[] buffer, int target, int from, int to) {
assert IntStream.range(0, to - 1).noneMatch(i -> buffer[i] > buffer[i + 1]);
return IMPL.findNextGEQ(buffer, target, from, to);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -769,21 +769,21 @@ private static int squareDistanceBody128(MemorySegment a, MemorySegment b, int l
private static final boolean ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO = INT_SPECIES.length() >= 8;

@Override
public int findNextGEQ(int[] buffer, int length, int target, int from) {
public int findNextGEQ(int[] buffer, int target, int from, int to) {
if (ENABLE_FIND_NEXT_GEQ_VECTOR_OPTO) {
for (; from + INT_SPECIES.length() < length; from += INT_SPECIES.length() + 1) {
for (; from + INT_SPECIES.length() < to; from += INT_SPECIES.length() + 1) {
if (buffer[from + INT_SPECIES.length()] >= target) {
IntVector vector = IntVector.fromArray(INT_SPECIES, buffer, from);
VectorMask<Integer> mask = vector.compare(VectorOperators.LT, target);
return from + mask.trueCount();
}
}
}
for (int i = from; i < length; ++i) {
for (int i = from; i < to; ++i) {
if (buffer[i] >= target) {
return i;
}
}
return length;
return to;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ public void testFindNextGEQ() {
- 5;
assertEquals(
slowFindNextGEQ(values, 128, target, from),
VectorUtil.findNextGEQ(values, 128, target, from));
VectorUtil.findNextGEQ(values, target, from, 128));
}
}

Expand Down

0 comments on commit a3d56ea

Please sign in to comment.