Skip to content

Commit

Permalink
Merge pull request #2197 from k-wall/kip-368
Browse files Browse the repository at this point in the history
KIP-368 : Allow SASL Connections to Periodically Re-Authenticate
  • Loading branch information
dnwe authored Apr 13, 2022
2 parents 602d831 + aba7b01 commit ed494ad
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 34 deletions.
72 changes: 56 additions & 16 deletions broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"math/rand"
"net"
"sort"
"strconv"
Expand Down Expand Up @@ -52,7 +53,8 @@ type Broker struct {
brokerRequestsInFlight metrics.Counter
brokerThrottleTime metrics.Histogram

kerberosAuthenticator GSSAPIKerberosAuth
kerberosAuthenticator GSSAPIKerberosAuth
clientSessionReauthenticationTimeMs int64
}

// SASLMechanism specifies the SASL mechanism the client uses to authenticate with the broker
Expand Down Expand Up @@ -923,6 +925,13 @@ func (b *Broker) sendWithPromise(rb protocolBody, promise *responsePromise) erro
return ErrNotConnected
}

if b.clientSessionReauthenticationTimeMs > 0 && currentUnixMilli() > b.clientSessionReauthenticationTimeMs {
err := b.authenticateViaSASL()
if err != nil {
return err
}
}

if !b.conf.Version.IsAtLeast(rb.requiredVersion()) {
return ErrUnsupportedVersion
}
Expand Down Expand Up @@ -1263,7 +1272,7 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {

// Will be decremented in updateIncomingCommunicationMetrics (except error)
b.addRequestInFlightMetrics(1)
bytesWritten, err := b.sendSASLPlainAuthClientResponse(correlationID)
bytesWritten, resVersion, err := b.sendSASLPlainAuthClientResponse(correlationID)
b.updateOutgoingCommunicationMetrics(bytesWritten)

if err != nil {
Expand All @@ -1274,7 +1283,8 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {

b.correlationID++

bytesRead, err := b.receiveSASLServerResponse(&SaslAuthenticateResponse{}, correlationID)
res := &SaslAuthenticateResponse{}
bytesRead, err := b.receiveSASLServerResponse(res, correlationID, resVersion)
b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime))

// With v1 sasl we get an error message set in the response we can return
Expand All @@ -1288,6 +1298,10 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {
return nil
}

func currentUnixMilli() int64 {
return time.Now().UnixNano() / int64(time.Millisecond)
}

// sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255
// https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876
func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
Expand Down Expand Up @@ -1327,7 +1341,7 @@ func (b *Broker) sendClientMessage(message []byte) (bool, error) {
b.addRequestInFlightMetrics(1)
correlationID := b.correlationID

bytesWritten, err := b.sendSASLOAuthBearerClientMessage(message, correlationID)
bytesWritten, resVersion, err := b.sendSASLOAuthBearerClientMessage(message, correlationID)
b.updateOutgoingCommunicationMetrics(bytesWritten)
if err != nil {
b.addRequestInFlightMetrics(-1)
Expand All @@ -1337,7 +1351,7 @@ func (b *Broker) sendClientMessage(message []byte) (bool, error) {
b.correlationID++

res := &SaslAuthenticateResponse{}
bytesRead, err := b.receiveSASLServerResponse(res, correlationID)
bytesRead, err := b.receiveSASLServerResponse(res, correlationID, resVersion)

requestLatency := time.Since(requestTime)
b.updateIncomingCommunicationMetrics(bytesRead, requestLatency)
Expand Down Expand Up @@ -1464,7 +1478,7 @@ func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
}

func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (int, error) {
rb := &SaslAuthenticateRequest{msg}
rb := b.createSaslAuthenticateRequest(msg)
req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
if err != nil {
Expand All @@ -1474,6 +1488,15 @@ func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (i
return b.write(buf)
}

func (b *Broker) createSaslAuthenticateRequest(msg []byte) *SaslAuthenticateRequest {
authenticateRequest := SaslAuthenticateRequest{SaslAuthBytes: msg}
if b.conf.Version.IsAtLeast(V2_2_0_0) {
authenticateRequest.Version = 1
}

return &authenticateRequest
}

func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) {
buf := make([]byte, responseLengthSize+correlationIDSize)
_, err := b.readFull(buf)
Expand Down Expand Up @@ -1538,32 +1561,34 @@ func mapToString(extensions map[string]string, keyValSep string, elemSep string)
return strings.Join(buf, elemSep)
}

func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, error) {
func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, int16, error) {
authBytes := []byte(b.conf.Net.SASL.AuthIdentity + "\x00" + b.conf.Net.SASL.User + "\x00" + b.conf.Net.SASL.Password)
rb := &SaslAuthenticateRequest{authBytes}
rb := b.createSaslAuthenticateRequest(authBytes)
req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
if err != nil {
return 0, err
return 0, rb.Version, err
}

return b.write(buf)
write, err := b.write(buf)
return write, rb.Version, err
}

func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, error) {
rb := &SaslAuthenticateRequest{initialResp}
func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, int16, error) {
rb := b.createSaslAuthenticateRequest(initialResp)

req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}

buf, err := encode(req, b.conf.MetricRegistry)
if err != nil {
return 0, err
return 0, rb.version(), err
}

return b.write(buf)
write, err := b.write(buf)
return write, rb.version(), err
}

func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32) (int, error) {
func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32, resVersion int16) (int, error) {
buf := make([]byte, responseLengthSize+correlationIDSize)
bytesRead, err := b.readFull(buf)
if err != nil {
Expand All @@ -1587,7 +1612,7 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl
return bytesRead, err
}

if err := versionedDecode(buf, res, 0); err != nil {
if err := versionedDecode(buf, res, resVersion); err != nil {
return bytesRead, err
}

Expand All @@ -1599,6 +1624,21 @@ func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correl
return bytesRead, err
}

if res.SessionLifetimeMs > 0 {
// Follows the Java Kafka implementation from SaslClientAuthenticator.ReauthInfo#setAuthenticationEndAndSessionReauthenticationTimes
// pick a random percentage between 85% and 95% for session re-authentication
positiveSessionLifetimeMs := res.SessionLifetimeMs
authenticationEndMs := currentUnixMilli()
pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount := 0.85
pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously := 0.10
pctToUse := pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + rand.Float64()*pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously
sessionLifetimeMsToUse := int64(float64(positiveSessionLifetimeMs) * pctToUse)
DebugLogger.Printf("Session expiration in %d ms and session re-authentication on or after %d ms", positiveSessionLifetimeMs, sessionLifetimeMsToUse)
b.clientSessionReauthenticationTimeMs = authenticationEndMs + sessionLifetimeMsToUse
} else {
b.clientSessionReauthenticationTimeMs = 0
}

return bytesRead, nil
}

