Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,31 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
}
flowType := getFlowFromChallenge(codeChallenge)

// Check if this provider requires PKCE for its own authorization flow
// Generate a separate PKCE pair for communication with the external provider
// This is separate from the user's PKCE and is used for provider-to-provider auth
providerRequiresPKCE := false
providerCodeVerifier := ""
if oauthProvider, ok := p.(provider.OAuthProvider); ok {
providerRequiresPKCE = oauthProvider.RequiresPKCE()
if providerRequiresPKCE {
// Uses oauth2 library's built-in PKCE support
providerCodeVerifier = oauth2.GenerateVerifier()
}
}

flowStateID := ""
if isPKCEFlow(flowType) {
flowState, err := generateFlowState(db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil)
var flowState *models.FlowState
// Create FlowState if user is using PKCE OR provider requires PKCE
// This ensures we can store provider's code_verifier even for implicit flows
if isPKCEFlow(flowType) || providerRequiresPKCE {
// a bit hacky but we have a db constraint on code challenge method,
// so we default to s256 if the provider requires PKCE and the code challenge method is not provided
if providerRequiresPKCE && codeChallengeMethod == "" {
codeChallengeMethod = "s256"
}

flowState, err = generateFlowState(db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil, providerCodeVerifier)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -129,6 +151,13 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
}
}

// Pass the GENERATED PKCE parameters (not user's PKCE) to the external provider's
// authorization URL using oauth2 library's built-in PKCE support
// This works for any OAuth provider that supports/requires PKCE
if flowState != nil && flowState.ProviderCodeVerifier != "" {
authUrlParams = append(authUrlParams, oauth2.S256ChallengeOption(flowState.ProviderCodeVerifier))
}

authURL := p.AuthCodeURL(tokenString, authUrlParams...)

return authURL, nil
Expand Down Expand Up @@ -237,8 +266,10 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
return terr
}
}
if flowState != nil {
// This means that the callback is using PKCE
// Check if USER requested PKCE (not just if FlowState exists)
// FlowState might exist only because provider requires PKCE
if flowState.IsUserPKCEFlow() {
// User wants PKCE flow - store tokens and return auth code
flowState.ProviderAccessToken = providerAccessToken
flowState.ProviderRefreshToken = providerRefreshToken
flowState.UserID = &(user.ID)
Expand All @@ -247,6 +278,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re

terr = tx.Update(flowState)
} else {
// User wants implicit flow - issue token directly
token, terr = a.issueRefreshToken(r, tx, user, models.OAuth, grantParams)
}

Expand All @@ -272,14 +304,15 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
}

rurl := a.getExternalRedirectURL(r)
if flowState != nil {
// This means that the callback is using PKCE
// Set the flowState.AuthCode to the query param here
// Check if USER requested PKCE (not just if FlowState exists)
if flowState.IsUserPKCEFlow() {
// User wants PKCE - return auth code
rurl, err = a.prepPKCERedirectURL(rurl, flowState.AuthCode)
if err != nil {
return err
}
} else if token != nil {
// User wants implicit flow - return token directly
q := url.Values{}
q.Set("provider_token", providerAccessToken)
// Because not all providers give out a refresh token
Expand Down Expand Up @@ -644,6 +677,9 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide
case "spotify":
pConfig = config.External.Spotify
p, err = provider.NewSpotifyProvider(pConfig, scopes)
case "supabase":
pConfig = config.External.Supabase
p, err = provider.NewSupabaseProvider(ctx, pConfig, scopes)
case "slack":
pConfig = config.External.Slack
p, err = provider.NewSlackProvider(pConfig, scopes)
Expand Down
16 changes: 15 additions & 1 deletion internal/api/external_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import (
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/provider"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/utilities"
"golang.org/x/oauth2"
)

// OAuthProviderData contains the userData and token returned by the oauth provider
Expand Down Expand Up @@ -83,7 +85,19 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s
"code": oauthCode,
}).Debug("Exchanging oauth code")

token, err := oAuthProvider.GetOAuthToken(oauthCode)
// Build token exchange options, including PKCE code verifier if available
var tokenOpts []oauth2.AuthCodeOption
flowStateID := getFlowStateID(ctx)
if flowStateID != "" {
db := a.db.WithContext(ctx)
flowState, fsErr := models.FindFlowStateByID(db, flowStateID)
if fsErr == nil && flowState.ProviderCodeVerifier != "" {
// Pass PKCE code verifier for token exchange
tokenOpts = append(tokenOpts, oauth2.VerifierOption(flowState.ProviderCodeVerifier))
}
}

