Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1825790 Implement safer file based token cache #1327

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
25 changes: 25 additions & 0 deletions os_specific_posix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//go:build darwin || linux

package gosnowflake

import (
"fmt"
"os"
"syscall"
)

func provideFileOwner(file *os.File) (uint32, error) {
info, err := file.Stat()
if err != nil {
return 0, err
}
return provideOwnerFromStat(info, file.Name())
}

func provideOwnerFromStat(info os.FileInfo, filepath string) (uint32, error) {
nativeStat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return 0, fmt.Errorf("cannot cast file info for %v to *syscall.Stat_t", filepath)
}
return nativeStat.Uid, nil
}
12 changes: 12 additions & 0 deletions os_specific_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// go:build windows

package gosnowflake

import (
"errors"
"os"
)

func provideFileOwner(file *os.File) (uint32, error) {
return 0, errors.New("provideFileOwner is unsupported on windows")
}
458 changes: 328 additions & 130 deletions secure_storage_manager.go
Original file line number Diff line number Diff line change
@@ -3,16 +3,21 @@
package gosnowflake

import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/99designs/keyring"
"io"
"os"
"os/user"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"

"github.com/99designs/keyring"
)

type tokenType string
@@ -23,17 +28,27 @@ const (
)

const (
driverName = "SNOWFLAKE-GO-DRIVER"
credCacheDirEnv = "SF_TEMPORARY_CREDENTIAL_CACHE_DIR"
credCacheFileName = "temporary_credential.json"
credCacheFileName = "credential_cache_v1.json"
)

type cacheDirConf struct {
envVar string
pathSegments []string
}

var defaultLinuxCacheDirConf = []cacheDirConf{
{envVar: credCacheDirEnv, pathSegments: []string{}},
{envVar: "XDG_CACHE_DIR", pathSegments: []string{"snowflake"}},
{envVar: "HOME", pathSegments: []string{".cache", "snowflake"}},
}

type secureTokenSpec struct {
host, user string
tokenType tokenType
}

func (t *secureTokenSpec) buildKey() string {
func (t *secureTokenSpec) buildKey() (string, error) {
return buildCredentialsKey(t.host, t.user, t.tokenType)
}

@@ -69,181 +84,323 @@ func newSecureStorageManager() secureStorageManager {
logger.Debugf("failed to create credentials cache dir. %v", err)
return newNoopSecureStorageManager()
}
return ssm
return &threadSafeSecureStorageManager{&sync.Mutex{}, ssm}
case "darwin", "windows":
return newKeyringBasedSecureStorageManager()
return &threadSafeSecureStorageManager{&sync.Mutex{}, newKeyringBasedSecureStorageManager()}
default:
logger.Warnf("OS %v does not support credentials cache", runtime.GOOS)
return newNoopSecureStorageManager()
}
}

type fileBasedSecureStorageManager struct {
credCacheFilePath string
localCredCache map[string]string
credCacheLock sync.RWMutex
credDirPath string
}

