Skip to content

Commit 2b4c5a4

Browse files
committed
expose source for KeyAuth/JWT key/token validation/parsing function to allow custom logic depending from where key/token value was extracted
1 parent 0d85116 commit 2b4c5a4

10 files changed

+101
-50
lines changed

middleware/basic_auth.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
7474

7575
b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
7676
if errDecode != nil {
77-
lastError = fmt.Errorf("invalid basic auth value: %w", errDecode)
77+
lastError = echo.ErrUnauthorized.WithInternal(fmt.Errorf("invalid basic auth value: %w", errDecode))
7878
continue
7979
}
8080
idx := bytes.IndexByte(b, ':')

middleware/basic_auth_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func TestBasicAuth(t *testing.T) {
5656
name: "nok, not base64 Authorization header",
5757
givenConfig: defaultConfig,
5858
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
59-
expectErr: "invalid basic auth value: illegal base64 data at input byte 3",
59+
expectErr: "code=401, message=Unauthorized, internal=invalid basic auth value: illegal base64 data at input byte 3",
6060
},
6161
{
6262
name: "nok, missing Authorization header",

middleware/csrf.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
144144
var lastTokenErr error
145145
outer:
146146
for _, extractor := range extractors {
147-
clientTokens, err := extractor(c)
147+
clientTokens, _, err := extractor(c)
148148
if err != nil {
149149
lastExtractorErr = err
150150
continue

middleware/extractor.go

+37-19
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,24 @@ const (
1313
extractorLimit = 20
1414
)
1515

16+
// ExtractorSource is type to indicate source for extracted value
17+
type ExtractorSource string
18+
19+
const (
20+
// ExtractorSourceHeader means value was extracted from request header
21+
ExtractorSourceHeader ExtractorSource = "header"
22+
// ExtractorSourceQuery means value was extracted from request query parameters
23+
ExtractorSourceQuery ExtractorSource = "query"
24+
// ExtractorSourcePathParam means value was extracted from route path parameters
25+
ExtractorSourcePathParam ExtractorSource = "param"
26+
// ExtractorSourceCookie means value was extracted from request cookies
27+
ExtractorSourceCookie ExtractorSource = "cookie"
28+
// ExtractorSourceForm means value was extracted from request form values
29+
ExtractorSourceForm ExtractorSource = "form"
30+
// ExtractorSourceCustom means value was extracted by custom extractor
31+
ExtractorSourceCustom ExtractorSource = "custom"
32+
)
33+
1634
// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups
1735
type ValueExtractorError struct {
1836
message string
@@ -31,7 +49,7 @@ var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing valu
3149
var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"}
3250

3351
// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
34-
type ValuesExtractor func(c echo.Context) ([]string, error)
52+
type ValuesExtractor func(c echo.Context) ([]string, ExtractorSource, error)
3553

3654
func createExtractors(lookups string) ([]ValuesExtractor, error) {
3755
if lookups == "" {
@@ -75,10 +93,10 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
7593
prefixLen := len(valuePrefix)
7694
// standard library parses http.Request header keys in canonical form but we may provide something else so fix this
7795
header = textproto.CanonicalMIMEHeaderKey(header)
78-
return func(c echo.Context) ([]string, error) {
96+
return func(c echo.Context) ([]string, ExtractorSource, error) {
7997
values := c.Request().Header.Values(header)
8098
if len(values) == 0 {
81-
return nil, errHeaderExtractorValueMissing
99+
return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
82100
}
83101

84102
result := make([]string, 0)
@@ -100,30 +118,30 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
100118

101119
if len(result) == 0 {
102120
if prefixLen > 0 {
103-
return nil, errHeaderExtractorValueInvalid
121+
return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid
104122
}
105-
return nil, errHeaderExtractorValueMissing
123+
return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
106124
}
107-
return result, nil
125+
return result, ExtractorSourceHeader, nil
108126
}
109127
}
110128

111129
// valuesFromQuery returns a function that extracts values from the query string.
112130
func valuesFromQuery(param string) ValuesExtractor {
113-
return func(c echo.Context) ([]string, error) {
131+
return func(c echo.Context) ([]string, ExtractorSource, error) {
114132
result := c.QueryParams()[param]
115133
if len(result) == 0 {
116-
return nil, errQueryExtractorValueMissing
134+
return nil, ExtractorSourceQuery, errQueryExtractorValueMissing
117135
} else if len(result) > extractorLimit-1 {
118136
result = result[:extractorLimit]
119137
}
120-
return result, nil
138+
return result, ExtractorSourceQuery, nil
121139
}
122140
}
123141

124142
// valuesFromParam returns a function that extracts values from the url param string.
125143
func valuesFromParam(param string) ValuesExtractor {
126-
return func(c echo.Context) ([]string, error) {
144+
return func(c echo.Context) ([]string, ExtractorSource, error) {
127145
result := make([]string, 0)
128146
for i, p := range c.PathParams() {
129147
if param == p.Name {
@@ -134,18 +152,18 @@ func valuesFromParam(param string) ValuesExtractor {
134152
}
135153
}
136154
if len(result) == 0 {
137-
return nil, errParamExtractorValueMissing
155+
return nil, ExtractorSourcePathParam, errParamExtractorValueMissing
138156
}
139-
return result, nil
157+
return result, ExtractorSourcePathParam, nil
140158
}
141159
}
142160

143161
// valuesFromCookie returns a function that extracts values from the named cookie.
144162
func valuesFromCookie(name string) ValuesExtractor {
145-
return func(c echo.Context) ([]string, error) {
163+
return func(c echo.Context) ([]string, ExtractorSource, error) {
146164
cookies := c.Cookies()
147165
if len(cookies) == 0 {
148-
return nil, errCookieExtractorValueMissing
166+
return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
149167
}
150168

151169
result := make([]string, 0)
@@ -158,26 +176,26 @@ func valuesFromCookie(name string) ValuesExtractor {
158176
}
159177
}
160178
if len(result) == 0 {
161-
return nil, errCookieExtractorValueMissing
179+
return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
162180
}
163-
return result, nil
181+
return result, ExtractorSourceCookie, nil
164182
}
165183
}
166184

167185
// valuesFromForm returns a function that extracts values from the form field.
168186
func valuesFromForm(name string) ValuesExtractor {
169-
return func(c echo.Context) ([]string, error) {
187+
return func(c echo.Context) ([]string, ExtractorSource, error) {
170188
if c.Request().Form == nil {
171189
_ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does
172190
}
173191
values := c.Request().Form[name]
174192
if len(values) == 0 {
175-
return nil, errFormExtractorValueMissing
193+
return nil, ExtractorSourceForm, errFormExtractorValueMissing
176194
}
177195
if len(values) > extractorLimit-1 {
178196
values = values[:extractorLimit]
179197
}
180198
result := append([]string{}, values...)
181-
return result, nil
199+
return result, ExtractorSourceForm, nil
182200
}
183201
}

middleware/extractor_test.go

+18-6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ func TestCreateExtractors(t *testing.T) {
2020
givenPathParams echo.PathParams
2121
whenLoopups string
2222
expectValues []string
23+
expectSource ExtractorSource
2324
expectCreateError string
2425
expectError string
2526
}{
@@ -32,6 +33,7 @@ func TestCreateExtractors(t *testing.T) {
3233
},
3334
whenLoopups: "header:Authorization:Bearer ",
3435
expectValues: []string{"token"},
36+
expectSource: ExtractorSourceHeader,
3537
},
3638
{
3739
name: "ok, form",
@@ -45,6 +47,7 @@ func TestCreateExtractors(t *testing.T) {
4547
},
4648
whenLoopups: "form:name",
4749
expectValues: []string{"Jon Snow"},
50+
expectSource: ExtractorSourceForm,
4851
},
4952
{
5053
name: "ok, cookie",
@@ -55,6 +58,7 @@ func TestCreateExtractors(t *testing.T) {
5558
},
5659
whenLoopups: "cookie:_csrf",
5760
expectValues: []string{"token"},
61+
expectSource: ExtractorSourceCookie,
5862
},
5963
{
6064
name: "ok, param",
@@ -63,6 +67,7 @@ func TestCreateExtractors(t *testing.T) {
6367
},
6468
whenLoopups: "param:id",
6569
expectValues: []string{"123"},
70+
expectSource: ExtractorSourcePathParam,
6671
},
6772
{
6873
name: "ok, query",
@@ -72,6 +77,7 @@ func TestCreateExtractors(t *testing.T) {
7277
},
7378
whenLoopups: "query:id",
7479
expectValues: []string{"999"},
80+
expectSource: ExtractorSourceQuery,
7581
},
7682
{
7783
name: "nok, invalid lookup",
@@ -102,8 +108,9 @@ func TestCreateExtractors(t *testing.T) {
102108
assert.NoError(t, err)
103109

104110
for _, e := range extractors {
105-
values, eErr := e(c)
111+
values, source, eErr := e(c)
106112
assert.Equal(t, tc.expectValues, values)
113+
assert.Equal(t, tc.expectSource, source)
107114
if tc.expectError != "" {
108115
assert.EqualError(t, eErr, tc.expectError)
109116
return
@@ -228,8 +235,9 @@ func TestValuesFromHeader(t *testing.T) {
228235

229236
extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix)
230237

231-
values, err := extractor(c)
238+
values, source, err := extractor(c)
232239
assert.Equal(t, tc.expectValues, values)
240+
assert.Equal(t, ExtractorSourceHeader, source)
233241
if tc.expectError != "" {
234242
assert.EqualError(t, err, tc.expectError)
235243
} else {
@@ -289,8 +297,9 @@ func TestValuesFromQuery(t *testing.T) {
289297

290298
extractor := valuesFromQuery(tc.whenName)
291299

292-
values, err := extractor(c)
300+
values, source, err := extractor(c)
293301
assert.Equal(t, tc.expectValues, values)
302+
assert.Equal(t, ExtractorSourceQuery, source)
294303
if tc.expectError != "" {
295304
assert.EqualError(t, err, tc.expectError)
296305
} else {
@@ -368,8 +377,9 @@ func TestValuesFromParam(t *testing.T) {
368377

369378
extractor := valuesFromParam(tc.whenName)
370379

371-
values, err := extractor(c)
380+
values, source, err := extractor(c)
372381
assert.Equal(t, tc.expectValues, values)
382+
assert.Equal(t, ExtractorSourcePathParam, source)
373383
if tc.expectError != "" {
374384
assert.EqualError(t, err, tc.expectError)
375385
} else {
@@ -448,8 +458,9 @@ func TestValuesFromCookie(t *testing.T) {
448458

449459
extractor := valuesFromCookie(tc.whenName)
450460

451-
values, err := extractor(c)
461+
values, source, err := extractor(c)
452462
assert.Equal(t, tc.expectValues, values)
463+
assert.Equal(t, ExtractorSourceCookie, source)
453464
if tc.expectError != "" {
454465
assert.EqualError(t, err, tc.expectError)
455466
} else {
@@ -578,8 +589,9 @@ func TestValuesFromForm(t *testing.T) {
578589

579590
extractor := valuesFromForm(tc.whenName)
580591

581-
values, err := extractor(c)
592+
values, source, err := extractor(c)
582593
assert.Equal(t, tc.expectValues, values)
594+
assert.Equal(t, ExtractorSourceForm, source)
583595
if tc.expectError != "" {
584596
assert.EqualError(t, err, tc.expectError)
585597
} else {

middleware/jwt.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ type JWTConfig struct {
6464
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
6565
// parsing fails or parsed token is invalid.
6666
// Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library
67-
ParseTokenFunc func(c echo.Context, auth string) (interface{}, error)
67+
ParseTokenFunc func(c echo.Context, auth string, source ExtractorSource) (interface{}, error)
6868
}
6969

7070
// JWTSuccessHandler defines a function which is executed for a valid token.
@@ -101,7 +101,7 @@ var DefaultJWTConfig = JWTConfig{
101101
// For missing token, it returns "400 - Bad Request" error.
102102
//
103103
// See: https://jwt.io/introduction
104-
func JWT(parseTokenFunc func(c echo.Context, auth string) (interface{}, error)) echo.MiddlewareFunc {
104+
func JWT(parseTokenFunc func(c echo.Context, auth string, source ExtractorSource) (interface{}, error)) echo.MiddlewareFunc {
105105
c := DefaultJWTConfig
106106
c.ParseTokenFunc = parseTokenFunc
107107
return JWTWithConfig(c)
@@ -152,13 +152,13 @@ func (config JWTConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
152152
var lastExtractorErr error
153153
var lastTokenErr error
154154
for _, extractor := range extractors {
155-
auths, extrErr := extractor(c)
155+
auths, source, extrErr := extractor(c)
156156
if extrErr != nil {
157157
lastExtractorErr = extrErr
158158
continue
159159
}
160160
for _, auth := range auths {
161-
token, err := config.ParseTokenFunc(c, auth)
161+
token, err := config.ParseTokenFunc(c, auth, source)
162162
if err != nil {
163163
lastTokenErr = err
164164
continue

middleware/jwt_external_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import (
2121
// This is one of the options to provide a token validation key.
2222
// The order of precedence is a user-defined SigningKeys and SigningKey.
2323
// Required if signingKey is not provided
24-
func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string) (interface{}, error) {
24+
func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]interface{}) func(c echo.Context, auth string, source middleware.ExtractorSource) (interface{}, error) {
2525
// keyFunc defines a user-defined function that supplies the public key for a token validation.
2626
// The function shall take care of verifying the signing algorithm and selecting the proper key.
2727
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
@@ -41,7 +41,7 @@ func CreateJWTGoParseTokenFunc(signingKey interface{}, signingKeys map[string]in
4141
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
4242
}
4343

44-
return func(c echo.Context, auth string) (interface{}, error) {
44+
return func(c echo.Context, auth string, source middleware.ExtractorSource) (interface{}, error) {
4545
token, err := jwt.ParseWithClaims(auth, jwt.MapClaims{}, keyFunc) // you could add your default claims here
4646
if err != nil {
4747
return nil, err

0 commit comments

Comments
 (0)