diff --git a/balance_strategy.go b/balance_strategy.go index 67c4d96d0..56da276d2 100644 --- a/balance_strategy.go +++ b/balance_strategy.go @@ -47,6 +47,10 @@ type BalanceStrategy interface { // Plan accepts a map of `memberID -> metadata` and a map of `topic -> partitions` // and returns a distribution plan. Plan(members map[string]ConsumerGroupMemberMetadata, topics map[string][]int32) (BalanceStrategyPlan, error) + + // AssignmentData returns the serialized assignment data for the specified + // memberID + AssignmentData(memberID string, topics map[string][]int32, generationID int32) ([]byte, error) } // -------------------------------------------------------------------- @@ -132,6 +136,11 @@ func (s *balanceStrategy) Plan(members map[string]ConsumerGroupMemberMetadata, t return plan, nil } +// AssignmentData simple strategies do not require any shared assignment data +func (s *balanceStrategy) AssignmentData(memberID string, topics map[string][]int32, generationID int32) ([]byte, error) { + return nil, nil +} + type balanceStrategySortable struct { topic string memberIDs []string @@ -268,6 +277,15 @@ func (s *stickyBalanceStrategy) Plan(members map[string]ConsumerGroupMemberMetad return plan, nil } +// AssignmentData serializes the set of topics currently assigned to the +// specified member as part of the supplied balance plan +func (s *stickyBalanceStrategy) AssignmentData(memberID string, topics map[string][]int32, generationID int32) ([]byte, error) { + return encode(&StickyAssignorUserDataV1{ + Topics: topics, + Generation: generationID, + }, nil) +} + func strsContains(s []string, value string) bool { for _, entry := range s { if entry == value { diff --git a/balance_strategy_test.go b/balance_strategy_test.go index e9c72748b..f930d7663 100644 --- a/balance_strategy_test.go +++ b/balance_strategy_test.go @@ -1,6 +1,7 @@ package sarama import ( + "bytes" "fmt" "math" "math/rand" @@ -62,6 +63,27 @@ func TestBalanceStrategyRange(t *testing.T) { } } +func TestBalanceStrategyRangeAssignmentData(t *testing.T) { + + strategy := BalanceStrategyRange + + members := make(map[string]ConsumerGroupMemberMetadata, 2) + members["consumer1"] = ConsumerGroupMemberMetadata{ + Topics: []string{"topic1"}, + } + members["consumer2"] = ConsumerGroupMemberMetadata{ + Topics: []string{"topic1"}, + } + + actual, err := strategy.AssignmentData("consumer1", map[string][]int32{"topic1": {0, 1}}, 1) + if err != nil { + t.Errorf("Error building assignment data: %v", err) + } + if actual != nil { + t.Error("Invalid assignment data returned from AssignmentData") + } +} + func TestBalanceStrategyRoundRobin(t *testing.T) { tests := []struct { members map[string][]string @@ -191,6 +213,27 @@ func Test_deserializeTopicPartitionAssignment(t *testing.T) { } } +func TestBalanceStrategyRoundRobinAssignmentData(t *testing.T) { + + strategy := BalanceStrategyRoundRobin + + members := make(map[string]ConsumerGroupMemberMetadata, 2) + members["consumer1"] = ConsumerGroupMemberMetadata{ + Topics: []string{"topic1"}, + } + members["consumer2"] = ConsumerGroupMemberMetadata{ + Topics: []string{"topic1"}, + } + + actual, err := strategy.AssignmentData("consumer1", map[string][]int32{"topic1": {0, 1}}, 1) + if err != nil { + t.Errorf("Error building assignment data: %v", err) + } + if actual != nil { + t.Error("Invalid assignment data returned from AssignmentData") + } +} + func Test_prepopulateCurrentAssignments(t *testing.T) { type args struct { members map[string]ConsumerGroupMemberMetadata @@ -1950,6 +1993,29 @@ func Test_stickyBalanceStrategy_Plan_ConflictingPreviousAssignments(t *testing.T verifyFullyBalanced(t, plan) } +func Test_stickyBalanceStrategy_Plan_AssignmentData(t *testing.T) { + + s := &stickyBalanceStrategy{} + + members := make(map[string]ConsumerGroupMemberMetadata, 2) + members["consumer1"] = ConsumerGroupMemberMetadata{ + Topics: []string{"topic1"}, + } + members["consumer2"] = ConsumerGroupMemberMetadata{ + Topics: []string{"topic1"}, + } + + expected := encodeSubscriberPlanWithGeneration(t, map[string][]int32{"topic1": {0, 1}}, 1) + + actual, err := s.AssignmentData("consumer1", map[string][]int32{"topic1": {0, 1}}, 1) + if err != nil { + t.Errorf("Error building assignment data: %v", err) + } + if bytes.Compare(expected, actual) != 0 { + t.Error("Invalid assignment data returned from AssignmentData") + } +} + func BenchmarkStickAssignmentWithLargeNumberOfConsumersAndTopics(b *testing.B) { s := &stickyBalanceStrategy{} r := rand.New(rand.NewSource(time.Now().UnixNano())) diff --git a/consumer_group.go b/consumer_group.go index fc95cd0df..951f64b33 100644 --- a/consumer_group.go +++ b/consumer_group.go @@ -331,20 +331,14 @@ func (c *consumerGroup) syncGroupRequest(coordinator *Broker, plan BalanceStrate MemberId: c.memberID, GenerationId: generationID, } + strategy := c.config.Consumer.Group.Rebalance.Strategy for memberID, topics := range plan { assignment := &ConsumerGroupMemberAssignment{Topics: topics} - - // Include topic assignments in group-assignment userdata for each consumer-group member - if c.config.Consumer.Group.Rebalance.Strategy.Name() == StickyBalanceStrategyName { - userDataBytes, err := encode(&StickyAssignorUserDataV1{ - Topics: topics, - Generation: generationID, - }, nil) - if err != nil { - return nil, err - } - assignment.UserData = userDataBytes + userDataBytes, err := strategy.AssignmentData(memberID, topics, generationID) + if err != nil { + return nil, err } + assignment.UserData = userDataBytes if err := req.AddGroupAssignmentMember(memberID, assignment); err != nil { return nil, err }