Skip to content

Commit cb11980

Browse files
authored
Merge pull request #279 from Azridum/refresh-token-resolver
Refresh and Access token resolve handler
2 parents fb61132 + d92fb72 commit cb11980

File tree

4 files changed

+189
-22
lines changed

4 files changed

+189
-22
lines changed

server/handler.go

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package server
33
import (
44
"context"
55
"net/http"
6+
"strings"
67
"time"
78

89
"github.com/go-oauth2/oauth2/v4"
@@ -49,8 +50,14 @@ type (
4950
// ExtensionFieldsHandler in response to the access token with the extension of the field
5051
ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{})
5152

52-
// ResponseTokenHandler response token handing
53+
// ResponseTokenHandler response token handling
5354
ResponseTokenHandler func(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error
55+
56+
// Handler to fetch the refresh token from the request
57+
RefreshTokenResolveHandler func(r *http.Request) (string, error)
58+
59+
// Handler to fetch the access token from the request
60+
AccessTokenResolveHandler func(r *http.Request) (string, bool)
5461
)
5562

5663
// ClientFormHandler get client data from form
@@ -71,3 +78,44 @@ func ClientBasicHandler(r *http.Request) (string, string, error) {
7178
}
7279
return username, password, nil
7380
}
81+
82+
func RefreshTokenFormResolveHandler(r *http.Request) (string, error) {
83+
rt := r.FormValue("refresh_token")
84+
if rt == "" {
85+
return "", errors.ErrInvalidRequest
86+
}
87+
88+
return rt, nil
89+
}
90+
91+
func RefreshTokenCookieResolveHandler(r *http.Request) (string, error) {
92+
c, err := r.Cookie("refresh_token")
93+
if err != nil {
94+
return "", errors.ErrInvalidRequest
95+
}
96+
97+
return c.Value, nil
98+
}
99+
100+
func AccessTokenDefaultResolveHandler(r *http.Request) (string, bool) {
101+
token := ""
102+
auth := r.Header.Get("Authorization")
103+
prefix := "Bearer "
104+
105+
if auth != "" && strings.HasPrefix(auth, prefix) {
106+
token = auth[len(prefix):]
107+
} else {
108+
token = r.FormValue("access_token")
109+
}
110+
111+
return token, token != ""
112+
}
113+
114+
func AccessTokenCookieResolveHandler(r *http.Request) (string, bool) {
115+
c, err := r.Cookie("access_token")
116+
if err != nil {
117+
return "", false
118+
}
119+
120+
return c.Value, true
121+
}

server/handler_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package server
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"net/url"
7+
"strings"
8+
"testing"
9+
"time"
10+
11+
"github.com/go-oauth2/oauth2/v4/errors"
12+
. "github.com/smartystreets/goconvey/convey"
13+
)
14+
15+
func TestRefreshTokenFormResolveHandler(t *testing.T) {
16+
Convey("Correct Request", t, func() {
17+
f := url.Values{}
18+
f.Add("refresh_token", "test_token")
19+
20+
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
21+
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
22+
23+
token, err := RefreshTokenFormResolveHandler(r)
24+
So(err, ShouldBeNil)
25+
So(token, ShouldEqual, "test_token")
26+
})
27+
28+
Convey("Missing Refresh Token", t, func() {
29+
r := httptest.NewRequest("POST", "/", nil)
30+
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
31+
32+
token, err := RefreshTokenFormResolveHandler(r)
33+
So(err, ShouldBeError, errors.ErrInvalidRequest)
34+
So(token, ShouldBeEmpty)
35+
})
36+
}
37+
38+
func TestRefreshTokenCookieResolveHandler(t *testing.T) {
39+
Convey("Correct Request", t, func() {
40+
r := httptest.NewRequest(http.MethodPost, "/", nil)
41+
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
42+
r.AddCookie(&http.Cookie{
43+
Name: "refresh_token",
44+
Value: "test_token",
45+
HttpOnly: true,
46+
Path: "/",
47+
Domain: ".example.com",
48+
Expires: time.Now().Add(time.Hour),
49+
})
50+
51+
token, err := RefreshTokenCookieResolveHandler(r)
52+
So(err, ShouldBeNil)
53+
So(token, ShouldEqual, "test_token")
54+
})
55+
56+
Convey("Missing Refresh Token", t, func() {
57+
r := httptest.NewRequest("POST", "/", nil)
58+
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
59+
60+
token, err := RefreshTokenCookieResolveHandler(r)
61+
So(err, ShouldBeError, errors.ErrInvalidRequest)
62+
So(token, ShouldBeEmpty)
63+
})
64+
}
65+
66+
func TestAccessTokenDefaultHandler(t *testing.T) {
67+
Convey("Request Has Header", t, func() {
68+
r := httptest.NewRequest(http.MethodPost, "/", nil)
69+
r.Header.Add("Authorization", "Bearer test_token")
70+
71+
token, ok := AccessTokenDefaultResolveHandler(r)
72+
So(ok, ShouldBeTrue)
73+
So(token, ShouldEqual, "test_token")
74+
})
75+
76+
Convey("Request Has FormValue", t, func() {
77+
f := url.Values{}
78+
f.Add("access_token", "test_token")
79+
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
80+
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
81+
82+
token, ok := AccessTokenDefaultResolveHandler(r)
83+
So(ok, ShouldBeTrue)
84+
So(token, ShouldEqual, "test_token")
85+
})
86+
87+
Convey("Request Has Nothing", t, func() {
88+
r := httptest.NewRequest(http.MethodPost, "/", nil)
89+
90+
token, ok := AccessTokenDefaultResolveHandler(r)
91+
So(ok, ShouldBeFalse)
92+
So(token, ShouldBeEmpty)
93+
})
94+
}
95+
96+
func TestAccessTokenCookieHandler(t *testing.T) {
97+
Convey("Request Has Cookie", t, func() {
98+
r := httptest.NewRequest(http.MethodPost, "/", nil)
99+
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
100+
r.AddCookie(&http.Cookie{
101+
Name: "access_token",
102+
Value: "test_token",
103+
HttpOnly: true,
104+
Path: "/",
105+
Domain: ".example.com",
106+
Expires: time.Now().Add(time.Hour),
107+
})
108+
109+
token, ok := AccessTokenCookieResolveHandler(r)
110+
So(ok, ShouldBeTrue)
111+
So(token, ShouldEqual, "test_token")
112+
})
113+
114+
Convey("Request Has No Cookie", t, func() {
115+
r := httptest.NewRequest(http.MethodPost, "/", nil)
116+
117+
token, ok := AccessTokenCookieResolveHandler(r)
118+
So(ok, ShouldBeFalse)
119+
So(token, ShouldBeEmpty)
120+
})
121+
}

