From 104357c5dd92cfbe333838490b47deeef28e7838 Mon Sep 17 00:00:00 2001 From: "NIDHAL.Z" <88096539+ZitouniNidhal@users.noreply.github.com> Date: Thu, 12 Dec 2024 18:12:49 +0100 Subject: [PATCH] Update jwt.go --- jwt/jwt.go | 218 +++++++++++++++++++++++++++++------------------------ 1 file changed, 120 insertions(+), 98 deletions(-) diff --git a/jwt/jwt.go b/jwt/jwt.go index b2bf18298..95a43779e 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -4,16 +4,15 @@ // Package jwt implements the OAuth 2.0 JSON Web Token flow, commonly // known as "two-legged OAuth 2.0". -// // See: https://tools.ietf.org/html/draft-ietf-oauth-jwt-bearer-12 package jwt import ( "context" "encoding/json" + "errors" "fmt" "io" - "io/ioutil" "net/http" "net/url" "strings" @@ -29,157 +28,180 @@ var ( defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"} ) -// Config is the configuration for using JWT to fetch tokens, -// commonly known as "two-legged OAuth 2.0". +// Config holds the configuration for using JWT to fetch tokens. type Config struct { - // Email is the OAuth client identifier used when communicating with - // the configured OAuth provider. - Email string - - // PrivateKey contains the contents of an RSA private key or the - // contents of a PEM file that contains a private key. The provided - // private key is used to sign JWT payloads. - // PEM containers with a passphrase are not supported. - // Use the following command to convert a PKCS 12 file into a PEM. - // - // $ openssl pkcs12 -in key.p12 -out key.pem -nodes - // - PrivateKey []byte - - // PrivateKeyID contains an optional hint indicating which key is being - // used. - PrivateKeyID string - - // Subject is the optional user to impersonate. - Subject string - - // Scopes optionally specifies a list of requested permission scopes. - Scopes []string - - // TokenURL is the endpoint required to complete the 2-legged JWT flow. - TokenURL string - - // Expires optionally specifies how long the token is valid for. - Expires time.Duration - - // Audience optionally specifies the intended audience of the - // request. If empty, the value of TokenURL is used as the - // intended audience. - Audience string - - // PrivateClaims optionally specifies custom private claims in the JWT. - // See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3 - PrivateClaims map[string]interface{} - - // UseIDToken optionally specifies whether ID token should be used instead - // of access token when the server returns both. - UseIDToken bool + Email string + PrivateKey []byte + PrivateKeyID string + Subject string + Scopes []string + TokenURL string + Expires time.Duration + Audience string + PrivateClaims map[string]interface{} + UseIDToken bool } -// TokenSource returns a JWT TokenSource using the configuration -// in c and the HTTP client from the provided context. +// TokenSource returns a JWT TokenSource using the configuration in c. func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { - return oauth2.ReuseTokenSource(nil, jwtSource{ctx, c}) + return oauth2.ReuseTokenSource(nil, jwtSource{ctx: ctx, conf: c}) } -// Client returns an HTTP client wrapping the context's -// HTTP transport and adding Authorization headers with tokens -// obtained from c. -// -// The returned client and its Transport should not be modified. +// Client returns an HTTP client that adds Authorization headers with tokens obtained from c. func (c *Config) Client(ctx context.Context) *http.Client { return oauth2.NewClient(ctx, c.TokenSource(ctx)) } -// jwtSource is a source that always does a signed JWT request for a token. -// It should typically be wrapped with a reuseTokenSource. type jwtSource struct { ctx context.Context conf *Config } func (js jwtSource) Token() (*oauth2.Token, error) { + // Validate config + if err := js.validateConfig(); err != nil { + return nil, err + } + + // Parse private key pk, err := internal.ParseKey(js.conf.PrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %v", err) + } + + // Generate JWT payload + claimSet, err := js.generateClaimSet() if err != nil { return nil, err } - hc := oauth2.NewClient(js.ctx, nil) + + h := *defaultHeader + h.KeyID = js.conf.PrivateKeyID + payload, err := jws.Encode(&h, claimSet, pk) + if err != nil { + return nil, fmt.Errorf("failed to encode JWT: %v", err) + } + + // Request token + return js.requestToken(payload) +} + +func (js jwtSource) validateConfig() error { + if js.conf.Email == "" { + return errors.New("email is required") + } + if len(js.conf.PrivateKey) == 0 { + return errors.New("private key is required") + } + if js.conf.TokenURL == "" { + return errors.New("token URL is required") + } + return nil +} + +func (js jwtSource) generateClaimSet() (*jws.ClaimSet, error) { claimSet := &jws.ClaimSet{ Iss: js.conf.Email, Scope: strings.Join(js.conf.Scopes, " "), Aud: js.conf.TokenURL, PrivateClaims: js.conf.PrivateClaims, } - if subject := js.conf.Subject; subject != "" { - claimSet.Sub = subject - // prn is the old name of sub. Keep setting it - // to be compatible with legacy OAuth 2.0 providers. - claimSet.Prn = subject + + if js.conf.Subject != "" { + claimSet.Sub = js.conf.Subject + claimSet.Prn = js.conf.Subject } - if t := js.conf.Expires; t > 0 { - claimSet.Exp = time.Now().Add(t).Unix() + + if js.conf.Expires > 0 { + claimSet.Exp = time.Now().Add(js.conf.Expires).Unix() } - if aud := js.conf.Audience; aud != "" { - claimSet.Aud = aud + + if js.conf.Audience != "" { + claimSet.Aud = js.conf.Audience } - h := *defaultHeader - h.KeyID = js.conf.PrivateKeyID - payload, err := jws.Encode(&h, claimSet, pk) - if err != nil { - return nil, err + + return claimSet, nil +} + +func (js jwtSource) requestToken(payload string) (*oauth2.Token, error) { + hc := oauth2.NewClient(js.ctx, nil) + v := url.Values{ + "grant_type": {defaultGrantType}, + "assertion": {payload}, } - v := url.Values{} - v.Set("grant_type", defaultGrantType) - v.Set("assertion", payload) + resp, err := hc.PostForm(js.conf.TokenURL, v) if err != nil { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) + return nil, fmt.Errorf("failed to fetch token: %v", err) } defer resp.Body.Close() - body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) - if err != nil { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) - } - if c := resp.StatusCode; c < 200 || c > 299 { + + if resp.StatusCode < 200 || resp.StatusCode > 299 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) return nil, &oauth2.RetrieveError{ Response: resp, Body: body, } } - // tokenRes is the JSON response body. + + return js.parseTokenResponse(resp) +} + +func (js jwtSource) parseTokenResponse(resp *http.Response) (*oauth2.Token, error) { + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("failed to read token response: %v", err) + } + var tokenRes struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` IDToken string `json:"id_token"` - ExpiresIn int64 `json:"expires_in"` // relative seconds from now + ExpiresIn int64 `json:"expires_in"` } if err := json.Unmarshal(body, &tokenRes); err != nil { - return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) + return nil, fmt.Errorf("failed to parse token response: %v", err) } + token := &oauth2.Token{ AccessToken: tokenRes.AccessToken, TokenType: tokenRes.TokenType, + Expiry: time.Now().Add(time.Duration(tokenRes.ExpiresIn) * time.Second), } - raw := make(map[string]interface{}) - json.Unmarshal(body, &raw) // no error checks for optional fields - token = token.WithExtra(raw) - if secs := tokenRes.ExpiresIn; secs > 0 { - token.Expiry = time.Now().Add(time.Duration(secs) * time.Second) - } - if v := tokenRes.IDToken; v != "" { - // decode returned id token to get expiry - claimSet, err := jws.Decode(v) - if err != nil { - return nil, fmt.Errorf("oauth2: error decoding JWT token: %v", err) - } - token.Expiry = time.Unix(claimSet.Exp, 0) - } if js.conf.UseIDToken { if tokenRes.IDToken == "" { - return nil, fmt.Errorf("oauth2: response doesn't have JWT token") + return nil, errors.New("response missing ID token") } token.AccessToken = tokenRes.IDToken } + return token, nil } + +// Helper functions for better debugging +func debugLog(msg string) { + fmt.Println("DEBUG:", msg) +} + +func infoLog(msg string) { + fmt.Println("INFO:", msg) +} + +func warnLog(msg string) { + fmt.Println("WARNING:", msg) +} + +func errorLog(msg string) { + fmt.Println("ERROR:", msg) +} + +// Additional notes to ensure code clarity and maintainability: +// 1. Proper documentation should be added to all exported functions. +// 2. Ensure this code adheres to the latest security practices. +// 3. Add more test cases to cover edge scenarios. +// 4. Future improvements could include support for additional JWT algorithms. + +// End of file + +