From aba7b01604dc76e8c71141638ada700d10c38e8e Mon Sep 17 00:00:00 2001 From: kwall Date: Tue, 29 Mar 2022 17:10:08 +0100 Subject: [PATCH] feat: KIP-368 support periodic re-auth Allow SASL Connections to Periodically Re-Authenticate [KIP-368](https://cwiki.apache.org/confluence/display/KAFKA/KIP-368%3A+Allow+SASL+Connections+to+Periodically+Re-Authenticate) --- broker.go | 72 ++++++++++--- broker_test.go | 159 +++++++++++++++++++++++++++++ mockresponses.go | 15 ++- sasl_authenticate_request.go | 12 ++- sasl_authenticate_request_test.go | 7 ++ sasl_authenticate_response.go | 35 +++++-- sasl_authenticate_response_test.go | 32 ++++-- 7 files changed, 298 insertions(+), 34 deletions(-) diff --git a/broker.go b/broker.go index e481ad711..f529fc24a 100644 --- a/broker.go +++ b/broker.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "math/rand" "net" "sort" "strconv" @@ -52,7 +53,8 @@ type Broker struct { brokerRequestsInFlight metrics.Counter brokerThrottleTime metrics.Histogram - kerberosAuthenticator GSSAPIKerberosAuth + kerberosAuthenticator GSSAPIKerberosAuth + clientSessionReauthenticationTimeMs int64 } // SASLMechanism specifies the SASL mechanism the client uses to authenticate with the broker @@ -923,6 +925,13 @@ func (b *Broker) sendWithPromise(rb protocolBody, promise *responsePromise) erro return ErrNotConnected } + if b.clientSessionReauthenticationTimeMs > 0 && currentUnixMilli() > b.clientSessionReauthenticationTimeMs { + err := b.authenticateViaSASL() + if err != nil { + return err + } + } + if !b.conf.Version.IsAtLeast(rb.requiredVersion()) { return ErrUnsupportedVersion } @@ -1263,7 +1272,7 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error { // Will be decremented in updateIncomingCommunicationMetrics (except error) b.addRequestInFlightMetrics(1) - bytesWritten, err := b.sendSASLPlainAuthClientResponse(correlationID) + bytesWritten, resVersion, err := b.sendSASLPlainAuthClientResponse(correlationID) b.updateOutgoingCommunicationMetrics(bytesWritten) if err != nil { @@ -1274,7 +1283,8 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error { b.correlationID++ - bytesRead, err := b.receiveSASLServerResponse(&SaslAuthenticateResponse{}, correlationID) + res := &SaslAuthenticateResponse{} + bytesRead, err := b.receiveSASLServerResponse(res, correlationID, resVersion) b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime)) // With v1 sasl we get an error message set in the response we can return @@ -1288,6 +1298,10 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error { return nil } +func currentUnixMilli() int64 { + return time.Now().UnixNano() / int64(time.Millisecond) +} + // sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255 // https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876 func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error { @@ -1327,7 +1341,7 @@ func (b *Broker) sendClientMessage(message []byte) (bool, error) { b.addRequestInFlightMetrics(1) correlationID := b.correlationID - bytesWritten, err := b.sendSASLOAuthBearerClientMessage(message, correlationID) + bytesWritten, resVersion, err := b.sendSASLOAuthBearerClientMessage(message, correlationID) b.updateOutgoingCommunicationMetrics(bytesWritten) if err != nil { b.addRequestInFlightMetrics(-1) @@ -1337,7 +1351,7 @@ func (b *Broker) sendClientMessage(message []byte) (bool, error) { b.correlationID++ res := &SaslAuthenticateResponse{} - bytesRead, err := b.receiveSASLServerResponse(res, correlationID) + bytesRead, err := b.receiveSASLServerResponse(res, correlationID, resVersion) requestLatency := time.Since(requestTime) b.updateIncomingCommunicationMetrics(bytesRead, requestLatency) @@ -1464,7 +1478,7 @@ func (b *Broker) sendAndReceiveSASLSCRAMv1() error { } func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (int, error) { - rb := &SaslAuthenticateRequest{msg} + rb := b.createSaslAuthenticateRequest(msg) req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb} buf, err := encode(req, b.conf.MetricRegistry) if err != nil { @@ -1474,6 +1488,15 @@ func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (i return b.write(buf) } +func (b *Broker) createSaslAuthenticateRequest(msg []byte) *SaslAuthenticateRequest { + authenticateRequest := SaslAuthenticateRequest{SaslAuthBytes: msg} + if b.conf.Version.IsAtLeast(V2_2_0_0) { + authenticateRequest.Version = 1 + } + + return &authenticateRequest +} + func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) { buf := make([]byte, responseLengthSize+correlationIDSize) _, err := b.readFull(buf) @@ -1538,32 +1561,34 @@ func mapToString(extensions map[string]string, keyValSep string, elemSep string) return strings.Join(buf, elemSep) } -func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, error) { +func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, int16, error) { authBytes := []byte(b.conf.Net.SASL.AuthIdentity + "\x00" + b.conf.Net.SASL.User + "\x00" + b.conf.Net.SASL.Password) - rb := &SaslAuthenticateRequest{authBytes} + rb := b.createSaslAuthenticateRequest(authBytes) req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb} buf, err := encode(req, b.conf.MetricRegistry) if err != nil { - return 0, err + return 0, rb.Version, err } - return b.write(buf) + write, err := b.write(buf) + return write, rb.Version, err } -func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, error) { - rb := &SaslAuthenticateRequest{initialResp} +func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, int16, error) { + rb := b.createSaslAuthenticateRequest(initialResp) req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb} buf, err := encode(req, b.conf.MetricRegistry) if err != nil { - return 0, err + return 0, rb.version(), err } - return b.write(buf) + write, err := b.write(buf) + return write, rb.version(), err } -func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32) (int, error) { +func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32, resVersion int16) (int, error) { buf := make([]byte, responseLengthSize+correlationIDSize) bytesRead, err := b.readFull(buf) if err != nil { @@ -1587,7 +1612,7 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl return bytesRead, err } - if err := versionedDecode(buf, res, 0); err != nil { + if err := versionedDecode(buf, res, resVersion); err != nil { return bytesRead, err } @@ -1599,6 +1624,21 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl return bytesRead, err } + if res.SessionLifetimeMs > 0 { + // Follows the Java Kafka implementation from SaslClientAuthenticator.ReauthInfo#setAuthenticationEndAndSessionReauthenticationTimes + // pick a random percentage between 85% and 95% for session re-authentication + positiveSessionLifetimeMs := res.SessionLifetimeMs + authenticationEndMs := currentUnixMilli() + pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount := 0.85 + pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously := 0.10 + pctToUse := pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + rand.Float64()*pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously + sessionLifetimeMsToUse := int64(float64(positiveSessionLifetimeMs) * pctToUse) + DebugLogger.Printf("Session expiration in %d ms and session re-authentication on or after %d ms", positiveSessionLifetimeMs, sessionLifetimeMsToUse) + b.clientSessionReauthenticationTimeMs = authenticationEndMs + sessionLifetimeMsToUse + } else { + b.clientSessionReauthenticationTimeMs = 0 + } + return bytesRead, nil } diff --git a/broker_test.go b/broker_test.go index 7d04473f5..0c23ab39a 100644 --- a/broker_test.go +++ b/broker_test.go @@ -828,6 +828,165 @@ func TestBuildClientFirstMessage(t *testing.T) { } } +func TestKip368ReAuthenticationSuccess(t *testing.T) { + sessionLifetimeMs := int64(100) + + mockBroker := NewMockBroker(t, 0) + + countSaslAuthRequests := func() (count int) { + for _, rr := range mockBroker.History() { + switch rr.Request.(type) { + case *SaslAuthenticateRequest: + count++ + } + } + return + } + + mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t). + SetAuthBytes([]byte(`response_payload`)). + SetSessionLifetimeMs(sessionLifetimeMs) + + mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t). + SetEnabledMechanisms([]string{SASLTypePlaintext}) + + mockApiVersions := NewMockApiVersionsResponse(t) + + mockBroker.SetHandlerByMap(map[string]MockResponse{ + "SaslAuthenticateRequest": mockSASLAuthResponse, + "SaslHandshakeRequest": mockSASLHandshakeResponse, + "ApiVersionsRequest": mockApiVersions, + }) + + broker := NewBroker(mockBroker.Addr()) + + conf := NewTestConfig() + conf.Net.SASL.Enable = true + conf.Net.SASL.Mechanism = SASLTypePlaintext + conf.Net.SASL.Version = SASLHandshakeV1 + conf.Net.SASL.AuthIdentity = "authid" + conf.Net.SASL.User = "token" + conf.Net.SASL.Password = "password" + + broker.conf = conf + broker.conf.Version = V2_2_0_0 + + err := broker.Open(conf) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = broker.Close() }) + + connected, err := broker.Connected() + if err != nil || !connected { + t.Fatal(err) + } + + actualSaslAuthRequests := countSaslAuthRequests() + if actualSaslAuthRequests != 1 { + t.Fatalf("unexpected number of SaslAuthRequests during initial authentication: %d", actualSaslAuthRequests) + } + + timeout := time.After(time.Duration(sessionLifetimeMs) * time.Millisecond) + +loop: + for actualSaslAuthRequests < 2 { + select { + case <-timeout: + break loop + default: + time.Sleep(10 * time.Millisecond) + // put some traffic on the wire + _, err = broker.ApiVersions(&ApiVersionsRequest{}) + if err != nil { + t.Fatal(err) + } + actualSaslAuthRequests = countSaslAuthRequests() + } + } + + if actualSaslAuthRequests < 2 { + t.Fatalf("sasl reauth has not occurred within expected timeframe") + } + + mockBroker.Close() +} + +func TestKip368ReAuthenticationFailure(t *testing.T) { + sessionLifetimeMs := int64(100) + + mockBroker := NewMockBroker(t, 0) + + mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t). + SetAuthBytes([]byte(`response_payload`)). + SetSessionLifetimeMs(sessionLifetimeMs) + + mockSASLAuthErrorResponse := NewMockSaslAuthenticateResponse(t). + SetError(ErrSASLAuthenticationFailed) + + mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t). + SetEnabledMechanisms([]string{SASLTypePlaintext}) + + mockApiVersions := NewMockApiVersionsResponse(t) + + mockBroker.SetHandlerByMap(map[string]MockResponse{ + "SaslAuthenticateRequest": mockSASLAuthResponse, + "SaslHandshakeRequest": mockSASLHandshakeResponse, + "ApiVersionsRequest": mockApiVersions, + }) + + broker := NewBroker(mockBroker.Addr()) + + conf := NewTestConfig() + conf.Net.SASL.Enable = true + conf.Net.SASL.Mechanism = SASLTypePlaintext + conf.Net.SASL.Version = SASLHandshakeV1 + conf.Net.SASL.AuthIdentity = "authid" + conf.Net.SASL.User = "token" + conf.Net.SASL.Password = "password" + + broker.conf = conf + broker.conf.Version = V2_2_0_0 + + err := broker.Open(conf) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = broker.Close() }) + + connected, err := broker.Connected() + if err != nil || !connected { + t.Fatal(err) + } + + mockBroker.SetHandlerByMap(map[string]MockResponse{ + "SaslAuthenticateRequest": mockSASLAuthErrorResponse, + "SaslHandshakeRequest": mockSASLHandshakeResponse, + "ApiVersionsRequest": mockApiVersions, + }) + + timeout := time.After(time.Duration(sessionLifetimeMs) * time.Millisecond) + + var apiVersionError error +loop: + for apiVersionError == nil { + select { + case <-timeout: + break loop + default: + time.Sleep(10 * time.Millisecond) + // put some traffic on the wire + _, apiVersionError = broker.ApiVersions(&ApiVersionsRequest{}) + } + } + + if !errors.Is(apiVersionError, ErrSASLAuthenticationFailed) { + t.Fatalf("sasl reauth has not failed in the expected way %v", apiVersionError) + } + + mockBroker.Close() +} + // We're not testing encoding/decoding here, so most of the requests/responses will be empty for simplicity's sake var brokerTestTable = []struct { version KafkaVersion diff --git a/mockresponses.go b/mockresponses.go index d26a44887..fff9dd77e 100644 --- a/mockresponses.go +++ b/mockresponses.go @@ -1057,9 +1057,10 @@ func (mr *MockListAclsResponse) For(reqBody versionedDecoder) encoderWithHeader } type MockSaslAuthenticateResponse struct { - t TestReporter - kerror KError - saslAuthBytes []byte + t TestReporter + kerror KError + saslAuthBytes []byte + sessionLifetimeMs int64 } func NewMockSaslAuthenticateResponse(t TestReporter) *MockSaslAuthenticateResponse { @@ -1067,9 +1068,12 @@ func NewMockSaslAuthenticateResponse(t TestReporter) *MockSaslAuthenticateRespon } func (msar *MockSaslAuthenticateResponse) For(reqBody versionedDecoder) encoderWithHeader { + req := reqBody.(*SaslAuthenticateRequest) res := &SaslAuthenticateResponse{} + res.Version = req.Version res.Err = msar.kerror res.SaslAuthBytes = msar.saslAuthBytes + res.SessionLifetimeMs = msar.sessionLifetimeMs return res } @@ -1083,6 +1087,11 @@ func (msar *MockSaslAuthenticateResponse) SetAuthBytes(saslAuthBytes []byte) *Mo return msar } +func (msar *MockSaslAuthenticateResponse) SetSessionLifetimeMs(sessionLifetimeMs int64) *MockSaslAuthenticateResponse { + msar.sessionLifetimeMs = sessionLifetimeMs + return msar +} + type MockDeleteAclsResponse struct { t TestReporter } diff --git a/sasl_authenticate_request.go b/sasl_authenticate_request.go index 90504df6f..5bb0988ea 100644 --- a/sasl_authenticate_request.go +++ b/sasl_authenticate_request.go @@ -1,6 +1,8 @@ package sarama type SaslAuthenticateRequest struct { + // Version defines the protocol version to use for encode and decode + Version int16 SaslAuthBytes []byte } @@ -12,6 +14,7 @@ func (r *SaslAuthenticateRequest) encode(pe packetEncoder) error { } func (r *SaslAuthenticateRequest) decode(pd packetDecoder, version int16) (err error) { + r.Version = version r.SaslAuthBytes, err = pd.getBytes() return err } @@ -21,7 +24,7 @@ func (r *SaslAuthenticateRequest) key() int16 { } func (r *SaslAuthenticateRequest) version() int16 { - return 0 + return r.Version } func (r *SaslAuthenticateRequest) headerVersion() int16 { @@ -29,5 +32,10 @@ func (r *SaslAuthenticateRequest) headerVersion() int16 { } func (r *SaslAuthenticateRequest) requiredVersion() KafkaVersion { - return V1_0_0_0 + switch r.Version { + case 1: + return V2_2_0_0 + default: + return V1_0_0_0 + } } diff --git a/sasl_authenticate_request_test.go b/sasl_authenticate_request_test.go index bf75004d2..fd13a7c4a 100644 --- a/sasl_authenticate_request_test.go +++ b/sasl_authenticate_request_test.go @@ -11,3 +11,10 @@ func TestSaslAuthenticateRequest(t *testing.T) { request.SaslAuthBytes = []byte(`foo`) testRequest(t, "basic", request, saslAuthenticateRequest) } + +func TestSaslAuthenticateRequestV1(t *testing.T) { + request := new(SaslAuthenticateRequest) + request.Version = 1 + request.SaslAuthBytes = []byte(`foo`) + testRequest(t, "basic", request, saslAuthenticateRequest) +} diff --git a/sasl_authenticate_response.go b/sasl_authenticate_response.go index 3ef57b5af..37c8e45da 100644 --- a/sasl_authenticate_response.go +++ b/sasl_authenticate_response.go @@ -1,9 +1,12 @@ package sarama type SaslAuthenticateResponse struct { - Err KError - ErrorMessage *string - SaslAuthBytes []byte + // Version defines the protocol version to use for encode and decode + Version int16 + Err KError + ErrorMessage *string + SaslAuthBytes []byte + SessionLifetimeMs int64 } func (r *SaslAuthenticateResponse) encode(pe packetEncoder) error { @@ -11,10 +14,17 @@ func (r *SaslAuthenticateResponse) encode(pe packetEncoder) error { if err := pe.putNullableString(r.ErrorMessage); err != nil { return err } - return pe.putBytes(r.SaslAuthBytes) + if err := pe.putBytes(r.SaslAuthBytes); err != nil { + return err + } + if r.Version > 0 { + pe.putInt64(r.SessionLifetimeMs) + } + return nil } func (r *SaslAuthenticateResponse) decode(pd packetDecoder, version int16) error { + r.Version = version kerr, err := pd.getInt16() if err != nil { return err @@ -26,7 +36,13 @@ func (r *SaslAuthenticateResponse) decode(pd packetDecoder, version int16) error return err } - r.SaslAuthBytes, err = pd.getBytes() + if r.SaslAuthBytes, err = pd.getBytes(); err != nil { + return err + } + + if version > 0 { + r.SessionLifetimeMs, err = pd.getInt64() + } return err } @@ -36,7 +52,7 @@ func (r *SaslAuthenticateResponse) key() int16 { } func (r *SaslAuthenticateResponse) version() int16 { - return 0 + return r.Version } func (r *SaslAuthenticateResponse) headerVersion() int16 { @@ -44,5 +60,10 @@ func (r *SaslAuthenticateResponse) headerVersion() int16 { } func (r *SaslAuthenticateResponse) requiredVersion() KafkaVersion { - return V1_0_0_0 + switch r.Version { + case 1: + return V2_2_0_0 + default: + return V1_0_0_0 + } } diff --git a/sasl_authenticate_response_test.go b/sasl_authenticate_response_test.go index 048dade19..ed037ed08 100644 --- a/sasl_authenticate_response_test.go +++ b/sasl_authenticate_response_test.go @@ -2,11 +2,19 @@ package sarama import "testing" -var saslAuthenticatResponseErr = []byte{ - 0, 58, - 0, 3, 'e', 'r', 'r', - 0, 0, 0, 3, 'm', 's', 'g', -} +var ( + saslAuthenticateResponseErr = []byte{ + 0, 58, + 0, 3, 'e', 'r', 'r', + 0, 0, 0, 3, 'm', 's', 'g', + } + saslAuthenticateResponseErrV1 = []byte{ + 0, 58, + 0, 3, 'e', 'r', 'r', + 0, 0, 0, 3, 'm', 's', 'g', + 0, 0, 0, 0, 0, 0, 0, 1, + } +) func TestSaslAuthenticateResponse(t *testing.T) { response := new(SaslAuthenticateResponse) @@ -15,5 +23,17 @@ func TestSaslAuthenticateResponse(t *testing.T) { response.ErrorMessage = &msg response.SaslAuthBytes = []byte(`msg`) - testResponse(t, "authenticate response", response, saslAuthenticatResponseErr) + testResponse(t, "authenticate response", response, saslAuthenticateResponseErr) +} + +func TestSaslAuthenticateResponseV1(t *testing.T) { + response := new(SaslAuthenticateResponse) + response.Err = ErrSASLAuthenticationFailed + msg := "err" + response.Version = 1 + response.ErrorMessage = &msg + response.SaslAuthBytes = []byte(`msg`) + response.SessionLifetimeMs = 1 + + testResponse(t, "authenticate response", response, saslAuthenticateResponseErrV1) }