Skip to content

Commit 5352be7

Browse files
Remove Extra Determine ALG Function
1 parent 08437e8 commit 5352be7

File tree

5 files changed

+67
-96
lines changed

5 files changed

+67
-96
lines changed

authentication/authentication.go

+2-30
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,12 @@ import (
44
"context"
55
"encoding/json"
66
"errors"
7-
"fmt"
87
"net/http"
98
"net/url"
109
"reflect"
1110
"strings"
1211
"time"
1312

14-
"github.com/lestrrat-go/jwx/v2/jwa"
15-
1613
"github.com/auth0/go-auth0/authentication/oauth"
1714
"github.com/auth0/go-auth0/internal/client"
1815
"github.com/auth0/go-auth0/internal/idtokenvalidator"
@@ -253,7 +250,7 @@ func (a *Authentication) addClientAuthenticationToURLValues(params oauth.ClientA
253250

254251
switch {
255252
case a.clientAssertionSigningKey != "" && a.clientAssertionSigningAlg != "":
256-
alg, err := determineAlg(a.clientAssertionSigningAlg)
253+
alg, err := client.DetermineSigningAlgorithm(a.clientAssertionSigningAlg)
257254
if err != nil {
258255
return err
259256
}
@@ -295,7 +292,7 @@ func (a *Authentication) addClientAuthenticationToClientAuthStruct(params *oauth
295292
}
296293

297294
if a.clientAssertionSigningKey != "" && a.clientAssertionSigningAlg != "" {
298-
alg, err := determineAlg(a.clientAssertionSigningAlg)
295+
alg, err := client.DetermineSigningAlgorithm(a.clientAssertionSigningAlg)
299296
if err != nil {
300297
return err
301298
}
@@ -323,28 +320,3 @@ func (a *Authentication) addClientAuthenticationToClientAuthStruct(params *oauth
323320

324321
return nil
325322
}
326-
327-
func determineAlg(alg string) (jwa.SignatureAlgorithm, error) {
328-
switch alg {
329-
case "RS256":
330-
return jwa.RS256, nil
331-
case "RS384":
332-
return jwa.RS384, nil
333-
case "RS512":
334-
return jwa.RS512, nil
335-
case "PS256":
336-
return jwa.PS256, nil
337-
case "PS384":
338-
return jwa.PS384, nil
339-
case "PS512":
340-
return jwa.PS512, nil
341-
case "ES256":
342-
return jwa.ES256, nil
343-
case "ES384":
344-
return jwa.ES384, nil
345-
case "ES512":
346-
return jwa.ES512, nil
347-
default:
348-
return "", fmt.Errorf("unsupported client assertion algorithm %q provided", alg)
349-
}
350-
}

authentication/oauth_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ func TestLoginWithClientCredentials(t *testing.T) {
318318
Audience: "test-audience",
319319
}, oauth.IDTokenValidationOptions{})
320320

321-
assert.ErrorContains(t, err, "unsupported client assertion algorithm \"invalid-alg\" provided")
321+
assert.ErrorContains(t, err, "unsupported client assertion algorithm \"invalid-alg\"")
322322
})
323323

324324
t.Run("Should support passing an organization", func(t *testing.T) {

internal/client/jwt_token_source.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
"golang.org/x/oauth2/clientcredentials"
1515
)
1616

17-
// privateKeyJwtTokenSource implements oauth2.TokenSource for Private Key JWT client authentication
17+
// privateKeyJwtTokenSource implements oauth2.TokenSource for Private Key JWT client authentication.
1818
type privateKeyJwtTokenSource struct {
1919
ctx context.Context
2020
uri string
@@ -24,7 +24,7 @@ type privateKeyJwtTokenSource struct {
2424
audience string
2525
}
2626

27-
// newPrivateKeyJwtTokenSource creates a new token source that uses Private Key JWT authentication
27+
// newPrivateKeyJwtTokenSource creates a new token source that uses Private Key JWT authentication.
2828
func newPrivateKeyJwtTokenSource(
2929
ctx context.Context,
3030
uri,
@@ -45,9 +45,9 @@ func newPrivateKeyJwtTokenSource(
4545
return oauth2.ReuseTokenSource(nil, source)
4646
}
4747

48-
// Token generates a new token using Private Key JWT client authentication
48+
// Token generates a new token using Private Key JWT client authentication.
4949
func (p privateKeyJwtTokenSource) Token() (*oauth2.Token, error) {
50-
alg, err := determineAlg(p.clientAssertionSigningAlg)
50+
alg, err := DetermineSigningAlgorithm(p.clientAssertionSigningAlg)
5151
if err != nil {
5252
return nil, fmt.Errorf("invalid algorithm: %w", err)
5353
}
@@ -86,8 +86,8 @@ func (p privateKeyJwtTokenSource) Token() (*oauth2.Token, error) {
8686
return token, nil
8787
}
8888

89-
// determineAlg returns the appropriate JWA signature algorithm based on the string representation
90-
func determineAlg(alg string) (jwa.SignatureAlgorithm, error) {
89+
// DetermineSigningAlgorithm returns the appropriate JWA signature algorithm based on the string representation.
90+
func DetermineSigningAlgorithm(alg string) (jwa.SignatureAlgorithm, error) {
9191
switch alg {
9292
case "RS256":
9393
return jwa.RS256, nil
@@ -112,7 +112,7 @@ func determineAlg(alg string) (jwa.SignatureAlgorithm, error) {
112112
}
113113
}
114114

115-
// CreateClientAssertion creates a JWT token for client authentication with the specified lifetime
115+
// CreateClientAssertion creates a JWT token for client authentication with the specified lifetime.
116116
func CreateClientAssertion(alg jwa.SignatureAlgorithm, signingKey, clientID, audience string) (string, error) {
117117
key, err := jwk.ParseKey([]byte(signingKey), jwk.WithPEM(true))
118118
if err != nil {
@@ -146,7 +146,7 @@ func CreateClientAssertion(alg jwa.SignatureAlgorithm, signingKey, clientID, aud
146146
return string(signedToken), nil
147147
}
148148

149-
// verifyKeyCompatibility checks if the provided key is compatible with the specified algorithm
149+
// verifyKeyCompatibility checks if the provided key is compatible with the specified algorithm.
150150
func verifyKeyCompatibility(alg jwa.SignatureAlgorithm, key jwk.Key) error {
151151
keyType := key.KeyType()
152152

internal/client/jwt_token_source_test.go

+54-55
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func TestDetermineAlg(t *testing.T) {
4646

4747
for _, tc := range testCases {
4848
t.Run(tc.name, func(t *testing.T) {
49-
alg, err := determineAlg(tc.algorithm)
49+
alg, err := DetermineSigningAlgorithm(tc.algorithm)
5050

5151
if tc.expectedError {
5252
assert.Error(t, err)
@@ -84,7 +84,7 @@ func TestClientAssertion(t *testing.T) {
8484
}
8585

8686
// Get the signed assertion
87-
alg, err := determineAlg(ts.clientAssertionSigningAlg)
87+
alg, err := DetermineSigningAlgorithm(ts.clientAssertionSigningAlg)
8888
require.NoError(t, err)
8989

9090
baseURL, err := url.Parse(ts.uri)
@@ -149,7 +149,7 @@ func TestECClientAssertion(t *testing.T) {
149149
}
150150

151151
// Get the signed assertion
152-
alg, err := determineAlg(ts.clientAssertionSigningAlg)
152+
alg, err := DetermineSigningAlgorithm(ts.clientAssertionSigningAlg)
153153
require.NoError(t, err)
154154

155155
baseURL, err := url.Parse(ts.uri)
@@ -224,7 +224,7 @@ func TestIncompatibleKeyTypeForAlgorithm(t *testing.T) {
224224
}
225225

226226
// Get the signed assertion
227-
alg, err := determineAlg(ts.clientAssertionSigningAlg)
227+
alg, err := DetermineSigningAlgorithm(ts.clientAssertionSigningAlg)
228228
require.NoError(t, err)
229229

230230
baseURL, err := url.Parse(ts.uri)
@@ -312,58 +312,57 @@ func TestPrivateKeyJwtTokenSource(t *testing.T) {
312312
assert.Equal(t, "Bearer", token.TokenType)
313313
}
314314

315-
316315
func TestPrivateKeyJwtTokenSourceRefresh(t *testing.T) {
317-
// Generate a test RSA key
318-
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
319-
require.NoError(t, err)
320-
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
321-
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
322-
Type: "RSA PRIVATE KEY",
323-
Bytes: privateKeyBytes,
324-
})
325-
326-
// Track token request count
327-
requestCount := 0
328-
329-
// Create a test server
330-
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
331-
requestCount++
332-
333-
// Return a token with short expiration
334-
w.Header().Set("Content-Type", "application/json")
335-
w.Write([]byte(fmt.Sprintf(`{
316+
// Generate a test RSA key
317+
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
318+
require.NoError(t, err)
319+
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
320+
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
321+
Type: "RSA PRIVATE KEY",
322+
Bytes: privateKeyBytes,
323+
})
324+
325+
// Track token request count
326+
requestCount := 0
327+
328+
// Create a test server
329+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
330+
requestCount++
331+
332+
// Return a token with short expiration
333+
w.Header().Set("Content-Type", "application/json")
334+
w.Write(fmt.Appendf(nil, `{
336335
"access_token": "mock-token-%d",
337336
"token_type": "Bearer",
338337
"expires_in": 2
339-
}`, requestCount)))
340-
}))
341-
defer server.Close()
342-
343-
// Create token source
344-
tokenSource := newPrivateKeyJwtTokenSource(
345-
context.Background(),
346-
server.URL,
347-
"RS256",
348-
string(privateKeyPEM),
349-
"test-client-id",
350-
"test-audience",
351-
)
352-
353-
// Get first token
354-
token1, err := tokenSource.Token()
355-
require.NoError(t, err)
356-
assert.Equal(t, "mock-token-1", token1.AccessToken)
357-
358-
// Wait for token to expire (just over 2 seconds)
359-
time.Sleep(3 * time.Second)
360-
361-
// Get second token - should trigger a refresh
362-
token2, err := tokenSource.Token()
363-
require.NoError(t, err)
364-
assert.Equal(t, "mock-token-2", token2.AccessToken)
365-
assert.NotEqual(t, token1.AccessToken, token2.AccessToken)
366-
367-
// Verify server received two requests
368-
assert.Equal(t, 2, requestCount)
369-
}
338+
}`, requestCount))
339+
}))
340+
defer server.Close()
341+
342+
// Create token source
343+
tokenSource := newPrivateKeyJwtTokenSource(
344+
context.Background(),
345+
server.URL,
346+
"RS256",
347+
string(privateKeyPEM),
348+
"test-client-id",
349+
"test-audience",
350+
)
351+
352+
// Get first token
353+
token1, err := tokenSource.Token()
354+
require.NoError(t, err)
355+
assert.Equal(t, "mock-token-1", token1.AccessToken)
356+
357+
// Wait for token to expire (just over 2 seconds)
358+
time.Sleep(3 * time.Second)
359+
360+
// Get second token - should trigger a refresh
361+
token2, err := tokenSource.Token()
362+
require.NoError(t, err)
363+
assert.Equal(t, "mock-token-2", token2.AccessToken)
364+
assert.NotEqual(t, token1.AccessToken, token2.AccessToken)
365+
366+
// Verify server received two requests
367+
assert.Equal(t, 2, requestCount)
368+
}

internal/idtokenvalidator/idtokenvalidator.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func New(
4545
if i := strings.Index(domain, "//"); i != -1 {
4646
domain = domain[i+2:]
4747
}
48-
alg, err := determineAlg(idTokenSigningAlg)
48+
alg, err := determineSigningAlgorithm(idTokenSigningAlg)
4949
if err != nil {
5050
return nil, err
5151
}
@@ -201,7 +201,7 @@ func (i *IDTokenValidator) Validate(idToken string, optional ValidationOptions)
201201
return err
202202
}
203203

204-
func determineAlg(alg string) (jwa.SignatureAlgorithm, error) {
204+
func determineSigningAlgorithm(alg string) (jwa.SignatureAlgorithm, error) {
205205
switch alg {
206206
case "HS256":
207207
return jwa.HS256, nil

0 commit comments

Comments
 (0)