From 2751a33ce273289c1f86767eb884d7e6241607a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tommy=20Tr=C3=B8en?= Date: Wed, 19 May 2021 14:50:39 +0200 Subject: [PATCH 1/3] tests: add junit-jupiter-params --- build.gradle.kts | 1 + 1 file changed, 1 insertion(+) diff --git a/build.gradle.kts b/build.gradle.kts index 33aaf182..9007b6e1 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -63,6 +63,7 @@ dependencies { implementation("org.freemarker:freemarker:$freemarkerVersion") testImplementation("org.assertj:assertj-core:$assertjVersion") testImplementation("org.junit.jupiter:junit-jupiter-api:$junitJupiterVersion") + testImplementation("org.junit.jupiter:junit-jupiter-params:$junitJupiterVersion") testImplementation("io.kotest:kotest-runner-junit5-jvm:$kotestVersion") // for kotest framework testImplementation("io.kotest:kotest-assertions-core-jvm:$kotestVersion") // for kotest core jvm assertions testImplementation("org.jetbrains.kotlin:kotlin-test-junit5:$kotlinVersion") From 430641f599295f4b8e0c61281c4d9baf4153cb64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tommy=20Tr=C3=B8en?= Date: Wed, 19 May 2021 14:51:04 +0200 Subject: [PATCH 2/3] chore: formatting --- .../kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequest.kt | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequest.kt b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequest.kt index 38b17657..7464b2f3 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequest.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequest.kt @@ -2,7 +2,6 @@ package no.nav.security.mock.oauth2.http import com.nimbusds.oauth2.sdk.GrantType import com.nimbusds.oauth2.sdk.TokenRequest -import com.nimbusds.oauth2.sdk.auth.ClientAuthentication import com.nimbusds.oauth2.sdk.http.HTTPRequest import com.nimbusds.openid.connect.sdk.AuthenticationRequest import no.nav.security.mock.oauth2.extensions.clientAuthentication From 23a9d9f726627bf8e006c38539833fd8b30a29f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tommy=20Tr=C3=B8en?= Date: Wed, 19 May 2021 14:54:47 +0200 Subject: [PATCH 3/3] feat: fix #43 - support different signing keys per issuer (issuerId) * signing key per issuerId will be generated on-demand and reused --- .../oauth2/http/OAuth2HttpRequestHandler.kt | 9 +- .../mock/oauth2/token/OAuth2TokenProvider.kt | 114 +++++++++--------- .../e2e/MockOAuth2ServerIntegrationTest.kt | 2 +- .../oauth2/token/OAuth2TokenProviderTest.kt | 50 +++++++- 4 files changed, 108 insertions(+), 67 deletions(-) diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequestHandler.kt b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequestHandler.kt index 897c612f..5b61dabe 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequestHandler.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequestHandler.kt @@ -70,7 +70,7 @@ class OAuth2HttpRequestHandler( AUTHORIZATION -> handleAuthenticationRequest(request) TOKEN -> handleTokenRequest(request) END_SESSION -> handleEndSessionRequest(request) - JWKS -> json(config.tokenProvider.publicJwkSet().toJSONObject()).also { log.debug("handle jwks request") } + JWKS -> handleJwksRequest(request) DEBUGGER -> debuggerRequestHandler.handleDebuggerForm(request).also { log.debug("handle debugger request") } DEBUGGER_CALLBACK -> debuggerRequestHandler.handleDebuggerCallback(request).also { log.debug("handle debugger callback request") } FAVICON -> OAuth2HttpResponse(status = 200) @@ -84,6 +84,13 @@ class OAuth2HttpRequestHandler( fun enqueueTokenCallback(oAuth2TokenCallback: OAuth2TokenCallback) = tokenCallbackQueue.add(oAuth2TokenCallback) + private fun handleJwksRequest(request: OAuth2HttpRequest): OAuth2HttpResponse { + log.debug("handle jwks request on url=${request.url}") + val issuerId = request.url.issuerId() + val jwkSet = config.tokenProvider.publicJwkSet(issuerId) + return json(jwkSet.toJSONObject()) + } + private fun handleEndSessionRequest(request: OAuth2HttpRequest): OAuth2HttpResponse { log.debug("handle end session request $request") val postLogoutRedirectUri = request.url.queryParameter("post_logout_redirect_uri") diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt index 553b1cb6..60b3b16e 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProvider.kt @@ -11,8 +11,8 @@ import com.nimbusds.jwt.JWTClaimsSet import com.nimbusds.jwt.SignedJWT import com.nimbusds.oauth2.sdk.TokenRequest import no.nav.security.mock.oauth2.extensions.clientIdAsString +import no.nav.security.mock.oauth2.extensions.issuerId import okhttp3.HttpUrl -import java.security.KeyPair import java.security.KeyPairGenerator import java.security.interfaces.RSAPrivateKey import java.security.interfaces.RSAPublicKey @@ -20,13 +20,14 @@ import java.time.Duration import java.time.Instant import java.util.Date import java.util.UUID +import java.util.concurrent.ConcurrentHashMap class OAuth2TokenProvider { - private val jwkSet: JWKSet = generateJWKSet(DEFAULT_KEYID) - private val rsaKey: RSAKey = jwkSet.getKeyByKeyId(DEFAULT_KEYID) as RSAKey + private val signingKeys: ConcurrentHashMap = ConcurrentHashMap() - fun publicJwkSet(): JWKSet { - return jwkSet.toPublicJWKSet() + @JvmOverloads + fun publicJwkSet(issuerId: String = "default"): JWKSet { + return JWKSet(rsaKey(issuerId)).toPublicJWKSet() } fun idToken( @@ -34,32 +35,28 @@ class OAuth2TokenProvider { issuerUrl: HttpUrl, oAuth2TokenCallback: OAuth2TokenCallback, nonce: String? = null - ) = createSignedJWT( - defaultClaims( - issuerUrl, - oAuth2TokenCallback.subject(tokenRequest), - listOf(tokenRequest.clientIdAsString()), - nonce, - oAuth2TokenCallback.addClaims(tokenRequest), - oAuth2TokenCallback.tokenExpiry() - ) - ) + ) = defaultClaims( + issuerUrl, + oAuth2TokenCallback.subject(tokenRequest), + listOf(tokenRequest.clientIdAsString()), + nonce, + oAuth2TokenCallback.addClaims(tokenRequest), + oAuth2TokenCallback.tokenExpiry() + ).sign(issuerUrl.issuerId()) fun accessToken( tokenRequest: TokenRequest, issuerUrl: HttpUrl, oAuth2TokenCallback: OAuth2TokenCallback, nonce: String? = null - ) = createSignedJWT( - defaultClaims( - issuerUrl, - oAuth2TokenCallback.subject(tokenRequest), - oAuth2TokenCallback.audience(tokenRequest), - nonce, - oAuth2TokenCallback.addClaims(tokenRequest), - oAuth2TokenCallback.tokenExpiry() - ) - ) + ) = defaultClaims( + issuerUrl, + oAuth2TokenCallback.subject(tokenRequest), + oAuth2TokenCallback.audience(tokenRequest), + nonce, + oAuth2TokenCallback.addClaims(tokenRequest), + oAuth2TokenCallback.tokenExpiry() + ).sign(issuerUrl.issuerId()) fun exchangeAccessToken( tokenRequest: TokenRequest, @@ -67,20 +64,20 @@ class OAuth2TokenProvider { claimsSet: JWTClaimsSet, oAuth2TokenCallback: OAuth2TokenCallback ) = Instant.now().let { now -> - createSignedJWT( - JWTClaimsSet.Builder(claimsSet) - .issuer(issuerUrl.toString()) - .expirationTime(Date.from(now.plusSeconds(oAuth2TokenCallback.tokenExpiry()))) - .notBeforeTime(Date.from(now)) - .issueTime(Date.from(now)) - .jwtID(UUID.randomUUID().toString()) - .audience(oAuth2TokenCallback.audience(tokenRequest)) - .addClaims(oAuth2TokenCallback.addClaims(tokenRequest)) - .build() - ) + JWTClaimsSet.Builder(claimsSet) + .issuer(issuerUrl.toString()) + .expirationTime(Date.from(now.plusSeconds(oAuth2TokenCallback.tokenExpiry()))) + .notBeforeTime(Date.from(now)) + .issueTime(Date.from(now)) + .jwtID(UUID.randomUUID().toString()) + .audience(oAuth2TokenCallback.audience(tokenRequest)) + .addClaims(oAuth2TokenCallback.addClaims(tokenRequest)) + .build() + .sign(issuerUrl.issuerId()) } - fun jwt(claims: Map, expiry: Duration = Duration.ofHours(1)): SignedJWT = + @JvmOverloads + fun jwt(claims: Map, expiry: Duration = Duration.ofHours(1), issuerId: String = "default"): SignedJWT = JWTClaimsSet.Builder().let { builder -> val now = Instant.now() builder @@ -89,18 +86,20 @@ class OAuth2TokenProvider { .expirationTime(Date.from(now.plusSeconds(expiry.toSeconds()))) builder.addClaims(claims) builder.build() - }.let { - createSignedJWT(it) - } + }.sign(issuerId) + + private fun rsaKey(issuerId: String): RSAKey = signingKeys.computeIfAbsent(issuerId) { generateRSAKey(issuerId) } - private fun createSignedJWT(claimsSet: JWTClaimsSet): SignedJWT { - val header = JWSHeader.Builder(JWSAlgorithm.RS256) - .keyID(rsaKey.keyID) - .type(JOSEObjectType.JWT) - val signedJWT = SignedJWT(header.build(), claimsSet) - val signer = RSASSASigner(rsaKey.toPrivateKey()) - signedJWT.sign(signer) - return signedJWT + private fun JWTClaimsSet.sign(issuerId: String): SignedJWT { + val key = rsaKey(issuerId) + return SignedJWT( + JWSHeader.Builder(JWSAlgorithm.RS256) + .keyID(key.keyID) + .type(JOSEObjectType.JWT).build(), + this + ).apply { + sign(RSASSASigner(key.toPrivateKey())) + } } private fun JWTClaimsSet.Builder.addClaims(claims: Map = emptyMap()) = apply { @@ -130,21 +129,16 @@ class OAuth2TokenProvider { } companion object { - private const val DEFAULT_KEYID = "mock-oauth2-server-key" - private fun generateJWKSet(keyId: String) = - JWKSet(createRSAKey(keyId, generateKeyPair())) - - private fun generateKeyPair(): KeyPair = + private fun generateRSAKey(keyId: String): RSAKey = KeyPairGenerator.getInstance("RSA").let { it.initialize(2048) it.generateKeyPair() + }.let { + RSAKey.Builder(it.public as RSAPublicKey) + .privateKey(it.private as RSAPrivateKey) + .keyUse(KeyUse.SIGNATURE) + .keyID(keyId) + .build() } - - private fun createRSAKey(keyID: String, keyPair: KeyPair) = - RSAKey.Builder(keyPair.public as RSAPublicKey) - .privateKey(keyPair.private as RSAPrivateKey) - .keyUse(KeyUse.SIGNATURE) - .keyID(keyID) - .build() } } diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/e2e/MockOAuth2ServerIntegrationTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/e2e/MockOAuth2ServerIntegrationTest.kt index 1ea3c0b5..f1629925 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/e2e/MockOAuth2ServerIntegrationTest.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/e2e/MockOAuth2ServerIntegrationTest.kt @@ -143,7 +143,7 @@ class MockOAuth2ServerIntegrationTest { @Test fun `anyToken should issue token with claims from input and be verifyable by servers keys`() { withMockOAuth2Server { - val customIssuer = "https://customissuer".toHttpUrl() + val customIssuer = "https://customissuer/default".toHttpUrl() val token = this.anyToken( customIssuer, mutableMapOf( diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProviderTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProviderTest.kt index 29bdf7d6..5e9eeade 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProviderTest.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenProviderTest.kt @@ -2,24 +2,33 @@ package no.nav.security.mock.oauth2.token import com.nimbusds.jose.jwk.KeyType import com.nimbusds.jose.jwk.KeyUse +import com.nimbusds.jwt.SignedJWT import com.nimbusds.oauth2.sdk.GrantType +import com.nimbusds.oauth2.sdk.id.Issuer import io.kotest.assertions.asClue +import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe import io.kotest.matchers.shouldNotBe +import no.nav.security.mock.oauth2.OAuth2Exception +import no.nav.security.mock.oauth2.extensions.verifySignatureAndIssuer import no.nav.security.mock.oauth2.testutils.nimbusTokenRequest import okhttp3.HttpUrl.Companion.toHttpUrl import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource internal class OAuth2TokenProviderTest { private val tokenProvider = OAuth2TokenProvider() - private val jwkSet = tokenProvider.publicJwkSet() @Test - fun `public jwks returns public part of JWKs`() = + fun `public jwks returns public part of JWKs`() { + val jwkSet = tokenProvider.publicJwkSet() jwkSet.keys.any { it.isPrivate } shouldNotBe true + } @Test fun `all keys in public jwks should contain kty, use and kid`() { + val jwkSet = tokenProvider.publicJwkSet() jwkSet.keys.forEach { it.keyID shouldNotBe null it.keyType shouldBe KeyType.RSA @@ -45,9 +54,9 @@ internal class OAuth2TokenProviderTest { "scope" to "scope1", "assertion" to initialToken.serialize() ), - "http://default_if_not_overridden".toHttpUrl(), - initialToken.jwtClaimsSet, - DefaultOAuth2TokenCallback( + issuerUrl = "http://default_if_not_overridden".toHttpUrl(), + claimsSet = initialToken.jwtClaimsSet, + oAuth2TokenCallback = DefaultOAuth2TokenCallback( claims = mapOf( "extraclaim" to "extra", "iss" to "http://overrideissuer" @@ -61,4 +70,35 @@ internal class OAuth2TokenProviderTest { it.claims["extraclaim"] shouldBe "extra" } } + + @Test + fun `publicJwks should return different signing key for each issuerId`() { + val keys1 = tokenProvider.publicJwkSet("issuer1").toJSONObject() + keys1 shouldBe tokenProvider.publicJwkSet("issuer1").toJSONObject() + val keys2 = tokenProvider.publicJwkSet("issuer2").toJSONObject() + keys2 shouldNotBe keys1 + } + + @ParameterizedTest + @ValueSource(strings = ["issuer1", "issuer2"]) + fun `ensure idToken is signed with same key as returned from public jwks`(issuerId: String) { + + val issuer = Issuer("http://localhost/$issuerId") + idToken(issuer.toString()).verifySignatureAndIssuer(issuer, tokenProvider.publicJwkSet(issuerId)) + + shouldThrow { + idToken(issuer.toString()).verifySignatureAndIssuer(issuer, tokenProvider.publicJwkSet("shouldfail")) + } + } + + private fun idToken(issuerUrl: String): SignedJWT = + tokenProvider.idToken( + tokenRequest = nimbusTokenRequest( + "client1", + "grant_type" to "authorization_code", + "code" to "123" + ), + issuerUrl = issuerUrl.toHttpUrl(), + oAuth2TokenCallback = DefaultOAuth2TokenCallback() + ) }