Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
k-wall committed Mar 30, 2022
1 parent 88b1c46 commit 455bae0
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 3 deletions.
159 changes: 159 additions & 0 deletions broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,165 @@ func TestBuildClientFirstMessage(t *testing.T) {
}
}

func TestKip368ReAuthenticationSuccess(t *testing.T) {
sessionLifetimeMs := int64(100)

mockBroker := NewMockBroker(t, 0)

countSaslAuthRequests := func() (count int) {
for _, rr := range mockBroker.History() {
switch rr.Request.(type) {
case *SaslAuthenticateRequest:
count++
}
}
return
}

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
SetAuthBytes([]byte(`response_payload`)).
SetSessionLifetimeMs(sessionLifetimeMs)

mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypePlaintext})

mockApiVersions := NewMockApiVersionsResponse(t)

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthResponse,
"SaslHandshakeRequest": mockSASLHandshakeResponse,
"ApiVersionsRequest": mockApiVersions,
})

broker := NewBroker(mockBroker.Addr())

conf := NewTestConfig()
conf.Net.SASL.Enable = true
conf.Net.SASL.Mechanism = SASLTypePlaintext
conf.Net.SASL.Version = SASLHandshakeV1
conf.Net.SASL.AuthIdentity = "authid"
conf.Net.SASL.User = "token"
conf.Net.SASL.Password = "password"

broker.conf = conf
broker.conf.Version = V2_2_0_0

err := broker.Open(conf)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = broker.Close() })

connected, err := broker.Connected()
if err != nil || !connected {
t.Fatal(err)
}

actualSaslAuthRequests := countSaslAuthRequests()
if actualSaslAuthRequests != 1 {
t.Fatalf("unexpected number of SaslAuthRequests during initial authentication: %d", actualSaslAuthRequests)
}

timeout := time.After(time.Duration(sessionLifetimeMs) * time.Millisecond)

loop:
for actualSaslAuthRequests < 2 {
select {
case <-timeout:
break loop
default:
time.Sleep(10 * time.Millisecond)
// put some traffic on the wire
_, err = broker.ApiVersions(&ApiVersionsRequest{})
if err != nil {
t.Fatal(err)
}
actualSaslAuthRequests = countSaslAuthRequests()
}
}

if actualSaslAuthRequests < 2 {
t.Fatalf("sasl reauth has not occurred within expected timeframe")
}

mockBroker.Close()
}

func TestKip368ReAuthenticationFailure(t *testing.T) {
sessionLifetimeMs := int64(100)

mockBroker := NewMockBroker(t, 0)

mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
SetAuthBytes([]byte(`response_payload`)).
SetSessionLifetimeMs(sessionLifetimeMs)

mockSASLAuthErrorResponse := NewMockSaslAuthenticateResponse(t).
SetError(ErrSASLAuthenticationFailed)

mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
SetEnabledMechanisms([]string{SASLTypePlaintext})

mockApiVersions := NewMockApiVersionsResponse(t)

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthResponse,
"SaslHandshakeRequest": mockSASLHandshakeResponse,
"ApiVersionsRequest": mockApiVersions,
})

broker := NewBroker(mockBroker.Addr())

conf := NewTestConfig()
conf.Net.SASL.Enable = true
conf.Net.SASL.Mechanism = SASLTypePlaintext
conf.Net.SASL.Version = SASLHandshakeV1
conf.Net.SASL.AuthIdentity = "authid"
conf.Net.SASL.User = "token"
conf.Net.SASL.Password = "password"

broker.conf = conf
broker.conf.Version = V2_2_0_0

err := broker.Open(conf)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = broker.Close() })

connected, err := broker.Connected()
if err != nil || !connected {
t.Fatal(err)
}

mockBroker.SetHandlerByMap(map[string]MockResponse{
"SaslAuthenticateRequest": mockSASLAuthErrorResponse,
"SaslHandshakeRequest": mockSASLHandshakeResponse,
"ApiVersionsRequest": mockApiVersions,
})

timeout := time.After(time.Duration(sessionLifetimeMs) * time.Millisecond)

var apiVersionError error
loop:
for apiVersionError == nil {
select {
case <-timeout:
break loop
default:
time.Sleep(10 * time.Millisecond)
// put some traffic on the wire
_, apiVersionError = broker.ApiVersions(&ApiVersionsRequest{})
}
}

if !errors.Is(apiVersionError, ErrSASLAuthenticationFailed) {
t.Fatalf("sasl reauth has not failed in the expected way %v", apiVersionError)
}

mockBroker.Close()
}

// We're not testing encoding/decoding here, so most of the requests/responses will be empty for simplicity's sake
var brokerTestTable = []struct {
version KafkaVersion
Expand Down
15 changes: 12 additions & 3 deletions mockresponses.go
Original file line number Diff line number Diff line change
Expand Up @@ -1057,19 +1057,23 @@ func (mr *MockListAclsResponse) For(reqBody versionedDecoder) encoderWithHeader
}

type MockSaslAuthenticateResponse struct {
t TestReporter
kerror KError
saslAuthBytes []byte
t TestReporter
kerror KError
saslAuthBytes []byte
sessionLifetimeMs int64
}

func NewMockSaslAuthenticateResponse(t TestReporter) *MockSaslAuthenticateResponse {
return &MockSaslAuthenticateResponse{t: t}
}

func (msar *MockSaslAuthenticateResponse) For(reqBody versionedDecoder) encoderWithHeader {
req := reqBody.(*SaslAuthenticateRequest)
res := &SaslAuthenticateResponse{}
res.Version = req.Version
res.Err = msar.kerror
res.SaslAuthBytes = msar.saslAuthBytes
res.SessionLifetimeMs = msar.sessionLifetimeMs
return res
}

Expand All @@ -1083,6 +1087,11 @@ func (msar *MockSaslAuthenticateResponse) SetAuthBytes(saslAuthBytes []byte) *Mo
return msar
}

func (msar *MockSaslAuthenticateResponse) SetSessionLifetimeMs(sessionLifetimeMs int64) *MockSaslAuthenticateResponse {
msar.sessionLifetimeMs = sessionLifetimeMs
return msar
}

type MockDeleteAclsResponse struct {
t TestReporter
}
Expand Down

0 comments on commit 455bae0

Please sign in to comment.