Skip to content

Commit 1269096

Browse files
authored
CLOUDP-252326 add missing auth dependencies (#236)
1 parent e32fddf commit 1269096

File tree

7 files changed

+870
-6
lines changed

7 files changed

+870
-6
lines changed

auth/device_flow.go

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
// Copyright 2022 MongoDB Inc
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package auth
16+
17+
import (
18+
"context"
19+
"errors"
20+
"net/http"
21+
"net/url"
22+
"strings"
23+
"time"
24+
25+
"go.mongodb.org/ops-manager/opsmngr"
26+
)
27+
28+
const authExpiredError = "DEVICE_AUTHORIZATION_EXPIRED"
29+
30+
// DeviceCode holds information about the authorization-in-progress.
31+
type DeviceCode struct {
32+
UserCode string `json:"user_code"` //nolint:tagliatelle // UserCode is the code presented to users
33+
VerificationURI string `json:"verification_uri"` //nolint:tagliatelle // VerificationURI is the URI where users will need to confirm the code
34+
DeviceCode string `json:"device_code"` //nolint:tagliatelle // DeviceCode is the internal code to confirm the status of the flow
35+
ExpiresIn int `json:"expires_in"` //nolint:tagliatelle // ExpiresIn when the code will expire
36+
Interval int `json:"interval"` // Interval how often to verify the status of the code
37+
38+
timeNow func() time.Time
39+
timeSleep func(time.Duration)
40+
}
41+
42+
type RegistrationConfig struct {
43+
RegistrationURL string `json:"registrationUrl"`
44+
}
45+
46+
const deviceBasePath = "api/private/unauth/account/device"
47+
48+
// RequestCode initiates the authorization flow by requesting a code.
49+
func (c *Config) RequestCode(ctx context.Context) (*DeviceCode, *opsmngr.Response, error) {
50+
req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/authorize",
51+
url.Values{
52+
"client_id": {c.ClientID},
53+
"scope": {strings.Join(c.Scopes, " ")},
54+
},
55+
)
56+
if err != nil {
57+
return nil, nil, err
58+
}
59+
var r *DeviceCode
60+
resp, err2 := c.Do(ctx, req, &r)
61+
return r, resp, err2
62+
}
63+
64+
// GetToken gets a device token.
65+
func (c *Config) GetToken(ctx context.Context, deviceCode string) (*Token, *opsmngr.Response, error) {
66+
req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/token",
67+
url.Values{
68+
"client_id": {c.ClientID},
69+
"device_code": {deviceCode},
70+
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
71+
},
72+
)
73+
if err != nil {
74+
return nil, nil, err
75+
}
76+
var t *Token
77+
resp, err2 := c.Do(ctx, req, &t)
78+
if err2 != nil {
79+
return nil, resp, err2
80+
}
81+
return t, resp, err2
82+
}
83+
84+
// ErrTimeout is returned when polling the server for the granted token has timed out.
85+
var ErrTimeout = errors.New("authentication timed out")
86+
87+
// PollToken polls the server until an access token is granted or denied.
88+
func (c *Config) PollToken(ctx context.Context, code *DeviceCode) (*Token, *opsmngr.Response, error) {
89+
timeNow := code.timeNow
90+
if timeNow == nil {
91+
timeNow = time.Now
92+
}
93+
timeSleep := code.timeSleep
94+
if timeSleep == nil {
95+
timeSleep = time.Sleep
96+
}
97+
98+
checkInterval := time.Duration(code.Interval) * time.Second
99+
expiresAt := timeNow().Add(time.Duration(code.ExpiresIn) * time.Second)
100+
101+
for {
102+
timeSleep(checkInterval)
103+
token, resp, err := c.GetToken(ctx, code.DeviceCode)
104+
var target *opsmngr.ErrorResponse
105+
if errors.As(err, &target) && target.ErrorCode == "DEVICE_AUTHORIZATION_PENDING" {
106+
continue
107+
}
108+
if err != nil {
109+
return nil, resp, err
110+
}
111+
112+
if timeNow().After(expiresAt) {
113+
return nil, nil, ErrTimeout
114+
}
115+
return token, resp, nil
116+
}
117+
}
118+
119+
// RefreshToken takes a refresh token and gets a new access token.
120+
func (c *Config) RefreshToken(ctx context.Context, token string) (*Token, *opsmngr.Response, error) {
121+
req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/token",
122+
url.Values{
123+
"client_id": {c.ClientID},
124+
"refresh_token": {token},
125+
"scope": {strings.Join(c.Scopes, " ")},
126+
"grant_type": {"refresh_token"},
127+
},
128+
)
129+
if err != nil {
130+
return nil, nil, err
131+
}
132+
var t *Token
133+
resp, err2 := c.Do(ctx, req, &t)
134+
if err2 != nil {
135+
return nil, resp, err2
136+
}
137+
return t, resp, err2
138+
}
139+
140+
// RevokeToken takes an access or refresh token and revokes it.
141+
func (c *Config) RevokeToken(ctx context.Context, token, tokenTypeHint string) (*opsmngr.Response, error) {
142+
req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/revoke",
143+
url.Values{
144+
"client_id": {c.ClientID},
145+
"token": {token},
146+
"token_type_hint": {tokenTypeHint},
147+
},
148+
)
149+
if err != nil {
150+
return nil, err
151+
}
152+
153+
return c.Do(ctx, req, nil)
154+
}
155+
156+
// RegistrationConfig retrieves the config used for registration.
157+
func (c *Config) RegistrationConfig(ctx context.Context) (*RegistrationConfig, *opsmngr.Response, error) {
158+
req, err := c.NewRequest(ctx, http.MethodGet, deviceBasePath+"/registration", url.Values{})
159+
if err != nil {
160+
return nil, nil, err
161+
}
162+
var rc *RegistrationConfig
163+
resp, err := c.Do(ctx, req, &rc)
164+
if err != nil {
165+
return nil, resp, err
166+
}
167+
return rc, resp, err
168+
}
169+
170+
// IsTimeoutErr checks if the given error is for the case where the device flow has expired.
171+
func IsTimeoutErr(err error) bool {
172+
var target *opsmngr.ErrorResponse
173+
return errors.Is(err, ErrTimeout) || (errors.As(err, &target) && target.ErrorCode == authExpiredError)
174+
}

