Skip to content
This repository was archived by the owner on Nov 27, 2023. It is now read-only.

Commit 075f547

Browse files
committed
Use LoadBalancer's VPC and subnet when x-aws-loadbalancer is set
Signed-off-by: Nicolas De Loof <[email protected]>
1 parent f6e5c91 commit 075f547

File tree

4 files changed

+58
-33
lines changed

4 files changed

+58
-33
lines changed

ecs/aws.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ type API interface {
4040
CheckVPC(ctx context.Context, vpcID string) error
4141
GetDefaultVPC(ctx context.Context) (string, error)
4242
GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error)
43-
IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error)
43+
IsPublicSubnet(ctx context.Context, subNetID string) (bool, error)
4444
GetRoleArn(ctx context.Context, name string) (string, error)
4545
StackExists(ctx context.Context, name string) (bool, error)
4646
CreateStack(ctx context.Context, name string, region string, template []byte) error
@@ -68,7 +68,7 @@ type API interface {
6868
getURLWithPortMapping(ctx context.Context, targetGroupArns []string) ([]compose.PortPublisher, error)
6969
ListTasks(ctx context.Context, cluster string, family string) ([]string, error)
7070
GetPublicIPs(ctx context.Context, interfaces ...string) (map[string]string, error)
71-
ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, error)
71+
ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, string, []awsResource, error)
7272
GetLoadBalancerURL(ctx context.Context, arn string) (string, error)
7373
GetParameter(ctx context.Context, name string) (string, error)
7474
SecurityGroupExists(ctx context.Context, sg string) (bool, error)

ecs/awsResources.go

+34-16
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ func (b *ecsAPIService) parse(ctx context.Context, project *types.Project, templ
129129
if err != nil {
130130
return r, err
131131
}
132-
r.vpc, r.subnets, err = b.parseVPCExtension(ctx, project)
132+
err = b.parseLoadBalancerExtension(ctx, project, &r)
133133
if err != nil {
134134
return r, err
135135
}
136-
r.loadBalancer, r.loadBalancerType, err = b.parseLoadBalancerExtension(ctx, project)
136+
err = b.parseVPCExtension(ctx, project, &r)
137137
if err != nil {
138138
return r, err
139139
}
@@ -165,7 +165,7 @@ func (b *ecsAPIService) parseClusterExtension(ctx context.Context, project *type
165165
return nil, nil
166166
}
167167

168-
func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Project) (string, []awsResource, error) {
168+
func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Project, r *awsResources) error {
169169
var vpc string
170170
if x, ok := project.Extensions[extensionVPC]; ok {
171171
vpc = x.(string)
@@ -177,57 +177,75 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
177177
vpc = id[i+1:]
178178
}
179179

180+
if r.vpc != "" {
181+
if r.vpc != vpc {
182+
return fmt.Errorf("load balancer set by %s is attached to VPC %s", extensionLoadBalancer, r.vpc)
183+
}
184+
return nil
185+
}
186+
180187
err = b.aws.CheckVPC(ctx, vpc)
181188
if err != nil {
182-
return "", nil, err
189+
return err
183190
}
184191

185192
} else {
193+
if r.vpc != "" {
194+
return nil
195+
}
196+
186197
defaultVPC, err := b.aws.GetDefaultVPC(ctx)
187198
if err != nil {
188-
return "", nil, err
199+
return err
189200
}
190201
vpc = defaultVPC
191202
}
192203

193204
subNets, err := b.aws.GetSubNets(ctx, vpc)
194205
if err != nil {
195-
return "", nil, err
206+
return err
196207
}
197208

198209
var publicSubNets []awsResource
199210
for _, subNet := range subNets {
200-
isPublic, err := b.aws.IsPublicSubnet(ctx, vpc, subNet.ID())
211+
isPublic, err := b.aws.IsPublicSubnet(ctx, subNet.ID())
201212
if err != nil {
202-
return "", nil, err
213+
return err
203214
}
204215
if isPublic {
205216
publicSubNets = append(publicSubNets, subNet)
206217
}
207218
}
208219

209220
if len(publicSubNets) < 2 {
210-
return "", nil, fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc)
221+
return fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc)
211222
}
212-
return vpc, publicSubNets, nil
223+
224+
r.vpc = vpc
225+
r.subnets = subNets
226+
return nil
213227
}
214228

