Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different signing keys per issuer (issuerId) #44

Merged
merged 3 commits into from
May 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ dependencies {
implementation("org.freemarker:freemarker:$freemarkerVersion")
testImplementation("org.assertj:assertj-core:$assertjVersion")
testImplementation("org.junit.jupiter:junit-jupiter-api:$junitJupiterVersion")
testImplementation("org.junit.jupiter:junit-jupiter-params:$junitJupiterVersion")
testImplementation("io.kotest:kotest-runner-junit5-jvm:$kotestVersion") // for kotest framework
testImplementation("io.kotest:kotest-assertions-core-jvm:$kotestVersion") // for kotest core jvm assertions
testImplementation("org.jetbrains.kotlin:kotlin-test-junit5:$kotlinVersion")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package no.nav.security.mock.oauth2.http

import com.nimbusds.oauth2.sdk.GrantType
import com.nimbusds.oauth2.sdk.TokenRequest
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication
import com.nimbusds.oauth2.sdk.http.HTTPRequest
import com.nimbusds.openid.connect.sdk.AuthenticationRequest
import no.nav.security.mock.oauth2.extensions.clientAuthentication
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class OAuth2HttpRequestHandler(
AUTHORIZATION -> handleAuthenticationRequest(request)
TOKEN -> handleTokenRequest(request)
END_SESSION -> handleEndSessionRequest(request)
JWKS -> json(config.tokenProvider.publicJwkSet().toJSONObject()).also { log.debug("handle jwks request") }
JWKS -> handleJwksRequest(request)
DEBUGGER -> debuggerRequestHandler.handleDebuggerForm(request).also { log.debug("handle debugger request") }
DEBUGGER_CALLBACK -> debuggerRequestHandler.handleDebuggerCallback(request).also { log.debug("handle debugger callback request") }
FAVICON -> OAuth2HttpResponse(status = 200)
Expand All @@ -84,6 +84,13 @@ class OAuth2HttpRequestHandler(

fun enqueueTokenCallback(oAuth2TokenCallback: OAuth2TokenCallback) = tokenCallbackQueue.add(oAuth2TokenCallback)

private fun handleJwksRequest(request: OAuth2HttpRequest): OAuth2HttpResponse {
log.debug("handle jwks request on url=${request.url}")
val issuerId = request.url.issuerId()
val jwkSet = config.tokenProvider.publicJwkSet(issuerId)
return json(jwkSet.toJSONObject())
}

private fun handleEndSessionRequest(request: OAuth2HttpRequest): OAuth2HttpResponse {
log.debug("handle end session request $request")
val postLogoutRedirectUri = request.url.queryParameter("post_logout_redirect_uri")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,76 +11,73 @@ import com.nimbusds.jwt.JWTClaimsSet
import com.nimbusds.jwt.SignedJWT
import com.nimbusds.oauth2.sdk.TokenRequest
import no.nav.security.mock.oauth2.extensions.clientIdAsString
import no.nav.security.mock.oauth2.extensions.issuerId
import okhttp3.HttpUrl
import java.security.KeyPair
import java.security.KeyPairGenerator
import java.security.interfaces.RSAPrivateKey
import java.security.interfaces.RSAPublicKey
import java.time.Duration
import java.time.Instant
import java.util.Date
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap

class OAuth2TokenProvider {
private val jwkSet: JWKSet = generateJWKSet(DEFAULT_KEYID)
private val rsaKey: RSAKey = jwkSet.getKeyByKeyId(DEFAULT_KEYID) as RSAKey
private val signingKeys: ConcurrentHashMap<String, RSAKey> = ConcurrentHashMap()

fun publicJwkSet(): JWKSet {
return jwkSet.toPublicJWKSet()
@JvmOverloads
fun publicJwkSet(issuerId: String = "default"): JWKSet {
return JWKSet(rsaKey(issuerId)).toPublicJWKSet()
}

fun idToken(
tokenRequest: TokenRequest,
issuerUrl: HttpUrl,
oAuth2TokenCallback: OAuth2TokenCallback,
nonce: String? = null
) = createSignedJWT(
defaultClaims(
issuerUrl,
oAuth2TokenCallback.subject(tokenRequest),
listOf(tokenRequest.clientIdAsString()),
nonce,
oAuth2TokenCallback.addClaims(tokenRequest),
oAuth2TokenCallback.tokenExpiry()
)
)
) = defaultClaims(
issuerUrl,
oAuth2TokenCallback.subject(tokenRequest),
listOf(tokenRequest.clientIdAsString()),
nonce,
oAuth2TokenCallback.addClaims(tokenRequest),
oAuth2TokenCallback.tokenExpiry()
).sign(issuerUrl.issuerId())

fun accessToken(
tokenRequest: TokenRequest,
issuerUrl: HttpUrl,
oAuth2TokenCallback: OAuth2TokenCallback,
nonce: String? = null
) = createSignedJWT(
defaultClaims(
issuerUrl,
oAuth2TokenCallback.subject(tokenRequest),
oAuth2TokenCallback.audience(tokenRequest),
nonce,
oAuth2TokenCallback.addClaims(tokenRequest),
oAuth2TokenCallback.tokenExpiry()
)
)
) = defaultClaims(
issuerUrl,
oAuth2TokenCallback.subject(tokenRequest),
oAuth2TokenCallback.audience(tokenRequest),
nonce,
oAuth2TokenCallback.addClaims(tokenRequest),
oAuth2TokenCallback.tokenExpiry()
).sign(issuerUrl.issuerId())

fun exchangeAccessToken(
tokenRequest: TokenRequest,
issuerUrl: HttpUrl,
claimsSet: JWTClaimsSet,
oAuth2TokenCallback: OAuth2TokenCallback
) = Instant.now().let { now ->
createSignedJWT(
JWTClaimsSet.Builder(claimsSet)
.issuer(issuerUrl.toString())
.expirationTime(Date.from(now.plusSeconds(oAuth2TokenCallback.tokenExpiry())))
.notBeforeTime(Date.from(now))
.issueTime(Date.from(now))
.jwtID(UUID.randomUUID().toString())
.audience(oAuth2TokenCallback.audience(tokenRequest))
.addClaims(oAuth2TokenCallback.addClaims(tokenRequest))
.build()
)
JWTClaimsSet.Builder(claimsSet)
.issuer(issuerUrl.toString())
.expirationTime(Date.from(now.plusSeconds(oAuth2TokenCallback.tokenExpiry())))
.notBeforeTime(Date.from(now))
.issueTime(Date.from(now))
.jwtID(UUID.randomUUID().toString())
.audience(oAuth2TokenCallback.audience(tokenRequest))
.addClaims(oAuth2TokenCallback.addClaims(tokenRequest))
.build()
.sign(issuerUrl.issuerId())
}

fun jwt(claims: Map<String, Any>, expiry: Duration = Duration.ofHours(1)): SignedJWT =
@JvmOverloads
fun jwt(claims: Map<String, Any>, expiry: Duration = Duration.ofHours(1), issuerId: String = "default"): SignedJWT =
JWTClaimsSet.Builder().let { builder ->
val now = Instant.now()
builder
Expand All @@ -89,18 +86,20 @@ class OAuth2TokenProvider {
.expirationTime(Date.from(now.plusSeconds(expiry.toSeconds())))
builder.addClaims(claims)
builder.build()
}.let {
createSignedJWT(it)
}
}.sign(issuerId)

private fun rsaKey(issuerId: String): RSAKey = signingKeys.computeIfAbsent(issuerId) { generateRSAKey(issuerId) }

private fun createSignedJWT(claimsSet: JWTClaimsSet): SignedJWT {
val header = JWSHeader.Builder(JWSAlgorithm.RS256)
.keyID(rsaKey.keyID)
.type(JOSEObjectType.JWT)
val signedJWT = SignedJWT(header.build(), claimsSet)
val signer = RSASSASigner(rsaKey.toPrivateKey())
signedJWT.sign(signer)
return signedJWT
private fun JWTClaimsSet.sign(issuerId: String): SignedJWT {
val key = rsaKey(issuerId)
return SignedJWT(
JWSHeader.Builder(JWSAlgorithm.RS256)
.keyID(key.keyID)
.type(JOSEObjectType.JWT).build(),
this
).apply {
sign(RSASSASigner(key.toPrivateKey()))
}
}

private fun JWTClaimsSet.Builder.addClaims(claims: Map<String, Any> = emptyMap()) = apply {
Expand Down Expand Up @@ -130,21 +129,16 @@ class OAuth2TokenProvider {
}

companion object {
private const val DEFAULT_KEYID = "mock-oauth2-server-key"
private fun generateJWKSet(keyId: String) =
JWKSet(createRSAKey(keyId, generateKeyPair()))

private fun generateKeyPair(): KeyPair =
private fun generateRSAKey(keyId: String): RSAKey =
KeyPairGenerator.getInstance("RSA").let {
it.initialize(2048)
it.generateKeyPair()
}.let {
RSAKey.Builder(it.public as RSAPublicKey)
.privateKey(it.private as RSAPrivateKey)
.keyUse(KeyUse.SIGNATURE)
.keyID(keyId)
.build()
}

private fun createRSAKey(keyID: String, keyPair: KeyPair) =
RSAKey.Builder(keyPair.public as RSAPublicKey)
.privateKey(keyPair.private as RSAPrivateKey)
.keyUse(KeyUse.SIGNATURE)
.keyID(keyID)
.build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class MockOAuth2ServerIntegrationTest {
@Test
fun `anyToken should issue token with claims from input and be verifyable by servers keys`() {
withMockOAuth2Server {
val customIssuer = "https://customissuer".toHttpUrl()
val customIssuer = "https://customissuer/default".toHttpUrl()
val token = this.anyToken(
customIssuer,
mutableMapOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,33 @@ package no.nav.security.mock.oauth2.token

import com.nimbusds.jose.jwk.KeyType
import com.nimbusds.jose.jwk.KeyUse
import com.nimbusds.jwt.SignedJWT
import com.nimbusds.oauth2.sdk.GrantType
import com.nimbusds.oauth2.sdk.id.Issuer
import io.kotest.assertions.asClue
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import io.kotest.matchers.shouldNotBe
import no.nav.security.mock.oauth2.OAuth2Exception
import no.nav.security.mock.oauth2.extensions.verifySignatureAndIssuer
import no.nav.security.mock.oauth2.testutils.nimbusTokenRequest
import okhttp3.HttpUrl.Companion.toHttpUrl
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource

internal class OAuth2TokenProviderTest {
private val tokenProvider = OAuth2TokenProvider()
private val jwkSet = tokenProvider.publicJwkSet()

@Test
fun `public jwks returns public part of JWKs`() =
fun `public jwks returns public part of JWKs`() {
val jwkSet = tokenProvider.publicJwkSet()
jwkSet.keys.any { it.isPrivate } shouldNotBe true
}

@Test
fun `all keys in public jwks should contain kty, use and kid`() {
val jwkSet = tokenProvider.publicJwkSet()
jwkSet.keys.forEach {
it.keyID shouldNotBe null
it.keyType shouldBe KeyType.RSA
Expand All @@ -45,9 +54,9 @@ internal class OAuth2TokenProviderTest {
"scope" to "scope1",
"assertion" to initialToken.serialize()
),
"http://default_if_not_overridden".toHttpUrl(),
initialToken.jwtClaimsSet,
DefaultOAuth2TokenCallback(
issuerUrl = "http://default_if_not_overridden".toHttpUrl(),
claimsSet = initialToken.jwtClaimsSet,
oAuth2TokenCallback = DefaultOAuth2TokenCallback(
claims = mapOf(
"extraclaim" to "extra",
"iss" to "http://overrideissuer"
Expand All @@ -61,4 +70,35 @@ internal class OAuth2TokenProviderTest {
it.claims["extraclaim"] shouldBe "extra"
}
}

@Test
fun `publicJwks should return different signing key for each issuerId`() {
val keys1 = tokenProvider.publicJwkSet("issuer1").toJSONObject()
keys1 shouldBe tokenProvider.publicJwkSet("issuer1").toJSONObject()
val keys2 = tokenProvider.publicJwkSet("issuer2").toJSONObject()
keys2 shouldNotBe keys1
}

@ParameterizedTest
@ValueSource(strings = ["issuer1", "issuer2"])
fun `ensure idToken is signed with same key as returned from public jwks`(issuerId: String) {

val issuer = Issuer("http://localhost/$issuerId")
idToken(issuer.toString()).verifySignatureAndIssuer(issuer, tokenProvider.publicJwkSet(issuerId))

shouldThrow<OAuth2Exception> {
idToken(issuer.toString()).verifySignatureAndIssuer(issuer, tokenProvider.publicJwkSet("shouldfail"))
}
}

private fun idToken(issuerUrl: String): SignedJWT =
tokenProvider.idToken(
tokenRequest = nimbusTokenRequest(
"client1",
"grant_type" to "authorization_code",
"code" to "123"
),
issuerUrl = issuerUrl.toHttpUrl(),
oAuth2TokenCallback = DefaultOAuth2TokenCallback()
)
}