Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix IntegerOverflow exception in postings encoding as group-varint #13376

Merged
merged 10 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ Bug Fixes

* GITHUB#13374: Fix bug in SQ when just a single vector present in a segment (Chris Hegarty)

* GITHUB#13376: Fix integer overflow exception in postings encoding as group-varint. (Zhang Chao, Guo Feng)

Build
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,31 +140,6 @@ public void init(IndexInput termsIn, SegmentReadState state) throws IOException
}
}

/** Read values that have been written using variable-length encoding instead of bit-packing. */
static void readVIntBlock(
IndexInput docIn,
long[] docBuffer,
long[] freqBuffer,
int num,
boolean indexHasFreq,
boolean decodeFreq)
throws IOException {
docIn.readGroupVInts(docBuffer, num);
if (indexHasFreq && decodeFreq) {
for (int i = 0; i < num; ++i) {
freqBuffer[i] = docBuffer[i] & 0x01;
docBuffer[i] >>= 1;
if (freqBuffer[i] == 0) {
freqBuffer[i] = docIn.readVInt();
}
}
} else if (indexHasFreq) {
for (int i = 0; i < num; ++i) {
docBuffer[i] >>= 1;
}
}
}

static void prefixSum(long[] buffer, int count, long base) {
buffer[0] += base;
for (int i = 1; i < count; ++i) {
Expand Down Expand Up @@ -475,7 +450,7 @@ private void refillDocs() throws IOException {
blockUpto++;
} else {
// Read vInts:
readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, needsFreq);
PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, needsFreq);
prefixSum(docBuffer, left, accum);
docBuffer[left] = NO_MORE_DOCS;
blockUpto += left;
Expand Down Expand Up @@ -768,7 +743,7 @@ private void refillDocs() throws IOException {
docBuffer[1] = NO_MORE_DOCS;
blockUpto++;
} else {
readVIntBlock(docIn, docBuffer, freqBuffer, left, true, true);
PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, true, true);
prefixSum(docBuffer, left, accum);
docBuffer[left] = NO_MORE_DOCS;
blockUpto += left;
Expand Down Expand Up @@ -1155,7 +1130,7 @@ private void refillDocs() throws IOException {
}
blockUpto += BLOCK_SIZE;
} else {
readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreqs, true);
PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreqs, true);
prefixSum(docBuffer, left, accum);
docBuffer[left] = NO_MORE_DOCS;
blockUpto += left;
Expand Down Expand Up @@ -1363,7 +1338,7 @@ private void refillDocs() throws IOException {
forDeltaUtil.decodeAndPrefixSum(docIn, accum, docBuffer);
pforUtil.decode(docIn, freqBuffer);
} else {
readVIntBlock(docIn, docBuffer, freqBuffer, left, true, true);
PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, true, true);
prefixSum(docBuffer, left, accum);
docBuffer[left] = NO_MORE_DOCS;
}
Expand Down Expand Up @@ -1753,7 +1728,7 @@ private void refillDocs() throws IOException {
false; // freq block will be loaded lazily when necessary, we don't load it here
}
} else {
readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, true);
PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, true);
prefixSum(docBuffer, left, accum);
docBuffer[left] = NO_MORE_DOCS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,20 +371,7 @@ public void finishTerm(BlockTermState _state) throws IOException {
} else {
singletonDocID = -1;
// Group vInt encode the remaining doc deltas and freqs:
if (writeFreqs) {
for (int i = 0; i < docBufferUpto; i++) {
docDeltaBuffer[i] = (docDeltaBuffer[i] << 1) | (freqBuffer[i] == 1 ? 1 : 0);
}
}
docOut.writeGroupVInts(docDeltaBuffer, docBufferUpto);
if (writeFreqs) {
for (int i = 0; i < docBufferUpto; i++) {
final int freq = (int) freqBuffer[i];
if (freq != 1) {
docOut.writeVInt(freq);
}
}
}
PostingsUtil.writeVIntBlock(docOut, docDeltaBuffer, freqBuffer, docBufferUpto, writeFreqs);
}

