diff --git a/pkg/dmsg/server_test.go b/pkg/dmsg/server_test.go index 6bea69a1b2..8b4431b53d 100644 --- a/pkg/dmsg/server_test.go +++ b/pkg/dmsg/server_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "log" "math" "math/rand" @@ -198,19 +199,11 @@ func TestServer_Serve(t *testing.T) { err = b.InitiateServerConnections(context.Background(), 1) require.NoError(t, err) - aDone := make(chan struct{}) - var aTransport transport.Transport - var aErr error - go func() { - aTransport, aErr = a.Accept(context.Background()) - close(aDone) - }() - bTransport, err := b.Dial(context.Background(), aPK) require.NoError(t, err) - <-aDone - require.NoError(t, aErr) + aTransport, err := a.Accept(context.Background()) + require.NoError(t, err) // must be 2 ServerConn's require.Equal(t, 2, s.connCount()) @@ -351,6 +344,7 @@ func TestServer_Serve(t *testing.T) { // channel to listen for `Accept` errors. Any single error must // fail the test acceptErrs := make(chan error, totalRemoteTpsCount) + var remotesTpsMX sync.Mutex remotesTps := make(map[int][]transport.Transport, len(usedRemotes)) var remotesWG sync.WaitGroup remotesWG.Add(totalRemoteTpsCount) @@ -371,7 +365,9 @@ func TestServer_Serve(t *testing.T) { } // store transport + remotesTpsMX.Lock() remotesTps[remoteInd] = append(remotesTps[remoteInd], transport) + remotesTpsMX.Unlock() remotesWG.Done() }(i) @@ -382,6 +378,7 @@ func TestServer_Serve(t *testing.T) { // channel to listen for `Dial` errors. Any single error must // fail the test dialErrs := make(chan error, initiatorsCount) + var initiatorsTpsMx sync.Mutex initiatorsTps := make([]transport.Transport, 0, initiatorsCount) var initiatorsWG sync.WaitGroup initiatorsWG.Add(initiatorsCount) @@ -400,7 +397,9 @@ func TestServer_Serve(t *testing.T) { } // store transport + initiatorsTpsMx.Lock() initiatorsTps = append(initiatorsTps, transport) + initiatorsTpsMx.Unlock() initiatorsWG.Done() }(i) @@ -553,7 +552,7 @@ func TestServer_Serve(t *testing.T) { // create remote a := NewClient(aPK, aSK, dc) a.SetLogger(logging.MustGetLogger("A")) - err := a.InitiateServerConnections(context.Background(), 1) + err = a.InitiateServerConnections(context.Background(), 1) require.NoError(t, err) // create initiator @@ -562,96 +561,165 @@ func TestServer_Serve(t *testing.T) { err = b.InitiateServerConnections(context.Background(), 1) require.NoError(t, err) - aDone := make(chan struct{}) - var aTransport transport.Transport - var aErr error - go func() { - aTransport, aErr = a.Accept(context.Background()) - close(aDone) - }() - bTransport, err := b.Dial(context.Background(), aPK) require.NoError(t, err) - <-aDone - require.NoError(t, aErr) + aTransport, err := a.Accept(context.Background()) + require.NoError(t, err) - aTpDone := make(chan struct{}) - bTpDone := make(chan struct{}) + readWriteStop := make(chan struct{}) + readWriteDone := make(chan struct{}) - var bErr error - var tpReadWriteWG sync.WaitGroup - tpReadWriteWG.Add(2) - // run infinite reading from tp loop in goroutine + var readErr, writeErr error go func() { + // read/write to/from transport until the stop signal arrives for { select { - case <-aTpDone: - log.Println("ATransport DONE") - tpReadWriteWG.Done() + case <-readWriteStop: + close(readWriteDone) return default: - msg := make([]byte, 13) - if _, aErr = aTransport.Read(msg); aErr != nil { - tpReadWriteWG.Done() + msg := []byte("Hello there!") + if _, writeErr = bTransport.Write(msg); writeErr != nil { + close(readWriteDone) return } - log.Printf("GOT MESSAGE %s", string(msg)) - } - } - }() - - // run infinite writing to tp loop in goroutine - go func() { - for { - select { - case <-bTpDone: - log.Println("BTransport DONE") - tpReadWriteWG.Done() - return - default: - msg := []byte("Hello there!") - if _, bErr = bTransport.Write(msg); bErr != nil { - tpReadWriteWG.Done() + if _, readErr = aTransport.Read(msg); readErr != nil { + close(readWriteDone) return } } } }() - ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) - defer cancel() + // continue creating transports until the error occurs + for { + ctx := context.Background() + if _, err = a.Dial(ctx, bPK); err != nil { + break + } + } + // must be error + require.Error(t, err) - // try to create another transport - _, err = a.Dial(ctx, bPK) - // must fail with timeout + // the same as above, transport is created by another client + for { + ctx := context.Background() + if _, err = b.Dial(ctx, aPK); err != nil { + break + } + } + // must be error require.Error(t, err) // wait more time to ensure that the initially created transport works time.Sleep(2 * time.Second) + err = aTransport.Close() + require.NoError(t, err) + + err = bTransport.Close() + require.NoError(t, err) + // stop reading/writing goroutines - close(aTpDone) - close(bTpDone) + close(readWriteStop) + <-readWriteDone - // wait for goroutines to stop - tpReadWriteWG.Wait() // check that the initial transport had been working properly all the time - require.NoError(t, aErr) - require.NoError(t, bErr) + // if any error, it must be `io.EOF` for reader + if readErr != io.EOF { + require.NoError(t, readErr) + } + // if any error, it must be `io.ErrClosedPipe` for writer + if writeErr != io.ErrClosedPipe { + require.NoError(t, writeErr) + } - err = aTransport.Close() + err = a.Close() + require.NoError(t, err) + + b.log.Println("BEFORE CLOSING") + err = b.Close() + b.log.Println("AFTER CLOSING") + require.NoError(t, err) + }) + + t.Run("test sent/received message consistency", func(t *testing.T) { + // generate keys for both clients + aPK, aSK := cipher.GenerateKeyPair() + bPK, bSK := cipher.GenerateKeyPair() + + // create remote + a := NewClient(aPK, aSK, dc) + a.SetLogger(logging.MustGetLogger("A")) + err = a.InitiateServerConnections(context.Background(), 1) require.NoError(t, err) + // create initiator + b := NewClient(bPK, bSK, dc) + b.SetLogger(logging.MustGetLogger("B")) + err = b.InitiateServerConnections(context.Background(), 1) + require.NoError(t, err) + + // create transports + bTransport, err := b.Dial(context.Background(), aPK) + require.NoError(t, err) + + aTransport, err := a.Accept(context.Background()) + require.NoError(t, err) + + msgCount := 100 + for i := 0; i < msgCount; i++ { + msg := "Hello there!" + + // write message of 12 bytes + _, err := bTransport.Write([]byte(msg)) + require.NoError(t, err) + + // create a receiving buffer of 5 bytes + recBuff := make([]byte, 5) + + // read 5 bytes, 7 left + n, err := aTransport.Read(recBuff) + require.NoError(t, err) + require.Equal(t, n, len(recBuff)) + + received := string(recBuff[:n]) + + // read 5 more, 2 left + n, err = aTransport.Read(recBuff) + require.NoError(t, err) + require.Equal(t, n, len(recBuff)) + + received += string(recBuff[:n]) + + // read 2 bytes left + n, err = aTransport.Read(recBuff) + require.NoError(t, err) + require.Equal(t, n, len(msg)-len(recBuff)*2) + + received += string(recBuff[:n]) + + // received string must be equal to the sent one + require.Equal(t, received, msg) + } + err = bTransport.Close() require.NoError(t, err) + err = aTransport.Close() + require.NoError(t, err) + err = a.Close() require.NoError(t, err) err = b.Close() require.NoError(t, err) }) + + t.Run("test capped_transport_buffer_should_not_result_in_hang", func(t *testing.T) { + + }) } // Given two client instances (a & b) and a server instance (s),