Skip to content

Commit 1ac4a8f

Browse files
committed
Adds JWTConfig.ParseTokenFunc to JWT middleware to allow different libraries implementing JWT parsing.
1 parent fdacff0 commit 1ac4a8f

File tree

2 files changed

+228
-12
lines changed

2 files changed

+228
-12
lines changed

middleware/jwt.go

+36-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package middleware
22

33
import (
4+
"errors"
45
"fmt"
56
"net/http"
67
"reflect"
@@ -49,7 +50,8 @@ type (
4950
// Optional. Default value "user".
5051
ContextKey string
5152

52-
// Claims are extendable claims data defining token content.
53+
// Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation.
54+
// Not used if custom ParseTokenFunc is set.
5355
// Optional. Default value jwt.MapClaims
5456
Claims jwt.Claims
5557

@@ -74,13 +76,20 @@ type (
7476
// KeyFunc defines a user-defined function that supplies the public key for a token validation.
7577
// The function shall take care of verifying the signing algorithm and selecting the proper key.
7678
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
79+
// Used by default ParseTokenFunc implementation.
7780
//
7881
// When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
7982
// This is one of the three options to provide a token validation key.
8083
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
8184
// Required if neither SigningKeys nor SigningKey is provided.
85+
// Not used if custom ParseTokenFunc is set.
8286
// Default to an internal implementation verifying the signing algorithm and selecting the proper key.
8387
KeyFunc jwt.Keyfunc
88+
89+
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
90+
// parsing fails or parsed token is invalid.
91+
// Defaults to implementation using `github.com/dgrijalva/jwt-go` as JWT implementation library
92+
ParseTokenFunc func(auth string, c echo.Context) (interface{}, error)
8493
}
8594

8695
// JWTSuccessHandler defines a function which is executed for a valid token.
@@ -140,7 +149,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
140149
if config.Skipper == nil {
141150
config.Skipper = DefaultJWTConfig.Skipper
142151
}
143-
if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil {
152+
if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil {
144153
panic("echo: jwt middleware requires signing key")
145154
}
146155
if config.SigningMethod == "" {
@@ -161,6 +170,9 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
161170
if config.KeyFunc == nil {
162171
config.KeyFunc = config.defaultKeyFunc
163172
}
173+
if config.ParseTokenFunc == nil {
174+
config.ParseTokenFunc = config.defaultParseToken
175+
}
164176

165177
// Initialize
166178
// Split sources
@@ -214,16 +226,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
214226
return err
215227
}
216228

217-
token := new(jwt.Token)
218-
// Issue #647, #656
219-
if _, ok := config.Claims.(jwt.MapClaims); ok {
220-
token, err = jwt.Parse(auth, config.KeyFunc)
221-
} else {
222-
t := reflect.ValueOf(config.Claims).Type().Elem()
223-
claims := reflect.New(t).Interface().(jwt.Claims)
224-
token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc)
225-
}
226-
if err == nil && token.Valid {
229+
token, err := config.ParseTokenFunc(auth, c)
230+
if err == nil {
227231
// Store user information from token into context.
228232
c.Set(config.ContextKey, token)
229233
if config.SuccessHandler != nil {
@@ -246,6 +250,26 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
246250
}
247251
}
248252

253+
func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) {
254+
token := new(jwt.Token)
255+
var err error
256+
// Issue #647, #656
257+
if _, ok := config.Claims.(jwt.MapClaims); ok {
258+
token, err = jwt.Parse(auth, config.KeyFunc)
259+
} else {
260+
t := reflect.ValueOf(config.Claims).Type().Elem()
261+
claims := reflect.New(t).Interface().(jwt.Claims)
262+
token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc)
263+
}
264+
if err != nil {
265+
return nil, err
266+
}
267+
if !token.Valid {
268+
return nil, errors.New("invalid token")
269+
}
270+
return token, nil
271+
}
272+
249273
// defaultKeyFunc returns a signing key of the given token.
250274
func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
251275
// Check the signing method

middleware/jwt_test.go

+192
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package middleware
22

33
import (
44
"errors"
5+
"fmt"
56
"net/http"
67
"net/http/httptest"
78
"net/url"
@@ -404,3 +405,194 @@ func TestJWTwithKID(t *testing.T) {
404405
}
405406
}
406407
}
408+
409+
func TestJWTConfig_skipper(t *testing.T) {
410+
e := echo.New()
411+
412+
e.Use(JWTWithConfig(JWTConfig{
413+
Skipper: func(context echo.Context) bool {
414+
return true // skip everything
415+
},
416+
SigningKey: []byte("secret"),
417+
}))
418+
419+
isCalled := false
420+
e.GET("/", func(c echo.Context) error {
421+
isCalled = true
422+
return c.String(http.StatusTeapot, "test")
423+
})
424+
425+
req := httptest.NewRequest(http.MethodGet, "/", nil)
426+
res := httptest.NewRecorder()
427+
e.ServeHTTP(res, req)
428+
429+
assert.Equal(t, http.StatusTeapot, res.Code)
430+
assert.True(t, isCalled)
431+
}
432+
433+
func TestJWTConfig_BeforeFunc(t *testing.T) {
434+
e := echo.New()
435+
e.GET("/", func(c echo.Context) error {
436+
return c.String(http.StatusTeapot, "test")
437+
})
438+
439+
isCalled := false
440+
e.Use(JWTWithConfig(JWTConfig{
441+
BeforeFunc: func(context echo.Context) {
442+
isCalled = true
443+
},
444+
SigningKey: []byte("secret"),
445+
}))
446+
447+
req := httptest.NewRequest(http.MethodGet, "/", nil)
448+
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
449+
res := httptest.NewRecorder()
450+
e.ServeHTTP(res, req)
451+
452+
assert.Equal(t, http.StatusTeapot, res.Code)
453+
assert.True(t, isCalled)
454+
}
455+
456+
func TestJWTConfig_extractorErrorHandling(t *testing.T) {
457+
var testCases = []struct {
458+
name string
459+
given JWTConfig
460+
expectStatusCode int
461+
}{
462+
{
463+
name: "ok, ErrorHandler is executed",
464+
given: JWTConfig{
465+
SigningKey: []byte("secret"),
466+
ErrorHandler: func(err error) error {
467+
return echo.NewHTTPError(http.StatusTeapot, "custom_error")
468+
},
469+
},
470+
expectStatusCode: http.StatusTeapot,
471+
},
472+
{
473+
name: "ok, ErrorHandlerWithContext is executed",
474+
given: JWTConfig{
475+
SigningKey: []byte("secret"),
476+
ErrorHandlerWithContext: func(err error, context echo.Context) error {
477+
return echo.NewHTTPError(http.StatusTeapot, "custom_error")
478+
},
479+
},
480+
expectStatusCode: http.StatusTeapot,
481+
},
482+
}
483+
484+
for _, tc := range testCases {
485+
t.Run(tc.name, func(t *testing.T) {
486+
e := echo.New()
487+
e.GET("/", func(c echo.Context) error {
488+
return c.String(http.StatusNotImplemented, "should not end up here")
489+
})
490+
491+
e.Use(JWTWithConfig(tc.given))
492+
493+
req := httptest.NewRequest(http.MethodGet, "/", nil)
494+
res := httptest.NewRecorder()
495+
e.ServeHTTP(res, req)
496+
497+
assert.Equal(t, tc.expectStatusCode, res.Code)
498+
})
499+
}
500+
}
501+
502+
func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
503+
var testCases = []struct {
504+
name string
505+
given JWTConfig
506+
expectErr string
507+
}{
508+
{
509+
name: "ok, ErrorHandler is executed",
510+
given: JWTConfig{
511+
SigningKey: []byte("secret"),
512+
ErrorHandler: func(err error) error {
513+
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error())
514+
},
515+
},
516+
expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n",
517+
},
518+
{
519+
name: "ok, ErrorHandlerWithContext is executed",
520+
given: JWTConfig{
521+
SigningKey: []byte("secret"),
522+
ErrorHandlerWithContext: func(err error, context echo.Context) error {
523+
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error())
524+
},
525+
},
526+
expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n",
527+
},
528+
}
529+
530+
for _, tc := range testCases {
531+
t.Run(tc.name, func(t *testing.T) {
532+
e := echo.New()
533+
//e.Debug = true
534+
e.GET("/", func(c echo.Context) error {
535+
return c.String(http.StatusNotImplemented, "should not end up here")
536+
})
537+
538+
config := tc.given
539+
parseTokenCalled := false
540+
config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) {
541+
parseTokenCalled = true
542+
return nil, errors.New("parsing failed")
543+
}
544+
e.Use(JWTWithConfig(config))
545+
546+
req := httptest.NewRequest(http.MethodGet, "/", nil)
547+
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
548+
res := httptest.NewRecorder()
549+
550+
e.ServeHTTP(res, req)
551+
552+
assert.Equal(t, http.StatusTeapot, res.Code)
553+
assert.Equal(t, tc.expectErr, res.Body.String())
554+
assert.True(t, parseTokenCalled)
555+
})
556+
}
557+
}
558+
559+
func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
560+
e := echo.New()
561+
e.GET("/", func(c echo.Context) error {
562+
return c.String(http.StatusTeapot, "test")
563+
})
564+
565+
// example of minimal custom ParseTokenFunc implementation. Allows you to use different versions of `github.com/dgrijalva/jwt-go`
566+
// with current JWT middleware
567+
signingKey := []byte("secret")
568+
569+
config := JWTConfig{
570+
ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) {
571+
keyFunc := func(t *jwt.Token) (interface{}, error) {
572+
if t.Method.Alg() != "HS256" {
573+
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
574+
}
575+
return signingKey, nil
576+
}
577+
578+
// claims are of type `jwt.MapClaims` when token is created with `jwt.Parse`
579+
token, err := jwt.Parse(auth, keyFunc)
580+
if err != nil {
581+
return nil, err
582+
}
583+
if !token.Valid {
584+
return nil, errors.New("invalid token")
585+
}
586+
return token, nil
587+
},
588+
}
589+
590+
e.Use(JWTWithConfig(config))
591+
592+
req := httptest.NewRequest(http.MethodGet, "/", nil)
593+
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
594+
res := httptest.NewRecorder()
595+
e.ServeHTTP(res, req)
596+
597+
assert.Equal(t, http.StatusTeapot, res.Code)
598+
}

0 commit comments

Comments
 (0)