Skip to content

optional and mandatory claims check #17030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,29 @@ public final class JwtAudienceValidator implements OAuth2TokenValidator<Jwt> {
private final JwtClaimValidator<Collection<String>> validator;

/**
* Constructs a {@link JwtAudienceValidator} using the provided parameters
* Constructs a {@link JwtAudienceValidator} using the provided parameters with
* {@link JwtClaimNames#ISS "iss"} claim is REQUIRED
* @param audience - The audience that each {@link Jwt} should have.
*/
public JwtAudienceValidator(String audience) {
this(audience, true);
}

/**
* Constructs a {@link JwtIssuerValidator} using the provided parameters
* @param audience - The audience that each {@link Jwt} should have.
* @param required -{@code true} if the {@link JwtClaimNames#AUD "aud"} claim is
* REQUIRED in the {@link Jwt}, {@code false} otherwise
*/
public JwtAudienceValidator(String audience, boolean required) {
Assert.notNull(audience, "audience cannot be null");
this.validator = new JwtClaimValidator<>(JwtClaimNames.AUD,
(claimValue) -> (claimValue != null) && claimValue.contains(audience));
(claimValue) -> (claimValue != null) ? claimValue.contains(audience) : !required);
}

@Override
public OAuth2TokenValidatorResult validate(Jwt token) {
Assert.notNull(token, "token cannot be null");
Assert.notNull(token, "jwt cannot be null");
return this.validator.validate(token);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package org.springframework.security.oauth2.jwt;

import java.util.function.Predicate;

import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
import org.springframework.util.Assert;
Expand All @@ -33,19 +31,29 @@ public final class JwtIssuerValidator implements OAuth2TokenValidator<Jwt> {
private final JwtClaimValidator<Object> validator;

/**
* Constructs a {@link JwtIssuerValidator} using the provided parameters
* Constructs a {@link JwtIssuerValidator} using the provided parameters with
* {@link JwtClaimNames#ISS "iss"} claim is REQUIRED
* @param issuer - The issuer that each {@link Jwt} should have.
*/
public JwtIssuerValidator(String issuer) {
Assert.notNull(issuer, "issuer cannot be null");
this(issuer, true);
}

Predicate<Object> testClaimValue = (claimValue) -> (claimValue != null) && issuer.equals(claimValue.toString());
this.validator = new JwtClaimValidator<>(JwtClaimNames.ISS, testClaimValue);
/**
* Constructs a {@link JwtIssuerValidator} using the provided parameters
* @param issuer - The issuer that each {@link Jwt} should have.
* @param required -{@code true} if the {@link JwtClaimNames#ISS "iss"} claim is
* REQUIRED in the {@link Jwt}, {@code false} otherwise
*/
public JwtIssuerValidator(String issuer, boolean required) {
Assert.notNull(issuer, "issuer cannot be null");
this.validator = new JwtClaimValidator<>(JwtClaimNames.ISS,
(claimValue) -> (claimValue != null) ? issuer.equals(claimValue.toString()) : !required);
}

@Override
public OAuth2TokenValidatorResult validate(Jwt token) {
Assert.notNull(token, "token cannot be null");
Assert.notNull(token, "jwt cannot be null");
return this.validator.validate(token);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ public final class JwtTimestampValidator implements OAuth2TokenValidator<Jwt> {

private static final Duration DEFAULT_MAX_CLOCK_SKEW = Duration.of(60, ChronoUnit.SECONDS);

private final boolean required;

private final Duration clockSkew;

private Clock clock = Clock.systemUTC();
Expand All @@ -60,25 +62,38 @@ public final class JwtTimestampValidator implements OAuth2TokenValidator<Jwt> {
* A basic instance with no custom verification and the default max clock skew
*/
public JwtTimestampValidator() {
this(DEFAULT_MAX_CLOCK_SKEW);
this(DEFAULT_MAX_CLOCK_SKEW, false);
}

public JwtTimestampValidator(boolean required) {
this(DEFAULT_MAX_CLOCK_SKEW, required);
}

public JwtTimestampValidator(Duration clockSkew) {
this(clockSkew, false);
}

public JwtTimestampValidator(Duration clockSkew, boolean required) {
Assert.notNull(clockSkew, "clockSkew cannot be null");
this.required = required;
this.clockSkew = clockSkew;
}

@Override
public OAuth2TokenValidatorResult validate(Jwt jwt) {
Assert.notNull(jwt, "jwt cannot be null");
Instant expiry = jwt.getExpiresAt();
Instant notBefore = jwt.getNotBefore();
if (this.required && !(expiry != null || notBefore != null)) {
OAuth2Error oAuth2Error = createOAuth2Error("exp and nbf are required");
return OAuth2TokenValidatorResult.failure(oAuth2Error);
}
if (expiry != null) {
if (Instant.now(this.clock).minus(this.clockSkew).isAfter(expiry)) {
OAuth2Error oAuth2Error = createOAuth2Error(String.format("Jwt expired at %s", jwt.getExpiresAt()));
return OAuth2TokenValidatorResult.failure(oAuth2Error);
}
}
Instant notBefore = jwt.getNotBefore();
if (notBefore != null) {
if (Instant.now(this.clock).plus(this.clockSkew).isBefore(notBefore)) {
OAuth2Error oAuth2Error = createOAuth2Error(String.format("Jwt used before %s", jwt.getNotBefore()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,73 @@
*/
class JwtAudienceValidatorTests {

private final JwtAudienceValidator validator = new JwtAudienceValidator("audience");
private final JwtAudienceValidator validatorDefault = new JwtAudienceValidator("audience");

private final JwtAudienceValidator validatorRequiredTrue = new JwtAudienceValidator("audience", true);

private final JwtAudienceValidator validatorRequiredFalse = new JwtAudienceValidator("audience", false);

@Test
void givenRequiredDefaultJwtWithMatchingAudienceThenShouldValidate() {
Jwt jwt = TestJwts.jwt().audience(List.of("audience")).build();
OAuth2TokenValidatorResult result = this.validatorDefault.validate(jwt);
assertThat(result).isEqualTo(OAuth2TokenValidatorResult.success());
}

@Test
void givenRequiredJwtWithMatchingAudienceThenShouldValidate() {
Jwt jwt = TestJwts.jwt().audience(List.of("audience")).build();
OAuth2TokenValidatorResult result = this.validatorRequiredTrue.validate(jwt);
assertThat(result).isEqualTo(OAuth2TokenValidatorResult.success());
}

@Test
void givenJwtWithMatchingAudienceThenShouldValidate() {
void givenNotRequiredJwtWithMatchingAudienceThenShouldValidate() {
Jwt jwt = TestJwts.jwt().audience(List.of("audience")).build();
OAuth2TokenValidatorResult result = this.validator.validate(jwt);
OAuth2TokenValidatorResult result = this.validatorRequiredFalse.validate(jwt);
assertThat(result).isEqualTo(OAuth2TokenValidatorResult.success());
}

@Test
void givenJwtWithoutMatchingAudienceThenShouldValidate() {
void givenRequiredDefaultJwtWithoutMatchingAudienceThenShouldValidate() {
Jwt jwt = TestJwts.jwt().audience(List.of("other")).build();
OAuth2TokenValidatorResult result = this.validator.validate(jwt);
OAuth2TokenValidatorResult result = this.validatorDefault.validate(jwt);
assertThat(result.hasErrors()).isTrue();
}

@Test
void givenJwtWithoutAudienceThenShouldValidate() {
void givenRequiredJwtWithoutMatchingAudienceThenShouldValidate() {
Jwt jwt = TestJwts.jwt().audience(List.of("other")).build();
OAuth2TokenValidatorResult result = this.validatorRequiredTrue.validate(jwt);
assertThat(result.hasErrors()).isTrue();
}

@Test
void givenNotRequiredJwtWithoutMatchingAudienceThenShouldValidate() {
Jwt jwt = TestJwts.jwt().audience(List.of("other")).build();
OAuth2TokenValidatorResult result = this.validatorRequiredFalse.validate(jwt);
assertThat(result.hasErrors()).isTrue();
}

@Test
void givenRequiredDefaultJwtWithoutAudienceThenShouldValidate() {
Jwt jwt = TestJwts.jwt().audience(null).build();
OAuth2TokenValidatorResult result = this.validator.validate(jwt);
OAuth2TokenValidatorResult result = this.validatorDefault.validate(jwt);
assertThat(result.hasErrors()).isTrue();
}

@Test
void givenRequiredJwtWithoutAudienceThenShouldValidate() {
Jwt jwt = TestJwts.jwt().audience(null).build();
OAuth2TokenValidatorResult result = this.validatorRequiredTrue.validate(jwt);
assertThat(result.hasErrors()).isTrue();
}

@Test
void givenNotRequiredJwtWithoutAudienceThenShouldValidate() {
Jwt jwt = TestJwts.jwt().audience(null).build();
OAuth2TokenValidatorResult result = this.validatorRequiredFalse.validate(jwt);
assertThat(result.hasErrors()).isFalse();
}

}
Loading