Skip to content

NimbusJwtEncoder should simplify constructing with javax.security Keys #17033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

import java.net.URI;
import java.net.URL;
import java.security.KeyPair;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Date;
Expand All @@ -26,18 +31,27 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;

import javax.crypto.SecretKey;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
import com.nimbusds.jose.jwk.Curve;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.produce.JWSSignerFactory;
Expand All @@ -47,10 +61,13 @@
import com.nimbusds.jwt.SignedJWT;

import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.util.function.ThrowingConsumer;

/**
* An implementation of a {@link JwtEncoder} that encodes a JSON Web Token (JWT) using the
Expand Down Expand Up @@ -83,6 +100,8 @@ public final class NimbusJwtEncoder implements JwtEncoder {

private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory();

private final JwsHeader jwsHeader;

private final Map<JWK, JWSSigner> jwsSigners = new ConcurrentHashMap<>();

private final JWKSource<SecurityContext> jwkSource;
Expand All @@ -100,10 +119,22 @@ public final class NimbusJwtEncoder implements JwtEncoder {
* @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource}
*/
public NimbusJwtEncoder(JWKSource<SecurityContext> jwkSource) {
this.jwsHeader = DEFAULT_JWS_HEADER;
Assert.notNull(jwkSource, "jwkSource cannot be null");
this.jwkSource = jwkSource;
}

private NimbusJwtEncoder(JWK jwk) {
Assert.notNull(jwk, "jwk cannot be null");
this.jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
JwsAlgorithm algorithm = SignatureAlgorithm.from(jwk.getAlgorithm().getName());
if (algorithm == null) {
algorithm = MacAlgorithm.from(jwk.getAlgorithm().getName());
}
Assert.notNull(algorithm, "Failed to derive supported algorithm from " + jwk.getAlgorithm());
this.jwsHeader = JwsHeader.with(algorithm).type(jwk.getKeyType().getValue()).keyId(jwk.getKeyID()).build();
}

/**
* Use this strategy to reduce the list of matching JWKs when there is more than one.
* <p>
Expand All @@ -125,8 +156,9 @@ public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException {

JwsHeader headers = parameters.getJwsHeader();
if (headers == null) {
headers = DEFAULT_JWS_HEADER;
headers = this.jwsHeader;
}

JwtClaimsSet claims = parameters.getClaims();

JWK jwk = selectJwk(headers);
Expand Down Expand Up @@ -369,4 +401,212 @@ private static URI convertAsURI(String header, URL url) {
}
}

/**
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
* @param publicKey the {@link RSAPublicKey} and @Param privateKey the
* {@link RSAPrivateKey} to use for signing JWTs
* @return a {@link RsaKeyPairJwtEncoderBuilder}
* @since 7.0
*/
public static RsaKeyPairJwtEncoderBuilder withRsaKeyPair(RSAPublicKey publicKey, RSAPrivateKey privateKey) {
return new RsaKeyPairJwtEncoderBuilder(publicKey, privateKey);
}

/**
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
* @param publicKey the {@link ECPublicKey} and @param privateKey the
* {@link ECPrivateKey} to use for signing JWTs
* @return a {@link EcKeyPairJwtEncoderBuilder}
*/
public static EcKeyPairJwtEncoderBuilder withEcKeyPair(ECPublicKey publicKey, ECPrivateKey privateKey) {
return new EcKeyPairJwtEncoderBuilder(publicKey, privateKey);
}

/**
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
* @param secretKey
* @return
*/
public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) {
return new SecretKeyJwtEncoderBuilder(secretKey);
}

/**
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
* {@link SecretKey}.
*
* @since 7.0
*/
public static final class SecretKeyJwtEncoderBuilder {

private static final ThrowingConsumer<OctetSequenceKey.Builder> defaultKid = OctetSequenceKey.Builder::keyIDFromThumbprint;

private final OctetSequenceKey.Builder builder;

private SecretKeyJwtEncoderBuilder(SecretKey secretKey) {
Assert.notNull(secretKey, "secretKey cannot be null");
OctetSequenceKey.Builder builder = new OctetSequenceKey.Builder(secretKey).keyUse(KeyUse.SIGNATURE)
.algorithm(JWSAlgorithm.HS256);
defaultKid.accept(builder, IllegalArgumentException::new);
this.builder = builder;
}

/**
* Sets the JWS algorithm to use for signing. Defaults to
* {@link JWSAlgorithm#HS256}. Must be an HMAC-based algorithm (HS256, HS384, or
* HS512).
* @param macAlgorithm the {@link MacAlgorithm} to use
* @return this builder instance for method chaining
*/
public SecretKeyJwtEncoderBuilder algorithm(MacAlgorithm macAlgorithm) {
Assert.notNull(macAlgorithm, "macAlgorithm cannot be null");
this.builder.algorithm(JWSAlgorithm.parse(macAlgorithm.getName()));
return this;
}

/**
* Post-process the {@link JWK} using the given {@link Consumer}. For example, you
* may use this to override the default {@code kid}
* @param jwkPostProcessor the post-processor to use
* @return this builder instance for method chaining
*/
public SecretKeyJwtEncoderBuilder jwkPostProcessor(Consumer<OctetSequenceKey.Builder> jwkPostProcessor) {
Assert.notNull(jwkPostProcessor, "jwkPostProcessor cannot be null");
jwkPostProcessor.accept(this.builder);
return this;
}

/**
* Builds the {@link NimbusJwtEncoder} instance.
* @return the configured {@link NimbusJwtEncoder}
* @throws IllegalStateException if the configured JWS algorithm is not compatible
* with a {@link SecretKey}.
*/
public NimbusJwtEncoder build() {
return new NimbusJwtEncoder(this.builder.build());
}

}

