Skip to content

Commit 0977eb4

Browse files
committed
feat: support multiple issuer:audience combinations by introducing an option for the expectedClaims. WithExpectedClaims can be called with multiple jwt.Expected parameters to allow different Issuer:Audience combinations to validate tokens
feat: support multiple issuers in a provider using WithAdditionalIssuers option Every effort has been made to ensure backwards compatibility. Some error messages will be different due to the wrapping of errors when multiple jwt.Expected are set. When validating the jwt, if an error is encountered, instead of returning immediately, the current error is wrapped. This is good and bad. Good because all verification failure causes are captured in a single wrapped error; Bad because all verification failure causes are captured in a single monolithic wrapped error. Unwrapping the error can be tedious if many jwt.Expected are included. There is likely a better way but this suits my purposes. A few more test cases will likely be needed in order to achieve true confidence in this change
1 parent f5f0a00 commit 0977eb4

File tree

11 files changed

+606
-53
lines changed

11 files changed

+606
-53
lines changed

README.md

-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ import (
4343
"log"
4444
"net/http"
4545

46-
"github.com/auth0/go-jwt-middleware/v2"
4746
"github.com/auth0/go-jwt-middleware/v2/validator"
4847
jwtmiddleware "github.com/auth0/go-jwt-middleware/v2"
4948
)

examples/gin-example/main.go

