Skip to content

Commit

Permalink
feat: kip-368 support
Browse files Browse the repository at this point in the history
  • Loading branch information
k-wall committed Mar 29, 2022
1 parent f07b7b8 commit 88b1c46
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 31 deletions.
73 changes: 57 additions & 16 deletions broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"math/rand"
"net"
"sort"
"strconv"
Expand Down Expand Up @@ -52,7 +53,8 @@ type Broker struct {
brokerRequestsInFlight metrics.Counter
brokerThrottleTime metrics.Histogram

kerberosAuthenticator GSSAPIKerberosAuth
kerberosAuthenticator GSSAPIKerberosAuth
clientSessionReauthenticationTimeMs int64
}

// SASLMechanism specifies the SASL mechanism the client uses to authenticate with the broker
Expand Down Expand Up @@ -923,6 +925,13 @@ func (b *Broker) sendWithPromise(rb protocolBody, promise *responsePromise) erro
return ErrNotConnected
}

if b.clientSessionReauthenticationTimeMs > 0 && currentUnixMilli() > b.clientSessionReauthenticationTimeMs {
err := b.authenticateViaSASL()
if err != nil {
return err
}
}

if !b.conf.Version.IsAtLeast(rb.requiredVersion()) {
return ErrUnsupportedVersion
}
Expand Down Expand Up @@ -1263,7 +1272,7 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {

// Will be decremented in updateIncomingCommunicationMetrics (except error)
b.addRequestInFlightMetrics(1)
bytesWritten, err := b.sendSASLPlainAuthClientResponse(correlationID)
bytesWritten, resVersion, err := b.sendSASLPlainAuthClientResponse(correlationID)
b.updateOutgoingCommunicationMetrics(bytesWritten)

if err != nil {
Expand All @@ -1274,7 +1283,8 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {

b.correlationID++

bytesRead, err := b.receiveSASLServerResponse(&SaslAuthenticateResponse{}, correlationID)
res := &SaslAuthenticateResponse{}
bytesRead, err := b.receiveSASLServerResponse(res, correlationID, resVersion)
b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime))

// With v1 sasl we get an error message set in the response we can return
Expand All @@ -1288,6 +1298,10 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {
return nil
}

func currentUnixMilli() int64 {
return time.Now().UnixNano() / int64(time.Millisecond)
}

// sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255
// https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876
func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
Expand Down Expand Up @@ -1327,7 +1341,7 @@ func (b *Broker) sendClientMessage(message []byte) (bool, error) {
b.addRequestInFlightMetrics(1)
correlationID := b.correlationID

bytesWritten, err := b.sendSASLOAuthBearerClientMessage(message, correlationID)
bytesWritten, resVersion, err := b.sendSASLOAuthBearerClientMessage(message, correlationID)
b.updateOutgoingCommunicationMetrics(bytesWritten)
if err != nil {
b.addRequestInFlightMetrics(-1)
Expand All @@ -1337,7 +1351,7 @@ func (b *Broker) sendClientMessage(message []byte) (bool, error) {
b.correlationID++

res := &SaslAuthenticateResponse{}
bytesRead, err := b.receiveSASLServerResponse(res, correlationID)
bytesRead, err := b.receiveSASLServerResponse(res, correlationID, resVersion)

requestLatency := time.Since(requestTime)
b.updateIncomingCommunicationMetrics(bytesRead, requestLatency)
Expand Down Expand Up @@ -1464,7 +1478,7 @@ func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
}

func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (int, error) {
rb := &SaslAuthenticateRequest{msg}
rb := b.createSaslAuthenticateRequest(msg)
req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
if err != nil {
Expand All @@ -1474,6 +1488,15 @@ func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (i
return b.write(buf)
}

func (b *Broker) createSaslAuthenticateRequest(msg []byte) *SaslAuthenticateRequest {
authenticateRequest := SaslAuthenticateRequest{SaslAuthBytes: msg}
if b.conf.Version.IsAtLeast(V2_2_0_0) {
authenticateRequest.Version = 1
}

return &authenticateRequest
}

func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) {
buf := make([]byte, responseLengthSize+correlationIDSize)
_, err := b.readFull(buf)
Expand Down Expand Up @@ -1538,32 +1561,34 @@ func mapToString(extensions map[string]string, keyValSep string, elemSep string)
return strings.Join(buf, elemSep)
}

func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, error) {
func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, int16, error) {
authBytes := []byte(b.conf.Net.SASL.AuthIdentity + "\x00" + b.conf.Net.SASL.User + "\x00" + b.conf.Net.SASL.Password)
rb := &SaslAuthenticateRequest{authBytes}
rb := b.createSaslAuthenticateRequest(authBytes)
req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
if err != nil {
return 0, err
return 0, rb.Version, err
}

return b.write(buf)
write, err := b.write(buf)
return write, rb.Version, err
}

func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, error) {
rb := &SaslAuthenticateRequest{initialResp}
func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, int16, error) {
rb := b.createSaslAuthenticateRequest(initialResp)

req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}

buf, err := encode(req, b.conf.MetricRegistry)
if err != nil {
return 0, err
return 0, rb.version(), err
}

return b.write(buf)
write, err := b.write(buf)
return write, rb.version(), err
}

func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32) (int, error) {
func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32, resVersion int16) (int, error) {
buf := make([]byte, responseLengthSize+correlationIDSize)
bytesRead, err := b.readFull(buf)
if err != nil {
Expand All @@ -1587,7 +1612,7 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl
return bytesRead, err
}

if err := versionedDecode(buf, res, 0); err != nil {
if err := versionedDecode(buf, res, resVersion); err != nil {
return bytesRead, err
}

Expand All @@ -1599,6 +1624,22 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl
return bytesRead, err
}

if res.SessionLifetimeMs > 0 {
// Follows the Java Kafka implementation from SaslClientAuthenticator.ReauthInfo#setAuthenticationEndAndSessionReauthenticationTimes
// pick a random percentage between 85% and 95% for session re-authentication
positiveSessionLifetimeMs := res.SessionLifetimeMs
authenticationEndNanos := currentUnixMilli()
pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount := 0.85
pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously := 0.10
pctToUse := pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + rand.Float64()*pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously
sessionLifetimeMsToUse := int64(float64(positiveSessionLifetimeMs) * pctToUse)
DebugLogger.Printf("Session expiration in %d ms and session re-authentication on or after %d ms", positiveSessionLifetimeMs, sessionLifetimeMsToUse)
clientSessionReauthenticationTimeMs := authenticationEndNanos + sessionLifetimeMsToUse
b.clientSessionReauthenticationTimeMs = clientSessionReauthenticationTimeMs
} else {
b.clientSessionReauthenticationTimeMs = 0
}

return bytesRead, nil
}

