Skip to content

[do not merge] feat: support account id in imds / new profile configs #3067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .changelog/3a4c3951c2504554a64b14ce2dddf6ef.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"id": "3a4c3951-c250-4554-a64b-14ce2dddf6ef",
"type": "feature",
"description": "Support account ID retrieval in IMDS credentials provider, and support new IMDS profile name config:\n\n1. environment: `AWS_EC2_INSTANCE_PROFILE_NAME`\n2. shared config: `ec2_instance_profile_name`",
"modules": [
"config",
"credentials"
]
}
4 changes: 2 additions & 2 deletions config/config_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ func (f imdsForwarder) Do(r *http.Request) (*http.Response, error) {
header.Set(ttlHeader, r.Header.Get(ttlHeader))
return &http.Response{StatusCode: 200, Header: header, Body: io.NopCloser(strings.NewReader("validToken"))}, nil
}
if r.URL.Path == "/latest/meta-data/iam/security-credentials/" {
if r.URL.Path == "/latest/meta-data/iam/security-credentials-extended/" {
return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("RoleName"))}, nil
}
if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" {
if r.URL.Path == "/latest/meta-data/iam/security-credentials-extended/RoleName" {
return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(ecsResponse))}, nil
}
return f.innerClient.Do(r)
Expand Down
12 changes: 12 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"os"
"reflect"
"strings"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -148,6 +149,17 @@ func TestLoadDefaultConfig(t *testing.T) {
}
}

func TestLoadDefaultConfig_EmptyEC2InstanceProfileName(t *testing.T) {
t.Setenv(awsEc2InstanceProfileNameEnv, "")
_, err := LoadDefaultConfig(context.TODO())
if err == nil {
t.Fatal("expect error, got none")
}
if expect, actual := "env AWS_EC2_INSTANCE_PROFILE_NAME cannot be empty", err.Error(); !strings.Contains(actual, expect) {
t.Fatalf("expect error %s, got %s", expect, actual)
}
}

