Skip to content

Commit

Permalink
[SPARK-47172] Addressing reviewer comments for encryption
Browse files Browse the repository at this point in the history
  • Loading branch information
sweisdb committed Jun 7, 2024
1 parent 81a3ca0 commit 730a0bb
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ public boolean release(int decrement) {

@Override
public long transferTo(WritableByteChannel target, long position) throws IOException {
Preconditions.checkArgument(position == transferred(),
"Invalid position.");
int transferredThisCall = 0;
// If the header has is not empty, try to write it out to the target.
if (headerByteBuffer.hasRemaining()) {
Expand All @@ -214,24 +212,22 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep
}
while (bytesRead < bytesToRead) {
long readableBytes = getReadableBytes();
boolean lastSegment = readableBytes <= plaintextBuffer.capacity();
plaintextBuffer.clear();
int readLimit =
(int) Math.min(readableBytes, plaintextBuffer.capacity());
plaintextBuffer.limit(readLimit);
(int) Math.min(readableBytes, plaintextBuffer.remaining());
if (plaintextMessage instanceof ByteBuf byteBuf) {
byteBuf.readBytes(plaintextBuffer);
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
ByteBufferWriteableChannel plaintextChannel =
new ByteBufferWriteableChannel(plaintextBuffer);
long transferred =
long plaintextRead =
fileRegion.transferTo(plaintextChannel, fileRegion.transferred());
if (transferred < readLimit) {
if (plaintextRead < readLimit) {
// If we do not read a full plaintext buffer or all the available
// readable bytes, return what was transferred this call.
return transferredThisCall;
}
}
boolean lastSegment = getReadableBytes() == 0;
plaintextBuffer.flip();
bytesRead += plaintextBuffer.remaining();
ciphertextBuffer.clear();
Expand All @@ -240,6 +236,7 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep
} catch (GeneralSecurityException e) {
throw new IllegalStateException("GeneralSecurityException from encrypter", e);
}
plaintextBuffer.clear();
ciphertextBuffer.flip();
int written = target.write(ciphertextBuffer);
transferredThisCall += written;
Expand Down Expand Up @@ -283,7 +280,7 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter {
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;
private final StreamSegmentDecrypter decrypter;
private boolean decrypterInit = false;
private boolean lastSegment = false;
private boolean completed = false;
private int segmentNumber = 0;
private long expectedLength = -1;
private long ciphertextRead = 0;
Expand All @@ -300,7 +297,7 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter {
}

private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) {
if (expectedLength < 0 && expectedLengthBuffer.hasRemaining()) {
if (expectedLength < 0) {
ciphertextNettyBuf.readBytes(expectedLengthBuffer);
if (expectedLengthBuffer.hasRemaining()) {
// We did not read enough bytes to initialize the expected length.
Expand All @@ -320,7 +317,7 @@ private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf)
throws GeneralSecurityException {
// Check if the ciphertext header has been read. This contains
// the IV and other internal metadata.
if (!decrypterInit && headerBuffer.hasRemaining()) {
if (!decrypterInit) {
ciphertextNettyBuf.readBytes(headerBuffer);
if (headerBuffer.hasRemaining()) {
// We did not read enough bytes to initialize the header.
Expand All @@ -331,6 +328,10 @@ private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf)
decrypter.init(headerBuffer, lengthAad);
decrypterInit = true;
ciphertextRead += aesGcmHkdfStreaming.getHeaderLength();
if (expectedLength == ciphertextRead) {
// If the expected length is just the header, the ciphertext is 0 length.
completed = true;
}
}
return true;
}
Expand All @@ -350,15 +351,11 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
return;
}
if (!initalizeDecrypter(ciphertextNettyBuf)) {
// We nave not read enough bytes to initalize a header, needed to
// We have not read enough bytes to initialize a header, needed to
// initialize a decrypter.
return;
}
if (expectedLength == ciphertextRead) {
// If the expected length is just the header, the ciphertext is 0 length.
lastSegment = true;
}
while (ciphertextNettyBuf.readableBytes() > 0 && !lastSegment) {
while (ciphertextNettyBuf.readableBytes() > 0 && !completed) {
ciphertextBuffer.clear();
// Read the ciphertext into the local buffer
int readableBytes = Integer.min(
Expand All @@ -375,7 +372,7 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
ciphertextRead += bytesToRead;
// Check if this is the last segment
if (ciphertextRead == expectedLength) {
lastSegment = true;
completed = true;
} else if (ciphertextRead > expectedLength) {
throw new IllegalStateException("Read more ciphertext than expected.");
}
Expand All @@ -385,7 +382,7 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
decrypter.decryptSegment(
ciphertextBuffer,
segmentNumber,
lastSegment,
completed,
plaintextBuffer);
segmentNumber++;
plaintextBuffer.flip();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

package org.apache.spark.network.crypto;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
import java.security.GeneralSecurityException;
import java.util.Arrays;
import java.util.Map;
import java.util.Random;

Expand All @@ -31,10 +33,8 @@
import io.netty.channel.ChannelPromise;
import io.netty.channel.FileRegion;
import org.apache.spark.network.crypto.GcmTransportCipher.ByteBufferWriteableChannel;
import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.ConfigProvider;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.apache.spark.network.util.*;

import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -266,16 +266,6 @@ public void testCtrEncryptedMessage() throws Exception {
}
}

private ChannelHandlerContext mockChannelHandlerContext() {
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
when(ctx.fireChannelRead(any())).thenAnswer(in -> {
ByteBuf buf = (ByteBuf) in.getArguments()[0];
buf.release();
return null;
});
return ctx;
}

@Test
public void testGcmEncryptedMessage() throws Exception {
TransportConf gcmConf = getConf(2, false);
Expand Down Expand Up @@ -333,6 +323,136 @@ public void testGcmEncryptedMessage() throws Exception {
}
}


static class FakeRegion extends AbstractFileRegion {
private final ByteBuffer[] source;
private int sourcePosition;
private final long count;

FakeRegion(ByteBuffer... source) {
this.source = source;
sourcePosition = 0;
count = remaining();
}

private long remaining() {
long remaining = 0;
for (ByteBuffer buffer : source) {
remaining += buffer.remaining();
}
return remaining;
}

@Override
public long position() {
return 0;
}

@Override
public long transferred() {
return count - remaining();
}

@Override
public long count() {
return count;
}

@Override
public long transferTo(WritableByteChannel target, long position) throws IOException {
if (sourcePosition < source.length) {
ByteBuffer currentBuffer = source[sourcePosition];
long written = target.write(currentBuffer);
if (!currentBuffer.hasRemaining()) {
sourcePosition++;
}
return written;
} else {
return 0;
}
}

@Override
protected void deallocate() {
}
}

private static ByteBuffer getTestByteBuf(int size, byte fill) {
byte[] data = new byte[size];
Arrays.fill(data, fill);
return ByteBuffer.wrap(data);
}

@Test
public void testGcmEncryptedMessageFileRegion() throws Exception {
TransportConf gcmConf = getConf(2, false);
try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf);
AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) {
AuthMessage clientChallenge = client.challenge();
AuthMessage serverResponse = server.response(clientChallenge);
client.deriveSessionCipher(clientChallenge, serverResponse);
TransportCipher clientCipher = server.sessionCipher();
// Verify that it derives a GcmTransportCipher
assert (clientCipher instanceof GcmTransportCipher);
GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher;
GcmTransportCipher.EncryptionHandler encryptionHandler =
gcmTransportCipher.getEncryptionHandler();
GcmTransportCipher.DecryptionHandler decryptionHandler =
gcmTransportCipher.getDecryptionHandler();
// Allocating 1.5x the buffer size to test multiple segments and a fractional segment.
int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16;
int halfSegmentSize = plaintextSegmentSize / 2;
int totalSize = plaintextSegmentSize + halfSegmentSize;

// Set up some fragmented segments to test
ByteBuffer halfSegment = getTestByteBuf(halfSegmentSize, (byte) 'a');
int smallFragmentSize = 128;
ByteBuffer smallFragment = getTestByteBuf(smallFragmentSize, (byte) 'b');
int remainderSize = totalSize - halfSegmentSize - smallFragmentSize;
ByteBuffer remainder = getTestByteBuf(remainderSize, (byte) 'c');
FakeRegion fakeRegion = new FakeRegion(halfSegment, smallFragment, remainder);
assertEquals(totalSize, fakeRegion.count());

ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
ChannelPromise promise = mock(ChannelPromise.class);
ArgumentCaptor<GcmTransportCipher.GcmEncryptedMessage> captorWrappedEncrypted =
ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class);
encryptionHandler.write(ctx, fakeRegion, promise);
verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise));

// Get the encrypted value and pass it to the decryption handler
GcmTransportCipher.GcmEncryptedMessage encrypted =
captorWrappedEncrypted.getValue();
ByteBuffer ciphertextBuffer =
ByteBuffer.allocate((int) encrypted.count());
ByteBufferWriteableChannel channel =
new ByteBufferWriteableChannel(ciphertextBuffer);

// We'll simulate the FileRegion only transferring half a segment.
// The encrypted message should buffer the partial segment plaintext.
long ciphertextTransferred = 0;
while (ciphertextTransferred < encrypted.count()) {
long chunkTransferred = encrypted.transferTo(channel, 0);
ciphertextTransferred += chunkTransferred;
}
assertEquals(encrypted.count(), ciphertextTransferred);

ciphertextBuffer.flip();
ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer);

// Capture the decrypted values and verify them
ArgumentCaptor<ByteBuf> captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class);
decryptionHandler.channelRead(ctx, ciphertext);
verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture());
ByteBuf plaintext = captorPlaintext.getValue();
// We expect this to be the last partial plaintext segment
int expectedLength = totalSize % plaintextSegmentSize;
assertEquals(expectedLength, plaintext.readableBytes());
// This will be the "remainder" segment that is filled to 'c'
assertEquals('c', plaintext.getByte(0));
}
}

@Test
public void testCorruptGcmEncryptedMessage() throws Exception {
TransportConf gcmConf = getConf(2, false);
Expand Down

0 comments on commit 730a0bb

Please sign in to comment.