Expand Down
12 changes: 10 additions & 2 deletions sasl_authenticate_request.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package sarama

type SaslAuthenticateRequest struct {
// Version defines the protocol version to use for encode and decode
Version int16
SaslAuthBytes []byte
}

Expand All @@ -12,6 +14,7 @@ func (r *SaslAuthenticateRequest) encode(pe packetEncoder) error {
}

func (r *SaslAuthenticateRequest) decode(pd packetDecoder, version int16) (err error) {
r.Version = version
r.SaslAuthBytes, err = pd.getBytes()
return err
}
Expand All @@ -21,13 +24,18 @@ func (r *SaslAuthenticateRequest) key() int16 {
}

func (r *SaslAuthenticateRequest) version() int16 {
return 0
return r.Version
}

func (r *SaslAuthenticateRequest) headerVersion() int16 {
return 1
}

func (r *SaslAuthenticateRequest) requiredVersion() KafkaVersion {
return V1_0_0_0
switch r.Version {
case 1:
return V2_2_0_0
default:
return V1_0_0_0
}
}
7 changes: 7 additions & 0 deletions sasl_authenticate_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ func TestSaslAuthenticateRequest(t *testing.T) {
request.SaslAuthBytes = []byte(`foo`)
testRequest(t, "basic", request, saslAuthenticateRequest)
}

