From e5b456aa7001489185f0875dc26222740711e057 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Tue, 9 Jul 2024 21:55:14 +0200 Subject: [PATCH] GH-524: Optimize the chacha20-poly1305 cipher Unroll the permutation loop of ChaCha20 and simplify the crypt() method to work on bytes instead of ints. Unroll and optimize the pack/unpack methods in ChaCha20 and in Poly1305. Don't use Integer.toUnsignedLong(): avoid the extra method call by doing (int & 0xFFFF_FFFFL) inline. In Poly1305.processBlock(), use ints instead of longs where possible (variables t0-t3). All this brings a speed-up of about 40% for the encryption/decryption. Of the time spent in ChaCha20-Poly1305 about one third is spent in the ChaChaEngine and two thirds in Poly1305.processBlock(). Bug: https://github.com/apache/mina-sshd/issues/524 --- .../sshd/common/cipher/ChaCha20Cipher.java | 178 ++++++++++++------ .../apache/sshd/common/mac/Poly1305Mac.java | 43 ++--- 2 files changed, 138 insertions(+), 83 deletions(-) diff --git a/sshd-common/src/main/java/org/apache/sshd/common/cipher/ChaCha20Cipher.java b/sshd-common/src/main/java/org/apache/sshd/common/cipher/ChaCha20Cipher.java index 9d4023db0..71eec6e3f 100644 --- a/sshd-common/src/main/java/org/apache/sshd/common/cipher/ChaCha20Cipher.java +++ b/sshd-common/src/main/java/org/apache/sshd/common/cipher/ChaCha20Cipher.java @@ -144,11 +144,11 @@ protected static class ChaChaEngine { private static final int NONCE_OFFSET = 14; private static final int NONCE_BYTES = 8; private static final int NONCE_INTS = NONCE_BYTES / Integer.BYTES; - private static final int[] ENGINE_STATE_HEADER - = unpackSigmaString("expand 32-byte k".getBytes(StandardCharsets.US_ASCII)); + private static final int[] ENGINE_STATE_HEADER = unpackSigmaString( + "expand 32-byte k".getBytes(StandardCharsets.US_ASCII)); - protected final int[] x = new int[BLOCK_INTS]; protected final int[] engineState = new int[BLOCK_INTS]; + protected final byte[] keyStream = new byte[BLOCK_BYTES]; protected final byte[] nonce = new byte[NONCE_BYTES]; protected long initialNonce; @@ -181,19 +181,12 @@ protected void initCounter(long counter) { // one-shot usage protected void crypt(byte[] in, int offset, int length, byte[] out, int outOffset) { while (length > 0) { - System.arraycopy(engineState, 0, x, 0, BLOCK_INTS); - permute(x); + setKeyStream(engineState); int want = Math.min(BLOCK_BYTES, length); - for (int i = 0, j = 0; i < want; i += Integer.BYTES, j++) { - int keyStream = engineState[j] + x[j]; - int take = Math.min(Integer.BYTES, length); - int input = unpackIntLE(in, offset, take); - int output = keyStream ^ input; - packIntLE(output, out, outOffset, take); - offset += take; - outOffset += take; - length -= take; + for (int i = 0; i < want; i++) { + out[outOffset++] = (byte) (in[offset++] ^ keyStream[i]); } + length -= want; int lo = ++engineState[COUNTER_OFFSET]; if (lo == 0) { // overflow @@ -210,56 +203,122 @@ protected byte[] polyKey() { return block; } - protected static void permute(int[] state) { + protected void setKeyStream(int[] engine) { + int x0 = engine[0]; + int x1 = engine[1]; + int x2 = engine[2]; + int x3 = engine[3]; + int x4 = engine[4]; + int x5 = engine[5]; + int x6 = engine[6]; + int x7 = engine[7]; + int x8 = engine[8]; + int x9 = engine[9]; + int x10 = engine[10]; + int x11 = engine[11]; + int x12 = engine[12]; + int x13 = engine[13]; + int x14 = engine[14]; + int x15 = engine[15]; + for (int i = 0; i < 10; i++) { - columnRound(state); - diagonalRound(state); + // Columns + // Quarter round 0, 4, 8, 12 + x0 += x4; + x12 = Integer.rotateLeft(x12 ^ x0, 16); + x8 += x12; + x4 = Integer.rotateLeft(x4 ^ x8, 12); + x0 += x4; + x12 = Integer.rotateLeft(x12 ^ x0, 8); + x8 += x12; + x4 = Integer.rotateLeft(x4 ^ x8, 7); + // Quarter round 1, 5, 9, 13 + x1 += x5; + x13 = Integer.rotateLeft(x13 ^ x1, 16); + x9 += x13; + x5 = Integer.rotateLeft(x5 ^ x9, 12); + x1 += x5; + x13 = Integer.rotateLeft(x13 ^ x1, 8); + x9 += x13; + x5 = Integer.rotateLeft(x5 ^ x9, 7); + // Quarter round 2, 6, 10, 14 + x2 += x6; + x14 = Integer.rotateLeft(x14 ^ x2, 16); + x10 += x14; + x6 = Integer.rotateLeft(x6 ^ x10, 12); + x2 += x6; + x14 = Integer.rotateLeft(x14 ^ x2, 8); + x10 += x14; + x6 = Integer.rotateLeft(x6 ^ x10, 7); + // Quarter round 3, 7, 11, 15 + x3 += x7; + x15 = Integer.rotateLeft(x15 ^ x3, 16); + x11 += x15; + x7 = Integer.rotateLeft(x7 ^ x11, 12); + x3 += x7; + x15 = Integer.rotateLeft(x15 ^ x3, 8); + x11 += x15; + x7 = Integer.rotateLeft(x7 ^ x11, 7); + // Diagonals + // Quarter round 0, 5, 10, 15 + x0 += x5; + x15 = Integer.rotateLeft(x15 ^ x0, 16); + x10 += x15; + x5 = Integer.rotateLeft(x5 ^ x10, 12); + x0 += x5; + x15 = Integer.rotateLeft(x15 ^ x0, 8); + x10 += x15; + x5 = Integer.rotateLeft(x5 ^ x10, 7); + // Quarter round 1, 6, 11, 12 + x1 += x6; + x12 = Integer.rotateLeft(x12 ^ x1, 16); + x11 += x12; + x6 = Integer.rotateLeft(x6 ^ x11, 12); + x1 += x6; + x12 = Integer.rotateLeft(x12 ^ x1, 8); + x11 += x12; + x6 = Integer.rotateLeft(x6 ^ x11, 7); + // Quarter round 2, 7, 8, 13 + x2 += x7; + x13 = Integer.rotateLeft(x13 ^ x2, 16); + x8 += x13; + x7 = Integer.rotateLeft(x7 ^ x8, 12); + x2 += x7; + x13 = Integer.rotateLeft(x13 ^ x2, 8); + x8 += x13; + x7 = Integer.rotateLeft(x7 ^ x8, 7); + // Quarter round 3, 4, 9, 14 + x3 += x4; + x14 = Integer.rotateLeft(x14 ^ x3, 16); + x9 += x14; + x4 = Integer.rotateLeft(x4 ^ x9, 12); + x3 += x4; + x14 = Integer.rotateLeft(x14 ^ x3, 8); + x9 += x14; + x4 = Integer.rotateLeft(x4 ^ x9, 7); } - } - - protected static void columnRound(int[] state) { - quarterRound(state, 0, 4, 8, 12); - quarterRound(state, 1, 5, 9, 13); - quarterRound(state, 2, 6, 10, 14); - quarterRound(state, 3, 7, 11, 15); - } - - protected static void diagonalRound(int[] state) { - quarterRound(state, 0, 5, 10, 15); - quarterRound(state, 1, 6, 11, 12); - quarterRound(state, 2, 7, 8, 13); - quarterRound(state, 3, 4, 9, 14); - } - - protected static void quarterRound(int[] state, int a, int b, int c, int d) { - state[a] += state[b]; - state[d] = Integer.rotateLeft(state[d] ^ state[a], 16); - - state[c] += state[d]; - state[b] = Integer.rotateLeft(state[b] ^ state[c], 12); - state[a] += state[b]; - state[d] = Integer.rotateLeft(state[d] ^ state[a], 8); - - state[c] += state[d]; - state[b] = Integer.rotateLeft(state[b] ^ state[c], 7); - } - - private static int unpackIntLE(byte[] buf, int off) { - return unpackIntLE(buf, off, Integer.BYTES); - } - - private static int unpackIntLE(byte[] buf, int off, int len) { - int ret = 0; - for (int i = 0; i < len; i++) { - ret |= Byte.toUnsignedInt(buf[off + i]) << i * Byte.SIZE; - } - return ret; + Poly1305Mac.packIntLE(engine[0] + x0, keyStream, 0); + Poly1305Mac.packIntLE(engine[1] + x1, keyStream, 4); + Poly1305Mac.packIntLE(engine[2] + x2, keyStream, 8); + Poly1305Mac.packIntLE(engine[3] + x3, keyStream, 12); + Poly1305Mac.packIntLE(engine[4] + x4, keyStream, 16); + Poly1305Mac.packIntLE(engine[5] + x5, keyStream, 20); + Poly1305Mac.packIntLE(engine[6] + x6, keyStream, 24); + Poly1305Mac.packIntLE(engine[7] + x7, keyStream, 28); + Poly1305Mac.packIntLE(engine[8] + x8, keyStream, 32); + Poly1305Mac.packIntLE(engine[9] + x9, keyStream, 36); + Poly1305Mac.packIntLE(engine[10] + x10, keyStream, 40); + Poly1305Mac.packIntLE(engine[11] + x11, keyStream, 44); + Poly1305Mac.packIntLE(engine[12] + x12, keyStream, 48); + Poly1305Mac.packIntLE(engine[13] + x13, keyStream, 52); + Poly1305Mac.packIntLE(engine[14] + x14, keyStream, 56); + Poly1305Mac.packIntLE(engine[15] + x15, keyStream, 60); } private static void unpackIntsLE(byte[] buf, int off, int nrInts, int[] dst, int dstOff) { for (int i = 0; i < nrInts; i++) { - dst[dstOff++] = unpackIntLE(buf, off); + dst[dstOff++] = Poly1305Mac.unpackIntLE(buf, off); off += Integer.BYTES; } } @@ -270,10 +329,5 @@ private static int[] unpackSigmaString(byte[] buf) { return values; } - private static void packIntLE(int value, byte[] dst, int off, int len) { - for (int i = 0; i < len; i++) { - dst[off + i] = (byte) (value >>> i * Byte.SIZE); - } - } } } diff --git a/sshd-common/src/main/java/org/apache/sshd/common/mac/Poly1305Mac.java b/sshd-common/src/main/java/org/apache/sshd/common/mac/Poly1305Mac.java index dcb8919c9..62c8edb7b 100644 --- a/sshd-common/src/main/java/org/apache/sshd/common/mac/Poly1305Mac.java +++ b/sshd-common/src/main/java/org/apache/sshd/common/mac/Poly1305Mac.java @@ -163,10 +163,10 @@ public void doFinal(byte[] out, int offset) throws Exception { h3 = h3 & nb | g3 & b; h4 = h4 & nb | g4 & b; - long f0 = Integer.toUnsignedLong(h0 | h1 << 26) + Integer.toUnsignedLong(k0); - long f1 = Integer.toUnsignedLong(h1 >>> 6 | h2 << 20) + Integer.toUnsignedLong(k1); - long f2 = Integer.toUnsignedLong(h2 >>> 12 | h3 << 14) + Integer.toUnsignedLong(k2); - long f3 = Integer.toUnsignedLong(h3 >>> 18 | h4 << 8) + Integer.toUnsignedLong(k3); + long f0 = ((h0 | h1 << 26) & 0xFFFF_FFFFL) + (k0 & 0xFFFF_FFFFL); + long f1 = ((h1 >>> 6 | h2 << 20) & 0xFFFF_FFFFL) + (k1 & 0xFFFF_FFFFL); + long f2 = ((h2 >>> 12 | h3 << 14) & 0xFFFF_FFFFL) + (k2 & 0xFFFF_FFFFL); + long f3 = ((h3 >>> 18 | h4 << 8) & 0xFFFF_FFFFL) + (k3 & 0xFFFF_FFFFL); packIntLE((int) f0, out, offset); f1 += f0 >>> 32; @@ -188,15 +188,15 @@ private void processBlock() { } } - long t0 = Integer.toUnsignedLong(unpackIntLE(currentBlock, 0)); - long t1 = Integer.toUnsignedLong(unpackIntLE(currentBlock, 4)); - long t2 = Integer.toUnsignedLong(unpackIntLE(currentBlock, 8)); - long t3 = Integer.toUnsignedLong(unpackIntLE(currentBlock, 12)); + int t0 = unpackIntLE(currentBlock, 0); + int t1 = unpackIntLE(currentBlock, 4); + int t2 = unpackIntLE(currentBlock, 8); + int t3 = unpackIntLE(currentBlock, 12); h0 += t0 & 0x3ffffff; - h1 += (t1 << 32 | t0) >>> 26 & 0x3ffffff; - h2 += (t2 << 32 | t1) >>> 20 & 0x3ffffff; - h3 += (t3 << 32 | t2) >>> 14 & 0x3ffffff; + h1 += (t0 >>> 26 | t1 << 6) & 0x3ffffff; + h2 += (t1 >>> 20 | t2 << 12) & 0x3ffffff; + h3 += (t2 >>> 14 | t3 << 18) & 0x3ffffff; h4 += t3 >>> 8; if (currentBlockOffset == BLOCK_SIZE) { @@ -250,21 +250,22 @@ public int getDefaultBlockSize() { return BLOCK_SIZE; } - private static int unpackIntLE(byte[] buf, int off) { - int ret = 0; - for (int i = 0; i < Integer.BYTES; i++) { - ret |= Byte.toUnsignedInt(buf[off + i]) << i * Byte.SIZE; - } + public static int unpackIntLE(byte[] buf, int off) { + int ret = buf[off++] & 0xFF; + ret |= (buf[off++] & 0xFF) << 8; + ret |= (buf[off++] & 0xFF) << 16; + ret |= (buf[off] & 0xFF) << 24; return ret; } - private static void packIntLE(int value, byte[] dst, int off) { - for (int i = 0; i < Integer.BYTES; i++) { - dst[off + i] = (byte) (value >>> i * Byte.SIZE); - } + public static void packIntLE(int value, byte[] dst, int off) { + dst[off++] = (byte) value; + dst[off++] = (byte) (value >> 8); + dst[off++] = (byte) (value >> 16); + dst[off] = (byte) (value >> 24); } private static long unsignedProduct(int i1, int i2) { - return Integer.toUnsignedLong(i1) * Integer.toUnsignedLong(i2); + return (i1 & 0xFFFF_FFFFL) * i2; } }