func newFileBasedSecureStorageManager() (*fileBasedSecureStorageManager, error) {
ssm := &fileBasedSecureStorageManager{
localCredCache: map[string]string{},
credCacheLock: sync.RWMutex{},
}
credCacheDir := ssm.buildCredCacheDirPath()
if err := ssm.createCacheDir(credCacheDir); err != nil {
credDirPath, err := buildCredCacheDirPath(defaultLinuxCacheDirConf)
if err != nil {
return nil, err
}
credCacheFilePath := filepath.Join(credCacheDir, credCacheFileName)
logger.Infof("Credentials cache path: %v", credCacheFilePath)
ssm.credCacheFilePath = credCacheFilePath
ssm := &fileBasedSecureStorageManager{
credDirPath: credDirPath,
}
return ssm, nil
}

func (ssm *fileBasedSecureStorageManager) createCacheDir(credCacheDir string) error {
_, err := os.Stat(credCacheDir)
if os.IsNotExist(err) {
if err = os.MkdirAll(credCacheDir, os.ModePerm); err != nil {
return fmt.Errorf("failed to create cache directory. %v, err: %v", credCacheDir, err)
func lookupCacheDir(envVar string, pathSegments ...string) (string, error) {
envVal := os.Getenv(envVar)
if envVal == "" {
return "", fmt.Errorf("environment variable %s not set", envVar)
}

fileInfo, err := os.Stat(envVal)
if err != nil {
return "", fmt.Errorf("failed to stat %s=%s, due to %v", envVar, envVal, err)
}

if !fileInfo.IsDir() {
return "", fmt.Errorf("environment variable %s=%s is not a directory", envVar, envVal)
}

cacheDir := filepath.Join(envVal, filepath.Join(pathSegments...))
parentOfCacheDir := cacheDir[:strings.LastIndex(cacheDir, "/")]

if err = os.MkdirAll(parentOfCacheDir, os.FileMode(0755)); err != nil {
return "", err
}

// We don't check if permissions are incorrect here if a directory exists, because we check it later.
if err = os.Mkdir(cacheDir, os.FileMode(0700)); err != nil && !errors.Is(err, os.ErrExist) {
return "", err
}

return cacheDir, nil
}

func buildCredCacheDirPath(confs []cacheDirConf) (string, error) {
for _, conf := range confs {
path, err := lookupCacheDir(conf.envVar, conf.pathSegments...)
if err != nil {
logger.Debugf("Skipping %s in cache directory lookup due to %v", conf.envVar, err)
} else {
logger.Debugf("Using %s as cache directory", path)
return path, nil
}
return nil
}
return err

return "", errors.New("no credentials cache directory found")
}

func (ssm *fileBasedSecureStorageManager) buildCredCacheDirPath() string {
credCacheDir := os.Getenv(credCacheDirEnv)
if credCacheDir != "" {
return credCacheDir
func (ssm *fileBasedSecureStorageManager) getTokens(data map[string]any) map[string]interface{} {
val, ok := data["tokens"]
if !ok {
return map[string]interface{}{}
}
home := os.Getenv("HOME")
if home == "" {
logger.Info("HOME is blank")
return ""

tokens, ok := val.(map[string]interface{})
if !ok {
return map[string]interface{}{}
}

return tokens
}

func (ssm *fileBasedSecureStorageManager) withLock(action func(cacheFile *os.File)) {
err := ssm.lockFile()
if err != nil {
logger.Warnf("Unable to lock cache. %v", err)
return
}
credCacheDir = filepath.Join(home, ".cache", "snowflake")
return credCacheDir
defer ssm.unlockFile()

ssm.withCacheFile(action)
}

func (ssm *fileBasedSecureStorageManager) withCacheFile(action func(*os.File)) {
cacheFile, err := os.OpenFile(ssm.credFilePath(), os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
logger.Warnf("cannot access %v. %v", ssm.credFilePath(), err)
return
}
defer func(file *os.File) {
if err := file.Close(); err != nil {
logger.Warnf("cannot release file descriptor for %v. %v", ssm.credFilePath(), err)
}
}(cacheFile)

cacheDir, err := os.Open(ssm.credDirPath)
if err != nil {
logger.Warnf("cannot access %v. %v", ssm.credDirPath, err)
}

if err := ssm.ensurePermissionsAndOwner(cacheFile, 0600); err != nil {
logger.Warnf("failed to ensure permission for temporary cache file. %v", err)
return
}
if err := ssm.ensurePermissionsAndOwner(cacheDir, 0700|os.ModeDir); err != nil {
logger.Warnf("failed to ensure permission for temporary cache dir. %v", err)
return
}

action(cacheFile)
}

func (ssm *fileBasedSecureStorageManager) setCredential(tokenSpec *secureTokenSpec, value string) {
if value == "" {
logger.Debug("no token provided")
} else {
credentialsKey := tokenSpec.buildKey()
ssm.credCacheLock.Lock()
defer ssm.credCacheLock.Unlock()
ssm.localCredCache[credentialsKey] = value
credentialsKey, err := tokenSpec.buildKey()
if err != nil {
logger.Warn(err)
return
}

j, err := json.Marshal(ssm.localCredCache)
ssm.withLock(func(cacheFile *os.File) {
credCache, err := ssm.readTemporaryCacheFile(cacheFile)
if err != nil {
logger.Warnf("failed to convert credential to JSON.")
logger.Warnf("Error while reading cache file. %v", err)
return
}
tokens := ssm.getTokens(credCache)
tokens[credentialsKey] = value
credCache["tokens"] = tokens
err = ssm.writeTemporaryCacheFile(credCache, cacheFile)
if err != nil {
logger.Warnf("Set credential failed. Unable to write cache. %v", err)
}
})
}

logger.Debugf("writing credential cache file. %v\n", ssm.credCacheFilePath)
credCacheLockFileName := ssm.credCacheFilePath + ".lck"
logger.Debugf("Creating lock file. %v", credCacheLockFileName)
err = os.Mkdir(credCacheLockFileName, 0600)

switch {
case os.IsExist(err):
statinfo, err := os.Stat(credCacheLockFileName)
if err != nil {
logger.Debugf("failed to write credential cache file. file: %v, err: %v. ignored.\n", ssm.credCacheFilePath, err)
return
}
if time.Since(statinfo.ModTime()) < 15*time.Minute {
logger.Debugf("other process locks the cache file. %v. ignored.\n", ssm.credCacheFilePath)
return
}
if err = os.Remove(credCacheLockFileName); err != nil {
logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err)
return
}
if err = os.Mkdir(credCacheLockFileName, 0600); err != nil {
logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err)
return
}
func (ssm *fileBasedSecureStorageManager) lockPath() string {
return filepath.Join(ssm.credDirPath, credCacheFileName+".lck")
}

func (ssm *fileBasedSecureStorageManager) lockFile() error {
const numRetries = 10
const retryInterval = 100 * time.Millisecond
lockPath := ssm.lockPath()

fileInfo, err := os.Stat(lockPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("failed to stat %v and determine if lock is stale. err: %v", lockPath, err)
}

// removing stale lock
now := time.Now()
if !errors.Is(err, os.ErrNotExist) && fileInfo.ModTime().Add(time.Second).UnixNano() < now.UnixNano() {
logger.Debugf("removing credentials cache lock file, stale for %vms", (now.UnixNano()-fileInfo.ModTime().UnixNano())/1000/1000)
err = os.Remove(lockPath)
if err != nil {
return fmt.Errorf("failed to remove %v while trying to remove stale lock. err: %v", lockPath, err)
}
defer os.RemoveAll(credCacheLockFileName)
}

if err = os.WriteFile(ssm.credCacheFilePath, j, 0644); err != nil {
logger.Debugf("Failed to write the cache file. File: %v err: %v.", ssm.credCacheFilePath, err)
locked := false
for i := 0; i < numRetries; i++ {
err := os.Mkdir(lockPath, 0700)
if err != nil {
if errors.Is(err, os.ErrExist) {
time.Sleep(retryInterval)
continue
}
return fmt.Errorf("failed to create cache lock: %v, err: %v", lockPath, err)
}
locked = true
break
}
if !locked {
return fmt.Errorf("failed to lock cache. lockPath: %v", lockPath)
}
return nil
}

func (ssm *fileBasedSecureStorageManager) unlockFile() {
lockPath := ssm.lockPath()
err := os.Remove(lockPath)
if err != nil {
logger.Warnf("Failed to unlock cache lock: %v. %v", lockPath, err)
}
}

func (ssm *fileBasedSecureStorageManager) getCredential(tokenSpec *secureTokenSpec) string {
credentialsKey := tokenSpec.buildKey()
ssm.credCacheLock.Lock()
defer ssm.credCacheLock.Unlock()
localCredCache := ssm.readTemporaryCacheFile()
cred := localCredCache[credentialsKey]
if cred != "" {
logger.Debug("Successfully read token. Returning as string")
} else {
logger.Debug("Returned credential is empty")
credentialsKey, err := tokenSpec.buildKey()
if err != nil {
logger.Warn(err)
return ""
}
return cred

ret := ""
ssm.withLock(func(cacheFile *os.File) {
credCache, err := ssm.readTemporaryCacheFile(cacheFile)
if err != nil {
logger.Warnf("Error while reading cache file. %v", err)
return
}
cred, ok := ssm.getTokens(credCache)[credentialsKey]
if !ok {
return
}

credStr, ok := cred.(string)
if !ok {
return
}

ret = credStr
})
return ret
}

func (ssm *fileBasedSecureStorageManager) readTemporaryCacheFile() map[string]string {
jsonData, err := os.ReadFile(ssm.credCacheFilePath)
func (ssm *fileBasedSecureStorageManager) credFilePath() string {
return filepath.Join(ssm.credDirPath, credCacheFileName)
}

func (ssm *fileBasedSecureStorageManager) ensurePermissionsAndOwner(f *os.File, expectedMode os.FileMode) error {
fileInfo, err := f.Stat()
if err != nil {
logger.Debugf("Failed to read credential file: %v", err)
return nil
return err
}

if fileInfo.Mode() != expectedMode {
return fmt.Errorf("incorrect permissions(%v, expected %v) for credential file", fileInfo.Mode(), expectedMode)
}

ownerUID, err := provideFileOwner(f)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
err = json.Unmarshal([]byte(jsonData), &ssm.localCredCache)
currentUser, err := user.Current()
if err != nil {
logger.Debugf("failed to read JSON. Err: %v", err)
return err
}
if errors.Is(err, os.ErrNotExist) {
return nil
}
if strconv.Itoa(int(ownerUID)) != currentUser.Uid {
return errors.New("incorrect owner of " + ssm.credDirPath)
}
return nil
}

func (ssm *fileBasedSecureStorageManager) readTemporaryCacheFile(cacheFile *os.File) (map[string]any, error) {

jsonData, err := io.ReadAll(cacheFile)
if err != nil {
logger.Warnf("Failed to read credential cache file. %v.\n", err)
return map[string]any{}, nil
}
if _, err = cacheFile.Seek(0, 0); err != nil {
return map[string]any{}, fmt.Errorf("cannot seek to the beginning of a cache file. %v", err)
}

if len(jsonData) == 0 {
// Happens when the file didn't exist before.
return map[string]any{}, nil
}

return ssm.localCredCache
credentialsMap := map[string]any{}
err = json.Unmarshal(jsonData, &credentialsMap)
if err != nil {
return map[string]any{}, fmt.Errorf("failed to unmarshal credential cache file. %v", err)
}

return credentialsMap, nil
}

func (ssm *fileBasedSecureStorageManager) deleteCredential(tokenSpec *secureTokenSpec) {
ssm.credCacheLock.Lock()
defer ssm.credCacheLock.Unlock()
credentialsKey := tokenSpec.buildKey()
delete(ssm.localCredCache, credentialsKey)
j, err := json.Marshal(ssm.localCredCache)
credentialsKey, err := tokenSpec.buildKey()
if err != nil {
logger.Warnf("failed to convert credential to JSON.")
logger.Warn(err)
return
}
ssm.writeTemporaryCacheFile(j)
}

func (ssm *fileBasedSecureStorageManager) writeTemporaryCacheFile(input []byte) {
logger.Debugf("writing credential cache file. %v\n", ssm.credCacheFilePath)
credCacheLockFileName := ssm.credCacheFilePath + ".lck"
err := os.Mkdir(credCacheLockFileName, 0600)
logger.Debugf("Creating lock file. %v", credCacheLockFileName)

switch {
case os.IsExist(err):
statinfo, err := os.Stat(credCacheLockFileName)
ssm.withLock(func(cacheFile *os.File) {
credCache, err := ssm.readTemporaryCacheFile(cacheFile)
if err != nil {
logger.Debugf("failed to write credential cache file. file: %v, err: %v. ignored.\n", ssm.credCacheFilePath, err)
return
}
if time.Since(statinfo.ModTime()) < 15*time.Minute {
logger.Debugf("other process locks the cache file. %v. ignored.\n", ssm.credCacheFilePath)
return
}
if err = os.Remove(credCacheLockFileName); err != nil {
logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err)
logger.Warnf("Error while reading cache file. %v", err)
return
}
if err = os.Mkdir(credCacheLockFileName, 0600); err != nil {
logger.Debugf("failed to delete lock file. file: %v, err: %v. ignored.\n", credCacheLockFileName, err)
return
delete(ssm.getTokens(credCache), credentialsKey)

err = ssm.writeTemporaryCacheFile(credCache, cacheFile)
if err != nil {
logger.Warnf("Set credential failed. Unable to write cache. %v", err)
}
})
}

func (ssm *fileBasedSecureStorageManager) writeTemporaryCacheFile(cache map[string]any, cacheFile *os.File) error {
bytes, err := json.Marshal(cache)
if err != nil {
return fmt.Errorf("failed to marshal credential cache map. %w", err)
}
defer os.RemoveAll(credCacheLockFileName)

if err = os.WriteFile(ssm.credCacheFilePath, input, 0644); err != nil {
logger.Debugf("Failed to write the cache file. File: %v err: %v.", ssm.credCacheFilePath, err)
if err = cacheFile.Truncate(0); err != nil {
return fmt.Errorf("error while truncating credentials cache. %v", err)
}
_, err = cacheFile.Write(bytes)
if err != nil {
return fmt.Errorf("failed to write the credential cache file: %w", err)
}
return nil
}

type keyringSecureStorageManager struct {
@@ -257,7 +414,11 @@ func (ssm *keyringSecureStorageManager) setCredential(tokenSpec *secureTokenSpec
if value == "" {
logger.Debug("no token provided")
} else {
credentialsKey := tokenSpec.buildKey()
credentialsKey, err := tokenSpec.buildKey()
if err != nil {
logger.Warn(err)
return
}
if runtime.GOOS == "windows" {
ring, _ := keyring.Open(keyring.Config{
WinCredPrefix: strings.ToUpper(tokenSpec.host),
@@ -288,7 +449,11 @@ func (ssm *keyringSecureStorageManager) setCredential(tokenSpec *secureTokenSpec

func (ssm *keyringSecureStorageManager) getCredential(tokenSpec *secureTokenSpec) string {
cred := ""
credentialsKey := tokenSpec.buildKey()
credentialsKey, err := tokenSpec.buildKey()
if err != nil {
logger.Warn(err)
return ""
}
if runtime.GOOS == "windows" {
ring, _ := keyring.Open(keyring.Config{
WinCredPrefix: strings.ToUpper(tokenSpec.host),
@@ -319,7 +484,11 @@ func (ssm *keyringSecureStorageManager) getCredential(tokenSpec *secureTokenSpec
}

func (ssm *keyringSecureStorageManager) deleteCredential(tokenSpec *secureTokenSpec) {
credentialsKey := tokenSpec.buildKey()
credentialsKey, err := tokenSpec.buildKey()
if err != nil {
logger.Warn(err)
return
}
if runtime.GOOS == "windows" {
ring, _ := keyring.Open(keyring.Config{
WinCredPrefix: strings.ToUpper(tokenSpec.host),
@@ -341,11 +510,17 @@ func (ssm *keyringSecureStorageManager) deleteCredential(tokenSpec *secureTokenS
}
}

func buildCredentialsKey(host, user string, credType tokenType) string {
host = strings.ToUpper(host)
user = strings.ToUpper(user)
credTypeStr := strings.ToUpper(string(credType))
return host + ":" + user + ":" + driverName + ":" + credTypeStr
func buildCredentialsKey(host, user string, credType tokenType) (string, error) {
if host == "" {
return "", errors.New("host is not provided to store in token cache, skipping")
}
if user == "" {
return "", errors.New("user is not provided to store in token cache, skipping")
}
plainCredKey := host + ":" + user + ":" + string(credType)
checksum := sha256.New()
checksum.Write([]byte(plainCredKey))
return hex.EncodeToString(checksum.Sum(nil)), nil
}

type noopSecureStorageManager struct {
@@ -362,5 +537,28 @@ func (ssm *noopSecureStorageManager) getCredential(_ *secureTokenSpec) string {
return ""
}

func (ssm *noopSecureStorageManager) deleteCredential(_ *secureTokenSpec) { //TODO implement me
func (ssm *noopSecureStorageManager) deleteCredential(_ *secureTokenSpec) {
}

type threadSafeSecureStorageManager struct {
mu *sync.Mutex
delegate secureStorageManager
}

func (ssm *threadSafeSecureStorageManager) setCredential(tokenSpec *secureTokenSpec, value string) {
ssm.mu.Lock()
defer ssm.mu.Unlock()
ssm.delegate.setCredential(tokenSpec, value)
}

func (ssm *threadSafeSecureStorageManager) getCredential(tokenSpec *secureTokenSpec) string {
ssm.mu.Lock()
defer ssm.mu.Unlock()
return ssm.delegate.getCredential(tokenSpec)
}

func (ssm *threadSafeSecureStorageManager) deleteCredential(tokenSpec *secureTokenSpec) {
ssm.mu.Lock()
defer ssm.mu.Unlock()
ssm.delegate.deleteCredential(tokenSpec)
}
226 changes: 223 additions & 3 deletions secure_storage_manager_test.go
Original file line number Diff line number Diff line change
@@ -3,9 +3,200 @@
package gosnowflake

import (
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
)

func TestBuildCredCacheDirPath(t *testing.T) {
skipOnWindows(t, "permission model is different")
testRoot1, err := os.MkdirTemp("", "")
assertNilF(t, err)
defer os.RemoveAll(testRoot1)
testRoot2, err := os.MkdirTemp("", "")
assertNilF(t, err)
defer os.RemoveAll(testRoot2)

env1 := overrideEnv("CACHE_DIR_TEST_NOT_EXISTING", "/tmp/not_existing_dir")
defer env1.rollback()
env2 := overrideEnv("CACHE_DIR_TEST_1", testRoot1)
defer env2.rollback()
env3 := overrideEnv("CACHE_DIR_TEST_2", testRoot2)
defer env3.rollback()

t.Run("cannot find any dir", func(t *testing.T) {
_, err := buildCredCacheDirPath([]cacheDirConf{
{envVar: "CACHE_DIR_TEST_NOT_EXISTING"},
})
assertEqualE(t, err.Error(), "no credentials cache directory found")
_, err = os.Stat("/tmp/not_existing_dir")
assertStringContainsE(t, err.Error(), "no such file or directory")
})

t.Run("should use first dir that exists", func(t *testing.T) {
path, err := buildCredCacheDirPath([]cacheDirConf{
{envVar: "CACHE_DIR_TEST_NOT_EXISTING"},
{envVar: "CACHE_DIR_TEST_1"},
})
assertNilF(t, err)
assertEqualE(t, path, testRoot1)
stat, err := os.Stat(testRoot1)
assertNilF(t, err)
assertEqualE(t, stat.Mode(), 0700|os.ModeDir)
})

t.Run("should use first dir that exists and append segments", func(t *testing.T) {
path, err := buildCredCacheDirPath([]cacheDirConf{
{envVar: "CACHE_DIR_TEST_NOT_EXISTING"},
{envVar: "CACHE_DIR_TEST_2", pathSegments: []string{"sub1", "sub2"}},
})
assertNilF(t, err)
assertEqualE(t, path, filepath.Join(testRoot2, "sub1", "sub2"))
stat, err := os.Stat(testRoot2)
assertNilF(t, err)
assertEqualE(t, stat.Mode(), 0700|os.ModeDir)
})
}

func TestSnowflakeFileBasedSecureStorageManager(t *testing.T) {
skipOnWindows(t, "file system permission is different")
credCacheDir, err := os.MkdirTemp("", "")
assertNilF(t, err)
assertNilF(t, os.MkdirAll(credCacheDir, os.ModePerm))
credCacheDirEnvOverride := overrideEnv(credCacheDirEnv, credCacheDir)
defer credCacheDirEnvOverride.rollback()
ssm, err := newFileBasedSecureStorageManager()
assertNilF(t, err)

t.Run("store single token", func(t *testing.T) {
tokenSpec := newMfaTokenSpec("host.com", "johndoe")
cred := "token123"
ssm.setCredential(tokenSpec, cred)
assertEqualE(t, ssm.getCredential(tokenSpec), cred)
ssm.deleteCredential(tokenSpec)
assertEqualE(t, ssm.getCredential(tokenSpec), "")
})

t.Run("store tokens of different types, hosts and users", func(t *testing.T) {
mfaTokenSpec := newMfaTokenSpec("host.com", "johndoe")
mfaCred := "token12"
idTokenSpec := newIDTokenSpec("host.com", "johndoe")
idCred := "token34"
idTokenSpec2 := newIDTokenSpec("host.org", "johndoe")
idCred2 := "token56"
idTokenSpec3 := newIDTokenSpec("host.com", "someoneelse")
idCred3 := "token78"
ssm.setCredential(mfaTokenSpec, mfaCred)
ssm.setCredential(idTokenSpec, idCred)
ssm.setCredential(idTokenSpec2, idCred2)
ssm.setCredential(idTokenSpec3, idCred3)
assertEqualE(t, ssm.getCredential(mfaTokenSpec), mfaCred)
assertEqualE(t, ssm.getCredential(idTokenSpec), idCred)
assertEqualE(t, ssm.getCredential(idTokenSpec2), idCred2)
assertEqualE(t, ssm.getCredential(idTokenSpec3), idCred3)
ssm.deleteCredential(mfaTokenSpec)
assertEqualE(t, ssm.getCredential(mfaTokenSpec), "")
assertEqualE(t, ssm.getCredential(idTokenSpec), idCred)
assertEqualE(t, ssm.getCredential(idTokenSpec2), idCred2)
assertEqualE(t, ssm.getCredential(idTokenSpec3), idCred3)
})

t.Run("override single token", func(t *testing.T) {
mfaTokenSpec := newMfaTokenSpec("host.com", "johndoe")
mfaCred := "token123"
idTokenSpec := newIDTokenSpec("host.com", "johndoe")
idCred := "token456"
ssm.setCredential(mfaTokenSpec, mfaCred)
ssm.setCredential(idTokenSpec, idCred)
assertEqualE(t, ssm.getCredential(mfaTokenSpec), mfaCred)
mfaCredOverride := "token789"
ssm.setCredential(mfaTokenSpec, mfaCredOverride)
assertEqualE(t, ssm.getCredential(mfaTokenSpec), mfaCredOverride)
ssm.setCredential(idTokenSpec, idCred)
})

t.Run("unlock stale cache", func(t *testing.T) {
tokenSpec := newMfaTokenSpec("stale", "cache")
assertNilF(t, os.Mkdir(ssm.lockPath(), 0700))
time.Sleep(1000 * time.Millisecond)
ssm.setCredential(tokenSpec, "unlocked")
assertEqualE(t, ssm.getCredential(tokenSpec), "unlocked")
})

t.Run("wait for other process to unlock cache", func(t *testing.T) {
tokenSpec := newMfaTokenSpec("stale", "cache")
startTime := time.Now()
assertNilF(t, os.Mkdir(ssm.lockPath(), 0700))
time.Sleep(500 * time.Millisecond)
go func() {
time.Sleep(500 * time.Millisecond)
assertNilF(t, os.Remove(ssm.lockPath()))
}()
ssm.setCredential(tokenSpec, "unlocked")
totalDurationMillis := time.Since(startTime).Milliseconds()
assertEqualE(t, ssm.getCredential(tokenSpec), "unlocked")
assertTrueE(t, totalDurationMillis > 1000 && totalDurationMillis < 1200)
})

t.Run("should not modify keys other than tokens", func(t *testing.T) {
content := []byte(`{
"otherKey": "otherValue"
}`)
err = os.WriteFile(ssm.credFilePath(), content, 0600)
assertNilF(t, err)
ssm.setCredential(newMfaTokenSpec("somehost.com", "someUser"), "someToken")
result, err := os.ReadFile(ssm.credFilePath())
assertNilF(t, err)
assertStringContainsE(t, string(result), `"otherKey":"otherValue"`)
})

t.Run("should not modify file if it has wrong permission", func(t *testing.T) {
tokenSpec := newMfaTokenSpec("somehost.com", "someUser")
ssm.setCredential(tokenSpec, "initialValue")
assertEqualE(t, ssm.getCredential(tokenSpec), "initialValue")
err = os.Chmod(ssm.credFilePath(), 0644)
assertNilF(t, err)
defer func() {
assertNilE(t, os.Chmod(ssm.credFilePath(), 0600))
}()
ssm.setCredential(tokenSpec, "newValue")
assertEqualE(t, ssm.getCredential(tokenSpec), "")
fileContent, err := os.ReadFile(ssm.credFilePath())
assertNilF(t, err)
var m map[string]any
err = json.Unmarshal(fileContent, &m)
assertNilF(t, err)
cacheKey, err := tokenSpec.buildKey()
assertNilF(t, err)
tokens := m["tokens"].(map[string]any)
assertEqualE(t, tokens[cacheKey], "initialValue")
})

t.Run("should not modify file if its dir has wrong permission", func(t *testing.T) {
tokenSpec := newMfaTokenSpec("somehost.com", "someUser")
ssm.setCredential(tokenSpec, "initialValue")
assertEqualE(t, ssm.getCredential(tokenSpec), "initialValue")
err = os.Chmod(ssm.credDirPath, 0777)
assertNilF(t, err)
defer func() {
assertNilE(t, os.Chmod(ssm.credDirPath, 0700))
}()
ssm.setCredential(tokenSpec, "newValue")
assertEqualE(t, ssm.getCredential(tokenSpec), "")
fileContent, err := os.ReadFile(ssm.credFilePath())
assertNilF(t, err)
var m map[string]any
err = json.Unmarshal(fileContent, &m)
assertNilF(t, err)
cacheKey, err := tokenSpec.buildKey()
assertNilF(t, err)
tokens := m["tokens"].(map[string]any)
assertEqualE(t, tokens[cacheKey], "initialValue")
})
}

func TestSetAndGetCredentialMfa(t *testing.T) {
for _, tokenSpec := range []*secureTokenSpec{
newMfaTokenSpec("testhost", "testuser"),
@@ -25,6 +216,34 @@ func TestSetAndGetCredentialMfa(t *testing.T) {
}
}

func TestSkipStoringCredentialIfUserIsEmpty(t *testing.T) {
tokenSpecs := []*secureTokenSpec{
newMfaTokenSpec("mfaHost.com", ""),
newIDTokenSpec("idHost.com", ""),
}

for _, tokenSpec := range tokenSpecs {
t.Run(tokenSpec.host, func(t *testing.T) {
credentialsStorage.setCredential(tokenSpec, "non-empty-value")
assertEqualE(t, credentialsStorage.getCredential(tokenSpec), "")
})
}
}

func TestSkipStoringCredentialIfHostIsEmpty(t *testing.T) {
tokenSpecs := []*secureTokenSpec{
newMfaTokenSpec("", "mfaUser"),
newIDTokenSpec("", "idUser"),
}

for _, tokenSpec := range tokenSpecs {
t.Run(tokenSpec.user, func(t *testing.T) {
credentialsStorage.setCredential(tokenSpec, "non-empty-value")
assertEqualE(t, credentialsStorage.getCredential(tokenSpec), "")
})
}
}

func TestStoreTemporaryCredental(t *testing.T) {
if runningOnGithubAction() {
t.Skip("cannot write to github file system")
@@ -58,11 +277,12 @@ func TestBuildCredentialsKey(t *testing.T) {
credType tokenType
out string
}{
{"testaccount.snowflakecomputing.com", "testuser", "mfaToken", "TESTACCOUNT.SNOWFLAKECOMPUTING.COM:TESTUSER:SNOWFLAKE-GO-DRIVER:MFATOKEN"},
{"testaccount.snowflakecomputing.com", "testuser", "IdToken", "TESTACCOUNT.SNOWFLAKECOMPUTING.COM:TESTUSER:SNOWFLAKE-GO-DRIVER:IDTOKEN"},
{"testaccount.snowflakecomputing.com", "testuser", "mfaToken", "c4e781475e7a5e74aca87cd462afafa8cc48ebff6f6ccb5054b894dae5eb6345"}, // pragma: allowlist secret
{"testaccount.snowflakecomputing.com", "testuser", "IdToken", "5014e26489992b6ea56b50e936ba85764dc51338f60441bdd4a69eac7e15bada"}, // pragma: allowlist secret
}
for _, test := range testcases {
target := buildCredentialsKey(test.host, test.user, test.credType)
target, err := buildCredentialsKey(test.host, test.user, test.credType)
assertNilF(t, err)
if target != test.out {
t.Fatalf("failed to convert target. expected: %v, but got: %v", test.out, target)
}
25 changes: 25 additions & 0 deletions util_test.go
Original file line number Diff line number Diff line change
@@ -404,6 +404,12 @@ func skipOnMac(t *testing.T, reason string) {
}
}

func skipOnWindows(t *testing.T, reason string) {
if runtime.GOOS == "windows" {
t.Skip("skipped on Windows: " + reason)
}
}

func randomString(n int) string {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
alpha := []rune("abcdefghijklmnopqrstuvwxyz")
@@ -440,3 +446,22 @@ func TestInternal(t *testing.T) {
ctx = WithInternal(ctx)
assertTrueE(t, isInternal(ctx))
}

type envOverride struct {
envName string
oldValue string
}

func (e *envOverride) rollback() {
if e.oldValue != "" {
os.Setenv(e.envName, e.oldValue)
} else {
os.Unsetenv(e.envName)
}
}

func overrideEnv(env string, value string) envOverride {
oldValue := os.Getenv(env)
os.Setenv(env, value)
return envOverride{env, oldValue}
}