Skip to content

Commit 989b2e5

Browse files
committed
refactor: use claims if userinfo n/a, parse token if flow is not code
1 parent ee14f97 commit 989b2e5

File tree

1 file changed

+49
-27
lines changed

1 file changed

+49
-27
lines changed

Diff for: goic.go

+49-27
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func (g *Goic) AddProvider(p *Provider) *Provider {
112112
}
113113

114114
if _, err := p.getWellKnown(); err != nil {
115-
log.Fatalf("goic provider %s: cannot load well-known configuration: %s", p.Name, err.Error())
115+
p.SetErr(err)
116116
}
117117

118118
g.providers[p.Name] = p
@@ -150,21 +150,30 @@ func (g *Goic) RequestAuth(p *Provider, state, nonce, redir string, res http.Res
150150
// AuthRedirectURL gives the full auth redirect URL for the provider
151151
// It returns empty string when there is an error
152152
func AuthRedirectURL(p *Provider, state, nonce, redir string) string {
153-
redirect, err := http.NewRequest("GET", p.wellKnown.AuthURI, nil)
153+
redirect, err := http.NewRequest("GET", p.GetURI("auth"), nil)
154154
if err != nil {
155155
return ""
156156
}
157157

158158
qry := redirect.URL.Query()
159159
qry.Add("response_type", "code")
160+
if p.ResType != "" {
161+
qry.Set("response_type", p.ResType)
162+
}
163+
160164
qry.Add("redirect_uri", redir)
161165
qry.Add("client_id", p.clientID)
162166
qry.Add("scope", p.Scope)
163167
qry.Add("state", state)
164168
qry.Add("nonce", nonce)
165169
redirect.URL.RawQuery = qry.Encode()
166170

167-
return redirect.URL.String()
171+
query := ""
172+
if p.QueryFn != nil {
173+
query = "&" + p.QueryFn()
174+
}
175+
176+
return redirect.URL.String() + query
168177
}
169178

170179
// checkState checks if given state is valid (i.e. known)
@@ -187,26 +196,33 @@ func (g *Goic) checkState(state string) (string, error) {
187196

188197
// Authenticate tries to authenticate a user by given code and nonce
189198
// It is where token is requested and validated
190-
func (g *Goic) Authenticate(p *Provider, code, nonce, redir string) (*Token, error) {
199+
func (g *Goic) Authenticate(p *Provider, codeOrTok, nonce, redir string) (tok *Token, err error) {
200+
tok = &Token{Provider: p.Name}
191201
if !g.Supports(p.Name) {
192-
return &Token{Provider: p.Name}, ErrProviderSupport
202+
return tok, ErrProviderSupport
193203
}
194204

195-
tok, err := g.getToken(p, code, redir, "authorization_code")
196-
if err != nil {
197-
return tok, err
205+
isCode := p.ResType == "" || strings.Contains(" "+p.ResType+" ", " code ")
206+
// get token from code or just parse token
207+
if isCode {
208+
tok, err = g.getToken(p, codeOrTok, redir, "authorization_code")
209+
} else {
210+
tok, err = parseToken([]byte(codeOrTok), tok)
198211
}
199212

213+
if err != nil {
214+
return tok, fmt.Errorf("get token: %w", err)
215+
}
200216
if err := g.verifyToken(p, tok, nonce); err != nil {
201-
return tok, err
217+
return tok, fmt.Errorf("verify token: %w", err)
202218
}
203219

204220
return tok, nil
205221
}
206222

207223
// getToken actually gets token from Provider via wellKnown.TokenURI
208-
func (g *Goic) getToken(p *Provider, code, redir, grant string) (*Token, error) {
209-
tok := &Token{Provider: p.Name}
224+
func (g *Goic) getToken(p *Provider, code, redir, grant string) (tok *Token, err error) {
225+
tok = &Token{Provider: p.Name}
210226

211227
qry := url.Values{}
212228
qry.Add("grant_type", grant)
@@ -219,7 +235,7 @@ func (g *Goic) getToken(p *Provider, code, redir, grant string) (*Token, error)
219235
qry.Add("client_id", p.clientID)
220236
qry.Add("client_secret", p.clientSecret)
221237

222-
req, err := http.NewRequest("POST", p.wellKnown.TokenURI, strings.NewReader(qry.Encode()))
238+
req, err := http.NewRequest("POST", p.GetURI("token"), strings.NewReader(qry.Encode()))
223239
if err != nil {
224240
return tok, err
225241
}
@@ -231,38 +247,41 @@ func (g *Goic) getToken(p *Provider, code, redir, grant string) (*Token, error)
231247
}
232248
defer res.Body.Close()
233249

234-
body, err := ioutil.ReadAll(res.Body)
250+
body, err := io.ReadAll(res.Body)
235251
if err != nil {
236252
return tok, err
237253
}
238254

239-
if err := json.Unmarshal(body, &tok); err != nil {
255+
return parseToken(body, tok)
256+
}
257+
258+
func parseToken(tokByte []byte, tok *Token) (*Token, error) {
259+
if err := json.Unmarshal(tokByte, &tok); err != nil {
240260
return tok, err
241261
}
262+
if tok.IDToken == "" {
263+
return tok, ErrTokenEmpty
264+
}
242265

243266
if tok.Err != "" {
244267
msg := tok.Err
245268
if tok.ErrDesc != "" {
246269
msg += ": " + tok.ErrDesc
247270
}
248-
return tok, errors.New(msg)
271+
return tok, fmt.Errorf(msg)
249272
}
250-
251-
if tok.IDToken == "" {
252-
return tok, ErrTokenEmpty
253-
}
254-
255273
return tok, nil
256274
}
257275

258276
// verifyToken checks and verifies authenticity and ownership of Token
259-
func (g *Goic) verifyToken(p *Provider, tok *Token, nonce string) error {
260-
claims, err := verifyClaims(tok, nonce, p.clientID)
261-
if err != nil {
277+
func (g *Goic) verifyToken(p *Provider, tok *Token, nonce string) (err error) {
278+
// Data verification
279+
if err = tok.VerifyClaims(nonce, p.clientID); err != nil {
262280
return err
263281
}
264282

265-
_, err = jwt.ParseWithClaims(tok.IDToken, claims, func(t *jwt.Token) (interface{}, error) {
283+
// Signature verification
284+
_, err = jwt.ParseWithClaims(tok.IDToken, tok.Claims, func(t *jwt.Token) (any, error) {
266285
alg := t.Header["alg"].(string)
267286
al2 := alg[0:2]
268287
if al2 == "HS" {
@@ -397,8 +416,11 @@ func (g *Goic) UserInfo(tok *Token) *User {
397416
}
398417

399418
p := g.providers[tok.Provider]
419+
if p.GetURI("userinfo") == "" {
420+
return user.FromClaims(tok.Claims)
421+
}
400422

401-
req, err := http.NewRequest("GET", p.wellKnown.UserInfoURI, nil)
423+
req, err := http.NewRequest("GET", p.GetURI("userinfo"), nil)
402424
if err != nil {
403425
return user.withError(err)
404426
}
@@ -457,7 +479,7 @@ func (g *Goic) SignOut(tok *Token, redir string, res http.ResponseWriter, req *h
457479
return ErrProviderSupport
458480
}
459481

460-
redirect, err := http.NewRequest("GET", p.wellKnown.SignOutURI, nil)
482+
redirect, err := http.NewRequest("GET", p.GetURI("signout"), nil)
461483
if err != nil {
462484
return err
463485
}
@@ -497,7 +519,7 @@ func (g *Goic) RevokeToken(tok *Token) error {
497519
qry.Add("token", tk)
498520
qry.Add("token_type_hint", hint)
499521

500-
req, err := http.NewRequest("POST", p.wellKnown.RevokeURI, strings.NewReader(qry.Encode()))
522+
req, err := http.NewRequest("POST", p.GetURI("revoke"), strings.NewReader(qry.Encode()))
501523
if err != nil {
502524
return err
503525
}

0 commit comments

Comments
 (0)