diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 9285f29fd..8f23bf548 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -4,7 +4,9 @@ package middleware import ( + "bytes" "encoding/base64" + "errors" "net/http" "strconv" "strings" @@ -21,6 +23,12 @@ type BasicAuthConfig struct { // Required. Validator BasicAuthValidator + // HeaderValidationLimit limits the amount of authorization headers will be validated + // for valid credentials. Set this value to be higher from in an environment where multiple + // basic auth headers could be received. + // Default value 1. + HeaderValidationLimit int + // Realm is a string to define realm attribute of BasicAuth. // Default value "Restricted". Realm string @@ -29,7 +37,7 @@ type BasicAuthConfig struct { // BasicAuthValidator defines a function to validate BasicAuth credentials. // The function should return a boolean indicating whether the credentials are valid, // and an error if any error occurs during the validation process. -type BasicAuthValidator func(string, string, echo.Context) (bool, error) +type BasicAuthValidator func(user string, password string, c echo.Context) (bool, error) const ( basic = "basic" @@ -38,8 +46,9 @@ const ( // DefaultBasicAuthConfig is the default BasicAuth middleware config. var DefaultBasicAuthConfig = BasicAuthConfig{ - Skipper: DefaultSkipper, - Realm: defaultRealm, + Skipper: DefaultSkipper, + Realm: defaultRealm, + HeaderValidationLimit: 1, } // BasicAuth returns an BasicAuth middleware. @@ -52,18 +61,30 @@ func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { return BasicAuthWithConfig(c) } -// BasicAuthWithConfig returns an BasicAuth middleware with config. -// See `BasicAuth()`. +// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config. func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { - // Defaults + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} + +// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration +func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Validator == nil { - panic("echo: basic-auth middleware requires a validator function") + return nil, errors.New("echo basic-auth middleware requires a validator function") } if config.Skipper == nil { - config.Skipper = DefaultBasicAuthConfig.Skipper + config.Skipper = DefaultSkipper + } + realm := defaultRealm + if config.Realm != "" && config.Realm != realm { + realm = strconv.Quote(config.Realm) } - if config.Realm == "" { - config.Realm = defaultRealm + maxValidationAttemptCount := 1 + if config.HeaderValidationLimit > 1 { + maxValidationAttemptCount = config.HeaderValidationLimit } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -72,40 +93,47 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { return next(c) } - auth := c.Request().Header.Get(echo.HeaderAuthorization) + var lastError error l := len(basic) + errCount := 0 + // multiple auth headers is something that can happen in environments like + // corporate test environments that are secured by application proxy servers where + // front facing proxy is also configured to require own basic auth value and does auth checks. + // In that case middleware can receive multiple auth headers. + for _, auth := range c.Request().Header[echo.HeaderAuthorization] { + if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) { + continue + } + if errCount >= maxValidationAttemptCount { + break + } - if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { // Invalid base64 shouldn't be treated as error // instead should be treated as invalid client input - b, err := base64.StdEncoding.DecodeString(auth[l+1:]) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err) + b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:]) + if errDecode != nil { + lastError = echo.NewHTTPError(http.StatusBadRequest).WithInternal(errDecode) + continue } - - cred := string(b) - for i := 0; i < len(cred); i++ { - if cred[i] == ':' { - // Verify credentials - valid, err := config.Validator(cred[:i], cred[i+1:], c) - if err != nil { - return err - } else if valid { - return next(c) - } - break + idx := bytes.IndexByte(b, ':') + if idx >= 0 { + valid, errValidate := config.Validator(string(b[:idx]), string(b[idx+1:]), c) + if errValidate != nil { + lastError = errValidate + } else if valid { + return next(c) } + errCount++ } } - realm := defaultRealm - if config.Realm != defaultRealm { - realm = strconv.Quote(config.Realm) + if lastError != nil { + return lastError } // Need to return `401` for browsers to pop-up login box. c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm) return echo.ErrUnauthorized } - } + }, nil } diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index b3abfa172..d27cbf241 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -16,9 +16,45 @@ import ( ) func TestBasicAuth(t *testing.T) { + + validator := func(u, p string, c echo.Context) (bool, error) { + if u == "joe" && p == "secret" { + return true, nil + } + return false, nil + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + + userPassB64 := base64.StdEncoding.EncodeToString([]byte("joe:secret")) + req.Header.Set(echo.HeaderAuthorization, basic+" "+userPassB64) + + e := echo.New() + c := e.NewContext(req, res) + + h := BasicAuth(validator)(func(c echo.Context) error { + return c.String(http.StatusIMUsed, "test") + }) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusIMUsed, res.Code) +} + +func TestBasicAuthPanic(t *testing.T) { + assert.PanicsWithError(t, "echo basic-auth middleware requires a validator function", func() { + BasicAuth(nil) + }) +} + +func TestBasicAuthWithConfig(t *testing.T) { e := echo.New() + exampleSecret := base64.StdEncoding.EncodeToString([]byte("joe:secret")) mockValidator := func(u, p string, c echo.Context) (bool, error) { + if u == "error" { + return false, errors.New("validator_error") + } if u == "joe" && p == "secret" { return true, nil } @@ -27,56 +63,83 @@ func TestBasicAuth(t *testing.T) { tests := []struct { name string - authHeader string + authHeader []string + config *BasicAuthConfig expectedCode int expectedAuth string - skipperResult bool - expectedErr bool + expectedErr string expectedErrMsg string }{ { name: "Valid credentials", - authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + authHeader: []string{basic + " " + exampleSecret}, expectedCode: http.StatusOK, }, { name: "Case-insensitive header scheme", - authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + authHeader: []string{strings.ToUpper(basic) + " " + exampleSecret}, expectedCode: http.StatusOK, }, { name: "Invalid credentials", - authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")), + authHeader: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))}, expectedCode: http.StatusUnauthorized, expectedAuth: basic + ` realm="someRealm"`, - expectedErr: true, + expectedErr: "code=401, message=Unauthorized", expectedErrMsg: "Unauthorized", }, + { + name: "validator errors out at 2 tries", + authHeader: []string{ + basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")), + basic + " " + base64.StdEncoding.EncodeToString([]byte("error:secret")), + }, + config: &BasicAuthConfig{ + HeaderValidationLimit: 2, + Validator: mockValidator, + }, + expectedCode: http.StatusUnauthorized, + expectedAuth: "", + expectedErr: "validator_error", + expectedErrMsg: "Unauthorized", + }, + { + name: "Invalid credentials, default realm", + authHeader: []string{basic + " " + exampleSecret}, + expectedCode: http.StatusOK, + expectedAuth: basic + ` realm="Restricted"`, + }, { name: "Invalid base64 string", - authHeader: basic + " invalidString", + authHeader: []string{basic + " invalidString"}, expectedCode: http.StatusBadRequest, - expectedErr: true, + expectedErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 12", expectedErrMsg: "Bad Request", }, { name: "Missing Authorization header", expectedCode: http.StatusUnauthorized, - expectedErr: true, + expectedErr: "code=401, message=Unauthorized", expectedErrMsg: "Unauthorized", }, { name: "Invalid Authorization header", - authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")), + authHeader: []string{base64.StdEncoding.EncodeToString([]byte("invalid"))}, expectedCode: http.StatusUnauthorized, - expectedErr: true, + expectedErr: "code=401, message=Unauthorized", expectedErrMsg: "Unauthorized", }, { - name: "Skipped Request", - authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")), - expectedCode: http.StatusOK, - skipperResult: true, + name: "Skipped Request", + authHeader: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip"))}, + expectedCode: http.StatusOK, + config: &BasicAuthConfig{ + Validator: mockValidator, + Realm: "someRealm", + Skipper: func(c echo.Context) bool { + return true + }, + }, }, } @@ -87,26 +150,25 @@ func TestBasicAuth(t *testing.T) { res := httptest.NewRecorder() c := e.NewContext(req, res) - if tt.authHeader != "" { - req.Header.Set(echo.HeaderAuthorization, tt.authHeader) + for _, h := range tt.authHeader { + req.Header.Add(echo.HeaderAuthorization, h) } - h := BasicAuthWithConfig(BasicAuthConfig{ + config := BasicAuthConfig{ Validator: mockValidator, Realm: "someRealm", - Skipper: func(c echo.Context) bool { - return tt.skipperResult - }, - })(func(c echo.Context) error { + } + if tt.config != nil { + config = *tt.config + } + h := BasicAuthWithConfig(config)(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) err := h(c) - if tt.expectedErr { - var he *echo.HTTPError - errors.As(err, &he) - assert.Equal(t, tt.expectedCode, he.Code) + if tt.expectedErr != "" { + assert.EqualError(t, err, tt.expectedErr) if tt.expectedAuth != "" { assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate)) }