Skip to content

Commit 4fffee2

Browse files
RashadAnsarialdas
authored andcommitted
Add custom jwt extractor to jwt config
1 parent 6b5e62b commit 4fffee2

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

middleware/jwt.go

+28-21
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,14 @@ type (
6868
// - "form:<name>"
6969
// Multiply sources example:
7070
// - "header: Authorization,cookie: myowncookie"
71-
7271
TokenLookup string
7372

73+
// TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context.
74+
// This is one of the two options to provide a token extractor.
75+
// The order of precedence is user-defined TokenLookupFuncs, and TokenLookup.
76+
// You can also provide both if you want.
77+
TokenLookupFuncs []TokenLookupFunc
78+
7479
// AuthScheme to be used in the Authorization header.
7580
// Optional. Default value "Bearer".
7681
AuthScheme string
@@ -103,7 +108,8 @@ type (
103108
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
104109
JWTErrorHandlerWithContext func(error, echo.Context) error
105110

106-
jwtExtractor func(echo.Context) (string, error)
111+
// TokenLookupFunc defines a function for extracting JWT token from the given context.
112+
TokenLookupFunc func(echo.Context) (string, error)
107113
)
108114

109115
// Algorithms
@@ -120,13 +126,14 @@ var (
120126
var (
121127
// DefaultJWTConfig is the default JWT auth middleware config.
122128
DefaultJWTConfig = JWTConfig{
123-
Skipper: DefaultSkipper,
124-
SigningMethod: AlgorithmHS256,
125-
ContextKey: "user",
126-
TokenLookup: "header:" + echo.HeaderAuthorization,
127-
AuthScheme: "Bearer",
128-
Claims: jwt.MapClaims{},
129-
KeyFunc: nil,
129+
Skipper: DefaultSkipper,
130+
SigningMethod: AlgorithmHS256,
131+
ContextKey: "user",
132+
TokenLookup: "header:" + echo.HeaderAuthorization,
133+
TokenLookupFuncs: nil,
134+
AuthScheme: "Bearer",
135+
Claims: jwt.MapClaims{},
136+
KeyFunc: nil,
130137
}
131138
)
132139

@@ -163,7 +170,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
163170
if config.Claims == nil {
164171
config.Claims = DefaultJWTConfig.Claims
165172
}
166-
if config.TokenLookup == "" {
173+
if config.TokenLookup == "" && len(config.TokenLookupFuncs) == 0 {
167174
config.TokenLookup = DefaultJWTConfig.TokenLookup
168175
}
169176
if config.AuthScheme == "" {
@@ -179,7 +186,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
179186
// Initialize
180187
// Split sources
181188
sources := strings.Split(config.TokenLookup, ",")
182-
var extractors []jwtExtractor
189+
var extractors = config.TokenLookupFuncs
183190
for _, source := range sources {
184191
parts := strings.Split(source, ":")
185192

@@ -290,8 +297,8 @@ func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
290297
return config.SigningKey, nil
291298
}
292299

293-
// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header.
294-
func jwtFromHeader(header string, authScheme string) jwtExtractor {
300+
// jwtFromHeader returns a `TokenLookupFunc` that extracts token from the request header.
301+
func jwtFromHeader(header string, authScheme string) TokenLookupFunc {
295302
return func(c echo.Context) (string, error) {
296303
auth := c.Request().Header.Get(header)
297304
l := len(authScheme)
@@ -302,8 +309,8 @@ func jwtFromHeader(header string, authScheme string) jwtExtractor {
302309
}
303310
}
304311

305-
// jwtFromQuery returns a `jwtExtractor` that extracts token from the query string.
306-
func jwtFromQuery(param string) jwtExtractor {
312+
// jwtFromQuery returns a `TokenLookupFunc` that extracts token from the query string.
313+
func jwtFromQuery(param string) TokenLookupFunc {
307314
return func(c echo.Context) (string, error) {
308315
token := c.QueryParam(param)
309316
if token == "" {
@@ -313,8 +320,8 @@ func jwtFromQuery(param string) jwtExtractor {
313320
}
314321
}
315322

316-
// jwtFromParam returns a `jwtExtractor` that extracts token from the url param string.
317-
func jwtFromParam(param string) jwtExtractor {
323+
// jwtFromParam returns a `TokenLookupFunc` that extracts token from the url param string.
324+
func jwtFromParam(param string) TokenLookupFunc {
318325
return func(c echo.Context) (string, error) {
319326
token := c.Param(param)
320327
if token == "" {
@@ -324,8 +331,8 @@ func jwtFromParam(param string) jwtExtractor {
324331
}
325332
}
326333

327-
// jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie.
328-
func jwtFromCookie(name string) jwtExtractor {
334+
// jwtFromCookie returns a `TokenLookupFunc` that extracts token from the named cookie.
335+
func jwtFromCookie(name string) TokenLookupFunc {
329336
return func(c echo.Context) (string, error) {
330337
cookie, err := c.Cookie(name)
331338
if err != nil {
@@ -335,8 +342,8 @@ func jwtFromCookie(name string) jwtExtractor {
335342
}
336343
}
337344

338-
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
339-
func jwtFromForm(name string) jwtExtractor {
345+
// jwtFromForm returns a `TokenLookupFunc` that extracts token from the form field.
346+
func jwtFromForm(name string) TokenLookupFunc {
340347
return func(c echo.Context) (string, error) {
341348
field := c.FormValue(name)
342349
if field == "" {

middleware/jwt_test.go

+24
Original file line numberDiff line numberDiff line change
@@ -603,3 +603,27 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
603603

604604
assert.Equal(t, http.StatusTeapot, res.Code)
605605
}
606+
607+
func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
608+
e := echo.New()
609+
610+
e.GET("/", func(c echo.Context) error {
611+
return c.String(http.StatusOK, "test")
612+
})
613+
614+
e.Use(JWTWithConfig(JWTConfig{
615+
TokenLookupFuncs: []TokenLookupFunc{
616+
func(c echo.Context) (string, error) {
617+
return c.Request().Header.Get("X-API-Key"), nil
618+
},
619+
},
620+
SigningKey: []byte("secret"),
621+
}))
622+
623+
req := httptest.NewRequest(http.MethodGet, "/", nil)
624+
req.Header.Set("X-API-Key", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
625+
res := httptest.NewRecorder()
626+
e.ServeHTTP(res, req)
627+
628+
assert.Equal(t, http.StatusOK, res.Code)
629+
}

0 commit comments

Comments
 (0)