diff --git a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java index 224ced5292ab..81da41778336 100644 --- a/java/core/src/main/java/com/google/protobuf/CodedInputStream.java +++ b/java/core/src/main/java/com/google/protobuf/CodedInputStream.java @@ -2278,6 +2278,9 @@ public String readString() throws IOException { if (size == 0) { return ""; } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } if (size <= bufferSize) { refillBuffer(size); String result = new String(buffer, pos, size, UTF_8); @@ -2302,6 +2305,8 @@ public String readStringRequireUtf8() throws IOException { tempPos = oldPos; } else if (size == 0) { return ""; + } else if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); } else if (size <= bufferSize) { refillBuffer(size); bytes = buffer; @@ -2396,6 +2401,9 @@ public ByteString readBytes() throws IOException { if (size == 0) { return ByteString.EMPTY; } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } return readBytesSlowPath(size); } @@ -2408,6 +2416,8 @@ public byte[] readByteArray() throws IOException { final byte[] result = Arrays.copyOfRange(buffer, pos, pos + size); pos += size; return result; + } else if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); } else { // Slow path: Build a byte array first then copy it. // TODO: Do we want to protect from malicious input streams here? @@ -2427,6 +2437,9 @@ public ByteBuffer readByteBuffer() throws IOException { if (size == 0) { return Internal.EMPTY_BYTE_BUFFER; } + if (size < 0) { + throw InvalidProtocolBufferException.negativeSize(); + } // Slow path: Build a byte array first then copy it. // We must copy as the byte array was handed off to the InputStream and a malicious diff --git a/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java b/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java index 5162888d0720..ff700587a160 100644 --- a/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java +++ b/java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java @@ -10,6 +10,7 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertThrows; import protobuf_unittest.UnittestProto.BoolMessage; import protobuf_unittest.UnittestProto.Int32Message; import protobuf_unittest.UnittestProto.Int64Message; @@ -534,6 +535,86 @@ public void testReadMaliciouslyLargeBlob() throws Exception { } } + @Test + public void testReadStringWithSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readString); + } + } + + @Test + public void testReadStringRequireUtf8WithSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readStringRequireUtf8); + } + } + + @Test + public void testReadBytesWithHugeSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readBytes); + } + } + + @Test + public void testReadByteArrayWithHugeSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readByteArray); + } + } + + @Test + public void testReadByteBufferWithSizeOverflow_throwsInvalidProtocolBufferException() + throws Exception { + ByteString.Output rawOutput = ByteString.newOutput(); + CodedOutputStream output = CodedOutputStream.newInstance(rawOutput); + + output.writeUInt32NoTag(0xFFFFFFFF); // Larger than Integer.MAX_VALUE. + output.writeRawBytes(new byte[32]); // Pad with a few random bytes. + output.flush(); + byte[] data = rawOutput.toByteString().toByteArray(); + for (InputType inputType : InputType.values()) { + CodedInputStream input = inputType.newDecoder(data); + assertThrows(InvalidProtocolBufferException.class, input::readByteBuffer); + } + } + /** * Test we can do messages that are up to CodedInputStream#DEFAULT_SIZE_LIMIT in size (2G or * Integer#MAX_SIZE).