215-
func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (awsResource, string, error) {
229+
func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project, r *awsResources) error {
216230
if x, ok := project.Extensions[extensionLoadBalancer]; ok {
217231
nameOrArn := x.(string)
218-
loadBalancer, loadBalancerType, err := b.aws.ResolveLoadBalancer(ctx, nameOrArn)
232+
loadBalancer, loadBalancerType, vpc, subnets, err := b.aws.ResolveLoadBalancer(ctx, nameOrArn)
219233
if err != nil {
220-
return nil, "", err
234+
return err
221235
}
222236

223237
required := getRequiredLoadBalancerType(project)
224238
if loadBalancerType != required {
225-
return nil, "", fmt.Errorf("load balancer %q is of type %s, project require a %s", nameOrArn, loadBalancerType, required)
239+
return fmt.Errorf("load balancer %q is of type %s, project require a %s", nameOrArn, loadBalancerType, required)
226240
}
227241

228-
return loadBalancer, loadBalancerType, err
242+
r.loadBalancer = loadBalancer
243+
r.loadBalancerType = loadBalancerType
244+
r.vpc = vpc
245+
r.subnets = subnets
246+
return err
229247
}
230-
return nil, "", nil
248+
return nil
231249
}
232250

233251
func (b *ecsAPIService) parseExternalNetworks(ctx context.Context, project *types.Project) (map[string]string, error) {

ecs/aws_mock.go

+7-6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ecs/sdk.go

+15-9
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error
210210
return ids, nil
211211
}
212212

213-
func (s sdk) IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error) {
213+
func (s sdk) IsPublicSubnet(ctx context.Context, subNetID string) (bool, error) {
214214
tables, err := s.EC2.DescribeRouteTablesWithContext(ctx, &ec2.DescribeRouteTablesInput{
215215
Filters: []*ec2.Filter{
216216
{
@@ -1045,31 +1045,37 @@ func (s sdk) GetPublicIPs(ctx context.Context, interfaces ...string) (map[string
10451045
}
10461046
}
10471047

1048-
func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrarn string) (awsResource, string, error) {
1049-
logrus.Debug("Check if LoadBalancer exists: ", nameOrarn)
1048+
func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, string, []awsResource, error) {
1049+
logrus.Debug("Check if LoadBalancer exists: ", nameOrArn)
10501050
var arns []*string
10511051
var names []*string
1052-
if arn.IsARN(nameOrarn) {
1053-
arns = append(arns, aws.String(nameOrarn))
1052+
if arn.IsARN(nameOrArn) {
1053+
arns = append(arns, aws.String(nameOrArn))
10541054
} else {
1055-
names = append(names, aws.String(nameOrarn))
1055+
names = append(names, aws.String(nameOrArn))
10561056
}
10571057

10581058
lbs, err := s.ELB.DescribeLoadBalancersWithContext(ctx, &elbv2.DescribeLoadBalancersInput{
10591059
LoadBalancerArns: arns,
10601060
Names: names,
10611061
})
10621062
if err != nil {
1063-
return nil, "", err
1063+
return nil, "", "", nil, err
10641064
}
10651065
if len(lbs.LoadBalancers) == 0 {
1066-
return nil, "", errors.Wrapf(errdefs.ErrNotFound, "load balancer %q does not exist", nameOrarn)
1066+
return nil, "", "", nil, errors.Wrapf(errdefs.ErrNotFound, "load balancer %q does not exist", nameOrArn)
10671067
}
10681068
it := lbs.LoadBalancers[0]
1069+
var subNets []awsResource
1070+
for _, az := range it.AvailabilityZones {
1071+
subNets = append(subNets, existingAWSResource{
1072+
id: aws.StringValue(az.SubnetId),
1073+
})
1074+
}
10691075
return existingAWSResource{
10701076
arn: aws.StringValue(it.LoadBalancerArn),
10711077
id: aws.StringValue(it.LoadBalancerName),
1072-
}, aws.StringValue(it.Type), nil
1078+
}, aws.StringValue(it.Type), aws.StringValue(it.VpcId), subNets, nil
10731079
}
10741080

10751081
func (s sdk) GetLoadBalancerURL(ctx context.Context, arn string) (string, error) {

0 commit comments

Comments
 (0)