From 22ddd481f081d5f01e72014865fb4a0c76ea2b54 Mon Sep 17 00:00:00 2001 From: Zhang Chao <80152403@qq.com> Date: Fri, 17 May 2024 23:43:23 +0800 Subject: [PATCH] Fix IntegerOverflow exception in postings encoding as group-varint (#13376) The exception happen because the tail postings list block, which encoding with GroupVInt, had a docID delta that was >= 1<<30, when the postings are also storing freqs. --- lucene/CHANGES.txt | 2 + .../lucene99/Lucene99PostingsReader.java | 35 ++------- .../lucene99/Lucene99PostingsWriter.java | 15 +--- .../lucene/codecs/lucene99/PostingsUtil.java | 73 +++++++++++++++++++ .../org/apache/lucene/store/DataInput.java | 2 +- .../org/apache/lucene/util/GroupVIntUtil.java | 29 +++++--- .../codecs/lucene99/TestPostingsUtil.java | 49 +++++++++++++ .../tests/store/BaseDirectoryTestCase.java | 31 ++++++++ 8 files changed, 181 insertions(+), 55 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene99/PostingsUtil.java create mode 100644 lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestPostingsUtil.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 8afa946ca723..6c2b916859e8 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -374,6 +374,8 @@ Bug Fixes * GITHUB#13374: Fix bug in SQ when just a single vector present in a segment (Chris Hegarty) +* GITHUB#13376: Fix integer overflow exception in postings encoding as group-varint. (Zhang Chao, Guo Feng) + Build --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99PostingsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99PostingsReader.java index 48c093b7570e..13353c426034 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99PostingsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99PostingsReader.java @@ -143,31 +143,6 @@ public void init(IndexInput termsIn, SegmentReadState state) throws IOException } } - /** Read values that have been written using variable-length encoding instead of bit-packing. */ - static void readVIntBlock( - IndexInput docIn, - long[] docBuffer, - long[] freqBuffer, - int num, - boolean indexHasFreq, - boolean decodeFreq) - throws IOException { - docIn.readGroupVInts(docBuffer, num); - if (indexHasFreq && decodeFreq) { - for (int i = 0; i < num; ++i) { - freqBuffer[i] = docBuffer[i] & 0x01; - docBuffer[i] >>= 1; - if (freqBuffer[i] == 0) { - freqBuffer[i] = docIn.readVInt(); - } - } - } else if (indexHasFreq) { - for (int i = 0; i < num; ++i) { - docBuffer[i] >>= 1; - } - } - } - static void prefixSum(long[] buffer, int count, long base) { buffer[0] += base; for (int i = 1; i < count; ++i) { @@ -480,7 +455,7 @@ private void refillDocs() throws IOException { blockUpto++; } else { // Read vInts: - readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, needsFreq); + PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, needsFreq); prefixSum(docBuffer, left, accum); docBuffer[left] = NO_MORE_DOCS; blockUpto += left; @@ -783,7 +758,7 @@ private void refillDocs() throws IOException { docBuffer[1] = NO_MORE_DOCS; blockUpto++; } else { - readVIntBlock(docIn, docBuffer, freqBuffer, left, true, true); + PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, true, true); prefixSum(docBuffer, left, accum); docBuffer[left] = NO_MORE_DOCS; blockUpto += left; @@ -1179,7 +1154,7 @@ private void refillDocs() throws IOException { } blockUpto += BLOCK_SIZE; } else { - readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreqs, true); + PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreqs, true); prefixSum(docBuffer, left, accum); docBuffer[left] = NO_MORE_DOCS; blockUpto += left; @@ -1388,7 +1363,7 @@ private void refillDocs() throws IOException { forDeltaUtil.decodeAndPrefixSum(docIn, accum, docBuffer); pforUtil.decode(docIn, freqBuffer); } else { - readVIntBlock(docIn, docBuffer, freqBuffer, left, true, true); + PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, true, true); prefixSum(docBuffer, left, accum); docBuffer[left] = NO_MORE_DOCS; } @@ -1779,7 +1754,7 @@ private void refillDocs() throws IOException { false; // freq block will be loaded lazily when necessary, we don't load it here } } else { - readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, true); + PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, true); prefixSum(docBuffer, left, accum); docBuffer[left] = NO_MORE_DOCS; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99PostingsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99PostingsWriter.java index a001bea210c3..61949cfe2270 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99PostingsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99PostingsWriter.java @@ -371,20 +371,7 @@ public void finishTerm(BlockTermState _state) throws IOException { } else { singletonDocID = -1; // Group vInt encode the remaining doc deltas and freqs: - if (writeFreqs) { - for (int i = 0; i < docBufferUpto; i++) { - docDeltaBuffer[i] = (docDeltaBuffer[i] << 1) | (freqBuffer[i] == 1 ? 1 : 0); - } - } - docOut.writeGroupVInts(docDeltaBuffer, docBufferUpto); - if (writeFreqs) { - for (int i = 0; i < docBufferUpto; i++) { - final int freq = (int) freqBuffer[i]; - if (freq != 1) { - docOut.writeVInt(freq); - } - } - } + PostingsUtil.writeVIntBlock(docOut, docDeltaBuffer, freqBuffer, docBufferUpto, writeFreqs); } final long lastPosBlockOffset; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/PostingsUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/PostingsUtil.java new file mode 100644 index 000000000000..678754047b6c --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/PostingsUtil.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene99; + +import java.io.IOException; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; + +/** Utility class to encode/decode postings block. */ +final class PostingsUtil { + + /** + * Read values that have been written using variable-length encoding and group-varint encoding + * instead of bit-packing. + */ + static void readVIntBlock( + IndexInput docIn, + long[] docBuffer, + long[] freqBuffer, + int num, + boolean indexHasFreq, + boolean decodeFreq) + throws IOException { + docIn.readGroupVInts(docBuffer, num); + if (indexHasFreq && decodeFreq) { + for (int i = 0; i < num; ++i) { + freqBuffer[i] = docBuffer[i] & 0x01; + docBuffer[i] >>= 1; + if (freqBuffer[i] == 0) { + freqBuffer[i] = docIn.readVInt(); + } + } + } else if (indexHasFreq) { + for (int i = 0; i < num; ++i) { + docBuffer[i] >>= 1; + } + } + } + + /** Write freq buffer with variable-length encoding and doc buffer with group-varint encoding. */ + static void writeVIntBlock( + IndexOutput docOut, long[] docBuffer, long[] freqBuffer, int num, boolean writeFreqs) + throws IOException { + if (writeFreqs) { + for (int i = 0; i < num; i++) { + docBuffer[i] = (docBuffer[i] << 1) | (freqBuffer[i] == 1 ? 1 : 0); + } + } + docOut.writeGroupVInts(docBuffer, num); + if (writeFreqs) { + for (int i = 0; i < num; i++) { + final int freq = (int) freqBuffer[i]; + if (freq != 1) { + docOut.writeVInt(freq); + } + } + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/store/DataInput.java b/lucene/core/src/java/org/apache/lucene/store/DataInput.java index 781066f02ab5..427e81f2df24 100644 --- a/lucene/core/src/java/org/apache/lucene/store/DataInput.java +++ b/lucene/core/src/java/org/apache/lucene/store/DataInput.java @@ -113,7 +113,7 @@ public final void readGroupVInts(long[] dst, int limit) throws IOException { readGroupVInt(dst, i); } for (; i < limit; ++i) { - dst[i] = readVInt(); + dst[i] = readVInt() & 0xFFFFFFFFL; } } diff --git a/lucene/core/src/java/org/apache/lucene/util/GroupVIntUtil.java b/lucene/core/src/java/org/apache/lucene/util/GroupVIntUtil.java index 2465e8a2a777..e1b5466342a0 100644 --- a/lucene/core/src/java/org/apache/lucene/util/GroupVIntUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/GroupVIntUtil.java @@ -28,7 +28,9 @@ public final class GroupVIntUtil { // the maximum length of a single group-varint is 4 integers + 1 byte flag. public static final int MAX_LENGTH_PER_GROUP = 17; - private static final int[] MASKS = new int[] {0xFF, 0xFFFF, 0xFFFFFF, 0xFFFFFFFF}; + + // we use long array instead of int array to make negative integer to be read as positive long. + private static final long[] MASKS = new long[] {0xFFL, 0xFFFFL, 0xFFFFFFL, 0xFFFFFFFFL}; /** * Default implementation of read single group, for optimal performance, you should use {@link @@ -118,6 +120,13 @@ private static int numBytes(int v) { return Integer.BYTES - (Integer.numberOfLeadingZeros(v | 1) >> 3); } + private static int toInt(long value) { + if ((Long.compareUnsigned(value, 0xFFFFFFFFL) > 0)) { + throw new ArithmeticException("integer overflow"); + } + return (int) value; + } + /** * The implementation for group-varint encoding, It uses a maximum of {@link * #MAX_LENGTH_PER_GROUP} bytes scratch buffer. @@ -129,19 +138,19 @@ public static void writeGroupVInts(DataOutput out, byte[] scratch, long[] values // encode each group while ((limit - readPos) >= 4) { int writePos = 0; - final int n1Minus1 = numBytes(Math.toIntExact(values[readPos])) - 1; - final int n2Minus1 = numBytes(Math.toIntExact(values[readPos + 1])) - 1; - final int n3Minus1 = numBytes(Math.toIntExact(values[readPos + 2])) - 1; - final int n4Minus1 = numBytes(Math.toIntExact(values[readPos + 3])) - 1; + final int n1Minus1 = numBytes(toInt(values[readPos])) - 1; + final int n2Minus1 = numBytes(toInt(values[readPos + 1])) - 1; + final int n3Minus1 = numBytes(toInt(values[readPos + 2])) - 1; + final int n4Minus1 = numBytes(toInt(values[readPos + 3])) - 1; int flag = (n1Minus1 << 6) | (n2Minus1 << 4) | (n3Minus1 << 2) | (n4Minus1); scratch[writePos++] = (byte) flag; - BitUtil.VH_LE_INT.set(scratch, writePos, Math.toIntExact(values[readPos++])); + BitUtil.VH_LE_INT.set(scratch, writePos, (int) (values[readPos++])); writePos += n1Minus1 + 1; - BitUtil.VH_LE_INT.set(scratch, writePos, Math.toIntExact(values[readPos++])); + BitUtil.VH_LE_INT.set(scratch, writePos, (int) (values[readPos++])); writePos += n2Minus1 + 1; - BitUtil.VH_LE_INT.set(scratch, writePos, Math.toIntExact(values[readPos++])); + BitUtil.VH_LE_INT.set(scratch, writePos, (int) (values[readPos++])); writePos += n3Minus1 + 1; - BitUtil.VH_LE_INT.set(scratch, writePos, Math.toIntExact(values[readPos++])); + BitUtil.VH_LE_INT.set(scratch, writePos, (int) (values[readPos++])); writePos += n4Minus1 + 1; out.writeBytes(scratch, writePos); @@ -149,7 +158,7 @@ public static void writeGroupVInts(DataOutput out, byte[] scratch, long[] values // tail vints for (; readPos < limit; readPos++) { - out.writeVInt(Math.toIntExact(values[readPos])); + out.writeVInt(toInt(values[readPos])); } } } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestPostingsUtil.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestPostingsUtil.java new file mode 100644 index 000000000000..b9cb3f20dbaa --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestPostingsUtil.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene99; + +import java.io.IOException; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestPostingsUtil extends LuceneTestCase { + + // checks for bug described in https://github.com/apache/lucene/issues/13373 + public void testIntegerOverflow() throws IOException { + final int size = random().nextInt(1, ForUtil.BLOCK_SIZE); + final long[] docDeltaBuffer = new long[size]; + final long[] freqBuffer = new long[size]; + + final int delta = 1 << 30; + docDeltaBuffer[0] = delta; + try (Directory dir = newDirectory()) { + try (IndexOutput out = dir.createOutput("test", IOContext.DEFAULT)) { + // In old implementation, this would cause integer overflow exception. + PostingsUtil.writeVIntBlock(out, docDeltaBuffer, freqBuffer, size, true); + } + long[] restoredDocs = new long[size]; + long[] restoredFreqs = new long[size]; + try (IndexInput in = dir.openInput("test", IOContext.DEFAULT)) { + PostingsUtil.readVIntBlock(in, restoredDocs, restoredFreqs, size, true, true); + } + assertEquals(delta, restoredDocs[0]); + } + } +} diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/store/BaseDirectoryTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/store/BaseDirectoryTestCase.java index bb252982c259..9cc271a9d618 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/store/BaseDirectoryTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/store/BaseDirectoryTestCase.java @@ -1465,6 +1465,37 @@ public void testDataTypes() throws IOException { } } + public void testGroupVIntOverflow() throws IOException { + try (Directory dir = getDirectory(createTempDir("testGroupVIntOverflow"))) { + final int size = 32; + final long[] values = new long[size]; + final long[] restore = new long[size]; + values[0] = 1L << 31; // values[0] = 2147483648 as long, but as int it is -2147483648 + + for (int i = 0; i < size; i++) { + if (random().nextBoolean()) { + values[i] = values[0]; + } + } + + // a smaller limit value cover default implementation of readGroupVInts + // and a bigger limit value cover the faster implementation. + final int limit = random().nextInt(1, size); + IndexOutput out = dir.createOutput("test", IOContext.DEFAULT); + out.writeGroupVInts(values, limit); + out.close(); + try (IndexInput in = dir.openInput("test", IOContext.DEFAULT)) { + in.readGroupVInts(restore, limit); + for (int i = 0; i < limit; i++) { + assertEquals(values[i], restore[i]); + } + } + + values[0] = 0xFFFFFFFFL + 1; + assertThrows(ArithmeticException.class, () -> out.writeGroupVInts(values, 4)); + } + } + public void testGroupVInt() throws IOException { try (Directory dir = getDirectory(createTempDir("testGroupVInt"))) { // test fallback to default implementation of readGroupVInt