Skip to content

Chore/sentry errors #238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 120 additions & 64 deletions directive/authorized.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@ package directive
import (
"context"
"encoding/json"
"errors"
"fmt"
openmfperrors "github.com/openmfp/golang-commons/errors"
"github.com/openmfp/golang-commons/sentry"
"github.com/rs/zerolog/log"
"github.com/vektah/gqlparser/v2/gqlerror"
"strings"

"github.com/99designs/gqlgen/graphql"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
openmfpcontext "github.com/openmfp/golang-commons/context"
"github.com/openmfp/golang-commons/fga/helpers"
"github.com/openmfp/golang-commons/logger"
"github.com/vektah/gqlparser/v2/gqlerror"
"google.golang.org/grpc/metadata"
)

Expand Down Expand Up @@ -55,102 +57,156 @@ func extractNestedKeyFromArgs(args map[string]any, paramName string) (string, er
}

func Authorized(openfgaClient openfgav1.OpenFGAServiceClient, log *logger.Logger) func(context.Context, interface{}, graphql.Resolver, string, *string, *string, string) (interface{}, error) {
compLogger := log.ComponentLogger("authorizedDirective")
ac := authChecker{
log: compLogger,
openfgaClient: openfgaClient,
}

if !directiveConfiguration.DirectivesAuthorizationEnabled {
log.Trace().Msg("Authorization directive is disabled. Skipping authorization check.")
return func(ctx context.Context, obj interface{}, next graphql.Resolver, relation string, entityType *string, entityTypeParamName *string, entityParamName string) (interface{}, error) {
return next(ctx)
}
}

return func(ctx context.Context, obj interface{}, next graphql.Resolver, relation string, entityType *string, entityTypeParamName *string, entityParamName string) (interface{}, error) {

if openfgaClient == nil {
return nil, errors.New("OpenFGAServiceClient is nil. Cannot process request")
return nil, sentry.SentryError(openmfperrors.New("OpenFGAServiceClient is nil. Cannot process request"))
}

ctx, err := setTenantToContextForTechnicalUsers(ctx, log)
ctx, hasToken, err := ac.withTenantContextForTechnicalUsers(ctx)
if err != nil {
compLogger.Info().Err(err).Msg("error setting tenant context for technical users")
return nil, err
}

token, err := openmfpcontext.GetAuthHeaderFromContext(ctx)
hasToken := err == nil

if hasToken {
ctx = metadata.AppendToOutgoingContext(ctx, "authorization", token)
}

fctx := graphql.GetFieldContext(ctx)

entityID, err := extractNestedKeyFromArgs(fctx.Args, entityParamName)
entityID, tenantID, evaluatedEntityType, err := ac.prepareAuthCheckInputs(ctx, entityParamName, entityTypeParamName, entityType)
if err != nil {
compLogger.Info().Err(err).Msg("error when extracting values for auth check")
return nil, err
}

tenantID, err := openmfpcontext.GetTenantFromContext(ctx)
res, err := ac.executeTheAuthCheck(ctx, hasToken, entityID, tenantID, evaluatedEntityType, relation)
if err != nil {
return nil, err
compLogger.Error().Err(err).Msg("error in authorized directive")
return nil, sentry.SentryError(err)
}

evaluatedEntityType := ""
if entityTypeParamName != nil {
evaluatedEntityType, err = extractNestedKeyFromArgs(fctx.Args, *entityTypeParamName)
if err != nil {
return nil, err
}
} else if entityType != nil {
evaluatedEntityType = *entityType
if !res.Allowed {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure res is always not nil if no error occured?

log.Info().Bool("allowed", res.Allowed).Msg("not allowed")
return nil, gqlerror.Errorf("unauthorized")
}

if evaluatedEntityType == "" {
return nil, fmt.Errorf("make sure to either provide entityType or entityTypeParamName")
}
return next(ctx)
}
}

type authChecker struct {
log *logger.Logger
openfgaClient openfgav1.OpenFGAServiceClient
}

func (ac *authChecker) executeTheAuthCheck(ctx context.Context, hasToken bool, entityID string, tenantID string, evaluatedEntityType string, relation string) (*openfgav1.CheckResponse, error) {
storeID, err := helpers.GetStoreIDForTenant(ctx, ac.openfgaClient, tenantID)
if err != nil {
return nil, err
}
modelID, err := helpers.GetModelIDForTenant(ctx, ac.openfgaClient, tenantID)
if err != nil {
return nil, err
}

storeID, err := helpers.GetStoreIDForTenant(ctx, openfgaClient, tenantID)
var userID string
if hasToken {
user, err := openmfpcontext.GetWebTokenFromContext(ctx)
if err != nil {
return nil, err
}
modelID, err := helpers.GetModelIDForTenant(ctx, openfgaClient, tenantID)
userID = user.Subject
} else {
spiffe, err := openmfpcontext.GetSpiffeFromContext(ctx)
if err != nil {
return nil, err
return nil, openmfperrors.New("authorized was invoked without a user token or a spiffe header")
}
userID = strings.TrimPrefix(spiffe, "spiffe://")
log.Trace().Str("user", userID).Msg("using spiffe user in authorized directive")
}

var userID string
if hasToken {
user, err := openmfpcontext.GetWebTokenFromContext(ctx)
if err != nil {
return nil, err
}
userID = user.Subject
} else {
spiffe, err := openmfpcontext.GetSpiffeFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("authorized was invoked without a user token or a spiffe header")
}
userID = strings.TrimPrefix(spiffe, "spiffe://")
log.Trace().Str("user", userID).Msg("using spiffe user in authorized directive")
}
req := &openfgav1.CheckRequest{
StoreId: storeID,
AuthorizationModelId: modelID,
TupleKey: &openfgav1.CheckRequestTupleKey{
User: fmt.Sprintf("user:%s", helpers.SanitizeUserID(userID)),
Relation: relation,
Object: fmt.Sprintf("%s:%s", evaluatedEntityType, entityID),
},
}

req := &openfgav1.CheckRequest{
StoreId: storeID,
AuthorizationModelId: modelID,
TupleKey: &openfgav1.CheckRequestTupleKey{
User: fmt.Sprintf("user:%s", helpers.SanitizeUserID(userID)),
Relation: relation,
Object: fmt.Sprintf("%s:%s", evaluatedEntityType, entityID),
},
}
res, err := ac.openfgaClient.Check(ctx, req)
if err != nil {
return nil, err
}
if res == nil {
return nil, openmfperrors.New("received nil response from openfgaClient.Check with no error")
}
return res, nil
}

