Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up advancing within a block, take 2. #13958

Merged
merged 15 commits into from
Oct 30, 2024
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ Optimizations
* GITHUB#13963: Speed up nextDoc() implementations in Lucene912PostingsReader.
(Adrien Grand)

* GITHUB#13958: Speed up advancing within a block. (Adrien Grand)

Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.benchmark.jmh;

import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.VectorUtil;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.CompilerControl;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 5, time = 1)
@Measurement(iterations = 5, time = 1)
@Fork(
value = 3,
jvmArgsAppend = {
"-Xmx1g",
"-Xms1g",
"-XX:+AlwaysPreTouch",
"--add-modules",
"jdk.incubator.vector"
})
public class AdvanceBenchmark {

private final long[] values = new long[129];
private final int[] startIndexes = new int[1_000];
private final long[] targets = new long[startIndexes.length];

@Setup(Level.Trial)
public void setup() throws Exception {
for (int i = 0; i < 128; ++i) {
values[i] = i;
}
values[128] = DocIdSetIterator.NO_MORE_DOCS;
Random r = new Random(0);
for (int i = 0; i < startIndexes.length; ++i) {
startIndexes[i] = r.nextInt(64);
targets[i] = startIndexes[i] + 1 + r.nextInt(1 << r.nextInt(7));
}
}

@Benchmark
public void binarySearch() {
for (int i = 0; i < startIndexes.length; ++i) {
binarySearch(values, targets[i], startIndexes[i]);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private static int binarySearch(long[] values, long target, int startIndex) {
// Standard binary search
int i = Arrays.binarySearch(values, startIndex, values.length, target);
if (i < 0) {
i = -1 - i;
}
return i;
}

@Benchmark
public void inlinedBranchlessBinarySearch() {
for (int i = 0; i < targets.length; ++i) {
inlinedBranchlessBinarySearch(values, targets[i]);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private static int inlinedBranchlessBinarySearch(long[] values, long target) {
// This compiles to cmov instructions.
int start = 0;

if (values[63] < target) {
start += 64;
}
if (values[start + 31] < target) {
start += 32;
}
if (values[start + 15] < target) {
start += 16;
}
if (values[start + 7] < target) {
start += 8;
}
if (values[start + 3] < target) {
start += 4;
}
if (values[start + 1] < target) {
start += 2;
}
if (values[start] < target) {
start += 1;
}

return start;
}

@Benchmark
public void linearSearch() {
for (int i = 0; i < startIndexes.length; ++i) {
linearSearch(values, targets[i], startIndexes[i]);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private static int linearSearch(long[] values, long target, int startIndex) {
// Naive linear search.
for (int i = startIndex; i < values.length; ++i) {
if (values[i] >= target) {
return i;
}
}
return values.length;
}

@Benchmark
public void vectorUtilSearch() {
for (int i = 0; i < startIndexes.length; ++i) {
VectorUtil.findNextGEQ(values, 128, targets[i], startIndexes[i]);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private static int vectorUtilSearch(long[] values, long target, int startIndex) {
return VectorUtil.findNextGEQ(values, 128, target, startIndex);
}

private static void assertEquals(int expected, int actual) {
if (expected != actual) {
throw new AssertionError("Expected: " + expected + ", got " + actual);
}
}

public static void main(String[] args) {
// For testing purposes
long[] values = new long[129];
for (int i = 0; i < 128; ++i) {
values[i] = i;
}
values[128] = DocIdSetIterator.NO_MORE_DOCS;
for (int start = 0; start < 128; ++start) {
for (int targetIndex = start; targetIndex < 128; ++targetIndex) {
int actualIndex = binarySearch(values, values[targetIndex], start);
assertEquals(targetIndex, actualIndex);
actualIndex = inlinedBranchlessBinarySearch(values, values[targetIndex]);
assertEquals(targetIndex, actualIndex);
actualIndex = linearSearch(values, values[targetIndex], start);
assertEquals(targetIndex, actualIndex);
actualIndex = vectorUtilSearch(values, values[targetIndex], start);
assertEquals(targetIndex, actualIndex);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SlowImpactsEnum;
import org.apache.lucene.internal.vectorization.PostingDecodingUtil;
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;
import org.apache.lucene.store.ByteArrayDataInput;
import org.apache.lucene.store.ChecksumIndexInput;
Expand All @@ -65,6 +66,8 @@
public final class Lucene912PostingsReader extends PostingsReaderBase {

static final VectorizationProvider VECTORIZATION_PROVIDER = VectorizationProvider.getInstance();
private static final VectorUtilSupport VECTOR_SUPPORT =
VECTORIZATION_PROVIDER.getVectorUtilSupport();
// Dummy impacts, composed of the maximum possible term frequency and the lowest possible
// (unsigned) norm value. This is typically used on tail blocks, which don't actually record
// impacts as the storage overhead would not be worth any query evaluation speedup, since there's
Expand Down Expand Up @@ -215,15 +218,6 @@ static void prefixSum(long[] buffer, int count, long base) {
}
}

static int findFirstGreater(long[] buffer, int target, int from) {
for (int i = from; i < BLOCK_SIZE; ++i) {
if (buffer[i] >= target) {
return i;
}
}
return BLOCK_SIZE;
}

@Override
public BlockTermState newTermState() {
return new IntBlockTermState();
Expand Down Expand Up @@ -357,6 +351,7 @@ private abstract class AbstractPostingsEnum extends PostingsEnum {
protected int docCountUpto; // number of docs in or before the current block
protected long prevDocID; // last doc ID of the previous block

protected int docBufferSize;
protected int docBufferUpto;

protected IndexInput docIn;
Expand Down Expand Up @@ -402,6 +397,7 @@ protected PostingsEnum resetIdsAndLevelParams(IntBlockTermState termState) throw
level1DocEndFP = termState.docStartFP;
}
level1DocCountUpto = 0;
docBufferSize = BLOCK_SIZE;
docBufferUpto = BLOCK_SIZE;
return this;
}
Expand Down Expand Up @@ -487,7 +483,7 @@ private void refillFullBlock() throws IOException {
docCountUpto += BLOCK_SIZE;
prevDocID = docBuffer[BLOCK_SIZE - 1];
docBufferUpto = 0;
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
}

private void refillRemainder() throws IOException {
Expand All @@ -508,6 +504,7 @@ private void refillRemainder() throws IOException {
docCountUpto += left;
}
docBufferUpto = 0;
docBufferSize = left;
freqFP = -1;
}

Expand Down Expand Up @@ -604,7 +601,7 @@ public int advance(int target) throws IOException {
}
}

int next = findFirstGreater(docBuffer, target, docBufferUpto);
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
this.doc = (int) docBuffer[next];
docBufferUpto = next + 1;
return doc;
Expand Down Expand Up @@ -782,16 +779,18 @@ private void refillDocs() throws IOException {
freqBuffer[0] = totalTermFreq;
docBuffer[1] = NO_MORE_DOCS;
docCountUpto++;
docBufferSize = 1;
} else {
// Read vInts:
PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, true);
prefixSum(docBuffer, left, prevDocID);
docBuffer[left] = NO_MORE_DOCS;
docCountUpto += left;
docBufferSize = left;
}
prevDocID = docBuffer[BLOCK_SIZE - 1];
docBufferUpto = 0;
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
}

private void skipLevel1To(int target) throws IOException {
Expand Down Expand Up @@ -951,7 +950,7 @@ public int advance(int target) throws IOException {
refillDocs();
}

int next = findFirstGreater(docBuffer, target, docBufferUpto);
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
posPendingCount += sumOverRange(freqBuffer, docBufferUpto, next + 1);
this.freq = (int) freqBuffer[next];
this.docBufferUpto = next + 1;
Expand Down Expand Up @@ -1155,6 +1154,7 @@ private abstract class BlockImpactsEnum extends ImpactsEnum {
protected int docCountUpto; // number of docs in or before the current block
protected int doc = -1; // doc we last read
protected long prevDocID = -1; // last doc ID of the previous block
protected int docBufferSize = BLOCK_SIZE;
protected int docBufferUpto = BLOCK_SIZE;

// true if we shallow-advanced to a new block that we have not decoded yet
Expand Down Expand Up @@ -1306,10 +1306,11 @@ private void refillDocs() throws IOException {
docBuffer[left] = NO_MORE_DOCS;
freqFP = -1;
docCountUpto += left;
docBufferSize = left;
}
prevDocID = docBuffer[BLOCK_SIZE - 1];
docBufferUpto = 0;
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
}

private void skipLevel1To(int target) throws IOException {
Expand Down Expand Up @@ -1437,7 +1438,7 @@ public int advance(int target) throws IOException {
needsRefilling = false;
}

int next = findFirstGreater(docBuffer, target, docBufferUpto);
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
this.doc = (int) docBuffer[next];
docBufferUpto = next + 1;
return doc;
Expand Down Expand Up @@ -1535,10 +1536,11 @@ private void refillDocs() throws IOException {
prefixSum(docBuffer, left, prevDocID);
docBuffer[left] = NO_MORE_DOCS;
docCountUpto += left;
docBufferSize = left;
}
prevDocID = docBuffer[BLOCK_SIZE - 1];
docBufferUpto = 0;
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
}

private void skipLevel1To(int target) throws IOException {
Expand Down Expand Up @@ -1669,7 +1671,7 @@ public int advance(int target) throws IOException {
needsRefilling = false;
}

int next = findFirstGreater(docBuffer, target, docBufferUpto);
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
posPendingCount += sumOverRange(freqBuffer, docBufferUpto, next + 1);
freq = (int) freqBuffer[next];
docBufferUpto = next + 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,14 @@ public int squareDistance(byte[] a, byte[] b) {
}
return squareSum;
}

@Override
public int findNextGEQ(long[] buffer, int length, long target, int from) {
for (int i = from; i < length; ++i) {
if (buffer[i] >= target) {
return i;
}
}
return length;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,12 @@ public interface VectorUtilSupport {

/** Returns the sum of squared differences of the two byte vectors. */
int squareDistance(byte[] a, byte[] b);

/**
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code
* length} exclusive, find the first array index whose value is greater than or equal to {@code
* target}. This index is guaranteed to be at least {@code from}. If there is no such array index,
* {@code length} is returned.
*/
int findNextGEQ(long[] buffer, int length, long target, int from);
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ public int advance(int target) throws IOException {

@Override
public int nextDoc() throws IOException {
DocIdSetIterator in = this.in;
if (in.docID() < upTo) {
return in.nextDoc();
}
return advance(in.docID() + 1);
}

Expand Down
10 changes: 10 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -307,4 +307,14 @@ public static float[] checkFinite(float[] v) {
}
return v;
}

/**
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code
* length} exclusive, find the first array index whose value is greater than or equal to {@code
* target}. This index is guaranteed to be at least {@code from}. If there is no such array index,
* {@code length} is returned.
*/
public static int findNextGEQ(long[] buffer, int length, long target, int from) {
return IMPL.findNextGEQ(buffer, length, target, from);
}
}
Loading