auth/device_flow_test.go

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
// Copyright 2022 MongoDB Inc
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package auth
16+
17+
import (
18+
"fmt"
19+
"net/http"
20+
"testing"
21+
22+
"github.com/go-test/deep"
23+
"go.mongodb.org/ops-manager/opsmngr"
24+
)
25+
26+
func TestConfig_RequestCode(t *testing.T) {
27+
config, mux, teardown := setup()
28+
defer teardown()
29+
30+
mux.HandleFunc("/api/private/unauth/account/device/authorize", func(w http.ResponseWriter, r *http.Request) {
31+
testMethod(t, r)
32+
fmt.Fprintf(w, `{
33+
"user_code": "QW3PYV7R",
34+
"verification_uri": "%s/account/connect",
35+
"device_code": "61eef18e310968047ff5e02a",
36+
"expires_in": 600,
37+
"interval": 10
38+
}`, baseURLPath)
39+
})
40+
41+
results, _, err := config.RequestCode(ctx)
42+
if err != nil {
43+
t.Fatalf("RequestCode returned error: %v", err)
44+
}
45+
46+
expected := &DeviceCode{
47+
UserCode: "QW3PYV7R",
48+
VerificationURI: baseURLPath + "/account/connect",
49+
DeviceCode: "61eef18e310968047ff5e02a",
50+
ExpiresIn: 600,
51+
Interval: 10,
52+
}
53+
54+
if diff := deep.Equal(results, expected); diff != nil {
55+
t.Error(diff)
56+
}
57+
}
58+
59+
func TestConfig_GetToken(t *testing.T) {
60+
config, mux, teardown := setup()
61+
defer teardown()
62+
63+
mux.HandleFunc("/api/private/unauth/account/device/token", func(w http.ResponseWriter, r *http.Request) {
64+
testMethod(t, r)
65+
fmt.Fprint(w, `{
66+
"access_token": "secret1",
67+
"refresh_token": "secret2",
68+
"scope": "openid",
69+
"id_token": "idtoken",
70+
"token_type": "Bearer",
71+
"expires_in": 3600
72+
}`)
73+
})
74+
code := &DeviceCode{
75+
DeviceCode: "61eef18e310968047ff5e02a",
76+
ExpiresIn: 600,
77+
Interval: 10,
78+
}
79+
results, _, err := config.GetToken(ctx, code.DeviceCode)
80+
if err != nil {
81+
t.Fatalf("GetToken returned error: %v", err)
82+
}
83+
84+
expected := &Token{
85+
AccessToken: "secret1",
86+
RefreshToken: "secret2",
87+
Scope: "openid",
88+
IDToken: "idtoken",
89+
TokenType: "Bearer",
90+
ExpiresIn: 3600,
91+
}
92+
93+
if diff := deep.Equal(results, expected); diff != nil {
94+
t.Error(diff)
95+
}
96+
}
97+
98+
func TestConfig_RefreshToken(t *testing.T) {
99+
config, mux, teardown := setup()
100+
defer teardown()
101+
102+
mux.HandleFunc("/api/private/unauth/account/device/token", func(w http.ResponseWriter, r *http.Request) {
103+
testMethod(t, r)
104+
fmt.Fprint(w, `{
105+
"access_token": "secret1",
106+
"refresh_token": "secret2",
107+
"scope": "openid",
108+
"id_token": "idtoken",
109+
"token_type": "Bearer",
110+
"expires_in": 3600
111+
}`)
112+
})
113+
114+
results, _, err := config.RefreshToken(ctx, "secret2")
115+
if err != nil {
116+
t.Fatalf("RefreshToken returned error: %v", err)
117+
}
118+
119+
expected := &Token{
120+
AccessToken: "secret1",
121+
RefreshToken: "secret2",
122+
Scope: "openid",
123+
IDToken: "idtoken",
124+
TokenType: "Bearer",
125+
ExpiresIn: 3600,
126+
}
127+
128+
if diff := deep.Equal(results, expected); diff != nil {
129+
t.Error(diff)
130+
}
131+
}
132+
133+
func TestConfig_PollToken(t *testing.T) {
134+
config, mux, teardown := setup()
135+
defer teardown()
136+
137+
mux.HandleFunc("/api/private/unauth/account/device/token", func(w http.ResponseWriter, r *http.Request) {
138+
testMethod(t, r)
139+
_, _ = fmt.Fprint(w, `{
140+
"access_token": "secret1",
141+
"refresh_token": "secret2",
142+
"scope": "openid",
143+
"id_token": "idtoken",
144+
"token_type": "Bearer",
145+
"expires_in": 3600
146+
}`)
147+
})
148+
code := &DeviceCode{
149+
DeviceCode: "61eef18e310968047ff5e02a",
150+
ExpiresIn: 600,
151+
Interval: 10,
152+
}
153+
results, _, err := config.PollToken(ctx, code)
154+
if err != nil {
155+
t.Fatalf("PollToken returned error: %v", err)
156+
}
157+
158+
expected := &Token{
159+
AccessToken: "secret1",
160+
RefreshToken: "secret2",
161+
Scope: "openid",
162+
IDToken: "idtoken",
163+
TokenType: "Bearer",
164+
ExpiresIn: 3600,
165+
}
166+
167+
if diff := deep.Equal(results, expected); diff != nil {
168+
t.Error(diff)
169+
}
170+
}
171+
172+
func TestConfig_RevokeToken(t *testing.T) {
173+
config, mux, teardown := setup()
174+
defer teardown()
175+
176+
mux.HandleFunc("/api/private/unauth/account/device/revoke", func(w http.ResponseWriter, r *http.Request) {
177+
testMethod(t, r)
178+
})
179+
180+
_, err := config.RevokeToken(ctx, "a", "refresh_token")
181+
if err != nil {
182+
t.Fatalf("RequestCode returned error: %v", err)
183+
}
184+
}
185+
186+
func TestConfig_RegistrationConfig(t *testing.T) {
187+
config, mux, teardown := setup()
188+
defer teardown()
189+
190+
mux.HandleFunc("/api/private/unauth/account/device/registration", func(w http.ResponseWriter, r *http.Request) {
191+
if http.MethodGet != r.Method {
192+
t.Errorf("Request method = %v, expected %v", r.Method, http.MethodGet)
193+
}
194+
195+
fmt.Fprint(w, `{
196+
"registrationUrl": "http://localhost:8080/account/register/cli"
197+
}`)
198+
})
199+
200+
results, _, err := config.RegistrationConfig(ctx)
201+
if err != nil {
202+
t.Fatalf("RegistrationConfig returned error: %v", err)
203+
}
204+
205+
expected := &RegistrationConfig{
206+
RegistrationURL: "http://localhost:8080/account/register/cli",
207+
}
208+
209+
if diff := deep.Equal(results, expected); diff != nil {
210+
t.Error(diff)
211+
}
212+
}
213+
214+
func TestIsTimeoutErr(t *testing.T) {
215+
err := &opsmngr.ErrorResponse{
216+
ErrorCode: "DEVICE_AUTHORIZATION_EXPIRED",
217+
}
218+
if !IsTimeoutErr(err) {
219+
t.Error("expected to be a timeout error")
220+
}
221+
}

0 commit comments

Comments
 (0)