Expand Down
159 changes: 159 additions & 0 deletions broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,165 @@ func TestBuildClientFirstMessage(t *testing.T) {
}
}

func TestKip368ReAuthenticationSuccess(t *testing.T) {
sessionLifetimeMs := int64(100)

mockBroker := NewMockBroker(t, 0)

countSaslAuthRequests := func() (count int) {
for _, rr := range mockBroker.History() {
switch rr.Request.(type) {
case *SaslAuthenticateRequest:
count++
}
}
return
}

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
SetAuthBytes([]byte(`response_payload`)).
SetSessionLifetimeMs(sessionLifetimeMs)

mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypePlaintext})

mockApiVersions := NewMockApiVersionsResponse(t)

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthResponse,
"SaslHandshakeRequest": mockSASLHandshakeResponse,
"ApiVersionsRequest": mockApiVersions,
})

broker := NewBroker(mockBroker.Addr())

conf := NewTestConfig()
conf.Net.SASL.Enable = true
conf.Net.SASL.Mechanism = SASLTypePlaintext
conf.Net.SASL.Version = SASLHandshakeV1
conf.Net.SASL.AuthIdentity = "authid"
conf.Net.SASL.User = "token"
conf.Net.SASL.Password = "password"

broker.conf = conf
broker.conf.Version = V2_2_0_0

err := broker.Open(conf)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = broker.Close() })

connected, err := broker.Connected()
if err != nil || !connected {
t.Fatal(err)
}

actualSaslAuthRequests := countSaslAuthRequests()
if actualSaslAuthRequests != 1 {
t.Fatalf("unexpected number of SaslAuthRequests during initial authentication: %d", actualSaslAuthRequests)
}

timeout := time.After(time.Duration(sessionLifetimeMs) * time.Millisecond)