server/server.go

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"fmt"
77
"net/http"
88
"net/url"
9-
"strings"
109
"time"
1110

1211
"github.com/go-oauth2/oauth2/v4"
@@ -25,8 +24,10 @@ func NewServer(cfg *Config, manager oauth2.Manager) *Server {
2524
Manager: manager,
2625
}
2726

28-
// default handler
27+
// default handlers
2928
srv.ClientInfoHandler = ClientBasicHandler
29+
srv.RefreshTokenResolveHandler = RefreshTokenFormResolveHandler
30+
srv.AccessTokenResolveHandler = AccessTokenDefaultResolveHandler
3031

3132
srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) {
3233
return "", errors.ErrAccessDenied
@@ -56,6 +57,8 @@ type Server struct {
5657
AccessTokenExpHandler AccessTokenExpHandler
5758
AuthorizeScopeHandler AuthorizeScopeHandler
5859
ResponseTokenHandler ResponseTokenHandler
60+
RefreshTokenResolveHandler RefreshTokenResolveHandler
61+
AccessTokenResolveHandler AccessTokenResolveHandler
5962
}
6063

6164
func (s *Server) handleError(w http.ResponseWriter, req *AuthorizeRequest, err error) error {
@@ -367,10 +370,10 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau
367370
case oauth2.ClientCredentials:
368371
tgr.Scope = r.FormValue("scope")
369372
case oauth2.Refreshing:
370-
tgr.Refresh = r.FormValue("refresh_token")
373+
tgr.Refresh, err = s.RefreshTokenResolveHandler(r)
371374
tgr.Scope = r.FormValue("scope")
372-
if tgr.Refresh == "" {
373-
return "", nil, errors.ErrInvalidRequest
375+
if err != nil {
376+
return "", nil, err
374377
}
375378
}
376379
return gt, tgr, nil
@@ -569,27 +572,12 @@ func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Head
569572
return data, statusCode, re.Header
570573
}
571574

572-
// BearerAuth parse bearer token
573-
func (s *Server) BearerAuth(r *http.Request) (string, bool) {
574-
auth := r.Header.Get("Authorization")
575-
prefix := "Bearer "
576-
token := ""
577-
578-
if auth != "" && strings.HasPrefix(auth, prefix) {
579-
token = auth[len(prefix):]
580-
} else {
581-
token = r.FormValue("access_token")
582-
}
583-
584-
return token, token != ""
585-
}
586-
587575
// ValidationBearerToken validation the bearer tokens
588576
// https://tools.ietf.org/html/rfc6750
589577
func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
590578
ctx := r.Context()
591579

592-
accessToken, ok := s.BearerAuth(r)
580+
accessToken, ok := s.AccessTokenResolveHandler(r)
593581
if !ok {
594582
return nil, errors.ErrInvalidAccessToken
595583
}

server/server_config.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,13 @@ func (s *Server) SetAuthorizeScopeHandler(handler AuthorizeScopeHandler) {
9393
func (s *Server) SetResponseTokenHandler(handler ResponseTokenHandler) {
9494
s.ResponseTokenHandler = handler
9595
}
96+
97+
// SetRefreshTokenResolveHandler refresh token resolver
98+
func (s *Server) SetRefreshTokenResolveHandler(handler RefreshTokenResolveHandler) {
99+
s.RefreshTokenResolveHandler = handler
100+
}
101+
102+
// SetAccessTokenResolveHandler access token resolver
103+
func (s *Server) SetAccessTokenResolveHandler(handler AccessTokenResolveHandler) {
104+
s.AccessTokenResolveHandler = handler
105+
}

0 commit comments

Comments
 (0)