diff --git a/broker.go b/broker.go index e481ad711c..b945e9f602 100644 --- a/broker.go +++ b/broker.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "math/rand" "net" "sort" "strconv" @@ -52,7 +53,8 @@ type Broker struct { brokerRequestsInFlight metrics.Counter brokerThrottleTime metrics.Histogram - kerberosAuthenticator GSSAPIKerberosAuth + kerberosAuthenticator GSSAPIKerberosAuth + clientSessionReauthenticationTimeMs int64 } // SASLMechanism specifies the SASL mechanism the client uses to authenticate with the broker @@ -923,6 +925,13 @@ func (b *Broker) sendWithPromise(rb protocolBody, promise *responsePromise) erro return ErrNotConnected } + if b.clientSessionReauthenticationTimeMs > 0 && currentUnixMilli() > b.clientSessionReauthenticationTimeMs { + err := b.authenticateViaSASL() + if err != nil { + return err + } + } + if !b.conf.Version.IsAtLeast(rb.requiredVersion()) { return ErrUnsupportedVersion } @@ -1263,7 +1272,7 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error { // Will be decremented in updateIncomingCommunicationMetrics (except error) b.addRequestInFlightMetrics(1) - bytesWritten, err := b.sendSASLPlainAuthClientResponse(correlationID) + bytesWritten, resVersion, err := b.sendSASLPlainAuthClientResponse(correlationID) b.updateOutgoingCommunicationMetrics(bytesWritten) if err != nil { @@ -1274,7 +1283,8 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error { b.correlationID++ - bytesRead, err := b.receiveSASLServerResponse(&SaslAuthenticateResponse{}, correlationID) + res := &SaslAuthenticateResponse{} + bytesRead, err := b.receiveSASLServerResponse(res, correlationID, resVersion) b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime)) // With v1 sasl we get an error message set in the response we can return @@ -1288,6 +1298,10 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error { return nil } +func currentUnixMilli() int64 { + return time.Now().UnixNano() / int64(time.Millisecond) +} + // sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255 // https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876 func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error { @@ -1327,7 +1341,7 @@ func (b *Broker) sendClientMessage(message []byte) (bool, error) { b.addRequestInFlightMetrics(1) correlationID := b.correlationID - bytesWritten, err := b.sendSASLOAuthBearerClientMessage(message, correlationID) + bytesWritten, resVersion, err := b.sendSASLOAuthBearerClientMessage(message, correlationID) b.updateOutgoingCommunicationMetrics(bytesWritten) if err != nil { b.addRequestInFlightMetrics(-1) @@ -1337,7 +1351,7 @@ func (b *Broker) sendClientMessage(message []byte) (bool, error) { b.correlationID++ res := &SaslAuthenticateResponse{} - bytesRead, err := b.receiveSASLServerResponse(res, correlationID) + bytesRead, err := b.receiveSASLServerResponse(res, correlationID, resVersion) requestLatency := time.Since(requestTime) b.updateIncomingCommunicationMetrics(bytesRead, requestLatency) @@ -1464,7 +1478,7 @@ func (b *Broker) sendAndReceiveSASLSCRAMv1() error { } func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (int, error) { - rb := &SaslAuthenticateRequest{msg} + rb := b.createSaslAuthenticateRequest(msg) req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb} buf, err := encode(req, b.conf.MetricRegistry) if err != nil { @@ -1474,6 +1488,15 @@ func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (i return b.write(buf) } +func (b *Broker) createSaslAuthenticateRequest(msg []byte) *SaslAuthenticateRequest { + authenticateRequest := SaslAuthenticateRequest{SaslAuthBytes: msg} + if b.conf.Version.IsAtLeast(V2_2_0_0) { + authenticateRequest.Version = 1 + } + + return &authenticateRequest +} + func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) { buf := make([]byte, responseLengthSize+correlationIDSize) _, err := b.readFull(buf) @@ -1538,32 +1561,34 @@ func mapToString(extensions map[string]string, keyValSep string, elemSep string) return strings.Join(buf, elemSep) } -func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, error) { +func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, int16, error) { authBytes := []byte(b.conf.Net.SASL.AuthIdentity + "\x00" + b.conf.Net.SASL.User + "\x00" + b.conf.Net.SASL.Password) - rb := &SaslAuthenticateRequest{authBytes} + rb := b.createSaslAuthenticateRequest(authBytes) req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb} buf, err := encode(req, b.conf.MetricRegistry) if err != nil { - return 0, err + return 0, rb.Version, err } - return b.write(buf) + write, err := b.write(buf) + return write, rb.Version, err } -func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, error) { - rb := &SaslAuthenticateRequest{initialResp} +func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, int16, error) { + rb := b.createSaslAuthenticateRequest(initialResp) req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb} buf, err := encode(req, b.conf.MetricRegistry) if err != nil { - return 0, err + return 0, rb.version(), err } - return b.write(buf) + write, err := b.write(buf) + return write, rb.version(), err } -func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32) (int, error) { +func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32, resVersion int16) (int, error) { buf := make([]byte, responseLengthSize+correlationIDSize) bytesRead, err := b.readFull(buf) if err != nil { @@ -1587,7 +1612,7 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl return bytesRead, err } - if err := versionedDecode(buf, res, 0); err != nil { + if err := versionedDecode(buf, res, resVersion); err != nil { return bytesRead, err } @@ -1599,6 +1624,22 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl return bytesRead, err } + if res.SessionLifetimeMs > 0 { + // Follows the Java Kafka implementation from SaslClientAuthenticator.ReauthInfo#setAuthenticationEndAndSessionReauthenticationTimes + // pick a random percentage between 85% and 95% for session re-authentication + positiveSessionLifetimeMs := res.SessionLifetimeMs + authenticationEndNanos := currentUnixMilli() + pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount := 0.85 + pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously := 0.10 + pctToUse := pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + rand.Float64()*pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously + sessionLifetimeMsToUse := int64(float64(positiveSessionLifetimeMs) * pctToUse) + DebugLogger.Printf("Session expiration in %d ms and session re-authentication on or after %d ms", positiveSessionLifetimeMs, sessionLifetimeMsToUse) + clientSessionReauthenticationTimeMs := authenticationEndNanos + sessionLifetimeMsToUse + b.clientSessionReauthenticationTimeMs = clientSessionReauthenticationTimeMs + } else { + b.clientSessionReauthenticationTimeMs = 0 + } + return bytesRead, nil } diff --git a/sasl_authenticate_request.go b/sasl_authenticate_request.go index 90504df6f5..5bb0988ea5 100644 --- a/sasl_authenticate_request.go +++ b/sasl_authenticate_request.go @@ -1,6 +1,8 @@ package sarama type SaslAuthenticateRequest struct { + // Version defines the protocol version to use for encode and decode + Version int16 SaslAuthBytes []byte } @@ -12,6 +14,7 @@ func (r *SaslAuthenticateRequest) encode(pe packetEncoder) error { } func (r *SaslAuthenticateRequest) decode(pd packetDecoder, version int16) (err error) { + r.Version = version r.SaslAuthBytes, err = pd.getBytes() return err } @@ -21,7 +24,7 @@ func (r *SaslAuthenticateRequest) key() int16 { } func (r *SaslAuthenticateRequest) version() int16 { - return 0 + return r.Version } func (r *SaslAuthenticateRequest) headerVersion() int16 { @@ -29,5 +32,10 @@ func (r *SaslAuthenticateRequest) headerVersion() int16 { } func (r *SaslAuthenticateRequest) requiredVersion() KafkaVersion { - return V1_0_0_0 + switch r.Version { + case 1: + return V2_2_0_0 + default: + return V1_0_0_0 + } } diff --git a/sasl_authenticate_request_test.go b/sasl_authenticate_request_test.go index bf75004d27..fd13a7c4aa 100644 --- a/sasl_authenticate_request_test.go +++ b/sasl_authenticate_request_test.go @@ -11,3 +11,10 @@ func TestSaslAuthenticateRequest(t *testing.T) { request.SaslAuthBytes = []byte(`foo`) testRequest(t, "basic", request, saslAuthenticateRequest) } + +func TestSaslAuthenticateRequestV1(t *testing.T) { + request := new(SaslAuthenticateRequest) + request.Version = 1 + request.SaslAuthBytes = []byte(`foo`) + testRequest(t, "basic", request, saslAuthenticateRequest) +} diff --git a/sasl_authenticate_response.go b/sasl_authenticate_response.go index 3ef57b5afa..37c8e45dae 100644 --- a/sasl_authenticate_response.go +++ b/sasl_authenticate_response.go @@ -1,9 +1,12 @@ package sarama type SaslAuthenticateResponse struct { - Err KError - ErrorMessage *string - SaslAuthBytes []byte + // Version defines the protocol version to use for encode and decode + Version int16 + Err KError + ErrorMessage *string + SaslAuthBytes []byte + SessionLifetimeMs int64 } func (r *SaslAuthenticateResponse) encode(pe packetEncoder) error { @@ -11,10 +14,17 @@ func (r *SaslAuthenticateResponse) encode(pe packetEncoder) error { if err := pe.putNullableString(r.ErrorMessage); err != nil { return err } - return pe.putBytes(r.SaslAuthBytes) + if err := pe.putBytes(r.SaslAuthBytes); err != nil { + return err + } + if r.Version > 0 { + pe.putInt64(r.SessionLifetimeMs) + } + return nil } func (r *SaslAuthenticateResponse) decode(pd packetDecoder, version int16) error { + r.Version = version kerr, err := pd.getInt16() if err != nil { return err @@ -26,7 +36,13 @@ func (r *SaslAuthenticateResponse) decode(pd packetDecoder, version int16) error return err } - r.SaslAuthBytes, err = pd.getBytes() + if r.SaslAuthBytes, err = pd.getBytes(); err != nil { + return err + } + + if version > 0 { + r.SessionLifetimeMs, err = pd.getInt64() + } return err } @@ -36,7 +52,7 @@ func (r *SaslAuthenticateResponse) key() int16 { } func (r *SaslAuthenticateResponse) version() int16 { - return 0 + return r.Version } func (r *SaslAuthenticateResponse) headerVersion() int16 { @@ -44,5 +60,10 @@ func (r *SaslAuthenticateResponse) headerVersion() int16 { } func (r *SaslAuthenticateResponse) requiredVersion() KafkaVersion { - return V1_0_0_0 + switch r.Version { + case 1: + return V2_2_0_0 + default: + return V1_0_0_0 + } } diff --git a/sasl_authenticate_response_test.go b/sasl_authenticate_response_test.go index 048dade197..ed037ed086 100644 --- a/sasl_authenticate_response_test.go +++ b/sasl_authenticate_response_test.go @@ -2,11 +2,19 @@ package sarama import "testing" -var saslAuthenticatResponseErr = []byte{ - 0, 58, - 0, 3, 'e', 'r', 'r', - 0, 0, 0, 3, 'm', 's', 'g', -} +var ( + saslAuthenticateResponseErr = []byte{ + 0, 58, + 0, 3, 'e', 'r', 'r', + 0, 0, 0, 3, 'm', 's', 'g', + } + saslAuthenticateResponseErrV1 = []byte{ + 0, 58, + 0, 3, 'e', 'r', 'r', + 0, 0, 0, 3, 'm', 's', 'g', + 0, 0, 0, 0, 0, 0, 0, 1, + } +) func TestSaslAuthenticateResponse(t *testing.T) { response := new(SaslAuthenticateResponse) @@ -15,5 +23,17 @@ func TestSaslAuthenticateResponse(t *testing.T) { response.ErrorMessage = &msg response.SaslAuthBytes = []byte(`msg`) - testResponse(t, "authenticate response", response, saslAuthenticatResponseErr) + testResponse(t, "authenticate response", response, saslAuthenticateResponseErr) +} + +func TestSaslAuthenticateResponseV1(t *testing.T) { + response := new(SaslAuthenticateResponse) + response.Err = ErrSASLAuthenticationFailed + msg := "err" + response.Version = 1 + response.ErrorMessage = &msg + response.SaslAuthBytes = []byte(`msg`) + response.SessionLifetimeMs = 1 + + testResponse(t, "authenticate response", response, saslAuthenticateResponseErrV1) }