Skip to content

Commit

Permalink
Fix AWS credential expiration
Browse files Browse the repository at this point in the history
Use the AWS cached credential provider to automatically handle credentials. The CredentialsCache will automatically handle refreshing expired credentials and keeping them cached as long as necessary.

Replaces prometheus-community#634 as this offloads more of the work to the AWS SDK

Signed-off-by: Joe Adams <[email protected]>
  • Loading branch information
sysadmind committed Oct 17, 2022
1 parent 70152fe commit 1e64068
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions pkg/roundtripper/roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const (

type AWSSigningTransport struct {
t http.RoundTripper
creds aws.Credentials
creds aws.CredentialsProvider
region string
log log.Logger
}
Expand All @@ -48,12 +48,17 @@ func NewAWSSigningTransport(transport http.RoundTripper, region string, log log.
return nil, err
}

creds, err := cfg.Credentials.Retrieve(context.Background())
// Run a single fetch credentials operation to ensure that the credentials
// are valid before returning the transport.
_, err = cfg.Credentials.Retrieve(context.Background())
if err != nil {
_ = level.Error(log).Log("msg", "fail to retrive aws credentials", "err", err)
return nil, err
}

// Build a cached credentials provider to manage the credentials and prevent new credentials on every request.
creds := aws.NewCredentialsCache(cfg.Credentials)

return &AWSSigningTransport{
t: transport,
region: region,
Expand All @@ -69,8 +74,15 @@ func (a *AWSSigningTransport) RoundTrip(req *http.Request) (*http.Response, erro
_ = level.Error(a.log).Log("msg", "fail to hash request body", "err", err)
return nil, err
}

creds, err := a.creds.Retrieve(context.Background())
if err != nil {
_ = level.Error(a.log).Log("msg", "fail to retrive aws credentials", "err", err)
return nil, err
}

req.Body = newReader
err = signer.SignHTTP(context.Background(), a.creds, req, payloadHash, service, a.region, time.Now())
err = signer.SignHTTP(context.Background(), creds, req, payloadHash, service, a.region, time.Now())
if err != nil {
_ = level.Error(a.log).Log("msg", "fail to sign request body", "err", err)
return nil, err
Expand Down

0 comments on commit 1e64068

Please sign in to comment.