From 3044ac6a2b24b32fd901b07145424ab89697bd92 Mon Sep 17 00:00:00 2001 From: Belyakov Sergey Sergeevich Date: Wed, 31 Jul 2024 13:01:59 +0700 Subject: [PATCH] added support incremental rebalance in one session Signed-off-by: Belyakov Sergey Sergeevich --- config.go | 2 + consumer_group.go | 560 +++++++++++++++++++++++++++++++++++++++++++--- offset_manager.go | 13 +- 3 files changed, 540 insertions(+), 35 deletions(-) diff --git a/config.go b/config.go index f2f197887..c801707f3 100644 --- a/config.go +++ b/config.go @@ -304,6 +304,8 @@ type Config struct { Interval time.Duration } Rebalance struct { + // IsIncremental to allow don't stopping the rebalance process when a new member joins the group. + IsIncremental bool // Strategy for allocating topic partitions to members. // Deprecated: Strategy exists for historical compatibility // and should not be used. Please use GroupStrategies. diff --git a/consumer_group.go b/consumer_group.go index 53b64dd3b..fa0a6904e 100644 --- a/consumer_group.go +++ b/consumer_group.go @@ -4,8 +4,10 @@ import ( "context" "errors" "fmt" + "slices" "sort" "sync" + "sync/atomic" "time" "github.com/rcrowley/go-metrics" @@ -91,9 +93,11 @@ type consumerGroup struct { closed chan none closeOnce sync.Once - userData []byte + userData []byte + rebalanceInProgress chan none metricRegistry metrics.Registry + loopPartCheck atomic.Bool } // NewConsumerGroup creates a new consumer group the given broker addresses and configuration. @@ -135,14 +139,15 @@ func newConsumerGroup(groupID string, client Client) (ConsumerGroup, error) { } cg := &consumerGroup{ - client: client, - consumer: consumer, - config: config, - groupID: groupID, - errors: make(chan error, config.ChannelBufferSize), - closed: make(chan none), - userData: config.Consumer.Group.Member.UserData, - metricRegistry: newCleanupRegistry(config.MetricRegistry), + client: client, + consumer: consumer, + config: config, + groupID: groupID, + errors: make(chan error, config.ChannelBufferSize), + closed: make(chan none), + userData: config.Consumer.Group.Member.UserData, + metricRegistry: newCleanupRegistry(config.MetricRegistry), + rebalanceInProgress: make(chan none), } if config.Consumer.Group.InstanceId != "" && config.Version.IsAtLeast(V2_3_0_0) { cg.groupInstanceId = &config.Consumer.Group.InstanceId @@ -213,6 +218,10 @@ func (c *consumerGroup) Consume(ctx context.Context, topics []string, handler Co return err } + if c.config.Consumer.Group.Rebalance.IsIncremental { + go c.incRebalanceLoop(ctx, topics, handler, c.config.Consumer.Group.Rebalance.Retry.Max, sess) + } + // Wait for session exit signal or Close() call select { case <-c.closed: @@ -223,6 +232,22 @@ func (c *consumerGroup) Consume(ctx context.Context, topics []string, handler Co return sess.release(true) } +func (c *consumerGroup) incRebalanceLoop(ctx context.Context, topics []string, handler ConsumerGroupHandler, retries int, sess *consumerGroupSession) { + for { + if _, ok := <-sess.rebalanceInProgress; !ok { + Logger.Printf("closed rebalanceInProgress\n") + return + } + + err := c.IncRebalance(ctx, topics, handler, retries, sess) + if err != nil { + Logger.Printf("Error during incremental rebalance: %v", err) + sess.cancel() + return + } + } +} + // Pause implements ConsumerGroup. func (c *consumerGroup) Pause(partitions map[string][]int32) { c.consumer.Pause(partitions) @@ -442,6 +467,280 @@ func (c *consumerGroup) newSession(ctx context.Context, topics []string, handler return session, err } +func membersFromJoinResp(joinResp *JoinGroupResponse) (map[string]ConsumerGroupMemberMetadata, error) { + // Prepare distribution plan if we joined as the leader + var members map[string]ConsumerGroupMemberMetadata + var err error + if joinResp.LeaderId == joinResp.MemberId { + members, err = joinResp.GetMembers() + if err != nil { + return nil, err + } + } + + return members, nil +} + +func (c *consumerGroup) leaderDistributionPlan(joinResp *JoinGroupResponse, strategy BalanceStrategy, members map[string]ConsumerGroupMemberMetadata) (map[string][]int32, []string, BalanceStrategyPlan, error) { + // Prepare distribution plan if we joined as the leader + var plan BalanceStrategyPlan + var allSubscribedTopicPartitions map[string][]int32 + var allSubscribedTopics []string + var err error + if joinResp.LeaderId == joinResp.MemberId { + allSubscribedTopicPartitions, allSubscribedTopics, plan, err = c.balance(strategy, members) + if err != nil { + return nil, nil, nil, err + } + } + + return allSubscribedTopicPartitions, allSubscribedTopics, plan, nil +} + +func (c *consumerGroup) strategy(joinResp *JoinGroupResponse) (BalanceStrategy, error) { + var strategy BalanceStrategy + var ok bool + if strategy = c.config.Consumer.Group.Rebalance.Strategy; strategy == nil { + strategy, ok = c.findStrategy(joinResp.GroupProtocol, c.config.Consumer.Group.Rebalance.GroupStrategies) + if !ok { + // this case shouldn't happen in practice, since the leader will choose the protocol + // that all the members support + return nil, fmt.Errorf("unable to find selected strategy: %s", joinResp.GroupProtocol) + } + } + + return strategy, nil +} + +func (c *consumerGroup) joinIncRebalanceReq(coordinator *Broker, topics []string) (*JoinGroupResponse, error) { + // Join consumer group + var ( + metricRegistry = c.metricRegistry + consumerGroupJoinTotal metrics.Counter + consumerGroupJoinFailed metrics.Counter + ) + + if metricRegistry != nil { + consumerGroupJoinTotal = metrics.GetOrRegisterCounter(fmt.Sprintf("consumer-group-join-total-%s", c.groupID), metricRegistry) + consumerGroupJoinFailed = metrics.GetOrRegisterCounter(fmt.Sprintf("consumer-group-join-failed-%s", c.groupID), metricRegistry) + } + + join, err := c.joinGroupRequest(coordinator, topics) + if consumerGroupJoinTotal != nil { + consumerGroupJoinTotal.Inc(1) + } + if err != nil { + _ = coordinator.Close() + if consumerGroupJoinFailed != nil { + consumerGroupJoinFailed.Inc(1) + } + return nil, err + } + if !errors.Is(join.Err, ErrNoError) { + if consumerGroupJoinFailed != nil { + consumerGroupJoinFailed.Inc(1) + } + } + + if join.Err != ErrNoError { + return nil, join.Err + } + + return join, nil +} + +func (c *consumerGroup) SyncGroupIncRebalanceReq( + coordinator *Broker, + members map[string]ConsumerGroupMemberMetadata, + generationID int32, + plan BalanceStrategyPlan, + strategy BalanceStrategy, +) (*SyncGroupResponse, error) { + var ( + metricRegistry = c.metricRegistry + consumerGroupSyncTotal metrics.Counter + consumerGroupSyncFailed metrics.Counter + ) + + if metricRegistry != nil { + consumerGroupSyncTotal = metrics.GetOrRegisterCounter(fmt.Sprintf("consumer-group-sync-total-%s", c.groupID), metricRegistry) + consumerGroupSyncFailed = metrics.GetOrRegisterCounter(fmt.Sprintf("consumer-group-sync-failed-%s", c.groupID), metricRegistry) + } + + // Sync consumer group + syncGroupResponse, err := c.syncGroupRequest(coordinator, members, plan, generationID, strategy) + if consumerGroupSyncTotal != nil { + consumerGroupSyncTotal.Inc(1) + } + if err != nil { + _ = coordinator.Close() + if consumerGroupSyncFailed != nil { + consumerGroupSyncFailed.Inc(1) + } + return nil, err + } + if !errors.Is(syncGroupResponse.Err, ErrNoError) { + if consumerGroupSyncFailed != nil { + consumerGroupSyncFailed.Inc(1) + } + } + + switch syncGroupResponse.Err { + case ErrNoError: + case ErrUnknownMemberId, ErrIllegalGeneration: + return nil, syncGroupResponse.Err + case ErrNotCoordinatorForConsumer, ErrRebalanceInProgress, ErrOffsetsLoadInProgress: + // retry after backoff + return nil, syncGroupResponse.Err + + case ErrFencedInstancedId: + if c.groupInstanceId != nil { + Logger.Printf("SyncGroupp failed: group instance id %s has been fenced\n", *c.groupInstanceId) + } + return nil, syncGroupResponse.Err + default: + return nil, syncGroupResponse.Err + } + + return syncGroupResponse, nil +} + +type ReJoinResponse struct { + join *JoinGroupResponse + sync *SyncGroupResponse + allSubscribedTopicPartitions map[string][]int32 + allSubscribedTopics []string +} + +func (c *consumerGroup) ReJoin( + _ context.Context, + coordinator *Broker, + topics []string, + sess *consumerGroupSession, +) (*ReJoinResponse, error) { + sess.Commit() + + join, err := c.joinIncRebalanceReq(coordinator, topics) + + if err != nil { + return nil, err + } + + strategy, err := c.strategy(join) + + if err != nil { + return nil, err + } + + members, err := membersFromJoinResp(join) + + if err != nil { + return nil, err + } + + allSubscribedTopicPartitions, allSubscribedTopics, plan, err := c.leaderDistributionPlan(join, strategy, members) + if err != nil { + return nil, err + } + + sess.updateGenerationID(join.GenerationId) + sess.offsets.updateGeneration(join.GenerationId) + + syncGroupResponse, err := c.SyncGroupIncRebalanceReq(coordinator, members, join.GenerationId, plan, strategy) + if err != nil { + return nil, err + } + + return &ReJoinResponse{ + join: join, + sync: syncGroupResponse, + allSubscribedTopicPartitions: allSubscribedTopicPartitions, + allSubscribedTopics: allSubscribedTopics, + }, nil +} + +func (c *consumerGroup) claimsFromSyncResp(syncResponse *SyncGroupResponse) (map[string][]int32, error) { + var claims map[string][]int32 + if len(syncResponse.MemberAssignment) > 0 { + members, err := syncResponse.GetMemberAssignment() + if err != nil { + return nil, err + } + claims = members.Topics + + if members.UserData != nil { + c.userData = members.UserData + } else { + c.userData = c.config.Consumer.Group.Member.UserData + } + } + + return claims, nil +} + +func (c *consumerGroup) IncRebalance( + ctx context.Context, + topics []string, + handler ConsumerGroupHandler, + retries int, + sess *consumerGroupSession, +) error { + Logger.Printf("start incremental rebalance\n") + defer func() { + Logger.Printf("end incremental rebalance\n") + }() + + if ctx.Err() != nil { + return ctx.Err() + } + + coordinator, err := c.client.Coordinator(c.groupID) + if err != nil { + return err + } + + reJoinResp, err := c.ReJoin(context.TODO(), coordinator, topics, sess) + if err != nil { + if retries <= 0 { + return err + } + + Logger.Printf("Rejoin failed, retry request\n") + + return c.IncRebalance(ctx, topics, handler, retries-1, sess) + } + + sess.Commit() + + err = c.rebalanceIncClaims(reJoinResp, sess) + + return err +} + +func (c *consumerGroup) rebalanceIncClaims(reJoinResp *ReJoinResponse, sess *consumerGroupSession) error { + claims, err := c.claimsFromSyncResp(reJoinResp.sync) + + if err != nil { + return err + } + + // difference between the current claims and the previous claims + removedClaims, addedClaims := diffClaims(sess.claims, claims) + + Logger.Printf("new claims: %v\n", claims) + sess.incStartNewClaims(reJoinResp.join.GenerationId, claims, removedClaims, addedClaims) + + Logger.Printf("allSubscribedTopicPartitions rebalance %v\n", reJoinResp.allSubscribedTopicPartitions) + + //// only the leader needs to check whether there are newly-added partitions in order to trigger a rebalance + if reJoinResp.join.LeaderId == reJoinResp.join.MemberId { + Logger.Printf("I am leader %s %s\n", sess.parent.config.ClientID, reJoinResp.join.MemberId) + go c.loopCheckPartitionNumbers(reJoinResp.allSubscribedTopicPartitions, reJoinResp.allSubscribedTopics, sess) + } + + return nil +} + func (c *consumerGroup) joinGroupRequest(coordinator *Broker, topics []string) (*JoinGroupResponse, error) { req := &JoinGroupRequest{ GroupId: c.groupID, @@ -666,7 +965,7 @@ func (c *consumerGroup) handleError(err error, topic string, partition int32) { err = &ConsumerError{ Topic: topic, Partition: partition, - Err: err, + Err: fmt.Errorf("%w consumerGroup error", err), } } @@ -692,6 +991,12 @@ func (c *consumerGroup) handleError(err error, topic string, partition int32) { } func (c *consumerGroup) loopCheckPartitionNumbers(allSubscribedTopicPartitions map[string][]int32, topics []string, session *consumerGroupSession) { + if !c.loopPartCheck.CompareAndSwap(false, true) { + return + } + + defer c.loopPartCheck.Store(false) + if c.config.Metadata.RefreshFrequency == time.Duration(0) { return } @@ -795,19 +1100,22 @@ type ConsumerGroupSession interface { } type consumerGroupSession struct { + mu sync.RWMutex parent *consumerGroup memberID string generationID int32 handler ConsumerGroupHandler - claims map[string][]int32 - offsets *offsetManager - ctx context.Context - cancel func() + claims map[string][]int32 + claimsBrokers map[string]map[int32]*consumerGroupClaim + offsets *offsetManager + ctx context.Context + cancel func() - waitGroup sync.WaitGroup - releaseOnce sync.Once - hbDying, hbDead chan none + waitGroup sync.WaitGroup + releaseOnce sync.Once + hbDying, hbDead chan none + rebalanceInProgress chan none } func newConsumerGroupSession(ctx context.Context, parent *consumerGroup, claims map[string][]int32, memberID string, generationID int32, handler ConsumerGroupHandler) (*consumerGroupSession, error) { @@ -822,16 +1130,18 @@ func newConsumerGroupSession(ctx context.Context, parent *consumerGroup, claims // init session sess := &consumerGroupSession{ - parent: parent, - memberID: memberID, - generationID: generationID, - handler: handler, - offsets: offsets, - claims: claims, - ctx: ctx, - cancel: cancel, - hbDying: make(chan none), - hbDead: make(chan none), + parent: parent, + memberID: memberID, + generationID: generationID, + handler: handler, + offsets: offsets, + claims: claims, + claimsBrokers: make(map[string]map[int32]*consumerGroupClaim), + ctx: ctx, + cancel: cancel, + hbDying: make(chan none), + hbDead: make(chan none), + rebalanceInProgress: make(chan none), } // start heartbeat loop @@ -871,7 +1181,9 @@ func newConsumerGroupSession(ctx context.Context, parent *consumerGroup, claims // cancel the as session as soon as the first // goroutine exits - defer sess.cancel() + if !sess.parent.config.Consumer.Group.Rebalance.IsIncremental { + defer sess.cancel() + } // consume a single topic/partition, blocking sess.consume(topic, partition) @@ -881,9 +1193,44 @@ func newConsumerGroupSession(ctx context.Context, parent *consumerGroup, claims return sess, nil } -func (s *consumerGroupSession) Claims() map[string][]int32 { return s.claims } -func (s *consumerGroupSession) MemberID() string { return s.memberID } -func (s *consumerGroupSession) GenerationID() int32 { return s.generationID } +func (s *consumerGroupSession) Claims() map[string][]int32 { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.claims +} + +func (s *consumerGroupSession) MemberID() string { return s.memberID } +func (s *consumerGroupSession) GenerationID() int32 { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.generationID +} + +func (s *consumerGroupSession) updateGenerationID(generationID int32) { + s.mu.Lock() + defer s.mu.Unlock() + + s.generationID = generationID +} + +func (s *consumerGroupSession) updateClaims(claims map[string][]int32) { + s.mu.Lock() + defer s.mu.Unlock() + + s.claims = claims +} + +func (s *consumerGroupSession) addBrokerClaim(topic string, partition int32, claim *consumerGroupClaim) { + s.mu.Lock() + defer s.mu.Unlock() + + if _, ok := s.claimsBrokers[topic]; !ok { + s.claimsBrokers[topic] = make(map[int32]*consumerGroupClaim) + } + s.claimsBrokers[topic][partition] = claim +} func (s *consumerGroupSession) MarkOffset(topic string, partition int32, offset int64, metadata string) { if pom := s.offsets.findPOM(topic, partition); pom != nil { @@ -932,6 +1279,8 @@ func (s *consumerGroupSession) consume(topic string, partition int32) { return } + s.addBrokerClaim(topic, partition, claim) + // handle errors go func() { for err := range claim.Errors() { @@ -1000,6 +1349,8 @@ func (s *consumerGroupSession) heartbeatLoop() { s.MemberID(), s.GenerationID()) }() + defer close(s.rebalanceInProgress) + pause := time.NewTicker(s.parent.config.Consumer.Group.Heartbeat.Interval) defer pause.Stop() @@ -1042,7 +1393,10 @@ func (s *consumerGroupSession) heartbeatLoop() { retries = s.parent.config.Metadata.Retry.Max case ErrRebalanceInProgress: retries = s.parent.config.Metadata.Retry.Max - s.cancel() + s.rebalanceInProgress <- none{} + if !s.parent.config.Consumer.Group.Rebalance.IsIncremental { + s.cancel() + } case ErrUnknownMemberId, ErrIllegalGeneration: return case ErrFencedInstancedId: @@ -1063,6 +1417,102 @@ func (s *consumerGroupSession) heartbeatLoop() { } } } +func (s *consumerGroupSession) removeClaimAndPOMBrokers(removedClaims map[string][]int32) { + s.mu.Lock() + defer s.mu.Unlock() + + for topic, removePartitions := range removedClaims { + if claimsBroker, ok := s.claimsBrokers[topic]; ok { + for part, claim := range claimsBroker { + if slices.Contains(removePartitions, part) { + err := claim.Close() + if err != nil { + s.parent.handleError(fmt.Errorf("%w erorr close claim", err), topic, part) + continue + } + Logger.Printf("closed claim, pom for topic %s, part %d", topic, part) + } + } + } + } + + s.Commit() + + for topic, removePartitions := range removedClaims { + if claimsBroker, ok := s.claimsBrokers[topic]; ok { + for part := range claimsBroker { + if slices.Contains(removePartitions, part) { + delete(s.claimsBrokers[topic], part) + if pom := s.offsets.findPOM(topic, part); pom != nil { + err := pom.Close() + if err != nil { + s.parent.handleError(fmt.Errorf("%w erorr close POM", err), topic, part) + continue + } + Logger.Printf("closed POM for topic %s, part %d", topic, part) + } + } + } + if len(s.claimsBrokers[topic]) == 0 { + delete(s.claimsBrokers, topic) + } + } + } + + s.Commit() +} + +func (s *consumerGroupSession) createNewPOMs(addedClaims map[string][]int32) { + // create a POM for each claim + for topic, partitions := range addedClaims { + for _, partition := range partitions { + pom, err := s.offsets.ManagePartition(topic, partition) + if err != nil { + _ = s.release(false) + Logger.Printf("error creating partition offset manager: %s\n", err) + return + } + + // handle POM errors + go func(topic string, partition int32) { + for err := range pom.Errors() { + s.parent.handleError(err, topic, partition) + } + }(topic, partition) + } + } +} + +func (s *consumerGroupSession) startClaims(addedClaims map[string][]int32) { + // start consuming + for topic, partitions := range addedClaims { + for _, partition := range partitions { + s.waitGroup.Add(1) + + go func(topic string, partition int32) { + defer s.waitGroup.Done() + + // consume a single topic/partition, blocking + Logger.Printf("start consuming new topic %s, partition %d\n", topic, partition) + s.consume(topic, partition) + }(topic, partition) + } + } +} + +func (s *consumerGroupSession) incStartNewClaims(generationID int32, claims, removedClaims, addedClaims map[string][]int32) { + s.updateClaims(claims) + s.removeClaimAndPOMBrokers(removedClaims) + s.createNewPOMs(addedClaims) + + if err := s.handler.Setup(s); err != nil { + _ = s.release(true) + Logger.Printf("error handler Setup: %s\n", err) + return + } + + s.startClaims(addedClaims) +} // -------------------------------------------------------------------- @@ -1158,3 +1608,49 @@ func (c *consumerGroupClaim) waitClosed() (errs ConsumerErrors) { } return } + +func diffClaims(oldClaims, newClaims map[string][]int32) (removedClaims, addedClaims map[string][]int32) { + removedClaims = make(map[string][]int32) + addedClaims = make(map[string][]int32) + + for topic, partitions := range oldClaims { + if newPartitions, ok := newClaims[topic]; ok { + removed, added := diffPartitions(partitions, newPartitions) + if len(removed) > 0 { + removedClaims[topic] = removed + } + if len(added) > 0 { + addedClaims[topic] = added + } + } else { + removedClaims[topic] = partitions + } + } + + for topic, partitions := range newClaims { + if _, ok := oldClaims[topic]; !ok { + addedClaims[topic] = partitions + } + } + + return +} + +func diffPartitions(partitions, newPartitions []int32) (removed, added []int32) { + removed = make([]int32, 0) + added = make([]int32, 0) + + for _, partition := range partitions { + if !slices.Contains(newPartitions, partition) { + removed = append(removed, partition) + } + } + + for _, partition := range newPartitions { + if !slices.Contains(partitions, partition) { + added = append(added, partition) + } + } + + return +} diff --git a/offset_manager.go b/offset_manager.go index 1bf545908..64ceb69ee 100644 --- a/offset_manager.go +++ b/offset_manager.go @@ -1,6 +1,7 @@ package sarama import ( + "fmt" "sync" "time" ) @@ -39,8 +40,8 @@ type offsetManager struct { broker *Broker brokerLock sync.RWMutex - poms map[string]map[int32]*partitionOffsetManager pomsLock sync.RWMutex + poms map[string]map[int32]*partitionOffsetManager closeOnce sync.Once closing chan none @@ -84,6 +85,12 @@ func newOffsetManagerFromClient(group, memberID string, generation int32, client return om, nil } +func (om *offsetManager) updateGeneration(generation int32) { + om.pomsLock.Lock() + defer om.pomsLock.Unlock() + om.generation = generation +} + func (om *offsetManager) ManagePartition(topic string, partition int32) (PartitionOffsetManager, error) { pom, err := om.newPartitionOffsetManager(topic, partition) if err != nil { @@ -264,7 +271,7 @@ func (om *offsetManager) flushToBroker() { resp, err := broker.CommitOffset(req) if err != nil { - om.handleError(err) + om.handleError(fmt.Errorf("%w failed to commit offset", err)) om.releaseCoordinator(broker) _ = broker.Close() return @@ -621,7 +628,7 @@ func (pom *partitionOffsetManager) handleError(err error) { cErr := &ConsumerError{ Topic: pom.topic, Partition: pom.partition, - Err: err, + Err: fmt.Errorf("%w partitionOffsetManager error", err), } if pom.parent.conf.Consumer.Return.Errors {