Skip to content

Commit

Permalink
[SPARK-47172] Addressing reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sweisdb committed Jun 7, 2024
1 parent 81a3ca0 commit 5bc503a
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public GcmEncryptedMessage touch(Object o) {
super.touch(o);
if (plaintextMessage instanceof ByteBuf byteBuf) {
byteBuf.touch(o);
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
} else if (plaintextMessage instanceof FileRegion fileRegion) {
fileRegion.touch(o);
}
return this;
Expand All @@ -173,7 +173,7 @@ public GcmEncryptedMessage retain(int increment) {
super.retain(increment);
if (plaintextMessage instanceof ByteBuf byteBuf) {
byteBuf.retain(increment);
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
} else if (plaintextMessage instanceof FileRegion fileRegion) {
fileRegion.retain(increment);
}
return this;
Expand All @@ -183,16 +183,14 @@ public GcmEncryptedMessage retain(int increment) {
public boolean release(int decrement) {
if (plaintextMessage instanceof ByteBuf byteBuf) {
byteBuf.release(decrement);
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
} else if (plaintextMessage instanceof FileRegion fileRegion) {
fileRegion.release(decrement);
}
return super.release(decrement);
}

@Override
public long transferTo(WritableByteChannel target, long position) throws IOException {
Preconditions.checkArgument(position == transferred(),
"Invalid position.");
int transferredThisCall = 0;
// If the header has is not empty, try to write it out to the target.
if (headerByteBuffer.hasRemaining()) {
Expand All @@ -214,24 +212,23 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep
}
while (bytesRead < bytesToRead) {
long readableBytes = getReadableBytes();
boolean lastSegment = readableBytes <= plaintextBuffer.capacity();
plaintextBuffer.clear();
int readLimit =
(int) Math.min(readableBytes, plaintextBuffer.capacity());
plaintextBuffer.limit(readLimit);
(int) Math.min(readableBytes, plaintextBuffer.remaining());
if (plaintextMessage instanceof ByteBuf byteBuf) {
plaintextBuffer.limit(readLimit);
byteBuf.readBytes(plaintextBuffer);
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
} else if (plaintextMessage instanceof FileRegion fileRegion) {
ByteBufferWriteableChannel plaintextChannel =
new ByteBufferWriteableChannel(plaintextBuffer);
long transferred =
long plaintextRead =
fileRegion.transferTo(plaintextChannel, fileRegion.transferred());
if (transferred < readLimit) {
if (plaintextRead < readLimit) {
// If we do not read a full plaintext buffer or all the available
// readable bytes, return what was transferred this call.
return transferredThisCall;
}
}
boolean lastSegment = getReadableBytes() == 0;
plaintextBuffer.flip();
bytesRead += plaintextBuffer.remaining();
ciphertextBuffer.clear();
Expand All @@ -240,6 +237,7 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep
} catch (GeneralSecurityException e) {
throw new IllegalStateException("GeneralSecurityException from encrypter", e);
}
plaintextBuffer.clear();
ciphertextBuffer.flip();
int written = target.write(ciphertextBuffer);
transferredThisCall += written;
Expand All @@ -256,7 +254,7 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep
private long getReadableBytes() {
if (plaintextMessage instanceof ByteBuf byteBuf) {
return byteBuf.readableBytes();
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
} else if (plaintextMessage instanceof FileRegion fileRegion) {
return fileRegion.count() - fileRegion.transferred();
} else {
throw new IllegalArgumentException("Unsupported message type: " +
Expand All @@ -283,7 +281,7 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter {
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;
private final StreamSegmentDecrypter decrypter;
private boolean decrypterInit = false;
private boolean lastSegment = false;
private boolean completed = false;
private int segmentNumber = 0;
private long expectedLength = -1;
private long ciphertextRead = 0;
Expand All @@ -300,7 +298,7 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter {
}

private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) {
if (expectedLength < 0 && expectedLengthBuffer.hasRemaining()) {
if (expectedLength < 0) {
ciphertextNettyBuf.readBytes(expectedLengthBuffer);
if (expectedLengthBuffer.hasRemaining()) {
// We did not read enough bytes to initialize the expected length.
Expand All @@ -320,7 +318,7 @@ private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf)
throws GeneralSecurityException {
// Check if the ciphertext header has been read. This contains
// the IV and other internal metadata.
if (!decrypterInit && headerBuffer.hasRemaining()) {
if (!decrypterInit) {
ciphertextNettyBuf.readBytes(headerBuffer);
if (headerBuffer.hasRemaining()) {
// We did not read enough bytes to initialize the header.
Expand All @@ -331,6 +329,10 @@ private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf)
decrypter.init(headerBuffer, lengthAad);
decrypterInit = true;
ciphertextRead += aesGcmHkdfStreaming.getHeaderLength();
if (expectedLength == ciphertextRead) {
// If the expected length is just the header, the ciphertext is 0 length.
completed = true;
}
}
return true;
}
Expand All @@ -350,46 +352,51 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
return;
}
if (!initalizeDecrypter(ciphertextNettyBuf)) {
// We nave not read enough bytes to initalize a header, needed to
// We have not read enough bytes to initialize a header, needed to
// initialize a decrypter.
return;
}
if (expectedLength == ciphertextRead) {
// If the expected length is just the header, the ciphertext is 0 length.
lastSegment = true;
}
while (ciphertextNettyBuf.readableBytes() > 0 && !lastSegment) {
ciphertextBuffer.clear();
int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes();
while (nettyBufReadableBytes > 0 && !completed) {
// Read the ciphertext into the local buffer
int readableBytes = Integer.min(
ciphertextNettyBuf.readableBytes(),
nettyBufReadableBytes,
ciphertextBuffer.remaining());
if (readableBytes == 0) {
return;
}
int expectedRemaining = (int) (expectedLength - ciphertextRead);
int bytesToRead = Integer.min(readableBytes, expectedRemaining);
// The smallest ciphertext size is 16 bytes for the auth tag
ciphertextBuffer.limit(bytesToRead);
ciphertextBuffer.limit(ciphertextBuffer.position() + bytesToRead);
ciphertextNettyBuf.readBytes(ciphertextBuffer);
ciphertextRead += bytesToRead;
// Check if this is the last segment
if (ciphertextRead == expectedLength) {
lastSegment = true;
completed = true;
} else if (ciphertextRead > expectedLength) {
throw new IllegalStateException("Read more ciphertext than expected.");
}
plaintextBuffer.clear();
ciphertextBuffer.flip();

decrypter.decryptSegment(
ciphertextBuffer,
segmentNumber,
lastSegment,
plaintextBuffer);
segmentNumber++;
plaintextBuffer.flip();
ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer));
// If the ciphertext buffer is full, or this is the last segment,
// then decrypt it and fire a read.
if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) {
plaintextBuffer.clear();
ciphertextBuffer.flip();
decrypter.decryptSegment(
ciphertextBuffer,
segmentNumber,
completed,
plaintextBuffer);
segmentNumber++;
// Clear the ciphertext buffer because it's been read
ciphertextBuffer.clear();
plaintextBuffer.flip();
ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer));
} else {
// Set the ciphertext buffer up to read the next chunk
ciphertextBuffer.limit(ciphertextBuffer.capacity());
}
nettyBufReadableBytes = ciphertextNettyBuf.readableBytes();
}
} finally {
ciphertextNettyBuf.release();
Expand Down
Loading

0 comments on commit 5bc503a

Please sign in to comment.