|
| 1 | +package no.nav.security.mock.oauth2 |
| 2 | + |
| 3 | +import com.nimbusds.jwt.SignedJWT |
| 4 | +import com.nimbusds.oauth2.sdk.AuthorizationCode |
| 5 | +import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant |
| 6 | +import com.nimbusds.oauth2.sdk.TokenRequest |
| 7 | +import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic |
| 8 | +import com.nimbusds.oauth2.sdk.auth.Secret |
| 9 | +import com.nimbusds.oauth2.sdk.id.ClientID |
| 10 | +import mu.KotlinLogging |
| 11 | +import no.nav.security.mock.oauth2.extensions.asOAuth2HttpRequest |
| 12 | +import no.nav.security.mock.oauth2.extensions.toAuthorizationEndpointUrl |
| 13 | +import no.nav.security.mock.oauth2.extensions.toJwksUrl |
| 14 | +import no.nav.security.mock.oauth2.extensions.toTokenEndpointUrl |
| 15 | +import no.nav.security.mock.oauth2.extensions.toWellKnownUrl |
| 16 | +import no.nav.security.mock.oauth2.http.OAuth2HttpRequestHandler |
| 17 | +import no.nav.security.mock.oauth2.http.OAuth2HttpResponse |
| 18 | +import no.nav.security.mock.oauth2.token.OAuth2TokenCallback |
| 19 | +import no.nav.security.mock.oauth2.token.OAuth2TokenProvider |
| 20 | +import okhttp3.HttpUrl |
| 21 | +import okhttp3.mockwebserver.Dispatcher |
| 22 | +import okhttp3.mockwebserver.MockResponse |
| 23 | +import okhttp3.mockwebserver.MockWebServer |
| 24 | +import okhttp3.mockwebserver.RecordedRequest |
| 25 | +import java.io.IOException |
| 26 | +import java.net.InetSocketAddress |
| 27 | +import java.net.URI |
| 28 | +import java.util.concurrent.BlockingQueue |
| 29 | +import java.util.concurrent.LinkedBlockingQueue |
| 30 | + |
| 31 | +private val log = KotlinLogging.logger {} |
| 32 | + |
| 33 | +class MockOAuth2Server( |
| 34 | + config: OAuth2Config = OAuth2Config() |
| 35 | +) { |
| 36 | + private val mockWebServer: MockWebServer = MockWebServer() |
| 37 | + private val tokenProvider: OAuth2TokenProvider = |
| 38 | + OAuth2TokenProvider() |
| 39 | + |
| 40 | + var dispatcher: Dispatcher = MockOAuth2Dispatcher(config) |
| 41 | + |
| 42 | + fun start() { |
| 43 | + mockWebServer.start() |
| 44 | + mockWebServer.dispatcher = dispatcher |
| 45 | + } |
| 46 | + |
| 47 | + fun start(port: Int = 0) { |
| 48 | + val address = InetSocketAddress(0).address |
| 49 | + log.info("attempting to start server on port $port and InetAddress=$address") |
| 50 | + mockWebServer.start(address, port) |
| 51 | + mockWebServer.dispatcher = dispatcher |
| 52 | + } |
| 53 | + |
| 54 | + @Throws(IOException::class) |
| 55 | + fun shutdown() { |
| 56 | + mockWebServer.shutdown() |
| 57 | + } |
| 58 | + |
| 59 | + fun url(path: String): HttpUrl = mockWebServer.url(path) |
| 60 | + fun enqueueResponse(response: MockResponse) = (dispatcher as MockOAuth2Dispatcher).enqueueResponse(response) |
| 61 | + fun enqueueCallback(oAuth2TokenCallback: OAuth2TokenCallback) = (dispatcher as MockOAuth2Dispatcher).enqueueTokenCallback(oAuth2TokenCallback) |
| 62 | + fun takeRequest(): RecordedRequest = mockWebServer.takeRequest() |
| 63 | + |
| 64 | + fun wellKnownUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId).toWellKnownUrl() |
| 65 | + fun tokenEndpointUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId).toTokenEndpointUrl() |
| 66 | + fun jwksUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId).toJwksUrl() |
| 67 | + fun issuerUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId) |
| 68 | + fun authorizationEndpointUrl(issuerId: String): HttpUrl = mockWebServer.url(issuerId).toAuthorizationEndpointUrl() |
| 69 | + fun baseUrl(): HttpUrl = mockWebServer.url("") |
| 70 | + |
| 71 | + fun issueToken(issuerId: String, clientId: String, OAuth2TokenCallback: OAuth2TokenCallback): SignedJWT { |
| 72 | + val uri = tokenEndpointUrl(issuerId) |
| 73 | + val issuerUrl = issuerUrl(issuerId) |
| 74 | + val tokenRequest = TokenRequest( |
| 75 | + uri.toUri(), |
| 76 | + ClientSecretBasic(ClientID(clientId), Secret("secret")), |
| 77 | + AuthorizationCodeGrant(AuthorizationCode("123"), URI.create("http://localhost")) |
| 78 | + ) |
| 79 | + return tokenProvider.accessToken(tokenRequest, issuerUrl, null, OAuth2TokenCallback) |
| 80 | + } |
| 81 | +} |
| 82 | + |
| 83 | +class MockOAuth2Dispatcher( |
| 84 | + config: OAuth2Config |
| 85 | +) : Dispatcher() { |
| 86 | + private val httpRequestHandler: OAuth2HttpRequestHandler = OAuth2HttpRequestHandler(config) |
| 87 | + private val responseQueue: BlockingQueue<MockResponse> = LinkedBlockingQueue() |
| 88 | + |
| 89 | + fun enqueueResponse(mockResponse: MockResponse) = responseQueue.add(mockResponse) |
| 90 | + fun enqueueTokenCallback(oAuth2TokenCallback: OAuth2TokenCallback) = httpRequestHandler.enqueueTokenCallback(oAuth2TokenCallback) |
| 91 | + |
| 92 | + override fun dispatch(request: RecordedRequest): MockResponse = |
| 93 | + when { |
| 94 | + responseQueue.peek() != null -> responseQueue.take() |
| 95 | + else -> mockResponse(httpRequestHandler.handleRequest(request.asOAuth2HttpRequest())) |
| 96 | + } |
| 97 | + |
| 98 | + |
| 99 | + private fun mockResponse(response: OAuth2HttpResponse): MockResponse = |
| 100 | + MockResponse() |
| 101 | + .setHeaders(response.headers) |
| 102 | + .setResponseCode(response.status) |
| 103 | + .apply { |
| 104 | + response.body?.let { this.setBody(it) } |
| 105 | + } |
| 106 | +} |
0 commit comments