diff --git a/mocks/sync_producer.go b/mocks/sync_producer.go index fa86b245c..7541841f6 100644 --- a/mocks/sync_producer.go +++ b/mocks/sync_producer.go @@ -57,6 +57,31 @@ func (sp *SyncProducer) SendMessage(msg *sarama.ProducerMessage) (partition int3 } } +// SendMessages corresponds with the SendMessages method of sarama's SyncProducer implementation. +// You have to set expectations on the mock producer before calling SendMessages, so it knows +// how to handle them. If there is no more remaining expectations when SendMessages is called, +// the mock producer will write an error to the test state object. +func (sp *SyncProducer) SendMessages(msgs []*sarama.ProducerMessage) error { + sp.l.Lock() + defer sp.l.Unlock() + + if len(sp.expectations) >= len(msgs) { + expectations := sp.expectations[0 : len(msgs)-1] + sp.expectations = sp.expectations[len(msgs):] + + for _, expectation := range expectations { + if expectation.Result != errProduceSuccess { + return expectation.Result + } + + } + return nil + } else { + sp.t.Errorf("Insufficient expectations set on this mock producer to handle the input messages.") + return errOutOfExpectations + } +} + // Close corresponds with the Close method of sarama's SyncProducer implementation. // By closing a mock syncproducer, you also tell it that no more SendMessage calls will follow, // so it will write an error to the test state if there's any remaining expectations. diff --git a/sync_producer.go b/sync_producer.go index 2e6f87b8e..b181527f0 100644 --- a/sync_producer.go +++ b/sync_producer.go @@ -16,6 +16,12 @@ type SyncProducer interface { // of the produced message, or an error if the message failed to produce. SendMessage(msg *ProducerMessage) (partition int32, offset int64, err error) + // SendMessages produces a given set of messages, and returns only when all + // messages in the set have either succeeded or failed. Note that messages + // can succeed and fail individually; if some succeed and some fail, + // SendMessages will return an error. + SendMessages(msgs []*ProducerMessage) error + // Close shuts down the producer and flushes any messages it may have buffered. // You must call this function before a producer object passes out of scope, as // it may otherwise leak memory. You must call this before calling Close on the @@ -65,21 +71,56 @@ func (sp *syncProducer) SendMessage(msg *ProducerMessage) (partition int32, offs msg.Metadata = oldMetadata }() - expectation := make(chan error, 1) + expectation := make(chan *ProducerError, 1) msg.Metadata = expectation sp.producer.Input() <- msg if err := <-expectation; err != nil { - return -1, -1, err + return -1, -1, err.Err } return msg.Partition, msg.Offset, nil } +func (sp *syncProducer) SendMessages(msgs []*ProducerMessage) error { + savedMetadata := make([]interface{}, len(msgs)) + for i := range msgs { + savedMetadata[i] = msgs[i].Metadata + } + defer func() { + for i := range msgs { + msgs[i].Metadata = savedMetadata[i] + } + }() + + expectations := make(chan chan *ProducerError, len(msgs)) + go func() { + for _, msg := range msgs { + expectation := make(chan *ProducerError, 1) + msg.Metadata = expectation + sp.producer.Input() <- msg + expectations <- expectation + } + close(expectations) + }() + + var errors ProducerErrors + for expectation := range expectations { + if err := <-expectation; err != nil { + errors = append(errors, err) + } + } + + if len(errors) > 0 { + return errors + } + return nil +} + func (sp *syncProducer) handleSuccesses() { defer sp.wg.Done() for msg := range sp.producer.Successes() { - expectation := msg.Metadata.(chan error) + expectation := msg.Metadata.(chan *ProducerError) expectation <- nil } } @@ -87,8 +128,8 @@ func (sp *syncProducer) handleSuccesses() { func (sp *syncProducer) handleErrors() { defer sp.wg.Done() for err := range sp.producer.Errors() { - expectation := err.Msg.Metadata.(chan error) - expectation <- err.Err + expectation := err.Msg.Metadata.(chan *ProducerError) + expectation <- err } } diff --git a/sync_producer_test.go b/sync_producer_test.go index 765877466..12ed20e1f 100644 --- a/sync_producer_test.go +++ b/sync_producer_test.go @@ -54,6 +54,53 @@ func TestSyncProducer(t *testing.T) { seedBroker.Close() } +func TestSyncProducerBatch(t *testing.T) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + + metadataResponse := new(MetadataResponse) + metadataResponse.AddBroker(leader.Addr(), leader.BrokerID()) + metadataResponse.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, ErrNoError) + seedBroker.Returns(metadataResponse) + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader.Returns(prodSuccess) + + config := NewConfig() + config.Producer.Flush.Messages = 3 + producer, err := NewSyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + err = producer.SendMessages([]*ProducerMessage{ + &ProducerMessage{ + Topic: "my_topic", + Value: StringEncoder(TestMessage), + Metadata: "test", + }, + &ProducerMessage{ + Topic: "my_topic", + Value: StringEncoder(TestMessage), + Metadata: "test", + }, + &ProducerMessage{ + Topic: "my_topic", + Value: StringEncoder(TestMessage), + Metadata: "test", + }, + }) + + if err != nil { + t.Error(err) + } + + safeClose(t, producer) + leader.Close() + seedBroker.Close() +} + func TestConcurrentSyncProducer(t *testing.T) { seedBroker := NewMockBroker(t, 1) leader := NewMockBroker(t, 2)