final long lastPosBlockOffset;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;

import java.io.IOException;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;

/** Utility class to encode/decode postings block. */
final class PostingsUtil {

/**
* Read values that have been written using variable-length encoding and group-varint encoding
* instead of bit-packing.
*/
static void readVIntBlock(
IndexInput docIn,
long[] docBuffer,
long[] freqBuffer,
int num,
boolean indexHasFreq,
boolean decodeFreq)
throws IOException {
docIn.readGroupVInts(docBuffer, num);
if (indexHasFreq && decodeFreq) {
for (int i = 0; i < num; ++i) {
freqBuffer[i] = docBuffer[i] & 0x01;
docBuffer[i] >>= 1;
if (freqBuffer[i] == 0) {
freqBuffer[i] = docIn.readVInt();
}
}
} else if (indexHasFreq) {
for (int i = 0; i < num; ++i) {
docBuffer[i] >>= 1;
}
}
}

/** Write freq buffer with variable-length encoding and doc buffer with group-varint encoding. */
static void writeVIntBlock(
IndexOutput docOut, long[] docBuffer, long[] freqBuffer, int num, boolean writeFreqs)
throws IOException {
if (writeFreqs) {
for (int i = 0; i < num; i++) {
docBuffer[i] = (docBuffer[i] << 1) | (freqBuffer[i] == 1 ? 1 : 0);
}
}
docOut.writeGroupVInts(docBuffer, num);
if (writeFreqs) {
for (int i = 0; i < num; i++) {
final int freq = (int) freqBuffer[i];
if (freq != 1) {
docOut.writeVInt(freq);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public final void readGroupVInts(long[] dst, int limit) throws IOException {
readGroupVInt(dst, i);
}
for (; i < limit; ++i) {
dst[i] = readVInt();
dst[i] = readVInt() & 0xFFFFFFFFL;
}
}

Expand Down
29 changes: 19 additions & 10 deletions lucene/core/src/java/org/apache/lucene/util/GroupVIntUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
public final class GroupVIntUtil {
// the maximum length of a single group-varint is 4 integers + 1 byte flag.
public static final int MAX_LENGTH_PER_GROUP = 17;
private static final int[] MASKS = new int[] {0xFF, 0xFFFF, 0xFFFFFF, 0xFFFFFFFF};

// we use long array instead of int array to make negative integer to be read as positive long.
private static final long[] MASKS = new long[] {0xFFL, 0xFFFFL, 0xFFFFFFL, 0xFFFFFFFFL};

/**
* Default implementation of read single group, for optimal performance, you should use {@link
Expand Down Expand Up @@ -118,6 +120,13 @@ private static int numBytes(int v) {
return Integer.BYTES - (Integer.numberOfLeadingZeros(v | 1) >> 3);
}

private static int toInt(long value) {
if ((Long.compareUnsigned(value, 0xFFFFFFFFL) > 0)) {
throw new ArithmeticException("integer overflow");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use Long.compareUnsigned? (if (Long.compareUnsigned(value, 0xFFFFFFFFL) > 0))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

return (int) value;
}

/**
* The implementation for group-varint encoding, It uses a maximum of {@link
* #MAX_LENGTH_PER_GROUP} bytes scratch buffer.
Expand All @@ -129,27 +138,27 @@ public static void writeGroupVInts(DataOutput out, byte[] scratch, long[] values
// encode each group
while ((limit - readPos) >= 4) {
int writePos = 0;
final int n1Minus1 = numBytes(Math.toIntExact(values[readPos])) - 1;
final int n2Minus1 = numBytes(Math.toIntExact(values[readPos + 1])) - 1;
final int n3Minus1 = numBytes(Math.toIntExact(values[readPos + 2])) - 1;
final int n4Minus1 = numBytes(Math.toIntExact(values[readPos + 3])) - 1;
final int n1Minus1 = numBytes(toInt(values[readPos])) - 1;
final int n2Minus1 = numBytes(toInt(values[readPos + 1])) - 1;
final int n3Minus1 = numBytes(toInt(values[readPos + 2])) - 1;
final int n4Minus1 = numBytes(toInt(values[readPos + 3])) - 1;
int flag = (n1Minus1 << 6) | (n2Minus1 << 4) | (n3Minus1 << 2) | (n4Minus1);
scratch[writePos++] = (byte) flag;
BitUtil.VH_LE_INT.set(scratch, writePos, Math.toIntExact(values[readPos++]));
BitUtil.VH_LE_INT.set(scratch, writePos, (int) (values[readPos++]));
writePos += n1Minus1 + 1;
BitUtil.VH_LE_INT.set(scratch, writePos, Math.toIntExact(values[readPos++]));
BitUtil.VH_LE_INT.set(scratch, writePos, (int) (values[readPos++]));
writePos += n2Minus1 + 1;
BitUtil.VH_LE_INT.set(scratch, writePos, Math.toIntExact(values[readPos++]));
BitUtil.VH_LE_INT.set(scratch, writePos, (int) (values[readPos++]));
writePos += n3Minus1 + 1;
BitUtil.VH_LE_INT.set(scratch, writePos, Math.toIntExact(values[readPos++]));
BitUtil.VH_LE_INT.set(scratch, writePos, (int) (values[readPos++]));
writePos += n4Minus1 + 1;

out.writeBytes(scratch, writePos);
}

// tail vints
for (; readPos < limit; readPos++) {
out.writeVInt(Math.toIntExact(values[readPos]));
out.writeVInt(toInt(values[readPos]));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;

import java.io.IOException;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.tests.util.LuceneTestCase;

public class TestPostingsUtil extends LuceneTestCase {

// checks for bug described in https://github.com/apache/lucene/issues/13373
public void testIntegerOverflow() throws IOException {
final int size = random().nextInt(1, ForUtil.BLOCK_SIZE);
final long[] docDeltaBuffer = new long[size];
final long[] freqBuffer = new long[size];

final int delta = 1 << 30;
docDeltaBuffer[0] = delta;
try (Directory dir = newDirectory()) {
try (IndexOutput out = dir.createOutput("test", IOContext.DEFAULT)) {
// In old implementation, this would cause integer overflow exception.
PostingsUtil.writeVIntBlock(out, docDeltaBuffer, freqBuffer, size, true);
}
long[] restoredDocs = new long[size];
long[] restoredFreqs = new long[size];
try (IndexInput in = dir.openInput("test", IOContext.DEFAULT)) {
PostingsUtil.readVIntBlock(in, restoredDocs, restoredFreqs, size, true, true);
}
assertEquals(delta, restoredDocs[0]);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,37 @@ public void testDataTypes() throws IOException {
}
}

public void testGroupVIntOverflow() throws IOException {
try (Directory dir = getDirectory(createTempDir("testGroupVIntOverflow"))) {
final int size = 32;
final long[] values = new long[size];
final long[] restore = new long[size];
values[0] = 1L << 31; // values[0] = 2147483648 as long, but as int it is -2147483648

for (int i = 0; i < size; i++) {
if (random().nextBoolean()) {
values[i] = values[0];
}
}

// a smaller limit value cover default implementation of readGroupVInts
// and a bigger limit value cover the faster implementation.
final int limit = random().nextInt(1, size);
IndexOutput out = dir.createOutput("test", IOContext.DEFAULT);
out.writeGroupVInts(values, limit);
out.close();
try (IndexInput in = dir.openInput("test", IOContext.DEFAULT)) {
in.readGroupVInts(restore, limit);
for (int i = 0; i < limit; i++) {
assertEquals(values[i], restore[i]);
}
}

values[0] = 0xFFFFFFFFL + 1;
assertThrows(ArithmeticException.class, () -> out.writeGroupVInts(values, 4));
}
}

public void testGroupVInt() throws IOException {
try (Directory dir = getDirectory(createTempDir("testGroupVInt"))) {
// test fallback to default implementation of readGroupVInt
Expand Down