+52-2
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,29 @@ import (
3939
// "username": "user123",
4040
// "shouldReject": true
4141
// }
42+
//
43+
// You can also try out the /multiple endpoint. This endpoint accepts tokens signed by multiple issuers. Try the
44+
// token below which has a different issuer:
45+
//
46+
// eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnby1qd3QtbWlkZGxld2FyZS1tdWx0aXBsZS1leGFtcGxlIiwiYXVkIjoiYXVkaWVuY2UtbXVsdGlwbGUtZXhhbXBsZSIsInN1YiI6IjEyMzQ1Njc4OTAiLCJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE1MTYyMzkwMjIsInVzZXJuYW1lIjoidXNlcjEyMyJ9.9zV_bY1wAmQlMCPlXOppx1Y9_z_T_wNng9-yfQk4I0c
47+
//
48+
// which is signed with 'secret' and has the data:
49+
//
50+
// {
51+
// "iss": "go-jwt-middleware-multiple-example",
52+
// "aud": "audience-multiple-example",
53+
// "sub": "1234567890",
54+
// "name": "John Doe",
55+
// "iat": 1516239022,
56+
// "username": "user123"
57+
// }
58+
//
59+
// You can also try the previous tokens with the /multiple endpoint. The first token will be valid the second will fail because
60+
// the custom validator rejects it (shouldReject: true)
4261

4362
func main() {
4463
router := gin.Default()
64+
4565
router.GET("/", checkJWT(), func(ctx *gin.Context) {
4666
claims, ok := ctx.Request.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims)
4767
if !ok {
@@ -52,7 +72,37 @@ func main() {
5272
return
5373
}
5474

55-
customClaims, ok := claims.CustomClaims.(*CustomClaimsExample)
75+
localCustomClaims, ok := claims.CustomClaims.(*CustomClaimsExample)
76+
if !ok {
77+
ctx.AbortWithStatusJSON(
78+
http.StatusInternalServerError,
79+
map[string]string{"message": "Failed to cast custom JWT claims to specific type."},
80+
)
81+
return
82+
}
83+
84+
if len(localCustomClaims.Username) == 0 {
85+
ctx.AbortWithStatusJSON(
86+
http.StatusBadRequest,
87+
map[string]string{"message": "Username in JWT claims was empty."},
88+
)
89+
return
90+
}
91+
92+
ctx.JSON(http.StatusOK, claims)
93+
})
94+
95+
router.GET("/multiple", checkJWTMultiple(), func(ctx *gin.Context) {
96+
claims, ok := ctx.Request.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims)
97+
if !ok {
98+
ctx.AbortWithStatusJSON(
99+
http.StatusInternalServerError,
100+
map[string]string{"message": "Failed to get validated JWT claims."},
101+
)
102+
return
103+
}
104+
105+
localCustomClaims, ok := claims.CustomClaims.(*CustomClaimsExample)
56106
if !ok {
57107
ctx.AbortWithStatusJSON(
58108
http.StatusInternalServerError,
@@ -61,7 +111,7 @@ func main() {
61111
return
62112
}
63113

64-
if len(customClaims.Username) == 0 {
114+
if len(localCustomClaims.Username) == 0 {
65115
ctx.AbortWithStatusJSON(
66116
http.StatusBadRequest,
67117
map[string]string{"message": "Username in JWT claims was empty."},

examples/gin-example/middleware.go

+52-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"context"
5+
"gopkg.in/go-jose/go-jose.v2/jwt"
56
"log"
67
"net/http"
78
"time"
@@ -16,10 +17,12 @@ var (
1617
signingKey = []byte("secret")
1718

1819
// The issuer of our token.
19-
issuer = "go-jwt-middleware-example"
20+
issuer = "go-jwt-middleware-example"
21+
issuerTwo = "go-jwt-middleware-multiple-example"
2022

2123
// The audience of our token.
22-
audience = []string{"audience-example"}
24+
audience = []string{"audience-example"}
25+
audienceTwo = []string{"audience-multiple-example"}
2326

2427
// Our token must be signed using this data.
2528
keyFunc = func(ctx context.Context) (interface{}, error) {
@@ -76,3 +79,50 @@ func checkJWT() gin.HandlerFunc {
7679
}
7780
}
7881
}
82+
83+
func checkJWTMultiple() gin.HandlerFunc {
84+
// Set up the validator.
85+
jwtValidator, err := validator.NewValidator(
86+
keyFunc,
87+
validator.HS256,
88+
validator.WithCustomClaims(customClaims),
89+
validator.WithAllowedClockSkew(30*time.Second),
90+
validator.WithExpectedClaims(jwt.Expected{
91+
Issuer: issuer,
92+
Audience: audience,
93+
}, jwt.Expected{
94+
Issuer: issuerTwo,
95+
Audience: audienceTwo,
96+
}),
97+
)
98+
if err != nil {
99+
log.Fatalf("failed to set up the validator: %v", err)
100+
}
101+
102+
errorHandler := func(w http.ResponseWriter, r *http.Request, err error) {
103+
log.Printf("Encountered error while validating JWT: %v", err)
104+
}
105+
106+
middleware := jwtmiddleware.New(
107+
jwtValidator.ValidateToken,
108+
jwtmiddleware.WithErrorHandler(errorHandler),
109+
)
110+
111+
return func(ctx *gin.Context) {
112+
encounteredError := true
113+
var handler http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
114+
encounteredError = false
115+
ctx.Request = r
116+
ctx.Next()
117+
}
118+
119+
middleware.CheckJWT(handler).ServeHTTP(ctx.Writer, ctx.Request)
120+
121+
if encounteredError {
122+
ctx.AbortWithStatusJSON(
123+
http.StatusUnauthorized,
124+
map[string]string{"message": "JWT is invalid."},
125+
)
126+
}
127+
}
128+
}

examples/http-example/main.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ import (
88
"net/http"
99
"time"
1010

11-
"github.com/auth0/go-jwt-middleware/v2"
12-
"github.com/auth0/go-jwt-middleware/v2/validator"
1311
jwtmiddleware "github.com/auth0/go-jwt-middleware/v2"
12+
"github.com/auth0/go-jwt-middleware/v2/validator"
1413
)
1514

1615
var (

examples/http-jwks-example/main.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@ import (
77
"net/url"
88
"time"
99

10-
"github.com/auth0/go-jwt-middleware/v2"
10+
jwtmiddleware "github.com/auth0/go-jwt-middleware/v2"
1111
"github.com/auth0/go-jwt-middleware/v2/jwks"
1212
"github.com/auth0/go-jwt-middleware/v2/validator"
13-
jwtmiddleware "github.com/auth0/go-jwt-middleware/v2"
1413
)
1514

1615
var handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

extractor.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func AuthHeaderTokenExtractor(r *http.Request) (string, error) {
2323

2424
authHeaderParts := strings.Fields(authHeader)
2525
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
26-
return "", errors.New("Authorization header format must be Bearer {token}")
26+
return "", errors.New("authorization header format must be Bearer {token}")
2727
}
2828

2929
return authHeaderParts[1], nil
@@ -34,7 +34,7 @@ func AuthHeaderTokenExtractor(r *http.Request) (string, error) {
3434
func CookieTokenExtractor(cookieName string) TokenExtractor {
3535
return func(r *http.Request) (string, error) {
3636
cookie, err := r.Cookie(cookieName)
37-
if err == http.ErrNoCookie {
37+
if errors.Is(err, http.ErrNoCookie) {
3838
return "", nil // No cookie, then no JWT, so no error.
3939
}
4040

extractor_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func Test_AuthHeaderTokenExtractor(t *testing.T) {
3838
"Authorization": []string{"i-am-a-token"},
3939
},
4040
},
41-
wantError: "Authorization header format must be Bearer {token}",
41+
wantError: "authorization header format must be Bearer {token}",
4242
},
4343
}
4444

jwks/provider.go

+54-7
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ import (
2121
// getting and caching JWKS which can help reduce request time and potential
2222
// rate limiting from your provider.
2323
type Provider struct {
24-
IssuerURL *url.URL // Required.
25-
CustomJWKSURI *url.URL // Optional.
26-
Client *http.Client
24+
IssuerURL *url.URL // Required.
25+
CustomJWKSURI *url.URL // Optional.
26+
AdditionalProviders []Provider // Optional
27+
Client *http.Client
2728
}
2829

2930
// ProviderOption is how options for the Provider are set up.
@@ -32,14 +33,24 @@ type ProviderOption func(*Provider)
3233
// NewProvider builds and returns a new *Provider.
3334
func NewProvider(issuerURL *url.URL, opts ...ProviderOption) *Provider {
3435
p := &Provider{
35-
IssuerURL: issuerURL,
36-
Client: &http.Client{},
36+
Client: &http.Client{},
37+
AdditionalProviders: make([]Provider, 0),
38+
}
39+
40+
if issuerURL != nil {
41+
p.IssuerURL = issuerURL
3742
}
3843

3944
for _, opt := range opts {
4045
opt(p)
4146
}
4247

48+
for _, provider := range p.AdditionalProviders {
49+
if provider.Client == nil {
50+
provider.Client = p.Client
51+
}
52+
}
53+
4354
return p
4455
}
4556

@@ -56,13 +67,47 @@ func WithCustomJWKSURI(jwksURI *url.URL) ProviderOption {
5667
func WithCustomClient(c *http.Client) ProviderOption {
5768
return func(p *Provider) {
5869
p.Client = c
70+
for _, provider := range p.AdditionalProviders {
71+
provider.Client = c
72+
}
73+
}
74+
}
75+
76+
// WithAdditionalProviders allows validation with mutliple IssuerURLs if desired. If multiple issuers are specified,
77+
// a jwt may be signed by any of them and be considered valid
78+
func WithAdditionalProviders(issuerURL *url.URL, customJWKSURI *url.URL) ProviderOption {
79+
return func(p *Provider) {
80+
p.AdditionalProviders = append(p.AdditionalProviders, Provider{
81+
IssuerURL: issuerURL,
82+
CustomJWKSURI: customJWKSURI,
83+
Client: p.Client,
84+
})
5985
}
6086
}
6187

6288
// KeyFunc adheres to the keyFunc signature that the Validator requires.
6389
// While it returns an interface to adhere to keyFunc, as long as the
6490
// error is nil the type will be *jose.JSONWebKeySet.
6591
func (p *Provider) KeyFunc(ctx context.Context) (interface{}, error) {
92+
rawJwks, err := p.keyFunc(ctx)
93+
94+
if len(p.AdditionalProviders) == 0 {
95+
return rawJwks, err
96+
} else {
97+
var jwks *jose.JSONWebKeySet
98+
jwks = rawJwks.(*jose.JSONWebKeySet)
99+
for _, provider := range p.AdditionalProviders {
100+
if rawJwks, err = provider.keyFunc(ctx); err != nil {
101+
continue
102+
} else {
103+
jwks.Keys = append(jwks.Keys, rawJwks.(*jose.JSONWebKeySet).Keys...)
104+
}
105+
}
106+
return jwks, err
107+
}
108+
}
109+
110+
func (p *Provider) keyFunc(ctx context.Context) (interface{}, error) {
66111
jwksURI := p.CustomJWKSURI
67112
if jwksURI == nil {
68113
wkEndpoints, err := oidc.GetWellKnownEndpointsFromIssuerURL(ctx, p.Client, *p.IssuerURL)
@@ -85,10 +130,12 @@ func (p *Provider) KeyFunc(ctx context.Context) (interface{}, error) {
85130
if err != nil {
86131
return nil, err
87132
}
88-
defer response.Body.Close()
133+
defer func() {
134+
_ = response.Body.Close()
135+
}()
89136

90137
var jwks jose.JSONWebKeySet
91-
if err := json.NewDecoder(response.Body).Decode(&jwks); err != nil {
138+
if err = json.NewDecoder(response.Body).Decode(&jwks); err != nil {
92139
return nil, fmt.Errorf("could not decode jwks: %w", err)
93140
}
94141

validator/option.go

+14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package validator
22

33
import (
4+
"gopkg.in/go-jose/go-jose.v2/jwt"
45
"time"
56
)
67

@@ -26,3 +27,16 @@ func WithCustomClaims(f func() CustomClaims) Option {
2627
v.customClaims = f
2728
}
2829
}
30+
31+
// WithExpectedClaims allows fine-grained customization of the expected claims
32+
func WithExpectedClaims(expectedClaims ...jwt.Expected) Option {
33+
return func(v *Validator) {
34+
if len(expectedClaims) == 0 {
35+
return
36+
}
37+
if v.expectedClaims == nil {
38+
v.expectedClaims = make([]jwt.Expected, 0)
39+
}
40+
v.expectedClaims = append(v.expectedClaims, expectedClaims...)
41+
}
42+
}

0 commit comments

Comments
 (0)