token, err := oAuthProvider.GetOAuthToken(oauthCode, tokenOpts...)
if err != nil {
return nil, apierrors.NewInternalServerError("Unable to exchange external code: %s", oauthCode).WithInternalError(err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/api/magic_link.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error {
}

if isPKCEFlow(flowType) {
if _, err = generateFlowState(db, models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID); err != nil {
if _, err = generateFlowState(db, models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID, ""); err != nil {
return err
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/pkce.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ func getFlowFromChallenge(codeChallenge string) models.FlowType {
}

// Should only be used with Auth Code of PKCE Flows
func generateFlowState(tx *storage.Connection, providerType string, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID) (*models.FlowState, error) {
func generateFlowState(tx *storage.Connection, providerType string, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID, providerCodeVerifier string) (*models.FlowState, error) {
codeChallengeMethod, err := models.ParseCodeChallengeMethod(codeChallengeMethodParam)
if err != nil {
return nil, err
}
flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethod, authenticationMethod, userID)
flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethod, authenticationMethod, userID, providerCodeVerifier)
if err := tx.Create(flowState); err != nil {
return nil, err
}
Expand Down
11 changes: 8 additions & 3 deletions internal/api/provider/apple.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,13 @@ func NewAppleProvider(ctx context.Context, ext conf.OAuthProviderConfiguration)
}

// GetOAuthToken returns the apple provider access token
func (p AppleProvider) GetOAuthToken(code string) (*oauth2.Token, error) {
opts := []oauth2.AuthCodeOption{
func (p AppleProvider) GetOAuthToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
appleOpts := []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("client_id", p.ClientID),
oauth2.SetAuthURLParam("secret", p.ClientSecret),
}
return p.Exchange(context.Background(), code, opts...)
appleOpts = append(appleOpts, opts...)
return p.Exchange(context.Background(), code, appleOpts...)
}

func (p AppleProvider) AuthCodeURL(state string, args ...oauth2.AuthCodeOption) string {
Expand Down Expand Up @@ -172,3 +173,7 @@ func (p AppleProvider) ParseUser(data string, userData *UserProvidedData) error
userData.Metadata.FullName = strings.TrimSpace(u.Name.FirstName + " " + u.Name.LastName)
return nil
}
// RequiresPKCE returns false as this provider does not require PKCE
func (p *AppleProvider) RequiresPKCE() bool {
return false
}
8 changes: 6 additions & 2 deletions internal/api/provider/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ func NewAzureProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuth
}, nil
}

func (g azureProvider) GetOAuthToken(code string) (*oauth2.Token, error) {
return g.Exchange(context.Background(), code)
func (g azureProvider) GetOAuthToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return g.Exchange(context.Background(), code, opts...)
}

func DetectAzureIDTokenIssuer(ctx context.Context, idToken string) (string, error) {
Expand Down Expand Up @@ -162,3 +162,7 @@ func (g azureProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*Use

return nil, fmt.Errorf("azure: no OIDC ID token present in response")
}
// RequiresPKCE returns false as this provider does not require PKCE
func (p *azureProvider) RequiresPKCE() bool {
return false
}
8 changes: 6 additions & 2 deletions internal/api/provider/bitbucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ func NewBitbucketProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, e
}, nil
}

func (g bitbucketProvider) GetOAuthToken(code string) (*oauth2.Token, error) {
return g.Exchange(context.Background(), code)
func (g bitbucketProvider) GetOAuthToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return g.Exchange(context.Background(), code, opts...)
}

func (g bitbucketProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) {
Expand Down Expand Up @@ -102,3 +102,7 @@ func (g bitbucketProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (

return data, nil
}
// RequiresPKCE returns false as this provider does not require PKCE
func (p *bitbucketProvider) RequiresPKCE() bool {
return false
}
8 changes: 6 additions & 2 deletions internal/api/provider/discord.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ func NewDiscordProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAu
}, nil
}

func (g discordProvider) GetOAuthToken(code string) (*oauth2.Token, error) {
return g.Exchange(context.Background(), code)
func (g discordProvider) GetOAuthToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return g.Exchange(context.Background(), code, opts...)
}

func (g discordProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) {
Expand Down Expand Up @@ -118,3 +118,7 @@ func (g discordProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*U

return data, nil
}
// RequiresPKCE returns false as this provider does not require PKCE
func (p *discordProvider) RequiresPKCE() bool {
return false
}
8 changes: 6 additions & 2 deletions internal/api/provider/facebook.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ func NewFacebookProvider(ext conf.OAuthProviderConfiguration, scopes string) (OA
}, nil
}

func (p facebookProvider) GetOAuthToken(code string) (*oauth2.Token, error) {
return p.Exchange(context.Background(), code)
func (p facebookProvider) GetOAuthToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return p.Exchange(context.Background(), code, opts...)
}

func (p facebookProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) {
Expand Down Expand Up @@ -110,3 +110,7 @@ func (p facebookProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*

return data, nil
}
// RequiresPKCE returns false as this provider does not require PKCE
func (p *facebookProvider) RequiresPKCE() bool {
return false
}
8 changes: 6 additions & 2 deletions internal/api/provider/figma.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ func NewFigmaProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuth
}, nil
}

