diff --git a/admin.go b/admin.go index 43bf90b2c..618981f5e 100644 --- a/admin.go +++ b/admin.go @@ -104,6 +104,15 @@ type ClusterAdmin interface { // Get information about all log directories on the given set of brokers DescribeLogDirs(brokers []int32) (map[int32][]DescribeLogDirsResponseDirMetadata, error) + // Get information about SCRAM users + DescribeUserScramCredentials(users []string) ([]*DescribeUserScramCredentialsResult, error) + + // Delete SCRAM users + DeleteUserScramCredentials(delete []AlterUserScramCredentialsDelete) ([]*AlterUserScramCredentialsResult, error) + + // Upsert SCRAM users + UpsertUserScramCredentials(upsert []AlterUserScramCredentialsUpsert) ([]*AlterUserScramCredentialsResult, error) + // Close shuts down the admin and closes underlying client. Close() error } @@ -936,3 +945,61 @@ func (ca *clusterAdmin) DescribeLogDirs(brokerIds []int32) (allLogDirs map[int32 err = <-errChan return } + +func (ca *clusterAdmin) DescribeUserScramCredentials(users []string) ([]*DescribeUserScramCredentialsResult, error) { + req := &DescribeUserScramCredentialsRequest{} + for _, u := range users { + req.DescribeUsers = append(req.DescribeUsers, DescribeUserScramCredentialsRequestUser{ + Name: u, + }) + } + + b, err := ca.Controller() + if err != nil { + return nil, err + } + + rsp, err := b.DescribeUserScramCredentials(req) + if err != nil { + return nil, err + } + + return rsp.Results, nil +} + +func (ca *clusterAdmin) UpsertUserScramCredentials(upsert []AlterUserScramCredentialsUpsert) ([]*AlterUserScramCredentialsResult, error) { + res, err := ca.AlterUserScramCredentials(upsert, nil) + if err != nil { + return nil, err + } + + return res, nil +} + +func (ca *clusterAdmin) DeleteUserScramCredentials(delete []AlterUserScramCredentialsDelete) ([]*AlterUserScramCredentialsResult, error) { + res, err := ca.AlterUserScramCredentials(nil, delete) + if err != nil { + return nil, err + } + + return res, nil +} + +func (ca *clusterAdmin) AlterUserScramCredentials(u []AlterUserScramCredentialsUpsert, d []AlterUserScramCredentialsDelete) ([]*AlterUserScramCredentialsResult, error) { + req := &AlterUserScramCredentialsRequest{ + Deletions: d, + Upsertions: u, + } + + b, err := ca.Controller() + if err != nil { + return nil, err + } + + rsp, err := b.AlterUserScramCredentials(req) + if err != nil { + return nil, err + } + + return rsp.Results, nil +} diff --git a/alter_user_scram_credentials_request.go b/alter_user_scram_credentials_request.go new file mode 100644 index 000000000..0530d8946 --- /dev/null +++ b/alter_user_scram_credentials_request.go @@ -0,0 +1,142 @@ +package sarama + +type AlterUserScramCredentialsRequest struct { + Version int16 + + // Deletions represent list of SCRAM credentials to remove + Deletions []AlterUserScramCredentialsDelete + + // Upsertions represent list of SCRAM credentials to update/insert + Upsertions []AlterUserScramCredentialsUpsert +} + +type AlterUserScramCredentialsDelete struct { + Name string + Mechanism ScramMechanismType +} + +type AlterUserScramCredentialsUpsert struct { + Name string + Mechanism ScramMechanismType + Iterations int32 + Salt []byte + saltedPassword []byte + + // This field is never transmitted over the wire + // @see: https://tools.ietf.org/html/rfc5802 + Password []byte +} + +func (r *AlterUserScramCredentialsRequest) encode(pe packetEncoder) error { + pe.putCompactArrayLength(len(r.Deletions)) + for _, d := range r.Deletions { + if err := pe.putCompactString(d.Name); err != nil { + return err + } + pe.putInt8(int8(d.Mechanism)) + pe.putEmptyTaggedFieldArray() + } + + pe.putCompactArrayLength(len(r.Upsertions)) + for _, u := range r.Upsertions { + if err := pe.putCompactString(u.Name); err != nil { + return err + } + pe.putInt8(int8(u.Mechanism)) + pe.putInt32(u.Iterations) + + if err := pe.putCompactBytes(u.Salt); err != nil { + return err + } + + // do not transmit the password over the wire + formatter := scramFormatter{mechanism: u.Mechanism} + salted, err := formatter.saltedPassword(u.Password, u.Salt, int(u.Iterations)) + if err != nil { + return err + } + + if err := pe.putCompactBytes(salted); err != nil { + return err + } + pe.putEmptyTaggedFieldArray() + } + + pe.putEmptyTaggedFieldArray() + return nil +} + +func (r *AlterUserScramCredentialsRequest) decode(pd packetDecoder, version int16) error { + numDeletions, err := pd.getCompactArrayLength() + if err != nil { + return err + } + + r.Deletions = make([]AlterUserScramCredentialsDelete, numDeletions) + for i := 0; i < numDeletions; i++ { + r.Deletions[i] = AlterUserScramCredentialsDelete{} + if r.Deletions[i].Name, err = pd.getCompactString(); err != nil { + return err + } + mechanism, err := pd.getInt8() + if err != nil { + return err + } + r.Deletions[i].Mechanism = ScramMechanismType(mechanism) + if _, err = pd.getEmptyTaggedFieldArray(); err != nil { + return err + } + } + + numUpsertions, err := pd.getCompactArrayLength() + if err != nil { + return err + } + + r.Upsertions = make([]AlterUserScramCredentialsUpsert, numUpsertions) + for i := 0; i < numUpsertions; i++ { + r.Upsertions[i] = AlterUserScramCredentialsUpsert{} + if r.Upsertions[i].Name, err = pd.getCompactString(); err != nil { + return err + } + mechanism, err := pd.getInt8() + if err != nil { + return err + } + + r.Upsertions[i].Mechanism = ScramMechanismType(mechanism) + if r.Upsertions[i].Iterations, err = pd.getInt32(); err != nil { + return err + } + if r.Upsertions[i].Salt, err = pd.getCompactBytes(); err != nil { + return err + } + if r.Upsertions[i].saltedPassword, err = pd.getCompactBytes(); err != nil { + return err + } + if _, err = pd.getEmptyTaggedFieldArray(); err != nil { + return err + } + } + + if _, err = pd.getEmptyTaggedFieldArray(); err != nil { + return err + } + return nil +} + +func (r *AlterUserScramCredentialsRequest) key() int16 { + return 51 +} + +func (r *AlterUserScramCredentialsRequest) version() int16 { + return r.Version +} + +func (r *AlterUserScramCredentialsRequest) headerVersion() int16 { + return 2 +} + +func (r *AlterUserScramCredentialsRequest) requiredVersion() KafkaVersion { + return V2_7_0_0 +} diff --git a/alter_user_scram_credentials_request_test.go b/alter_user_scram_credentials_request_test.go new file mode 100644 index 000000000..6fe881906 --- /dev/null +++ b/alter_user_scram_credentials_request_test.go @@ -0,0 +1,60 @@ +package sarama + +import "testing" + +var ( + emptyAlterUserScramCredentialsRequest = []byte{ + 1, // Deletions + 1, // Upsertions + 0, // empty tagged fields + } + userAlterUserScramCredentialsRequest = []byte{ + 2, // Deletions array, length 1 + 7, // User name length 6 + 'd', 'e', 'l', 'e', 't', 'e', // User name + 2, // SCRAM_SHA_512 + 0, // empty tagged fields + 2, // Upsertions array, length 1 + 7, // User name length 6 + 'u', 'p', 's', 'e', 'r', 't', + 1, // SCRAM_SHA_256 + 0, 0, 16, 0, // iterations: 4096 + // salt bytes: + 6, 119, 111, 114, 108, 100, + // saltedPassword: + 33, 193, 85, 83, 3, 218, 48, 159, 107, 125, 30, 143, + 228, 86, 54, 191, 221, 220, 75, 245, 100, 5, 231, + 233, 78, 157, 21, 240, 231, 185, 203, 211, 128, + 0, // empty tagged fields + 0, // empty tagged fields + } +) + +func TestAlterUserScramCredentialsRequest(t *testing.T) { + request := &AlterUserScramCredentialsRequest{ + Version: 0, + Deletions: []AlterUserScramCredentialsDelete{}, + Upsertions: []AlterUserScramCredentialsUpsert{}, + } + + // Password is not transmitted, will fail with `testRequest` and `DeepEqual` check + testRequestEncode(t, "no upsertions/deletions", request, emptyAlterUserScramCredentialsRequest) + + request.Deletions = []AlterUserScramCredentialsDelete{ + { + Name: "delete", + Mechanism: SCRAM_MECHANISM_SHA_512, + }, + } + request.Upsertions = []AlterUserScramCredentialsUpsert{ + { + Name: "upsert", + Mechanism: SCRAM_MECHANISM_SHA_256, + Iterations: 4096, + Salt: []byte("world"), + Password: []byte("hello"), + }, + } + // Password is not transmitted, will fail with `testRequest` and `DeepEqual` check + testRequestEncode(t, "single deletion and upsertion", request, userAlterUserScramCredentialsRequest) +} diff --git a/alter_user_scram_credentials_response.go b/alter_user_scram_credentials_response.go new file mode 100644 index 000000000..31e167b5e --- /dev/null +++ b/alter_user_scram_credentials_response.go @@ -0,0 +1,94 @@ +package sarama + +import "time" + +type AlterUserScramCredentialsResponse struct { + Version int16 + + ThrottleTime time.Duration + + Results []*AlterUserScramCredentialsResult +} + +type AlterUserScramCredentialsResult struct { + User string + + ErrorCode KError + ErrorMessage *string +} + +func (r *AlterUserScramCredentialsResponse) encode(pe packetEncoder) error { + pe.putInt32(int32(r.ThrottleTime / time.Millisecond)) + pe.putCompactArrayLength(len(r.Results)) + + for _, u := range r.Results { + if err := pe.putCompactString(u.User); err != nil { + return err + } + pe.putInt16(int16(u.ErrorCode)) + if err := pe.putNullableCompactString(u.ErrorMessage); err != nil { + return err + } + pe.putEmptyTaggedFieldArray() + } + + pe.putEmptyTaggedFieldArray() + return nil +} + +func (r *AlterUserScramCredentialsResponse) decode(pd packetDecoder, version int16) error { + throttleTime, err := pd.getInt32() + if err != nil { + return err + } + r.ThrottleTime = time.Duration(throttleTime) * time.Millisecond + + numResults, err := pd.getCompactArrayLength() + if err != nil { + return err + } + + if numResults > 0 { + r.Results = make([]*AlterUserScramCredentialsResult, numResults) + for i := 0; i < numResults; i++ { + r.Results[i] = &AlterUserScramCredentialsResult{} + if r.Results[i].User, err = pd.getCompactString(); err != nil { + return err + } + + kerr, err := pd.getInt16() + if err != nil { + return err + } + + r.Results[i].ErrorCode = KError(kerr) + if r.Results[i].ErrorMessage, err = pd.getCompactNullableString(); err != nil { + return err + } + if _, err := pd.getEmptyTaggedFieldArray(); err != nil { + return err + } + } + } + + if _, err := pd.getEmptyTaggedFieldArray(); err != nil { + return err + } + return nil +} + +func (r *AlterUserScramCredentialsResponse) key() int16 { + return 51 +} + +func (r *AlterUserScramCredentialsResponse) version() int16 { + return r.Version +} + +func (r *AlterUserScramCredentialsResponse) headerVersion() int16 { + return 2 +} + +func (r *AlterUserScramCredentialsResponse) requiredVersion() KafkaVersion { + return V2_7_0_0 +} diff --git a/alter_user_scram_credentials_response_test.go b/alter_user_scram_credentials_response_test.go new file mode 100644 index 000000000..983500639 --- /dev/null +++ b/alter_user_scram_credentials_response_test.go @@ -0,0 +1,39 @@ +package sarama + +import ( + "testing" + "time" +) + +var ( + emptyAlterUserScramCredentialsResponse = []byte{ + 0, 0, 11, 184, // throttle time + 1, // empty results array + 0, // empty tagged fields + } + userAlterUserScramCredentialsResponse = []byte{ + 0, 0, 11, 184, // throttle time + 2, // results array length + 7, 'n', 'o', 'b', 'o', 'd', 'y', // User + 0, 11, // ErrorCode + 6, 'e', 'r', 'r', 'o', 'r', // ErrorMessage + 0, // empty tagged fields + 0, // empty tagged fields + } +) + +func TestAlterUserScramCredentialsResponse(t *testing.T) { + response := &AlterUserScramCredentialsResponse{ + Version: 0, + ThrottleTime: time.Second * 3, + } + testResponse(t, "empty response", response, emptyAlterUserScramCredentialsResponse) + + resultErrorMessage := "error" + response.Results = append(response.Results, &AlterUserScramCredentialsResult{ + User: "nobody", + ErrorCode: 11, + ErrorMessage: &resultErrorMessage, + }) + testResponse(t, "single user response", response, userAlterUserScramCredentialsResponse) +} diff --git a/broker.go b/broker.go index 4fc425f3a..0b3ea969c 100644 --- a/broker.go +++ b/broker.go @@ -693,6 +693,29 @@ func (b *Broker) DescribeLogDirs(request *DescribeLogDirsRequest) (*DescribeLogD return response, nil } +// DescribeUserScramCredentials sends a request to get SCRAM users +func (b *Broker) DescribeUserScramCredentials(req *DescribeUserScramCredentialsRequest) (*DescribeUserScramCredentialsResponse, error) { + res := new(DescribeUserScramCredentialsResponse) + + err := b.sendAndReceive(req, res) + if err != nil { + return nil, err + } + + return res, err +} + +func (b *Broker) AlterUserScramCredentials(req *AlterUserScramCredentialsRequest) (*AlterUserScramCredentialsResponse, error) { + res := new(AlterUserScramCredentialsResponse) + + err := b.sendAndReceive(req, res) + if err != nil { + return nil, err + } + + return res, nil +} + // readFull ensures the conn ReadDeadline has been setup before making a // call to io.ReadFull func (b *Broker) readFull(buf []byte) (n int, err error) { diff --git a/describe_user_scram_credentials_request.go b/describe_user_scram_credentials_request.go new file mode 100644 index 000000000..b5b59404b --- /dev/null +++ b/describe_user_scram_credentials_request.go @@ -0,0 +1,70 @@ +package sarama + +// DescribeUserScramCredentialsRequest is a request to get list of SCRAM user names +type DescribeUserScramCredentialsRequest struct { + // Version 0 is currently only supported + Version int16 + + // If this is an empty array, all users will be queried + DescribeUsers []DescribeUserScramCredentialsRequestUser +} + +// DescribeUserScramCredentialsRequestUser is a describe request about specific user name +type DescribeUserScramCredentialsRequestUser struct { + Name string +} + +func (r *DescribeUserScramCredentialsRequest) encode(pe packetEncoder) error { + pe.putCompactArrayLength(len(r.DescribeUsers)) + for _, d := range r.DescribeUsers { + if err := pe.putCompactString(d.Name); err != nil { + return err + } + pe.putEmptyTaggedFieldArray() + } + + pe.putEmptyTaggedFieldArray() + return nil +} + +func (r *DescribeUserScramCredentialsRequest) decode(pd packetDecoder, version int16) error { + n, err := pd.getCompactArrayLength() + if err != nil { + return err + } + if n == -1 { + n = 0 + } + + r.DescribeUsers = make([]DescribeUserScramCredentialsRequestUser, n) + for i := 0; i < n; i++ { + r.DescribeUsers[i] = DescribeUserScramCredentialsRequestUser{} + if r.DescribeUsers[i].Name, err = pd.getCompactString(); err != nil { + return err + } + if _, err = pd.getEmptyTaggedFieldArray(); err != nil { + return err + } + } + + if _, err = pd.getEmptyTaggedFieldArray(); err != nil { + return err + } + return nil +} + +func (r *DescribeUserScramCredentialsRequest) key() int16 { + return 50 +} + +func (r *DescribeUserScramCredentialsRequest) version() int16 { + return r.Version +} + +func (r *DescribeUserScramCredentialsRequest) headerVersion() int16 { + return 2 +} + +func (r *DescribeUserScramCredentialsRequest) requiredVersion() KafkaVersion { + return V2_7_0_0 +} diff --git a/describe_user_scram_credentials_request_test.go b/describe_user_scram_credentials_request_test.go new file mode 100644 index 000000000..87e52bab6 --- /dev/null +++ b/describe_user_scram_credentials_request_test.go @@ -0,0 +1,30 @@ +package sarama + +import "testing" + +var ( + emptyDescribeUserScramCredentialsRequest = []byte{ + 1, 0, // empty tagged fields + } + userDescribeUserScramCredentialsRequest = []byte{ + 2, // DescribeUsers array, Array length 1 + 7, // User name length 6 + 'r', 'a', 'n', 'd', 'o', 'm', // User name + 0, 0, // empty tagged fields + } +) + +func TestDescribeUserScramCredentialsRequest(t *testing.T) { + request := &DescribeUserScramCredentialsRequest{ + Version: 0, + DescribeUsers: []DescribeUserScramCredentialsRequestUser{}, + } + testRequest(t, "no users", request, emptyDescribeUserScramCredentialsRequest) + + request.DescribeUsers = []DescribeUserScramCredentialsRequestUser{ + { + Name: "random", + }, + } + testRequest(t, "single user", request, userDescribeUserScramCredentialsRequest) +} diff --git a/describe_user_scram_credentials_response.go b/describe_user_scram_credentials_response.go new file mode 100644 index 000000000..2656c2faa --- /dev/null +++ b/describe_user_scram_credentials_response.go @@ -0,0 +1,168 @@ +package sarama + +import "time" + +type ScramMechanismType int8 + +const ( + SCRAM_MECHANISM_UNKNOWN ScramMechanismType = iota // 0 + SCRAM_MECHANISM_SHA_256 // 1 + SCRAM_MECHANISM_SHA_512 // 2 +) + +func (s ScramMechanismType) String() string { + switch s { + case 1: + return SASLTypeSCRAMSHA256 + case 2: + return SASLTypeSCRAMSHA512 + default: + return "Unknown" + } +} + +type DescribeUserScramCredentialsResponse struct { + // Version 0 is currently only supported + Version int16 + + ThrottleTime time.Duration + + ErrorCode KError + ErrorMessage *string + + Results []*DescribeUserScramCredentialsResult +} + +type DescribeUserScramCredentialsResult struct { + User string + + ErrorCode KError + ErrorMessage *string + + CredentialInfos []*UserScramCredentialsResponseInfo +} + +type UserScramCredentialsResponseInfo struct { + Mechanism ScramMechanismType + Iterations int32 +} + +func (r *DescribeUserScramCredentialsResponse) encode(pe packetEncoder) error { + pe.putInt32(int32(r.ThrottleTime / time.Millisecond)) + + pe.putInt16(int16(r.ErrorCode)) + if err := pe.putNullableCompactString(r.ErrorMessage); err != nil { + return err + } + + pe.putCompactArrayLength(len(r.Results)) + for _, u := range r.Results { + if err := pe.putCompactString(u.User); err != nil { + return err + } + pe.putInt16(int16(u.ErrorCode)) + if err := pe.putNullableCompactString(u.ErrorMessage); err != nil { + return err + } + + pe.putCompactArrayLength(len(u.CredentialInfos)) + for _, c := range u.CredentialInfos { + pe.putInt8(int8(c.Mechanism)) + pe.putInt32(c.Iterations) + pe.putEmptyTaggedFieldArray() + } + + pe.putEmptyTaggedFieldArray() + } + + pe.putEmptyTaggedFieldArray() + return nil +} + +func (r *DescribeUserScramCredentialsResponse) decode(pd packetDecoder, version int16) error { + throttleTime, err := pd.getInt32() + if err != nil { + return err + } + r.ThrottleTime = time.Duration(throttleTime) * time.Millisecond + + kerr, err := pd.getInt16() + if err != nil { + return err + } + + r.ErrorCode = KError(kerr) + if r.ErrorMessage, err = pd.getCompactNullableString(); err != nil { + return err + } + + numUsers, err := pd.getCompactArrayLength() + if err != nil { + return err + } + + if numUsers > 0 { + r.Results = make([]*DescribeUserScramCredentialsResult, numUsers) + for i := 0; i < numUsers; i++ { + r.Results[i] = &DescribeUserScramCredentialsResult{} + if r.Results[i].User, err = pd.getCompactString(); err != nil { + return err + } + + errorCode, err := pd.getInt16() + if err != nil { + return err + } + r.Results[i].ErrorCode = KError(errorCode) + if r.Results[i].ErrorMessage, err = pd.getCompactNullableString(); err != nil { + return err + } + + numCredentialInfos, err := pd.getCompactArrayLength() + if err != nil { + return err + } + + r.Results[i].CredentialInfos = make([]*UserScramCredentialsResponseInfo, numCredentialInfos) + for j := 0; j < numCredentialInfos; j++ { + r.Results[i].CredentialInfos[j] = &UserScramCredentialsResponseInfo{} + scramMechanism, err := pd.getInt8() + if err != nil { + return err + } + r.Results[i].CredentialInfos[j].Mechanism = ScramMechanismType(scramMechanism) + if r.Results[i].CredentialInfos[j].Iterations, err = pd.getInt32(); err != nil { + return err + } + if _, err = pd.getEmptyTaggedFieldArray(); err != nil { + return err + } + } + + if _, err = pd.getEmptyTaggedFieldArray(); err != nil { + return err + } + } + } + + if _, err = pd.getEmptyTaggedFieldArray(); err != nil { + return err + } + return nil +} + +func (r *DescribeUserScramCredentialsResponse) key() int16 { + return 50 +} + +func (r *DescribeUserScramCredentialsResponse) version() int16 { + return r.Version +} + +func (r *DescribeUserScramCredentialsResponse) headerVersion() int16 { + return 2 +} + +func (r *DescribeUserScramCredentialsResponse) requiredVersion() KafkaVersion { + return V2_7_0_0 +} diff --git a/describe_user_scram_credentials_response_test.go b/describe_user_scram_credentials_response_test.go new file mode 100644 index 000000000..a251eaf7a --- /dev/null +++ b/describe_user_scram_credentials_response_test.go @@ -0,0 +1,56 @@ +package sarama + +import ( + "testing" + "time" +) + +var ( + emptyDescribeUserScramCredentialsResponse = []byte{ + 0, 0, 11, 184, // throttle time (3000 ms) + 0, 0, // no error code + 0, // no error message + 1, // empty array + 0, // tagged fields + } + + userDescribeUserScramCredentialsResponse = []byte{ + 0, 0, 11, 184, // throttle time (3000 ms) + 0, 11, // Error Code + 6, 'e', 'r', 'r', 'o', 'r', // ErrorMessage + 2, // Results array length + 7, 'n', 'o', 'b', 'o', 'd', 'y', // User + 0, 13, // User ErrorCode + 11, 'e', 'r', 'r', 'o', 'r', '_', 'u', 's', 'e', 'r', // User ErrorMessage + 2, // CredentialInfos array length + 2, // Mechanism + 0, 0, 16, 0, // Iterations + 0, 0, 0, + } +) + +func TestDescribeUserScramCredentialsResponse(t *testing.T) { + response := &DescribeUserScramCredentialsResponse{ + Version: 0, + ThrottleTime: time.Second * 3, + } + testResponse(t, "empty", response, emptyDescribeUserScramCredentialsResponse) + + responseErrorMessage := "error" + responseUserErrorMessage := "error_user" + + response.ErrorCode = 11 + response.ErrorMessage = &responseErrorMessage + response.Results = append(response.Results, &DescribeUserScramCredentialsResult{ + User: "nobody", + ErrorCode: 13, + ErrorMessage: &responseUserErrorMessage, + CredentialInfos: []*UserScramCredentialsResponseInfo{ + { + Mechanism: SCRAM_MECHANISM_SHA_512, + Iterations: 4096, + }, + }, + }) + testResponse(t, "empty", response, userDescribeUserScramCredentialsResponse) +} diff --git a/errors.go b/errors.go index da3353654..a5b6edbb5 100644 --- a/errors.go +++ b/errors.go @@ -52,6 +52,9 @@ var ErrControllerNotAvailable = errors.New("kafka: controller is not available") // the metadata. var ErrNoTopicsToUpdateMetadata = errors.New("kafka: no specific topics to update metadata") +// ErrUnknownScramMechanism is returned when user tries to AlterUserScramCredentials with unknown SCRAM mechanism +var ErrUnknownScramMechanism = errors.New("kafka: unknown SCRAM mechanism provided") + // PacketEncodingError is returned from a failure while encoding a Kafka packet. This can happen, for example, // if you try to encode a string over 2^15 characters in length, since Kafka's encoding rules do not permit that. type PacketEncodingError struct { diff --git a/packet_decoder.go b/packet_decoder.go index ed00ba350..184bc26ae 100644 --- a/packet_decoder.go +++ b/packet_decoder.go @@ -19,6 +19,7 @@ type packetDecoder interface { // Collections getBytes() ([]byte, error) getVarintBytes() ([]byte, error) + getCompactBytes() ([]byte, error) getRawBytes(length int) ([]byte, error) getString() (string, error) getNullableString() (*string, error) diff --git a/packet_encoder.go b/packet_encoder.go index 50c735c04..aea53ca83 100644 --- a/packet_encoder.go +++ b/packet_encoder.go @@ -20,6 +20,7 @@ type packetEncoder interface { // Collections putBytes(in []byte) error putVarintBytes(in []byte) error + putCompactBytes(in []byte) error putRawBytes(in []byte) error putCompactString(in string) error putNullableCompactString(in *string) error diff --git a/prep_encoder.go b/prep_encoder.go index 827542c50..0d0137487 100644 --- a/prep_encoder.go +++ b/prep_encoder.go @@ -77,6 +77,11 @@ func (pe *prepEncoder) putVarintBytes(in []byte) error { return pe.putRawBytes(in) } +func (pe *prepEncoder) putCompactBytes(in []byte) error { + pe.putUVarint(uint64(len(in) + 1)) + return pe.putRawBytes(in) +} + func (pe *prepEncoder) putCompactString(in string) error { pe.putCompactArrayLength(len(in)) return pe.putRawBytes([]byte(in)) diff --git a/real_decoder.go b/real_decoder.go index 8ac576db2..bffec2f39 100644 --- a/real_decoder.go +++ b/real_decoder.go @@ -170,6 +170,16 @@ func (rd *realDecoder) getVarintBytes() ([]byte, error) { return rd.getRawBytes(int(tmp)) } +func (rd *realDecoder) getCompactBytes() ([]byte, error) { + n, err := rd.getUVarint() + if err != nil { + return nil, err + } + + var length = int(n - 1) + return rd.getRawBytes(length) +} + func (rd *realDecoder) getStringLength() (int, error) { length, err := rd.getInt16() if err != nil { diff --git a/real_encoder.go b/real_encoder.go index ba073f7d3..c07204cbc 100644 --- a/real_encoder.go +++ b/real_encoder.go @@ -88,6 +88,11 @@ func (re *realEncoder) putVarintBytes(in []byte) error { return re.putRawBytes(in) } +func (re *realEncoder) putCompactBytes(in []byte) error { + re.putUVarint(uint64(len(in) + 1)) + return re.putRawBytes(in) +} + func (re *realEncoder) putCompactString(in string) error { re.putCompactArrayLength(len(in)) return re.putRawBytes([]byte(in)) diff --git a/request.go b/request.go index 0f76ae534..afa64ff92 100644 --- a/request.go +++ b/request.go @@ -186,6 +186,10 @@ func allocateBody(key, version int16) protocolBody { return &AlterPartitionReassignmentsRequest{} case 46: return &ListPartitionReassignmentsRequest{} + case 50: + return &DescribeUserScramCredentialsRequest{} + case 51: + return &AlterUserScramCredentialsRequest{} } return nil } diff --git a/scram_formatter.go b/scram_formatter.go new file mode 100644 index 000000000..2af9e4a69 --- /dev/null +++ b/scram_formatter.go @@ -0,0 +1,78 @@ +package sarama + +import ( + "crypto/hmac" + "crypto/sha256" + "crypto/sha512" + "hash" +) + +// ScramFormatter implementation +// @see: https://github.com/apache/kafka/blob/99b9b3e84f4e98c3f07714e1de6a139a004cbc5b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramFormatter.java#L93 +type scramFormatter struct { + mechanism ScramMechanismType +} + +func (s scramFormatter) mac(key []byte) (hash.Hash, error) { + var m hash.Hash + + switch s.mechanism { + case SCRAM_MECHANISM_SHA_256: + m = hmac.New(sha256.New, key) + + case SCRAM_MECHANISM_SHA_512: + m = hmac.New(sha512.New, key) + default: + return nil, ErrUnknownScramMechanism + } + + return m, nil +} + +func (s scramFormatter) hmac(key []byte, extra []byte) ([]byte, error) { + mac, err := s.mac(key) + if err != nil { + return nil, err + } + + if _, err := mac.Write(extra); err != nil { + return nil, err + } + return mac.Sum(nil), nil +} + +func (s scramFormatter) xor(result []byte, second []byte) { + for i := 0; i < len(result); i++ { + result[i] = result[i] ^ second[i] + } +} + +func (s scramFormatter) saltedPassword(password []byte, salt []byte, iterations int) ([]byte, error) { + mac, err := s.mac(password) + if err != nil { + return nil, err + } + + if _, err := mac.Write(salt); err != nil { + return nil, err + } + if _, err := mac.Write([]byte{0, 0, 0, 1}); err != nil { + return nil, err + } + + u1 := mac.Sum(nil) + prev := u1 + result := u1 + + for i := 2; i <= iterations; i++ { + ui, err := s.hmac(password, prev) + if err != nil { + return nil, err + } + + s.xor(result, ui) + prev = ui + } + + return result, nil +} diff --git a/scram_formatter_test.go b/scram_formatter_test.go new file mode 100644 index 000000000..b673a6a7d --- /dev/null +++ b/scram_formatter_test.go @@ -0,0 +1,80 @@ +package sarama + +import ( + "bytes" + "testing" +) + +/* +Following code can be used to validate saltedPassword implementation: + +
+import org.apache.kafka.common.security.scram.internals.ScramFormatter;
+import org.apache.kafka.common.security.scram.internals.ScramMechanism;
+import java.nio.charset.StandardCharsets;
+
+public class App {
+
+    public static String bytesToHex(byte[] in) {
+        final StringBuilder builder = new StringBuilder();
+        for(byte b : in) {
+            builder.append(String.format("0x%02x, ", b));
+        }
+        return builder.toString();
+    }
+
+	public static void main(String[] args) throws NoSuchAlgorithmException, InvalidKeyException {
+	   int digestIterations = 4096;
+	   String password = "hello";
+	   byte[] salt = "world".getBytes(StandardCharsets.UTF_8);
+	   byte[] saltedPassword = new ScramFormatter(ScramMechanism.SCRAM_SHA_256)
+			   .saltedPassword(password, salt, digestIterations);
+	   System.out.println(bytesToHex(saltedPassword));
+	}
+}
+
+*/ + +func TestScramSaltedPasswordSha512(t *testing.T) { + password := []byte("hello") + salt := []byte("world") + + formatter := scramFormatter{mechanism: SCRAM_MECHANISM_SHA_512} + result, _ := formatter.saltedPassword(password, salt, 4096) + + // calculated using ScramFormatter (see comment above) + expected := []byte{ + 0x35, 0x0c, 0x77, 0x84, 0x8a, 0x63, 0x06, 0x92, 0x00, + 0x6e, 0xc6, 0x6a, 0x0c, 0x39, 0xeb, 0xb0, 0x00, 0xd3, + 0xf8, 0x8a, 0x94, 0xae, 0x7f, 0x8c, 0xcd, 0x1d, 0x92, + 0x52, 0x6c, 0x5b, 0x16, 0x15, 0x86, 0x3b, 0xde, 0xa1, + 0x6c, 0x12, 0x9a, 0x7b, 0x09, 0xed, 0x0e, 0x38, 0xf2, + 0x07, 0x4d, 0x2f, 0xe2, 0x9f, 0x0f, 0x41, 0xe1, 0xfb, + 0x00, 0xc1, 0xd3, 0xbd, 0xd3, 0xfd, 0x51, 0x0b, 0xa9, + 0x8f, + } + + if !bytes.Equal(result, expected) { + t.Errorf("saltedPassword SHA-512 failed, expected: %v, result: %v", expected, result) + } +} + +func TestScramSaltedPasswordSha256(t *testing.T) { + password := []byte("hello") + salt := []byte("world") + + formatter := scramFormatter{mechanism: SCRAM_MECHANISM_SHA_256} + result, _ := formatter.saltedPassword(password, salt, 4096) + + // calculated using ScramFormatter (see comment above) + expected := []byte{ + 0xc1, 0x55, 0x53, 0x03, 0xda, 0x30, 0x9f, 0x6b, 0x7d, + 0x1e, 0x8f, 0xe4, 0x56, 0x36, 0xbf, 0xdd, 0xdc, 0x4b, + 0xf5, 0x64, 0x05, 0xe7, 0xe9, 0x4e, 0x9d, 0x15, 0xf0, + 0xe7, 0xb9, 0xcb, 0xd3, 0x80, + } + + if !bytes.Equal(result, expected) { + t.Errorf("saltedPassword SHA-256 failed, expected: %v, result: %v", expected, result) + } +}