From b30b3a2d894541555e33043266e45b6c6c97777b Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Tue, 4 Jun 2019 20:34:52 +0300 Subject: [PATCH] Add concurrent test for `dmsg.Server`'s `Serve` (WIP) --- pkg/dmsg/server_test.go | 168 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 162 insertions(+), 6 deletions(-) diff --git a/pkg/dmsg/server_test.go b/pkg/dmsg/server_test.go index f7d1209ae3..139b10f87a 100644 --- a/pkg/dmsg/server_test.go +++ b/pkg/dmsg/server_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "math/rand" "net" "sync" "testing" @@ -39,7 +40,7 @@ func TestNewServer(t *testing.T) { assert.NoError(t, s.Close()) } -func TestServer_ListenAndServe(t *testing.T) { +func TestServer_Serve(t *testing.T) { sPK, sSK := cipher.GenerateKeyPair() dc := client.NewMock() @@ -68,22 +69,23 @@ func TestServer_ListenAndServe(t *testing.T) { err = b.InitiateServerConnections(context.Background(), 1) require.NoError(t, err) - aDone := make(chan struct{}) + aDone := make(chan error) var aTransport transport.Transport go func() { // avoid ambiguity between this and the outer one var err error aTransport, err = a.Accept(context.Background()) - // check if really fails test - require.NoError(t, err) + + aDone <- err close(aDone) }() bTransport, err := b.Dial(context.Background(), aPK) require.NoError(t, err) - <-aDone + err = <-aDone + require.NoError(t, err) // must be 2 ServerConn's require.Equal(t, 2, len(s.conns)) @@ -132,12 +134,166 @@ func TestServer_ListenAndServe(t *testing.T) { err = b.Close() require.NoError(t, err) - time.Sleep(3 * time.Second) + time.Sleep(5 * time.Second) require.Equal(t, 0, len(s.conns)) require.Equal(t, 0, len(a.conns)) require.Equal(t, 0, len(b.conns)) }) + + t.Run("test transport establishment concurrently", func(t *testing.T) { + initiatorsCount := 4 + remotesCount := 4 + + initiators := make([]*Client, 0, initiatorsCount) + remotes := make([]*Client, 0, remotesCount) + + for i := 0; i < initiatorsCount; i++ { + pk, sk := cipher.GenerateKeyPair() + + c := NewClient(pk, sk, dc) + c.SetLogger(logging.MustGetLogger(fmt.Sprintf("Initiator %d", i))) + err := c.InitiateServerConnections(context.Background(), 1) + require.NoError(t, err) + + initiators = append(initiators, c) + } + + for i := 0; i < remotesCount; i++ { + pk, sk := cipher.GenerateKeyPair() + + c := NewClient(pk, sk, dc) + c.SetLogger(logging.MustGetLogger(fmt.Sprintf("Remote %d", i))) + err := c.InitiateServerConnections(context.Background(), 1) + require.NoError(t, err) + + remotes = append(remotes, c) + } + + rand := rand.New(rand.NewSource(time.Now().UnixNano())) + + usedRemotes := make(map[int]struct{}) + pickedRemotes := make([]int, 0, initiatorsCount) + for range initiators { + remote := rand.Intn(remotesCount) + usedRemotes[remote] = struct{}{} + pickedRemotes = append(pickedRemotes, remote) + } + + acceptErrs := make(chan error, remotesCount) + var remotesWG sync.WaitGroup + remotesTps := make(map[int]transport.Transport, len(usedRemotes)) + remotesWG.Add(len(usedRemotes)) + for i, r := range remotes { + if _, ok := usedRemotes[i]; ok { + go func() { + var ( + transport transport.Transport + err error + ) + + transport, err = r.Accept(context.Background()) + if err != nil { + acceptErrs <- err + } + + remotesTps[i] = transport + + remotesWG.Done() + }() + } + } + + dialErrs := make(chan error, initiatorsCount) + var initiatorsWG sync.WaitGroup + initiatorsTps := make([]transport.Transport, initiatorsCount) + initiatorsWG.Add(initiatorsCount) + for i := range initiators { + go func() { + var ( + transport transport.Transport + err error + ) + + transport, err = initiators[i].Dial(context.Background(), remotes[pickedRemotes[i]].pk) + if err != nil { + dialErrs <- err + } + + initiatorsTps = append(initiatorsTps, transport) + + initiatorsWG.Done() + }() + } + + initiatorsWG.Wait() + close(dialErrs) + err = <-dialErrs + require.NoError(t, err) + + remotesWG.Done() + close(acceptErrs) + err = <-acceptErrs + require.NoError(t, err) + + require.Equal(t, len(usedRemotes)+initiatorsCount, len(s.conns)) + + /*err = <-aDone + require.NoError(t, err) + + // must be 2 ServerConn's + require.Equal(t, 2, len(s.conns)) + + // must have ServerConn for A + aServerConn, ok := s.conns[aPK] + require.Equal(t, true, ok) + require.Equal(t, aPK, aServerConn.remoteClient) + + // must have ServerConn for B + bServerConn, ok := s.conns[bPK] + require.Equal(t, true, ok) + require.Equal(t, bPK, bServerConn.remoteClient) + + // must have a ClientConn + aClientConn, ok := a.conns[sPK] + require.Equal(t, true, ok) + require.Equal(t, sPK, aClientConn.remoteSrv) + + // must have a ClientConn + bClientConn, ok := b.conns[sPK] + require.Equal(t, true, ok) + require.Equal(t, sPK, bClientConn.remoteSrv) + + // check whether nextConn's contents are as must be + bNextConn := bServerConn.nextConns[bClientConn.nextInitID-2] + require.NotNil(t, bNextConn) + require.Equal(t, bNextConn.id, aServerConn.nextRespID-2) + + // check whether nextConn's contents are as must be + aNextConn := aServerConn.nextConns[aServerConn.nextRespID-2] + require.NotNil(t, aNextConn) + require.Equal(t, aNextConn.id, bClientConn.nextInitID-2) + + log.Printf("%v\n", s.conns) + + err = aTransport.Close() + require.NoError(t, err) + + err = bTransport.Close() + require.NoError(t, err) + + err = a.Close() + require.NoError(t, err) + + err = b.Close() + require.NoError(t, err) + + time.Sleep(5 * time.Second) + + require.Equal(t, 0, len(s.conns)) + require.Equal(t, 0, len(a.conns)) + require.Equal(t, 0, len(b.conns))*/ + }) } // Given two client instances (a & b) and a server instance (s),