Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ML-DSA support for TLSv1.3 (draft-tls-westerbaan-mldsa-00) #2020

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ private static Map<String, PublicKeyFilter> createFiltersClient()
addFilter(filters, DSAPublicKey.class, "DSA");
addFilter(filters, ECPublicKey.class, "EC");

addFilter((filters), "ML-DSA-44");
addFilter((filters), "ML-DSA-65");
addFilter((filters), "ML-DSA-87");

return Collections.unmodifiableMap(filters);
}

Expand Down Expand Up @@ -201,6 +205,10 @@ private static Map<String, PublicKeyFilter> createFiltersServer()
KeyExchangeAlgorithm.SRP_RSA);
addFilterLegacyServer(filters, ProvAlgorithmChecker.KU_KEY_ENCIPHERMENT, "RSA", KeyExchangeAlgorithm.RSA);

addFilter((filters), "ML-DSA-44");
addFilter((filters), "ML-DSA-65");
addFilter((filters), "ML-DSA-87");

return Collections.unmodifiableMap(filters);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ private enum All
sm2sig_sm3(SignatureScheme.sm2sig_sm3, "SM3withSM2", "EC"),

// TODO[tls] Need mechanism for restricting signature schemes to TLS 1.3+ before adding
// DRAFT_mldsa44(SignatureScheme.DRAFT_mldsa44, "ML-DSA-44", "ML-DSA-44"),
// DRAFT_mldsa65(SignatureScheme.DRAFT_mldsa65, "ML-DSA-65", "ML-DSA-65"),
// DRAFT_mldsa87(SignatureScheme.DRAFT_mldsa87, "ML-DSA-87", "ML-DSA-87"),
DRAFT_mldsa44(SignatureScheme.DRAFT_mldsa44, "ML-DSA-44", "ML-DSA-44"),
DRAFT_mldsa65(SignatureScheme.DRAFT_mldsa65, "ML-DSA-65", "ML-DSA-65"),
DRAFT_mldsa87(SignatureScheme.DRAFT_mldsa87, "ML-DSA-87", "ML-DSA-87"),

