Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix IntegerOverflow exception in postings encoding as group-varint #13376

Merged
merged 10 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ Bug Fixes

* GITHUB#13374: Fix bug in SQ when just a single vector present in a segment (Chris Hegarty)

* GITHUB#13376: Fix integer overflow exception in postings encoding as group-varint. (Zhang Chao, Guo Feng)

Build
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,31 +140,6 @@ public void init(IndexInput termsIn, SegmentReadState state) throws IOException
}
}

/** Read values that have been written using variable-length encoding instead of bit-packing. */
static void readVIntBlock(
IndexInput docIn,
long[] docBuffer,
long[] freqBuffer,
int num,
boolean indexHasFreq,
boolean decodeFreq)
throws IOException {
docIn.readGroupVInts(docBuffer, num);
if (indexHasFreq && decodeFreq) {
for (int i = 0; i < num; ++i) {
freqBuffer[i] = docBuffer[i] & 0x01;
docBuffer[i] >>= 1;
if (freqBuffer[i] == 0) {
freqBuffer[i] = docIn.readVInt();
}
}
} else if (indexHasFreq) {
for (int i = 0; i < num; ++i) {
docBuffer[i] >>= 1;
}
}
}

static void prefixSum(long[] buffer, int count, long base) {
buffer[0] += base;
for (int i = 1; i < count; ++i) {
Expand Down Expand Up @@ -475,7 +450,7 @@ private void refillDocs() throws IOException {
blockUpto++;
} else {
// Read vInts:
readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, needsFreq);
PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, needsFreq);
prefixSum(docBuffer, left, accum);
docBuffer[left] = NO_MORE_DOCS;
blockUpto += left;
Expand Down Expand Up @@ -768,7 +743,7 @@ private void refillDocs() throws IOException {
docBuffer[1] = NO_MORE_DOCS;
blockUpto++;
} else {
readVIntBlock(docIn, docBuffer, freqBuffer, left, true, true);
PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, true, true);
prefixSum(docBuffer, left, accum);
docBuffer[left] = NO_MORE_DOCS;
blockUpto += left;
Expand Down Expand Up @@ -1155,7 +1130,7 @@ private void refillDocs() throws IOException {
}
blockUpto += BLOCK_SIZE;
} else {
readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreqs, true);
PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreqs, true);
prefixSum(docBuffer, left, accum);
docBuffer[left] = NO_MORE_DOCS;
blockUpto += left;
Expand Down Expand Up @@ -1363,7 +1338,7 @@ private void refillDocs() throws IOException {
forDeltaUtil.decodeAndPrefixSum(docIn, accum, docBuffer);
pforUtil.decode(docIn, freqBuffer);
} else {
readVIntBlock(docIn, docBuffer, freqBuffer, left, true, true);
PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, true, true);
prefixSum(docBuffer, left, accum);
docBuffer[left] = NO_MORE_DOCS;
}
Expand Down Expand Up @@ -1753,7 +1728,7 @@ private void refillDocs() throws IOException {
false; // freq block will be loaded lazily when necessary, we don't load it here
}
} else {
readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, true);
PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, true);
prefixSum(docBuffer, left, accum);
docBuffer[left] = NO_MORE_DOCS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,20 +371,7 @@ public void finishTerm(BlockTermState _state) throws IOException {
} else {
singletonDocID = -1;
// Group vInt encode the remaining doc deltas and freqs:
if (writeFreqs) {
for (int i = 0; i < docBufferUpto; i++) {
docDeltaBuffer[i] = (docDeltaBuffer[i] << 1) | (freqBuffer[i] == 1 ? 1 : 0);
}
}
docOut.writeGroupVInts(docDeltaBuffer, docBufferUpto);
if (writeFreqs) {
for (int i = 0; i < docBufferUpto; i++) {
final int freq = (int) freqBuffer[i];
if (freq != 1) {
docOut.writeVInt(freq);
}
}
}
PostingsUtil.writeVIntBlock(docOut, docDeltaBuffer, freqBuffer, docBufferUpto, writeFreqs);
}

final long lastPosBlockOffset;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// This file has been automatically generated, DO NOT EDIT

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;

