Skip to content

Commit

Permalink
fix AckWaiter hang issue
Browse files Browse the repository at this point in the history
  • Loading branch information
林志宇 committed Jun 19, 2019
1 parent 862da5a commit bb2d5e1
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
4 changes: 2 additions & 2 deletions internal/ioutil/ack_waiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (w *Uint16AckWaiter) StopAll() {

// Wait performs the given action, and waits for given seq to be Done.
func (w *Uint16AckWaiter) Wait(ctx context.Context, action func(seq Uint16Seq) error) (err error) {
ackCh := make(chan struct{})
ackCh := make(chan struct{}, 1)

w.mx.Lock()
seq := w.nextSeq
Expand All @@ -78,7 +78,7 @@ func (w *Uint16AckWaiter) Wait(ctx context.Context, action func(seq Uint16Seq) e
case _, ok := <-ackCh:
if !ok {
// waiter stopped manually.
return io.ErrClosedPipe
err = io.ErrClosedPipe
}
case <-ctx.Done():
err = ctx.Err()
Expand Down
49 changes: 35 additions & 14 deletions internal/ioutil/ack_waiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,44 @@ package ioutil

import (
"context"
"sync"
"testing"
)

// Ensure that no race conditions occurs.
func TestUint16AckWaiter_Wait(t *testing.T) {
w := new(Uint16AckWaiter)

seqChan := make(chan Uint16Seq)
defer close(seqChan)
for i := 0; i < 64; i++ {
go w.Wait(context.TODO(), func(seq Uint16Seq) error { //nolint:errcheck,unparam
seqChan <- seq
return nil
})
seq := <-seqChan
for j := 0; j < i; j++ {
go w.Done(seq)

// Ensure that no race conditions occurs when
// each concurrent call to 'Uint16AckWaiter.Wait()' is met with
// multiple concurrent calls to 'Uint16AckWaiter.Done()' with the same seq.
t.Run("ensure_no_race_conditions", func(*testing.T) {
w := new(Uint16AckWaiter)
defer w.StopAll()

seqChan := make(chan Uint16Seq)
defer close(seqChan)

wg := new(sync.WaitGroup)

for i := 0; i < 64; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = w.Wait(context.TODO(), func(seq Uint16Seq) error { //nolint:errcheck,unparam
seqChan <- seq
return nil
})
}()

seq := <-seqChan
for j := 0; j <= i; j++ {
wg.Add(1)
go func() {
defer wg.Done()
w.Done(seq)
}()
}
}
}

wg.Wait()
})
}
1 change: 0 additions & 1 deletion pkg/dmsg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ var (
ErrClientClosed = errors.New("client closed")
// ErrClientAcceptMaxed indicates that the client cannot take in more accepts.
ErrClientAcceptMaxed = errors.New("client accepts buffer maxed")

)

// ClientConn represents a connection between a dmsg.Client and dmsg.Server from a client's perspective.
Expand Down

0 comments on commit bb2d5e1

Please sign in to comment.