Skip to content

Commit e3b1c37

Browse files
committed
Create cache dir with correct perm
1 parent ffad6d7 commit e3b1c37

File tree

2 files changed

+26
-31
lines changed

2 files changed

+26
-31
lines changed

secure_storage_manager.go

+23-28
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,14 @@ func lookupCacheDir(envVar string, pathSegments ...string) (string, error) {
124124
}
125125

126126
cacheDir := filepath.Join(envVal, filepath.Join(pathSegments...))
127+
parentOfCacheDir := cacheDir[:strings.LastIndex(cacheDir, "/")]
127128

128-
if err = os.MkdirAll(cacheDir, os.FileMode(0o755)); err != nil {
129+
if err = os.MkdirAll(parentOfCacheDir, os.FileMode(0755)); err != nil {
129130
return "", err
130131
}
131132

132-
if err = os.Chmod(cacheDir, os.FileMode(0700)); err != nil {
133+
// We don't check if permissions are incorrect here if a directory exists, because we check it later.
134+
if err = os.Mkdir(cacheDir, os.FileMode(0700)); err != nil && !errors.Is(err, os.ErrExist) {
133135
return "", err
134136
}
135137

@@ -164,6 +166,17 @@ func (ssm *fileBasedSecureStorageManager) getTokens(data map[string]any) map[str
164166
return tokens
165167
}
166168

169+
func (ssm *fileBasedSecureStorageManager) withLock(action func(cacheFile *os.File)) {
170+
err := ssm.lockFile()
171+
if err != nil {
172+
logger.Warnf("Unable to lock cache. %v", err)
173+
return
174+
}
175+
defer ssm.unlockFile()
176+
177+
ssm.withCacheFile(action)
178+
}
179+
167180
func (ssm *fileBasedSecureStorageManager) withCacheFile(action func(*os.File)) {
168181
cacheFile, err := os.OpenFile(ssm.credFilePath(), os.O_CREATE|os.O_RDWR, 0600)
169182
if err != nil {
@@ -184,14 +197,8 @@ func (ssm *fileBasedSecureStorageManager) setCredential(tokenSpec *secureTokenSp
184197
logger.Warn(err)
185198
return
186199
}
187-
err = ssm.lockFile()
188-
if err != nil {
189-
logger.Warnf("Set credential failed. Unable to lock cache. %v", err)
190-
return
191-
}
192-
defer ssm.unlockFile()
193200

194-
ssm.withCacheFile(func(cacheFile *os.File) {
201+
ssm.withLock(func(cacheFile *os.File) {
195202
credCache, err := ssm.readTemporaryCacheFile(cacheFile)
196203
if err != nil {
197204
logger.Warnf("Error while reading cache file. %v", err)
@@ -233,7 +240,7 @@ func (ssm *fileBasedSecureStorageManager) lockFile() error {
233240

234241
locked := false
235242
for i := 0; i < numRetries; i++ {
236-
err := os.Mkdir(lockPath, 0o700)
243+
err := os.Mkdir(lockPath, 0700)
237244
if err != nil {
238245
if errors.Is(err, os.ErrExist) {
239246
time.Sleep(retryInterval)
@@ -264,15 +271,9 @@ func (ssm *fileBasedSecureStorageManager) getCredential(tokenSpec *secureTokenSp
264271
logger.Warn(err)
265272
return ""
266273
}
267-
err = ssm.lockFile()
268-
if err != nil {
269-
logger.Warnf("Failed to lock credential cache file. %v", err)
270-
return ""
271-
}
272-
defer ssm.unlockFile()
273274

274275
ret := ""
275-
ssm.withCacheFile(func(cacheFile *os.File) {
276+
ssm.withLock(func(cacheFile *os.File) {
276277
credCache, err := ssm.readTemporaryCacheFile(cacheFile)
277278
if err != nil {
278279
logger.Warnf("Error while reading cache file. %v", err)
@@ -303,7 +304,7 @@ func (ssm *fileBasedSecureStorageManager) ensurePermissions(cacheFile *os.File)
303304
return err
304305
}
305306

306-
if dirInfo.Mode().Perm() != 0o700&os.ModePerm {
307+
if dirInfo.Mode().Perm() != 0700&os.ModePerm {
307308
return fmt.Errorf("incorrect permissions(%o, expected 700) for %s", dirInfo.Mode().Perm(), ssm.credDirPath)
308309
}
309310

@@ -312,7 +313,7 @@ func (ssm *fileBasedSecureStorageManager) ensurePermissions(cacheFile *os.File)
312313
return err
313314
}
314315

315-
if fileInfo.Mode().Perm() != 0o600&os.ModePerm {
316+
if fileInfo.Mode().Perm() != 0600&os.ModePerm {
316317
return fmt.Errorf("incorrect permissions(%v, expected 600) for credential file", fileInfo.Mode().Perm())
317318
}
318319

@@ -335,15 +336,15 @@ func (ssm *fileBasedSecureStorageManager) ensureOwnerForFile(file *os.File) erro
335336
return ssm.ensureOwner(ownerUID)
336337
}
337338

338-
func (ssm *fileBasedSecureStorageManager) ensureOwner(ownerId uint32) error {
339+
func (ssm *fileBasedSecureStorageManager) ensureOwner(ownerID uint32) error {
339340
currentUser, err := user.Current()
340341
if err != nil {
341342
return err
342343
}
343344
if errors.Is(err, os.ErrNotExist) {
344345
return nil
345346
}
346-
if strconv.Itoa(int(ownerId)) != currentUser.Uid {
347+
if strconv.Itoa(int(ownerID)) != currentUser.Uid {
347348
return errors.New("incorrect owner of " + ssm.credDirPath)
348349
}
349350
return nil
@@ -389,14 +390,8 @@ func (ssm *fileBasedSecureStorageManager) deleteCredential(tokenSpec *secureToke
389390
logger.Warn(err)
390391
return
391392
}
392-
err = ssm.lockFile()
393-
if err != nil {
394-
logger.Warnf("Set credential failed. Unable to lock cache. %v", err)
395-
return
396-
}
397-
defer ssm.unlockFile()
398393

399-
ssm.withCacheFile(func(cacheFile *os.File) {
394+
ssm.withLock(func(cacheFile *os.File) {
400395
credCache, err := ssm.readTemporaryCacheFile(cacheFile)
401396
if err != nil {
402397
logger.Warnf("Error while reading cache file. %v", err)

secure_storage_manager_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func TestSnowflakeFileBasedSecureStorageManager(t *testing.T) {
119119

120120
t.Run("unlock stale cache", func(t *testing.T) {
121121
tokenSpec := newMfaTokenSpec("stale", "cache")
122-
assertNilF(t, os.Mkdir(ssm.lockPath(), 0o700))
122+
assertNilF(t, os.Mkdir(ssm.lockPath(), 0700))
123123
time.Sleep(1000 * time.Millisecond)
124124
ssm.setCredential(tokenSpec, "unlocked")
125125
assertEqualE(t, ssm.getCredential(tokenSpec), "unlocked")
@@ -128,7 +128,7 @@ func TestSnowflakeFileBasedSecureStorageManager(t *testing.T) {
128128
t.Run("wait for other process to unlock cache", func(t *testing.T) {
129129
tokenSpec := newMfaTokenSpec("stale", "cache")
130130
startTime := time.Now()
131-
assertNilF(t, os.Mkdir(ssm.lockPath(), 0o700))
131+
assertNilF(t, os.Mkdir(ssm.lockPath(), 0700))
132132
time.Sleep(500 * time.Millisecond)
133133
go func() {
134134
time.Sleep(500 * time.Millisecond)
@@ -144,7 +144,7 @@ func TestSnowflakeFileBasedSecureStorageManager(t *testing.T) {
144144
content := []byte(`{
145145
"otherKey": "otherValue"
146146
}`)
147-
err = os.WriteFile(ssm.credFilePath(), content, 0o600)
147+
err = os.WriteFile(ssm.credFilePath(), content, 0600)
148148
assertNilF(t, err)
149149
ssm.setCredential(newMfaTokenSpec("somehost.com", "someUser"), "someToken")
150150
result, err := os.ReadFile(ssm.credFilePath())

0 commit comments

Comments
 (0)