diff --git a/admin_test.go b/admin_test.go index 677e20a0e8..b729d6ee9b 100644 --- a/admin_test.go +++ b/admin_test.go @@ -1130,6 +1130,38 @@ func TestClusterAdminCreateAcls(t *testing.T) { } } +func TestClusterAdminCreateAclErrorHandling(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + defer seedBroker.Close() + + seedBroker.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetController(seedBroker.BrokerID()). + SetBroker(seedBroker.Addr(), seedBroker.BrokerID()), + "CreateAclsRequest": NewMockCreateAclsResponseWithError(t), + }) + + config := NewTestConfig() + config.Version = V1_0_0_0 + admin, err := NewClusterAdmin([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + r := Resource{ResourceType: AclResourceTopic, ResourceName: "my_topic"} + a := Acl{Host: "localhost", Operation: AclOperationAlter, PermissionType: AclPermissionAny} + + err = admin.CreateACL(r, a) + if err == nil { + t.Fatal(errors.New("error should have been thrown")) + } + + err = admin.Close() + if err != nil { + t.Fatal(err) + } +} + func TestClusterAdminListAcls(t *testing.T) { seedBroker := NewMockBroker(t, 1) defer seedBroker.Close() diff --git a/broker.go b/broker.go index 2d75a8cd3d..e481ad711c 100644 --- a/broker.go +++ b/broker.go @@ -666,6 +666,17 @@ func (b *Broker) CreateAcls(request *CreateAclsRequest) (*CreateAclsResponse, er return nil, err } + errs := make([]error, 0) + for _, res := range response.AclCreationResponses { + if !errors.Is(res.Err, ErrNoError) { + errs = append(errs, res.Err) + } + } + + if len(errs) > 0 { + return response, Wrap(ErrCreateACLs, errs...) + } + return response, nil } diff --git a/errors.go b/errors.go index 507002bfa6..bb64090427 100644 --- a/errors.go +++ b/errors.go @@ -64,6 +64,9 @@ var ErrReassignPartitions = errors.New("failed to reassign partitions for topic" // ErrDeleteRecords is the type of error returned when fail to delete the required records var ErrDeleteRecords = errors.New("kafka server: failed to delete records") +// ErrCreateACLs is the type of error returned when ACL creation failed +var ErrCreateACLs = errors.New("kafka server: failed to create one or more ACL rules") + // MultiErrorFormat specifies the formatter applied to format multierrors. The // default implementation is a consensed version of the hashicorp/go-multierror // default one diff --git a/message.go b/message.go index fd0d1d90b7..c6f35a3f5e 100644 --- a/message.go +++ b/message.go @@ -42,6 +42,28 @@ func (cc CompressionCodec) String() string { }[int(cc)] } +// UnmarshalText returns a CompressionCodec from its string representation. +func (cc *CompressionCodec) UnmarshalText(text []byte) error { + codecs := map[string]CompressionCodec{ + "none": CompressionNone, + "gzip": CompressionGZIP, + "snappy": CompressionSnappy, + "lz4": CompressionLZ4, + "zstd": CompressionZSTD, + } + codec, ok := codecs[string(text)] + if !ok { + return fmt.Errorf("cannot parse %q as a compression codec", string(text)) + } + *cc = codec + return nil +} + +// MarshalText transforms a CompressionCodec into its string representation. +func (cc CompressionCodec) MarshalText() ([]byte, error) { + return []byte(cc.String()), nil +} + // Message is a kafka message type type Message struct { Codec CompressionCodec // codec used to compress the message contents diff --git a/message_test.go b/message_test.go index a6c7cff2a5..d7bd430d38 100644 --- a/message_test.go +++ b/message_test.go @@ -244,3 +244,32 @@ func TestMessageDecodingUnknownVersions(t *testing.T) { t.Error("Decoding an unknown magic byte produced an unknown error ", err) } } + +func TestCompressionCodecUnmarshal(t *testing.T) { + cases := []struct { + Input string + Expected CompressionCodec + ExpectedError bool + }{ + {"none", CompressionNone, false}, + {"zstd", CompressionZSTD, false}, + {"gzip", CompressionGZIP, false}, + {"unknown", CompressionNone, true}, + } + for _, c := range cases { + var cc CompressionCodec + err := cc.UnmarshalText([]byte(c.Input)) + if err != nil && !c.ExpectedError { + t.Errorf("UnmarshalText(%q) error:\n%+v", c.Input, err) + continue + } + if err == nil && c.ExpectedError { + t.Errorf("UnmarshalText(%q) got %v but expected error", c.Input, cc) + continue + } + if cc != c.Expected { + t.Errorf("UnmarshalText(%q) got %v but expected %v", c.Input, cc, c.Expected) + continue + } + } +} diff --git a/mockresponses.go b/mockresponses.go index da816963a5..d26a448879 100644 --- a/mockresponses.go +++ b/mockresponses.go @@ -998,6 +998,24 @@ func (mr *MockCreateAclsResponse) For(reqBody versionedDecoder) encoderWithHeade return res } +type MockCreateAclsResponseError struct { + t TestReporter +} + +func NewMockCreateAclsResponseWithError(t TestReporter) *MockCreateAclsResponseError { + return &MockCreateAclsResponseError{t: t} +} + +func (mr *MockCreateAclsResponseError) For(reqBody versionedDecoder) encoderWithHeader { + req := reqBody.(*CreateAclsRequest) + res := &CreateAclsResponse{} + + for range req.AclCreations { + res.AclCreationResponses = append(res.AclCreationResponses, &AclCreationResponse{Err: ErrInvalidRequest}) + } + return res +} + type MockListAclsResponse struct { t TestReporter }