Skip to content

Commit 5f9f144

Browse files
committed
middleware: basic auth middleware can extract and check multiple auth headers
1 parent c7d6d43 commit 5f9f144

File tree

4 files changed

+178
-80
lines changed

4 files changed

+178
-80
lines changed

middleware/basic_auth.go

+26-18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package middleware
22

33
import (
4+
"bytes"
45
"encoding/base64"
56
"net/http"
67
"strconv"
@@ -15,7 +16,8 @@ type (
1516
// Skipper defines a function to skip middleware.
1617
Skipper Skipper
1718

18-
// Validator is a function to validate BasicAuth credentials.
19+
// Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic
20+
// auth headers this function would be called once for each header until first valid result is returned
1921
// Required.
2022
Validator BasicAuthValidator
2123

@@ -71,30 +73,36 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
7173
return next(c)
7274
}
7375

74-
auth := c.Request().Header.Get(echo.HeaderAuthorization)
76+
var lastError error
7577
l := len(basic)
78+
for i, auth := range c.Request().Header[echo.HeaderAuthorization] {
79+
if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) {
80+
continue
81+
}
7682

77-
if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) {
7883
// Invalid base64 shouldn't be treated as error
7984
// instead should be treated as invalid client input
80-
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
81-
if err != nil {
82-
return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err)
85+
b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
86+
if errDecode != nil {
87+
lastError = echo.NewHTTPError(http.StatusBadRequest).WithInternal(errDecode)
88+
continue
8389
}
84-
85-
cred := string(b)
86-
for i := 0; i < len(cred); i++ {
87-
if cred[i] == ':' {
88-
// Verify credentials
89-
valid, err := config.Validator(cred[:i], cred[i+1:], c)
90-
if err != nil {
91-
return err
92-
} else if valid {
93-
return next(c)
94-
}
95-
break
90+
idx := bytes.IndexByte(b, ':')
91+
if idx >= 0 {
92+
valid, errValidate := config.Validator(string(b[:idx]), string(b[idx+1:]), c)
93+
if errValidate != nil {
94+
lastError = errValidate
95+
} else if valid {
96+
return next(c)
9697
}
9798
}
99+
if i >= headerCountLimit { // guard against attacker maliciously sending huge amount of invalid headers
100+
break
101+
}
102+
}
103+
104+
if lastError != nil {
105+
return lastError
98106
}
99107

100108
realm := defaultRealm

middleware/basic_auth_test.go

+136-46
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package middleware
22

33
import (
44
"encoding/base64"
5+
"errors"
56
"net/http"
67
"net/http/httptest"
78
"strings"
@@ -11,11 +12,139 @@ import (
1112
"github.com/stretchr/testify/assert"
1213
)
1314

15+
func TestBasicAuthWithConfig(t *testing.T) {
16+
validatorFunc := func(u, p string, c echo.Context) (bool, error) {
17+
if u == "joe" && p == "secret" {
18+
return true, nil
19+
}
20+
if u == "error" {
21+
return false, errors.New(p)
22+
}
23+
return false, nil
24+
}
25+
defaultConfig := BasicAuthConfig{Validator: validatorFunc}
26+
27+
// we can not add OK value here because ranging over map returns random order. We just try to trigger break
28+
tooManyAuths := make([]string, 0)
29+
for i := 0; i < headerCountLimit+2; i++ {
30+
tooManyAuths = append(tooManyAuths, basic+" "+base64.StdEncoding.EncodeToString([]byte("nope:nope")))
31+
}
32+
33+
var testCases = []struct {
34+
name string
35+
givenConfig BasicAuthConfig
36+
whenAuth []string
37+
expectHeader string
38+
expectErr string
39+
}{
40+
{
41+
name: "ok",
42+
givenConfig: defaultConfig,
43+
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
44+
},
45+
{
46+
name: "ok, from multiple auth headers one is ok",
47+
givenConfig: defaultConfig,
48+
whenAuth: []string{
49+
"Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), // different type
50+
basic + " NOT_BASE64", // invalid basic auth
51+
basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), // OK
52+
},
53+
},
54+
{
55+
name: "nok, invalid Authorization header",
56+
givenConfig: defaultConfig,
57+
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
58+
expectHeader: basic + ` realm=Restricted`,
59+
expectErr: "code=401, message=Unauthorized",
60+
},
61+
{
62+
name: "nok, not base64 Authorization header",
63+
givenConfig: defaultConfig,
64+
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
65+
expectErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 3",
66+
},
67+
{
68+
name: "nok, missing Authorization header",
69+
givenConfig: defaultConfig,
70+
expectHeader: basic + ` realm=Restricted`,
71+
expectErr: "code=401, message=Unauthorized",
72+
},
73+
{
74+
name: "nok, too many invalid Authorization header",
75+
givenConfig: defaultConfig,
76+
whenAuth: tooManyAuths,
77+
expectHeader: basic + ` realm=Restricted`,
78+
expectErr: "code=401, message=Unauthorized",
79+
},
80+
{
81+
name: "ok, realm",
82+
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
83+
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
84+
},
85+
{
86+
name: "ok, realm, case-insensitive header scheme",
87+
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
88+
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
89+
},
90+
{
91+
name: "nok, realm, invalid Authorization header",
92+
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
93+
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
94+
expectHeader: basic + ` realm="someRealm"`,
95+
expectErr: "code=401, message=Unauthorized",
96+
},
97+
{
98+
name: "nok, validator func returns an error",
99+
givenConfig: defaultConfig,
100+
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
101+
expectErr: "my_error",
102+
},
103+
{
104+
name: "ok, skipped",
105+
givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c echo.Context) bool {
106+
return true
107+
}},
108+
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
109+
},
110+
}
111+
112+
for _, tc := range testCases {
113+
t.Run(tc.name, func(t *testing.T) {
114+
e := echo.New()
115+
116+
mw := BasicAuthWithConfig(tc.givenConfig)
117+
118+
h := mw(func(c echo.Context) error {
119+
return c.String(http.StatusTeapot, "test")
120+
})
121+
122+
req := httptest.NewRequest(http.MethodGet, "/", nil)
123+
res := httptest.NewRecorder()
124+
125+
if len(tc.whenAuth) != 0 {
126+
for _, a := range tc.whenAuth {
127+
req.Header.Add(echo.HeaderAuthorization, a)
128+
}
129+
}
130+
err := h(e.NewContext(req, res))
131+
132+
if tc.expectErr != "" {
133+
assert.Equal(t, http.StatusOK, res.Code)
134+
assert.EqualError(t, err, tc.expectErr)
135+
} else {
136+
assert.Equal(t, http.StatusTeapot, res.Code)
137+
assert.NoError(t, err)
138+
}
139+
if tc.expectHeader != "" {
140+
assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
141+
}
142+
})
143+
}
144+
}
145+
14146
func TestBasicAuth(t *testing.T) {
15147
e := echo.New()
16-
req := httptest.NewRequest(http.MethodGet, "/", nil)
17-
res := httptest.NewRecorder()
18-
c := e.NewContext(req, res)
19148
f := func(u, p string, c echo.Context) (bool, error) {
20149
if u == "joe" && p == "secret" {
21150
return true, nil
@@ -26,50 +155,11 @@ func TestBasicAuth(t *testing.T) {
26155
return c.String(http.StatusOK, "test")
27156
})
28157

29-
// Valid credentials
30-
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
31-
req.Header.Set(echo.HeaderAuthorization, auth)
32-
assert.NoError(t, h(c))
33-
34-
h = BasicAuthWithConfig(BasicAuthConfig{
35-
Skipper: nil,
36-
Validator: f,
37-
Realm: "someRealm",
38-
})(func(c echo.Context) error {
39-
return c.String(http.StatusOK, "test")
40-
})
41-
42-
// Valid credentials
43-
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
44-
req.Header.Set(echo.HeaderAuthorization, auth)
45-
assert.NoError(t, h(c))
158+
req := httptest.NewRequest(http.MethodGet, "/", nil)
159+
res := httptest.NewRecorder()
160+
c := e.NewContext(req, res)
46161

47-
// Case-insensitive header scheme
48-
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
162+
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
49163
req.Header.Set(echo.HeaderAuthorization, auth)
50164
assert.NoError(t, h(c))
51-
52-
// Invalid credentials
53-
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
54-
req.Header.Set(echo.HeaderAuthorization, auth)
55-
he := h(c).(*echo.HTTPError)
56-
assert.Equal(t, http.StatusUnauthorized, he.Code)
57-
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
58-
59-
// Invalid base64 string
60-
auth = basic + " invalidString"
61-
req.Header.Set(echo.HeaderAuthorization, auth)
62-
he = h(c).(*echo.HTTPError)
63-
assert.Equal(t, http.StatusBadRequest, he.Code)
64-
65-
// Missing Authorization header
66-
req.Header.Del(echo.HeaderAuthorization)
67-
he = h(c).(*echo.HTTPError)
68-
assert.Equal(t, http.StatusUnauthorized, he.Code)
69-
70-
// Invalid Authorization header
71-
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
72-
req.Header.Set(echo.HeaderAuthorization, auth)
73-
he = h(c).(*echo.HTTPError)
74-
assert.Equal(t, http.StatusUnauthorized, he.Code)
75165
}

middleware/extractor.go

+10-10
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ import (
99
)
1010

1111
const (
12-
// extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion
12+
// headerCountLimit is arbitrary number to limit number of headers processed. this limits possible resource exhaustion
1313
// attack vector
14-
extractorLimit = 20
14+
headerCountLimit = 20
1515
)
1616

1717
var errHeaderExtractorValueMissing = errors.New("missing value in request header")
@@ -105,14 +105,14 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
105105
for i, value := range values {
106106
if prefixLen == 0 {
107107
result = append(result, value)
108-
if i >= extractorLimit-1 {
108+
if i >= headerCountLimit-1 {
109109
break
110110
}
111111
continue
112112
}
113113
if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
114114
result = append(result, value[prefixLen:])
115-
if i >= extractorLimit-1 {
115+
if i >= headerCountLimit-1 {
116116
break
117117
}
118118
}
@@ -134,8 +134,8 @@ func valuesFromQuery(param string) ValuesExtractor {
134134
result := c.QueryParams()[param]
135135
if len(result) == 0 {
136136
return nil, errQueryExtractorValueMissing
137-
} else if len(result) > extractorLimit-1 {
138-
result = result[:extractorLimit]
137+
} else if len(result) > headerCountLimit-1 {
138+
result = result[:headerCountLimit]
139139
}
140140
return result, nil
141141
}
@@ -149,7 +149,7 @@ func valuesFromParam(param string) ValuesExtractor {
149149
for i, p := range c.ParamNames() {
150150
if param == p {
151151
result = append(result, paramVales[i])
152-
if i >= extractorLimit-1 {
152+
if i >= headerCountLimit-1 {
153153
break
154154
}
155155
}
@@ -173,7 +173,7 @@ func valuesFromCookie(name string) ValuesExtractor {
173173
for i, cookie := range cookies {
174174
if name == cookie.Name {
175175
result = append(result, cookie.Value)
176-
if i >= extractorLimit-1 {
176+
if i >= headerCountLimit-1 {
177177
break
178178
}
179179
}
@@ -195,8 +195,8 @@ func valuesFromForm(name string) ValuesExtractor {
195195
if len(values) == 0 {
196196
return nil, errFormExtractorValueMissing
197197
}
198-
if len(values) > extractorLimit-1 {
199-
values = values[:extractorLimit]
198+
if len(values) > headerCountLimit-1 {
199+
values = values[:headerCountLimit]
200200
}
201201
result := append([]string{}, values...)
202202
return result, nil

0 commit comments

Comments
 (0)