diff --git a/consumer_group.go b/consumer_group.go index 230c5cefe..b603d1705 100644 --- a/consumer_group.go +++ b/consumer_group.go @@ -114,6 +114,7 @@ func newConsumerGroup(groupID string, client Client) (ConsumerGroup, error) { groupID: groupID, errors: make(chan error, config.ChannelBufferSize), closed: make(chan none), + userData: config.Consumer.Group.Member.UserData, }, nil } @@ -329,7 +330,15 @@ func (c *consumerGroup) newSession(ctx context.Context, topics []string, handler return nil, err } claims = members.Topics - c.userData = members.UserData + + // in the case of stateful balance strategies, hold on to the returned + // assignment metadata, otherwise, reset the statically defined conusmer + // group metadata + if members.UserData != nil { + c.userData = members.UserData + } else { + c.userData = c.config.Consumer.Group.Member.UserData + } for _, partitions := range claims { sort.Sort(int32Slice(partitions)) @@ -351,14 +360,9 @@ func (c *consumerGroup) joinGroupRequest(coordinator *Broker, topics []string) ( req.RebalanceTimeout = int32(c.config.Consumer.Group.Rebalance.Timeout / time.Millisecond) } - // use static user-data if configured, otherwise use consumer-group userdata from the last sync - userData := c.config.Consumer.Group.Member.UserData - if len(userData) == 0 { - userData = c.userData - } meta := &ConsumerGroupMemberMetadata{ Topics: topics, - UserData: userData, + UserData: c.userData, } strategy := c.config.Consumer.Group.Rebalance.Strategy if err := req.AddGroupProtocolMetadata(strategy.Name(), meta); err != nil { diff --git a/functional_consumer_group_test.go b/functional_consumer_group_test.go index 581b498a7..13ed2e935 100644 --- a/functional_consumer_group_test.go +++ b/functional_consumer_group_test.go @@ -8,6 +8,7 @@ import ( "fmt" "log" "reflect" + "strings" "sync" "sync/atomic" "testing" @@ -48,6 +49,51 @@ func TestFuncConsumerGroupPartitioning(t *testing.T) { m2.AssertCleanShutdown() } +func TestFuncConsumerGroupPartitioningStateful(t *testing.T) { + checkKafkaVersion(t, "0.10.2") + setupFunctionalTest(t) + defer teardownFunctionalTest(t) + + groupID := testFuncConsumerGroupID(t) + + m1s := newTestStatefulStrategy(t) + config := defaultConfig("M1") + config.Consumer.Group.Rebalance.Strategy = m1s + config.Consumer.Group.Member.UserData = []byte(config.ClientID) + + // start M1 + m1 := runTestFuncConsumerGroupMemberWithConfig(t, config, groupID, 0, nil) + defer m1.Stop() + m1.WaitForState(2) + m1.WaitForClaims(map[string]int{"test.4": 4}) + m1.WaitForHandlers(4) + m1s.AssertInitialValues(1) + + m2s := newTestStatefulStrategy(t) + config = defaultConfig("M2") + config.Consumer.Group.Rebalance.Strategy = m2s + config.Consumer.Group.Member.UserData = []byte(config.ClientID) + + // start M2 + m2 := runTestFuncConsumerGroupMemberWithConfig(t, config, groupID, 0, nil, "test.1", "test.4") + defer m2.Stop() + m2.WaitForState(2) + m1s.AssertInitialValues(2) + m2s.AssertNoInitialValues() + + // assert that claims are shared among both members + m1.WaitForClaims(map[string]int{"test.4": 2}) + m1.WaitForHandlers(2) + m2.WaitForClaims(map[string]int{"test.1": 1, "test.4": 2}) + m2.WaitForHandlers(3) + + // shutdown M1, wait for M2 to take over + m1.AssertCleanShutdown() + m2.WaitForClaims(map[string]int{"test.1": 1, "test.4": 4}) + m2.WaitForHandlers(5) + m2s.AssertNoInitialValues() +} + func TestFuncConsumerGroupExcessConsumers(t *testing.T) { checkKafkaVersion(t, "0.10.2") setupFunctionalTest(t) @@ -305,15 +351,25 @@ type testFuncConsumerGroupMember struct { mu sync.RWMutex } -func runTestFuncConsumerGroupMember(t *testing.T, groupID, clientID string, maxMessages int32, sink *testFuncConsumerGroupSink, topics ...string) *testFuncConsumerGroupMember { - t.Helper() - - config := NewTestConfig() +func defaultConfig(clientID string) *Config { + config := NewConfig() config.ClientID = clientID config.Version = V0_10_2_0 config.Consumer.Return.Errors = true config.Consumer.Offsets.Initial = OffsetOldest config.Consumer.Group.Rebalance.Timeout = 10 * time.Second + return config +} + +func runTestFuncConsumerGroupMember(t *testing.T, groupID, clientID string, maxMessages int32, sink *testFuncConsumerGroupSink, topics ...string) *testFuncConsumerGroupMember { + t.Helper() + + config := defaultConfig(clientID) + return runTestFuncConsumerGroupMemberWithConfig(t, config, groupID, maxMessages, sink, topics...) +} + +func runTestFuncConsumerGroupMemberWithConfig(t *testing.T, config *Config, groupID string, maxMessages int32, sink *testFuncConsumerGroupSink, topics ...string) *testFuncConsumerGroupMember { + t.Helper() group, err := NewConsumerGroup(FunctionalTestEnv.KafkaBrokerAddrs, groupID, config) if err != nil { @@ -327,7 +383,7 @@ func runTestFuncConsumerGroupMember(t *testing.T, groupID, clientID string, maxM member := &testFuncConsumerGroupMember{ ConsumerGroup: group, - clientID: clientID, + clientID: config.ClientID, claims: make(map[string]int), maxMessages: maxMessages, isCapped: maxMessages != 0, @@ -488,3 +544,53 @@ func (m *testFuncConsumerGroupMember) loop(topics []string) { } } } + +func newTestStatefulStrategy(t *testing.T) *testStatefulStrategy { + return &testStatefulStrategy{ + BalanceStrategy: BalanceStrategyRange, + t: t, + } +} + +type testStatefulStrategy struct { + BalanceStrategy + t *testing.T + initial int32 + state sync.Map +} + +func (h *testStatefulStrategy) Name() string { + return "TestStatefulStrategy" +} + +func (h *testStatefulStrategy) Plan(members map[string]ConsumerGroupMemberMetadata, topics map[string][]int32) (BalanceStrategyPlan, error) { + h.state = sync.Map{} + for memberID, metadata := range members { + if !strings.HasSuffix(string(metadata.UserData), "-stateful") { + metadata.UserData = []byte(string(metadata.UserData) + "-stateful") + atomic.AddInt32(&h.initial, 1) + } + h.state.Store(memberID, metadata.UserData) + } + return h.BalanceStrategy.Plan(members, topics) +} + +func (h *testStatefulStrategy) AssignmentData(memberID string, topics map[string][]int32, generationID int32) ([]byte, error) { + if obj, ok := h.state.Load(memberID); ok { + return obj.([]byte), nil + } + return nil, nil +} + +func (h *testStatefulStrategy) AssertInitialValues(count int32) { + h.t.Helper() + actual := atomic.LoadInt32(&h.initial) + if actual != count { + h.t.Fatalf("unexpected count of initial values: %d, expected: %d", actual, count) + } +} + +func (h *testStatefulStrategy) AssertNoInitialValues() { + h.t.Helper() + h.AssertInitialValues(0) +}