diff --git a/writer_test.go b/writer_test.go index d9ba2b0e4..15cedc465 100644 --- a/writer_test.go +++ b/writer_test.go @@ -5,6 +5,7 @@ import ( "errors" "io" "math" + "sync" "testing" "time" ) @@ -664,16 +665,30 @@ func testWriterSmallBatchBytes(t *testing.T) { } type testRetryWriter struct { - errs int + ch chan writerMessage + join sync.WaitGroup } func (w *testRetryWriter) messages() chan<- writerMessage { - ch := make(chan writerMessage, 1) + return w.ch +} - go func() { - for { - msg := <-ch - if w.errs > 0 { +func (w *testRetryWriter) close() { + close(w.ch) + w.join.Wait() +} + +func (w *testRetryWriter) run(errs int) { + w.join.Add(1) + defer w.join.Done() + + var done bool + for !done { + msg, ok := <-w.ch + if !ok { + done = true + } else { + if errs > 0 { msg.res <- writerResponse{ id: msg.id, err: &WriterError{ @@ -681,7 +696,7 @@ func (w *testRetryWriter) messages() chan<- writerMessage { Err: errors.New("bad attempt"), }, } - w.errs -= 1 + errs -= 1 } else { msg.res <- writerResponse{ id: msg.id, @@ -689,13 +704,13 @@ func (w *testRetryWriter) messages() chan<- writerMessage { } } } - }() - - return ch + } } -func (w *testRetryWriter) close() { - +func newTestRetryWriter(_ int, _ WriterConfig, _ *writerStats) partitionWriter { + w := &testRetryWriter{ch: make(chan writerMessage, 1)} + go w.run(2) + return w } func testWriterRetries(t *testing.T) { @@ -737,12 +752,10 @@ func testWriterRetries(t *testing.T) { for i, tc := range tcs { w := newTestWriter(WriterConfig{ - Topic: topic, - MaxAttempts: 2, - Balancer: &RoundRobin{}, - newPartitionWriter: func(_ int, _ WriterConfig, _ *writerStats) partitionWriter { - return &testRetryWriter{errs: 2} - }, + Topic: topic, + MaxAttempts: 2, + Balancer: &RoundRobin{}, + newPartitionWriter: newTestRetryWriter, }) err := w.WriteMessages(context.Background(), tc.msgs()...)