Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for X25519MLKEM768 hybrid algorithm #1989

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ private enum All
OQS_mlkem1024(NamedGroup.OQS_mlkem1024, "ML-KEM"),
MLKEM512(NamedGroup.MLKEM512, "ML-KEM"),
MLKEM768(NamedGroup.MLKEM768, "ML-KEM"),
MLKEM1024(NamedGroup.MLKEM1024, "ML-KEM");
MLKEM1024(NamedGroup.MLKEM1024, "ML-KEM"),

X25519MLKEM768(NamedGroup.X25519MLKEM768, "ML-KEM");

private final int namedGroup;
private final String name;
Expand Down
8 changes: 8 additions & 0 deletions tls/src/main/java/org/bouncycastle/tls/NamedGroup.java
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ public class NamedGroup
public static final int MLKEM768 = 0x0768;
public static final int MLKEM1024 = 0x1024;

/*
* draft-kwiatkowski-tls-ecdhe-mlkem-03
*/
public static final int X25519MLKEM768 = 0x11EC;

/* Names of the actual underlying elliptic curves (not necessarily matching the NamedGroup names). */
private static final String[] CURVE_NAMES = new String[]{ "sect163k1", "sect163r1", "sect163r2", "sect193r1",
"sect193r2", "sect233k1", "sect233r1", "sect239k1", "sect283k1", "sect283r1", "sect409k1", "sect409r1",
Expand Down Expand Up @@ -310,6 +315,8 @@ public static String getKemName(int namedGroup)
case OQS_mlkem1024:
case MLKEM1024:
return "ML-KEM-1024";
case X25519MLKEM768:
return "X25519MLKEM768";
default:
return null;
}
Expand Down Expand Up @@ -502,6 +509,7 @@ public static boolean refersToASpecificKem(int namedGroup)
case MLKEM512:
case MLKEM768:
case MLKEM1024:
case X25519MLKEM768:
return true;
default:
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,13 @@ public TlsECDomain createECDomain(TlsECConfig ecConfig)

public TlsKemDomain createKemDomain(TlsKemConfig kemConfig)
{
return new BcTlsMLKemDomain(this, kemConfig);
switch (kemConfig.getNamedGroup())
{
case NamedGroup.X25519MLKEM768:
return new BcTlsX25519MLKemDomain(this, kemConfig);
default:
return new BcTlsMLKemDomain(this, kemConfig);
}
}

public TlsNonceGenerator createNonceGenerator(byte[] additionalSeedMaterial)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public static MLKEMParameters getDomainParameters(TlsKemConfig kemConfig)
return MLKEMParameters.ml_kem_512;
case NamedGroup.OQS_mlkem768:
case NamedGroup.MLKEM768:
case NamedGroup.X25519MLKEM768:
return MLKEMParameters.ml_kem_768;
case NamedGroup.OQS_mlkem1024:
case NamedGroup.MLKEM1024:
Expand All @@ -47,6 +48,11 @@ public BcTlsMLKemDomain(BcTlsCrypto crypto, TlsKemConfig kemConfig)
this.isServer = kemConfig.isServer();
}

public TlsKemConfig getTlsKemConfig()
{
return this.config;
}

