Skip to content

Commit 9abdb01

Browse files
committed
- Renamed cache file to credential_cache_v1.json
- Added sha256 checksum for credentials key - Fixed some issues - Added more tests
1 parent d1e6bd0 commit 9abdb01

3 files changed

+160
-95
lines changed

secure_storage_manager.go

+57-62
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
package gosnowflake
44

55
import (
6+
"crypto/sha256"
7+
"encoding/hex"
68
"encoding/json"
79
"errors"
810
"fmt"
911
"github.com/99designs/keyring"
10-
"golang.org/x/sys/unix"
1112
"os"
1213
"path/filepath"
1314
"runtime"
@@ -24,9 +25,20 @@ const (
2425

2526
const (
2627
credCacheDirEnv = "SF_TEMPORARY_CREDENTIAL_CACHE_DIR"
27-
credCacheFileName = "temporary_credential.json"
28+
credCacheFileName = "credential_cache_v1.json"
2829
)
2930

31+
type cacheDirConf struct {
32+
envVar string
33+
pathSegments []string
34+
}
35+
36+
var defaultLinuxCacheDirConf = []cacheDirConf{
37+
{envVar: credCacheDirEnv, pathSegments: []string{}},
38+
{envVar: "XDG_CACHE_DIR", pathSegments: []string{"snowflake"}},
39+
{envVar: "HOME", pathSegments: []string{".cache", "snowflake"}},
40+
}
41+
3042
type secureTokenSpec struct {
3143
host, user string
3244
tokenType tokenType
@@ -72,6 +84,7 @@ func newSecureStorageManager() secureStorageManager {
7284
case "darwin", "windows":
7385
return newKeyringBasedSecureStorageManager()
7486
default:
87+
logger.Infof("OS %v does not support credentials cache", runtime.GOOS)
7588
return newNoopSecureStorageManager()
7689
}
7790
}
@@ -81,27 +94,16 @@ type fileBasedSecureStorageManager struct {
8194
}
8295

8396
func newFileBasedSecureStorageManager() (*fileBasedSecureStorageManager, error) {
84-
credDirPath := buildCredCacheDirPath()
85-
if credDirPath == "" {
86-
return nil, fmt.Errorf("failed to build cache dir path")
97+
credDirPath, err := buildCredCacheDirPath(defaultLinuxCacheDirConf)
98+
if err != nil {
99+
return nil, err
87100
}
88101
ssm := &fileBasedSecureStorageManager{
89102
credDirPath: credDirPath,
90103
}
91104
return ssm, nil
92105
}
93106

94-
func (ssm *fileBasedSecureStorageManager) createCacheDir(credCacheDir string) error {
95-
_, err := os.Stat(credCacheDir)
96-
if os.IsNotExist(err) {
97-
if err = os.MkdirAll(credCacheDir, os.ModePerm); err != nil {
98-
return fmt.Errorf("failed to create cache directory. %v, err: %v", credCacheDir, err)
99-
}
100-
return nil
101-
}
102-
return err
103-
}
104-
105107
func lookupCacheDir(envVar string, pathSegments ...string) (string, error) {
106108
envVal := os.Getenv(envVar)
107109
if envVal == "" {
@@ -110,27 +112,21 @@ func lookupCacheDir(envVar string, pathSegments ...string) (string, error) {
110112

111113
fileInfo, err := os.Stat(envVal)
112114
if err != nil {
113-
return "", fmt.Errorf("failed to stat %s=%s, due to %w", envVar, envVal, err)
115+
return "", fmt.Errorf("failed to stat %s=%s, due to %v", envVar, envVal, err)
114116
}
115117

116118
if !fileInfo.IsDir() {
117119
return "", fmt.Errorf("environment variable %s=%s is not a directory", envVar, envVal)
118120
}
119121

120-
cacheDir := envVal
122+
cacheDir := filepath.Join(envVal, filepath.Join(pathSegments...))
121123

122-
if len(pathSegments) > 0 {
123-
for _, pathSegment := range pathSegments {
124-
err := os.Mkdir(pathSegment, os.ModePerm)
125-
if err != nil {
126-
return "", fmt.Errorf("failed to create cache directory. %v, err: %w", pathSegment, err)
127-
}
128-
cacheDir = filepath.Join(cacheDir, pathSegment)
129-
}
130-
fileInfo, err = os.Stat(cacheDir)
131-
if err != nil {
132-
return "", fmt.Errorf("failed to stat %s=%s, due to %w", envVar, cacheDir, err)
133-
}
124+
if err = os.MkdirAll(cacheDir, os.FileMode(0o755)); err != nil {
125+
return "", err
126+
}
127+
fileInfo, err = os.Stat(cacheDir)
128+
if err != nil {
129+
return "", fmt.Errorf("failed to stat %s=%s, due to %w", envVar, cacheDir, err)
134130
}
135131

136132
if fileInfo.Mode().Perm() != 0o700 {
@@ -143,27 +139,18 @@ func lookupCacheDir(envVar string, pathSegments ...string) (string, error) {
143139
return cacheDir, nil
144140
}
145141

146-
func buildCredCacheDirPath() string {
147-
type cacheDirConf struct {
148-
envVar string
149-
pathSegments []string
150-
}
151-
confs := []cacheDirConf{
152-
{envVar: credCacheDirEnv, pathSegments: []string{}},
153-
{envVar: "XDG_CACHE_DIR", pathSegments: []string{"snowflake"}},
154-
{envVar: "HOME", pathSegments: []string{".cache", "snowflake"}},
155-
}
142+
func buildCredCacheDirPath(confs []cacheDirConf) (string, error) {
156143
for _, conf := range confs {
157144
path, err := lookupCacheDir(conf.envVar, conf.pathSegments...)
158145
if err != nil {
159-
logger.Debugf("Skipping %s in cache directory lookup due to %w", conf.envVar, err)
146+
logger.Debugf("Skipping %s in cache directory lookup due to %v", conf.envVar, err)
160147
} else {
161148
logger.Infof("Using %s as cache directory", path)
162-
return path
149+
return path, nil
163150
}
164151
}
165152

166-
return ""
153+
return "", errors.New("no credentials cache directory found")
167154
}
168155

169156
func (ssm *fileBasedSecureStorageManager) getTokens(data map[string]any) map[string]interface{} {
@@ -198,26 +185,23 @@ func (ssm *fileBasedSecureStorageManager) setCredential(tokenSpec *secureTokenSp
198185
err = ssm.writeTemporaryCacheFile(credCache)
199186
if err != nil {
200187
logger.Warnf("Set credential failed. Unable to write cache. %v", err)
201-
return
202188
}
203-
204-
return
205189
}
206190

207191
func (ssm *fileBasedSecureStorageManager) lockPath() string {
208192
return filepath.Join(ssm.credDirPath, credCacheFileName+".lck")
209193
}
210194

211195
func (ssm *fileBasedSecureStorageManager) lockFile() error {
212-
const NUM_RETRIES = 10
213-
const RETRY_INTERVAL = 100 * time.Millisecond
196+
const numRetries = 10
197+
const retryInterval = 100 * time.Millisecond
214198
lockPath := ssm.lockPath()
215199
locked := false
216-
for i := 0; i < NUM_RETRIES; i++ {
200+
for i := 0; i < numRetries; i++ {
217201
err := os.Mkdir(lockPath, 0o700)
218202
if err != nil {
219203
if errors.Is(err, os.ErrExist) {
220-
time.Sleep(RETRY_INTERVAL)
204+
time.Sleep(retryInterval)
221205
continue
222206
}
223207
return fmt.Errorf("failed to create cache lock: %v, err: %v", lockPath, err)
@@ -228,13 +212,13 @@ func (ssm *fileBasedSecureStorageManager) lockFile() error {
228212

229213
if !locked {
230214
logger.Warnf("failed to lock cache lock. lockPath: %v.", lockPath)
231-
var stat unix.Stat_t
232-
err := unix.Stat(lockPath, &stat)
233-
if err != nil {
215+
fileInfo, err := os.Stat(lockPath)
216+
if err != nil && !errors.Is(err, os.ErrNotExist) {
234217
return fmt.Errorf("failed to stat %v and determine if lock is stale. err: %v", lockPath, err)
235218
}
236219

237-
if stat.Ctim.Nano()+time.Second.Nanoseconds() < time.Now().UnixNano() {
220+
if fileInfo.ModTime().Add(time.Second).UnixNano() < time.Now().UnixNano() {
221+
logger.Debugf("removing credentials cache lock file, stale for %v", time.Now().UnixNano()-fileInfo.ModTime().UnixNano())
238222
err := os.Remove(lockPath)
239223
if err != nil {
240224
return fmt.Errorf("failed to remove %v while trying to remove stale lock. err: %v", lockPath, err)
@@ -290,7 +274,7 @@ func (ssm *fileBasedSecureStorageManager) ensurePermissions() error {
290274
}
291275

292276
if dirInfo.Mode().Perm() != 0o700 {
293-
return fmt.Errorf("incorrect permissions(%o, expected 700) for %s.", dirInfo.Mode().Perm(), ssm.credDirPath)
277+
return fmt.Errorf("incorrect permissions(%o, expected 700) for %s", dirInfo.Mode().Perm(), ssm.credDirPath)
294278
}
295279

296280
fileInfo, err := os.Stat(ssm.credFilePath())
@@ -302,7 +286,7 @@ func (ssm *fileBasedSecureStorageManager) ensurePermissions() error {
302286
logger.Debugf("Incorrect permissions(%o, expected 600) for credential file.", fileInfo.Mode().Perm())
303287
err := os.Chmod(ssm.credFilePath(), 0o600)
304288
if err != nil {
305-
return fmt.Errorf("Failed to chmod credential file: %v", err)
289+
return fmt.Errorf("failed to chmod credential file: %v", err)
306290
}
307291
logger.Debug("Successfully fixed credential file permissions.")
308292
}
@@ -324,7 +308,7 @@ func (ssm *fileBasedSecureStorageManager) readTemporaryCacheFile() map[string]an
324308
}
325309

326310
credentialsMap := map[string]any{}
327-
err = json.Unmarshal([]byte(jsonData), &credentialsMap)
311+
err = json.Unmarshal(jsonData, &credentialsMap)
328312
if err != nil {
329313
logger.Warnf("Failed to unmarshal credential cache file. %v.\n", err)
330314
}
@@ -347,10 +331,7 @@ func (ssm *fileBasedSecureStorageManager) deleteCredential(tokenSpec *secureToke
347331
err = ssm.writeTemporaryCacheFile(credCache)
348332
if err != nil {
349333
logger.Warnf("Set credential failed. Unable to write cache. %v", err)
350-
return
351334
}
352-
353-
return
354335
}
355336

356337
func (ssm *fileBasedSecureStorageManager) writeTemporaryCacheFile(cache map[string]any) error {
@@ -359,6 +340,18 @@ func (ssm *fileBasedSecureStorageManager) writeTemporaryCacheFile(cache map[stri
359340
return fmt.Errorf("failed to marshal credential cache map. %w", err)
360341
}
361342

343+
stat, err := os.Stat(ssm.credFilePath())
344+
if err != nil && !errors.Is(err, os.ErrNotExist) {
345+
return err
346+
}
347+
if err == nil {
348+
if stat.Mode().String() != "-rw-------" {
349+
if err = os.Chmod(ssm.credFilePath(), 0600); err != nil {
350+
return fmt.Errorf("cannot chmod file %v to 600. %v", ssm.credFilePath(), err)
351+
}
352+
}
353+
}
354+
362355
err = os.WriteFile(ssm.credFilePath(), bytes, 0600)
363356
if err != nil {
364357
return fmt.Errorf("failed to write the credential cache file: %w", err)
@@ -462,8 +455,10 @@ func (ssm *keyringSecureStorageManager) deleteCredential(tokenSpec *secureTokenS
462455
}
463456

464457
func buildCredentialsKey(host, user string, credType tokenType) string {
465-
credTypeStr := string(credType)
466-
return host + ":" + user + ":" + credTypeStr
458+
plainCredKey := host + ":" + user + ":" + string(credType)
459+
checksum := sha256.New()
460+
checksum.Write([]byte(plainCredKey))
461+
return hex.EncodeToString(checksum.Sum(nil))
467462
}
468463

469464
type noopSecureStorageManager struct {

secure_storage_manager_test.go

+81-30
Original file line numberDiff line numberDiff line change
@@ -4,44 +4,95 @@ package gosnowflake
44

55
import (
66
"os"
7+
"path/filepath"
78
"testing"
9+
"time"
810
)
911

10-
type EnvOverride struct {
11-
env string
12-
oldValue string
13-
}
12+
func TestBuildCredCacheDirPath(t *testing.T) {
13+
skipOnWindows(t, "permission model is different")
14+
testRoot1, err := os.MkdirTemp("", "")
15+
assertNilF(t, err)
16+
defer os.RemoveAll(testRoot1)
17+
testRoot2, err := os.MkdirTemp("", "")
18+
assertNilF(t, err)
19+
defer os.RemoveAll(testRoot2)
1420

15-
func (e *EnvOverride) rollback() {
16-
if e.oldValue != "" {
17-
os.Setenv(e.env, e.oldValue)
18-
} else {
19-
os.Unsetenv(e.env)
20-
}
21-
}
21+
assertNilF(t, os.Setenv("CACHE_DIR_TEST_NOT_EXISTING", "/tmp/not_existing_dir"))
22+
assertNilF(t, os.Setenv("CACHE_DIR_TEST_1", testRoot1))
23+
assertNilF(t, os.Setenv("CACHE_DIR_TEST_2", testRoot2))
24+
25+
t.Run("cannot find any dir", func(t *testing.T) {
26+
_, err := buildCredCacheDirPath([]cacheDirConf{
27+
{envVar: "CACHE_DIR_TEST_NOT_EXISTING"},
28+
})
29+
assertEqualE(t, err.Error(), "no credentials cache directory found")
30+
_, err = os.Stat("/tmp/not_existing_dir")
31+
assertStringContainsE(t, err.Error(), "no such file or directory")
32+
})
33+
34+
t.Run("should use first dir that exists", func(t *testing.T) {
35+
path, err := buildCredCacheDirPath([]cacheDirConf{
36+
{envVar: "CACHE_DIR_TEST_NOT_EXISTING"},
37+
{envVar: "CACHE_DIR_TEST_1"},
38+
})
39+
assertNilF(t, err)
40+
assertEqualE(t, path, testRoot1)
41+
stat, err := os.Stat(testRoot1)
42+
assertNilF(t, err)
43+
assertEqualE(t, stat.Mode().String(), "drwx------")
44+
})
2245

23-
func override_env(env string, value string) EnvOverride {
24-
oldValue := os.Getenv(env)
25-
os.Setenv(env, value)
26-
return EnvOverride{env, oldValue}
46+
t.Run("should use first dir that exists and append segments", func(t *testing.T) {
47+
path, err := buildCredCacheDirPath([]cacheDirConf{
48+
{envVar: "CACHE_DIR_TEST_NOT_EXISTING"},
49+
{envVar: "CACHE_DIR_TEST_2", pathSegments: []string{"sub1", "sub2"}},
50+
})
51+
assertNilF(t, err)
52+
assertEqualE(t, path, filepath.Join(testRoot2, "sub1", "sub2"))
53+
stat, err := os.Stat(testRoot2)
54+
assertNilF(t, err)
55+
assertEqualE(t, stat.Mode().String(), "drwx------")
56+
})
2757
}
2858

2959
func TestSnowflakeFileBasedSecureStorageManager(t *testing.T) {
30-
//skipOnNonLinux(t, "Not supported on non-linux")
31-
os.Mkdir("./testdata", 0777)
32-
credCacheDirEnvOverride := override_env(credCacheDirEnv, "./testdata")
60+
skipOnWindows(t, "file system permission is different")
61+
credCacheDir, err := os.MkdirTemp("", "")
62+
assertNilF(t, err)
63+
assertNilF(t, os.MkdirAll(credCacheDir, 0777))
64+
credCacheDirEnvOverride := overrideEnv(credCacheDirEnv, credCacheDir)
3365
defer credCacheDirEnvOverride.rollback()
34-
fbss, err := newFileBasedSecureStorageManager()
35-
if err != nil {
36-
t.Fatal(err)
37-
}
66+
ssm, err := newFileBasedSecureStorageManager()
67+
assertNilF(t, err)
68+
69+
t.Run("success", func(t *testing.T) {
70+
tokenSpec := newMfaTokenSpec("host.com", "johndoe")
71+
cred := "token123"
72+
ssm.setCredential(tokenSpec, cred)
73+
assertEqualE(t, ssm.getCredential(tokenSpec), cred)
74+
ssm.deleteCredential(tokenSpec)
75+
assertEqualE(t, ssm.getCredential(tokenSpec), "")
76+
})
77+
78+
t.Run("unlock stale cache", func(t *testing.T) {
79+
startTime := time.Now()
80+
assertNilF(t, os.Mkdir(ssm.lockPath(), 0o700))
81+
ssm.getCredential(newMfaTokenSpec("ignored", "ignored"))
82+
assertTrueE(t, time.Since(startTime).Milliseconds() > 1000)
83+
})
3884

39-
tokenSpec := newMfaTokenSpec("host.xd", "johndoe")
40-
cred := "token123"
41-
fbss.setCredential(tokenSpec, cred)
42-
assertEqualE(t, fbss.getCredential(tokenSpec), cred)
43-
fbss.deleteCredential(tokenSpec)
44-
assertEqualE(t, fbss.getCredential(tokenSpec), "")
85+
t.Run("should not modify keys other than tokens", func(t *testing.T) {
86+
content := []byte(`{
87+
"otherKey": "otherValue"
88+
}`)
89+
err = os.WriteFile(ssm.credFilePath(), content, 0o600)
90+
assertNilF(t, err)
91+
ssm.setCredential(newMfaTokenSpec("somehost.com", "someUser"), "someToken")
92+
result, err := os.ReadFile(ssm.credFilePath())
93+
assertNilF(t, err)
94+
assertStringContainsE(t, string(result), `"otherKey":"otherValue"`)
95+
})
4596
}
4697

4798
func TestSetAndGetCredentialMfa(t *testing.T) {
@@ -96,8 +147,8 @@ func TestBuildCredentialsKey(t *testing.T) {
96147
credType tokenType
97148
out string
98149
}{
99-
{"testaccount.snowflakecomputing.com", "testuser", "mfaToken", "TESTACCOUNT.SNOWFLAKECOMPUTING.COM:TESTUSER:SNOWFLAKE-GO-DRIVER:MFATOKEN"},
100-
{"testaccount.snowflakecomputing.com", "testuser", "IdToken", "TESTACCOUNT.SNOWFLAKECOMPUTING.COM:TESTUSER:SNOWFLAKE-GO-DRIVER:IDTOKEN"},
150+
{"testaccount.snowflakecomputing.com", "testuser", "mfaToken", "c4e781475e7a5e74aca87cd462afafa8cc48ebff6f6ccb5054b894dae5eb6345"}, // pragma: allowlist secret
151+
{"testaccount.snowflakecomputing.com", "testuser", "IdToken", "5014e26489992b6ea56b50e936ba85764dc51338f60441bdd4a69eac7e15bada"}, // pragma: allowlist secret
101152
}
102153
for _, test := range testcases {
103154
target := buildCredentialsKey(test.host, test.user, test.credType)

0 commit comments

Comments
 (0)