From 075f54713e1f4ccb1e1b94e3639a7c40176b2110 Mon Sep 17 00:00:00 2001 From: Nicolas De Loof Date: Tue, 12 Jan 2021 14:57:28 +0100 Subject: [PATCH] Use LoadBalancer's VPC and subnet when x-aws-loadbalancer is set Signed-off-by: Nicolas De Loof --- ecs/aws.go | 4 ++-- ecs/awsResources.go | 50 ++++++++++++++++++++++++++++++--------------- ecs/aws_mock.go | 13 ++++++------ ecs/sdk.go | 24 ++++++++++++++-------- 4 files changed, 58 insertions(+), 33 deletions(-) diff --git a/ecs/aws.go b/ecs/aws.go index e20fe66ad..83156b227 100644 --- a/ecs/aws.go +++ b/ecs/aws.go @@ -40,7 +40,7 @@ type API interface { CheckVPC(ctx context.Context, vpcID string) error GetDefaultVPC(ctx context.Context) (string, error) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error) - IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error) + IsPublicSubnet(ctx context.Context, subNetID string) (bool, error) GetRoleArn(ctx context.Context, name string) (string, error) StackExists(ctx context.Context, name string) (bool, error) CreateStack(ctx context.Context, name string, region string, template []byte) error @@ -68,7 +68,7 @@ type API interface { getURLWithPortMapping(ctx context.Context, targetGroupArns []string) ([]compose.PortPublisher, error) ListTasks(ctx context.Context, cluster string, family string) ([]string, error) GetPublicIPs(ctx context.Context, interfaces ...string) (map[string]string, error) - ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, error) + ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, string, []awsResource, error) GetLoadBalancerURL(ctx context.Context, arn string) (string, error) GetParameter(ctx context.Context, name string) (string, error) SecurityGroupExists(ctx context.Context, sg string) (bool, error) diff --git a/ecs/awsResources.go b/ecs/awsResources.go index 26cbed64b..a7688a3a4 100644 --- a/ecs/awsResources.go +++ b/ecs/awsResources.go @@ -129,11 +129,11 @@ func (b *ecsAPIService) parse(ctx context.Context, project *types.Project, templ if err != nil { return r, err } - r.vpc, r.subnets, err = b.parseVPCExtension(ctx, project) + err = b.parseLoadBalancerExtension(ctx, project, &r) if err != nil { return r, err } - r.loadBalancer, r.loadBalancerType, err = b.parseLoadBalancerExtension(ctx, project) + err = b.parseVPCExtension(ctx, project, &r) if err != nil { return r, err } @@ -165,7 +165,7 @@ func (b *ecsAPIService) parseClusterExtension(ctx context.Context, project *type return nil, nil } -func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Project) (string, []awsResource, error) { +func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Project, r *awsResources) error { var vpc string if x, ok := project.Extensions[extensionVPC]; ok { vpc = x.(string) @@ -177,29 +177,40 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr vpc = id[i+1:] } + if r.vpc != "" { + if r.vpc != vpc { + return fmt.Errorf("load balancer set by %s is attached to VPC %s", extensionLoadBalancer, r.vpc) + } + return nil + } + err = b.aws.CheckVPC(ctx, vpc) if err != nil { - return "", nil, err + return err } } else { + if r.vpc != "" { + return nil + } + defaultVPC, err := b.aws.GetDefaultVPC(ctx) if err != nil { - return "", nil, err + return err } vpc = defaultVPC } subNets, err := b.aws.GetSubNets(ctx, vpc) if err != nil { - return "", nil, err + return err } var publicSubNets []awsResource for _, subNet := range subNets { - isPublic, err := b.aws.IsPublicSubnet(ctx, vpc, subNet.ID()) + isPublic, err := b.aws.IsPublicSubnet(ctx, subNet.ID()) if err != nil { - return "", nil, err + return err } if isPublic { publicSubNets = append(publicSubNets, subNet) @@ -207,27 +218,34 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr } if len(publicSubNets) < 2 { - return "", nil, fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc) + return fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc) } - return vpc, publicSubNets, nil + + r.vpc = vpc + r.subnets = subNets + return nil } -func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (awsResource, string, error) { +func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project, r *awsResources) error { if x, ok := project.Extensions[extensionLoadBalancer]; ok { nameOrArn := x.(string) - loadBalancer, loadBalancerType, err := b.aws.ResolveLoadBalancer(ctx, nameOrArn) + loadBalancer, loadBalancerType, vpc, subnets, err := b.aws.ResolveLoadBalancer(ctx, nameOrArn) if err != nil { - return nil, "", err + return err } required := getRequiredLoadBalancerType(project) if loadBalancerType != required { - return nil, "", fmt.Errorf("load balancer %q is of type %s, project require a %s", nameOrArn, loadBalancerType, required) + return fmt.Errorf("load balancer %q is of type %s, project require a %s", nameOrArn, loadBalancerType, required) } - return loadBalancer, loadBalancerType, err + r.loadBalancer = loadBalancer + r.loadBalancerType = loadBalancerType + r.vpc = vpc + r.subnets = subnets + return err } - return nil, "", nil + return nil } func (b *ecsAPIService) parseExternalNetworks(ctx context.Context, project *types.Project) (map[string]string, error) { diff --git a/ecs/aws_mock.go b/ecs/aws_mock.go index 486648f44..6d7ef4f6f 100644 --- a/ecs/aws_mock.go +++ b/ecs/aws_mock.go @@ -6,13 +6,12 @@ package ecs import ( context "context" - reflect "reflect" - cloudformation "github.com/aws/aws-sdk-go/service/cloudformation" ecs "github.com/aws/aws-sdk-go/service/ecs" compose "github.com/docker/compose-cli/api/compose" secrets "github.com/docker/compose-cli/api/secrets" gomock "github.com/golang/mock/gomock" + reflect "reflect" ) // MockAPI is a mock of API interface @@ -455,7 +454,7 @@ func (mr *MockAPIMockRecorder) InspectSecret(arg0, arg1 interface{}) *gomock.Cal } // IsPublicSubnet mocks base method -func (m *MockAPI) IsPublicSubnet(ctx context.Context, arg0 string, arg1 string) (bool, error) { +func (m *MockAPI) IsPublicSubnet(arg0 context.Context, arg1 string) (bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IsPublicSubnet", arg0, arg1) ret0, _ := ret[0].(bool) @@ -605,13 +604,15 @@ func (mr *MockAPIMockRecorder) ResolveFileSystem(arg0, arg1 interface{}) *gomock } // ResolveLoadBalancer mocks base method -func (m *MockAPI) ResolveLoadBalancer(arg0 context.Context, arg1 string) (awsResource, string, error) { +func (m *MockAPI) ResolveLoadBalancer(arg0 context.Context, arg1 string) (awsResource, string, string, []awsResource, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ResolveLoadBalancer", arg0, arg1) ret0, _ := ret[0].(awsResource) ret1, _ := ret[1].(string) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret2, _ := ret[2].(string) + ret3, _ := ret[3].([]awsResource) + ret4, _ := ret[4].(error) + return ret0, ret1, ret2, ret3, ret4 } // ResolveLoadBalancer indicates an expected call of ResolveLoadBalancer diff --git a/ecs/sdk.go b/ecs/sdk.go index 374957bf8..1c9fa57d2 100644 --- a/ecs/sdk.go +++ b/ecs/sdk.go @@ -210,7 +210,7 @@ func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error return ids, nil } -func (s sdk) IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error) { +func (s sdk) IsPublicSubnet(ctx context.Context, subNetID string) (bool, error) { tables, err := s.EC2.DescribeRouteTablesWithContext(ctx, &ec2.DescribeRouteTablesInput{ Filters: []*ec2.Filter{ { @@ -1045,14 +1045,14 @@ func (s sdk) GetPublicIPs(ctx context.Context, interfaces ...string) (map[string } } -func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrarn string) (awsResource, string, error) { - logrus.Debug("Check if LoadBalancer exists: ", nameOrarn) +func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, string, []awsResource, error) { + logrus.Debug("Check if LoadBalancer exists: ", nameOrArn) var arns []*string var names []*string - if arn.IsARN(nameOrarn) { - arns = append(arns, aws.String(nameOrarn)) + if arn.IsARN(nameOrArn) { + arns = append(arns, aws.String(nameOrArn)) } else { - names = append(names, aws.String(nameOrarn)) + names = append(names, aws.String(nameOrArn)) } lbs, err := s.ELB.DescribeLoadBalancersWithContext(ctx, &elbv2.DescribeLoadBalancersInput{ @@ -1060,16 +1060,22 @@ func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrarn string) (awsReso Names: names, }) if err != nil { - return nil, "", err + return nil, "", "", nil, err } if len(lbs.LoadBalancers) == 0 { - return nil, "", errors.Wrapf(errdefs.ErrNotFound, "load balancer %q does not exist", nameOrarn) + return nil, "", "", nil, errors.Wrapf(errdefs.ErrNotFound, "load balancer %q does not exist", nameOrArn) } it := lbs.LoadBalancers[0] + var subNets []awsResource + for _, az := range it.AvailabilityZones { + subNets = append(subNets, existingAWSResource{ + id: aws.StringValue(az.SubnetId), + }) + } return existingAWSResource{ arn: aws.StringValue(it.LoadBalancerArn), id: aws.StringValue(it.LoadBalancerName), - }, aws.StringValue(it.Type), nil + }, aws.StringValue(it.Type), aws.StringValue(it.VpcId), subNets, nil } func (s sdk) GetLoadBalancerURL(ctx context.Context, arn string) (string, error) {