Skip to content

Commit

Permalink
[SPARK-47172] Updating GcmTransportCipher to use buffered streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
sweisdb committed May 15, 2024
1 parent 0bfe017 commit b9baf22
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ private TransportCipher generateTransportCipher(
return new GcmTransportCipher(sessionKey);
} else {
throw new IllegalArgumentException(
"Unsupported cipher transformation: " + conf.cipherTransformation());
String.format("Unsupported cipher mode: %s. %s and %s are supported.",
conf.cipherTransformation(), CIPHER_ALGORITHM, LEGACY_CIPHER_ALGORITHM));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
package org.apache.spark.network.crypto;

import com.google.common.annotations.VisibleForTesting;
import com.google.crypto.tink.subtle.AesGcmJce;
import com.google.common.base.Preconditions;
import com.google.crypto.tink.subtle.AesGcmHkdfStreaming;
import com.google.crypto.tink.subtle.StreamSegmentEncrypter;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.util.ReferenceCounted;
import org.apache.spark.network.util.AbstractFileRegion;
import io.netty.buffer.ByteBuf;

Expand All @@ -33,80 +36,284 @@

public class GcmTransportCipher implements TransportCipher {
private static final byte[] DEFAULT_AAD = new byte[0];

private static final int LENGTH_HEADER_BYTES = 8;
@VisibleForTesting
static final int CIPHERTEXT_BUFFER_SIZE = 1024;
private final SecretKeySpec aesKey;

public GcmTransportCipher(SecretKeySpec aesKey) {
this.aesKey = aesKey;
}

@VisibleForTesting
EncryptionHandler getEncryptionHandler() {
EncryptionHandler getEncryptionHandler() throws GeneralSecurityException {
return new EncryptionHandler();
}

@VisibleForTesting
DecryptionHandler getDecryptionHandler() {
DecryptionHandler getDecryptionHandler() throws GeneralSecurityException {
return new DecryptionHandler();
}

public void addToChannel(Channel ch) {
public void addToChannel(Channel ch) throws GeneralSecurityException {
ch.pipeline()
.addFirst("GcmTransportEncryption", getEncryptionHandler())
.addFirst("GcmTransportDecryption", getDecryptionHandler());
}

@VisibleForTesting
class EncryptionHandler extends ChannelOutboundHandlerAdapter {
private final ByteBuffer plaintextBuffer;
private final ByteBuffer ciphertextBuffer;
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;

EncryptionHandler() throws GeneralSecurityException {
aesGcmHkdfStreaming = new AesGcmHkdfStreaming(
aesKey.getEncoded(),
"HmacSha256",
aesKey.getEncoded().length,
CIPHERTEXT_BUFFER_SIZE,
0);
plaintextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize());
ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
}

@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
ByteBuffer inputBuffer;
int bytesToRead;
if (msg instanceof ByteBuf byteBuf) {
bytesToRead = byteBuf.readableBytes();
// This is allocating a buffer that is the size of the input
inputBuffer = ByteBuffer.allocate(bytesToRead);
// This will block while copying
while (inputBuffer.position() < bytesToRead) {
byteBuf.readBytes(inputBuffer);
GcmEncryptedMessage encryptedMessage = new GcmEncryptedMessage(
aesGcmHkdfStreaming,
msg,
plaintextBuffer,
ciphertextBuffer);
ctx.write(encryptedMessage, promise);
}
}

static class GcmEncryptedMessage extends AbstractFileRegion {
private final Object plaintextMessage;
private final ByteBuffer plaintextBuffer;
private final ByteBuffer ciphertextBuffer;
private final long bytesToRead;
private long bytesRead = 0;
private final StreamSegmentEncrypter encrypter;
private boolean headerWritten = false;
private long transferred = 0;
private final long encryptedCount;

GcmEncryptedMessage(AesGcmHkdfStreaming aesGcmHkdfStreaming,
Object plaintextMessage,
ByteBuffer plaintextBuffer,
ByteBuffer ciphertextBuffer) throws GeneralSecurityException {
Preconditions.checkArgument(
plaintextMessage instanceof ByteBuf || plaintextMessage instanceof FileRegion,
"Unrecognized message type: %s", plaintextMessage.getClass().getName());
this.plaintextMessage = plaintextMessage;
this.bytesToRead = getReadableBytes();
this.plaintextBuffer = plaintextBuffer;
this.plaintextBuffer.clear();
this.ciphertextBuffer = ciphertextBuffer;
this.ciphertextBuffer.clear();
this.encrypter = aesGcmHkdfStreaming.newStreamSegmentEncrypter(DEFAULT_AAD);
this.encryptedCount =
LENGTH_HEADER_BYTES + aesGcmHkdfStreaming.expectedCiphertextSize(bytesToRead);
}

@Override
public long position() {
return 0;
}

@Override
public long transferred() {
return transferred;
}

@Override
public long count() {
return encryptedCount;
}

@Override
public long transferTo(WritableByteChannel target, long position) throws IOException {
Preconditions.checkArgument(position == transferred(),
"Invalid position.");
int transferredThisCall = 0;
// The format of the output is:
// [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
if (!headerWritten) {
ByteBuffer expectedLength = ByteBuffer
.allocate(LENGTH_HEADER_BYTES)
.putLong(encryptedCount)
.flip();
target.write(expectedLength);
int headerWritten = LENGTH_HEADER_BYTES + target.write(encrypter.getHeader());
transferredThisCall += headerWritten;
this.transferred += headerWritten;
this.headerWritten = true;
}

while (bytesRead < bytesToRead) {
long readableBytes = getReadableBytes();
boolean lastSegment = readableBytes <= plaintextBuffer.capacity();
plaintextBuffer.clear();
int readLimit =
(int) Math.min(readableBytes, plaintextBuffer.capacity());
plaintextBuffer.limit(readLimit);
if (plaintextMessage instanceof ByteBuf byteBuf) {
byteBuf.readBytes(plaintextBuffer);
long inputBytesRead = readableBytes - byteBuf.readableBytes();
bytesRead += inputBytesRead;
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
ByteBufferWriteableChannel plaintextChannel =
new ByteBufferWriteableChannel(plaintextBuffer);
long transferred =
fileRegion.transferTo(plaintextChannel, fileRegion.transferred());
bytesRead += transferred;
if (transferred == 0) {
// File regions may return 0 if they are not ready to transfer
// more data. In that case, we'll return with the expectation
// that this transferTo() is called again.
return transferredThisCall;
}
}
} else if (msg instanceof AbstractFileRegion fileRegion) {
bytesToRead = (int) fileRegion.count();
// This is allocating a buffer that is the size of the input
inputBuffer = ByteBuffer.allocate(bytesToRead);
ByteBufferWriteableChannel writeableChannel =
new ByteBufferWriteableChannel(inputBuffer);
long transferred = 0;
// This will block while copying
while (transferred < bytesToRead) {
transferred +=
fileRegion.transferTo(writeableChannel, fileRegion.transferred());
plaintextBuffer.flip();
ciphertextBuffer.clear();
try {
encrypter.encryptSegment(plaintextBuffer, lastSegment, ciphertextBuffer);
} catch (GeneralSecurityException e) {
throw new RuntimeException(e);
}
ciphertextBuffer.flip();
int outputRemaining = ciphertextBuffer.remaining();
while (ciphertextBuffer.hasRemaining()) {
target.write(ciphertextBuffer);
}
transferredThisCall += outputRemaining;
transferred += outputRemaining;
}
return transferredThisCall;
}

private long getReadableBytes() {
if (plaintextMessage instanceof ByteBuf byteBuf) {
return byteBuf.readableBytes();
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
return fileRegion.count() - fileRegion.transferred();
} else {
throw new IllegalArgumentException("Unsupported message type: " + msg.getClass());
throw new IllegalArgumentException("Unsupported message type: " +
plaintextMessage.getClass().getName());
}
AesGcmJce cipher = new AesGcmJce(aesKey.getEncoded());
byte[] encrypted = cipher.encrypt(inputBuffer.array(), DEFAULT_AAD);
ByteBuf wrappedEncrypted = Unpooled.wrappedBuffer(encrypted);
ctx.write(wrappedEncrypted, promise);
}

@Override
protected void deallocate() {
if (plaintextMessage instanceof ReferenceCounted referenceCounted) {
referenceCounted.release();
}
plaintextBuffer.clear();
ciphertextBuffer.clear();
}
}

@VisibleForTesting
class DecryptionHandler extends ChannelInboundHandlerAdapter {
private final ByteBuffer ciphertextBuffer;
private final ByteBuffer plaintextBuffer;
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;
private final StreamSegmentDecrypter decrypter;
private boolean decrypterInit = false;
private int segmentNumber = 0;
private long expectedLength = -1;
private long ciphertextRead = 0;

DecryptionHandler() throws GeneralSecurityException {
aesGcmHkdfStreaming = new AesGcmHkdfStreaming(
aesKey.getEncoded(),
"HmacSha256",
aesKey.getEncoded().length,
CIPHERTEXT_BUFFER_SIZE,
0);
plaintextBuffer =
ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize());
ciphertextBuffer =
ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter();
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg)
public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
throws GeneralSecurityException {
if (msg instanceof ByteBuf byteBuf) {
// This is allocating a buffer that is the size of the input
byte[] encrypted = new byte[byteBuf.readableBytes()];
byteBuf.readBytes(encrypted);
AesGcmJce cipher = new AesGcmJce(aesKey.getEncoded());
byte[] decrypted = cipher.decrypt(encrypted, DEFAULT_AAD);
ByteBuf wrappedDecrypted = Unpooled.wrappedBuffer(decrypted);
ctx.fireChannelRead(wrappedDecrypted);
} else {
throw new IllegalArgumentException("Unsupported message type: " + msg.getClass());
Preconditions.checkArgument(ciphertextMessage instanceof ByteBuf,
"Unrecognized message type: %s",
ciphertextMessage.getClass().getName());
ByteBuf ciphertextNettyBuf = (ByteBuf) ciphertextMessage;
// The format of the output is:
// [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
try {
while (ciphertextNettyBuf.readableBytes() > 0) {
// Check if the expected ciphertext length has been read.
if (expectedLength < 0 &&
ciphertextNettyBuf.readableBytes() >= LENGTH_HEADER_BYTES) {
expectedLength = ciphertextNettyBuf.readLong();
if (expectedLength < 0) {
throw new IllegalStateException("Invalid expected ciphertext length.");
}
ciphertextRead += LENGTH_HEADER_BYTES;
}
int headerLength = aesGcmHkdfStreaming.getHeaderLength();
// Check if the ciphertext header has been read. This contains
// the IV and other internal metadata.
if (!decrypterInit &&
ciphertextNettyBuf.readableBytes() >= headerLength) {
ByteBuffer headerBuffer = ByteBuffer.allocate(headerLength);
ciphertextNettyBuf.readBytes(headerBuffer);
headerBuffer.flip();
decrypter.init(headerBuffer, DEFAULT_AAD);
decrypterInit = true;
ciphertextRead += headerLength;
}
// This may occur if there weren't enough readable bytes to read the header.
if (!decrypterInit) {
return;
}
// This may occur if the expected length is just the header.
if (expectedLength == ciphertextRead) {
return;
}
ciphertextBuffer.clear();
// Read the ciphertext into the local buffer
int readableBytes = Integer.min(
ciphertextNettyBuf.readableBytes(),
ciphertextBuffer.remaining());
if (readableBytes == 0) {
return;
}
// The smallest ciphertext size is 16 bytes for the auth tag
ciphertextBuffer.limit(readableBytes);
ciphertextNettyBuf.readBytes(ciphertextBuffer);
ciphertextRead += readableBytes;
// Check if this is the last segment
boolean lastSegment = false;
if (ciphertextRead == expectedLength) {
lastSegment = true;
} else if (ciphertextRead > expectedLength) {
throw new IllegalStateException("Read more ciphertext than expected.");
}
plaintextBuffer.clear();
ciphertextBuffer.flip();

decrypter.decryptSegment(
ciphertextBuffer,
segmentNumber,
lastSegment,
plaintextBuffer);
segmentNumber++;
plaintextBuffer.flip();
ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer));
}
} finally {
ciphertextNettyBuf.release();
}
}
}
Expand All @@ -125,8 +332,9 @@ public int write(ByteBuffer src) throws IOException {
throw new ClosedChannelException();
}
int bytesToWrite = Math.min(src.remaining(), destination.remaining());
// Destination buffer is full
if (bytesToWrite == 0) {
return 0; // Destination buffer is full
return 0;
}
ByteBuffer temp = src.slice().limit(bytesToWrite);
destination.put(temp);
Expand Down
Loading

0 comments on commit b9baf22

Please sign in to comment.