diff --git a/pkg/dmsg/client_test.go b/pkg/dmsg/client_test.go index fca12a7f03..a4de5ae36f 100644 --- a/pkg/dmsg/client_test.go +++ b/pkg/dmsg/client_test.go @@ -4,6 +4,7 @@ import ( "context" "net" "testing" + "time" "github.com/skycoin/skycoin/src/util/logging" "github.com/stretchr/testify/assert" @@ -11,14 +12,17 @@ import ( "github.com/skycoin/skywire/pkg/cipher" ) +const ( + acceptChSize = 128 + chanReadThreshold = time.Second * 5 +) + type transportWithError struct { tr *Transport err error } func TestClient(t *testing.T) { - const acceptChSize = 128 - logger := logging.MustGetLogger("dms_client") t.Run("Two clients", func(t *testing.T) { @@ -63,20 +67,13 @@ func TestClient(t *testing.T) { err = conn1.Close() assert.NoError(t, err) - _, ok := <-conn1.done - assert.False(t, ok) - err = conn2.Close() assert.NoError(t, err) - _, ok = <-conn1.done - assert.False(t, ok) - - _, ok = <-tr1.doneCh - assert.False(t, ok) - - _, ok = <-tr1.readCh - assert.False(t, ok) + assert.False(t, isDoneChannelOpen(conn1.done)) + assert.False(t, isDoneChannelOpen(conn2.done)) + assert.False(t, isDoneChannelOpen(tr1.doneCh)) + assert.False(t, isReadChannelOpen(tr1.readCh)) }) t.Run("Three clients", func(t *testing.T) { @@ -217,24 +214,31 @@ func TestClient(t *testing.T) { err = <-errCh4 assert.NoError(t, err) - _, ok := <-conn1.done - assert.False(t, ok) - - _, ok = <-conn3.done - assert.False(t, ok) - - _, ok = <-tr1.doneCh - assert.False(t, ok) - - _, ok = <-tr1.readCh - assert.False(t, ok) + assert.False(t, isDoneChannelOpen(conn1.done)) + assert.False(t, isDoneChannelOpen(conn3.done)) + assert.False(t, isDoneChannelOpen(tr1.doneCh)) + assert.False(t, isReadChannelOpen(tr1.readCh)) + assert.False(t, isDoneChannelOpen(tr2.doneCh)) + assert.False(t, isReadChannelOpen(tr2.readCh)) + }) +} - _, ok = <-tr2.doneCh - assert.False(t, ok) +func isDoneChannelOpen(ch chan struct{}) bool { + select { + case _, ok := <-ch: + return ok + case <-time.After(chanReadThreshold): + return false + } +} - _, ok = <-tr2.readCh - assert.False(t, ok) - }) +func isReadChannelOpen(ch chan Frame) bool { + select { + case _, ok := <-ch: + return ok + case <-time.After(chanReadThreshold): + return false + } } type invertedIDConn struct {