Skip to content

Commit

Permalink
apacheGH-524: Optimize the chacha20-poly1305 cipher
Browse files Browse the repository at this point in the history
Unroll the permutation loop of ChaCha20 and simplify the crypt() method
to work on bytes instead of ints. Unroll and optimize the pack/unpack
methods in ChaCha20 and in Poly1305. Don't use Integer.toUnsignedLong():
avoid the extra method call by doing (int & 0xFFFF_FFFFL) inline. In
Poly1305.processBlock(), use ints instead of longs where possible
(variables t0-t3).

All this brings a speed-up of about 40% for the encryption/decryption.
Of the time spent in ChaCha20-Poly1305 about one third is spent in the
ChaChaEngine and two thirds in Poly1305.processBlock().

Bug: apache#524
  • Loading branch information
tomaswolf committed Jul 15, 2024
1 parent 5b00c1f commit e5b456a
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ protected static class ChaChaEngine {
private static final int NONCE_OFFSET = 14;
private static final int NONCE_BYTES = 8;
private static final int NONCE_INTS = NONCE_BYTES / Integer.BYTES;
private static final int[] ENGINE_STATE_HEADER
= unpackSigmaString("expand 32-byte k".getBytes(StandardCharsets.US_ASCII));
private static final int[] ENGINE_STATE_HEADER = unpackSigmaString(
"expand 32-byte k".getBytes(StandardCharsets.US_ASCII));

protected final int[] x = new int[BLOCK_INTS];
protected final int[] engineState = new int[BLOCK_INTS];
protected final byte[] keyStream = new byte[BLOCK_BYTES];
protected final byte[] nonce = new byte[NONCE_BYTES];
protected long initialNonce;

Expand Down Expand Up @@ -181,19 +181,12 @@ protected void initCounter(long counter) {
// one-shot usage
protected void crypt(byte[] in, int offset, int length, byte[] out, int outOffset) {
while (length > 0) {
System.arraycopy(engineState, 0, x, 0, BLOCK_INTS);
permute(x);
setKeyStream(engineState);
int want = Math.min(BLOCK_BYTES, length);
for (int i = 0, j = 0; i < want; i += Integer.BYTES, j++) {
int keyStream = engineState[j] + x[j];
int take = Math.min(Integer.BYTES, length);
int input = unpackIntLE(in, offset, take);
int output = keyStream ^ input;
packIntLE(output, out, outOffset, take);
offset += take;
outOffset += take;
length -= take;
for (int i = 0; i < want; i++) {
out[outOffset++] = (byte) (in[offset++] ^ keyStream[i]);
}
length -= want;
int lo = ++engineState[COUNTER_OFFSET];
if (lo == 0) {
// overflow
Expand All @@ -210,56 +203,122 @@ protected byte[] polyKey() {
return block;
}

protected static void permute(int[] state) {
protected void setKeyStream(int[] engine) {
int x0 = engine[0];
int x1 = engine[1];
int x2 = engine[2];
int x3 = engine[3];
int x4 = engine[4];
int x5 = engine[5];
int x6 = engine[6];
int x7 = engine[7];
int x8 = engine[8];
int x9 = engine[9];
int x10 = engine[10];
int x11 = engine[11];
int x12 = engine[12];
int x13 = engine[13];
int x14 = engine[14];
int x15 = engine[15];

for (int i = 0; i < 10; i++) {
columnRound(state);
diagonalRound(state);
// Columns
// Quarter round 0, 4, 8, 12
x0 += x4;
x12 = Integer.rotateLeft(x12 ^ x0, 16);
x8 += x12;
x4 = Integer.rotateLeft(x4 ^ x8, 12);
x0 += x4;
x12 = Integer.rotateLeft(x12 ^ x0, 8);
x8 += x12;
x4 = Integer.rotateLeft(x4 ^ x8, 7);
// Quarter round 1, 5, 9, 13
x1 += x5;
x13 = Integer.rotateLeft(x13 ^ x1, 16);
x9 += x13;
x5 = Integer.rotateLeft(x5 ^ x9, 12);
x1 += x5;
x13 = Integer.rotateLeft(x13 ^ x1, 8);
x9 += x13;
x5 = Integer.rotateLeft(x5 ^ x9, 7);
// Quarter round 2, 6, 10, 14
x2 += x6;
x14 = Integer.rotateLeft(x14 ^ x2, 16);
x10 += x14;
x6 = Integer.rotateLeft(x6 ^ x10, 12);
x2 += x6;
x14 = Integer.rotateLeft(x14 ^ x2, 8);
x10 += x14;
x6 = Integer.rotateLeft(x6 ^ x10, 7);
// Quarter round 3, 7, 11, 15
x3 += x7;
x15 = Integer.rotateLeft(x15 ^ x3, 16);
x11 += x15;
x7 = Integer.rotateLeft(x7 ^ x11, 12);
x3 += x7;
x15 = Integer.rotateLeft(x15 ^ x3, 8);
x11 += x15;
x7 = Integer.rotateLeft(x7 ^ x11, 7);
// Diagonals
// Quarter round 0, 5, 10, 15
x0 += x5;
x15 = Integer.rotateLeft(x15 ^ x0, 16);
x10 += x15;
x5 = Integer.rotateLeft(x5 ^ x10, 12);
x0 += x5;
x15 = Integer.rotateLeft(x15 ^ x0, 8);
x10 += x15;
x5 = Integer.rotateLeft(x5 ^ x10, 7);
// Quarter round 1, 6, 11, 12
x1 += x6;
x12 = Integer.rotateLeft(x12 ^ x1, 16);
x11 += x12;
x6 = Integer.rotateLeft(x6 ^ x11, 12);
x1 += x6;
x12 = Integer.rotateLeft(x12 ^ x1, 8);
x11 += x12;
x6 = Integer.rotateLeft(x6 ^ x11, 7);
// Quarter round 2, 7, 8, 13
x2 += x7;
x13 = Integer.rotateLeft(x13 ^ x2, 16);
x8 += x13;
x7 = Integer.rotateLeft(x7 ^ x8, 12);
x2 += x7;
x13 = Integer.rotateLeft(x13 ^ x2, 8);
x8 += x13;
x7 = Integer.rotateLeft(x7 ^ x8, 7);
// Quarter round 3, 4, 9, 14
x3 += x4;
x14 = Integer.rotateLeft(x14 ^ x3, 16);
x9 += x14;
x4 = Integer.rotateLeft(x4 ^ x9, 12);
x3 += x4;
x14 = Integer.rotateLeft(x14 ^ x3, 8);
x9 += x14;
x4 = Integer.rotateLeft(x4 ^ x9, 7);
}
}

protected static void columnRound(int[] state) {
quarterRound(state, 0, 4, 8, 12);
quarterRound(state, 1, 5, 9, 13);
quarterRound(state, 2, 6, 10, 14);
quarterRound(state, 3, 7, 11, 15);
}

protected static void diagonalRound(int[] state) {
quarterRound(state, 0, 5, 10, 15);
quarterRound(state, 1, 6, 11, 12);
quarterRound(state, 2, 7, 8, 13);
quarterRound(state, 3, 4, 9, 14);
}

protected static void quarterRound(int[] state, int a, int b, int c, int d) {
state[a] += state[b];
state[d] = Integer.rotateLeft(state[d] ^ state[a], 16);

state[c] += state[d];
state[b] = Integer.rotateLeft(state[b] ^ state[c], 12);

state[a] += state[b];
state[d] = Integer.rotateLeft(state[d] ^ state[a], 8);

state[c] += state[d];
state[b] = Integer.rotateLeft(state[b] ^ state[c], 7);
}

private static int unpackIntLE(byte[] buf, int off) {
return unpackIntLE(buf, off, Integer.BYTES);
}

private static int unpackIntLE(byte[] buf, int off, int len) {
int ret = 0;
for (int i = 0; i < len; i++) {
ret |= Byte.toUnsignedInt(buf[off + i]) << i * Byte.SIZE;
}
return ret;
Poly1305Mac.packIntLE(engine[0] + x0, keyStream, 0);
Poly1305Mac.packIntLE(engine[1] + x1, keyStream, 4);
Poly1305Mac.packIntLE(engine[2] + x2, keyStream, 8);
Poly1305Mac.packIntLE(engine[3] + x3, keyStream, 12);
Poly1305Mac.packIntLE(engine[4] + x4, keyStream, 16);
Poly1305Mac.packIntLE(engine[5] + x5, keyStream, 20);
Poly1305Mac.packIntLE(engine[6] + x6, keyStream, 24);
Poly1305Mac.packIntLE(engine[7] + x7, keyStream, 28);
Poly1305Mac.packIntLE(engine[8] + x8, keyStream, 32);
Poly1305Mac.packIntLE(engine[9] + x9, keyStream, 36);
Poly1305Mac.packIntLE(engine[10] + x10, keyStream, 40);
Poly1305Mac.packIntLE(engine[11] + x11, keyStream, 44);
Poly1305Mac.packIntLE(engine[12] + x12, keyStream, 48);
Poly1305Mac.packIntLE(engine[13] + x13, keyStream, 52);
Poly1305Mac.packIntLE(engine[14] + x14, keyStream, 56);
Poly1305Mac.packIntLE(engine[15] + x15, keyStream, 60);
}

private static void unpackIntsLE(byte[] buf, int off, int nrInts, int[] dst, int dstOff) {
for (int i = 0; i < nrInts; i++) {
dst[dstOff++] = unpackIntLE(buf, off);
dst[dstOff++] = Poly1305Mac.unpackIntLE(buf, off);
off += Integer.BYTES;
}
}
Expand All @@ -270,10 +329,5 @@ private static int[] unpackSigmaString(byte[] buf) {
return values;
}

private static void packIntLE(int value, byte[] dst, int off, int len) {
for (int i = 0; i < len; i++) {
dst[off + i] = (byte) (value >>> i * Byte.SIZE);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ public void doFinal(byte[] out, int offset) throws Exception {
h3 = h3 & nb | g3 & b;
h4 = h4 & nb | g4 & b;

long f0 = Integer.toUnsignedLong(h0 | h1 << 26) + Integer.toUnsignedLong(k0);
long f1 = Integer.toUnsignedLong(h1 >>> 6 | h2 << 20) + Integer.toUnsignedLong(k1);
long f2 = Integer.toUnsignedLong(h2 >>> 12 | h3 << 14) + Integer.toUnsignedLong(k2);
long f3 = Integer.toUnsignedLong(h3 >>> 18 | h4 << 8) + Integer.toUnsignedLong(k3);
long f0 = ((h0 | h1 << 26) & 0xFFFF_FFFFL) + (k0 & 0xFFFF_FFFFL);
long f1 = ((h1 >>> 6 | h2 << 20) & 0xFFFF_FFFFL) + (k1 & 0xFFFF_FFFFL);
long f2 = ((h2 >>> 12 | h3 << 14) & 0xFFFF_FFFFL) + (k2 & 0xFFFF_FFFFL);
long f3 = ((h3 >>> 18 | h4 << 8) & 0xFFFF_FFFFL) + (k3 & 0xFFFF_FFFFL);

packIntLE((int) f0, out, offset);
f1 += f0 >>> 32;
Expand All @@ -188,15 +188,15 @@ private void processBlock() {
}
}

long t0 = Integer.toUnsignedLong(unpackIntLE(currentBlock, 0));
long t1 = Integer.toUnsignedLong(unpackIntLE(currentBlock, 4));
long t2 = Integer.toUnsignedLong(unpackIntLE(currentBlock, 8));
long t3 = Integer.toUnsignedLong(unpackIntLE(currentBlock, 12));
int t0 = unpackIntLE(currentBlock, 0);
int t1 = unpackIntLE(currentBlock, 4);
int t2 = unpackIntLE(currentBlock, 8);
int t3 = unpackIntLE(currentBlock, 12);

h0 += t0 & 0x3ffffff;
h1 += (t1 << 32 | t0) >>> 26 & 0x3ffffff;
h2 += (t2 << 32 | t1) >>> 20 & 0x3ffffff;
h3 += (t3 << 32 | t2) >>> 14 & 0x3ffffff;
h1 += (t0 >>> 26 | t1 << 6) & 0x3ffffff;
h2 += (t1 >>> 20 | t2 << 12) & 0x3ffffff;
h3 += (t2 >>> 14 | t3 << 18) & 0x3ffffff;
h4 += t3 >>> 8;

if (currentBlockOffset == BLOCK_SIZE) {
Expand Down Expand Up @@ -250,21 +250,22 @@ public int getDefaultBlockSize() {
return BLOCK_SIZE;
}

private static int unpackIntLE(byte[] buf, int off) {
int ret = 0;
for (int i = 0; i < Integer.BYTES; i++) {
ret |= Byte.toUnsignedInt(buf[off + i]) << i * Byte.SIZE;
}
public static int unpackIntLE(byte[] buf, int off) {
int ret = buf[off++] & 0xFF;
ret |= (buf[off++] & 0xFF) << 8;
ret |= (buf[off++] & 0xFF) << 16;
ret |= (buf[off] & 0xFF) << 24;
return ret;
}

private static void packIntLE(int value, byte[] dst, int off) {
for (int i = 0; i < Integer.BYTES; i++) {
dst[off + i] = (byte) (value >>> i * Byte.SIZE);
}
public static void packIntLE(int value, byte[] dst, int off) {
dst[off++] = (byte) value;
dst[off++] = (byte) (value >> 8);
dst[off++] = (byte) (value >> 16);
dst[off] = (byte) (value >> 24);
}

private static long unsignedProduct(int i1, int i2) {
return Integer.toUnsignedLong(i1) * Integer.toUnsignedLong(i2);
return (i1 & 0xFFFF_FFFFL) * i2;
}
}

0 comments on commit e5b456a

Please sign in to comment.