import java.io.IOException;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;

/** Utility class to encode/decode postings block. */
final class PostingsUtil {

/**
* Read values that have been written using variable-length encoding and group-varint encoding
* instead of bit-packing.
*/
static void readVIntBlock(
IndexInput docIn,
long[] docBuffer,
long[] freqBuffer,
int num,
boolean indexHasFreq,
boolean decodeFreq)
throws IOException {
docIn.readGroupVInts(docBuffer, num);
if (indexHasFreq && decodeFreq) {
for (int i = 0; i < num; ++i) {
freqBuffer[i] = docBuffer[i] & 0x01;
docBuffer[i] = (int) docBuffer[i] >>> 1;
if (freqBuffer[i] == 0) {
freqBuffer[i] = docIn.readVInt();
}
}
} else if (indexHasFreq) {
for (int i = 0; i < num; ++i) {
docBuffer[i] = (int) docBuffer[i] >>> 1;
}
}
}

/** Write freq buffer with variable-length encoding and doc buffer with group-varint encoding. */
static void writeVIntBlock(
IndexOutput docOut, long[] docBuffer, long[] freqBuffer, int num, boolean writeFreqs)
throws IOException {
if (writeFreqs) {
for (int i = 0; i < num; i++) {
docBuffer[i] = (int) (docBuffer[i] << 1) | (freqBuffer[i] == 1 ? 1 : 0);
}
}
docOut.writeGroupVInts(docBuffer, num);
if (writeFreqs) {
for (int i = 0; i < num; i++) {
final int freq = (int) freqBuffer[i];
if (freq != 1) {
docOut.writeVInt(freq);
}
}
}
}
}
7 changes: 5 additions & 2 deletions lucene/core/src/java/org/apache/lucene/store/DataOutput.java
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,12 @@ public void writeSetOfStrings(Set<String> set) throws IOException {
/**
* Encode integers using group-varint. It uses {@link DataOutput#writeVInt VInt} to encode tail
* values that are not enough for a group. we need a long[] because this is what postings are
* using, all longs are actually required to be integers.
* using, all longs are actually required to be integers. Negative numbers are supported, but
* should be avoided.
*
* @param values the values to write
* @param values the values to write. Note: if original integer is negative, it should also be
* negative as long, not positive which is greater than Integer.MAX_VALUE, that will cause
* integer overflow exception in {@link Math#toIntExact(long)}.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not mention this implementation detail.

Suggested change
* integer overflow exception in {@link Math#toIntExact(long)}.
* integer overflow exception.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change has reverted, there is no change in DataOutput in current fix approach.

* @param limit the number of values to write.
* @lucene.experimental
*/
Expand Down
10 changes: 5 additions & 5 deletions lucene/core/src/java/org/apache/lucene/util/GroupVIntUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@ public static void readGroupVInt(DataInput in, long[] dst, int offset) throws IO
dst[offset + 3] = readLongInGroup(in, n4Minus1);
}

