Skip to content

Commit

Permalink
[SPARK-47172] Addressing comments & Refactoring AuthEngine test suite…
Browse files Browse the repository at this point in the history
… into two classes
  • Loading branch information
sweisdb committed Jun 11, 2024
1 parent 4ceeee3 commit 4df0367
Show file tree
Hide file tree
Showing 7 changed files with 607 additions and 475 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.security.GeneralSecurityException;
import java.util.Properties;
import javax.crypto.spec.SecretKeySpec;
import javax.crypto.spec.IvParameterSpec;
Expand Down Expand Up @@ -64,6 +65,14 @@ public CtrTransportCipher(
this.outIv = outIv;
}

/*
* This method is for testing purposes only.
*/
@VisibleForTesting
public String getKeyId() throws GeneralSecurityException {
return TransportCipherUtil.getKeyId(key);
}

@VisibleForTesting
SecretKeySpec getKey() {
return key;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Longs;
import com.google.crypto.tink.subtle.AesGcmHkdfStreaming;
import com.google.crypto.tink.subtle.StreamSegmentDecrypter;
import com.google.crypto.tink.subtle.StreamSegmentEncrypter;
import com.google.crypto.tink.subtle.*;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
Expand Down Expand Up @@ -57,6 +55,14 @@ AesGcmHkdfStreaming getAesGcmHkdfStreaming() throws InvalidAlgorithmParameterExc
0);
}

/*
* This method is for testing purposes only.
*/
@VisibleForTesting
public String getKeyId() throws GeneralSecurityException {
return TransportCipherUtil.getKeyId(aesKey);
}

@VisibleForTesting
EncryptionHandler getEncryptionHandler() throws GeneralSecurityException {
return new EncryptionHandler();
Expand Down Expand Up @@ -215,8 +221,10 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep
int readLimit =
(int) Math.min(readableBytes, plaintextBuffer.remaining());
if (plaintextMessage instanceof ByteBuf byteBuf) {
Preconditions.checkState(0 == plaintextBuffer.position());
plaintextBuffer.limit(readLimit);
byteBuf.readBytes(plaintextBuffer);
Preconditions.checkState(readLimit == plaintextBuffer.position());
} else if (plaintextMessage instanceof FileRegion fileRegion) {
ByteBufferWriteableChannel plaintextChannel =
new ByteBufferWriteableChannel(plaintextBuffer);
Expand Down Expand Up @@ -277,9 +285,9 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter {
private final ByteBuffer expectedLengthBuffer;
private final ByteBuffer headerBuffer;
private final ByteBuffer ciphertextBuffer;
private final ByteBuffer plaintextBuffer;
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;
private final StreamSegmentDecrypter decrypter;
private final int plaintextSegmentSize;
private boolean decrypterInit = false;
private boolean completed = false;
private int segmentNumber = 0;
Expand All @@ -290,11 +298,10 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter {
aesGcmHkdfStreaming = getAesGcmHkdfStreaming();
expectedLengthBuffer = ByteBuffer.allocate(LENGTH_HEADER_BYTES);
headerBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getHeaderLength());
plaintextBuffer =
ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize());
ciphertextBuffer =
ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter();
plaintextSegmentSize = aesGcmHkdfStreaming.getPlaintextSegmentSize();
}

private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) {
Expand Down Expand Up @@ -362,9 +369,6 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
int readableBytes = Integer.min(
nettyBufReadableBytes,
ciphertextBuffer.remaining());
if (readableBytes == 0) {
return;
}
int expectedRemaining = (int) (expectedLength - ciphertextRead);
int bytesToRead = Integer.min(readableBytes, expectedRemaining);
// The smallest ciphertext size is 16 bytes for the auth tag
Expand All @@ -380,7 +384,7 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
// If the ciphertext buffer is full, or this is the last segment,
// then decrypt it and fire a read.
if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) {
plaintextBuffer.clear();
ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize);
ciphertextBuffer.flip();
decrypter.decryptSegment(
ciphertextBuffer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,32 @@

package org.apache.spark.network.crypto;

import com.google.common.annotations.VisibleForTesting;
import com.google.crypto.tink.subtle.Hex;
import com.google.crypto.tink.subtle.Hkdf;
import io.netty.channel.Channel;

import javax.crypto.spec.SecretKeySpec;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;

interface TransportCipher {
String getKeyId() throws GeneralSecurityException;
void addToChannel(Channel channel) throws IOException, GeneralSecurityException;
}

class TransportCipherUtil {
/*
* This method is used for testing to verify key derivation.
*/
@VisibleForTesting
static String getKeyId(SecretKeySpec key) throws GeneralSecurityException {
byte[] keyIdBytes = Hkdf.computeHkdf("HmacSha256",
key.getEncoded(),
null,
"keyID".getBytes(StandardCharsets.UTF_8),
32);
return Hex.encode(keyIdBytes);
}
}
Loading

0 comments on commit 4df0367

Please sign in to comment.