Skip to content
This repository was archived by the owner on Nov 27, 2023. It is now read-only.

Use LoadBalancer's VPC and subnet when x-aws-loadbalancer is set #1123

Merged
merged 1 commit into from
Jan 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ecs/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type API interface {
CheckVPC(ctx context.Context, vpcID string) error
GetDefaultVPC(ctx context.Context) (string, error)
GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error)
IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error)
IsPublicSubnet(ctx context.Context, subNetID string) (bool, error)
GetRoleArn(ctx context.Context, name string) (string, error)
StackExists(ctx context.Context, name string) (bool, error)
CreateStack(ctx context.Context, name string, region string, template []byte) error
Expand Down Expand Up @@ -68,7 +68,7 @@ type API interface {
getURLWithPortMapping(ctx context.Context, targetGroupArns []string) ([]compose.PortPublisher, error)
ListTasks(ctx context.Context, cluster string, family string) ([]string, error)
GetPublicIPs(ctx context.Context, interfaces ...string) (map[string]string, error)
ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, error)
ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, string, []awsResource, error)
GetLoadBalancerURL(ctx context.Context, arn string) (string, error)
GetParameter(ctx context.Context, name string) (string, error)
SecurityGroupExists(ctx context.Context, sg string) (bool, error)
Expand Down
50 changes: 34 additions & 16 deletions ecs/awsResources.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ func (b *ecsAPIService) parse(ctx context.Context, project *types.Project, templ
if err != nil {
return r, err
}
r.vpc, r.subnets, err = b.parseVPCExtension(ctx, project)
err = b.parseLoadBalancerExtension(ctx, project, &r)
if err != nil {
return r, err
}
r.loadBalancer, r.loadBalancerType, err = b.parseLoadBalancerExtension(ctx, project)
err = b.parseVPCExtension(ctx, project, &r)
if err != nil {
return r, err
}
Expand Down Expand Up @@ -165,7 +165,7 @@ func (b *ecsAPIService) parseClusterExtension(ctx context.Context, project *type
return nil, nil
}

func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Project) (string, []awsResource, error) {
func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Project, r *awsResources) error {
var vpc string
if x, ok := project.Extensions[extensionVPC]; ok {
vpc = x.(string)
Expand All @@ -177,57 +177,75 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
vpc = id[i+1:]
}

if r.vpc != "" {
if r.vpc != vpc {
return fmt.Errorf("load balancer set by %s is attached to VPC %s", extensionLoadBalancer, r.vpc)
}
return nil
}

err = b.aws.CheckVPC(ctx, vpc)
if err != nil {
return "", nil, err
return err
}

} else {
if r.vpc != "" {
return nil
}

defaultVPC, err := b.aws.GetDefaultVPC(ctx)
if err != nil {
return "", nil, err
return err
}
vpc = defaultVPC
}

subNets, err := b.aws.GetSubNets(ctx, vpc)
if err != nil {
return "", nil, err
return err
}

var publicSubNets []awsResource
for _, subNet := range subNets {
isPublic, err := b.aws.IsPublicSubnet(ctx, vpc, subNet.ID())
isPublic, err := b.aws.IsPublicSubnet(ctx, subNet.ID())
if err != nil {
return "", nil, err
return err
}
if isPublic {
publicSubNets = append(publicSubNets, subNet)
}
}

if len(publicSubNets) < 2 {
return "", nil, fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc)
return fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc)
}
return vpc, publicSubNets, nil

r.vpc = vpc
r.subnets = subNets
return nil
}

func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (awsResource, string, error) {
func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project, r *awsResources) error {
if x, ok := project.Extensions[extensionLoadBalancer]; ok {
nameOrArn := x.(string)
loadBalancer, loadBalancerType, err := b.aws.ResolveLoadBalancer(ctx, nameOrArn)
loadBalancer, loadBalancerType, vpc, subnets, err := b.aws.ResolveLoadBalancer(ctx, nameOrArn)
if err != nil {
return nil, "", err
return err
}

required := getRequiredLoadBalancerType(project)
if loadBalancerType != required {
return nil, "", fmt.Errorf("load balancer %q is of type %s, project require a %s", nameOrArn, loadBalancerType, required)
return fmt.Errorf("load balancer %q is of type %s, project require a %s", nameOrArn, loadBalancerType, required)
}

return loadBalancer, loadBalancerType, err
r.loadBalancer = loadBalancer
r.loadBalancerType = loadBalancerType
r.vpc = vpc
r.subnets = subnets
return err
}
return nil, "", nil
return nil
}

func (b *ecsAPIService) parseExternalNetworks(ctx context.Context, project *types.Project) (map[string]string, error) {
Expand Down
13 changes: 7 additions & 6 deletions ecs/aws_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 15 additions & 9 deletions ecs/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error
return ids, nil
}

func (s sdk) IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error) {
func (s sdk) IsPublicSubnet(ctx context.Context, subNetID string) (bool, error) {
tables, err := s.EC2.DescribeRouteTablesWithContext(ctx, &ec2.DescribeRouteTablesInput{
Filters: []*ec2.Filter{
{
Expand Down Expand Up @@ -1045,31 +1045,37 @@ func (s sdk) GetPublicIPs(ctx context.Context, interfaces ...string) (map[string
}
}

func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrarn string) (awsResource, string, error) {
logrus.Debug("Check if LoadBalancer exists: ", nameOrarn)
func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, string, []awsResource, error) {
logrus.Debug("Check if LoadBalancer exists: ", nameOrArn)
var arns []*string
var names []*string
if arn.IsARN(nameOrarn) {
arns = append(arns, aws.String(nameOrarn))
if arn.IsARN(nameOrArn) {
arns = append(arns, aws.String(nameOrArn))
} else {
names = append(names, aws.String(nameOrarn))
names = append(names, aws.String(nameOrArn))
}

lbs, err := s.ELB.DescribeLoadBalancersWithContext(ctx, &elbv2.DescribeLoadBalancersInput{
LoadBalancerArns: arns,
Names: names,
})
if err != nil {
return nil, "", err
return nil, "", "", nil, err
}
if len(lbs.LoadBalancers) == 0 {
return nil, "", errors.Wrapf(errdefs.ErrNotFound, "load balancer %q does not exist", nameOrarn)
return nil, "", "", nil, errors.Wrapf(errdefs.ErrNotFound, "load balancer %q does not exist", nameOrArn)
}
it := lbs.LoadBalancers[0]
var subNets []awsResource
for _, az := range it.AvailabilityZones {
subNets = append(subNets, existingAWSResource{
id: aws.StringValue(az.SubnetId),
})
}
return existingAWSResource{
arn: aws.StringValue(it.LoadBalancerArn),
id: aws.StringValue(it.LoadBalancerName),
}, aws.StringValue(it.Type), nil
}, aws.StringValue(it.Type), aws.StringValue(it.VpcId), subNets, nil
}

func (s sdk) GetLoadBalancerURL(ctx context.Context, arn string) (string, error) {
Expand Down