func TestSaslAuthenticateRequestV1(t *testing.T) {
request := new(SaslAuthenticateRequest)
request.Version = 1
request.SaslAuthBytes = []byte(`foo`)
testRequest(t, "basic", request, saslAuthenticateRequest)
}
35 changes: 28 additions & 7 deletions sasl_authenticate_response.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
package sarama

type SaslAuthenticateResponse struct {
Err KError
ErrorMessage *string
SaslAuthBytes []byte
// Version defines the protocol version to use for encode and decode
Version int16
Err KError
ErrorMessage *string
SaslAuthBytes []byte
SessionLifetimeMs int64
}

func (r *SaslAuthenticateResponse) encode(pe packetEncoder) error {
pe.putInt16(int16(r.Err))
if err := pe.putNullableString(r.ErrorMessage); err != nil {
return err
}
return pe.putBytes(r.SaslAuthBytes)
if err := pe.putBytes(r.SaslAuthBytes); err != nil {
return err
}
if r.Version > 0 {
pe.putInt64(r.SessionLifetimeMs)
}
return nil
}

func (r *SaslAuthenticateResponse) decode(pd packetDecoder, version int16) error {
r.Version = version
kerr, err := pd.getInt16()
if err != nil {
return err
Expand All @@ -26,7 +36,13 @@ func (r *SaslAuthenticateResponse) decode(pd packetDecoder, version int16) error
return err
}

r.SaslAuthBytes, err = pd.getBytes()
if r.SaslAuthBytes, err = pd.getBytes(); err != nil {
return err
}

if version > 0 {
r.SessionLifetimeMs, err = pd.getInt64()
}

return err
}
Expand All @@ -36,13 +52,18 @@ func (r *SaslAuthenticateResponse) key() int16 {
}

func (r *SaslAuthenticateResponse) version() int16 {
return 0
return r.Version
}

func (r *SaslAuthenticateResponse) headerVersion() int16 {
return 0
}

func (r *SaslAuthenticateResponse) requiredVersion() KafkaVersion {
return V1_0_0_0
switch r.Version {
case 1:
return V2_2_0_0
default:
return V1_0_0_0
}
}
32 changes: 26 additions & 6 deletions sasl_authenticate_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,19 @@ package sarama

import "testing"

var saslAuthenticatResponseErr = []byte{
0, 58,
0, 3, 'e', 'r', 'r',
0, 0, 0, 3, 'm', 's', 'g',
}
var (
saslAuthenticateResponseErr = []byte{
0, 58,
0, 3, 'e', 'r', 'r',
0, 0, 0, 3, 'm', 's', 'g',
}
saslAuthenticateResponseErrV1 = []byte{
0, 58,
0, 3, 'e', 'r', 'r',
0, 0, 0, 3, 'm', 's', 'g',
0, 0, 0, 0, 0, 0, 0, 1,
}
)

func TestSaslAuthenticateResponse(t *testing.T) {
response := new(SaslAuthenticateResponse)
Expand All @@ -15,5 +23,17 @@ func TestSaslAuthenticateResponse(t *testing.T) {
response.ErrorMessage = &msg
response.SaslAuthBytes = []byte(`msg`)

testResponse(t, "authenticate response", response, saslAuthenticatResponseErr)
testResponse(t, "authenticate response", response, saslAuthenticateResponseErr)
}

func TestSaslAuthenticateResponseV1(t *testing.T) {
response := new(SaslAuthenticateResponse)
response.Err = ErrSASLAuthenticationFailed
msg := "err"
response.Version = 1
response.ErrorMessage = &msg
response.SaslAuthBytes = []byte(`msg`)
response.SessionLifetimeMs = 1

testResponse(t, "authenticate response", response, saslAuthenticateResponseErrV1)
}

0 comments on commit 88b1c46

Please sign in to comment.