diff --git a/apikey.go b/apikey.go index ed441ac..ee3c58d 100644 --- a/apikey.go +++ b/apikey.go @@ -25,8 +25,6 @@ const ApiKeyTablePK = "value" const ( paramNewKeyId = "newKeyId" paramNewKeySecret = "newKeySecret" - paramOldKeyId = "oldKeyId" - paramOldKeySecret = "oldKeySecret" ) const ( @@ -458,22 +456,28 @@ func (a *App) CreateApiKey(w http.ResponseWriter, r *http.Request) { // any number of times to continue the process. A status of 200 does not indicate that all keys were encrypted using the // new key. Check the response data to determine if the rotation process is complete. func (a *App) RotateApiKey(w http.ResponseWriter, r *http.Request) { - requestBody, err := parseRotateKeyRequestBody(r.Body) + var requestBody map[string]string + err := json.NewDecoder(r.Body).Decode(&requestBody) if err != nil { - if strings.HasSuffix(err.Error(), "is required") { - jsonResponse(w, err, http.StatusBadRequest) - } else { - log.Printf("invalid request in RotateApiKey: %s", err) - jsonResponse(w, invalidRequest, http.StatusBadRequest) - } + log.Printf("invalid request in ActivateApiKey: %s", err) + jsonResponse(w, invalidRequest, http.StatusBadRequest) + return + } + + if requestBody[paramNewKeyId] == "" { + jsonResponse(w, paramNewKeyId+" is required", http.StatusBadRequest) return } - oldKey := ApiKey{Key: requestBody[paramOldKeyId], Store: a.GetDB()} - err = oldKey.loadAndCheck(requestBody[paramOldKeySecret]) + if requestBody[paramNewKeySecret] == "" { + jsonResponse(w, paramNewKeySecret+" is required", http.StatusBadRequest) + return + } + + oldKey, err := getAPIKey(r) if err != nil { - log.Printf("old key is not valid: %s", err) - jsonResponse(w, apiKeyNotFound, http.StatusNotFound) + log.Printf("Rotate API key error: %v", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) return } @@ -509,22 +513,6 @@ func (a *App) RotateApiKey(w http.ResponseWriter, r *http.Request) { jsonResponse(w, responseBody, http.StatusOK) } -func parseRotateKeyRequestBody(body io.Reader) (map[string]string, error) { - var requestBody map[string]string - err := json.NewDecoder(body).Decode(&requestBody) - if err != nil { - return nil, fmt.Errorf("invalid request in RotateApiKey: %w", err) - } - - fields := []string{paramNewKeyId, paramNewKeySecret, paramOldKeyId, paramOldKeySecret} - for _, field := range fields { - if _, ok := requestBody[field]; !ok { - return nil, fmt.Errorf("%s is required", field) - } - } - return requestBody, nil -} - func (k *ApiKey) loadAndCheck(secret string) error { err := k.Load() if err != nil { diff --git a/apikey_test.go b/apikey_test.go index 0a085c9..918e2b6 100644 --- a/apikey_test.go +++ b/apikey_test.go @@ -2,6 +2,7 @@ package mfa import ( "bytes" + "context" "crypto/aes" "crypto/rand" "encoding/base64" @@ -10,6 +11,7 @@ import ( "fmt" "io" "net/http" + "net/http/httptest" "regexp" "testing" "time" @@ -372,78 +374,70 @@ func (ms *MfaSuite) TestAppRotateApiKey() { tests := []struct { name string body any + key ApiKey wantStatus int - wantError error + wantError string }{ { - name: "missing oldKeyId", - body: map[string]interface{}{ - paramNewKeyId: newKey.Key, - paramNewKeySecret: newKey.Secret, - paramOldKeySecret: key.Secret, - }, - wantStatus: http.StatusBadRequest, - wantError: errors.New("oldKeyId is required"), - }, - { - name: "missing oldKeySecret", + name: "missing key", body: map[string]interface{}{ paramNewKeyId: newKey.Key, paramNewKeySecret: newKey.Secret, - paramOldKeyId: key.Key, }, - wantStatus: http.StatusBadRequest, - wantError: errors.New("oldKeySecret is required"), + wantStatus: http.StatusUnauthorized, + wantError: "Unauthorized", }, { name: "missing newKeyId", body: map[string]interface{}{ paramNewKeySecret: newKey.Secret, - paramOldKeyId: key.Key, - paramOldKeySecret: key.Secret, }, + key: key, wantStatus: http.StatusBadRequest, - wantError: errors.New("newKeyId is required"), + wantError: "newKeyId is required", }, { name: "missing newKeySecret", body: map[string]interface{}{ - paramNewKeyId: newKey.Key, - paramOldKeyId: key.Key, - paramOldKeySecret: key.Secret, + paramNewKeyId: newKey.Key, }, + key: key, wantStatus: http.StatusBadRequest, - wantError: errors.New("newKeySecret is required"), + wantError: "newKeySecret is required", }, { name: "good", body: map[string]interface{}{ paramNewKeyId: newKey.Key, paramNewKeySecret: newKey.Secret, - paramOldKeyId: user.ApiKey.Key, - paramOldKeySecret: key.Secret, }, + key: key, wantStatus: http.StatusOK, }, } for _, tt := range tests { ms.Run(tt.name, func() { - res := &lambdaResponseWriter{Headers: http.Header{}} - req := requestWithUser(tt.body, key) - ms.app.RotateApiKey(res, req) - - if tt.wantError != nil { - ms.Equal(tt.wantStatus, res.Status, fmt.Sprintf("CreateApiKey response: %s", res.Body)) - var se simpleError - ms.decodeBody(res.Body, &se) - ms.ErrorIs(se, tt.wantError) + jsonBody, err := json.Marshal(tt.body) + must(err) + b := io.NopCloser(bytes.NewReader(jsonBody)) + request, _ := http.NewRequest(http.MethodPost, "/api-key/rotate", b) + request.Header.Set(HeaderAPIKey, tt.key.Key) + request.Header.Set(HeaderAPISecret, tt.key.Secret) + + ctxWithUser := context.WithValue(request.Context(), UserContextKey, tt.key) + request = request.WithContext(ctxWithUser) + + res := httptest.NewRecorder() + Router(ms.app).ServeHTTP(res, request) + ms.Equal(tt.wantStatus, res.Code, "incorrect http status, body: %s", res.Body.String()) + + if tt.wantError != "" { + ms.Contains(res.Body.String(), tt.wantError) return } - ms.Equal(tt.wantStatus, res.Status, fmt.Sprintf("CreateApiKey response: %s", res.Body)) - var response map[string]int - ms.decodeBody(res.Body, &response) + ms.decodeBody(res.Body.Bytes(), &response) ms.Equal(1, response["totpComplete"]) ms.Equal(1, response["webauthnComplete"]) diff --git a/auth.go b/auth.go index c2c4935..1773738 100644 --- a/auth.go +++ b/auth.go @@ -7,14 +7,19 @@ import ( "strings" ) +const ( + HeaderAPIKey = "x-mfa-apikey" + HeaderAPISecret = "x-mfa-apisecret" +) + type User interface{} // AuthenticateRequest checks the provided API key against the keys stored in the database. If the key is active and // valid, a Webauthn client and WebauthnUser are created and stored in the request context. func AuthenticateRequest(r *http.Request) (User, error) { // get key and secret from headers - key := r.Header.Get("x-mfa-apikey") - secret := r.Header.Get("x-mfa-apisecret") + key := r.Header.Get(HeaderAPIKey) + secret := r.Header.Get(HeaderAPISecret) if key == "" || secret == "" { return nil, fmt.Errorf("x-mfa-apikey and x-mfa-apisecret are required") diff --git a/openapi.yaml b/openapi.yaml index b448211..89ef141 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -488,7 +488,8 @@ paths: operationId: rotateApiKey summary: Rotate API Key description: > - All data in webauthn and totp tables that is encrypted by the old key will be re-encrypted using the new key. + All data in webauthn and totp tables that is encrypted by the old key (identified by the request headers + `x-mfa-apikey` and `x-mfa-apisecret`) will be re-encrypted using a new key identified in the request body. If the process does not run to completion, this endpoint can be called any number of times to continue the process. A status of 200 does not indicate that all keys were encrypted using the new key. Check the response data to determine if the rotation process is complete. @@ -500,16 +501,6 @@ paths: schema: type: object properties: - oldKeyId: - type: string - description: old API Key ID - required: true - example: 0123456789012345678901234567890123456789 - oldKeySecret: - type: string - description: old API Key secret - required: true - example: 0123456789012345678901234567890123456789012= newKeyId: type: string description: new API Key ID diff --git a/webauthn_test.go b/webauthn_test.go index 7fbb741..20b5e6c 100644 --- a/webauthn_test.go +++ b/webauthn_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "log" "net/http" "net/http/httptest" "strings" @@ -744,6 +745,7 @@ func Test_GetPublicKeyAsBytes(t *testing.T) { func Router(app *App) http.Handler { mux := &http.ServeMux{} + mux.HandleFunc("POST /api-key/rotate", app.RotateApiKey) mux.HandleFunc(fmt.Sprintf("DELETE /webauthn/credential/{%s}", IDParam), app.DeleteCredential) // Ensure a request without an id gets handled properly mux.HandleFunc("DELETE /webauthn/credential/", app.DeleteCredential) @@ -752,11 +754,13 @@ func Router(app *App) http.Handler { return testAuthnMiddleware(mux) } +// testAuthnMiddleware is a copy of the authenticationMiddleware function func testAuthnMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user, err := AuthenticateRequest(r) if err != nil { - http.Error(w, fmt.Sprintf("unable to authenticate request: %s", err), http.StatusUnauthorized) + log.Printf("unable to authenticate request: %s", err) + http.Error(w, "Unauthorized", http.StatusUnauthorized) return } @@ -841,8 +845,8 @@ func (ms *MfaSuite) Test_DeleteCredential() { ms.T().Run(tt.name, func(t *testing.T) { request, _ := http.NewRequest("DELETE", fmt.Sprintf("/webauthn/credential/%s", tt.credID), nil) - request.Header.Set("x-mfa-apikey", tt.user.ApiKeyValue) - request.Header.Set("x-mfa-apisecret", tt.user.ApiKey.Secret) + request.Header.Set(HeaderAPIKey, tt.user.ApiKeyValue) + request.Header.Set(HeaderAPISecret, tt.user.ApiKey.Secret) request.Header.Set("x-mfa-RPDisplayName", "TestRPName") request.Header.Set("x-mfa-RPID", "111.11.11.11") request.Header.Set("x-mfa-UserUUID", tt.user.ID)