diff --git a/pkg/dmsg/client_test.go b/pkg/dmsg/client_test.go index c7825bc84c..fca12a7f03 100644 --- a/pkg/dmsg/client_test.go +++ b/pkg/dmsg/client_test.go @@ -11,43 +11,230 @@ import ( "github.com/skycoin/skywire/pkg/cipher" ) +type transportWithError struct { + tr *Transport + err error +} + func TestClient(t *testing.T) { const acceptChSize = 128 logger := logging.MustGetLogger("dms_client") - p1, p2 := net.Pipe() - p1, p2 = invertedIDConn{p1}, invertedIDConn{p2} + t.Run("Two clients", func(t *testing.T) { + p1, p2 := net.Pipe() + p1, p2 = invertedIDConn{p1}, invertedIDConn{p2} + + var pk1, pk2 cipher.PubKey + err := pk1.Set("024ec47420176680816e0406250e7156465e4531f5b26057c9f6297bb0303558c7") + assert.NoError(t, err) + err = pk2.Set("031b80cd5773143a39d940dc0710b93dcccc262a85108018a7a95ab9af734f8055") + assert.NoError(t, err) + + conn1 := NewClientConn(logger, p1, pk1, pk2) + conn2 := NewClientConn(logger, p2, pk2, pk1) + + conn2.nextInitID = randID(false) + + ch1 := make(chan *Transport, acceptChSize) + ch2 := make(chan *Transport, acceptChSize) + + ctx := context.TODO() + + go func() { + _ = conn1.Serve(ctx, ch1) // nolint:errcheck + }() + + go func() { + _ = conn2.Serve(ctx, ch2) // nolint:errcheck + }() + + initID := conn1.nextInitID + assert.Nil(t, conn1.tps[initID]) + + tr1, err := conn1.DialTransport(ctx, pk2) + assert.NoError(t, err) + assert.NotNil(t, conn1.tps[initID]) + assert.Equal(t, initID+2, conn1.nextInitID) + + err = tr1.Close() + assert.NoError(t, err) + + err = conn1.Close() + assert.NoError(t, err) + + _, ok := <-conn1.done + assert.False(t, ok) + + err = conn2.Close() + assert.NoError(t, err) + + _, ok = <-conn1.done + assert.False(t, ok) + + _, ok = <-tr1.doneCh + assert.False(t, ok) + + _, ok = <-tr1.readCh + assert.False(t, ok) + }) + + t.Run("Three clients", func(t *testing.T) { + p1, p2 := net.Pipe() + p1, p2 = invertedIDConn{p1}, invertedIDConn{p2} + + p3, p4 := net.Pipe() + p3, p4 = invertedIDConn{p3}, invertedIDConn{p4} + + var pk1, pk2, pk3 cipher.PubKey + err := pk1.Set("024ec47420176680816e0406250e7156465e4531f5b26057c9f6297bb0303558c7") + assert.NoError(t, err) + err = pk2.Set("031b80cd5773143a39d940dc0710b93dcccc262a85108018a7a95ab9af734f8055") + assert.NoError(t, err) + err = pk3.Set("035b57eef30b9a6be1effc2c3337a3a1ffedcd04ffbac6667cd822892cf56be24a") + assert.NoError(t, err) + + conn1 := NewClientConn(logger, p1, pk1, pk2) + conn2 := NewClientConn(logger, p2, pk2, pk1) + conn3 := NewClientConn(logger, p3, pk2, pk3) + conn4 := NewClientConn(logger, p4, pk3, pk2) + + conn2.nextInitID = randID(false) + conn4.nextInitID = randID(false) + + ch1 := make(chan *Transport, acceptChSize) + ch2 := make(chan *Transport, acceptChSize) + ch3 := make(chan *Transport, acceptChSize) + ch4 := make(chan *Transport, acceptChSize) + + ctx := context.TODO() + + go func() { + _ = conn1.Serve(ctx, ch1) // nolint:errcheck + }() + + go func() { + _ = conn2.Serve(ctx, ch2) // nolint:errcheck + }() + + go func() { + _ = conn3.Serve(ctx, ch3) // nolint:errcheck + }() + + go func() { + _ = conn4.Serve(ctx, ch4) // nolint:errcheck + }() + + initID1 := conn1.nextInitID + assert.Nil(t, conn1.tps[initID1]) + + initID2 := conn2.nextInitID + assert.Nil(t, conn2.tps[initID2]) + + initID3 := conn3.nextInitID + assert.Nil(t, conn3.tps[initID3]) + + initID4 := conn4.nextInitID + assert.Nil(t, conn4.tps[initID4]) + + trCh1 := make(chan transportWithError) + trCh2 := make(chan transportWithError) + + go func() { + tr, err := conn1.DialTransport(ctx, pk2) + trCh1 <- transportWithError{ + tr: tr, + err: err, + } + }() + + go func() { + tr, err := conn3.DialTransport(ctx, pk3) + trCh2 <- transportWithError{ + tr: tr, + err: err, + } + }() + + twe1 := <-trCh1 + twe2 := <-trCh2 + + tr1, err := twe1.tr, twe1.err + + assert.NoError(t, err) + assert.NotNil(t, conn1.tps[initID1]) + assert.Equal(t, initID1+2, conn1.nextInitID) + + tr2, err := twe2.tr, twe2.err + assert.NoError(t, err) + assert.NotNil(t, conn3.tps[initID3]) + assert.Equal(t, initID3+2, conn3.nextInitID) + + errCh1 := make(chan error) + errCh2 := make(chan error) + errCh3 := make(chan error) + errCh4 := make(chan error) + + go func() { + errCh1 <- tr1.Close() + }() + + go func() { + errCh2 <- tr2.Close() + }() + + err = <-errCh1 + assert.NoError(t, err) + + err = <-errCh2 + assert.NoError(t, err) + + go func() { + errCh1 <- conn1.Close() + }() + + go func() { + errCh2 <- conn2.Close() + }() + + go func() { + errCh3 <- conn3.Close() + }() + + go func() { + errCh4 <- conn4.Close() + }() + + err = <-errCh1 + assert.NoError(t, err) - var pk1, pk2 cipher.PubKey - err := pk1.Set("024ec47420176680816e0406250e7156465e4531f5b26057c9f6297bb0303558c7") - assert.NoError(t, err) - err = pk2.Set("031b80cd5773143a39d940dc0710b93dcccc262a85108018a7a95ab9af734f8055") - assert.NoError(t, err) + err = <-errCh2 + assert.NoError(t, err) - conn1 := NewClientConn(logger, p1, pk1, pk2) - conn2 := NewClientConn(logger, p2, pk2, pk1) + err = <-errCh3 + assert.NoError(t, err) - conn2.nextInitID = randID(false) + err = <-errCh4 + assert.NoError(t, err) - ch1 := make(chan *Transport, acceptChSize) - ch2 := make(chan *Transport, acceptChSize) + _, ok := <-conn1.done + assert.False(t, ok) - ctx := context.TODO() + _, ok = <-conn3.done + assert.False(t, ok) - go func() { - _ = conn1.Serve(ctx, ch1) - }() + _, ok = <-tr1.doneCh + assert.False(t, ok) - go func() { - _ = conn2.Serve(ctx, ch2) - }() + _, ok = <-tr1.readCh + assert.False(t, ok) - tr, err := conn1.DialTransport(ctx, pk2) - assert.NoError(t, err) + _, ok = <-tr2.doneCh + assert.False(t, ok) - err = tr.Close() - assert.NoError(t, err) + _, ok = <-tr2.readCh + assert.False(t, ok) + }) } type invertedIDConn struct {