loop:
for actualSaslAuthRequests < 2 {
select {
case <-timeout:
break loop
default:
time.Sleep(10 * time.Millisecond)
// put some traffic on the wire
_, err = broker.ApiVersions(&ApiVersionsRequest{})
if err != nil {
t.Fatal(err)
}
actualSaslAuthRequests = countSaslAuthRequests()
}
}

if actualSaslAuthRequests < 2 {
t.Fatalf("sasl reauth has not occurred within expected timeframe")
}

mockBroker.Close()
}

func TestKip368ReAuthenticationFailure(t *testing.T) {
sessionLifetimeMs := int64(100)

mockBroker := NewMockBroker(t, 0)

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
SetAuthBytes([]byte(`response_payload`)).
SetSessionLifetimeMs(sessionLifetimeMs)

mockSASLAuthErrorResponse := NewMockSaslAuthenticateResponse(t).
SetError(ErrSASLAuthenticationFailed)

mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypePlaintext})

mockApiVersions := NewMockApiVersionsResponse(t)

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthResponse,
"SaslHandshakeRequest": mockSASLHandshakeResponse,
"ApiVersionsRequest": mockApiVersions,
})

broker := NewBroker(mockBroker.Addr())

conf := NewTestConfig()
conf.Net.SASL.Enable = true
conf.Net.SASL.Mechanism = SASLTypePlaintext
conf.Net.SASL.Version = SASLHandshakeV1
conf.Net.SASL.AuthIdentity = "authid"
conf.Net.SASL.User = "token"
conf.Net.SASL.Password = "password"

broker.conf = conf
broker.conf.Version = V2_2_0_0

err := broker.Open(conf)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = broker.Close() })

connected, err := broker.Connected()
if err != nil || !connected {
t.Fatal(err)
}

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthErrorResponse,
"SaslHandshakeRequest": mockSASLHandshakeResponse,
"ApiVersionsRequest": mockApiVersions,
})

timeout := time.After(time.Duration(sessionLifetimeMs) * time.Millisecond)

var apiVersionError error
loop:
for apiVersionError == nil {
select {
case <-timeout:
break loop
default:
time.Sleep(10 * time.Millisecond)
// put some traffic on the wire
_, apiVersionError = broker.ApiVersions(&ApiVersionsRequest{})
}
}

if !errors.Is(apiVersionError, ErrSASLAuthenticationFailed) {
t.Fatalf("sasl reauth has not failed in the expected way %v", apiVersionError)
}

mockBroker.Close()
}

// We're not testing encoding/decoding here, so most of the requests/responses will be empty for simplicity's sake
var brokerTestTable = []struct {
version KafkaVersion
Expand Down
15 changes: 12 additions & 3 deletions mockresponses.go
Original file line number Diff line number Diff line change
Expand Up @@ -1057,19 +1057,23 @@ func (mr *MockListAclsResponse) For(reqBody versionedDecoder) encoderWithHeader
}

type MockSaslAuthenticateResponse struct {
t TestReporter
kerror KError
saslAuthBytes []byte
t TestReporter
kerror KError
saslAuthBytes []byte
sessionLifetimeMs int64
}

func NewMockSaslAuthenticateResponse(t TestReporter) *MockSaslAuthenticateResponse {
return &MockSaslAuthenticateResponse{t: t}
}

func (msar *MockSaslAuthenticateResponse) For(reqBody versionedDecoder) encoderWithHeader {
req := reqBody.(*SaslAuthenticateRequest)
res := &SaslAuthenticateResponse{}
res.Version = req.Version
res.Err = msar.kerror
res.SaslAuthBytes = msar.saslAuthBytes
res.SessionLifetimeMs = msar.sessionLifetimeMs
return res
}

Expand All @@ -1083,6 +1087,11 @@ func (msar *MockSaslAuthenticateResponse) SetAuthBytes(saslAuthBytes []byte) *Mo
return msar
}

func (msar *MockSaslAuthenticateResponse) SetSessionLifetimeMs(sessionLifetimeMs int64) *MockSaslAuthenticateResponse {
msar.sessionLifetimeMs = sessionLifetimeMs
return msar
}

type MockDeleteAclsResponse struct {
t TestReporter
}
Expand Down
Loading

0 comments on commit ed494ad

Please sign in to comment.