Skip to content

Commit dac8b61

Browse files
authored
feature: add support for request/response interceptors, support CORS with credentials (#199)
* feat: add request/response interceptor feature to route handling * feat: move CORS functionality into CorsInterceptor
1 parent cf05496 commit dac8b61

File tree

8 files changed

+205
-60
lines changed

8 files changed

+205
-60
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package no.nav.security.mock.oauth2.http
2+
3+
import mu.KotlinLogging
4+
5+
private val log = KotlinLogging.logger {}
6+
7+
class CorsInterceptor(
8+
private val allowedMethods: List<String> = listOf("POST", "GET", "OPTIONS")
9+
) : ResponseInterceptor {
10+
11+
companion object HeaderNames {
12+
const val ORIGIN = "origin"
13+
const val ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"
14+
const val ACCESS_CONTROL_REQUEST_HEADERS = "access-control-request-headers"
15+
const val ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"
16+
const val ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"
17+
const val ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
18+
}
19+
20+
override fun intercept(request: OAuth2HttpRequest, response: OAuth2HttpResponse): OAuth2HttpResponse {
21+
val origin = request.headers[ORIGIN]
22+
log.debug("intercept response if request origin header is set: $origin")
23+
return if (origin != null) {
24+
val headers = response.headers.newBuilder()
25+
if (request.method == "OPTIONS") {
26+
val reqHeader = request.headers[ACCESS_CONTROL_REQUEST_HEADERS]
27+
if (reqHeader != null) {
28+
headers[ACCESS_CONTROL_ALLOW_HEADERS] = reqHeader
29+
}
30+
headers[ACCESS_CONTROL_ALLOW_METHODS] = allowedMethods.joinToString(", ")
31+
}
32+
headers[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
33+
headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
34+
log.debug("adding CORS response headers")
35+
response.copy(headers = headers.build())
36+
} else {
37+
response
38+
}
39+
}
40+
}

src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequestHandler.kt

+2-12
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import com.nimbusds.oauth2.sdk.GrantType.REFRESH_TOKEN
1010
import com.nimbusds.oauth2.sdk.OAuth2Error
1111
import com.nimbusds.oauth2.sdk.ParseException
1212
import com.nimbusds.openid.connect.sdk.AuthenticationRequest
13-
import io.netty.handler.codec.http.HttpHeaderNames
1413
import mu.KotlinLogging
1514
import no.nav.security.mock.oauth2.OAuth2Config
1615
import no.nav.security.mock.oauth2.OAuth2Exception
@@ -38,7 +37,6 @@ import no.nav.security.mock.oauth2.login.LoginRequestHandler
3837
import no.nav.security.mock.oauth2.token.DefaultOAuth2TokenCallback
3938
import no.nav.security.mock.oauth2.token.OAuth2TokenCallback
4039
import no.nav.security.mock.oauth2.userinfo.userInfo
41-
import okhttp3.Headers
4240
import java.net.URLEncoder
4341
import java.nio.charset.Charset
4442
import java.util.concurrent.BlockingQueue
@@ -75,6 +73,7 @@ class OAuth2HttpRequestHandler(private val config: OAuth2Config) {
7573

7674
val authorizationServer: Route = routes {
7775
exceptionHandler(exceptionHandler)
76+
interceptors(CorsInterceptor())
7877
wellKnown()
7978
jwks()
8079
authorization()
@@ -139,16 +138,7 @@ class OAuth2HttpRequestHandler(private val config: OAuth2Config) {
139138
}
140139
}
141140

142-
private fun Route.Builder.preflight() = options {
143-
OAuth2HttpResponse(
144-
status = 200,
145-
headers = Headers.headersOf(
146-
HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString(), "*",
147-
HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS.toString(), "*",
148-
HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS.toString(), "*"
149-
)
150-
)
151-
}
141+
private fun Route.Builder.preflight() = options { OAuth2HttpResponse(status = 204) }
152142

153143
private fun tokenCallbackFromQueueOrDefault(issuerId: String): OAuth2TokenCallback =
154144
when (issuerId) {

src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpResponse.kt

-3
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ data class OAuth2TokenResponse(
6363
fun json(anyObject: Any): OAuth2HttpResponse = OAuth2HttpResponse(
6464
headers = Headers.headersOf(
6565
HttpHeaderNames.CONTENT_TYPE.toString(), "application/json;charset=UTF-8",
66-
HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString(), "*"
6766
),
6867
status = 200,
6968
body = when (anyObject) {
@@ -78,7 +77,6 @@ fun json(anyObject: Any): OAuth2HttpResponse = OAuth2HttpResponse(
7877
fun html(content: String): OAuth2HttpResponse = OAuth2HttpResponse(
7978
headers = Headers.headersOf(
8079
HttpHeaderNames.CONTENT_TYPE.toString(), "text/html;charset=UTF-8",
81-
HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString(), "*"
8280
),
8381
status = 200,
8482
body = content
@@ -116,7 +114,6 @@ fun oauth2Error(error: ErrorObject): OAuth2HttpResponse {
116114
return OAuth2HttpResponse(
117115
headers = Headers.headersOf(
118116
HttpHeaderNames.CONTENT_TYPE.toString(), "application/json;charset=UTF-8",
119-
HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString(), "*"
120117
),
121118
status = responseCode,
122119
body = objectMapper

src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRouter.kt

+69-28
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,35 @@ import no.nav.security.mock.oauth2.extensions.endsWith
66
private val log = KotlinLogging.logger { }
77

88
typealias RequestHandler = (OAuth2HttpRequest) -> OAuth2HttpResponse
9-
internal typealias ExceptionHandler = (OAuth2HttpRequest, Throwable) -> OAuth2HttpResponse
9+
10+
interface Interceptor
11+
12+
fun interface RequestInterceptor : Interceptor {
13+
fun intercept(request: OAuth2HttpRequest): OAuth2HttpRequest
14+
}
15+
16+
fun interface ResponseInterceptor : Interceptor {
17+
fun intercept(request: OAuth2HttpRequest, response: OAuth2HttpResponse): OAuth2HttpResponse
18+
}
1019

1120
interface Route : RequestHandler {
21+
1222
fun match(request: OAuth2HttpRequest): Boolean
1323

1424
class Builder {
1525
private val routes: MutableList<Route> = mutableListOf()
26+
private val interceptors: MutableList<Interceptor> = mutableListOf()
1627

1728
private var exceptionHandler: ExceptionHandler = { _, throwable ->
1829
throw throwable
1930
}
2031

32+
fun interceptors(vararg interceptor: Interceptor) = apply {
33+
interceptor.forEach {
34+
interceptors.add(it)
35+
}
36+
}
37+
2138
fun attach(vararg route: Route) = apply {
2239
route.forEach {
2340
routes.add(it)
@@ -56,40 +73,64 @@ interface Route : RequestHandler {
5673
routes.add(routeFromPathAndMethod(path, method, requestHandler))
5774
}
5875

59-
fun build(): Route = object : PathRoute {
60-
override fun matchPath(request: OAuth2HttpRequest): Boolean =
61-
routes.any { it.matchPath(request) }
62-
63-
override fun match(request: OAuth2HttpRequest): Boolean =
64-
routes.firstOrNull { it.match(request) } != null
65-
66-
override fun invoke(request: OAuth2HttpRequest): OAuth2HttpResponse =
67-
try {
68-
routes.firstOrNull { it.match(request) }?.invoke(request) ?: noMatch(request)
69-
} catch (t: Throwable) {
70-
exceptionHandler(request, t)
71-
}
72-
73-
override fun toString(): String = routes.toString()
74-
75-
private fun noMatch(request: OAuth2HttpRequest): OAuth2HttpResponse {
76-
log.debug("no route matching url=${request.url} with method=${request.method}")
77-
return if (matchPath(request)) {
78-
methodNotAllowed()
79-
} else {
80-
notFound("no routes found")
81-
}
82-
}
83-
84-
private fun Route.matchPath(request: OAuth2HttpRequest): Boolean = (this as? PathRoute)?.matchPath(request) ?: false
85-
}
76+
fun build(): Route = PathRouter(routes, interceptors, exceptionHandler)
8677
}
8778
}
8879

80+
internal typealias ExceptionHandler = (OAuth2HttpRequest, Throwable) -> OAuth2HttpResponse
81+
8982
internal interface PathRoute : Route {
9083
fun matchPath(request: OAuth2HttpRequest): Boolean
9184
}
9285

86+
internal class PathRouter(
87+
private val routes: MutableList<Route>,
88+
private val interceptors: MutableList<Interceptor>,
89+
private val exceptionHandler: ExceptionHandler,
90+
) : PathRoute {
91+
92+
override fun matchPath(request: OAuth2HttpRequest): Boolean = routes.any { it.matchPath(request) }
93+
override fun match(request: OAuth2HttpRequest): Boolean = routes.firstOrNull { it.match(request) } != null
94+
95+
override fun invoke(request: OAuth2HttpRequest): OAuth2HttpResponse = runCatching {
96+
routes.findHandler(request).invokeWith(request, interceptors)
97+
}.getOrElse {
98+
exceptionHandler(request, it)
99+
}
100+
101+
override fun toString(): String = routes.toString()
102+
103+
private fun MutableList<Route>.findHandler(request: OAuth2HttpRequest): RequestHandler =
104+
this.firstOrNull { it.match(request) } ?: { req -> noMatch(req) }
105+
106+
private fun RequestHandler.invokeWith(request: OAuth2HttpRequest, interceptors: MutableList<Interceptor>): OAuth2HttpResponse {
107+
return if (interceptors.size > 0) {
108+
109+
val filteredRequest = interceptors.filterIsInstance<RequestInterceptor>().fold(request) { next, interceptor ->
110+
interceptor.intercept(next)
111+
}
112+
val res = this.invoke(filteredRequest)
113+
val filteredResponse = interceptors.filterIsInstance<ResponseInterceptor>().fold(res.copy()) { next, interceptor ->
114+
interceptor.intercept(request, next)
115+
}
116+
filteredResponse
117+
} else {
118+
this.invoke(request)
119+
}
120+
}
121+
122+
private fun noMatch(request: OAuth2HttpRequest): OAuth2HttpResponse {
123+
log.debug("no route matching url=${request.url} with method=${request.method}")
124+
return if (matchPath(request)) {
125+
methodNotAllowed()
126+
} else {
127+
notFound("no routes found")
128+
}
129+
}
130+
131+
private fun Route.matchPath(request: OAuth2HttpRequest): Boolean = (this as? PathRoute)?.matchPath(request) ?: false
132+
}
133+
93134
fun routes(vararg route: Route): Route = routes {
94135
attach(*route)
95136
}

src/test/kotlin/examples/kotlin/ktor/client/OAuth2Client.kt

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import io.ktor.client.request.header
1414
import io.ktor.http.Headers
1515
import io.ktor.http.Parameters
1616
import io.ktor.http.headersOf
17+
import io.ktor.util.InternalAPI
1718
import java.nio.charset.StandardCharsets
1819
import java.security.KeyPair
1920
import java.security.interfaces.RSAPrivateKey
@@ -33,6 +34,7 @@ val httpClient = HttpClient(CIO) {
3334
}
3435
}
3536

37+
@OptIn(InternalAPI::class)
3638
suspend fun HttpClient.tokenRequest(url: String, auth: Auth, params: Map<String, String>) =
3739
submitForm<TokenResponse>(
3840
url = url,

src/test/kotlin/no/nav/security/mock/oauth2/e2e/CorsHeadersIntegrationTest.kt

+37-13
Original file line numberDiff line numberDiff line change
@@ -3,46 +3,68 @@ package no.nav.security.mock.oauth2.e2e
33
import com.nimbusds.oauth2.sdk.GrantType
44
import io.kotest.assertions.asClue
55
import io.kotest.matchers.shouldBe
6-
import io.netty.handler.codec.http.HttpHeaderNames
6+
import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS
7+
import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_ALLOW_HEADERS
8+
import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_ALLOW_METHODS
9+
import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN
10+
import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_REQUEST_HEADERS
711
import no.nav.security.mock.oauth2.testutils.client
812
import no.nav.security.mock.oauth2.testutils.get
913
import no.nav.security.mock.oauth2.testutils.options
1014
import no.nav.security.mock.oauth2.testutils.tokenRequest
1115
import no.nav.security.mock.oauth2.token.DefaultOAuth2TokenCallback
1216
import no.nav.security.mock.oauth2.withMockOAuth2Server
17+
import okhttp3.Headers
1318
import org.junit.jupiter.api.Test
1419

1520
class CorsHeadersIntegrationTest {
1621
private val client = client()
1722

23+
private val origin = "https://theorigin"
24+
1825
@Test
19-
fun `preflight response should allow all origin, all methods and all headers`() {
26+
fun `preflight response should allow specific origin, methods and headers`() {
2027
withMockOAuth2Server {
21-
client.options(this.baseUrl()).asClue {
22-
it.code shouldBe 200
23-
it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString()] shouldBe "*"
24-
it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS.toString()] shouldBe "*"
25-
it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS.toString()] shouldBe "*"
28+
client.options(
29+
this.baseUrl(),
30+
Headers.headersOf(
31+
"origin", origin,
32+
ACCESS_CONTROL_REQUEST_HEADERS, "X-MY-HEADER"
33+
)
34+
).asClue {
35+
it.code shouldBe 204
36+
it.headers[ACCESS_CONTROL_ALLOW_ORIGIN] shouldBe origin
37+
it.headers[ACCESS_CONTROL_ALLOW_METHODS] shouldBe "POST, GET, OPTIONS"
38+
it.headers[ACCESS_CONTROL_ALLOW_HEADERS] shouldBe "X-MY-HEADER"
39+
it.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] shouldBe "true"
2640
}
2741
}
2842
}
2943

3044
@Test
31-
fun `wellknown response should allow all origins`() {
45+
fun `wellknown response should allow origin`() {
3246
withMockOAuth2Server {
33-
client.get(this.wellKnownUrl("issuer")).asClue {
47+
client.get(
48+
this.wellKnownUrl("issuer"),
49+
Headers.headersOf("origin", origin)
50+
).asClue {
3451
it.code shouldBe 200
35-
it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString()] shouldBe "*"
52+
it.headers[ACCESS_CONTROL_ALLOW_ORIGIN] shouldBe origin
53+
it.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] shouldBe "true"
3654
}
3755
}
3856
}
3957

4058
@Test
4159
fun `jwks response should allow all origins`() {
4260
withMockOAuth2Server {
43-
client.get(this.jwksUrl("issuer")).asClue {
61+
client.get(
62+
this.jwksUrl("issuer"),
63+
Headers.headersOf("origin", origin)
64+
).asClue {
4465
it.code shouldBe 200
45-
it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString()] shouldBe "*"
66+
it.headers[ACCESS_CONTROL_ALLOW_ORIGIN] shouldBe origin
67+
it.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] shouldBe "true"
4668
}
4769
}
4870
}
@@ -56,6 +78,7 @@ class CorsHeadersIntegrationTest {
5678

5779
val response = client.tokenRequest(
5880
this.tokenEndpointUrl(issuerId),
81+
Headers.headersOf("origin", origin),
5982
mapOf(
6083
"grant_type" to GrantType.REFRESH_TOKEN.value,
6184
"refresh_token" to "canbewhatever",
@@ -65,7 +88,8 @@ class CorsHeadersIntegrationTest {
6588
)
6689

6790
response.code shouldBe 200
68-
response.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString()] shouldBe "*"
91+
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] shouldBe origin
92+
response.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] shouldBe "true"
6993
}
7094
}
7195
}

0 commit comments

Comments
 (0)