Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions pkg/authserver/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/stacklok/toolhive/pkg/auth/upstreamtoken"
servercrypto "github.com/stacklok/toolhive/pkg/authserver/server/crypto"
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
"github.com/stacklok/toolhive/pkg/authserver/server/registration"
"github.com/stacklok/toolhive/pkg/authserver/server/session"
"github.com/stacklok/toolhive/pkg/authserver/storage"
"github.com/stacklok/toolhive/pkg/authserver/upstream"
)
Expand All @@ -44,6 +46,7 @@ const (
type testServer struct {
Server *httptest.Server
PrivateKey *rsa.PrivateKey
authServer Server
}

// testServerOptions configures the test server setup.
Expand Down Expand Up @@ -188,6 +191,7 @@ func setupTestServer(t *testing.T, opts ...testServerOption) *testServer {
return &testServer{
Server: httpServer,
PrivateKey: privateKey,
authServer: srv,
}
}

Expand Down Expand Up @@ -862,6 +866,7 @@ func setupTestServerWithOIDCProvider(t *testing.T, m *mockoidc.MockOIDC) *testSe
testServer: &testServer{
Server: httpServer,
PrivateKey: privateKey,
authServer: srv,
},
mockOIDC: m,
}
Expand Down Expand Up @@ -1162,3 +1167,184 @@ func TestIntegration_RefreshToken_ShortLivedAccessToken(t *testing.T) {
require.True(t, ok)
assert.Greater(t, int64(exp), time.Now().Unix(), "refreshed token exp must be in the future")
}

// TestIntegration_UpstreamTokenService_GetValidTokens tests the UpstreamTokenService
// end-to-end: a real auth server stores upstream tokens during the OAuth callback,
// and the service retrieves them by session ID extracted from the JWT.
func TestIntegration_UpstreamTokenService_GetValidTokens(t *testing.T) {
t.Parallel()

m := startMockOIDC(t)
ts := setupTestServerWithMockOIDC(t, m)

verifier := servercrypto.GeneratePKCEVerifier()
challenge := servercrypto.ComputePKCEChallenge(verifier)

// Complete the full OAuth flow — this stores upstream tokens in the auth server's storage.
authCode, _ := completeAuthorizationFlow(t, ts.Server.URL, authorizationParams{
ClientID: testClientID,
RedirectURI: testRedirectURI,
State: "upstream-svc-test",
Challenge: challenge,
Scope: "openid profile offline_access",
ResponseType: "code",
})

tokenData := exchangeCodeForTokens(t, ts.Server.URL, authCode, verifier, testAudience)

// Extract tsid from the access token JWT — this is the session ID used by storage.
accessToken, ok := tokenData["access_token"].(string)
require.True(t, ok)
tsid := extractTSID(t, accessToken, ts.PrivateKey.Public())

// Create the UpstreamTokenService using the auth server's storage and refresher.
// This mirrors how vMCP would compose these in production.
svc := upstreamtoken.NewInProcessService(
ts.authServer.IDPTokenStorage(),
ts.authServer.UpstreamTokenRefresher(),
)

// The service should return the upstream access token stored during callback.
cred, err := svc.GetValidTokens(context.Background(), tsid)
require.NoError(t, err)
require.NotNil(t, cred)
assert.NotEmpty(t, cred.AccessToken, "upstream access token should be present")
}