/**
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
* {@link KeyPair}.
*
* @since 7.0
*/

public static final class RsaKeyPairJwtEncoderBuilder {

private static final ThrowingConsumer<RSAKey.Builder> defaultKid = RSAKey.Builder::keyIDFromThumbprint;

private final RSAKey.Builder builder;

private RsaKeyPairJwtEncoderBuilder(RSAPublicKey publicKey, RSAPrivateKey privateKey) {
Assert.notNull(publicKey, "publicKey cannot be null");
Assert.notNull(privateKey, "privateKey cannot be null");
RSAKey.Builder builder = new RSAKey.Builder(publicKey).privateKey(privateKey)
.keyUse(KeyUse.SIGNATURE)
.algorithm(JWSAlgorithm.RS256);
defaultKid.accept(builder, IllegalArgumentException::new);
this.builder = builder;
}

/**
* Sets the JWS algorithm to use for signing. Defaults to
* {@link SignatureAlgorithm#RS256}. Must be an RSA-based algorithm
* @param signatureAlgorithm the {@link SignatureAlgorithm} to use
* @return this builder instance for method chaining
*/
public RsaKeyPairJwtEncoderBuilder algorithm(SignatureAlgorithm signatureAlgorithm) {
Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null");
this.builder.algorithm(JWSAlgorithm.parse(signatureAlgorithm.getName()));
return this;
}

/**
* Add commentMore actions Post-process the {@link JWK} using the given
* {@link Consumer}. For example, you may use this to override the default
* {@code kid}
* @param jwkPostProcessor the post-processor to use
* @return this builder instance for method chaining
*/
public RsaKeyPairJwtEncoderBuilder jwkPostProcessor(Consumer<RSAKey.Builder> jwkPostProcessor) {
Assert.notNull(jwkPostProcessor, "jwkPostProcessor cannot be null");
jwkPostProcessor.accept(this.builder);
return this;
}

/**
* Builds the {@link NimbusJwtEncoder} instance.
* @return the configured {@link NimbusJwtEncoder}
*/
public NimbusJwtEncoder build() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move this down to RsaKeyJwtEncoderBuilder and EcKeyJwtEncoderBuilder. Each should be able to use the constructor defaults to have the following implementation for build():

return new NimbusJwtEncoder(this.builder.build());

return new NimbusJwtEncoder(this.builder.build());
}

}

/**
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
* {@link ECPublicKey} and {@link ECPrivateKey}.
* <p>
* This builder is used to create a {@link NimbusJwtEncoder}
*
* @since 7.0
*/
public static final class EcKeyPairJwtEncoderBuilder {

private static final ThrowingConsumer<ECKey.Builder> defaultKid = ECKey.Builder::keyIDFromThumbprint;

private final ECKey.Builder builder;

private EcKeyPairJwtEncoderBuilder(ECPublicKey publicKey, ECPrivateKey privateKey) {
Assert.notNull(publicKey, "publicKey cannot be null");
Assert.notNull(privateKey, "privateKey cannot be null");
Curve curve = Curve.forECParameterSpec(publicKey.getParams());
Assert.notNull(curve, "Unable to determine Curve for EC public key.");
ECKey.Builder builder = new ECKey.Builder(curve, publicKey).privateKey(privateKey)
.keyUse(KeyUse.SIGNATURE)
.algorithm(JWSAlgorithm.ES256);
defaultKid.accept(builder, IllegalArgumentException::new);
this.builder = builder;
}

/**
* Sets the JWS algorithm to use for signing. Defaults to
* {@link SignatureAlgorithm#ES256}. Must be an EC-based algorithm (ES256, ES384,
* or ES512).
* @param signatureAlgorithm the {@link SignatureAlgorithm} to use
* @return this builder instance for method chaining
*/
public EcKeyPairJwtEncoderBuilder algorithm(SignatureAlgorithm signatureAlgorithm) {
Assert.state(JWSAlgorithm.Family.EC.contains(JWSAlgorithm.parse(signatureAlgorithm.getName())),
() -> "The algorithm '" + signatureAlgorithm + "' is not compatible with an ECKey. "
+ "Please use one of the ES256, ES384, or ES512 algorithms.");
this.builder.algorithm(JWSAlgorithm.parse(signatureAlgorithm.getName()));
return this;
}

/**
* Post-process the {@link JWK} using the given {@link Consumer}. For example, you
* may use this to override the default {@code kid}
* @param jwkPostProcessor the post-processor to use
* @return this builder instance for method chaining
*/
public EcKeyPairJwtEncoderBuilder jwkPostProcessor(Consumer<ECKey.Builder> jwkPostProcessor) {
Assert.notNull(jwkPostProcessor, "jwkPostProcessor cannot be null");
jwkPostProcessor.accept(this.builder);
return this;
}

/**
* Builds the {@link NimbusJwtEncoder} instance.
* @return the configured {@link NimbusJwtEncoder}
*/
public NimbusJwtEncoder build() {
return new NimbusJwtEncoder(this.builder.build());
}

}

}
Loading