func BenchmarkLoadProfile1(b *testing.B) {
benchConfigLoad(b, 1)
}
Expand Down
19 changes: 19 additions & 0 deletions config/env_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ const (
awsEc2MetadataDisabledEnv = "AWS_EC2_METADATA_DISABLED"
awsEc2MetadataV1DisabledEnv = "AWS_EC2_METADATA_V1_DISABLED"

awsEc2InstanceProfileNameEnv = "AWS_EC2_INSTANCE_PROFILE_NAME"

awsS3DisableMultiRegionAccessPointsEnv = "AWS_S3_DISABLE_MULTIREGION_ACCESS_POINTS"

awsUseDualStackEndpointEnv = "AWS_USE_DUALSTACK_ENDPOINT"
Expand Down Expand Up @@ -304,6 +306,9 @@ type EnvConfig struct {

// Indicates whether response checksum should be validated
ResponseChecksumValidation aws.ResponseChecksumValidation

// Profile name used for fetching IMDS credentials.
EC2InstanceProfileName string
}

// loadEnvConfig reads configuration values from the OS's environment variables.
Expand Down Expand Up @@ -347,6 +352,12 @@ func NewEnvConfig() (EnvConfig, error) {

cfg.AppID = os.Getenv(awsSdkUaAppIDEnv)

ec2InstanceProfileName, ok := os.LookupEnv(awsEc2InstanceProfileNameEnv)
if ok && ec2InstanceProfileName == "" {
return cfg, fmt.Errorf("env %s cannot be empty", awsEc2InstanceProfileNameEnv)
}
cfg.EC2InstanceProfileName = ec2InstanceProfileName

if err := setBoolPtrFromEnvVal(&cfg.DisableRequestCompression, []string{awsDisableRequestCompressionEnv}); err != nil {
return cfg, err
}
Expand Down Expand Up @@ -916,3 +927,11 @@ func (c EnvConfig) GetS3DisableExpressAuth() (value, ok bool) {

return *c.S3DisableExpressAuth, true
}

func (c EnvConfig) getEC2InstanceProfileName() (string, bool, error) {
if len(c.EC2InstanceProfileName) == 0 {
return "", false, nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was discussed to err on this rather than leaving it empty. My main concern would be that someone would want to set this to something on their shell, like export AWS_EC2_INSTANCE_PROFILE_NAME=$(some_process_that_returns_a_value), and that something returns an empty string, then they would make a request to the base URL rather than to the profile they were trying to use

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was addressed but not here -- it actually uses Lookupenv to check for explicit empty in the main load routine.

}

return c.EC2InstanceProfileName, true, nil
}
8 changes: 8 additions & 0 deletions config/env_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,14 @@ func TestNewEnvConfig(t *testing.T) {
Config: EnvConfig{},
WantErr: true,
},
54: {
Env: map[string]string{
"AWS_EC2_INSTANCE_PROFILE_NAME": "ProfileName",
},
Config: EnvConfig{
EC2InstanceProfileName: "ProfileName",
},
},
}

for i, c := range cases {
Expand Down
16 changes: 16 additions & 0 deletions config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,22 @@ func getEC2RoleCredentialProviderOptions(ctx context.Context, configs configs) (
return
}

type ec2InstanceProfileNameProvider interface {
getEC2InstanceProfileName() (string, bool, error)
}

func getEC2InstanceProfileName(ctx context.Context, configs configs) (v string, found bool, err error) {
for _, config := range configs {
if p, ok := config.(ec2InstanceProfileNameProvider); ok {
v, found, err = p.getEC2InstanceProfileName()
if err != nil || found {
break
}
}
}
return
}

// defaultRegionProvider is an interface for retrieving a default region if a region was not resolved from other sources
type defaultRegionProvider interface {
getDefaultRegion(ctx context.Context) (string, bool, error)
Expand Down
21 changes: 17 additions & 4 deletions config/resolve_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func resolveCredsFromProfile(ctx context.Context, cfg *aws.Config, envConfig *En

default:
ctx = addCredentialSource(ctx, aws.CredentialSourceIMDS)
err = resolveEC2RoleCredentials(ctx, cfg, configs)
err = resolveEC2RoleCredentials(ctx, cfg, envConfig, sharedConfig, configs)
}
if err != nil {
return ctx, err
Expand Down Expand Up @@ -379,7 +379,7 @@ func resolveCredsFromSource(ctx context.Context, cfg *aws.Config, envConfig *Env
switch sharedCfg.CredentialSource {
case credSourceEc2Metadata:
ctx = addCredentialSource(ctx, aws.CredentialSourceIMDS)
return ctx, resolveEC2RoleCredentials(ctx, cfg, configs)
return ctx, resolveEC2RoleCredentials(ctx, cfg, envConfig, sharedCfg, configs)

case credSourceEnvironment:
ctx = addCredentialSource(ctx, aws.CredentialSourceHTTP)
Expand All @@ -402,8 +402,21 @@ func resolveCredsFromSource(ctx context.Context, cfg *aws.Config, envConfig *Env
return ctx, nil
}

func resolveEC2RoleCredentials(ctx context.Context, cfg *aws.Config, configs configs) error {
optFns := make([]func(*ec2rolecreds.Options), 0, 2)
func resolveEC2RoleCredentials(ctx context.Context, cfg *aws.Config, envCfg *EnvConfig, sharedCfg *SharedConfig, configs configs) error {
optFns := make([]func(*ec2rolecreds.Options), 0, 3)

var profile string
if sharedCfg != nil && sharedCfg.EC2InstanceProfileName != "" {
profile = sharedCfg.EC2InstanceProfileName
}
if envCfg != nil && envCfg.EC2InstanceProfileName != "" {
profile = envCfg.EC2InstanceProfileName
}
if profile != "" {
optFns = append(optFns, func(o *ec2rolecreds.Options) {
o.ProfileName = profile // caller options will override
})
}

optFn, found, err := getEC2RoleCredentialProviderOptions(ctx, configs)
if err != nil {
Expand Down
108 changes: 106 additions & 2 deletions config/resolve_credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/internal/awstesting"
"github.com/aws/aws-sdk-go-v2/service/sso"
Expand Down Expand Up @@ -82,9 +83,15 @@ func setupCredentialsEndpoints() (aws.EndpointResolverWithOptions, func()) {

ec2MetadataServer := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" {
if r.URL.Path == "/latest/meta-data/iam/security-credentials-extended/RoleName" {
w.Write([]byte(ec2MetadataResponse))
} else if r.URL.Path == "/latest/meta-data/iam/security-credentials/" {
} else if r.URL.Path == "/latest/meta-data/iam/security-credentials-extended/LoadOptions" {
w.Write([]byte(ec2MetadataResponseLoadOptions))
} else if r.URL.Path == "/latest/meta-data/iam/security-credentials-extended/EnvCfg" {
w.Write([]byte(ec2MetadataResponseEnvCfg))
} else if r.URL.Path == "/latest/meta-data/iam/security-credentials-extended/SharedCfg" {
w.Write([]byte(ec2MetadataResponseSharedCfg))
} else if r.URL.Path == "/latest/meta-data/iam/security-credentials-extended/" {
w.Write([]byte("RoleName"))
} else if r.URL.Path == "/latest/api/token" {
header := w.Header()
Expand Down Expand Up @@ -750,6 +757,103 @@ func TestResolveCredentialsEcsContainer(t *testing.T) {

}

func TestResolveCredentialsEC2RoleCreds(t *testing.T) {
testCases := map[string]struct {
expectedAccessKey string
expectedSecretKey string
envVar map[string]string
configFile string
configProfile string
loadOptions func(*LoadOptions) error
}{
"no config whatsoever": {
expectedAccessKey: "ec2-access-key",
expectedSecretKey: "ec2-secret-key",
envVar: map[string]string{},
configFile: "",
},
"env cfg": {
expectedAccessKey: "ec2-access-key-envcfg",
expectedSecretKey: "ec2-secret-key-envcfg",
envVar: map[string]string{
"AWS_EC2_INSTANCE_PROFILE_NAME": "EnvCfg",
},
configFile: "",
},
"shared cfg": {
expectedAccessKey: "ec2-access-key-sharedcfg",
expectedSecretKey: "ec2-secret-key-sharedcfg",
envVar: map[string]string{},
configFile: filepath.Join("testdata", "config_source_shared"),
configProfile: "ec2metadata-profilename",
},
"loadopts + env cfg + shared cfg": {
expectedAccessKey: "ec2-access-key-loadopts",
expectedSecretKey: "ec2-secret-key-loadopts",
envVar: map[string]string{
"AWS_EC2_INSTANCE_PROFILE_NAME": "EnvCfg",
},
configFile: filepath.Join("testdata", "config_source_shared"),
configProfile: "ec2metadata-profilename",
loadOptions: WithEC2RoleCredentialOptions(func(o *ec2rolecreds.Options) {
o.ProfileName = "LoadOptions"
}),
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
endpointResolver, cleanupFn := setupCredentialsEndpoints()
defer cleanupFn()

// setupCredentialsEndpoints sets this above and then we hold onto
// it for this test
ec2MetadataURL := os.Getenv("AWS_EC2_METADATA_SERVICE_ENDPOINT")

restoreEnv := awstesting.StashEnv()
defer awstesting.PopEnv(restoreEnv)

os.Setenv("AWS_EC2_METADATA_SERVICE_ENDPOINT", ec2MetadataURL)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use t.Setenv instead

for k, v := range tc.envVar {
os.Setenv(k, v)
}
var sharedConfigFiles []string
if tc.configFile != "" {
sharedConfigFiles = append(sharedConfigFiles, tc.configFile)
}
opts := []func(*LoadOptions) error{
WithEndpointResolverWithOptions(endpointResolver),
WithRetryer(func() aws.Retryer { return aws.NopRetryer{} }),
WithSharedConfigFiles(sharedConfigFiles),
WithSharedCredentialsFiles([]string{}),
}
if len(tc.configProfile) != 0 {
opts = append(opts, WithSharedConfigProfile(tc.configProfile))
}

if tc.loadOptions != nil {
opts = append(opts, tc.loadOptions)
}

cfg, err := LoadDefaultConfig(context.TODO(), opts...)
if err != nil {
t.Fatalf("could not load config: %s", err)
}
actual, err := cfg.Credentials.Retrieve(context.TODO())
if err != nil {
t.Fatalf("could not retrieve credentials: %s", err)
}
if actual.AccessKeyID != tc.expectedAccessKey {
t.Errorf("expected access key to be %s, got %s", tc.expectedAccessKey, actual.AccessKeyID)
}
if actual.SecretAccessKey != tc.expectedSecretKey {
t.Errorf("expected secret key to be %s, got %s", tc.expectedSecretKey, actual.SecretAccessKey)
}
})
}

}

type stubErrorClient struct {
err error
}
Expand Down
16 changes: 16 additions & 0 deletions config/shared_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ const (

ec2MetadataV1DisabledKey = "ec2_metadata_v1_disabled"

ec2InstanceProfileNameKey = "ec2_instance_profile_name"

// Use DualStack Endpoint Resolution
useDualStackEndpoint = "use_dualstack_endpoint"

Expand Down Expand Up @@ -357,6 +359,9 @@ type SharedConfig struct {

// ResponseChecksumValidation indicates if the response checksum should be validated
ResponseChecksumValidation aws.ResponseChecksumValidation

// Profile name used for fetching IMDS credentials.
EC2InstanceProfileName string
}

func (c SharedConfig) getDefaultsMode(ctx context.Context) (value aws.DefaultsMode, ok bool, err error) {
Expand Down Expand Up @@ -877,6 +882,7 @@ func mergeSections(dst *ini.Sections, src ini.Sections) error {
ec2MetadataServiceEndpointModeKey,
ec2MetadataServiceEndpointKey,
ec2MetadataV1DisabledKey,
ec2InstanceProfileNameKey,
useDualStackEndpoint,
useFIPSEndpointKey,
defaultsModeKey,
Expand Down Expand Up @@ -1110,6 +1116,8 @@ func (c *SharedConfig) setFromIniSection(profile string, section ini.Section) er
updateString(&c.EC2IMDSEndpoint, section, ec2MetadataServiceEndpointKey)
updateBoolPtr(&c.EC2IMDSv1Disabled, section, ec2MetadataV1DisabledKey)

updateString(&c.EC2InstanceProfileName, section, ec2InstanceProfileNameKey)

updateUseDualStackEndpoint(&c.UseDualStackEndpoint, section, useDualStackEndpoint)
updateUseFIPSEndpoint(&c.UseFIPSEndpoint, section, useFIPSEndpointKey)

Expand Down Expand Up @@ -1678,3 +1686,11 @@ func updateUseFIPSEndpoint(dst *aws.FIPSEndpointState, section ini.Section, key

return
}

func (c SharedConfig) getEC2InstanceProfileName() (string, bool, error) {
if len(c.EC2InstanceProfileName) == 0 {
return "", false, nil
}

return c.EC2InstanceProfileName, true, nil
}
9 changes: 9 additions & 0 deletions config/shared_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,15 @@ func TestNewSharedConfig(t *testing.T) {
},
Err: fmt.Errorf("invalid value for shared config profile field, response_checksum_validation=blabla, must be when_supported/when_required"),
},

"profile with ec2 instance profile name": {
ConfigFilenames: []string{testConfigFilename},
Profile: "ec2_instance_profile_name",
Expected: SharedConfig{
Profile: "ec2_instance_profile_name",
EC2InstanceProfileName: "ProfileName",
},
},
}

for name, c := range cases {
Expand Down
Loading