Skip to content

Commit

Permalink
[SPARK-47172] Addressing reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sweisdb committed Jun 4, 2024
1 parent 5f8cd7f commit 81a3ca0
Showing 1 changed file with 60 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,27 +221,24 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep
plaintextBuffer.limit(readLimit);
if (plaintextMessage instanceof ByteBuf byteBuf) {
byteBuf.readBytes(plaintextBuffer);
long inputBytesRead = readableBytes - byteBuf.readableBytes();
bytesRead += inputBytesRead;
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
ByteBufferWriteableChannel plaintextChannel =
new ByteBufferWriteableChannel(plaintextBuffer);
long transferred =
fileRegion.transferTo(plaintextChannel, fileRegion.transferred());
bytesRead += transferred;
if (transferred == 0) {
// File regions may return 0 if they are not ready to transfer
// more data. In that case, we'll return with the expectation
// that this transferTo() is called again.
if (transferred < readLimit) {
// If we do not read a full plaintext buffer or all the available
// readable bytes, return what was transferred this call.
return transferredThisCall;
}
}
plaintextBuffer.flip();
bytesRead += plaintextBuffer.remaining();
ciphertextBuffer.clear();
try {
encrypter.encryptSegment(plaintextBuffer, lastSegment, ciphertextBuffer);
} catch (GeneralSecurityException e) {
throw new RuntimeException(e);
throw new IllegalStateException("GeneralSecurityException from encrypter", e);
}
ciphertextBuffer.flip();
int written = target.write(ciphertextBuffer);
Expand Down Expand Up @@ -279,24 +276,65 @@ protected void deallocate() {

@VisibleForTesting
class DecryptionHandler extends ChannelInboundHandlerAdapter {
private final ByteBuffer expectedLengthBuffer;
private final ByteBuffer headerBuffer;
private final ByteBuffer ciphertextBuffer;
private final ByteBuffer plaintextBuffer;
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;
private final StreamSegmentDecrypter decrypter;
private boolean decrypterInit = false;
private boolean lastSegment = false;
private int segmentNumber = 0;
private long expectedLength = -1;
private long ciphertextRead = 0;

DecryptionHandler() throws GeneralSecurityException {
aesGcmHkdfStreaming = getAesGcmHkdfStreaming();
expectedLengthBuffer = ByteBuffer.allocate(LENGTH_HEADER_BYTES);
headerBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getHeaderLength());
plaintextBuffer =
ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize());
ciphertextBuffer =
ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter();
}

private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) {
if (expectedLength < 0 && expectedLengthBuffer.hasRemaining()) {
ciphertextNettyBuf.readBytes(expectedLengthBuffer);
if (expectedLengthBuffer.hasRemaining()) {
// We did not read enough bytes to initialize the expected length.
return false;
}
expectedLengthBuffer.flip();
expectedLength = expectedLengthBuffer.getLong();
if (expectedLength < 0) {
throw new IllegalStateException("Invalid expected ciphertext length.");
}
ciphertextRead += LENGTH_HEADER_BYTES;
}
return true;
}

private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf)
throws GeneralSecurityException {
// Check if the ciphertext header has been read. This contains
// the IV and other internal metadata.
if (!decrypterInit && headerBuffer.hasRemaining()) {
ciphertextNettyBuf.readBytes(headerBuffer);
if (headerBuffer.hasRemaining()) {
// We did not read enough bytes to initialize the header.
return false;
}
headerBuffer.flip();
byte[] lengthAad = Longs.toByteArray(expectedLength);
decrypter.init(headerBuffer, lengthAad);
decrypterInit = true;
ciphertextRead += aesGcmHkdfStreaming.getHeaderLength();
}
return true;
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
throws GeneralSecurityException {
Expand All @@ -307,37 +345,20 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
// The format of the output is:
// [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
try {
while (ciphertextNettyBuf.readableBytes() > 0) {
// Check if the expected ciphertext length has been read.
if (expectedLength < 0 &&
ciphertextNettyBuf.readableBytes() >= LENGTH_HEADER_BYTES) {
expectedLength = ciphertextNettyBuf.readLong();
if (expectedLength < 0) {
throw new IllegalStateException("Invalid expected ciphertext length.");
}
ciphertextRead += LENGTH_HEADER_BYTES;
}
int headerLength = aesGcmHkdfStreaming.getHeaderLength();
// Check if the ciphertext header has been read. This contains
// the IV and other internal metadata.
if (!decrypterInit &&
ciphertextNettyBuf.readableBytes() >= headerLength) {
ByteBuffer headerBuffer = ByteBuffer.allocate(headerLength);
ciphertextNettyBuf.readBytes(headerBuffer);
headerBuffer.flip();
byte[] lengthAad = Longs.toByteArray(expectedLength);
decrypter.init(headerBuffer, lengthAad);
decrypterInit = true;
ciphertextRead += headerLength;
}
// This may occur if there weren't enough readable bytes to read the header.
if (!decrypterInit) {
return;
}
// This may occur if the expected length is just the header.
if (expectedLength == ciphertextRead) {
return;
}
if (!initalizeExpectedLength(ciphertextNettyBuf)) {
// We have not read enough bytes to initialize the expected length.
return;
}
if (!initalizeDecrypter(ciphertextNettyBuf)) {
// We nave not read enough bytes to initalize a header, needed to
// initialize a decrypter.
return;
}
if (expectedLength == ciphertextRead) {
// If the expected length is just the header, the ciphertext is 0 length.
lastSegment = true;
}
while (ciphertextNettyBuf.readableBytes() > 0 && !lastSegment) {
ciphertextBuffer.clear();
// Read the ciphertext into the local buffer
int readableBytes = Integer.min(
Expand All @@ -353,7 +374,6 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
ciphertextNettyBuf.readBytes(ciphertextBuffer);
ciphertextRead += bytesToRead;
// Check if this is the last segment
boolean lastSegment = false;
if (ciphertextRead == expectedLength) {
lastSegment = true;
} else if (ciphertextRead > expectedLength) {
Expand Down

0 comments on commit 81a3ca0

Please sign in to comment.