Skip to content

Commit 0bf7acb

Browse files
authoredMay 20, 2021
Support different signing keys per issuer (issuerId) (#44)
* signing key per issuerId will be generated on-demand and reused * feat: fix #43 - support different signing keys per issuer (issuerId)
1 parent 55e132a commit 0bf7acb

File tree

6 files changed

+109
-68
lines changed

6 files changed

+109
-68
lines changed
 

‎build.gradle.kts

+1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ dependencies {
6363
implementation("org.freemarker:freemarker:$freemarkerVersion")
6464
testImplementation("org.assertj:assertj-core:$assertjVersion")
6565
testImplementation("org.junit.jupiter:junit-jupiter-api:$junitJupiterVersion")
66+
testImplementation("org.junit.jupiter:junit-jupiter-params:$junitJupiterVersion")
6667
testImplementation("io.kotest:kotest-runner-junit5-jvm:$kotestVersion") // for kotest framework
6768
testImplementation("io.kotest:kotest-assertions-core-jvm:$kotestVersion") // for kotest core jvm assertions
6869
testImplementation("org.jetbrains.kotlin:kotlin-test-junit5:$kotlinVersion")

‎src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequest.kt

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package no.nav.security.mock.oauth2.http
22

33
import com.nimbusds.oauth2.sdk.GrantType
44
import com.nimbusds.oauth2.sdk.TokenRequest
5-
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication
65
import com.nimbusds.oauth2.sdk.http.HTTPRequest
76
import com.nimbusds.openid.connect.sdk.AuthenticationRequest
87
import no.nav.security.mock.oauth2.extensions.clientAuthentication

‎src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequestHandler.kt

+8-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class OAuth2HttpRequestHandler(
7070
AUTHORIZATION -> handleAuthenticationRequest(request)
7171
TOKEN -> handleTokenRequest(request)
7272
END_SESSION -> handleEndSessionRequest(request)
73-
JWKS -> json(config.tokenProvider.publicJwkSet().toJSONObject()).also { log.debug("handle jwks request") }
73+
JWKS -> handleJwksRequest(request)
7474
DEBUGGER -> debuggerRequestHandler.handleDebuggerForm(request).also { log.debug("handle debugger request") }
7575
DEBUGGER_CALLBACK -> debuggerRequestHandler.handleDebuggerCallback(request).also { log.debug("handle debugger callback request") }
7676
FAVICON -> OAuth2HttpResponse(status = 200)
@@ -84,6 +84,13 @@ class OAuth2HttpRequestHandler(
8484

8585
fun enqueueTokenCallback(oAuth2TokenCallback: OAuth2TokenCallback) = tokenCallbackQueue.add(oAuth2TokenCallback)
8686

87+
private fun handleJwksRequest(request: OAuth2HttpRequest): OAuth2HttpResponse {
88+
log.debug("handle jwks request on url=${request.url}")
89+
val issuerId = request.url.issuerId()
90+
val jwkSet = config.tokenProvider.publicJwkSet(issuerId)
91+
return json(jwkSet.toJSONObject())
92+
}
93+
8794
private fun handleEndSessionRequest(request: OAuth2HttpRequest): OAuth2HttpResponse {
8895
log.debug("handle end session request $request")
8996
val postLogoutRedirectUri = request.url.queryParameter("post_logout_redirect_uri")

‎src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt

+54-60
Original file line numberDiff line numberDiff line change
@@ -11,76 +11,73 @@ import com.nimbusds.jwt.JWTClaimsSet
1111
import com.nimbusds.jwt.SignedJWT
1212
import com.nimbusds.oauth2.sdk.TokenRequest
1313
import no.nav.security.mock.oauth2.extensions.clientIdAsString
14+
import no.nav.security.mock.oauth2.extensions.issuerId
1415
import okhttp3.HttpUrl
15-
import java.security.KeyPair
1616
import java.security.KeyPairGenerator
1717
import java.security.interfaces.RSAPrivateKey
1818
import java.security.interfaces.RSAPublicKey
1919
import java.time.Duration
2020
import java.time.Instant
2121
import java.util.Date
2222
import java.util.UUID
23+
import java.util.concurrent.ConcurrentHashMap
2324

2425
class OAuth2TokenProvider {
25-
private val jwkSet: JWKSet = generateJWKSet(DEFAULT_KEYID)
26-
private val rsaKey: RSAKey = jwkSet.getKeyByKeyId(DEFAULT_KEYID) as RSAKey
26+
private val signingKeys: ConcurrentHashMap<String, RSAKey> = ConcurrentHashMap()
2727

28-
fun publicJwkSet(): JWKSet {
29-
return jwkSet.toPublicJWKSet()
28+
@JvmOverloads
29+
fun publicJwkSet(issuerId: String = "default"): JWKSet {
30+
return JWKSet(rsaKey(issuerId)).toPublicJWKSet()
3031
}
3132

3233
fun idToken(
3334
tokenRequest: TokenRequest,
3435
issuerUrl: HttpUrl,
3536
oAuth2TokenCallback: OAuth2TokenCallback,
3637
nonce: String? = null
37-
) = createSignedJWT(
38-
defaultClaims(
39-
issuerUrl,
40-
oAuth2TokenCallback.subject(tokenRequest),
41-
listOf(tokenRequest.clientIdAsString()),
42-
nonce,
43-
oAuth2TokenCallback.addClaims(tokenRequest),
44-
oAuth2TokenCallback.tokenExpiry()
45-
)
46-
)
38+
) = defaultClaims(
39+
issuerUrl,
40+
oAuth2TokenCallback.subject(tokenRequest),
41+
listOf(tokenRequest.clientIdAsString()),
42+
nonce,
43+
oAuth2TokenCallback.addClaims(tokenRequest),
44+
oAuth2TokenCallback.tokenExpiry()
45+
).sign(issuerUrl.issuerId())
4746

4847
fun accessToken(
4948
tokenRequest: TokenRequest,
5049
issuerUrl: HttpUrl,
5150
oAuth2TokenCallback: OAuth2TokenCallback,
5251
nonce: String? = null
53-
) = createSignedJWT(
54-
defaultClaims(
55-
issuerUrl,
56-
oAuth2TokenCallback.subject(tokenRequest),
57-
oAuth2TokenCallback.audience(tokenRequest),
58-
nonce,
59-
oAuth2TokenCallback.addClaims(tokenRequest),
60-
oAuth2TokenCallback.tokenExpiry()
61-
)
62-
)
52+
) = defaultClaims(
53+
issuerUrl,
54+
oAuth2TokenCallback.subject(tokenRequest),
55+
oAuth2TokenCallback.audience(tokenRequest),
56+
nonce,
57+
oAuth2TokenCallback.addClaims(tokenRequest),
58+
oAuth2TokenCallback.tokenExpiry()
59+
).sign(issuerUrl.issuerId())
6360

6461
fun exchangeAccessToken(
6562
tokenRequest: TokenRequest,
6663
issuerUrl: HttpUrl,
6764
claimsSet: JWTClaimsSet,
6865
oAuth2TokenCallback: OAuth2TokenCallback
6966
) = Instant.now().let { now ->
70-
createSignedJWT(
71-
JWTClaimsSet.Builder(claimsSet)
72-
.issuer(issuerUrl.toString())
73-
.expirationTime(Date.from(now.plusSeconds(oAuth2TokenCallback.tokenExpiry())))
74-
.notBeforeTime(Date.from(now))
75-
.issueTime(Date.from(now))
76-
.jwtID(UUID.randomUUID().toString())
77-
.audience(oAuth2TokenCallback.audience(tokenRequest))
78-
.addClaims(oAuth2TokenCallback.addClaims(tokenRequest))
79-
.build()
80-
)
67+
JWTClaimsSet.Builder(claimsSet)
68+
.issuer(issuerUrl.toString())
69+
.expirationTime(Date.from(now.plusSeconds(oAuth2TokenCallback.tokenExpiry())))
70+
.notBeforeTime(Date.from(now))
71+
.issueTime(Date.from(now))
72+
.jwtID(UUID.randomUUID().toString())
73+
.audience(oAuth2TokenCallback.audience(tokenRequest))
74+
.addClaims(oAuth2TokenCallback.addClaims(tokenRequest))
75+
.build()
76+
.sign(issuerUrl.issuerId())
8177
}
8278

83-
fun jwt(claims: Map<String, Any>, expiry: Duration = Duration.ofHours(1)): SignedJWT =
79+
@JvmOverloads
80+
fun jwt(claims: Map<String, Any>, expiry: Duration = Duration.ofHours(1), issuerId: String = "default"): SignedJWT =
8481
JWTClaimsSet.Builder().let { builder ->
8582
val now = Instant.now()
8683
builder
@@ -89,18 +86,20 @@ class OAuth2TokenProvider {
8986
.expirationTime(Date.from(now.plusSeconds(expiry.toSeconds())))
9087
builder.addClaims(claims)
9188
builder.build()
92-
}.let {
93-
createSignedJWT(it)
94-
}
89+
}.sign(issuerId)
90+
91+
private fun rsaKey(issuerId: String): RSAKey = signingKeys.computeIfAbsent(issuerId) { generateRSAKey(issuerId) }
9592

96-
private fun createSignedJWT(claimsSet: JWTClaimsSet): SignedJWT {
97-
val header = JWSHeader.Builder(JWSAlgorithm.RS256)
98-
.keyID(rsaKey.keyID)
99-
.type(JOSEObjectType.JWT)
100-
val signedJWT = SignedJWT(header.build(), claimsSet)
101-
val signer = RSASSASigner(rsaKey.toPrivateKey())
102-
signedJWT.sign(signer)
103-
return signedJWT
93+
private fun JWTClaimsSet.sign(issuerId: String): SignedJWT {
94+
val key = rsaKey(issuerId)
95+
return SignedJWT(
96+
JWSHeader.Builder(JWSAlgorithm.RS256)
97+
.keyID(key.keyID)
98+
.type(JOSEObjectType.JWT).build(),
99+
this
100+
).apply {
101+
sign(RSASSASigner(key.toPrivateKey()))
102+
}
104103
}
105104

106105
private fun JWTClaimsSet.Builder.addClaims(claims: Map<String, Any> = emptyMap()) = apply {
@@ -130,21 +129,16 @@ class OAuth2TokenProvider {
130129
}
131130

132131
companion object {
133-
private const val DEFAULT_KEYID = "mock-oauth2-server-key"
134-
private fun generateJWKSet(keyId: String) =
135-
JWKSet(createRSAKey(keyId, generateKeyPair()))
136-
137-
private fun generateKeyPair(): KeyPair =
132+
private fun generateRSAKey(keyId: String): RSAKey =
138133
KeyPairGenerator.getInstance("RSA").let {
139134
it.initialize(2048)
140135
it.generateKeyPair()
136+
}.let {
137+
RSAKey.Builder(it.public as RSAPublicKey)
138+
.privateKey(it.private as RSAPrivateKey)
139+
.keyUse(KeyUse.SIGNATURE)
140+
.keyID(keyId)
141+
.build()
141142
}
142-
143-
private fun createRSAKey(keyID: String, keyPair: KeyPair) =
144-
RSAKey.Builder(keyPair.public as RSAPublicKey)
145-
.privateKey(keyPair.private as RSAPrivateKey)
146-
.keyUse(KeyUse.SIGNATURE)
147-
.keyID(keyID)
148-
.build()
149143
}
150144
}

‎src/test/kotlin/no/nav/security/mock/oauth2/e2e/MockOAuth2ServerIntegrationTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class MockOAuth2ServerIntegrationTest {
143143
@Test
144144
fun `anyToken should issue token with claims from input and be verifyable by servers keys`() {
145145
withMockOAuth2Server {
146-
val customIssuer = "https://customissuer".toHttpUrl()
146+
val customIssuer = "https://customissuer/default".toHttpUrl()
147147
val token = this.anyToken(
148148
customIssuer,
149149
mutableMapOf(

‎src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProviderTest.kt

+45-5
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,33 @@ package no.nav.security.mock.oauth2.token
22

33
import com.nimbusds.jose.jwk.KeyType
44
import com.nimbusds.jose.jwk.KeyUse
5+
import com.nimbusds.jwt.SignedJWT
56
import com.nimbusds.oauth2.sdk.GrantType
7+
import com.nimbusds.oauth2.sdk.id.Issuer
68
import io.kotest.assertions.asClue
9+
import io.kotest.assertions.throwables.shouldThrow
710
import io.kotest.matchers.shouldBe
811
import io.kotest.matchers.shouldNotBe
12+
import no.nav.security.mock.oauth2.OAuth2Exception
13+
import no.nav.security.mock.oauth2.extensions.verifySignatureAndIssuer
914
import no.nav.security.mock.oauth2.testutils.nimbusTokenRequest
1015
import okhttp3.HttpUrl.Companion.toHttpUrl
1116
import org.junit.jupiter.api.Test
17+
import org.junit.jupiter.params.ParameterizedTest
18+
import org.junit.jupiter.params.provider.ValueSource
1219

1320
internal class OAuth2TokenProviderTest {
1421
private val tokenProvider = OAuth2TokenProvider()
15-
private val jwkSet = tokenProvider.publicJwkSet()
1622

1723
@Test
18-
fun `public jwks returns public part of JWKs`() =
24+
fun `public jwks returns public part of JWKs`() {
25+
val jwkSet = tokenProvider.publicJwkSet()
1926
jwkSet.keys.any { it.isPrivate } shouldNotBe true
27+
}
2028

2129
@Test
2230
fun `all keys in public jwks should contain kty, use and kid`() {
31+
val jwkSet = tokenProvider.publicJwkSet()
2332
jwkSet.keys.forEach {
2433
it.keyID shouldNotBe null
2534
it.keyType shouldBe KeyType.RSA
@@ -45,9 +54,9 @@ internal class OAuth2TokenProviderTest {
4554
"scope" to "scope1",
4655
"assertion" to initialToken.serialize()
4756
),
48-
"http://default_if_not_overridden".toHttpUrl(),
49-
initialToken.jwtClaimsSet,
50-
DefaultOAuth2TokenCallback(
57+
issuerUrl = "http://default_if_not_overridden".toHttpUrl(),
58+
claimsSet = initialToken.jwtClaimsSet,
59+
oAuth2TokenCallback = DefaultOAuth2TokenCallback(
5160
claims = mapOf(
5261
"extraclaim" to "extra",
5362
"iss" to "http://overrideissuer"
@@ -61,4 +70,35 @@ internal class OAuth2TokenProviderTest {
6170
it.claims["extraclaim"] shouldBe "extra"
6271
}
6372
}
73+
74+
@Test
75+
fun `publicJwks should return different signing key for each issuerId`() {
76+
val keys1 = tokenProvider.publicJwkSet("issuer1").toJSONObject()
77+
keys1 shouldBe tokenProvider.publicJwkSet("issuer1").toJSONObject()
78+
val keys2 = tokenProvider.publicJwkSet("issuer2").toJSONObject()
79+
keys2 shouldNotBe keys1
80+
}
81+
82+
@ParameterizedTest
83+
@ValueSource(strings = ["issuer1", "issuer2"])
84+
fun `ensure idToken is signed with same key as returned from public jwks`(issuerId: String) {
85+
86+
val issuer = Issuer("http://localhost/$issuerId")
87+
idToken(issuer.toString()).verifySignatureAndIssuer(issuer, tokenProvider.publicJwkSet(issuerId))
88+
89+
shouldThrow<OAuth2Exception> {
90+
idToken(issuer.toString()).verifySignatureAndIssuer(issuer, tokenProvider.publicJwkSet("shouldfail"))
91+
}
92+
}
93+
94+
private fun idToken(issuerUrl: String): SignedJWT =
95+
tokenProvider.idToken(
96+
tokenRequest = nimbusTokenRequest(
97+
"client1",
98+
"grant_type" to "authorization_code",
99+
"code" to "123"
100+
),
101+
issuerUrl = issuerUrl.toHttpUrl(),
102+
oAuth2TokenCallback = DefaultOAuth2TokenCallback()
103+
)
64104
}

0 commit comments

Comments
 (0)
Please sign in to comment.