diff --git a/rt/rs/security/sso/oidc/pom.xml b/rt/rs/security/sso/oidc/pom.xml
index 5fcda3ae890..a808fdf37fe 100644
--- a/rt/rs/security/sso/oidc/pom.xml
+++ b/rt/rs/security/sso/oidc/pom.xml
@@ -38,7 +38,7 @@
- -javaagent:${org.apache.openjpa:openjpa:jar}
+ -javaagent:${org.apache.openjpa:openjpa:jar} -javaagent:${org.mockito:mockito-core:jar}
@@ -64,6 +64,12 @@
junit
test
+
+ org.mockito
+ mockito-core
+ ${cxf.mockito.version}
+ test
+
org.hsqldb
hsqldb
diff --git a/rt/rs/security/sso/oidc/src/main/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationService.java b/rt/rs/security/sso/oidc/src/main/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationService.java
index 30d957bf6a2..192c3eaa716 100644
--- a/rt/rs/security/sso/oidc/src/main/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationService.java
+++ b/rt/rs/security/sso/oidc/src/main/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationService.java
@@ -57,11 +57,15 @@ public Response completeAuthentication(@Context OidcClientTokenContext oidcConte
URI redirectUri = null;
MultivaluedMap state = oidcContext.getState();
String location = state != null ? state.getFirst("state") : null;
- if (location == null && defaultLocation != null) {
+ if (location != null) {
+ URI requestedUri = URI.create(UrlUtils.urlDecode(location));
+ if (isSameOrigin(requestedUri)) {
+ redirectUri = requestedUri;
+ }
+ }
+ if (redirectUri == null && defaultLocation != null) {
String basePath = (String)mc.get("http.base.path");
redirectUri = UriBuilder.fromUri(basePath).path(defaultLocation).build();
- } else if (location != null) {
- redirectUri = URI.create(UrlUtils.urlDecode(location));
}
if (redirectUri != null) {
return Response.seeOther(redirectUri).build();
@@ -69,6 +73,26 @@ public Response completeAuthentication(@Context OidcClientTokenContext oidcConte
return Response.ok(oidcContext).build();
}
+ // The location is taken from the request state, so it can only be trusted as long as it
+ // stays within this application's own origin. An absolute value pointing at a different
+ // host (or a protocol-relative "//host" reference) would turn sign-in completion into an
+ // open redirect, so anything that is not same-origin is ignored here.
+ private boolean isSameOrigin(URI location) {
+ if (location.getScheme() == null && location.getAuthority() == null) {
+ // a path-only reference is resolved by the browser against the current request
+ return true;
+ }
+ String basePath = (String)mc.get("http.base.path");
+ if (basePath == null) {
+ return false;
+ }
+ URI base = URI.create(basePath);
+ return location.getScheme() != null
+ && location.getScheme().equalsIgnoreCase(base.getScheme())
+ && location.getAuthority() != null
+ && location.getAuthority().equalsIgnoreCase(base.getAuthority());
+ }
+
public void setDefaultLocation(String defaultLocation) {
this.defaultLocation = defaultLocation;
}
diff --git a/rt/rs/security/sso/oidc/src/test/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationServiceTest.java b/rt/rs/security/sso/oidc/src/test/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationServiceTest.java
new file mode 100644
index 00000000000..e64d5dd21a2
--- /dev/null
+++ b/rt/rs/security/sso/oidc/src/test/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationServiceTest.java
@@ -0,0 +1,103 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.cxf.rs.security.oidc.rp;
+
+import java.lang.reflect.Field;
+import java.net.URI;
+
+import jakarta.ws.rs.core.MultivaluedHashMap;
+import jakarta.ws.rs.core.MultivaluedMap;
+import jakarta.ws.rs.core.Response;
+
+import org.apache.cxf.jaxrs.ext.MessageContext;
+import org.apache.cxf.rs.security.oauth2.client.ClientTokenContextManager;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class OidcRpAuthenticationServiceTest {
+
+ private static final String BASE_PATH = "https://app.example.com:8080/services/";
+
+ @Test
+ public void testRejectsCrossOriginRedirectFromState() {
+ Response response = complete("https://evil.example.com/phish");
+ assertEquals(200, response.getStatus());
+ assertNull(response.getLocation());
+ }
+
+ @Test
+ public void testRejectsProtocolRelativeRedirectFromState() {
+ Response response = complete("//evil.example.com/phish");
+ assertEquals(200, response.getStatus());
+ assertNull(response.getLocation());
+ }
+
+ @Test
+ public void testRejectsUserinfoHostConfusionFromState() {
+ Response response = complete("https://app.example.com:8080@evil.example.com/phish");
+ assertEquals(200, response.getStatus());
+ assertNull(response.getLocation());
+ }
+
+ @Test
+ public void testAllowsSameOriginRedirectFromState() {
+ Response response = complete("https://app.example.com:8080/services/protected");
+ assertEquals(303, response.getStatus());
+ assertEquals(URI.create("https://app.example.com:8080/services/protected"), response.getLocation());
+ }
+
+ @Test
+ public void testAllowsRelativeRedirectFromState() {
+ Response response = complete("/services/protected");
+ assertEquals(303, response.getStatus());
+ assertEquals(URI.create("/services/protected"), response.getLocation());
+ }
+
+ private Response complete(String stateLocation) {
+ MessageContext messageContext = mock(MessageContext.class);
+ when(messageContext.get("http.base.path")).thenReturn(BASE_PATH);
+
+ MultivaluedMap state = new MultivaluedHashMap<>();
+ state.putSingle("state", stateLocation);
+
+ OidcClientTokenContext tokenContext = mock(OidcClientTokenContext.class);
+ when(tokenContext.getState()).thenReturn(state);
+
+ OidcRpAuthenticationService service = new OidcRpAuthenticationService();
+ service.setClientTokenContextManager(mock(ClientTokenContextManager.class));
+ setMessageContext(service, messageContext);
+
+ return service.completeAuthentication(tokenContext);
+ }
+
+ private static void setMessageContext(OidcRpAuthenticationService service, MessageContext messageContext) {
+ try {
+ Field field = OidcRpAuthenticationService.class.getDeclaredField("mc");
+ field.setAccessible(true);
+ field.set(service, messageContext);
+ } catch (ReflectiveOperationException ex) {
+ throw new IllegalStateException(ex);
+ }
+ }
+}