@@ -11,76 +11,73 @@ import com.nimbusds.jwt.JWTClaimsSet
11
11
import com.nimbusds.jwt.SignedJWT
12
12
import com.nimbusds.oauth2.sdk.TokenRequest
13
13
import no.nav.security.mock.oauth2.extensions.clientIdAsString
14
+ import no.nav.security.mock.oauth2.extensions.issuerId
14
15
import okhttp3.HttpUrl
15
- import java.security.KeyPair
16
16
import java.security.KeyPairGenerator
17
17
import java.security.interfaces.RSAPrivateKey
18
18
import java.security.interfaces.RSAPublicKey
19
19
import java.time.Duration
20
20
import java.time.Instant
21
21
import java.util.Date
22
22
import java.util.UUID
23
+ import java.util.concurrent.ConcurrentHashMap
23
24
24
25
class OAuth2TokenProvider {
25
- private val jwkSet: JWKSet = generateJWKSet(DEFAULT_KEYID )
26
- private val rsaKey: RSAKey = jwkSet.getKeyByKeyId(DEFAULT_KEYID ) as RSAKey
26
+ private val signingKeys: ConcurrentHashMap <String , RSAKey > = ConcurrentHashMap ()
27
27
28
- fun publicJwkSet (): JWKSet {
29
- return jwkSet.toPublicJWKSet()
28
+ @JvmOverloads
29
+ fun publicJwkSet (issuerId : String = "default"): JWKSet {
30
+ return JWKSet (rsaKey(issuerId)).toPublicJWKSet()
30
31
}
31
32
32
33
fun idToken (
33
34
tokenRequest : TokenRequest ,
34
35
issuerUrl : HttpUrl ,
35
36
oAuth2TokenCallback : OAuth2TokenCallback ,
36
37
nonce : String? = null
37
- ) = createSignedJWT(
38
- defaultClaims(
39
- issuerUrl,
40
- oAuth2TokenCallback.subject(tokenRequest),
41
- listOf (tokenRequest.clientIdAsString()),
42
- nonce,
43
- oAuth2TokenCallback.addClaims(tokenRequest),
44
- oAuth2TokenCallback.tokenExpiry()
45
- )
46
- )
38
+ ) = defaultClaims(
39
+ issuerUrl,
40
+ oAuth2TokenCallback.subject(tokenRequest),
41
+ listOf (tokenRequest.clientIdAsString()),
42
+ nonce,
43
+ oAuth2TokenCallback.addClaims(tokenRequest),
44
+ oAuth2TokenCallback.tokenExpiry()
45
+ ).sign(issuerUrl.issuerId())
47
46
48
47
fun accessToken (
49
48
tokenRequest : TokenRequest ,
50
49
issuerUrl : HttpUrl ,
51
50
oAuth2TokenCallback : OAuth2TokenCallback ,
52
51
nonce : String? = null
53
- ) = createSignedJWT(
54
- defaultClaims(
55
- issuerUrl,
56
- oAuth2TokenCallback.subject(tokenRequest),
57
- oAuth2TokenCallback.audience(tokenRequest),
58
- nonce,
59
- oAuth2TokenCallback.addClaims(tokenRequest),
60
- oAuth2TokenCallback.tokenExpiry()
61
- )
62
- )
52
+ ) = defaultClaims(
53
+ issuerUrl,
54
+ oAuth2TokenCallback.subject(tokenRequest),
55
+ oAuth2TokenCallback.audience(tokenRequest),
56
+ nonce,
57
+ oAuth2TokenCallback.addClaims(tokenRequest),
58
+ oAuth2TokenCallback.tokenExpiry()
59
+ ).sign(issuerUrl.issuerId())
63
60
64
61
fun exchangeAccessToken (
65
62
tokenRequest : TokenRequest ,
66
63
issuerUrl : HttpUrl ,
67
64
claimsSet : JWTClaimsSet ,
68
65
oAuth2TokenCallback : OAuth2TokenCallback
69
66
) = Instant .now().let { now ->
70
- createSignedJWT(
71
- JWTClaimsSet .Builder (claimsSet)
72
- .issuer(issuerUrl.toString())
73
- .expirationTime(Date .from(now.plusSeconds(oAuth2TokenCallback.tokenExpiry())))
74
- .notBeforeTime(Date .from(now))
75
- .issueTime(Date .from(now))
76
- .jwtID(UUID .randomUUID().toString())
77
- .audience(oAuth2TokenCallback.audience(tokenRequest))
78
- .addClaims(oAuth2TokenCallback.addClaims(tokenRequest))
79
- .build()
80
- )
67
+ JWTClaimsSet .Builder (claimsSet)
68
+ .issuer(issuerUrl.toString())
69
+ .expirationTime(Date .from(now.plusSeconds(oAuth2TokenCallback.tokenExpiry())))
70
+ .notBeforeTime(Date .from(now))
71
+ .issueTime(Date .from(now))
72
+ .jwtID(UUID .randomUUID().toString())
73
+ .audience(oAuth2TokenCallback.audience(tokenRequest))
74
+ .addClaims(oAuth2TokenCallback.addClaims(tokenRequest))
75
+ .build()
76
+ .sign(issuerUrl.issuerId())
81
77
}
82
78
83
- fun jwt (claims : Map <String , Any >, expiry : Duration = Duration .ofHours(1)): SignedJWT =
79
+ @JvmOverloads
80
+ fun jwt (claims : Map <String , Any >, expiry : Duration = Duration .ofHours(1), issuerId : String = "default"): SignedJWT =
84
81
JWTClaimsSet .Builder ().let { builder ->
85
82
val now = Instant .now()
86
83
builder
@@ -89,18 +86,20 @@ class OAuth2TokenProvider {
89
86
.expirationTime(Date .from(now.plusSeconds(expiry.toSeconds())))
90
87
builder.addClaims(claims)
91
88
builder.build()
92
- }.let {
93
- createSignedJWT(it)
94
- }
89
+ }.sign(issuerId)
90
+
91
+ private fun rsaKey ( issuerId : String ): RSAKey = signingKeys.computeIfAbsent(issuerId) { generateRSAKey(issuerId) }
95
92
96
- private fun createSignedJWT (claimsSet : JWTClaimsSet ): SignedJWT {
97
- val header = JWSHeader .Builder (JWSAlgorithm .RS256 )
98
- .keyID(rsaKey.keyID)
99
- .type(JOSEObjectType .JWT )
100
- val signedJWT = SignedJWT (header.build(), claimsSet)
101
- val signer = RSASSASigner (rsaKey.toPrivateKey())
102
- signedJWT.sign(signer)
103
- return signedJWT
93
+ private fun JWTClaimsSet.sign (issuerId : String ): SignedJWT {
94
+ val key = rsaKey(issuerId)
95
+ return SignedJWT (
96
+ JWSHeader .Builder (JWSAlgorithm .RS256 )
97
+ .keyID(key.keyID)
98
+ .type(JOSEObjectType .JWT ).build(),
99
+ this
100
+ ).apply {
101
+ sign(RSASSASigner (key.toPrivateKey()))
102
+ }
104
103
}
105
104
106
105
private fun JWTClaimsSet.Builder.addClaims (claims : Map <String , Any > = emptyMap()) = apply {
@@ -130,21 +129,16 @@ class OAuth2TokenProvider {
130
129
}
131
130
132
131
companion object {
133
- private const val DEFAULT_KEYID = " mock-oauth2-server-key"
134
- private fun generateJWKSet (keyId : String ) =
135
- JWKSet (createRSAKey(keyId, generateKeyPair()))
136
-
137
- private fun generateKeyPair (): KeyPair =
132
+ private fun generateRSAKey (keyId : String ): RSAKey =
138
133
KeyPairGenerator .getInstance(" RSA" ).let {
139
134
it.initialize(2048 )
140
135
it.generateKeyPair()
136
+ }.let {
137
+ RSAKey .Builder (it.public as RSAPublicKey )
138
+ .privateKey(it.private as RSAPrivateKey )
139
+ .keyUse(KeyUse .SIGNATURE )
140
+ .keyID(keyId)
141
+ .build()
141
142
}
142
-
143
- private fun createRSAKey (keyID : String , keyPair : KeyPair ) =
144
- RSAKey .Builder (keyPair.public as RSAPublicKey )
145
- .privateKey(keyPair.private as RSAPrivateKey )
146
- .keyUse(KeyUse .SIGNATURE )
147
- .keyID(keyID)
148
- .build()
149
143
}
150
144
}
0 commit comments