diff --git a/rt/rs/security/sso/oidc/pom.xml b/rt/rs/security/sso/oidc/pom.xml index 5fcda3ae890..a808fdf37fe 100644 --- a/rt/rs/security/sso/oidc/pom.xml +++ b/rt/rs/security/sso/oidc/pom.xml @@ -38,7 +38,7 @@ - -javaagent:${org.apache.openjpa:openjpa:jar} + -javaagent:${org.apache.openjpa:openjpa:jar} -javaagent:${org.mockito:mockito-core:jar} @@ -64,6 +64,12 @@ junit test + + org.mockito + mockito-core + ${cxf.mockito.version} + test + org.hsqldb hsqldb diff --git a/rt/rs/security/sso/oidc/src/main/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationService.java b/rt/rs/security/sso/oidc/src/main/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationService.java index 30d957bf6a2..192c3eaa716 100644 --- a/rt/rs/security/sso/oidc/src/main/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationService.java +++ b/rt/rs/security/sso/oidc/src/main/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationService.java @@ -57,11 +57,15 @@ public Response completeAuthentication(@Context OidcClientTokenContext oidcConte URI redirectUri = null; MultivaluedMap state = oidcContext.getState(); String location = state != null ? state.getFirst("state") : null; - if (location == null && defaultLocation != null) { + if (location != null) { + URI requestedUri = URI.create(UrlUtils.urlDecode(location)); + if (isSameOrigin(requestedUri)) { + redirectUri = requestedUri; + } + } + if (redirectUri == null && defaultLocation != null) { String basePath = (String)mc.get("http.base.path"); redirectUri = UriBuilder.fromUri(basePath).path(defaultLocation).build(); - } else if (location != null) { - redirectUri = URI.create(UrlUtils.urlDecode(location)); } if (redirectUri != null) { return Response.seeOther(redirectUri).build(); @@ -69,6 +73,26 @@ public Response completeAuthentication(@Context OidcClientTokenContext oidcConte return Response.ok(oidcContext).build(); } + // The location is taken from the request state, so it can only be trusted as long as it + // stays within this application's own origin. An absolute value pointing at a different + // host (or a protocol-relative "//host" reference) would turn sign-in completion into an + // open redirect, so anything that is not same-origin is ignored here. + private boolean isSameOrigin(URI location) { + if (location.getScheme() == null && location.getAuthority() == null) { + // a path-only reference is resolved by the browser against the current request + return true; + } + String basePath = (String)mc.get("http.base.path"); + if (basePath == null) { + return false; + } + URI base = URI.create(basePath); + return location.getScheme() != null + && location.getScheme().equalsIgnoreCase(base.getScheme()) + && location.getAuthority() != null + && location.getAuthority().equalsIgnoreCase(base.getAuthority()); + } + public void setDefaultLocation(String defaultLocation) { this.defaultLocation = defaultLocation; } diff --git a/rt/rs/security/sso/oidc/src/test/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationServiceTest.java b/rt/rs/security/sso/oidc/src/test/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationServiceTest.java new file mode 100644 index 00000000000..e64d5dd21a2 --- /dev/null +++ b/rt/rs/security/sso/oidc/src/test/java/org/apache/cxf/rs/security/oidc/rp/OidcRpAuthenticationServiceTest.java @@ -0,0 +1,103 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.cxf.rs.security.oidc.rp; + +import java.lang.reflect.Field; +import java.net.URI; + +import jakarta.ws.rs.core.MultivaluedHashMap; +import jakarta.ws.rs.core.MultivaluedMap; +import jakarta.ws.rs.core.Response; + +import org.apache.cxf.jaxrs.ext.MessageContext; +import org.apache.cxf.rs.security.oauth2.client.ClientTokenContextManager; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class OidcRpAuthenticationServiceTest { + + private static final String BASE_PATH = "https://app.example.com:8080/services/"; + + @Test + public void testRejectsCrossOriginRedirectFromState() { + Response response = complete("https://evil.example.com/phish"); + assertEquals(200, response.getStatus()); + assertNull(response.getLocation()); + } + + @Test + public void testRejectsProtocolRelativeRedirectFromState() { + Response response = complete("//evil.example.com/phish"); + assertEquals(200, response.getStatus()); + assertNull(response.getLocation()); + } + + @Test + public void testRejectsUserinfoHostConfusionFromState() { + Response response = complete("https://app.example.com:8080@evil.example.com/phish"); + assertEquals(200, response.getStatus()); + assertNull(response.getLocation()); + } + + @Test + public void testAllowsSameOriginRedirectFromState() { + Response response = complete("https://app.example.com:8080/services/protected"); + assertEquals(303, response.getStatus()); + assertEquals(URI.create("https://app.example.com:8080/services/protected"), response.getLocation()); + } + + @Test + public void testAllowsRelativeRedirectFromState() { + Response response = complete("/services/protected"); + assertEquals(303, response.getStatus()); + assertEquals(URI.create("/services/protected"), response.getLocation()); + } + + private Response complete(String stateLocation) { + MessageContext messageContext = mock(MessageContext.class); + when(messageContext.get("http.base.path")).thenReturn(BASE_PATH); + + MultivaluedMap state = new MultivaluedHashMap<>(); + state.putSingle("state", stateLocation); + + OidcClientTokenContext tokenContext = mock(OidcClientTokenContext.class); + when(tokenContext.getState()).thenReturn(state); + + OidcRpAuthenticationService service = new OidcRpAuthenticationService(); + service.setClientTokenContextManager(mock(ClientTokenContextManager.class)); + setMessageContext(service, messageContext); + + return service.completeAuthentication(tokenContext); + } + + private static void setMessageContext(OidcRpAuthenticationService service, MessageContext messageContext) { + try { + Field field = OidcRpAuthenticationService.class.getDeclaredField("mc"); + field.setAccessible(true); + field.set(service, messageContext); + } catch (ReflectiveOperationException ex) { + throw new IllegalStateException(ex); + } + } +}