Skip to content

Commit d8df82e

Browse files
authored
SNOW-1825790 Token cache refactor - v2 (#1299)
1 parent e926883 commit d8df82e

File tree

3 files changed

+121
-152
lines changed

3 files changed

+121
-152
lines changed

auth.go

+6-16
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ const (
2626
)
2727

2828
const (
29-
idToken = "ID_TOKEN"
30-
mfaToken = "MFATOKEN"
3129
clientStoreTemporaryCredential = "CLIENT_STORE_TEMPORARY_CREDENTIAL"
3230
clientRequestMfaToken = "CLIENT_REQUEST_MFA_TOKEN"
3331
idTokenAuthenticator = "ID_TOKEN"
@@ -365,10 +363,10 @@ func authenticate(
365363
logger.WithContext(ctx).Errorln("Authentication FAILED")
366364
sc.rest.TokenAccessor.SetTokens("", "", -1)
367365
if sessionParameters[clientRequestMfaToken] == true {
368-
credentialsStorage.deleteCredential(sc, mfaToken)
366+
credentialsStorage.deleteCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
369367
}
370368
if sessionParameters[clientStoreTemporaryCredential] == true {
371-
credentialsStorage.deleteCredential(sc, idToken)
369+
credentialsStorage.deleteCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
372370
}
373371
code, err := strconv.Atoi(respd.Code)
374372
if err != nil {
@@ -384,11 +382,11 @@ func authenticate(
384382
sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID)
385383
if sessionParameters[clientRequestMfaToken] == true {
386384
token := respd.Data.MfaToken
387-
credentialsStorage.setCredential(sc, mfaToken, token)
385+
credentialsStorage.setCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User), token)
388386
}
389387
if sessionParameters[clientStoreTemporaryCredential] == true {
390388
token := respd.Data.IDToken
391-
credentialsStorage.setCredential(sc, idToken, token)
389+
credentialsStorage.setCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User), token)
392390
}
393391
return &respd.Data, nil
394392
}
@@ -523,7 +521,7 @@ func authenticateWithConfig(sc *snowflakeConn) error {
523521
sc.cfg.ClientStoreTemporaryCredential = ConfigBoolTrue
524522
}
525523
if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
526-
fillCachedIDToken(sc)
524+
sc.cfg.IDToken = credentialsStorage.getCredential(newIDTokenSpec(sc.cfg.Host, sc.cfg.User))
527525
}
528526
// Disable console login by default
529527
if sc.cfg.DisableConsoleLogin == configBoolNotSet {
@@ -536,7 +534,7 @@ func authenticateWithConfig(sc *snowflakeConn) error {
536534
sc.cfg.ClientRequestMfaToken = ConfigBoolTrue
537535
}
538536
if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue {
539-
fillCachedMfaToken(sc)
537+
sc.cfg.MfaToken = credentialsStorage.getCredential(newMfaTokenSpec(sc.cfg.Host, sc.cfg.User))
540538
}
541539
}
542540

@@ -573,11 +571,3 @@ func authenticateWithConfig(sc *snowflakeConn) error {
573571
sc.ctx = context.WithValue(sc.ctx, SFSessionIDKey, authData.SessionID)
574572
return nil
575573
}
576-
577-
func fillCachedIDToken(sc *snowflakeConn) {
578-
credentialsStorage.getCredential(sc, idToken)
579-
}
580-
581-
func fillCachedMfaToken(sc *snowflakeConn) {
582-
credentialsStorage.getCredential(sc, mfaToken)
583-
}

secure_storage_manager.go

+82-64
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,61 @@ import (
1515
"github.com/99designs/keyring"
1616
)
1717

18+
type tokenType string
19+
20+
const (
21+
idToken tokenType = "ID_TOKEN"
22+
mfaToken tokenType = "MFATOKEN"
23+
)
24+
1825
const (
1926
driverName = "SNOWFLAKE-GO-DRIVER"
2027
credCacheDirEnv = "SF_TEMPORARY_CREDENTIAL_CACHE_DIR"
2128
credCacheFileName = "temporary_credential.json"
2229
)
2330