res, err := openfgaClient.Check(ctx, req)
if err != nil {
log.Error().Err(err).Str("user", req.TupleKey.User).Msg("authorization check failed")
return nil, err
}
func (ac *authChecker) withTenantContextForTechnicalUsers(ctx context.Context) (context.Context, bool, error) {
newCtx, err := setTenantToContextForTechnicalUsers(ctx, ac.log)
if err != nil {
return ctx, false, openmfperrors.EnsureStack(err)
}

if !res.Allowed {
log.Warn().Bool("allowed", res.Allowed).Any("req", req).Msg("not allowed")
return nil, gqlerror.Errorf("unauthorized")
token, err := openmfpcontext.GetAuthHeaderFromContext(newCtx)
hasToken := err == nil

if hasToken {
newCtx = metadata.AppendToOutgoingContext(newCtx, "authorization", token)
}

return newCtx, hasToken, nil
}

func (ac *authChecker) prepareAuthCheckInputs(
ctx context.Context,
entityParamName string,
entityTypeParamName *string,
entityType *string,
) (
entityID string,
tenantID string,
evaluatedEntityType string,
err error,
) {
fctx := graphql.GetFieldContext(ctx)

entityID, err = extractNestedKeyFromArgs(fctx.Args, entityParamName)
if err != nil {
err = openmfperrors.EnsureStack(err)
return
}

tenantID, err = openmfpcontext.GetTenantFromContext(ctx)
if err != nil {
err = openmfperrors.EnsureStack(err)
return
}

if entityTypeParamName != nil {
evaluatedEntityType, err = extractNestedKeyFromArgs(fctx.Args, *entityTypeParamName)
if err != nil {
err = openmfperrors.EnsureStack(err)
return
}
} else if entityType != nil {
evaluatedEntityType = *entityType
}

return next(ctx)
if evaluatedEntityType == "" {
err = openmfperrors.New("make sure to either provide entityType or entityTypeParamName")
return
}
return
}
22 changes: 12 additions & 10 deletions directive/authorized_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/openmfp/golang-commons/sentry"
"testing"

"github.com/99designs/gqlgen/graphql"
Expand Down Expand Up @@ -59,7 +60,7 @@ func TestAuthorized(t *testing.T) {
graphqlArgs: map[string]any{
"non-existent": "something wrong",
},
expectedError: fmt.Errorf("unable to extract param from request for given paramName %q, param is of wrong type", "non-existent.nested"),
expectedError: sentry.SentryError(fmt.Errorf("unable to extract param from request for given paramName %q, param is of wrong type", "non-existent.nested")),
},
{
name: "should error if the entityParamName has the wrong type for a nested value",
Expand All @@ -69,7 +70,7 @@ func TestAuthorized(t *testing.T) {
"nested": map[string]any{},
},
},
expectedError: fmt.Errorf("unable to extract param from request for given paramName %q, param is of wrong type", "non-existent.nested"),
expectedError: sentry.SentryError(fmt.Errorf("unable to extract param from request for given paramName %q, param is of wrong type", "non-existent.nested")),
},
{
name: "should error if the entityTypeParamName is set and not part of the arguments",
Expand All @@ -78,7 +79,7 @@ func TestAuthorized(t *testing.T) {
"existent": "something",
},
entityTypeParamName: String("non-existent"),
expectedError: fmt.Errorf("unable to extract param from request for given paramName %q", "non-existent"),
expectedError: sentry.SentryError(fmt.Errorf("unable to extract param from request for given paramName %q", "non-existent")),
},
{
name: "should error if the entityType is set and but emtpy",
Expand All @@ -88,7 +89,7 @@ func TestAuthorized(t *testing.T) {
"existent": "something",
"emtpy": "",
},
expectedError: errors.New("make sure to either provide entityType or entityTypeParamName"),
expectedError: sentry.SentryError(errors.New("make sure to either provide entityType or entityTypeParamName")),
},
{
name: "should error if the request is not allowed",
Expand Down Expand Up @@ -207,9 +208,8 @@ func TestAuthorized(t *testing.T) {
},
fgaMocks: func(s *mocks.OpenFGAServiceClient) {
s.EXPECT().ListStores(mock.Anything, mock.Anything).Return(nil, errors.New("ListStores error"))
// s.EXPECT().Check(mock.Anything, mock.Anything).Return(&openfgav1.CheckResponse{Allowed: true}, nil)
},
expectedError: errors.New("ListStores error"),
expectedError: sentry.SentryError(errors.New("ListStores error")),
},
}

