Skip to content

add the ability to store additional grant request parameters for the OIDC token exchange grant request #16393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRe
return null;
}
return (T) new OAuth2AuthorizedClient(registration, cachedAuthorizedClient.getPrincipalName(),
cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken());
cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken(),
cachedAuthorizedClient.getAttributes());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public final class OAuth2AuthorizationContext {
*/
public static final String PASSWORD_ATTRIBUTE_NAME = OAuth2AuthorizationContext.class.getName().concat(".PASSWORD");

public static final String ADDITIONAL_GRANT_REQUEST_PARAMETERS_ATTRIBUTE_NAME = OAuth2AuthorizationContext.class.getName().concat(".ADDITIONAL_GRANT_REQUEST_PARAMETERS");

private ClientRegistration clientRegistration;

private OAuth2AuthorizedClient authorizedClient;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.security.oauth2.client;

import java.io.Serializable;
import java.util.Map;

import org.springframework.lang.Nullable;
import org.springframework.security.core.SpringSecurityCoreVersion;
Expand Down Expand Up @@ -53,6 +54,8 @@ public class OAuth2AuthorizedClient implements Serializable {

private final OAuth2RefreshToken refreshToken;

private final Map<String, Object> attributes;

/**
* Constructs an {@code OAuth2AuthorizedClient} using the provided parameters.
* @param clientRegistration the authorized client's registration
Expand All @@ -73,13 +76,28 @@ public OAuth2AuthorizedClient(ClientRegistration clientRegistration, String prin
*/
public OAuth2AuthorizedClient(ClientRegistration clientRegistration, String principalName,
OAuth2AccessToken accessToken, @Nullable OAuth2RefreshToken refreshToken) {
this(clientRegistration, principalName, accessToken, refreshToken, null);
}

/**
* Constructs an {@code OAuth2AuthorizedClient} using the provided parameters.
* @param clientRegistration the authorized client's registration
* @param principalName the name of the End-User {@code Principal} (Resource Owner)
* @param accessToken the access token credential granted
* @param refreshToken the refresh token credential granted
* @param attributes associated with the client
*/
public OAuth2AuthorizedClient(ClientRegistration clientRegistration, String principalName,
OAuth2AccessToken accessToken, @Nullable OAuth2RefreshToken refreshToken,
@Nullable Map<String, Object> attributes) {
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
Assert.hasText(principalName, "principalName cannot be empty");
Assert.notNull(accessToken, "accessToken cannot be null");
this.clientRegistration = clientRegistration;
this.principalName = principalName;
this.accessToken = accessToken;
this.refreshToken = refreshToken;
this.attributes = attributes;
}

/**
Expand Down Expand Up @@ -115,4 +133,21 @@ public OAuth2AccessToken getAccessToken() {
return this.refreshToken;
}

/**
* Returns the {@link Map} attributes.
* @return the {@link Map}
* @since 6.5
*/
public @Nullable Map<String, Object> getAttributes() {
return this.attributes;
}

@Nullable
@SuppressWarnings("unchecked")
public <T> T getAttribute(String name) {
if (this.getAttributes() == null) {
return null;
}
return (T) this.getAttributes().get(name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.*;
import java.util.function.Function;

import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -73,9 +74,19 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
if (!AuthorizationGrantType.TOKEN_EXCHANGE.equals(clientRegistration.getAuthorizationGrantType())) {
return null;
}

Map<String, Object> contextAdditionalGrantRequestParameters =
context.getAttribute(OAuth2AuthorizationContext.ADDITIONAL_GRANT_REQUEST_PARAMETERS_ATTRIBUTE_NAME);

OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient();
if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) {
// If client is already authorized but access token is NOT expired than no
Map<String, Object> authorizedClientAdditionalGrantRequestParameters =
authorizedClient != null
? authorizedClient.getAttribute(OAuth2AuthorizationContext.ADDITIONAL_GRANT_REQUEST_PARAMETERS_ATTRIBUTE_NAME)
: null;
if (authorizedClient != null
&& !hasTokenExpired(authorizedClient.getAccessToken())
&& Objects.equals(authorizedClientAdditionalGrantRequestParameters, contextAdditionalGrantRequestParameters)) {
// If client is already authorized but access token is NOT expired and the attributes are equal, then no
// need for re-authorization
return null;
}
Expand All @@ -86,11 +97,14 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {

OAuth2Token actorToken = this.actorTokenResolver.apply(context);
TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, subjectToken,
actorToken);
actorToken, contextAdditionalGrantRequestParameters);
OAuth2AccessTokenResponse tokenResponse = getTokenResponse(clientRegistration, grantRequest);

Map<String, Object> authorizedClientAttributes = new HashMap<>();
authorizedClientAttributes.put(OAuth2AuthorizationContext.ADDITIONAL_GRANT_REQUEST_PARAMETERS_ATTRIBUTE_NAME, contextAdditionalGrantRequestParameters);

return new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(),
tokenResponse.getAccessToken(), tokenResponse.getRefreshToken());
tokenResponse.getAccessToken(), tokenResponse.getRefreshToken(), authorizedClientAttributes);
}

private OAuth2Token resolveSubjectToken(OAuth2AuthorizationContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

import java.util.Map;

/**
* A Token Exchange Grant request that holds the {@link OAuth2Token subject token} and
* optional {@link OAuth2Token actor token}.
Expand All @@ -53,6 +55,8 @@ public class TokenExchangeGrantRequest extends AbstractOAuth2AuthorizationGrantR

private final OAuth2Token actorToken;

private final Map<String, Object> additionalParameters;

/**
* Constructs a {@code TokenExchangeGrantRequest} using the provided parameters.
* @param clientRegistration the client registration
Expand All @@ -61,12 +65,24 @@ public class TokenExchangeGrantRequest extends AbstractOAuth2AuthorizationGrantR
*/
public TokenExchangeGrantRequest(ClientRegistration clientRegistration, OAuth2Token subjectToken,
OAuth2Token actorToken) {
this(clientRegistration,subjectToken,actorToken,null);
}

/**
* Constructs a {@code TokenExchangeGrantRequest} using the provided parameters.
* @param clientRegistration the client registration
* @param subjectToken the subject token
* @param actorToken the actor token
*/
public TokenExchangeGrantRequest(ClientRegistration clientRegistration, OAuth2Token subjectToken,
OAuth2Token actorToken, Map<String, Object> additionalParameters) {
super(AuthorizationGrantType.TOKEN_EXCHANGE, clientRegistration);
Assert.isTrue(AuthorizationGrantType.TOKEN_EXCHANGE.equals(clientRegistration.getAuthorizationGrantType()),
"clientRegistration.authorizationGrantType must be AuthorizationGrantType.TOKEN_EXCHANGE");
Assert.notNull(subjectToken, "subjectToken cannot be null");
this.subjectToken = subjectToken;
this.actorToken = actorToken;
this.additionalParameters = additionalParameters;
}

/**
Expand All @@ -85,6 +101,14 @@ public OAuth2Token getActorToken() {
return this.actorToken;
}

/**
* Returns the {@link Map additional parameters}.
* @return the {@link Map additional parameters}
*/
public Map<String, Object> getAdditionalParameters() {
return this.additionalParameters;
}

/**
* Populate default parameters for the Token Exchange Grant.
* @param grantRequest the authorization grant request
Expand Down