Skip to content

Commit 9b0a432

Browse files
authored
[feat aga] Implement endpoint group management with port override conflict resolution (#4470)
* [feat aga] few bugfixes and code refactoring * [feat aga] Implement model builder for endpoint groups * [feat aga] Implement endpoint group deployer with port override conflict resolution * [feat aga] addressing feedback for refactoring
1 parent d278541 commit 9b0a432

31 files changed

+4965
-137
lines changed

controllers/aga/globalaccelerator_controller.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ const (
6868
requeueMessage = "Monitoring provisioning state"
6969
statusUpdateRequeueTime = 1 * time.Minute
7070

71-
// Status reason constants
72-
EndpointLoadFailed = "EndpointLoadFailed"
73-
7471
// Metric stage constants
7572
MetricStageFetchGlobalAccelerator = "fetch_globalAccelerator"
7673
MetricStageAddFinalizers = "add_finalizers"
@@ -90,7 +87,7 @@ const (
9087
func NewGlobalAcceleratorReconciler(k8sClient client.Client, eventRecorder record.EventRecorder, finalizerManager k8s.FinalizerManager, config config.ControllerConfig, cloud services.Cloud, logger logr.Logger, metricsCollector lbcmetrics.MetricCollector, reconcileCounters *metricsutil.ReconcileCounters) *globalAcceleratorReconciler {
9188

9289
// Create tracking provider
93-
trackingProvider := tracking.NewDefaultProvider(agaTagPrefix, config.ClusterName, tracking.WithRegion(config.AWSConfig.Region))
90+
trackingProvider := tracking.NewDefaultProvider(agaTagPrefix, config.ClusterName, tracking.WithRegion(cloud.Region()))
9491

9592
// Create model builder
9693
agaModelBuilder := aga.NewDefaultModelBuilder(
@@ -99,7 +96,7 @@ func NewGlobalAcceleratorReconciler(k8sClient client.Client, eventRecorder recor
9996
trackingProvider,
10097
config.FeatureGates,
10198
config.ClusterName,
102-
config.AWSConfig.Region,
99+
cloud.Region(),
103100
config.DefaultTags,
104101
config.ExternalManagedTags,
105102
logger.WithName("aga-model-builder"),
@@ -272,11 +269,11 @@ func (r *globalAcceleratorReconciler) buildModel(ctx context.Context, ga *agaapi
272269
func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx context.Context, ga *agaapi.GlobalAccelerator) error {
273270
r.logger.Info("Reconciling GlobalAccelerator resources", "globalAccelerator", k8s.NamespacedName(ga))
274271

275-
// Get all endpoints from GA
276-
endpoints := aga.GetAllEndpointsFromGA(ga)
272+
// Get all desired endpoints from GA
273+
endpoints := aga.GetAllDesiredEndpointsFromGA(ga)
277274

278275
// Track referenced endpoints
279-
r.referenceTracker.UpdateReferencesForGA(ga, endpoints)
276+
r.referenceTracker.UpdateDesiredEndpointReferencesForGA(ga, endpoints)
280277

281278
// Update resource watches with the endpointResourcesManager
282279
r.endpointResourcesManager.MonitorEndpointResources(ga, endpoints)
@@ -285,10 +282,10 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co
285282
_, fatalErrors := r.endpointLoader.LoadEndpoints(ctx, ga, endpoints)
286283
if len(fatalErrors) > 0 {
287284
err := fmt.Errorf("failed to load endpoints: %v", fatalErrors[0])
288-
r.logger.Error(err, "Fatal error loading endpoints")
289-
285+
r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GlobalAcceleratorEventReasonFailedEndpointLoad, fmt.Sprintf("Failed to reconcile due to %v", err))
286+
r.logger.Error(err, fmt.Sprintf("fatal error loading endpoints for %v", k8s.NamespacedName(ga)))
290287
// Handle other endpoint loading errors
291-
if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, EndpointLoadFailed, err.Error()); statusErr != nil {
288+
if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, agadeploy.EndpointLoadFailed, err.Error()); statusErr != nil {
292289
r.logger.Error(statusErr, "Failed to update GlobalAccelerator status after endpoint load failure")
293290
}
294291
return err
@@ -302,6 +299,8 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co
302299
}
303300
r.metricsCollector.ObserveControllerReconcileLatency(controllerName, MetricStageBuildModel, buildModelFn)
304301
if err != nil {
302+
r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GatewayEventReasonFailedBuildModel, fmt.Sprintf("Failed to build model: %v", err))
303+
r.logger.Error(err, fmt.Sprintf("Failed to build model for: %v", k8s.NamespacedName(ga)))
305304
// Update status to indicate model building failure
306305
if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, agadeploy.ModelBuildFailed, fmt.Sprintf("Failed to build model: %v", err)); statusErr != nil {
307306
r.logger.Error(statusErr, "Failed to update GlobalAccelerator status after model build failure")
@@ -316,7 +315,7 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co
316315
r.metricsCollector.ObserveControllerReconcileLatency(controllerName, MetricStageDeployStack, deployStackFn)
317316
if err != nil {
318317
r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GlobalAcceleratorEventReasonFailedDeploy, fmt.Sprintf("Failed to deploy stack due to %v", err))
319-
318+
r.logger.Error(err, fmt.Sprintf("Failed to deploy stack for: %v", k8s.NamespacedName(ga)))
320319
// Update status to indicate deployment failure
321320
if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, agadeploy.DeploymentFailed, fmt.Sprintf("Failed to deploy stack: %v", err)); statusErr != nil {
322321
r.logger.Error(statusErr, "Failed to update GlobalAccelerator status after deployment failure")

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ require (
2525
github.com/golang/mock v1.6.0
2626
github.com/google/go-cmp v0.7.0
2727
github.com/google/uuid v1.6.0
28+
github.com/hashicorp/golang-lru v1.0.2
2829
github.com/onsi/ginkgo/v2 v2.23.3
2930
github.com/onsi/gomega v1.37.0
3031
github.com/pkg/errors v0.9.1
@@ -104,7 +105,6 @@ require (
104105
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 // indirect
105106
github.com/hashicorp/errwrap v1.1.0 // indirect
106107
github.com/hashicorp/go-multierror v1.1.1 // indirect
107-
github.com/hashicorp/golang-lru v1.0.2 // indirect
108108
github.com/huandu/xstrings v1.5.0 // indirect
109109
github.com/imkira/go-interpol v1.1.0 // indirect
110110
github.com/inconshreveable/mousetrap v1.1.0 // indirect

main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ func main() {
241241
}
242242

243243
// Setup GlobalAccelerator controller only if enabled
244-
if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, controllerCFG.AWSConfig.Region) {
244+
if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, cloud.Region()) {
245245
agaReconciler := agacontroller.NewGlobalAcceleratorReconciler(mgr.GetClient(), mgr.GetEventRecorderFor("globalAccelerator"),
246246
finalizerManager, controllerCFG, cloud, ctrl.Log.WithName("controllers").WithName("globalAccelerator"), lbcMetricsCollector, reconcileCounters)
247247
if err := agaReconciler.SetupWithManager(ctx, mgr, clientSet); err != nil {
@@ -442,7 +442,7 @@ func main() {
442442
networkingwebhook.NewIngressValidator(mgr.GetClient(), controllerCFG.IngressConfig, ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr)
443443

444444
// Setup GlobalAccelerator validator only if enabled
445-
if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, controllerCFG.AWSConfig.Region) {
445+
if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, cloud.Region()) {
446446
agawebhook.NewGlobalAcceleratorValidator(ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr)
447447
}
448448
//+kubebuilder:scaffold:builder

pkg/aga/endpoint_utils.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ type EndpointReference struct {
2828
Endpoint *agaapi.GlobalAcceleratorEndpoint
2929
}
3030

31-
// GetAllEndpointsFromGA extracts all endpoint references from a GlobalAccelerator resource
32-
func GetAllEndpointsFromGA(ga *agaapi.GlobalAccelerator) []EndpointReference {
31+
// GetAllDesiredEndpointsFromGA extracts all endpoint references from a GlobalAccelerator resource
32+
func GetAllDesiredEndpointsFromGA(ga *agaapi.GlobalAccelerator) []EndpointReference {
3333
if ga == nil || ga.Spec.Listeners == nil {
3434
return nil
3535
}

pkg/aga/endpoint_utils_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ func TestGetAllEndpointsFromGA(t *testing.T) {
180180
}
181181
}
182182

183-
result := GetAllEndpointsFromGA(tt.ga)
183+
result := GetAllDesiredEndpointsFromGA(tt.ga)
184184

185185
// Compare lengths
186186
assert.Equal(t, len(tt.expected), len(result))
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
package aga
2+
3+
import (
4+
"context"
5+
"fmt"
6+
awssdk "github.com/aws/aws-sdk-go-v2/aws"
7+
agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1"
8+
agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga"
9+
"sigs.k8s.io/aws-load-balancer-controller/pkg/model/core"
10+
)
11+
12+
// endpointGroupBuilder builds EndpointGroup model resources
13+
type endpointGroupBuilder interface {
14+
// Build builds all endpoint groups for all listeners
15+
Build(ctx context.Context, stack core.Stack, listeners []*agamodel.Listener, listenerConfigs []agaapi.GlobalAcceleratorListener) ([]*agamodel.EndpointGroup, error)
16+
17+
// buildEndpointGroupsForListener builds endpoint groups for a specific listener
18+
buildEndpointGroupsForListener(ctx context.Context, stack core.Stack, listener *agamodel.Listener, endpointGroups []agaapi.GlobalAcceleratorEndpointGroup, listenerIndex int) ([]*agamodel.EndpointGroup, error)
19+
}
20+
21+
// NewEndpointGroupBuilder constructs new endpointGroupBuilder
22+
func NewEndpointGroupBuilder(clusterRegion string) endpointGroupBuilder {
23+
return &defaultEndpointGroupBuilder{
24+
clusterRegion: clusterRegion,
25+
}
26+
}
27+
28+
var _ endpointGroupBuilder = &defaultEndpointGroupBuilder{}
29+
30+
type defaultEndpointGroupBuilder struct {
31+
clusterRegion string
32+
}
33+
34+
// Build builds EndpointGroup model resources
35+
func (b *defaultEndpointGroupBuilder) Build(ctx context.Context, stack core.Stack, listeners []*agamodel.Listener, listenerConfigs []agaapi.GlobalAcceleratorListener) ([]*agamodel.EndpointGroup, error) {
36+
if listeners == nil || len(listeners) == 0 {
37+
return nil, nil
38+
}
39+
40+
var result []*agamodel.EndpointGroup
41+
42+
// Create a map of all listener port ranges
43+
listenerPortRanges := make(map[string][]agamodel.PortRange) // Maps listener ID to its port ranges
44+
for _, listener := range listeners {
45+
listenerPortRanges[listener.ID()] = listener.Spec.PortRanges
46+
}
47+
48+
for i, listener := range listeners {
49+
listenerConfig := listenerConfigs[i]
50+
if listenerConfig.EndpointGroups == nil {
51+
continue
52+
}
53+
54+
listenerEndpointGroups, err := b.buildEndpointGroupsForListener(ctx, stack, listener, *listenerConfig.EndpointGroups, i)
55+
if err != nil {
56+
return nil, err
57+
}
58+
result = append(result, listenerEndpointGroups...)
59+
}
60+
61+
// Validate endpoint ports in all port overrides across all listeners
62+
if err := b.validateEndpointPortOverridesCrossListeners(result, listenerPortRanges); err != nil {
63+
return nil, err
64+
}
65+
66+
return result, nil
67+
}
68+
69+
// validateEndpointPortOverridesCrossListeners performs validations for endpoint port overrides across all listeners
70+
func (b *defaultEndpointGroupBuilder) validateEndpointPortOverridesCrossListeners(endpointGroups []*agamodel.EndpointGroup, listenerPortRanges map[string][]agamodel.PortRange) error {
71+
// Track endpoint port usage across all endpoint groups
72+
endpointPortUsage := make(map[int32]string) // Maps endpoint port to listener ID
73+
74+
// Check all endpoint groups for port overrides
75+
for _, endpointGroup := range endpointGroups {
76+
listenerID := endpointGroup.Listener.ID()
77+
78+
for _, portOverride := range endpointGroup.Spec.PortOverrides {
79+
endpointPort := portOverride.EndpointPort
80+
81+
// Rule 1: Check if endpoint port is within any listener's port range
82+
if err := b.validateEndpointPortOverridesWithinListener(endpointPort, listenerPortRanges); err != nil {
83+
return err
84+
}
85+
86+
// Rule 2: Check for duplicate endpoint port usage across listeners
87+
if existingListenerID, exists := endpointPortUsage[endpointPort]; exists && existingListenerID != listenerID {
88+
return fmt.Errorf("duplicate endpoint port %d: the same endpoint port cannot be used in port overrides from different listeners (used in %s and %s)",
89+
endpointPort, existingListenerID, listenerID)
90+
}
91+
92+
// Register this endpoint port usage
93+
endpointPortUsage[endpointPort] = listenerID
94+
}
95+
}
96+
97+
return nil
98+
}
99+
100+
// validateEndpointPortOverridesWithinListener checks if an endpoint port is within any listener's port range
101+
func (b *defaultEndpointGroupBuilder) validateEndpointPortOverridesWithinListener(endpointPort int32, listenerPortRanges map[string][]agamodel.PortRange) error {
102+
for listenerID, portRanges := range listenerPortRanges {
103+
if IsPortInRanges(endpointPort, portRanges) {
104+
// Find the specific port range for the error message
105+
for _, portRange := range portRanges {
106+
if endpointPort >= portRange.FromPort && endpointPort <= portRange.ToPort {
107+
return fmt.Errorf("endpoint port %d conflicts with listener %s port range %d-%d: endpoint port cannot be included in any listener port range",
108+
endpointPort, listenerID, portRange.FromPort, portRange.ToPort)
109+
}
110+
}
111+
}
112+
}
113+
return nil
114+
}
115+
116+
// buildEndpointGroupsForListener builds EndpointGroup models for a specific listener
117+
func (b *defaultEndpointGroupBuilder) buildEndpointGroupsForListener(ctx context.Context, stack core.Stack, listener *agamodel.Listener, endpointGroups []agaapi.GlobalAcceleratorEndpointGroup, listenerIndex int) ([]*agamodel.EndpointGroup, error) {
118+
var result []*agamodel.EndpointGroup
119+
120+
for i, endpointGroup := range endpointGroups {
121+
spec, err := b.buildEndpointGroupSpec(ctx, listener, endpointGroup)
122+
if err != nil {
123+
return nil, err
124+
}
125+
126+
resourceID := fmt.Sprintf("EndpointGroup-%d-%d", listenerIndex, i)
127+
endpointGroupModel := agamodel.NewEndpointGroup(stack, resourceID, spec, listener)
128+
result = append(result, endpointGroupModel)
129+
}
130+
131+
return result, nil
132+
}
133+
134+
// buildEndpointGroupSpec builds the EndpointGroupSpec for a single EndpointGroup model resource
135+
func (b *defaultEndpointGroupBuilder) buildEndpointGroupSpec(ctx context.Context, listener *agamodel.Listener, endpointGroup agaapi.GlobalAcceleratorEndpointGroup) (agamodel.EndpointGroupSpec, error) {
136+
region, err := b.determineRegion(endpointGroup)
137+
if err != nil {
138+
return agamodel.EndpointGroupSpec{}, err
139+
}
140+
141+
// Handle trafficDialPercentage
142+
trafficDialPercentage := endpointGroup.TrafficDialPercentage
143+
144+
portOverrides, err := b.buildPortOverrides(ctx, listener, endpointGroup)
145+
if err != nil {
146+
return agamodel.EndpointGroupSpec{}, err
147+
}
148+
149+
return agamodel.EndpointGroupSpec{
150+
ListenerARN: listener.ListenerARN(),
151+
Region: region,
152+
TrafficDialPercentage: trafficDialPercentage,
153+
PortOverrides: portOverrides,
154+
}, nil
155+
}
156+
157+
// validateListenerPortOverrideWithinListenerPortRanges ensures all listener ports used in port overrides are
158+
// contained within the listener's port ranges
159+
func (b *defaultEndpointGroupBuilder) validateListenerPortOverrideWithinListenerPortRanges(listener *agamodel.Listener, portOverrides []agamodel.PortOverride) error {
160+
if len(portOverrides) == 0 {
161+
return nil
162+
}
163+
164+
for _, portOverride := range portOverrides {
165+
listenerPort := portOverride.ListenerPort
166+
if !IsPortInRanges(listenerPort, listener.Spec.PortRanges) {
167+
return fmt.Errorf("port override listener port %d is not within any listener port ranges - this will cause AWS Global Accelerator to reject the configuration", listenerPort)
168+
}
169+
}
170+
return nil
171+
}
172+
173+
// determineRegion determines the region for the endpoint group
174+
func (b *defaultEndpointGroupBuilder) determineRegion(endpointGroup agaapi.GlobalAcceleratorEndpointGroup) (string, error) {
175+
// Use explicit region from endpoint group if specified
176+
if endpointGroup.Region != nil && awssdk.ToString(endpointGroup.Region) != "" {
177+
return awssdk.ToString(endpointGroup.Region), nil
178+
}
179+
180+
// Default to cluster region if available
181+
if b.clusterRegion != "" {
182+
return b.clusterRegion, nil
183+
}
184+
return "", fmt.Errorf("region is required for endpoint group but neither specified in the endpoint group nor available from cluster configuration")
185+
}
186+
187+
// buildPortOverrides builds the port overrides for the endpoint group
188+
func (b *defaultEndpointGroupBuilder) buildPortOverrides(_ context.Context, listener *agamodel.Listener, endpointGroup agaapi.GlobalAcceleratorEndpointGroup) ([]agamodel.PortOverride, error) {
189+
if endpointGroup.PortOverrides == nil {
190+
return nil, nil
191+
}
192+
193+
var portOverrides []agamodel.PortOverride
194+
for _, po := range *endpointGroup.PortOverrides {
195+
portOverrides = append(portOverrides, agamodel.PortOverride{
196+
ListenerPort: po.ListenerPort,
197+
EndpointPort: po.EndpointPort,
198+
})
199+
}
200+
201+
// Validate all port override rules
202+
if err := b.validatePortOverrides(listener, portOverrides); err != nil {
203+
return []agamodel.PortOverride{}, err
204+
}
205+
206+
return portOverrides, nil
207+
}
208+
209+
// validateNoDuplicatePorts checks both listener and endpoint ports for duplicates in a single pass
210+
func (b *defaultEndpointGroupBuilder) validateNoDuplicatePorts(portOverrides []agamodel.PortOverride) error {
211+
if len(portOverrides) <= 1 {
212+
return nil
213+
}
214+
215+
listenerPorts := make(map[int32]bool)
216+
endpointPorts := make(map[int32]bool)
217+
218+
for _, portOverride := range portOverrides {
219+
// Check for duplicate listener ports
220+
listenerPort := portOverride.ListenerPort
221+
if listenerPorts[listenerPort] {
222+
return fmt.Errorf("duplicate listener port %d in port overrides: each listener port can only be used once in port overrides for an endpoint group", listenerPort)
223+
}
224+
listenerPorts[listenerPort] = true
225+
226+
// Check for duplicate endpoint ports
227+
endpointPort := portOverride.EndpointPort
228+
if endpointPorts[endpointPort] {
229+
return fmt.Errorf("duplicate endpoint port %d in port overrides: each endpoint port can only be used once in port overrides for an endpoint group", endpointPort)
230+
}
231+
endpointPorts[endpointPort] = true
232+
}
233+
234+
return nil
235+
}
236+
237+
// validatePortOverrides is a wrapper function that runs all port override validation rules
238+
func (b *defaultEndpointGroupBuilder) validatePortOverrides(listener *agamodel.Listener, portOverrides []agamodel.PortOverride) error {
239+
// Validate listener port overrides against listener port ranges
240+
if err := b.validateListenerPortOverrideWithinListenerPortRanges(listener, portOverrides); err != nil {
241+
return err
242+
}
243+
244+
// Check for duplicate listener and endpoint ports within this endpoint group's port overrides
245+
if err := b.validateNoDuplicatePorts(portOverrides); err != nil {
246+
return err
247+
}
248+
249+
return nil
250+
}

0 commit comments

Comments
 (0)