Skip to content

Commit

Permalink
Merge pull request #1992 from zhaomoran/master
Browse files Browse the repository at this point in the history
feat: support SaslHandshakeRequest v0 for SCRAM
  • Loading branch information
bai authored Sep 8, 2021
2 parents 7f53062 + e7071f3 commit b64d8eb
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
66 changes: 65 additions & 1 deletion broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ func (b *Broker) authenticateViaSASL() error {
case SASLTypeOAuth:
return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider)
case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512:
return b.sendAndReceiveSASLSCRAMv1()
return b.sendAndReceiveSASLSCRAM()
case SASLTypeGSSAPI:
return b.sendAndReceiveKerberos()
default:
Expand Down Expand Up @@ -1180,6 +1180,70 @@ func (b *Broker) sendClientMessage(message []byte) (bool, error) {
return isChallenge, err
}

func (b *Broker) sendAndReceiveSASLSCRAM() error {
if b.conf.Net.SASL.Version == SASLHandshakeV0 {
return b.sendAndReceiveSASLSCRAMv0()
}
return b.sendAndReceiveSASLSCRAMv1()
}

func (b *Broker) sendAndReceiveSASLSCRAMv0() error {
if err := b.sendAndReceiveSASLHandshake(b.conf.Net.SASL.Mechanism, SASLHandshakeV0); err != nil {
return err
}

scramClient := b.conf.Net.SASL.SCRAMClientGeneratorFunc()
if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil {
return fmt.Errorf("failed to start SCRAM exchange with the server: %s", err.Error())
}

msg, err := scramClient.Step("")
if err != nil {
return fmt.Errorf("failed to advance the SCRAM exchange: %s", err.Error())
}

for !scramClient.Done() {
requestTime := time.Now()
// Will be decremented in updateIncomingCommunicationMetrics (except error)
b.addRequestInFlightMetrics(1)
length := len(msg)
authBytes := make([]byte, length+4) //4 byte length header + auth data
binary.BigEndian.PutUint32(authBytes, uint32(length))
copy(authBytes[4:], []byte(msg))
_, err := b.write(authBytes)
b.updateOutgoingCommunicationMetrics(length + 4)
if err != nil {
b.addRequestInFlightMetrics(-1)
Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
return err
}
b.correlationID++
header := make([]byte, 4)
_, err = b.readFull(header)
if err != nil {
b.addRequestInFlightMetrics(-1)
Logger.Printf("Failed to read response header while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
return err
}
payload := make([]byte, int32(binary.BigEndian.Uint32(header)))
n, err := b.readFull(payload)
if err != nil {
b.addRequestInFlightMetrics(-1)
Logger.Printf("Failed to read response payload while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
return err
}
b.updateIncomingCommunicationMetrics(n+4, time.Since(requestTime))
msg, err = scramClient.Step(string(payload))
if err != nil {
Logger.Println("SASL authentication failed", err)
return err
}
}

Logger.Println("SASL authentication succeeded")
return nil
}

func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
if err := b.sendAndReceiveSASLHandshake(b.conf.Net.SASL.Mechanism, SASLHandshakeV1); err != nil {
return err
Expand Down
2 changes: 2 additions & 0 deletions broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,11 @@ func TestSASLSCRAMSHAXXX(t *testing.T) {

conf := NewTestConfig()
conf.Net.SASL.Mechanism = SASLTypeSCRAMSHA512
conf.Net.SASL.Version = SASLHandshakeV1
conf.Net.SASL.SCRAMClientGeneratorFunc = func() SCRAMClient { return test.scramClient }

broker.conf = conf
broker.conf.Version = V1_0_0_0
dialer := net.Dialer{
Timeout: conf.Net.DialTimeout,
KeepAlive: conf.Net.KeepAlive,
Expand Down

0 comments on commit b64d8eb

Please sign in to comment.