31+
type secureTokenSpec struct {
32+
host, user string
33+
tokenType tokenType
34+
}
35+
36+
func (t *secureTokenSpec) buildKey() string {
37+
return buildCredentialsKey(t.host, t.user, t.tokenType)
38+
}
39+
40+
func newMfaTokenSpec(host, user string) *secureTokenSpec {
41+
return &secureTokenSpec{
42+
host,
43+
user,
44+
mfaToken,
45+
}
46+
}
47+
48+
func newIDTokenSpec(host, user string) *secureTokenSpec {
49+
return &secureTokenSpec{
50+
host,
51+
user,
52+
idToken,
53+
}
54+
}
55+
2456
type secureStorageManager interface {
25-
setCredential(sc *snowflakeConn, credType, token string)
26-
getCredential(sc *snowflakeConn, credType string)
27-
deleteCredential(sc *snowflakeConn, credType string)
57+
setCredential(tokenSpec *secureTokenSpec, value string)
58+
getCredential(tokenSpec *secureTokenSpec) string
59+
deleteCredential(tokenSpec *secureTokenSpec)
2860
}
2961

3062
var credentialsStorage = newSecureStorageManager()
3163

3264
func newSecureStorageManager() secureStorageManager {
3365
switch runtime.GOOS {
3466
case "linux":
35-
return newFileBasedSecureStorageManager()
67+
ssm, err := newFileBasedSecureStorageManager()
68+
if err != nil {
69+
logger.Debugf("failed to create credentials cache dir. %v", err)
70+
return newNoopSecureStorageManager()
71+
}
72+
return ssm
3673
case "darwin", "windows":
3774
return newKeyringBasedSecureStorageManager()
3875
default:
@@ -46,20 +83,19 @@ type fileBasedSecureStorageManager struct {
4683
credCacheLock sync.RWMutex
4784
}
4885

49-
func newFileBasedSecureStorageManager() secureStorageManager {
86+
func newFileBasedSecureStorageManager() (*fileBasedSecureStorageManager, error) {
5087
ssm := &fileBasedSecureStorageManager{
5188
localCredCache: map[string]string{},
5289
credCacheLock: sync.RWMutex{},
5390
}
5491
credCacheDir := ssm.buildCredCacheDirPath()
5592
if err := ssm.createCacheDir(credCacheDir); err != nil {
56-
logger.Debugf("failed to create credentials cache dir. %v", err)
57-
return newNoopSecureStorageManager()
93+
return nil, err
5894
}
5995
credCacheFilePath := filepath.Join(credCacheDir, credCacheFileName)
6096
logger.Infof("Credentials cache path: %v", credCacheFilePath)
6197
ssm.credCacheFilePath = credCacheFilePath
62-
return ssm
98+
return ssm, nil
6399
}
64100

65101
func (ssm *fileBasedSecureStorageManager) createCacheDir(credCacheDir string) error {
@@ -87,14 +123,14 @@ func (ssm *fileBasedSecureStorageManager) buildCredCacheDirPath() string {
87123
return credCacheDir
88124
}
89125

90-
func (ssm *fileBasedSecureStorageManager) setCredential(sc *snowflakeConn, credType, token string) {
91-
if token == "" {
126+
func (ssm *fileBasedSecureStorageManager) setCredential(tokenSpec *secureTokenSpec, value string) {
127+
if value == "" {
92128
logger.Debug("no token provided")
93129
} else {
94-
credentialsKey := buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType)
130+
credentialsKey := tokenSpec.buildKey()
95131
ssm.credCacheLock.Lock()
96132
defer ssm.credCacheLock.Unlock()
97-
ssm.localCredCache[credentialsKey] = token
133+
ssm.localCredCache[credentialsKey] = value
98134

99135
j, err := json.Marshal(ssm.localCredCache)
100136
if err != nil {
@@ -135,8 +171,8 @@ func (ssm *fileBasedSecureStorageManager) setCredential(sc *snowflakeConn, credT
135171
}
136172
}
137173

138-
func (ssm *fileBasedSecureStorageManager) getCredential(sc *snowflakeConn, credType string) {
139-
credentialsKey := buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType)
174+
func (ssm *fileBasedSecureStorageManager) getCredential(tokenSpec *secureTokenSpec) string {
175+
credentialsKey := tokenSpec.buildKey()
140176
ssm.credCacheLock.Lock()
141177
defer ssm.credCacheLock.Unlock()
142178
localCredCache := ssm.readTemporaryCacheFile()
@@ -146,14 +182,7 @@ func (ssm *fileBasedSecureStorageManager) getCredential(sc *snowflakeConn, credT
146182
} else {
147183
logger.Debug("Returned credential is empty")
148184
}
149-
150-
if credType == idToken {
151-
sc.cfg.IDToken = cred
152-
} else if credType == mfaToken {
153-
sc.cfg.MfaToken = cred
154-
} else {
155-
logger.Debugf("Unrecognized type %v for local cached credential", credType)
156-
}
185+
return cred
157186
}
158187

159188
func (ssm *fileBasedSecureStorageManager) readTemporaryCacheFile() map[string]string {
@@ -171,10 +200,10 @@ func (ssm *fileBasedSecureStorageManager) readTemporaryCacheFile() map[string]st
171200
return ssm.localCredCache
172201
}
173202

174-
func (ssm *fileBasedSecureStorageManager) deleteCredential(sc *snowflakeConn, credType string) {
203+
func (ssm *fileBasedSecureStorageManager) deleteCredential(tokenSpec *secureTokenSpec) {
175204
ssm.credCacheLock.Lock()
176205
defer ssm.credCacheLock.Unlock()
177-
credentialsKey := buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType)
206+
credentialsKey := tokenSpec.buildKey()
178207
delete(ssm.localCredCache, credentialsKey)
179208
j, err := json.Marshal(ssm.localCredCache)
180209
if err != nil {
@@ -220,37 +249,35 @@ func (ssm *fileBasedSecureStorageManager) writeTemporaryCacheFile(input []byte)
220249
type keyringSecureStorageManager struct {
221250
}
222251

223-
func newKeyringBasedSecureStorageManager() secureStorageManager {
252+
func newKeyringBasedSecureStorageManager() *keyringSecureStorageManager {
224253
return &keyringSecureStorageManager{}
225254
}
226255

227-
func (ssm *keyringSecureStorageManager) setCredential(sc *snowflakeConn, credType, token string) {
228-
if token == "" {
256+
func (ssm *keyringSecureStorageManager) setCredential(tokenSpec *secureTokenSpec, value string) {
257+
if value == "" {
229258
logger.Debug("no token provided")
230259
} else {
231-
var credentialsKey string
260+
credentialsKey := tokenSpec.buildKey()
232261
if runtime.GOOS == "windows" {
233-
credentialsKey = driverName + ":" + credType
234262
ring, _ := keyring.Open(keyring.Config{
235-
WinCredPrefix: strings.ToUpper(sc.cfg.Host),
236-
ServiceName: strings.ToUpper(sc.cfg.User),
263+
WinCredPrefix: strings.ToUpper(tokenSpec.host),
264+
ServiceName: strings.ToUpper(tokenSpec.user),
237265
})
238266
item := keyring.Item{
239267
Key: credentialsKey,
240-
Data: []byte(token),
268+
Data: []byte(value),
241269
}
242270
if err := ring.Set(item); err != nil {
243271
logger.Debugf("Failed to write to Windows credential manager. Err: %v", err)
244272
}
245273
} else if runtime.GOOS == "darwin" {
246-
credentialsKey = buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType)
247274
ring, _ := keyring.Open(keyring.Config{
248275
ServiceName: credentialsKey,
249276
})
250-
account := strings.ToUpper(sc.cfg.User)
277+
account := strings.ToUpper(tokenSpec.user)
251278
item := keyring.Item{
252279
Key: account,
253-
Data: []byte(token),
280+
Data: []byte(value),
254281
}
255282
if err := ring.Set(item); err != nil {
256283
logger.Debugf("Failed to write to keychain. Err: %v", err)
@@ -259,26 +286,24 @@ func (ssm *keyringSecureStorageManager) setCredential(sc *snowflakeConn, credTyp
259286
}
260287
}
261288

262-
func (ssm *keyringSecureStorageManager) getCredential(sc *snowflakeConn, credType string) {
263-
var credentialsKey string
289+
func (ssm *keyringSecureStorageManager) getCredential(tokenSpec *secureTokenSpec) string {
264290
cred := ""
291+
credentialsKey := tokenSpec.buildKey()
265292
if runtime.GOOS == "windows" {
266-
credentialsKey = driverName + ":" + credType
267293
ring, _ := keyring.Open(keyring.Config{
268-
WinCredPrefix: strings.ToUpper(sc.cfg.Host),
269-
ServiceName: strings.ToUpper(sc.cfg.User),
294+
WinCredPrefix: strings.ToUpper(tokenSpec.host),
295+
ServiceName: strings.ToUpper(tokenSpec.user),
270296
})
271297
i, err := ring.Get(credentialsKey)
272298
if err != nil {
273299
logger.Debugf("Failed to read credentialsKey or could not find it in Windows Credential Manager. Error: %v", err)
274300
}
275301
cred = string(i.Data)
276302
} else if runtime.GOOS == "darwin" {
277-
credentialsKey = buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType)
278303
ring, _ := keyring.Open(keyring.Config{
279304
ServiceName: credentialsKey,
280305
})
281-
account := strings.ToUpper(sc.cfg.User)
306+
account := strings.ToUpper(tokenSpec.user)
282307
i, err := ring.Get(account)
283308
if err != nil {
284309
logger.Debugf("Failed to find the item in keychain or item does not exist. Error: %v", err)
@@ -290,59 +315,52 @@ func (ssm *keyringSecureStorageManager) getCredential(sc *snowflakeConn, credTyp
290315
logger.Debug("Successfully read token. Returning as string")
291316
}
292317
}
293-
294-
if credType == idToken {
295-
sc.cfg.IDToken = cred
296-
} else if credType == mfaToken {
297-
sc.cfg.MfaToken = cred
298-
} else {
299-
logger.Debugf("Unrecognized type %v for local cached credential", credType)
300-
}
318+
return cred
301319
}
302320

303-
func (ssm *keyringSecureStorageManager) deleteCredential(sc *snowflakeConn, credType string) {
304-
credentialsKey := driverName + ":" + credType
321+
func (ssm *keyringSecureStorageManager) deleteCredential(tokenSpec *secureTokenSpec) {
322+
credentialsKey := tokenSpec.buildKey()
305323
if runtime.GOOS == "windows" {
306324
ring, _ := keyring.Open(keyring.Config{
307-
WinCredPrefix: strings.ToUpper(sc.cfg.Host),
308-
ServiceName: strings.ToUpper(sc.cfg.User),
325+
WinCredPrefix: strings.ToUpper(tokenSpec.host),
326+
ServiceName: strings.ToUpper(tokenSpec.user),
309327
})
310-
err := ring.Remove(credentialsKey)
328+
err := ring.Remove(string(credentialsKey))
311329
if err != nil {
312330
logger.Debugf("Failed to delete credentialsKey in Windows Credential Manager. Error: %v", err)
313331
}
314332
} else if runtime.GOOS == "darwin" {
315-
credentialsKey = buildCredentialsKey(sc.cfg.Host, sc.cfg.User, credType)
316333
ring, _ := keyring.Open(keyring.Config{
317334
ServiceName: credentialsKey,
318335
})
319-
account := strings.ToUpper(sc.cfg.User)
336+
account := strings.ToUpper(tokenSpec.user)
320337
err := ring.Remove(account)
321338
if err != nil {
322339
logger.Debugf("Failed to delete credentialsKey in keychain. Error: %v", err)
323340
}
324341
}
325342
}
326343

327-
func buildCredentialsKey(host, user, credType string) string {
344+
func buildCredentialsKey(host, user string, credType tokenType) string {
328345
host = strings.ToUpper(host)
329346
user = strings.ToUpper(user)
330-
credType = strings.ToUpper(credType)
331-
return host + ":" + user + ":" + driverName + ":" + credType
347+
credTypeStr := strings.ToUpper(string(credType))
348+
return host + ":" + user + ":" + driverName + ":" + credTypeStr
332349
}
333350

334351
type noopSecureStorageManager struct {
335352
}
336353

337-
func newNoopSecureStorageManager() secureStorageManager {
354+
func newNoopSecureStorageManager() *noopSecureStorageManager {
338355
return &noopSecureStorageManager{}
339356
}
340357

341-
func (ssm *noopSecureStorageManager) setCredential(sc *snowflakeConn, credType, token string) {
358+
func (ssm *noopSecureStorageManager) setCredential(_ *secureTokenSpec, _ string) {
342359
}
343360

344-
func (ssm *noopSecureStorageManager) getCredential(sc *snowflakeConn, credType string) {
361+
func (ssm *noopSecureStorageManager) getCredential(_ *secureTokenSpec) string {
362+
return ""
345363
}
346364

347-
func (ssm *noopSecureStorageManager) deleteCredential(sc *snowflakeConn, credType string) { //TODO implement me
365+
func (ssm *noopSecureStorageManager) deleteCredential(_ *secureTokenSpec) { //TODO implement me
348366
}

0 commit comments

Comments
 (0)