From 7f531e6b7523f8245801436643ddfda30caf68d6 Mon Sep 17 00:00:00 2001 From: Linuka Ratnayake <79963204+linukaratnayake@users.noreply.github.com> Date: Mon, 3 Feb 2025 22:14:45 +0530 Subject: [PATCH] Add support for X25519MLKEM768 hybrid algorithm --- .../jsse/provider/NamedGroupInfo.java | 4 +- .../java/org/bouncycastle/tls/NamedGroup.java | 8 ++ .../tls/crypto/impl/bc/BcTlsCrypto.java | 8 +- .../tls/crypto/impl/bc/BcTlsMLKemDomain.java | 6 ++ .../tls/crypto/impl/bc/BcTlsX25519MLKem.java | 73 +++++++++++++++ .../impl/bc/BcTlsX25519MLKemDomain.java | 60 +++++++++++++ .../tls/crypto/impl/jcajce/JcaTlsCrypto.java | 9 +- .../crypto/impl/jcajce/JceTlsMLKemDomain.java | 6 ++ .../crypto/impl/jcajce/JceTlsX25519MLKem.java | 78 ++++++++++++++++ .../impl/jcajce/JceTlsX25519MLKemDomain.java | 88 +++++++++++++++++++ 10 files changed, 337 insertions(+), 3 deletions(-) create mode 100644 tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsX25519MLKem.java create mode 100644 tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsX25519MLKemDomain.java create mode 100644 tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsX25519MLKem.java create mode 100644 tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsX25519MLKemDomain.java diff --git a/tls/src/main/java/org/bouncycastle/jsse/provider/NamedGroupInfo.java b/tls/src/main/java/org/bouncycastle/jsse/provider/NamedGroupInfo.java index f2af7facd9..c382ff42c6 100644 --- a/tls/src/main/java/org/bouncycastle/jsse/provider/NamedGroupInfo.java +++ b/tls/src/main/java/org/bouncycastle/jsse/provider/NamedGroupInfo.java @@ -81,7 +81,9 @@ private enum All OQS_mlkem1024(NamedGroup.OQS_mlkem1024, "ML-KEM"), MLKEM512(NamedGroup.MLKEM512, "ML-KEM"), MLKEM768(NamedGroup.MLKEM768, "ML-KEM"), - MLKEM1024(NamedGroup.MLKEM1024, "ML-KEM"); + MLKEM1024(NamedGroup.MLKEM1024, "ML-KEM"), + + X25519MLKEM768(NamedGroup.X25519MLKEM768, "ML-KEM"); private final int namedGroup; private final String name; diff --git a/tls/src/main/java/org/bouncycastle/tls/NamedGroup.java b/tls/src/main/java/org/bouncycastle/tls/NamedGroup.java index b8fc46d1c5..ba32a080ff 100644 --- a/tls/src/main/java/org/bouncycastle/tls/NamedGroup.java +++ b/tls/src/main/java/org/bouncycastle/tls/NamedGroup.java @@ -116,6 +116,11 @@ public class NamedGroup public static final int MLKEM768 = 0x0768; public static final int MLKEM1024 = 0x1024; + /* + * draft-kwiatkowski-tls-ecdhe-mlkem-03 + */ + public static final int X25519MLKEM768 = 0x11EC; + /* Names of the actual underlying elliptic curves (not necessarily matching the NamedGroup names). */ private static final String[] CURVE_NAMES = new String[]{ "sect163k1", "sect163r1", "sect163r2", "sect193r1", "sect193r2", "sect233k1", "sect233r1", "sect239k1", "sect283k1", "sect283r1", "sect409k1", "sect409r1", @@ -310,6 +315,8 @@ public static String getKemName(int namedGroup) case OQS_mlkem1024: case MLKEM1024: return "ML-KEM-1024"; + case X25519MLKEM768: + return "X25519MLKEM768"; default: return null; } @@ -502,6 +509,7 @@ public static boolean refersToASpecificKem(int namedGroup) case MLKEM512: case MLKEM768: case MLKEM1024: + case X25519MLKEM768: return true; default: return false; diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsCrypto.java b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsCrypto.java index 9207c2f544..813c04ef64 100644 --- a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsCrypto.java +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsCrypto.java @@ -219,7 +219,13 @@ public TlsECDomain createECDomain(TlsECConfig ecConfig) public TlsKemDomain createKemDomain(TlsKemConfig kemConfig) { - return new BcTlsMLKemDomain(this, kemConfig); + switch (kemConfig.getNamedGroup()) + { + case NamedGroup.X25519MLKEM768: + return new BcTlsX25519MLKemDomain(this, kemConfig); + default: + return new BcTlsMLKemDomain(this, kemConfig); + } } public TlsNonceGenerator createNonceGenerator(byte[] additionalSeedMaterial) diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsMLKemDomain.java b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsMLKemDomain.java index 4b2fc43294..3f69cb9a8d 100644 --- a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsMLKemDomain.java +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsMLKemDomain.java @@ -25,6 +25,7 @@ public static MLKEMParameters getDomainParameters(TlsKemConfig kemConfig) return MLKEMParameters.ml_kem_512; case NamedGroup.OQS_mlkem768: case NamedGroup.MLKEM768: + case NamedGroup.X25519MLKEM768: return MLKEMParameters.ml_kem_768; case NamedGroup.OQS_mlkem1024: case NamedGroup.MLKEM1024: @@ -47,6 +48,11 @@ public BcTlsMLKemDomain(BcTlsCrypto crypto, TlsKemConfig kemConfig) this.isServer = kemConfig.isServer(); } + public TlsKemConfig getTlsKemConfig() + { + return this.config; + } + public BcTlsSecret adoptLocalSecret(byte[] secret) { return crypto.adoptLocalSecret(secret); diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsX25519MLKem.java b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsX25519MLKem.java new file mode 100644 index 0000000000..2aa212ab99 --- /dev/null +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsX25519MLKem.java @@ -0,0 +1,73 @@ +package org.bouncycastle.tls.crypto.impl.bc; + +import java.io.IOException; + +import org.bouncycastle.crypto.AsymmetricCipherKeyPair; +import org.bouncycastle.crypto.SecretWithEncapsulation; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMPublicKeyParameters; +import org.bouncycastle.tls.crypto.TlsAgreement; +import org.bouncycastle.tls.crypto.TlsSecret; +import org.bouncycastle.util.Arrays; + +public class BcTlsX25519MLKem implements TlsAgreement +{ + protected final BcTlsX25519MLKemDomain domain; + + protected AsymmetricCipherKeyPair mlkemLocalKeyPair; + protected MLKEMPublicKeyParameters mlkemPeerPublicKey; + protected byte[] x25519PrivateKey; + protected byte[] x25519PeerPublicKey; + + protected byte[] mlkemCiphertext; + protected byte[] mlkemSecret; + + public BcTlsX25519MLKem(BcTlsX25519MLKemDomain domain) + { + this.domain = domain; + } + + public byte[] generateEphemeral() throws IOException + { + this.x25519PrivateKey = domain.generateX25519PrivateKey(); + byte[] x25519Key = domain.getX25519PublicKey(x25519PrivateKey); + byte[] mlkemKey; + if (domain.getKemDomain().getTlsKemConfig().isServer()) + { + mlkemKey = Arrays.clone(mlkemCiphertext); + } + else + { + this.mlkemLocalKeyPair = domain.getKemDomain().generateKeyPair(); + mlkemKey = domain.getKemDomain().encodePublicKey((MLKEMPublicKeyParameters)mlkemLocalKeyPair.getPublic()); + } + return Arrays.concatenate(mlkemKey, x25519Key); + } + + public void receivePeerValue(byte[] peerValue) throws IOException + { + this.x25519PeerPublicKey = Arrays.copyOfRange(peerValue, peerValue.length - domain.getX25519PublicKeyByteLength(), peerValue.length); + byte[] mlkemKey = Arrays.copyOf(peerValue, peerValue.length - domain.getX25519PublicKeyByteLength()); + if (domain.getKemDomain().getTlsKemConfig().isServer()) + { + this.mlkemPeerPublicKey = domain.getKemDomain().decodePublicKey(mlkemKey); + SecretWithEncapsulation encap = domain.getKemDomain().encapsulate(mlkemPeerPublicKey); + mlkemCiphertext = encap.getEncapsulation(); + mlkemSecret = encap.getSecret(); + } + else + { + this.mlkemCiphertext = Arrays.clone(mlkemKey); + } + } + + public TlsSecret calculateSecret() throws IOException + { + byte[] x25519Secret = domain.calculateX25519Secret(x25519PrivateKey, x25519PeerPublicKey); + if (!domain.getKemDomain().getTlsKemConfig().isServer()) + { + mlkemSecret = domain.getKemDomain().decapsulate((MLKEMPrivateKeyParameters)mlkemLocalKeyPair.getPrivate(), mlkemCiphertext).extract(); + } + return domain.getKemDomain().adoptLocalSecret(Arrays.concatenate(mlkemSecret, x25519Secret)); + } +} diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsX25519MLKemDomain.java b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsX25519MLKemDomain.java new file mode 100644 index 0000000000..660d2bd197 --- /dev/null +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/bc/BcTlsX25519MLKemDomain.java @@ -0,0 +1,60 @@ +package org.bouncycastle.tls.crypto.impl.bc; + +import java.io.IOException; +import org.bouncycastle.math.ec.rfc7748.X25519; +import org.bouncycastle.tls.AlertDescription; +import org.bouncycastle.tls.TlsFatalAlert; +import org.bouncycastle.tls.crypto.TlsAgreement; +import org.bouncycastle.tls.crypto.TlsKemConfig; +import org.bouncycastle.tls.crypto.TlsKemDomain; + +public class BcTlsX25519MLKemDomain implements TlsKemDomain +{ + protected final BcTlsMLKemDomain kemDomain; + protected final BcTlsCrypto crypto; + + public BcTlsX25519MLKemDomain(BcTlsCrypto crypto, TlsKemConfig kemConfig) + { + this.kemDomain = new BcTlsMLKemDomain(crypto, kemConfig); + this.crypto = crypto; + } + + public TlsAgreement createKem() + { + return new BcTlsX25519MLKem(this); + } + + public BcTlsMLKemDomain getKemDomain() + { + return kemDomain; + } + + public byte[] generateX25519PrivateKey() throws IOException + { + byte[] privateKey = new byte[X25519.SCALAR_SIZE]; + crypto.getSecureRandom().nextBytes(privateKey); + return privateKey; + } + + public byte[] getX25519PublicKey(byte[] privateKey) throws IOException + { + byte[] publicKey = new byte[X25519.POINT_SIZE]; + X25519.scalarMultBase(privateKey, 0, publicKey, 0); + return publicKey; + } + + public int getX25519PublicKeyByteLength() throws IOException + { + return X25519.POINT_SIZE; + } + + public byte[] calculateX25519Secret(byte[] privateKey, byte[] peerPublicKey) throws IOException + { + byte[] secret = new byte[X25519.POINT_SIZE]; + if (!X25519.calculateAgreement(privateKey, 0, peerPublicKey, 0, secret, 0)) + { + throw new TlsFatalAlert(AlertDescription.handshake_failure); + } + return secret; + } +} diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JcaTlsCrypto.java b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JcaTlsCrypto.java index d89b0257ec..eb16aa9b77 100644 --- a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JcaTlsCrypto.java +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JcaTlsCrypto.java @@ -458,6 +458,7 @@ else if (NamedGroup.refersToASpecificKem(namedGroup)) case NamedGroup.MLKEM512: case NamedGroup.MLKEM768: case NamedGroup.MLKEM1024: + case NamedGroup.X25519MLKEM768: return null; } } @@ -858,7 +859,13 @@ public TlsECDomain createECDomain(TlsECConfig ecConfig) public TlsKemDomain createKemDomain(TlsKemConfig kemConfig) { - return new JceTlsMLKemDomain(this, kemConfig); + switch (kemConfig.getNamedGroup()) + { + case NamedGroup.X25519MLKEM768: + return new JceTlsX25519MLKemDomain(this, kemConfig); + default: + return new JceTlsMLKemDomain(this, kemConfig); + } } public TlsSecret hkdfInit(int cryptoHashAlgorithm) diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsMLKemDomain.java b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsMLKemDomain.java index 5aaf97b7db..b9f22b3848 100644 --- a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsMLKemDomain.java +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsMLKemDomain.java @@ -25,6 +25,7 @@ public static MLKEMParameters getDomainParameters(TlsKemConfig kemConfig) return MLKEMParameters.ml_kem_512; case NamedGroup.OQS_mlkem768: case NamedGroup.MLKEM768: + case NamedGroup.X25519MLKEM768: return MLKEMParameters.ml_kem_768; case NamedGroup.OQS_mlkem1024: case NamedGroup.MLKEM1024: @@ -47,6 +48,11 @@ public JceTlsMLKemDomain(JcaTlsCrypto crypto, TlsKemConfig kemConfig) this.isServer = kemConfig.isServer(); } + public TlsKemConfig getTlsKemConfig() + { + return config; + } + public JceTlsSecret adoptLocalSecret(byte[] secret) { return crypto.adoptLocalSecret(secret); diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsX25519MLKem.java b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsX25519MLKem.java new file mode 100644 index 0000000000..34075c124f --- /dev/null +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsX25519MLKem.java @@ -0,0 +1,78 @@ +package org.bouncycastle.tls.crypto.impl.jcajce; + +import java.io.IOException; +import java.math.BigInteger; +import java.security.KeyPair; +import java.security.PublicKey; + +import org.bouncycastle.crypto.AsymmetricCipherKeyPair; +import org.bouncycastle.crypto.SecretWithEncapsulation; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMPublicKeyParameters; +import org.bouncycastle.tls.crypto.TlsAgreement; +import org.bouncycastle.tls.crypto.TlsSecret; +import org.bouncycastle.util.Arrays; + +public class JceTlsX25519MLKem implements TlsAgreement +{ + protected final JceTlsX25519MLKemDomain domain; + + protected KeyPair x25519LocalKeyPair; + protected PublicKey x25519PeerPublicKey; + protected AsymmetricCipherKeyPair mlkemLocalKeyPair; + protected MLKEMPublicKeyParameters mlkemPeerPublicKey; + + protected byte[] mlkemCiphertext; + protected byte[] mlkemSecret; + + public JceTlsX25519MLKem(JceTlsX25519MLKemDomain domain) + { + this.domain = domain; + } + + public byte[] generateEphemeral() throws IOException + { + this.x25519LocalKeyPair = domain.generateX25519KeyPair(); + byte[] x25519Key = domain.encodeX25519PublicKey(x25519LocalKeyPair.getPublic()); + byte[] mlkemKey; + if (domain.getKemDomain().getTlsKemConfig().isServer()) + { + mlkemKey = Arrays.clone(mlkemCiphertext); + } + else + { + this.mlkemLocalKeyPair = domain.getKemDomain().generateKeyPair(); + mlkemKey = domain.getKemDomain().encodePublicKey((MLKEMPublicKeyParameters)mlkemLocalKeyPair.getPublic()); + + } + return Arrays.concatenate(mlkemKey, x25519Key); + } + + public void receivePeerValue(byte[] peerValue) throws IOException + { + byte[] xdhKey = Arrays.copyOfRange(peerValue, peerValue.length - domain.getX25519PublicKeyByteLength(), peerValue.length); + byte[] mlkemKey = Arrays.copyOf(peerValue,peerValue.length - domain.getX25519PublicKeyByteLength()); + this.x25519PeerPublicKey = domain.decodeX25519PublicKey(xdhKey); + if (domain.getKemDomain().getTlsKemConfig().isServer()) + { + this.mlkemPeerPublicKey = domain.getKemDomain().decodePublicKey(mlkemKey); + SecretWithEncapsulation encap = domain.getKemDomain().encapsulate(mlkemPeerPublicKey); + this.mlkemCiphertext = encap.getEncapsulation(); + mlkemSecret = encap.getSecret(); + } + else + { + this.mlkemCiphertext = Arrays.clone(mlkemKey); + } + } + + public TlsSecret calculateSecret() throws IOException + { + byte[] x25519Secret = domain.calculateX25519AgreementToBytes(x25519LocalKeyPair.getPrivate(), x25519PeerPublicKey); + if (!domain.getKemDomain().getTlsKemConfig().isServer()) + { + mlkemSecret = domain.getKemDomain().decapsulate((MLKEMPrivateKeyParameters)mlkemLocalKeyPair.getPrivate(), mlkemCiphertext).extract(); + } + return domain.getKemDomain().adoptLocalSecret(Arrays.concatenate(mlkemSecret, x25519Secret)); + } +} diff --git a/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsX25519MLKemDomain.java b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsX25519MLKemDomain.java new file mode 100644 index 0000000000..ac05e934c6 --- /dev/null +++ b/tls/src/main/java/org/bouncycastle/tls/crypto/impl/jcajce/JceTlsX25519MLKemDomain.java @@ -0,0 +1,88 @@ +package org.bouncycastle.tls.crypto.impl.jcajce; + +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.security.KeyPair; +import java.security.PrivateKey; +import java.security.PublicKey; + +import org.bouncycastle.math.ec.rfc7748.X25519; +import org.bouncycastle.tls.AlertDescription; +import org.bouncycastle.tls.TlsFatalAlert; +import org.bouncycastle.tls.crypto.TlsAgreement; +import org.bouncycastle.tls.crypto.TlsCryptoException; +import org.bouncycastle.tls.crypto.TlsKemConfig; +import org.bouncycastle.tls.crypto.TlsKemDomain; +import org.bouncycastle.util.Arrays; + +public class JceTlsX25519MLKemDomain implements TlsKemDomain +{ + protected final JceTlsMLKemDomain mlkemDomain; + protected final JceX25519Domain x25519Domain; + protected final JcaTlsCrypto crypto; + + public JceTlsX25519MLKemDomain(JcaTlsCrypto crypto, TlsKemConfig pqcConfig) + { + this.mlkemDomain = new JceTlsMLKemDomain(crypto, pqcConfig); + this.crypto = crypto; + this.x25519Domain = new JceX25519Domain(crypto); + } + + public TlsAgreement createKem() + { + return new JceTlsX25519MLKem(this); + } + + public JceTlsMLKemDomain getKemDomain() + { + return mlkemDomain; + } + + public KeyPair generateX25519KeyPair() + { + try + { + return x25519Domain.generateKeyPair(); + } + catch (Exception e) + { + throw Exceptions.illegalStateException("Unable to create key pair: " + e.getMessage(), e); + } + } + + public byte[] encodeX25519PublicKey(PublicKey publicKey) throws IOException + { + return XDHUtil.encodePublicKey(publicKey); + } + + public int getX25519PublicKeyByteLength() throws IOException + { + return X25519.POINT_SIZE; + } + + public PublicKey decodeX25519PublicKey(byte[] x25519Key) throws IOException + { + return x25519Domain.decodePublicKey(x25519Key); + } + + public byte[] calculateX25519AgreementToBytes(PrivateKey privateKey, PublicKey publicKey) throws IOException + { + try + { + byte[] secret = crypto.calculateKeyAgreement("X25519", privateKey, publicKey, "TlsPremasterSecret"); + if (secret == null || secret.length != 32) + { + throw new TlsCryptoException("Invalid secret calculated"); + } + if (Arrays.areAllZeroes(secret, 0, secret.length)) + { + throw new TlsFatalAlert(AlertDescription.handshake_failure); + } + return secret; + } + catch (GeneralSecurityException e) + { + throw new TlsCryptoException("Cannot calculate secret", e); + } + } +}