/*
* Legacy/Historical: mostly not supported in 1.3, except ecdsa_sha1 and rsa_pkcs1_sha1 are
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,11 +455,12 @@ public boolean hasSignatureScheme(int signatureScheme)
switch (signatureScheme)
{
case SignatureScheme.sm2sig_sm3:
return false;
// TODO[tls] Test coverage before adding
case SignatureScheme.DRAFT_mldsa44:
case SignatureScheme.DRAFT_mldsa65:
case SignatureScheme.DRAFT_mldsa87:
return false;
return true;
default:
{
short signature = SignatureScheme.getSignatureAlgorithm(signatureScheme);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ else if ("Ed448".equalsIgnoreCase(algorithm))
{
signer = new JcaTlsEd448Signer(crypto, privateKey);
}
else if ("ML-DSA-44".equalsIgnoreCase(algorithm)
|| "ML-DSA-65".equalsIgnoreCase(algorithm)
|| "ML-DSA-87".equalsIgnoreCase(algorithm))
{
signer = new JcaTlsMLDSASigner(crypto, privateKey);
}
else
{
throw new IllegalArgumentException("'privateKey' type not supported: " + privateKey.getClass().getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ public Tls13Verifier createVerifier(int signatureScheme) throws IOException
case SignatureScheme.DRAFT_mldsa44:
case SignatureScheme.DRAFT_mldsa65:
case SignatureScheme.DRAFT_mldsa87:
return crypto.createTls13Verifier("ML-DSA", null, getPubKeyMLDSA());

default:
throw new TlsFatalAlert(AlertDescription.internal_error);
Expand Down Expand Up @@ -396,6 +397,11 @@ PublicKey getPubKeyRSA() throws IOException
return getPublicKey();
}

PublicKey getPubKeyMLDSA() throws IOException
{
return getPublicKey();
}

public short getLegacySignatureAlgorithm() throws IOException
{
PublicKey publicKey = getPublicKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -778,11 +778,12 @@ public boolean hasSignatureScheme(int signatureScheme)
switch (signatureScheme)
{
case SignatureScheme.sm2sig_sm3:
return false;
// TODO[tls] Implement before adding
case SignatureScheme.DRAFT_mldsa44:
case SignatureScheme.DRAFT_mldsa65:
case SignatureScheme.DRAFT_mldsa87:
return false;
return true;
default:
{
short signature = SignatureScheme.getSignatureAlgorithm(signatureScheme);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package org.bouncycastle.tls.crypto.impl.jcajce;

import java.io.IOException;
import java.security.PrivateKey;

import org.bouncycastle.tls.SignatureAndHashAlgorithm;
import org.bouncycastle.tls.crypto.TlsSigner;
import org.bouncycastle.tls.crypto.TlsStreamSigner;

public class JcaTlsMLDSASigner
implements TlsSigner
{
protected final JcaTlsCrypto crypto;
protected final PrivateKey privateKey;

public JcaTlsMLDSASigner(JcaTlsCrypto crypto, PrivateKey privateKey)
{
if (null == crypto)
{
throw new NullPointerException("crypto");
}
if (null == privateKey)
{
throw new NullPointerException("privateKey");
}

this.crypto = crypto;
this.privateKey = privateKey;
}

public byte[] generateRawSignature(SignatureAndHashAlgorithm algorithm, byte[] hash) throws IOException
{
throw new UnsupportedOperationException();
}

public TlsStreamSigner getStreamSigner(SignatureAndHashAlgorithm algorithm) throws IOException
{
if (algorithm == null)
{
throw new IllegalStateException("Invalid algorithm: " + algorithm);
}

return crypto.createStreamSigner("ML-DSA", null, privateKey, false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

import org.bouncycastle.test.PrintTestResult;

public class AllTests
Expand All @@ -25,6 +26,7 @@ public static Test suite()
suite.addTestSuite(ConfigTest.class);
suite.addTestSuite(ECDSACredentialsTest.class);
suite.addTestSuite(EdDSACredentialsTest.class);
suite.addTestSuite(MLDSACredentialsTest.class);
suite.addTestSuite(InstanceTest.class);
suite.addTestSuite(KeyManagerFactoryTest.class);
suite.addTestSuite(PSSCredentialsTest.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
package org.bouncycastle.jsse.provider.test;

import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.KeyPair;
import java.security.KeyStore;
import java.security.SecureRandom;
import java.security.cert.X509Certificate;
import java.util.concurrent.CountDownLatch;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManagerFactory;

import junit.framework.TestCase;

public class MLDSACredentialsTest
extends TestCase
{
protected void setUp()
{
ProviderUtils.setupLowPriority(false);
}

private static final String HOST = "localhost";
private static final int PORT_NO_13_MLDSA44 = 9050;
private static final int PORT_NO_13_MLDSA65 = 9051;
private static final int PORT_NO_13_MLDSA87 = 9052;

static class MLDSAClient
implements TestProtocolUtil.BlockingCallable
{
private final int port;
private final String protocol;
private final KeyStore trustStore;
private final KeyStore clientStore;
private final char[] clientKeyPass;
private final CountDownLatch latch;

MLDSAClient(int port, String protocol, KeyStore clientStore, char[] clientKeyPass,
X509Certificate trustAnchor) throws GeneralSecurityException, IOException
{
KeyStore trustStore = createKeyStore();
trustStore.setCertificateEntry("server", trustAnchor);

this.port = port;
this.protocol = protocol;
this.trustStore = trustStore;
this.clientStore = clientStore;
this.clientKeyPass = clientKeyPass;
this.latch = new CountDownLatch(1);
}

public Exception call() throws Exception
{
try
{
TrustManagerFactory trustMgrFact = TrustManagerFactory.getInstance("PKIX",
ProviderUtils.PROVIDER_NAME_BCJSSE);
trustMgrFact.init(trustStore);

KeyManagerFactory keyMgrFact = KeyManagerFactory.getInstance("PKIX",
ProviderUtils.PROVIDER_NAME_BCJSSE);
keyMgrFact.init(clientStore, clientKeyPass);

SSLContext clientContext = SSLContext.getInstance("TLS", ProviderUtils.PROVIDER_NAME_BCJSSE);
clientContext.init(keyMgrFact.getKeyManagers(), trustMgrFact.getTrustManagers(),
SecureRandom.getInstance("DEFAULT", ProviderUtils.PROVIDER_NAME_BC));

SSLSocketFactory fact = clientContext.getSocketFactory();
SSLSocket cSock = (SSLSocket)fact.createSocket(HOST, port);
cSock.setEnabledProtocols(new String[]{ protocol });

SSLSession session = cSock.getSession();
assertNotNull(session);
assertFalse("SSL_NULL_WITH_NULL_NULL".equals(session.getCipherSuite()));
assertEquals("CN=Test CA Certificate", session.getLocalPrincipal().getName());
assertEquals("CN=Test CA Certificate", session.getPeerPrincipal().getName());

TestProtocolUtil.doClientProtocol(cSock, "Hello");
}
finally
{
latch.countDown();
}

return null;
}

public void await()
throws InterruptedException
{
latch.await();
}
}

static class MLDSAServer
implements TestProtocolUtil.BlockingCallable
{
private final int port;
private final String protocol;
private final KeyStore serverStore;
private final char[] keyPass;
private final KeyStore trustStore;
private final CountDownLatch latch;

MLDSAServer(int port, String protocol, KeyStore serverStore, char[] keyPass, X509Certificate trustAnchor)
throws GeneralSecurityException, IOException
{
KeyStore trustStore = createKeyStore();
trustStore.setCertificateEntry("client", trustAnchor);

this.port = port;
this.protocol = protocol;
this.serverStore = serverStore;
this.keyPass = keyPass;
this.trustStore = trustStore;
this.latch = new CountDownLatch(1);
}

public Exception call() throws Exception
{
try
{
KeyManagerFactory keyMgrFact = KeyManagerFactory.getInstance("PKIX",
ProviderUtils.PROVIDER_NAME_BCJSSE);
keyMgrFact.init(serverStore, keyPass);

TrustManagerFactory trustMgrFact = TrustManagerFactory.getInstance("PKIX",
ProviderUtils.PROVIDER_NAME_BCJSSE);
trustMgrFact.init(trustStore);

SSLContext serverContext = SSLContext.getInstance("TLS", ProviderUtils.PROVIDER_NAME_BCJSSE);
serverContext.init(keyMgrFact.getKeyManagers(), trustMgrFact.getTrustManagers(),
SecureRandom.getInstance("DEFAULT", ProviderUtils.PROVIDER_NAME_BC));

SSLServerSocketFactory fact = serverContext.getServerSocketFactory();
SSLServerSocket sSock = (SSLServerSocket)fact.createServerSocket(port);

SSLUtils.enableAll(sSock);
sSock.setNeedClientAuth(true);

latch.countDown();

SSLSocket sslSock = (SSLSocket)sSock.accept();
sslSock.setEnabledProtocols(new String[]{ protocol });

SSLSession session = sslSock.getSession();
assertNotNull(session);
assertFalse("SSL_NULL_WITH_NULL_NULL".equals(session.getCipherSuite()));
assertEquals("CN=Test CA Certificate", session.getLocalPrincipal().getName());
assertEquals("CN=Test CA Certificate", session.getPeerPrincipal().getName());

TestProtocolUtil.doServerProtocol(sslSock, "World");

sslSock.close();
sSock.close();
}
finally
{
latch.countDown();
}

return null;
}

public void await() throws InterruptedException
{
latch.await();
}
}

public void test13_MLDSA44() throws Exception
{
implTestMLDSACredentials(PORT_NO_13_MLDSA44, "TLSv1.3", TestUtils.generateMLDSAKeyPair("ML-DSA-44"));
}

public void test13_MLDSA65() throws Exception
{
implTestMLDSACredentials(PORT_NO_13_MLDSA65, "TLSv1.3", TestUtils.generateMLDSAKeyPair("ML-DSA-65"));
}

public void test13_MLDSA87() throws Exception
{
implTestMLDSACredentials(PORT_NO_13_MLDSA87, "TLSv1.3", TestUtils.generateMLDSAKeyPair("ML-DSA-87"));
}

private void implTestMLDSACredentials(int port, String protocol, KeyPair caKeyPair) throws Exception
{
char[] keyPass = "keyPassword".toCharArray();

X509Certificate caCert = TestUtils.generateRootCert(caKeyPair);

KeyStore serverKs = createKeyStore();
serverKs.setKeyEntry("server", caKeyPair.getPrivate(), keyPass, new X509Certificate[]{ caCert });

KeyStore clientKs = createKeyStore();
clientKs.setKeyEntry("client", caKeyPair.getPrivate(), keyPass, new X509Certificate[]{ caCert });

TestProtocolUtil.runClientAndServer(new MLDSAServer(port, protocol, serverKs, keyPass, caCert),
new MLDSAClient(port, protocol, clientKs, keyPass, caCert));
}

private static KeyStore createKeyStore() throws GeneralSecurityException, IOException
{
/*
* NOTE: At the time of writing, default JKS implementation can't recover PKCS8 private keys
* with version != 0, which e.g. is the case when a public key is included, which the BC
* provider currently does for MLDSA.
*/
// KeyStore keyStore = KeyStore.getInstance("JKS");
KeyStore keyStore = KeyStore.getInstance("PKCS12", "BC");
keyStore.load(null, null);
return keyStore;
}
}
Loading