Skip to content

Commit

Permalink
Refactoring in ML-KEM
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdettman committed Nov 28, 2024
1 parent f6f6a6c commit efe1f51
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,16 @@ public void init(SecureRandom random)
this.random = random;
}

public byte[][] generateKemKeyPair()
{
byte[] d = new byte[KyberSymBytes];
byte[] z = new byte[KyberSymBytes];
random.nextBytes(d);
random.nextBytes(z);

return generateKemKeyPairInternal(d, z);
}

//Internal functions are deterministic. No randomness is sampled inside them
public byte[][] generateKemKeyPairInternal(byte[] d, byte[] z)
{
Expand All @@ -202,7 +212,15 @@ public byte[][] generateKemKeyPairInternal(byte[] d, byte[] z)

byte[] outputPublicKey = new byte[KyberIndCpaPublicKeyBytes];
System.arraycopy(indCpaKeyPair[0], 0, outputPublicKey, 0, KyberIndCpaPublicKeyBytes);
return new byte[][]{ Arrays.copyOfRange(outputPublicKey, 0, outputPublicKey.length - 32), Arrays.copyOfRange(outputPublicKey, outputPublicKey.length - 32, outputPublicKey.length), s, hashedPublicKey, z, Arrays.concatenate(d, z)};
return new byte[][]
{
Arrays.copyOfRange(outputPublicKey, 0, outputPublicKey.length - 32),
Arrays.copyOfRange(outputPublicKey, outputPublicKey.length - 32, outputPublicKey.length),
s,
hashedPublicKey,
z,
Arrays.concatenate(d, z)
};
}

public byte[][] kemEncryptInternal(byte[] publicKeyInput, byte[] randBytes)
Expand Down Expand Up @@ -263,16 +281,6 @@ public byte[] kemDecryptInternal(byte[] secretKey, byte[] cipherText)
return Arrays.copyOfRange(kr, 0, sessionKeyLength);
}

public byte[][] generateKemKeyPair()
{
byte[] d = new byte[KyberSymBytes];
byte[] z = new byte[KyberSymBytes];
random.nextBytes(d);
random.nextBytes(z);

return generateKemKeyPairInternal(d, z);
}

public byte[][] kemEncrypt(byte[] publicKeyInput, byte[] randBytes)
{
//TODO: do input validation elsewhere?
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package org.bouncycastle.pqc.crypto.mlkem;

import org.bouncycastle.crypto.digests.SHAKEDigest;
import org.bouncycastle.util.Arrays;

class MLKEMIndCpa
{
private MLKEMEngine engine;
private int kyberK;
private int eta1;
private int indCpaPublicKeyBytes;
private int polyVecBytes;
private int indCpaBytes;
Expand All @@ -20,7 +18,6 @@ public MLKEMIndCpa(MLKEMEngine engine)
{
this.engine = engine;
this.kyberK = engine.getKyberK();
this.eta1 = engine.getKyberEta1();
this.indCpaPublicKeyBytes = engine.getKyberPublicKeyBytes();
this.polyVecBytes = engine.getKyberPolyVecBytes();
this.indCpaBytes = engine.getKyberIndCpaBytes();
Expand Down Expand Up @@ -54,9 +51,7 @@ byte[][] generateKeyPair(byte[] d)
// (p, sigma) <- G(d || k)

byte[] buf = new byte[64];
byte[] k = new byte[1];
k[0] = (byte)kyberK;
symmetric.hash_g(buf, Arrays.concatenate(d, k));
symmetric.hash_g(buf, Arrays.append(d, (byte)kyberK));

byte[] publicSeed = new byte[32]; // p in docs
byte[] noiseSeed = new byte[32]; // sigma in docs
Expand Down Expand Up @@ -320,7 +315,6 @@ public void unpackSecretKey(PolyVec secretKeyPolyVec, byte[] secretKey)
public void generateMatrix(PolyVec[] aMatrix, byte[] seed, boolean transposed)
{
int i, j, k, ctr, off;
SHAKEDigest kyberXOF;
byte[] buf = new byte[KyberGenerateMatrixNBlocks * symmetric.xofBlockBytes + 2];
for (i = 0; i < kyberK; i++)
{
Expand Down Expand Up @@ -383,7 +377,6 @@ private static int rejectionSampling(Poly outputBuffer, int coeffOff, int len, b

public byte[] decrypt(byte[] secretKey, byte[] cipherText)
{
int i;
byte[] outputMessage = new byte[MLKEMEngine.getKyberIndCpaMsgBytes()];

PolyVec bp = new PolyVec(engine), secretKeyPolyVec = new PolyVec(engine);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ public AsymmetricCipherKeyPair internalGenerateKeyPair(byte[] d, byte[] z)
byte[][] keyPair = mlkemParams.getEngine().generateKemKeyPairInternal(d, z);

MLKEMPublicKeyParameters pubKey = new MLKEMPublicKeyParameters(mlkemParams, keyPair[0], keyPair[1]);
MLKEMPrivateKeyParameters privKey = new MLKEMPrivateKeyParameters(mlkemParams, keyPair[2], keyPair[3], keyPair[4], keyPair[0], keyPair[1]);
MLKEMPrivateKeyParameters privKey = new MLKEMPrivateKeyParameters(mlkemParams, keyPair[2], keyPair[3], keyPair[4], keyPair[0], keyPair[1], keyPair[5]);

return new AsymmetricCipherKeyPair(pubKey, privKey);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public void testMlKemRequestWithMlDsaCA()
new CMSProcessableCMPCertificate(cert),
new JceCMSContentEncryptorBuilder(CMSAlgorithm.AES128_CBC).setProvider("BC").build());

System.err.println(ASN1Dump.dumpAsString(encryptedCert.toASN1Structure()));
// System.err.println(ASN1Dump.dumpAsString(encryptedCert.toASN1Structure()));
CertificateResponseBuilder certRespBuilder = new CertificateResponseBuilder(senderReqMessage.getCertReqId(), new PKIStatusInfo(PKIStatus.granted));

certRespBuilder.withCertificate(encryptedCert);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public void setUp()
public void testParametersAndParamSpecs()
throws Exception
{
MLKEMParameters mldsaParameters[] = new MLKEMParameters[]
MLKEMParameters mlKemParameters[] = new MLKEMParameters[]
{
MLKEMParameters.ml_kem_512,
MLKEMParameters.ml_kem_768,
Expand All @@ -63,7 +63,7 @@ public void testParametersAndParamSpecs()

for (int i = 0; i != names.length; i++)
{
assertEquals(names[i], MLKEMParameterSpec.fromName(mldsaParameters[i].getName()).getName());
assertEquals(names[i], MLKEMParameterSpec.fromName(mlKemParameters[i].getName()).getName());
}

for (int i = 0; i != names.length; i++)
Expand Down

0 comments on commit efe1f51

Please sign in to comment.