Skip to content

Commit

Permalink
Merge pull request #2248 from Shopify/dnwe/describe-groups-response
Browse files Browse the repository at this point in the history
fix(protocol): tidyup DescribeGroupsResponse
  • Loading branch information
dnwe authored Jun 7, 2022
2 parents 9bf344f + 41bea2e commit b2d1b0a
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 93 deletions.
156 changes: 90 additions & 66 deletions describe_groups_response.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
package sarama

type DescribeGroupsResponse struct {
Version int16
// Version defines the protocol version to use for encode and decode
Version int16
// ThrottleTimeMs contains the duration in milliseconds for which the
// request was throttled due to a quota violation, or zero if the request
// did not violate any quota.
ThrottleTimeMs int32
Groups []*GroupDescription
// Groups contains each described group.
Groups []*GroupDescription
}

func (r *DescribeGroupsResponse) encode(pe packetEncoder) error {
func (r *DescribeGroupsResponse) encode(pe packetEncoder) (err error) {
if r.Version >= 1 {
pe.putInt32(r.ThrottleTimeMs)
}
if err := pe.putArrayLength(len(r.Groups)); err != nil {
return err
}

for _, groupDescription := range r.Groups {
groupDescription.Version = r.Version
if err := groupDescription.encode(pe); err != nil {
for _, block := range r.Groups {
if err := block.encode(pe, r.Version); err != nil {
return err
}
}
Expand All @@ -31,17 +35,16 @@ func (r *DescribeGroupsResponse) decode(pd packetDecoder, version int16) (err er
return err
}
}
n, err := pd.getArrayLength()
if err != nil {
if numGroups, err := pd.getArrayLength(); err != nil {
return err
}

r.Groups = make([]*GroupDescription, n)
for i := 0; i < n; i++ {
r.Groups[i] = new(GroupDescription)
r.Groups[i].Version = r.Version
if err := r.Groups[i].decode(pd); err != nil {
return err
} else if numGroups > 0 {
r.Groups = make([]*GroupDescription, numGroups)
for i := 0; i < numGroups; i++ {
block := &GroupDescription{}
if err := block.decode(pd, r.Version); err != nil {
return err
}
r.Groups[i] = block
}
}

Expand All @@ -68,20 +71,32 @@ func (r *DescribeGroupsResponse) requiredVersion() KafkaVersion {
return V0_9_0_0
}

// GroupDescription contains each described group.
type GroupDescription struct {
// Version defines the protocol version to use for encode and decode
Version int16

Err KError
GroupId string
State string
ProtocolType string
Protocol string
Members map[string]*GroupMemberDescription
// Err contains the describe error as the KError type.
Err KError
// ErrorCode contains the describe error, or 0 if there was no error.
ErrorCode int16
// GroupId contains the group ID string.
GroupId string
// State contains the group state string, or the empty string.
State string
// ProtocolType contains the group protocol type, or the empty string.
ProtocolType string
// Protocol contains the group protocol data, or the empty string.
Protocol string
// Members contains the group members.
Members map[string]*GroupMemberDescription
// AuthorizedOperations contains a 32-bit bitfield to represent authorized
// operations for this group.
AuthorizedOperations int32
}

func (gd *GroupDescription) encode(pe packetEncoder) error {
pe.putInt16(int16(gd.Err))
func (gd *GroupDescription) encode(pe packetEncoder, version int16) (err error) {
gd.Version = version
pe.putInt16(gd.ErrorCode)

if err := pe.putString(gd.GroupId); err != nil {
return err
Expand All @@ -100,13 +115,8 @@ func (gd *GroupDescription) encode(pe packetEncoder) error {
return err
}

for memberId, groupMemberDescription := range gd.Members {
if err := pe.putString(memberId); err != nil {
return err
}
// encode with version
groupMemberDescription.Version = gd.Version
if err := groupMemberDescription.encode(pe); err != nil {
for _, block := range gd.Members {
if err := block.encode(pe, gd.Version); err != nil {
return err
}
}
Expand All @@ -118,44 +128,38 @@ func (gd *GroupDescription) encode(pe packetEncoder) error {
return nil
}

func (gd *GroupDescription) decode(pd packetDecoder) (err error) {
kerr, err := pd.getInt16()
if err != nil {
func (gd *GroupDescription) decode(pd packetDecoder, version int16) (err error) {
gd.Version = version
if gd.ErrorCode, err = pd.getInt16(); err != nil {
return err
}

gd.Err = KError(kerr)
gd.Err = KError(gd.ErrorCode)

if gd.GroupId, err = pd.getString(); err != nil {
return
return err
}
if gd.State, err = pd.getString(); err != nil {
return
return err
}
if gd.ProtocolType, err = pd.getString(); err != nil {
return
return err
}
if gd.Protocol, err = pd.getString(); err != nil {
return
}

n, err := pd.getArrayLength()
if err != nil {
return err
}

if n > 0 {
gd.Members = make(map[string]*GroupMemberDescription)
for i := 0; i < n; i++ {
memberId, err := pd.getString()
if err != nil {
if numMembers, err := pd.getArrayLength(); err != nil {
return err
} else if numMembers > 0 {
gd.Members = make(map[string]*GroupMemberDescription, numMembers)
for i := 0; i < numMembers; i++ {
block := &GroupMemberDescription{}
if err := block.decode(pd, gd.Version); err != nil {
return err
}

gd.Members[memberId] = new(GroupMemberDescription)
gd.Members[memberId].Version = gd.Version
if err := gd.Members[memberId].decode(pd); err != nil {
return err
if block != nil {
gd.Members[block.MemberId] = block
}
}
}
Expand All @@ -169,17 +173,33 @@ func (gd *GroupDescription) decode(pd packetDecoder) (err error) {
return nil
}

// GroupMemberDescription contains the group members.
type GroupMemberDescription struct {
// Version defines the protocol version to use for encode and decode
Version int16

GroupInstanceId *string
ClientId string
ClientHost string
MemberMetadata []byte
// MemberId contains the member ID assigned by the group coordinator.
MemberId string
// GroupInstanceId contains the unique identifier of the consumer instance
// provided by end user.
GroupInstanceId *string
// ClientId contains the client ID used in the member's latest join group
// request.
ClientId string
// ClientHost contains the client host.
ClientHost string
// MemberMetadata contains the metadata corresponding to the current group
// protocol in use.
MemberMetadata []byte
// MemberAssignment contains the current assignment provided by the group
// leader.
MemberAssignment []byte
}

func (gmd *GroupMemberDescription) encode(pe packetEncoder) error {
func (gmd *GroupMemberDescription) encode(pe packetEncoder, version int16) (err error) {
gmd.Version = version
if err := pe.putString(gmd.MemberId); err != nil {
return err
}
if gmd.Version >= 4 {
if err := pe.putNullableString(gmd.GroupInstanceId); err != nil {
return err
Expand All @@ -201,23 +221,27 @@ func (gmd *GroupMemberDescription) encode(pe packetEncoder) error {
return nil
}

func (gmd *GroupMemberDescription) decode(pd packetDecoder) (err error) {
func (gmd *GroupMemberDescription) decode(pd packetDecoder, version int16) (err error) {
gmd.Version = version
if gmd.MemberId, err = pd.getString(); err != nil {
return err
}
if gmd.Version >= 4 {
if gmd.GroupInstanceId, err = pd.getNullableString(); err != nil {
return
return err
}
}
if gmd.ClientId, err = pd.getString(); err != nil {
return
return err
}
if gmd.ClientHost, err = pd.getString(); err != nil {
return
return err
}
if gmd.MemberMetadata, err = pd.getBytes(); err != nil {
return
return err
}
if gmd.MemberAssignment, err = pd.getBytes(); err != nil {
return
return err
}

return nil
Expand Down
48 changes: 21 additions & 27 deletions describe_groups_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"errors"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
)

var (
Expand Down Expand Up @@ -161,7 +163,7 @@ var (
func TestDescribeGroupsResponseV1plus(t *testing.T) {
groupInstanceId := "gid"
tests := []struct {
CaseName string
Name string
Version int16
MessageBytes []byte
Message *DescribeGroupsResponse
Expand All @@ -171,9 +173,7 @@ func TestDescribeGroupsResponseV1plus(t *testing.T) {
3,
describeGroupsResponseEmptyV3,
&DescribeGroupsResponse{
Version: 3,
ThrottleTimeMs: int32(0),
Groups: []*GroupDescription{},
Version: 3,
},
},
{
Expand All @@ -194,6 +194,7 @@ func TestDescribeGroupsResponseV1plus(t *testing.T) {
Members: map[string]*GroupMemberDescription{
"id": {
Version: 3,
MemberId: "id",
ClientId: "sarama",
ClientHost: "localhost",
MemberMetadata: []byte{1, 2, 3},
Expand All @@ -202,13 +203,9 @@ func TestDescribeGroupsResponseV1plus(t *testing.T) {
},
},
{
Version: 3,
Err: KError(30),
GroupId: "",
State: "",
ProtocolType: "",
Protocol: "",
Members: nil,
Version: 3,
Err: KError(30),
ErrorCode: 30,
},
},
},
Expand All @@ -218,9 +215,7 @@ func TestDescribeGroupsResponseV1plus(t *testing.T) {
4,
describeGroupsResponseEmptyV4,
&DescribeGroupsResponse{
Version: 4,
ThrottleTimeMs: int32(0),
Groups: []*GroupDescription{},
Version: 4,
},
},
{
Expand All @@ -241,6 +236,7 @@ func TestDescribeGroupsResponseV1plus(t *testing.T) {
Members: map[string]*GroupMemberDescription{
"id": {
Version: 4,
MemberId: "id",
GroupInstanceId: &groupInstanceId,
ClientId: "sarama",
ClientHost: "localhost",
Expand All @@ -250,25 +246,23 @@ func TestDescribeGroupsResponseV1plus(t *testing.T) {
},
},
{
Version: 4,
Err: KError(30),
GroupId: "",
State: "",
ProtocolType: "",
Protocol: "",
Members: nil,
Version: 4,
Err: KError(30),
ErrorCode: 30,
},
},
},
},
}

for _, c := range tests {
response := new(DescribeGroupsResponse)
testVersionDecodable(t, c.CaseName, response, c.MessageBytes, c.Version)
if !reflect.DeepEqual(c.Message, response) {
t.Errorf("case %s decode failed, expected:%+v got %+v", c.CaseName, c.Message, response)
}
testEncodable(t, c.CaseName, c.Message, c.MessageBytes)
t.Run(c.Name, func(t *testing.T) {
response := new(DescribeGroupsResponse)
testVersionDecodable(t, c.Name, response, c.MessageBytes, c.Version)
if !assert.Equal(t, c.Message, response) {
t.Errorf("case %s decode failed, expected:%+v got %+v", c.Name, c.Message, response)
}
testEncodable(t, c.Name, c.Message, c.MessageBytes)
})
}
}

0 comments on commit b2d1b0a

Please sign in to comment.