Skip to content

Commit

Permalink
Fixes map concurrency issue on tests
Browse files Browse the repository at this point in the history
  • Loading branch information
raulnegreiros committed Jan 5, 2022
1 parent 49ab66f commit fb8dd6d
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions mockresponses.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sarama
import (
"fmt"
"strings"
"sync"
)

// TestReporter has methods matching go's testing.T to avoid importing
Expand Down Expand Up @@ -264,6 +265,7 @@ func (mor *MockOffsetResponse) getOffset(topic string, partition int32, time int
// MockFetchResponse is a `FetchResponse` builder.
type MockFetchResponse struct {
messages map[string]map[int32]map[int64]Encoder
messagesLock *sync.RWMutex
highWaterMarks map[string]map[int32]int64
t TestReporter
batchSize int
Expand All @@ -273,6 +275,7 @@ type MockFetchResponse struct {
func NewMockFetchResponse(t TestReporter, batchSize int) *MockFetchResponse {
return &MockFetchResponse{
messages: make(map[string]map[int32]map[int64]Encoder),
messagesLock: &sync.RWMutex{},
highWaterMarks: make(map[string]map[int32]int64),
t: t,
batchSize: batchSize,
Expand All @@ -285,6 +288,8 @@ func (mfr *MockFetchResponse) SetVersion(version int16) *MockFetchResponse {
}

func (mfr *MockFetchResponse) SetMessage(topic string, partition int32, offset int64, msg Encoder) *MockFetchResponse {
mfr.messagesLock.Lock()
defer mfr.messagesLock.Unlock()
partitions := mfr.messages[topic]
if partitions == nil {
partitions = make(map[int32]map[int64]Encoder)
Expand Down Expand Up @@ -339,6 +344,8 @@ func (mfr *MockFetchResponse) For(reqBody versionedDecoder) encoderWithHeader {
}

func (mfr *MockFetchResponse) getMessage(topic string, partition int32, offset int64) Encoder {
mfr.messagesLock.RLock()
defer mfr.messagesLock.RUnlock()
partitions := mfr.messages[topic]
if partitions == nil {
return nil
Expand All @@ -351,6 +358,8 @@ func (mfr *MockFetchResponse) getMessage(topic string, partition int32, offset i
}

func (mfr *MockFetchResponse) getMessageCount(topic string, partition int32) int {
mfr.messagesLock.RLock()
defer mfr.messagesLock.RUnlock()
partitions := mfr.messages[topic]
if partitions == nil {
return 0
Expand Down

0 comments on commit fb8dd6d

Please sign in to comment.