Skip to content

Commit

Permalink
Merge pull request #1956 from joewreschnig/messagechecker
Browse files Browse the repository at this point in the history
Allow checking the entire `ProducerMessage` in the mock producers
  • Loading branch information
dnwe authored Jun 9, 2021
2 parents ad66013 + ab79623 commit 03b4a4f
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 54 deletions.
84 changes: 57 additions & 27 deletions mocks/async_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
// Before you can send messages to it's Input channel, you have to set expectations
// so it knows how to handle the input; it returns an error if the number of messages
// received is bigger then the number of expectations set. You can also set a
// function in each expectation so that the message value is checked by this function
// and an error is returned if the match fails.
// function in each expectation so that the message is checked by this function and
// an error is returned if the match fails.
type AsyncProducer struct {
l sync.Mutex
t ErrorReporter
Expand All @@ -21,12 +21,13 @@ type AsyncProducer struct {
successes chan *sarama.ProducerMessage
errors chan *sarama.ProducerError
lastOffset int64
*TopicConfig
}

// NewAsyncProducer instantiates a new Producer mock. The t argument should
// be the *testing.T instance of your test method. An error will be written to it if
// an expectation is violated. The config argument is used to determine whether it
// should ack successes on the Successes channel.
// should ack successes on the Successes channel and to handle partitioning.
func NewAsyncProducer(t ErrorReporter, config *sarama.Config) *AsyncProducer {
if config == nil {
config = sarama.NewConfig()
Expand All @@ -38,6 +39,7 @@ func NewAsyncProducer(t ErrorReporter, config *sarama.Config) *AsyncProducer {
input: make(chan *sarama.ProducerMessage, config.ChannelBufferSize),
successes: make(chan *sarama.ProducerMessage, config.ChannelBufferSize),
errors: make(chan *sarama.ProducerError, config.ChannelBufferSize),
TopicConfig: NewTopicConfig(),
}

go func() {
Expand All @@ -47,35 +49,45 @@ func NewAsyncProducer(t ErrorReporter, config *sarama.Config) *AsyncProducer {
close(mp.closed)
}()

partitioners := make(map[string]sarama.Partitioner, 1)

for msg := range mp.input {
partitioner := partitioners[msg.Topic]
if partitioner == nil {
partitioner = config.Producer.Partitioner(msg.Topic)
partitioners[msg.Topic] = partitioner
}
mp.l.Lock()
if mp.expectations == nil || len(mp.expectations) == 0 {
mp.expectations = nil
mp.t.Errorf("No more expectation set on this mock producer to handle the input message.")
} else {
expectation := mp.expectations[0]
mp.expectations = mp.expectations[1:]
if expectation.CheckFunction != nil {
if val, err := msg.Value.Encode(); err != nil {
mp.t.Errorf("Input message encoding failed: %s", err.Error())
mp.errors <- &sarama.ProducerError{Err: err, Msg: msg}
} else {
err = expectation.CheckFunction(val)

partition, err := partitioner.Partition(msg, mp.partitions(msg.Topic))
if err != nil {
mp.t.Errorf("Partitioner returned an error: %s", err.Error())
mp.errors <- &sarama.ProducerError{Err: err, Msg: msg}
} else {
msg.Partition = partition
if expectation.CheckFunction != nil {
err := expectation.CheckFunction(msg)
if err != nil {
mp.t.Errorf("Check function returned an error: %s", err.Error())
mp.errors <- &sarama.ProducerError{Err: err, Msg: msg}
}
}
}
if expectation.Result == errProduceSuccess {
mp.lastOffset++
if config.Producer.Return.Successes {
msg.Offset = mp.lastOffset
mp.successes <- msg
}
} else {
if config.Producer.Return.Errors {
mp.errors <- &sarama.ProducerError{Err: expectation.Result, Msg: msg}
if expectation.Result == errProduceSuccess {
mp.lastOffset++
if config.Producer.Return.Successes {
msg.Offset = mp.lastOffset
mp.successes <- msg
}
} else {
if config.Producer.Return.Errors {
mp.errors <- &sarama.ProducerError{Err: expectation.Result, Msg: msg}
}
}
}
}
Expand Down Expand Up @@ -135,15 +147,35 @@ func (mp *AsyncProducer) Errors() <-chan *sarama.ProducerError {
// Setting expectations
////////////////////////////////////////////////

// ExpectInputWithMessageCheckerFunctionAndSucceed sets an expectation on the mock producer that a
// message will be provided on the input channel. The mock producer will call the given function to
// check the message. If an error is returned it will be made available on the Errors channel
// otherwise the mock will handle the message as if it produced successfully, i.e. it will make it
// available on the Successes channel if the Producer.Return.Successes setting is set to true.
func (mp *AsyncProducer) ExpectInputWithMessageCheckerFunctionAndSucceed(cf MessageChecker) {
mp.l.Lock()
defer mp.l.Unlock()
mp.expectations = append(mp.expectations, &producerExpectation{Result: errProduceSuccess, CheckFunction: cf})
}

// ExpectInputWithMessageCheckerFunctionAndFail sets an expectation on the mock producer that a
// message will be provided on the input channel. The mock producer will first call the given
// function to check the message. If an error is returned it will be made available on the Errors
// channel otherwise the mock will handle the message as if it failed to produce successfully. This
// means it will make a ProducerError available on the Errors channel.
func (mp *AsyncProducer) ExpectInputWithMessageCheckerFunctionAndFail(cf MessageChecker, err error) {
mp.l.Lock()
defer mp.l.Unlock()
mp.expectations = append(mp.expectations, &producerExpectation{Result: err, CheckFunction: cf})
}

// ExpectInputWithCheckerFunctionAndSucceed sets an expectation on the mock producer that a message
// will be provided on the input channel. The mock producer will call the given function to check
// the message value. If an error is returned it will be made available on the Errors channel
// otherwise the mock will handle the message as if it produced successfully, i.e. it will make
// it available on the Successes channel if the Producer.Return.Successes setting is set to true.
func (mp *AsyncProducer) ExpectInputWithCheckerFunctionAndSucceed(cf ValueChecker) {
mp.l.Lock()
defer mp.l.Unlock()
mp.expectations = append(mp.expectations, &producerExpectation{Result: errProduceSuccess, CheckFunction: cf})
mp.ExpectInputWithMessageCheckerFunctionAndSucceed(messageValueChecker(cf))
}

// ExpectInputWithCheckerFunctionAndFail sets an expectation on the mock producer that a message
Expand All @@ -152,22 +184,20 @@ func (mp *AsyncProducer) ExpectInputWithCheckerFunctionAndSucceed(cf ValueChecke
// otherwise the mock will handle the message as if it failed to produce successfully. This means
// it will make a ProducerError available on the Errors channel.
func (mp *AsyncProducer) ExpectInputWithCheckerFunctionAndFail(cf ValueChecker, err error) {
mp.l.Lock()
defer mp.l.Unlock()
mp.expectations = append(mp.expectations, &producerExpectation{Result: err, CheckFunction: cf})
mp.ExpectInputWithMessageCheckerFunctionAndFail(messageValueChecker(cf), err)
}

// ExpectInputAndSucceed sets an expectation on the mock producer that a message will be provided
// on the input channel. The mock producer will handle the message as if it is produced successfully,
// i.e. it will make it available on the Successes channel if the Producer.Return.Successes setting
// is set to true.
func (mp *AsyncProducer) ExpectInputAndSucceed() {
mp.ExpectInputWithCheckerFunctionAndSucceed(nil)
mp.ExpectInputWithMessageCheckerFunctionAndSucceed(nil)
}

// ExpectInputAndFail sets an expectation on the mock producer that a message will be provided
// on the input channel. The mock producer will handle the message as if it failed to produce
// successfully. This means it will make a ProducerError available on the Errors channel.
func (mp *AsyncProducer) ExpectInputAndFail(err error) {
mp.ExpectInputWithCheckerFunctionAndFail(nil, err)
mp.ExpectInputWithMessageCheckerFunctionAndFail(nil, err)
}
42 changes: 42 additions & 0 deletions mocks/async_producer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,45 @@ func TestProducerWithCheckerFunction(t *testing.T) {
t.Error("Expected to report a value check error, found: ", err1.Err)
}
}

func TestProducerWithBrokenPartitioner(t *testing.T) {
trm := newTestReporterMock()
config := sarama.NewConfig()
config.Producer.Partitioner = func(string) sarama.Partitioner {
return brokePartitioner{}
}
mp := NewAsyncProducer(trm, config)
mp.ExpectInputWithMessageCheckerFunctionAndSucceed(func(msg *sarama.ProducerMessage) error {
if msg.Partition != 15 {
t.Error("Expected partition 15, found: ", msg.Partition)
}
if msg.Topic != "test" {
t.Errorf(`Expected topic "test", found: %q`, msg.Topic)
}
return nil
})
mp.ExpectInputAndSucceed() // should actually fail in partitioning

mp.Input() <- &sarama.ProducerMessage{Topic: "test"}
mp.Input() <- &sarama.ProducerMessage{Topic: "not-test"}
if err := mp.Close(); err != nil {
t.Error(err)
}

if len(trm.errors) != 1 || !strings.Contains(trm.errors[0], "partitioning unavailable") {
t.Error("Expected to report partitioning unavailable, found", trm.errors)
}
}

// brokeProducer refuses to partition anything not on the “test” topic, and sends everything on
// that topic to partition 15.
type brokePartitioner struct{}

func (brokePartitioner) Partition(msg *sarama.ProducerMessage, n int32) (int32, error) {
if msg.Topic == "test" {
return 15, nil
}
return 0, errors.New("partitioning unavailable")
}

func (brokePartitioner) RequiresConsistency() bool { return false }
58 changes: 57 additions & 1 deletion mocks/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package mocks

import (
"errors"
"fmt"

"github.com/Shopify/sarama"
)
Expand All @@ -29,6 +30,26 @@ type ErrorReporter interface {
// to check the value passed.
type ValueChecker func(val []byte) error

// MessageChecker is a function type to be set in each expectation of the producer mocks
// to check the message passed.
type MessageChecker func(*sarama.ProducerMessage) error

// messageValueChecker wraps a ValueChecker into a MessageChecker.
// Failure to encode the message value will return an error and not call
// the wrapped ValueChecker.
func messageValueChecker(f ValueChecker) MessageChecker {
if f == nil {
return nil
}
return func(msg *sarama.ProducerMessage) error {
val, err := msg.Value.Encode()
if err != nil {
return fmt.Errorf("Input message encoding failed: %s", err.Error())
}
return f(val)
}
}

var (
errProduceSuccess error = nil
errOutOfExpectations = errors.New("No more expectations set on mock")
Expand All @@ -39,7 +60,42 @@ const AnyOffset int64 = -1000

type producerExpectation struct {
Result error
CheckFunction ValueChecker
CheckFunction MessageChecker
}

// TopicConfig describes a mock topic structure for the mock producers’ partitioning needs.
type TopicConfig struct {
overridePartitions map[string]int32
defaultPartitions int32
}

// NewTopicConfig makes a configuration which defaults to 32 partitions for every topic.
func NewTopicConfig() *TopicConfig {
return &TopicConfig{
overridePartitions: make(map[string]int32, 0),
defaultPartitions: 32,
}
}

// SetDefaultPartitions sets the number of partitions any topic not explicitly configured otherwise
// (by SetPartitions) will have from the perspective of created partitioners.
func (pc *TopicConfig) SetDefaultPartitions(n int32) {
pc.defaultPartitions = n
}

// SetPartitions sets the number of partitions the partitioners will see for specific topics. This
// only applies to messages produced after setting them.
func (pc *TopicConfig) SetPartitions(partitions map[string]int32) {
for p, n := range partitions {
pc.overridePartitions[p] = n
}
}

func (pc *TopicConfig) partitions(topic string) int32 {
if n, found := pc.overridePartitions[topic]; found {
return n
}
return pc.defaultPartitions
}

// NewTestConfig returns a config meant to be used by tests.
Expand Down
Loading

0 comments on commit 03b4a4f

Please sign in to comment.