func (p figmaProvider) GetOAuthToken(code string) (*oauth2.Token, error) {
return p.Exchange(context.Background(), code)
func (p figmaProvider) GetOAuthToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return p.Exchange(context.Background(), code, opts...)
}

func (p figmaProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) {
Expand Down Expand Up @@ -93,3 +93,7 @@ func (p figmaProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*Use
}
return data, nil
}
// RequiresPKCE returns false as this provider does not require PKCE
func (p *figmaProvider) RequiresPKCE() bool {
return false
}
8 changes: 6 additions & 2 deletions internal/api/provider/fly.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func NewFlyProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthPr
}, nil
}

func (p flyProvider) GetOAuthToken(code string) (*oauth2.Token, error) {
return p.Exchange(context.Background(), code)
func (p flyProvider) GetOAuthToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return p.Exchange(context.Background(), code, opts...)
}

func (p flyProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) {
Expand Down Expand Up @@ -101,3 +101,7 @@ func (p flyProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserP
}
return data, nil
}
// RequiresPKCE returns false as this provider does not require PKCE
func (p *flyProvider) RequiresPKCE() bool {
return false
}
8 changes: 6 additions & 2 deletions internal/api/provider/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ func NewGithubProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAut
}, nil
}

func (g githubProvider) GetOAuthToken(code string) (*oauth2.Token, error) {
return g.Exchange(context.Background(), code)
func (g githubProvider) GetOAuthToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return g.Exchange(context.Background(), code, opts...)
}

func (g githubProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) {
Expand Down Expand Up @@ -108,3 +108,7 @@ func (g githubProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*Us

return data, nil
}
// RequiresPKCE returns false as this provider does not require PKCE
func (p *githubProvider) RequiresPKCE() bool {
return false
}
8 changes: 6 additions & 2 deletions internal/api/provider/gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ func NewGitlabProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAut
}, nil
}

func (g gitlabProvider) GetOAuthToken(code string) (*oauth2.Token, error) {
return g.Exchange(context.Background(), code)
func (g gitlabProvider) GetOAuthToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return g.Exchange(context.Background(), code, opts...)
}

func (g gitlabProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) {
Expand Down Expand Up @@ -105,3 +105,7 @@ func (g gitlabProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*Us

return data, nil
}
// RequiresPKCE returns false as this provider does not require PKCE
func (p *gitlabProvider) RequiresPKCE() bool {
return false
}
8 changes: 6 additions & 2 deletions internal/api/provider/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ func NewGoogleProvider(ctx context.Context, ext conf.OAuthProviderConfiguration,
}, nil
}

func (g googleProvider) GetOAuthToken(code string) (*oauth2.Token, error) {
return g.Exchange(context.Background(), code)
func (g googleProvider) GetOAuthToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return g.Exchange(context.Background(), code, opts...)
}

const UserInfoEndpointGoogle = "https://www.googleapis.com/userinfo/v2/me"
Expand Down Expand Up @@ -142,3 +142,7 @@ func OverrideGoogleProvider(issuer, userInfo string) {
internalIssuerGoogle = issuer
internalUserInfoEndpointGoogle = userInfo
}
// RequiresPKCE returns false as this provider does not require PKCE
func (p *googleProvider) RequiresPKCE() bool {
return false
}
8 changes: 6 additions & 2 deletions internal/api/provider/kakao.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ type kakaoUser struct {
} `json:"kakao_account"`
}

func (p kakaoProvider) GetOAuthToken(code string) (*oauth2.Token, error) {
return p.Exchange(context.Background(), code)
func (p kakaoProvider) GetOAuthToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return p.Exchange(context.Background(), code, opts...)
}

func (p kakaoProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) {
Expand Down Expand Up @@ -105,3 +105,7 @@ func NewKakaoProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuth
APIHost: apiHost,
}, nil
}
// RequiresPKCE returns false as this provider does not require PKCE
func (p *kakaoProvider) RequiresPKCE() bool {
return false
}
8 changes: 6 additions & 2 deletions internal/api/provider/keycloak.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ func NewKeycloakProvider(ext conf.OAuthProviderConfiguration, scopes string) (OA
}, nil
}

func (g keycloakProvider) GetOAuthToken(code string) (*oauth2.Token, error) {
return g.Exchange(context.Background(), code)
func (g keycloakProvider) GetOAuthToken(code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return g.Exchange(context.Background(), code, opts...)
}

func (g keycloakProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) {
Expand Down Expand Up @@ -132,3 +132,7 @@ func (g keycloakProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*
return data, nil

}
// RequiresPKCE returns false as this provider does not require PKCE
func (p *keycloakProvider) RequiresPKCE() bool {
return false
}
Loading