Expand Down Expand Up @@ -239,7 +239,9 @@ func TestAuthorized(t *testing.T) {
ctx = openmfpcontext.AddAuthHeaderToContext(ctx, fmt.Sprintf("Bearer %s", token))

_, err := Authorized(openfgaMock, log.Logger)(ctx, nil, nextFn, test.relation, test.entityType, test.entityTypeParamName, test.entityParamName)
assert.Equal(t, test.expectedError, err)
if test.expectedError != nil {
assert.Error(t, test.expectedError, err)
}
})
}
}
Expand Down Expand Up @@ -306,7 +308,7 @@ func TestAuthorizedEdgeCases2(t *testing.T) {
},
fgaMocks: func(s *mocks.OpenFGAServiceClient) {
},
expectedError: fmt.Errorf("someone stored a wrong value in the [tenantId] key with type [<nil>], expected [string]"),
expectedError: errors.New("someone stored a wrong value in the [tenantId] key with type [<nil>], expected [string]"),
},
{
name: "Check() return error",
Expand All @@ -327,7 +329,7 @@ func TestAuthorizedEdgeCases2(t *testing.T) {
}, nil)
s.EXPECT().Check(mock.Anything, mock.Anything).Return(nil, errors.New("Check error"))
},
expectedError: errors.New("Check error"),
expectedError: sentry.SentryError(errors.New("Check error")),
},
}

Expand Down Expand Up @@ -357,7 +359,7 @@ func TestAuthorizedEdgeCases2(t *testing.T) {
ctx = openmfpcontext.AddAuthHeaderToContext(ctx, fmt.Sprintf("Bearer %s", token))

_, err := Authorized(openfgaMock, log.Logger)(ctx, nil, nextFn, test.relation, test.entityType, test.entityTypeParamName, test.entityParamName)
assert.Equal(t, test.expectedError, err)
assert.Error(t, test.expectedError, err)
})
}

Expand Down