public BcTlsSecret adoptLocalSecret(byte[] secret)
{
return crypto.adoptLocalSecret(secret);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package org.bouncycastle.tls.crypto.impl.bc;

import java.io.IOException;

import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.crypto.SecretWithEncapsulation;
import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters;
import org.bouncycastle.pqc.crypto.mlkem.MLKEMPublicKeyParameters;
import org.bouncycastle.tls.crypto.TlsAgreement;
import org.bouncycastle.tls.crypto.TlsSecret;
import org.bouncycastle.util.Arrays;

public class BcTlsX25519MLKem implements TlsAgreement
{
protected final BcTlsX25519MLKemDomain domain;

protected AsymmetricCipherKeyPair mlkemLocalKeyPair;
protected MLKEMPublicKeyParameters mlkemPeerPublicKey;
protected byte[] x25519PrivateKey;
protected byte[] x25519PeerPublicKey;

protected byte[] mlkemCiphertext;
protected byte[] mlkemSecret;

public BcTlsX25519MLKem(BcTlsX25519MLKemDomain domain)
{
this.domain = domain;
}

public byte[] generateEphemeral() throws IOException
{
this.x25519PrivateKey = domain.generateX25519PrivateKey();
byte[] x25519Key = domain.getX25519PublicKey(x25519PrivateKey);
byte[] mlkemKey;
if (domain.getKemDomain().getTlsKemConfig().isServer())
{
mlkemKey = Arrays.clone(mlkemCiphertext);
}
else
{
this.mlkemLocalKeyPair = domain.getKemDomain().generateKeyPair();
mlkemKey = domain.getKemDomain().encodePublicKey((MLKEMPublicKeyParameters)mlkemLocalKeyPair.getPublic());
}
return Arrays.concatenate(mlkemKey, x25519Key);
}

public void receivePeerValue(byte[] peerValue) throws IOException
{
this.x25519PeerPublicKey = Arrays.copyOfRange(peerValue, peerValue.length - domain.getX25519PublicKeyByteLength(), peerValue.length);
byte[] mlkemKey = Arrays.copyOf(peerValue, peerValue.length - domain.getX25519PublicKeyByteLength());
if (domain.getKemDomain().getTlsKemConfig().isServer())
{
this.mlkemPeerPublicKey = domain.getKemDomain().decodePublicKey(mlkemKey);
SecretWithEncapsulation encap = domain.getKemDomain().encapsulate(mlkemPeerPublicKey);
mlkemCiphertext = encap.getEncapsulation();
mlkemSecret = encap.getSecret();
}
else
{
this.mlkemCiphertext = Arrays.clone(mlkemKey);
}
}

public TlsSecret calculateSecret() throws IOException
{
byte[] x25519Secret = domain.calculateX25519Secret(x25519PrivateKey, x25519PeerPublicKey);
if (!domain.getKemDomain().getTlsKemConfig().isServer())
{
mlkemSecret = domain.getKemDomain().decapsulate((MLKEMPrivateKeyParameters)mlkemLocalKeyPair.getPrivate(), mlkemCiphertext).extract();
}
return domain.getKemDomain().adoptLocalSecret(Arrays.concatenate(mlkemSecret, x25519Secret));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package org.bouncycastle.tls.crypto.impl.bc;

import java.io.IOException;
import org.bouncycastle.math.ec.rfc7748.X25519;
import org.bouncycastle.tls.AlertDescription;
import org.bouncycastle.tls.TlsFatalAlert;
import org.bouncycastle.tls.crypto.TlsAgreement;
import org.bouncycastle.tls.crypto.TlsKemConfig;
import org.bouncycastle.tls.crypto.TlsKemDomain;

public class BcTlsX25519MLKemDomain implements TlsKemDomain
{
protected final BcTlsMLKemDomain kemDomain;
protected final BcTlsCrypto crypto;

public BcTlsX25519MLKemDomain(BcTlsCrypto crypto, TlsKemConfig kemConfig)
{
this.kemDomain = new BcTlsMLKemDomain(crypto, kemConfig);
this.crypto = crypto;
}

public TlsAgreement createKem()
{
return new BcTlsX25519MLKem(this);
}

public BcTlsMLKemDomain getKemDomain()
{
return kemDomain;
}

public byte[] generateX25519PrivateKey() throws IOException
{
byte[] privateKey = new byte[X25519.SCALAR_SIZE];
crypto.getSecureRandom().nextBytes(privateKey);
return privateKey;
}

public byte[] getX25519PublicKey(byte[] privateKey) throws IOException
{
byte[] publicKey = new byte[X25519.POINT_SIZE];
X25519.scalarMultBase(privateKey, 0, publicKey, 0);
return publicKey;
}

public int getX25519PublicKeyByteLength() throws IOException
{
return X25519.POINT_SIZE;
}

public byte[] calculateX25519Secret(byte[] privateKey, byte[] peerPublicKey) throws IOException
{
byte[] secret = new byte[X25519.POINT_SIZE];
if (!X25519.calculateAgreement(privateKey, 0, peerPublicKey, 0, secret, 0))
{
throw new TlsFatalAlert(AlertDescription.handshake_failure);
}
return secret;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ else if (NamedGroup.refersToASpecificKem(namedGroup))
case NamedGroup.MLKEM512:
case NamedGroup.MLKEM768:
case NamedGroup.MLKEM1024:
case NamedGroup.X25519MLKEM768:
return null;
}
}
Expand Down Expand Up @@ -858,7 +859,13 @@ public TlsECDomain createECDomain(TlsECConfig ecConfig)

public TlsKemDomain createKemDomain(TlsKemConfig kemConfig)
{
return new JceTlsMLKemDomain(this, kemConfig);
switch (kemConfig.getNamedGroup())
{
case NamedGroup.X25519MLKEM768:
return new JceTlsX25519MLKemDomain(this, kemConfig);
default:
return new JceTlsMLKemDomain(this, kemConfig);
}
}

public TlsSecret hkdfInit(int cryptoHashAlgorithm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public static MLKEMParameters getDomainParameters(TlsKemConfig kemConfig)
return MLKEMParameters.ml_kem_512;
case NamedGroup.OQS_mlkem768:
case NamedGroup.MLKEM768:
case NamedGroup.X25519MLKEM768:
return MLKEMParameters.ml_kem_768;
case NamedGroup.OQS_mlkem1024:
case NamedGroup.MLKEM1024:
Expand All @@ -47,6 +48,11 @@ public JceTlsMLKemDomain(JcaTlsCrypto crypto, TlsKemConfig kemConfig)
this.isServer = kemConfig.isServer();
}

public TlsKemConfig getTlsKemConfig()
{
return config;
}

public JceTlsSecret adoptLocalSecret(byte[] secret)
{
return crypto.adoptLocalSecret(secret);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package org.bouncycastle.tls.crypto.impl.jcajce;

import java.io.IOException;
import java.math.BigInteger;
import java.security.KeyPair;
import java.security.PublicKey;

import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.crypto.SecretWithEncapsulation;
import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters;
import org.bouncycastle.pqc.crypto.mlkem.MLKEMPublicKeyParameters;
import org.bouncycastle.tls.crypto.TlsAgreement;
import org.bouncycastle.tls.crypto.TlsSecret;
import org.bouncycastle.util.Arrays;

public class JceTlsX25519MLKem implements TlsAgreement
{
protected final JceTlsX25519MLKemDomain domain;

protected KeyPair x25519LocalKeyPair;
protected PublicKey x25519PeerPublicKey;
protected AsymmetricCipherKeyPair mlkemLocalKeyPair;
protected MLKEMPublicKeyParameters mlkemPeerPublicKey;

protected byte[] mlkemCiphertext;
protected byte[] mlkemSecret;

public JceTlsX25519MLKem(JceTlsX25519MLKemDomain domain)
{
this.domain = domain;
}

public byte[] generateEphemeral() throws IOException
{
this.x25519LocalKeyPair = domain.generateX25519KeyPair();
byte[] x25519Key = domain.encodeX25519PublicKey(x25519LocalKeyPair.getPublic());
byte[] mlkemKey;
if (domain.getKemDomain().getTlsKemConfig().isServer())
{
mlkemKey = Arrays.clone(mlkemCiphertext);
}
else
{
this.mlkemLocalKeyPair = domain.getKemDomain().generateKeyPair();
mlkemKey = domain.getKemDomain().encodePublicKey((MLKEMPublicKeyParameters)mlkemLocalKeyPair.getPublic());

}
return Arrays.concatenate(mlkemKey, x25519Key);
}

public void receivePeerValue(byte[] peerValue) throws IOException
{
byte[] xdhKey = Arrays.copyOfRange(peerValue, peerValue.length - domain.getX25519PublicKeyByteLength(), peerValue.length);
byte[] mlkemKey = Arrays.copyOf(peerValue,peerValue.length - domain.getX25519PublicKeyByteLength());
this.x25519PeerPublicKey = domain.decodeX25519PublicKey(xdhKey);
if (domain.getKemDomain().getTlsKemConfig().isServer())
{
this.mlkemPeerPublicKey = domain.getKemDomain().decodePublicKey(mlkemKey);
SecretWithEncapsulation encap = domain.getKemDomain().encapsulate(mlkemPeerPublicKey);
this.mlkemCiphertext = encap.getEncapsulation();
mlkemSecret = encap.getSecret();
}
else
{
this.mlkemCiphertext = Arrays.clone(mlkemKey);
}
}

public TlsSecret calculateSecret() throws IOException
{
byte[] x25519Secret = domain.calculateX25519AgreementToBytes(x25519LocalKeyPair.getPrivate(), x25519PeerPublicKey);
if (!domain.getKemDomain().getTlsKemConfig().isServer())
{
mlkemSecret = domain.getKemDomain().decapsulate((MLKEMPrivateKeyParameters)mlkemLocalKeyPair.getPrivate(), mlkemCiphertext).extract();
}
return domain.getKemDomain().adoptLocalSecret(Arrays.concatenate(mlkemSecret, x25519Secret));
}
}
Loading