// TestIntegration_UpstreamTokenService_RefreshExpiredTokens verifies the transparent
// refresh path: upstream tokens are expired in storage, and the service uses the
// refresher (backed by mockoidc) to get fresh tokens without re-authentication.
func TestIntegration_UpstreamTokenService_RefreshExpiredTokens(t *testing.T) {
t.Parallel()

m := startMockOIDC(t)
ts := setupTestServerWithMockOIDC(t, m)

verifier := servercrypto.GeneratePKCEVerifier()
challenge := servercrypto.ComputePKCEChallenge(verifier)

authCode, _ := completeAuthorizationFlow(t, ts.Server.URL, authorizationParams{
ClientID: testClientID,
RedirectURI: testRedirectURI,
State: "upstream-refresh-test",
Challenge: challenge,
Scope: "openid profile offline_access",
ResponseType: "code",
})

tokenData := exchangeCodeForTokens(t, ts.Server.URL, authCode, verifier, testAudience)

accessToken, ok := tokenData["access_token"].(string)
require.True(t, ok)
tsid := extractTSID(t, accessToken, ts.PrivateKey.Public())

stor := ts.authServer.IDPTokenStorage()

// Read the stored tokens, then overwrite them with an expired ExpiresAt.
original, err := stor.GetUpstreamTokens(context.Background(), tsid)
require.NoError(t, err)
require.NotNil(t, original)
originalAccessToken := original.AccessToken

// Queue a new user for mockoidc's refresh token endpoint response.
m.QueueUser(&mockoidc.MockUser{
Subject: "mock-user-sub-123",
Email: "testuser@example.com",
})

// Store tokens back with ExpiresAt in the past to simulate expiry.
expired := &storage.UpstreamTokens{
ProviderID: original.ProviderID,
AccessToken: original.AccessToken,
RefreshToken: original.RefreshToken,
IDToken: original.IDToken,
ExpiresAt: time.Now().Add(-1 * time.Hour),
UserID: original.UserID,
UpstreamSubject: original.UpstreamSubject,
ClientID: original.ClientID,
}
require.NoError(t, stor.StoreUpstreamTokens(context.Background(), tsid, expired))

// The service should transparently refresh the expired tokens.
svc := upstreamtoken.NewInProcessService(stor, ts.authServer.UpstreamTokenRefresher())

cred, err := svc.GetValidTokens(context.Background(), tsid)
require.NoError(t, err)
require.NotNil(t, cred)
assert.NotEmpty(t, cred.AccessToken, "refreshed upstream access token should be present")

// Verify storage was updated with non-expired tokens after refresh.
refreshed, err := stor.GetUpstreamTokens(context.Background(), tsid)
require.NoError(t, err, "refreshed tokens should be retrievable without ErrExpired")
assert.True(t, refreshed.ExpiresAt.After(time.Now()),
"refreshed tokens should have a future expiry, got %v", refreshed.ExpiresAt)
_ = originalAccessToken // used only to confirm the flow completed
}

// TestIntegration_UpstreamTokenService_SessionNotFound verifies that the service
// returns ErrSessionNotFound for a non-existent session.
func TestIntegration_UpstreamTokenService_SessionNotFound(t *testing.T) {
t.Parallel()

m := startMockOIDC(t)
ts := setupTestServerWithMockOIDC(t, m)

svc := upstreamtoken.NewInProcessService(
ts.authServer.IDPTokenStorage(),
ts.authServer.UpstreamTokenRefresher(),
)

cred, err := svc.GetValidTokens(context.Background(), "non-existent-session-id")
require.Error(t, err)
assert.ErrorIs(t, err, upstreamtoken.ErrSessionNotFound)
assert.Nil(t, cred)
}

// TestIntegration_UpstreamTokenService_NoRefreshToken verifies that the service
// returns ErrNoRefreshToken when the upstream access token is expired but no
// refresh token is available.
func TestIntegration_UpstreamTokenService_NoRefreshToken(t *testing.T) {
t.Parallel()

m := startMockOIDC(t)
ts := setupTestServerWithMockOIDC(t, m)

stor := ts.authServer.IDPTokenStorage()

// Store expired tokens without a refresh token.
sessionID := "no-refresh-session"
require.NoError(t, stor.StoreUpstreamTokens(context.Background(), sessionID, &storage.UpstreamTokens{
ProviderID: "test",
AccessToken: "expired-access",
RefreshToken: "", // no refresh token
ExpiresAt: time.Now().Add(-1 * time.Hour),
UserID: "user-1",
UpstreamSubject: "sub-1",
ClientID: "client-1",
}))

svc := upstreamtoken.NewInProcessService(stor, ts.authServer.UpstreamTokenRefresher())

cred, err := svc.GetValidTokens(context.Background(), sessionID)
require.Error(t, err)
assert.ErrorIs(t, err, upstreamtoken.ErrNoRefreshToken)
assert.Nil(t, cred)
}

// extractTSID parses a JWT access token and extracts the tsid claim.
func extractTSID(t *testing.T, accessToken string, publicKey any) string {
t.Helper()

parsed, err := jwt.ParseSigned(accessToken, []jose.SignatureAlgorithm{jose.RS256})
require.NoError(t, err)

var claims map[string]interface{}
err = parsed.Claims(publicKey, &claims)
require.NoError(t, err)

tsid, ok := claims[session.TokenSessionIDClaimKey].(string)
require.True(t, ok, "tsid claim should be present in access token")
require.NotEmpty(t, tsid)

return tsid
}
Loading