private static long readLongInGroup(DataInput in, int numBytesMinus1) throws IOException {
private static int readLongInGroup(DataInput in, int numBytesMinus1) throws IOException {
switch (numBytesMinus1) {
case 0:
return in.readByte() & 0xFFL;
return in.readByte() & 0xFF;
case 1:
return in.readShort() & 0xFFFFL;
return in.readShort() & 0xFFFF;
case 2:
return (in.readShort() & 0xFFFFL) | ((in.readByte() & 0xFFL) << 16);
return (in.readShort() & 0xFFFF) | ((in.readByte() & 0xFF) << 16);
default:
return in.readInt() & 0xFFFFFFFFL;
return in.readInt();
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;

import java.io.IOException;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.tests.util.LuceneTestCase;

public class TestPostingUtil extends LuceneTestCase {

// checks for bug described in https://github.com/apache/lucene/issues/13373
public void testIntegerOverflow() throws IOException {
final long[] docDeltaBuffer = new long[ForUtil.BLOCK_SIZE];
final long[] freqBuffer = new long[ForUtil.BLOCK_SIZE];

final int delta = 1 << 30;
docDeltaBuffer[0] = delta;
try (Directory dir = newDirectory()) {
try (IndexOutput out = dir.createOutput("test", IOContext.DEFAULT)) {
// In old implementation, this would cause integer overflow exception.
PostingsUtil.writeVIntBlock(out, docDeltaBuffer, freqBuffer, ForUtil.BLOCK_SIZE, true);
}
long[] restoredDocs = new long[ForUtil.BLOCK_SIZE];
long[] restoredFreqs = new long[ForUtil.BLOCK_SIZE];
try (IndexInput in = dir.openInput("test", IOContext.DEFAULT)) {
PostingsUtil.readVIntBlock(in, restoredDocs, restoredFreqs, ForUtil.BLOCK_SIZE, true, true);
}
assertEquals(delta, restoredDocs[0]);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ protected Directory getDirectory(Path path) throws IOException {
public void testGroupVIntMultiBlocks() throws IOException {
final int maxChunkSize = random().nextInt(64, 512);
try (Directory dir = getDirectory(createTempDir(), maxChunkSize)) {
doTestGroupVInt(dir, 10, 1, 31, 1024);
doTestGroupVInt(dir, 10, 0, 1, 31, 1024);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1442,6 +1442,19 @@ public void testListAllIsSorted() throws IOException {
}
}

public void testGroupVIntOverflow() throws IOException {
try (Directory dir = getDirectory(createTempDir("testGroupVIntOverflow"))) {
final int v = 1 << 30;
final long[] values = new long[4];
values[0] = v;
values[0] <<= 1; // values[0] = 2147483648 as long, but as int it is -2147483648
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not do values[0] = 1L << 31 directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


IndexOutput out = dir.createOutput("test", IOContext.DEFAULT);
assertThrows(ArithmeticException.class, () -> out.writeGroupVInts(values, 4));
out.close();
}
}

public void testDataTypes() throws IOException {
final long[] values = new long[] {43, 12345, 123456, 1234567890};
try (Directory dir = getDirectory(createTempDir("testDataTypes"))) {
Expand All @@ -1468,18 +1481,21 @@ public void testDataTypes() throws IOException {
public void testGroupVInt() throws IOException {
try (Directory dir = getDirectory(createTempDir("testGroupVInt"))) {
// test fallback to default implementation of readGroupVInt
doTestGroupVInt(dir, 5, 1, 6, 8);
doTestGroupVInt(dir, 5, 0, 1, 6, 8);
doTestGroupVInt(dir, 5, -8, 3, 3, 8);

// use more iterations to covers all bpv
doTestGroupVInt(dir, atLeast(100), 1, 31, 128);
doTestGroupVInt(dir, atLeast(100), 0, 1, 31, 128);
doTestGroupVInt(dir, 5, Integer.MIN_VALUE, 31, 31, 128);

// we use BaseChunkedDirectoryTestCase#testGroupVIntMultiBlocks cover multiple blocks for
// ByteBuffersDataInput and MMapDirectory
}
}

protected void doTestGroupVInt(
Directory dir, int iterations, int minBpv, int maxBpv, int maxNumValues) throws IOException {
Directory dir, int iterations, int minValue, int minBpv, int maxBpv, int maxNumValues)
throws IOException {
long[] values = new long[maxNumValues];
int[] numValuesArray = new int[iterations];
IndexOutput groupVIntOut = dir.createOutput("group-varint", IOContext.DEFAULT);
Expand All @@ -1490,7 +1506,8 @@ protected void doTestGroupVInt(
final int bpv = TestUtil.nextInt(random(), minBpv, maxBpv);
numValuesArray[iter] = TestUtil.nextInt(random(), 1, maxNumValues);
for (int j = 0; j < numValuesArray[iter]; j++) {
values[j] = RandomNumbers.randomIntBetween(random(), 0, (int) PackedInts.maxValue(bpv));
values[j] =
RandomNumbers.randomIntBetween(random(), minValue, (int) PackedInts.maxValue(bpv));
vIntOut.writeVInt((int) values[j]);
}
groupVIntOut.writeGroupVInts(values, numValuesArray[iter]);
Expand Down