diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java index 95636e8b1028e..85b893751b39c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java @@ -21,6 +21,7 @@ import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import java.nio.channels.WritableByteChannel; +import java.security.GeneralSecurityException; import java.util.Properties; import javax.crypto.spec.SecretKeySpec; import javax.crypto.spec.IvParameterSpec; @@ -64,6 +65,14 @@ public CtrTransportCipher( this.outIv = outIv; } + /* + * This method is for testing purposes only. + */ + @VisibleForTesting + public String getKeyId() throws GeneralSecurityException { + return TransportCipherUtil.getKeyId(key); + } + @VisibleForTesting SecretKeySpec getKey() { return key; diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java index 376b6b354d4ac..c3540838bef09 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java @@ -20,9 +20,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.primitives.Longs; -import com.google.crypto.tink.subtle.AesGcmHkdfStreaming; -import com.google.crypto.tink.subtle.StreamSegmentDecrypter; -import com.google.crypto.tink.subtle.StreamSegmentEncrypter; +import com.google.crypto.tink.subtle.*; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.*; @@ -57,6 +55,14 @@ AesGcmHkdfStreaming getAesGcmHkdfStreaming() throws InvalidAlgorithmParameterExc 0); } + /* + * This method is for testing purposes only. + */ + @VisibleForTesting + public String getKeyId() throws GeneralSecurityException { + return TransportCipherUtil.getKeyId(aesKey); + } + @VisibleForTesting EncryptionHandler getEncryptionHandler() throws GeneralSecurityException { return new EncryptionHandler(); @@ -215,8 +221,10 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep int readLimit = (int) Math.min(readableBytes, plaintextBuffer.remaining()); if (plaintextMessage instanceof ByteBuf byteBuf) { + Preconditions.checkState(0 == plaintextBuffer.position()); plaintextBuffer.limit(readLimit); byteBuf.readBytes(plaintextBuffer); + Preconditions.checkState(readLimit == plaintextBuffer.position()); } else if (plaintextMessage instanceof FileRegion fileRegion) { ByteBufferWriteableChannel plaintextChannel = new ByteBufferWriteableChannel(plaintextBuffer); @@ -277,9 +285,9 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter { private final ByteBuffer expectedLengthBuffer; private final ByteBuffer headerBuffer; private final ByteBuffer ciphertextBuffer; - private final ByteBuffer plaintextBuffer; private final AesGcmHkdfStreaming aesGcmHkdfStreaming; private final StreamSegmentDecrypter decrypter; + private final int plaintextSegmentSize; private boolean decrypterInit = false; private boolean completed = false; private int segmentNumber = 0; @@ -290,11 +298,10 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter { aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); expectedLengthBuffer = ByteBuffer.allocate(LENGTH_HEADER_BYTES); headerBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getHeaderLength()); - plaintextBuffer = - ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize()); ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter(); + plaintextSegmentSize = aesGcmHkdfStreaming.getPlaintextSegmentSize(); } private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) { @@ -362,9 +369,6 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) int readableBytes = Integer.min( nettyBufReadableBytes, ciphertextBuffer.remaining()); - if (readableBytes == 0) { - return; - } int expectedRemaining = (int) (expectedLength - ciphertextRead); int bytesToRead = Integer.min(readableBytes, expectedRemaining); // The smallest ciphertext size is 16 bytes for the auth tag @@ -380,7 +384,7 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) // If the ciphertext buffer is full, or this is the last segment, // then decrypt it and fire a read. if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) { - plaintextBuffer.clear(); + ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize); ciphertextBuffer.flip(); decrypter.decryptSegment( ciphertextBuffer, diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index d699018ca1816..355c552720185 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -17,11 +17,32 @@ package org.apache.spark.network.crypto; +import com.google.common.annotations.VisibleForTesting; +import com.google.crypto.tink.subtle.Hex; +import com.google.crypto.tink.subtle.Hkdf; import io.netty.channel.Channel; +import javax.crypto.spec.SecretKeySpec; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; interface TransportCipher { + String getKeyId() throws GeneralSecurityException; void addToChannel(Channel channel) throws IOException, GeneralSecurityException; } + +class TransportCipherUtil { + /* + * This method is used for testing to verify key derivation. + */ + @VisibleForTesting + static String getKeyId(SecretKeySpec key) throws GeneralSecurityException { + byte[] keyIdBytes = Hkdf.computeHkdf("HmacSha256", + key.getEncoded(), + null, + "keyID".getBytes(StandardCharsets.UTF_8), + 32); + return Hex.encode(keyIdBytes); + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java index 38f7ee691421e..628de9e780337 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java @@ -17,56 +17,36 @@ package org.apache.spark.network.crypto; -import java.io.IOException; import java.nio.ByteBuffer; -import java.nio.channels.WritableByteChannel; import java.security.GeneralSecurityException; -import java.util.Arrays; import java.util.Map; -import java.util.Random; import com.google.common.collect.ImmutableMap; import com.google.crypto.tink.subtle.Hex; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; -import io.netty.channel.FileRegion; import org.apache.spark.network.util.*; import static org.junit.jupiter.api.Assertions.*; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; -import static org.mockito.Mockito.*; -import org.mockito.ArgumentCaptor; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -import javax.crypto.AEADBadTagException; - -public class AuthEngineSuite { - - private static final String clientPrivate = - "efe6b68b3fce92158e3637f6ef9d937e75558928dd4b401de04b43d300a73186"; - private static final String clientChallengeHex = - "fb00000005617070496400000010890b6e960f48e998777267a7e4e623220000003c48ad7dc7ec9466da9" + - "3bda9f11488dc9404050e02c661d87d67c782444944c6e369b27e0a416c30845a2d9e64271511ca98b41d" + - "65f8c426e18ff380f6"; - private static final String serverResponseHex = - "fb00000005617070496400000010708451c9dd2792c97c1ca66e6df449ef0000003c64fe899ecdaf458d4" + - "e25e9d5c5a380b8e6d1a184692fac065ed84f8592c18e9629f9c636809dca2ffc041f20346eb53db78738" + - "08ecad08b46b5ee3ff"; - private static final String derivedKey = "2d6e7a9048c8265c33a8f3747bfcc84c"; +abstract class AuthEngineSuite { + static final String clientPrivate = + "efe6b68b3fce92158e3637f6ef9d937e75558928dd4b401de04b43d300a73186"; + static final String clientChallengeHex = + "fb00000005617070496400000010890b6e960f48e998777267a7e4e623220000003c48ad7dc7ec9466da9" + + "3bda9f11488dc9404050e02c661d87d67c782444944c6e369b27e0a416c30845a2d9e64271511ca98b41d" + + "65f8c426e18ff380f6"; + static final String serverResponseHex = + "fb00000005617070496400000010708451c9dd2792c97c1ca66e6df449ef0000003c64fe899ecdaf458d4" + + "e25e9d5c5a380b8e6d1a184692fac065ed84f8592c18e9629f9c636809dca2ffc041f20346eb53db78738" + + "08ecad08b46b5ee3ff"; + static final String derivedKeyId = + "de04fd52d71040ed9d260579dacfdf4f5695f991ce8ddb1dde05a7335880906e"; // This key would have been derived for version 1.0 protocol that did not run a final HKDF round. - private static final String unsafeDerivedKey = - "31963f15a320d5c90333f7ecf5cf3a31c7eaf151de07fef8494663a9f47cfd31"; + static final String unsafeDerivedKey = + "31963f15a320d5c90333f7ecf5cf3a31c7eaf151de07fef8494663a9f47cfd31"; + static TransportConf conf; - private static final String inputIv = "fc6a5dc8b90a9dad8f54f08b51a59ed2"; - private static final String outputIv = "a72709baf00785cad6329ce09f631f71"; - private static TransportConf conf; - - private static TransportConf getConf(int authEngineVerison, boolean useCtr) { + static TransportConf getConf(int authEngineVerison, boolean useCtr) { String authEngineVersion = (authEngineVerison == 1) ? "1" : "2"; String mode = useCtr ? "AES/CTR/NoPadding" : "AES/GCM/NoPadding"; Map confMap = ImmutableMap.of( @@ -78,17 +58,6 @@ private static TransportConf getConf(int authEngineVerison, boolean useCtr) { return new TransportConf("rpc", v2Provider); } - @BeforeAll - public static void setUp() { - Map confMap = ImmutableMap.of( - "spark.network.crypto.enabled", "true", - "spark.network.crypto.authEngineVersion", "2", - "spark.network.crypto.cipher", "AES/CTR/NoPadding" - ); - ConfigProvider v2Provider = new MapConfigProvider(confMap); - conf = getConf(2, true); - } - @Test public void testAuthEngine() throws Exception { try (AuthEngine client = new AuthEngine("appId", "secret", conf); @@ -96,16 +65,24 @@ public void testAuthEngine() throws Exception { AuthMessage clientChallenge = client.challenge(); AuthMessage serverResponse = server.response(clientChallenge); client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher serverCipher = server.sessionCipher(); TransportCipher clientCipher = client.sessionCipher(); - assert(clientCipher instanceof CtrTransportCipher); - assert(serverCipher instanceof CtrTransportCipher); - CtrTransportCipher ctrClient = (CtrTransportCipher) clientCipher; - CtrTransportCipher ctrServer = (CtrTransportCipher) serverCipher; - assertArrayEquals(ctrServer.getInputIv(), ctrClient.getOutputIv()); - assertArrayEquals(ctrServer.getOutputIv(), ctrClient.getInputIv()); - assertEquals(ctrServer.getKey(), ctrClient.getKey()); + assertEquals(clientCipher.getKeyId(), serverCipher.getKeyId()); + } + } + + @Test + public void testFixedChallengeResponse() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf)) { + byte[] clientPrivateKey = Hex.decode(clientPrivate); + client.setClientPrivateKey(clientPrivateKey); + AuthMessage clientChallenge = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); + AuthMessage serverResponse = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); + // Verify that the client will accept an old transcript. + client.deriveSessionCipher(clientChallenge, serverResponse); + assertEquals(client.sessionCipher().getKeyId(), derivedKeyId); } } @@ -161,7 +138,7 @@ public void testCorruptResponseSalt() throws Exception { AuthMessage serverResponse = server.response(clientChallenge); serverResponse.salt()[0] ^= 1; assertThrows(GeneralSecurityException.class, - () -> client.deriveSessionCipher(clientChallenge, serverResponse)); + () -> client.deriveSessionCipher(clientChallenge, serverResponse)); } } @@ -188,47 +165,6 @@ public void testFixedChallenge() throws Exception { } } - @Test - public void testFixedChallengeResponse() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf)) { - byte[] clientPrivateKey = Hex.decode(clientPrivate); - client.setClientPrivateKey(clientPrivateKey); - AuthMessage clientChallenge = - AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); - AuthMessage serverResponse = - AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); - // Verify that the client will accept an old transcript. - client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher clientCipher = client.sessionCipher(); - assert(clientCipher instanceof CtrTransportCipher); - CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; - assertEquals(Hex.encode(ctrTransportCipher.getKey().getEncoded()), derivedKey); - assertEquals(Hex.encode(ctrTransportCipher.getInputIv()), inputIv); - assertEquals(Hex.encode(ctrTransportCipher.getOutputIv()), outputIv); - } - } - - @Test - public void testFixedChallengeResponseUnsafeVersion() throws Exception { - TransportConf v1Conf = getConf(1, true); - try (AuthEngine client = new AuthEngine("appId", "secret", v1Conf)) { - byte[] clientPrivateKey = Hex.decode(clientPrivate); - client.setClientPrivateKey(clientPrivateKey); - AuthMessage clientChallenge = - AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); - AuthMessage serverResponse = - AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); - // Verify that the client will accept an old transcript. - client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher clientCipher = client.sessionCipher(); - assert(clientCipher instanceof CtrTransportCipher); - CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; - assertEquals(Hex.encode(ctrTransportCipher.getKey().getEncoded()), unsafeDerivedKey); - assertEquals(Hex.encode(ctrTransportCipher.getInputIv()), inputIv); - assertEquals(Hex.encode(ctrTransportCipher.getOutputIv()), outputIv); - } - } - @Test public void testMismatchedSecret() throws Exception { try (AuthEngine client = new AuthEngine("appId", "secret", conf); @@ -237,368 +173,4 @@ public void testMismatchedSecret() throws Exception { assertThrows(GeneralSecurityException.class, () -> server.response(clientChallenge)); } } - - @Test - public void testCtrEncryptedMessage() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); - AuthEngine server = new AuthEngine("appId", "secret", conf)) { - AuthMessage clientChallenge = client.challenge(); - AuthMessage serverResponse = server.response(clientChallenge); - client.deriveSessionCipher(clientChallenge, serverResponse); - - TransportCipher clientCipher = server.sessionCipher(); - assert(clientCipher instanceof CtrTransportCipher); - CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; - CtrTransportCipher.EncryptionHandler handler = - new CtrTransportCipher.EncryptionHandler(ctrTransportCipher); - - byte[] data = new byte[CtrTransportCipher.STREAM_BUFFER_SIZE + 1]; - new Random().nextBytes(data); - ByteBuf buf = Unpooled.wrappedBuffer(data); - - ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length); - CtrTransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(buf); - while (emsg.transferred() < emsg.count()) { - emsg.transferTo(channel, emsg.transferred()); - } - assertEquals(data.length, channel.length()); - } - } - - @Test - public void testGcmEncryptedMessage() throws Exception { - TransportConf gcmConf = getConf(2, false); - try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); - AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { - AuthMessage clientChallenge = client.challenge(); - AuthMessage serverResponse = server.response(clientChallenge); - client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher clientCipher = server.sessionCipher(); - // Verify that it derives a GcmTransportCipher - assert (clientCipher instanceof GcmTransportCipher); - GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; - GcmTransportCipher.EncryptionHandler encryptionHandler = - gcmTransportCipher.getEncryptionHandler(); - GcmTransportCipher.DecryptionHandler decryptionHandler = - gcmTransportCipher.getDecryptionHandler(); - // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. - int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; - byte[] data = new byte[plaintextSegmentSize + (plaintextSegmentSize / 2)]; - // Just writing some bytes. - data[0] = 'a'; - data[data.length / 2] = 'b'; - data[data.length - 10] = 'c'; - ByteBuf buf = Unpooled.wrappedBuffer(data); - - // Mock the context and capture the arguments passed to it - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - ChannelPromise promise = mock(ChannelPromise.class); - ArgumentCaptor captorWrappedEncrypted = - ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); - encryptionHandler.write(ctx, buf, promise); - verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); - - // Get the encrypted value and pass it to the decryption handler - GcmTransportCipher.GcmEncryptedMessage encrypted = - captorWrappedEncrypted.getValue(); - ByteBuffer ciphertextBuffer = - ByteBuffer.allocate((int) encrypted.count()); - ByteBufferWriteableChannel channel = - new ByteBufferWriteableChannel(ciphertextBuffer); - encrypted.transferTo(channel, 0); - ciphertextBuffer.flip(); - ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); - - // Capture the decrypted values and verify them - ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); - decryptionHandler.channelRead(ctx, ciphertext); - verify(ctx, times(2)) - .fireChannelRead(captorPlaintext.capture()); - ByteBuf lastPlaintextSegment = captorPlaintext.getValue(); - assertEquals(plaintextSegmentSize/2, - lastPlaintextSegment.readableBytes()); - assertEquals('c', - lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10)); - } - } - - static class FakeRegion extends AbstractFileRegion { - private final ByteBuffer[] source; - private int sourcePosition; - private final long count; - - FakeRegion(ByteBuffer... source) { - this.source = source; - sourcePosition = 0; - count = remaining(); - } - - private long remaining() { - long remaining = 0; - for (ByteBuffer buffer : source) { - remaining += buffer.remaining(); - } - return remaining; - } - - @Override - public long position() { - return 0; - } - - @Override - public long transferred() { - return count - remaining(); - } - - @Override - public long count() { - return count; - } - - @Override - public long transferTo(WritableByteChannel target, long position) throws IOException { - if (sourcePosition < source.length) { - ByteBuffer currentBuffer = source[sourcePosition]; - long written = target.write(currentBuffer); - if (!currentBuffer.hasRemaining()) { - sourcePosition++; - } - return written; - } else { - return 0; - } - } - - @Override - protected void deallocate() { - } - } - - private static ByteBuffer getTestByteBuf(int size, byte fill) { - byte[] data = new byte[size]; - Arrays.fill(data, fill); - return ByteBuffer.wrap(data); - } - - @Test - public void testGcmEncryptedMessageFileRegion() throws Exception { - TransportConf gcmConf = getConf(2, false); - try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); - AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { - AuthMessage clientChallenge = client.challenge(); - AuthMessage serverResponse = server.response(clientChallenge); - client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher clientCipher = server.sessionCipher(); - // Verify that it derives a GcmTransportCipher - assert (clientCipher instanceof GcmTransportCipher); - GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; - GcmTransportCipher.EncryptionHandler encryptionHandler = - gcmTransportCipher.getEncryptionHandler(); - GcmTransportCipher.DecryptionHandler decryptionHandler = - gcmTransportCipher.getDecryptionHandler(); - // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. - int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; - int halfSegmentSize = plaintextSegmentSize / 2; - int totalSize = plaintextSegmentSize + halfSegmentSize; - - // Set up some fragmented segments to test - ByteBuffer halfSegment = getTestByteBuf(halfSegmentSize, (byte) 'a'); - int smallFragmentSize = 128; - ByteBuffer smallFragment = getTestByteBuf(smallFragmentSize, (byte) 'b'); - int remainderSize = totalSize - halfSegmentSize - smallFragmentSize; - ByteBuffer remainder = getTestByteBuf(remainderSize, (byte) 'c'); - FakeRegion fakeRegion = new FakeRegion(halfSegment, smallFragment, remainder); - assertEquals(totalSize, fakeRegion.count()); - - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - ChannelPromise promise = mock(ChannelPromise.class); - ArgumentCaptor captorWrappedEncrypted = - ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); - encryptionHandler.write(ctx, fakeRegion, promise); - verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); - - // Get the encrypted value and pass it to the decryption handler - GcmTransportCipher.GcmEncryptedMessage encrypted = - captorWrappedEncrypted.getValue(); - ByteBuffer ciphertextBuffer = - ByteBuffer.allocate((int) encrypted.count()); - ByteBufferWriteableChannel channel = - new ByteBufferWriteableChannel(ciphertextBuffer); - - // We'll simulate the FileRegion only transferring half a segment. - // The encrypted message should buffer the partial segment plaintext. - long ciphertextTransferred = 0; - while (ciphertextTransferred < encrypted.count()) { - long chunkTransferred = encrypted.transferTo(channel, 0); - ciphertextTransferred += chunkTransferred; - } - assertEquals(encrypted.count(), ciphertextTransferred); - - ciphertextBuffer.flip(); - ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); - - // Capture the decrypted values and verify them - ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); - decryptionHandler.channelRead(ctx, ciphertext); - verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture()); - ByteBuf plaintext = captorPlaintext.getValue(); - // We expect this to be the last partial plaintext segment - int expectedLength = totalSize % plaintextSegmentSize; - assertEquals(expectedLength, plaintext.readableBytes()); - // This will be the "remainder" segment that is filled to 'c' - assertEquals('c', plaintext.getByte(0)); - } - } - - - @Test - public void testGcmUnalignedDecryption() throws Exception { - TransportConf gcmConf = getConf(2, false); - try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); - AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { - AuthMessage clientChallenge = client.challenge(); - AuthMessage serverResponse = server.response(clientChallenge); - client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher clientCipher = server.sessionCipher(); - // Verify that it derives a GcmTransportCipher - assert (clientCipher instanceof GcmTransportCipher); - GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; - GcmTransportCipher.EncryptionHandler encryptionHandler = - gcmTransportCipher.getEncryptionHandler(); - GcmTransportCipher.DecryptionHandler decryptionHandler = - gcmTransportCipher.getDecryptionHandler(); - // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. - int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; - int plaintextSize = plaintextSegmentSize + (plaintextSegmentSize / 2); - byte[] data = new byte[plaintextSize]; - Arrays.fill(data, (byte) 'x'); - ByteBuf buf = Unpooled.wrappedBuffer(data); - - // Mock the context and capture the arguments passed to it - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - ChannelPromise promise = mock(ChannelPromise.class); - ArgumentCaptor captorWrappedEncrypted = - ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); - encryptionHandler.write(ctx, buf, promise); - verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); - - // Get the encrypted value and pass it to the decryption handler - GcmTransportCipher.GcmEncryptedMessage encrypted = - captorWrappedEncrypted.getValue(); - ByteBuffer ciphertextBuffer = - ByteBuffer.allocate((int) encrypted.count()); - ByteBufferWriteableChannel channel = - new ByteBufferWriteableChannel(ciphertextBuffer); - encrypted.transferTo(channel, 0); - ciphertextBuffer.flip(); - ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); - - // Split up the ciphertext into some different sized chunks - int firstChunkSize = plaintextSize / 2; - ByteBuf mockCiphertext = spy(ciphertext); - when(mockCiphertext.readableBytes()) - .thenReturn(firstChunkSize, firstChunkSize).thenCallRealMethod(); - - // Capture the decrypted values and verify them - ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); - decryptionHandler.channelRead(ctx, mockCiphertext); - verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture()); - ByteBuf lastPlaintextSegment = captorPlaintext.getValue(); - assertEquals(plaintextSegmentSize/2, - lastPlaintextSegment.readableBytes()); - assertEquals('x', - lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10)); - } - } - - @Test - public void testCorruptGcmEncryptedMessage() throws Exception { - TransportConf gcmConf = getConf(2, false); - - try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); - AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { - AuthMessage clientChallenge = client.challenge(); - AuthMessage serverResponse = server.response(clientChallenge); - client.deriveSessionCipher(clientChallenge, serverResponse); - - TransportCipher clientCipher = server.sessionCipher(); - assert (clientCipher instanceof GcmTransportCipher); - - GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; - GcmTransportCipher.EncryptionHandler encryptionHandler = - gcmTransportCipher.getEncryptionHandler(); - GcmTransportCipher.DecryptionHandler decryptionHandler = - gcmTransportCipher.getDecryptionHandler(); - byte[] zeroData = new byte[1024 * 32]; - // Just writing some bytes. - ByteBuf buf = Unpooled.wrappedBuffer(zeroData); - - // Mock the context and capture the arguments passed to it - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - ChannelPromise promise = mock(ChannelPromise.class); - ArgumentCaptor captorWrappedEncrypted = - ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); - encryptionHandler.write(ctx, buf, promise); - verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); - - GcmTransportCipher.GcmEncryptedMessage encrypted = - captorWrappedEncrypted.getValue(); - ByteBuffer ciphertextBuffer = - ByteBuffer.allocate((int) encrypted.count()); - ByteBufferWriteableChannel channel = - new ByteBufferWriteableChannel(ciphertextBuffer); - encrypted.transferTo(channel, 0); - ciphertextBuffer.flip(); - ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); - - byte b = ciphertext.getByte(100); - // Inverting the bits of the 100th bit - ciphertext.setByte(100, ~b & 0xFF); - assertThrows(AEADBadTagException.class, () -> decryptionHandler.channelRead(ctx, ciphertext)); - } - } - - @Test - public void testCtrEncryptedMessageWhenTransferringZeroBytes() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); - AuthEngine server = new AuthEngine("appId", "secret", conf)) { - AuthMessage clientChallenge = client.challenge(); - AuthMessage serverResponse = server.response(clientChallenge); - client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher clientCipher = server.sessionCipher(); - assert(clientCipher instanceof CtrTransportCipher); - CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; - CtrTransportCipher.EncryptionHandler handler = - new CtrTransportCipher.EncryptionHandler(ctrTransportCipher); - int testDataLength = 4; - FileRegion region = mock(FileRegion.class); - when(region.count()).thenReturn((long) testDataLength); - // Make `region.transferTo` do nothing in first call and transfer 4 bytes in the second one. - when(region.transferTo(any(), anyLong())).thenAnswer(new Answer() { - - private boolean firstTime = true; - - @Override - public Long answer(InvocationOnMock invocationOnMock) throws Throwable { - if (firstTime) { - firstTime = false; - return 0L; - } else { - WritableByteChannel channel = invocationOnMock.getArgument(0); - channel.write(ByteBuffer.wrap(new byte[testDataLength])); - return (long) testDataLength; - } - } - }); - - CtrTransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(region); - ByteArrayWritableChannel channel = new ByteArrayWritableChannel(testDataLength); - // "transferTo" should act correctly when the underlying FileRegion transfers 0 bytes. - assertEquals(0L, emsg.transferTo(channel, emsg.transferred())); - assertEquals(testDataLength, emsg.transferTo(channel, emsg.transferred())); - assertEquals(emsg.transferred(), emsg.count()); - assertEquals(4, channel.length()); - } - } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java index 3ec9aa9926334..cb5929f7c65b4 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -78,8 +78,18 @@ public void testNewGcmAuth() throws Exception { } @Test - public void testAuthFailure() throws Exception { - ctx = new AuthTestCtx(); + public void testCtrAuthFailure() throws Exception { + ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/CTR/NoPadding"); + ctx.createServer("server"); + + assertThrows(Exception.class, () -> ctx.createClient("client")); + assertFalse(ctx.authRpcHandler.isAuthenticated()); + assertFalse(ctx.serverChannel.isActive()); + } + + @Test + public void testGcmAuthFailure() throws Exception { + ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/GCM/NoPadding"); ctx.createServer("server"); assertThrows(Exception.class, () -> ctx.createClient("client")); @@ -110,7 +120,7 @@ public void testSaslClientFallback() throws Exception { } @Test - public void testAuthReplay() throws Exception { + public void testCtrAuthReplay() throws Exception { // This test covers the case where an attacker replays a challenge message sniffed from the // network, but doesn't know the actual secret. The server should close the connection as // soon as a message is sent after authentication is performed. This is emulated by removing diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java new file mode 100644 index 0000000000000..c353ee392ff4f --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import com.google.crypto.tink.subtle.Hex; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.TransportConf; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +public class CtrAuthEngineSuite extends AuthEngineSuite { + private static final String inputIv = "fc6a5dc8b90a9dad8f54f08b51a59ed2"; + private static final String outputIv = "a72709baf00785cad6329ce09f631f71"; + + @BeforeAll + public static void setUp() { + conf = getConf(2, true); + } + + @Test + public void testAuthEngine() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + + TransportCipher serverCipher = server.sessionCipher(); + TransportCipher clientCipher = client.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + assert(serverCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrClient = (CtrTransportCipher) clientCipher; + CtrTransportCipher ctrServer = (CtrTransportCipher) serverCipher; + assertArrayEquals(ctrServer.getInputIv(), ctrClient.getOutputIv()); + assertArrayEquals(ctrServer.getOutputIv(), ctrClient.getInputIv()); + assertEquals(ctrServer.getKey(), ctrClient.getKey()); + } + } + + @Test + public void testCtrFixedChallengeIvResponse() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf)) { + byte[] clientPrivateKey = Hex.decode(clientPrivate); + client.setClientPrivateKey(clientPrivateKey); + AuthMessage clientChallenge = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); + AuthMessage serverResponse = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); + // Verify that the client will accept an old transcript. + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = client.sessionCipher(); + assertEquals(clientCipher.getKeyId(), derivedKeyId); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + assertEquals(Hex.encode(ctrTransportCipher.getInputIv()), inputIv); + assertEquals(Hex.encode(ctrTransportCipher.getOutputIv()), outputIv); + } + } + + @Test + public void testFixedChallengeResponseUnsafeVersion() throws Exception { + TransportConf v1Conf = getConf(1, true); + try (AuthEngine client = new AuthEngine("appId", "secret", v1Conf)) { + byte[] clientPrivateKey = Hex.decode(clientPrivate); + client.setClientPrivateKey(clientPrivateKey); + AuthMessage clientChallenge = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); + AuthMessage serverResponse = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); + // Verify that the client will accept an old transcript. + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = client.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + assertEquals(Hex.encode(ctrTransportCipher.getKey().getEncoded()), unsafeDerivedKey); + assertEquals(Hex.encode(ctrTransportCipher.getInputIv()), inputIv); + assertEquals(Hex.encode(ctrTransportCipher.getOutputIv()), outputIv); + } + } + + @Test + public void testCtrEncryptedMessage() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + + TransportCipher clientCipher = server.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + CtrTransportCipher.EncryptionHandler handler = + new CtrTransportCipher.EncryptionHandler(ctrTransportCipher); + + byte[] data = new byte[CtrTransportCipher.STREAM_BUFFER_SIZE + 1]; + new Random().nextBytes(data); + ByteBuf buf = Unpooled.wrappedBuffer(data); + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length); + CtrTransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(buf); + while (emsg.transferred() < emsg.count()) { + emsg.transferTo(channel, emsg.transferred()); + } + assertEquals(data.length, channel.length()); + } + } + + @Test + public void testCtrEncryptedMessageWhenTransferringZeroBytes() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + CtrTransportCipher.EncryptionHandler handler = + new CtrTransportCipher.EncryptionHandler(ctrTransportCipher); + int testDataLength = 4; + FileRegion region = mock(FileRegion.class); + when(region.count()).thenReturn((long) testDataLength); + // Make `region.transferTo` do nothing in first call and transfer 4 bytes in the second one. + when(region.transferTo(any(), anyLong())).thenAnswer(new Answer() { + + private boolean firstTime = true; + + @Override + public Long answer(InvocationOnMock invocationOnMock) throws Throwable { + if (firstTime) { + firstTime = false; + return 0L; + } else { + WritableByteChannel channel = invocationOnMock.getArgument(0); + channel.write(ByteBuffer.wrap(new byte[testDataLength])); + return (long) testDataLength; + } + } + }); + + CtrTransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(region); + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(testDataLength); + // "transferTo" should act correctly when the underlying FileRegion transfers 0 bytes. + assertEquals(0L, emsg.transferTo(channel, emsg.transferred())); + assertEquals(testDataLength, emsg.transferTo(channel, emsg.transferred())); + assertEquals(emsg.transferred(), emsg.count()); + assertEquals(4, channel.length()); + } + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java new file mode 100644 index 0000000000000..20efb8d57dcbf --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import org.apache.spark.network.util.*; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import javax.crypto.AEADBadTagException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +public class GcmAuthEngineSuite extends AuthEngineSuite { + + @BeforeAll + public static void setUp() { + // Uses GCM mode + conf = getConf(2, false); + } + + @Test + public void testGcmEncryptedMessage() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + // Verify that it derives a GcmTransportCipher + assert (clientCipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + byte[] data = new byte[plaintextSegmentSize + (plaintextSegmentSize / 2)]; + // Just writing some bytes. + data[0] = 'a'; + data[data.length / 2] = 'b'; + data[data.length - 10] = 'c'; + ByteBuf buf = Unpooled.wrappedBuffer(data); + + // Mock the context and capture the arguments passed to it + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, buf, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + // Get the encrypted value and pass it to the decryption handler + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + encrypted.transferTo(channel, 0); + ciphertextBuffer.flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + // Capture the decrypted values and verify them + ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, ciphertext); + verify(ctx, times(2)) + .fireChannelRead(captorPlaintext.capture()); + ByteBuf lastPlaintextSegment = captorPlaintext.getValue(); + assertEquals(plaintextSegmentSize/2, + lastPlaintextSegment.readableBytes()); + assertEquals('c', + lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10)); + } + } + + static class FakeRegion extends AbstractFileRegion { + private final ByteBuffer[] source; + private int sourcePosition; + private final long count; + + FakeRegion(ByteBuffer... source) { + this.source = source; + sourcePosition = 0; + count = remaining(); + } + + private long remaining() { + long remaining = 0; + for (ByteBuffer buffer : source) { + remaining += buffer.remaining(); + } + return remaining; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transferred() { + return count - remaining(); + } + + @Override + public long count() { + return count; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + if (sourcePosition < source.length) { + ByteBuffer currentBuffer = source[sourcePosition]; + long written = target.write(currentBuffer); + if (!currentBuffer.hasRemaining()) { + sourcePosition++; + } + return written; + } else { + return 0; + } + } + + @Override + protected void deallocate() { + } + } + + private static ByteBuffer getTestByteBuf(int size, byte fill) { + byte[] data = new byte[size]; + Arrays.fill(data, fill); + return ByteBuffer.wrap(data); + } + + @Test + public void testGcmEncryptedMessageFileRegion() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + // Verify that it derives a GcmTransportCipher + assert (clientCipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + int halfSegmentSize = plaintextSegmentSize / 2; + int totalSize = plaintextSegmentSize + halfSegmentSize; + + // Set up some fragmented segments to test + ByteBuffer halfSegment = getTestByteBuf(halfSegmentSize, (byte) 'a'); + int smallFragmentSize = 128; + ByteBuffer smallFragment = getTestByteBuf(smallFragmentSize, (byte) 'b'); + int remainderSize = totalSize - halfSegmentSize - smallFragmentSize; + ByteBuffer remainder = getTestByteBuf(remainderSize, (byte) 'c'); + FakeRegion fakeRegion = new FakeRegion(halfSegment, smallFragment, remainder); + assertEquals(totalSize, fakeRegion.count()); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, fakeRegion, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + // Get the encrypted value and pass it to the decryption handler + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + + // We'll simulate the FileRegion only transferring half a segment. + // The encrypted message should buffer the partial segment plaintext. + long ciphertextTransferred = 0; + while (ciphertextTransferred < encrypted.count()) { + long chunkTransferred = encrypted.transferTo(channel, 0); + ciphertextTransferred += chunkTransferred; + } + assertEquals(encrypted.count(), ciphertextTransferred); + + ciphertextBuffer.flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + // Capture the decrypted values and verify them + ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, ciphertext); + verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture()); + ByteBuf plaintext = captorPlaintext.getValue(); + // We expect this to be the last partial plaintext segment + int expectedLength = totalSize % plaintextSegmentSize; + assertEquals(expectedLength, plaintext.readableBytes()); + // This will be the "remainder" segment that is filled to 'c' + assertEquals('c', plaintext.getByte(0)); + } + } + + + @Test + public void testGcmUnalignedDecryption() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + // Verify that it derives a GcmTransportCipher + assert (clientCipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + int plaintextSize = plaintextSegmentSize + (plaintextSegmentSize / 2); + byte[] data = new byte[plaintextSize]; + Arrays.fill(data, (byte) 'x'); + ByteBuf buf = Unpooled.wrappedBuffer(data); + + // Mock the context and capture the arguments passed to it + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, buf, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + // Get the encrypted value and pass it to the decryption handler + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + encrypted.transferTo(channel, 0); + ciphertextBuffer.flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + // Split up the ciphertext into some different sized chunks + int firstChunkSize = plaintextSize / 2; + ByteBuf mockCiphertext = spy(ciphertext); + when(mockCiphertext.readableBytes()) + .thenReturn(firstChunkSize, firstChunkSize).thenCallRealMethod(); + + // Capture the decrypted values and verify them + ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, mockCiphertext); + verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture()); + ByteBuf lastPlaintextSegment = captorPlaintext.getValue(); + assertEquals(plaintextSegmentSize/2, + lastPlaintextSegment.readableBytes()); + assertEquals('x', + lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10)); + } + } + + @Test + public void testCorruptGcmEncryptedMessage() throws Exception { + TransportConf gcmConf = getConf(2, false); + + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + + TransportCipher clientCipher = server.sessionCipher(); + assert (clientCipher instanceof GcmTransportCipher); + + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + byte[] zeroData = new byte[1024 * 32]; + // Just writing some bytes. + ByteBuf buf = Unpooled.wrappedBuffer(zeroData); + + // Mock the context and capture the arguments passed to it + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, buf, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + encrypted.transferTo(channel, 0); + ciphertextBuffer.flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + byte b = ciphertext.getByte(100); + // Inverting the bits of the 100th bit + ciphertext.setByte(100, ~b & 0xFF); + assertThrows(AEADBadTagException.class, () -> decryptionHandler.channelRead(ctx, ciphertext)); + } + } +}