Skip to content

Commit

Permalink
Merge further tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
林志宇 committed Jun 20, 2019
2 parents 32bab81 + d0741bb commit e316d53
Showing 1 changed file with 130 additions and 62 deletions.
192 changes: 130 additions & 62 deletions pkg/dmsg/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"log"
"math"
"math/rand"
Expand Down Expand Up @@ -198,19 +199,11 @@ func TestServer_Serve(t *testing.T) {
err = b.InitiateServerConnections(context.Background(), 1)
require.NoError(t, err)

aDone := make(chan struct{})
var aTransport transport.Transport
var aErr error
go func() {
aTransport, aErr = a.Accept(context.Background())
close(aDone)
}()

bTransport, err := b.Dial(context.Background(), aPK)
require.NoError(t, err)

<-aDone
require.NoError(t, aErr)
aTransport, err := a.Accept(context.Background())
require.NoError(t, err)

// must be 2 ServerConn's
require.Equal(t, 2, s.connCount())
Expand Down Expand Up @@ -351,6 +344,7 @@ func TestServer_Serve(t *testing.T) {
// channel to listen for `Accept` errors. Any single error must
// fail the test
acceptErrs := make(chan error, totalRemoteTpsCount)
var remotesTpsMX sync.Mutex
remotesTps := make(map[int][]transport.Transport, len(usedRemotes))
var remotesWG sync.WaitGroup
remotesWG.Add(totalRemoteTpsCount)
Expand All @@ -371,7 +365,9 @@ func TestServer_Serve(t *testing.T) {
}

// store transport
remotesTpsMX.Lock()
remotesTps[remoteInd] = append(remotesTps[remoteInd], transport)
remotesTpsMX.Unlock()

remotesWG.Done()
}(i)
Expand All @@ -382,6 +378,7 @@ func TestServer_Serve(t *testing.T) {
// channel to listen for `Dial` errors. Any single error must
// fail the test
dialErrs := make(chan error, initiatorsCount)
var initiatorsTpsMx sync.Mutex
initiatorsTps := make([]transport.Transport, 0, initiatorsCount)
var initiatorsWG sync.WaitGroup
initiatorsWG.Add(initiatorsCount)
Expand All @@ -400,7 +397,9 @@ func TestServer_Serve(t *testing.T) {
}

// store transport
initiatorsTpsMx.Lock()
initiatorsTps = append(initiatorsTps, transport)
initiatorsTpsMx.Unlock()

initiatorsWG.Done()
}(i)
Expand Down Expand Up @@ -553,7 +552,7 @@ func TestServer_Serve(t *testing.T) {
// create remote
a := NewClient(aPK, aSK, dc)
a.SetLogger(logging.MustGetLogger("A"))
err := a.InitiateServerConnections(context.Background(), 1)
err = a.InitiateServerConnections(context.Background(), 1)
require.NoError(t, err)

// create initiator
Expand All @@ -562,96 +561,165 @@ func TestServer_Serve(t *testing.T) {
err = b.InitiateServerConnections(context.Background(), 1)
require.NoError(t, err)

aDone := make(chan struct{})
var aTransport transport.Transport
var aErr error
go func() {
aTransport, aErr = a.Accept(context.Background())
close(aDone)
}()

bTransport, err := b.Dial(context.Background(), aPK)
require.NoError(t, err)

<-aDone
require.NoError(t, aErr)
aTransport, err := a.Accept(context.Background())
require.NoError(t, err)

aTpDone := make(chan struct{})
bTpDone := make(chan struct{})
readWriteStop := make(chan struct{})
readWriteDone := make(chan struct{})

var bErr error
var tpReadWriteWG sync.WaitGroup
tpReadWriteWG.Add(2)
// run infinite reading from tp loop in goroutine
var readErr, writeErr error
go func() {
// read/write to/from transport until the stop signal arrives
for {
select {
case <-aTpDone:
log.Println("ATransport DONE")
tpReadWriteWG.Done()
case <-readWriteStop:
close(readWriteDone)
return
default:
msg := make([]byte, 13)
if _, aErr = aTransport.Read(msg); aErr != nil {
tpReadWriteWG.Done()
msg := []byte("Hello there!")
if _, writeErr = bTransport.Write(msg); writeErr != nil {
close(readWriteDone)
return
}
log.Printf("GOT MESSAGE %s", string(msg))
}
}
}()

// run infinite writing to tp loop in goroutine
go func() {
for {
select {
case <-bTpDone:
log.Println("BTransport DONE")
tpReadWriteWG.Done()
return
default:
msg := []byte("Hello there!")
if _, bErr = bTransport.Write(msg); bErr != nil {
tpReadWriteWG.Done()
if _, readErr = aTransport.Read(msg); readErr != nil {
close(readWriteDone)
return
}
}
}
}()

ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
// continue creating transports until the error occurs
for {
ctx := context.Background()
if _, err = a.Dial(ctx, bPK); err != nil {
break
}
}
// must be error
require.Error(t, err)

// try to create another transport
_, err = a.Dial(ctx, bPK)
// must fail with timeout
// the same as above, transport is created by another client
for {
ctx := context.Background()
if _, err = b.Dial(ctx, aPK); err != nil {
break
}
}
// must be error
require.Error(t, err)

// wait more time to ensure that the initially created transport works
time.Sleep(2 * time.Second)

err = aTransport.Close()
require.NoError(t, err)

err = bTransport.Close()
require.NoError(t, err)

// stop reading/writing goroutines
close(aTpDone)
close(bTpDone)
close(readWriteStop)
<-readWriteDone

// wait for goroutines to stop
tpReadWriteWG.Wait()
// check that the initial transport had been working properly all the time
require.NoError(t, aErr)
require.NoError(t, bErr)
// if any error, it must be `io.EOF` for reader
if readErr != io.EOF {
require.NoError(t, readErr)
}
// if any error, it must be `io.ErrClosedPipe` for writer
if writeErr != io.ErrClosedPipe {
require.NoError(t, writeErr)
}

err = aTransport.Close()
err = a.Close()
require.NoError(t, err)

b.log.Println("BEFORE CLOSING")
err = b.Close()
b.log.Println("AFTER CLOSING")
require.NoError(t, err)
})

t.Run("test sent/received message consistency", func(t *testing.T) {
// generate keys for both clients
aPK, aSK := cipher.GenerateKeyPair()
bPK, bSK := cipher.GenerateKeyPair()

// create remote
a := NewClient(aPK, aSK, dc)
a.SetLogger(logging.MustGetLogger("A"))
err = a.InitiateServerConnections(context.Background(), 1)
require.NoError(t, err)

// create initiator
b := NewClient(bPK, bSK, dc)
b.SetLogger(logging.MustGetLogger("B"))
err = b.InitiateServerConnections(context.Background(), 1)
require.NoError(t, err)

// create transports
bTransport, err := b.Dial(context.Background(), aPK)
require.NoError(t, err)

aTransport, err := a.Accept(context.Background())
require.NoError(t, err)

msgCount := 100
for i := 0; i < msgCount; i++ {
msg := "Hello there!"

// write message of 12 bytes
_, err := bTransport.Write([]byte(msg))
require.NoError(t, err)

// create a receiving buffer of 5 bytes
recBuff := make([]byte, 5)

// read 5 bytes, 7 left
n, err := aTransport.Read(recBuff)
require.NoError(t, err)
require.Equal(t, n, len(recBuff))

received := string(recBuff[:n])

// read 5 more, 2 left
n, err = aTransport.Read(recBuff)
require.NoError(t, err)
require.Equal(t, n, len(recBuff))

received += string(recBuff[:n])

// read 2 bytes left
n, err = aTransport.Read(recBuff)
require.NoError(t, err)
require.Equal(t, n, len(msg)-len(recBuff)*2)

received += string(recBuff[:n])

// received string must be equal to the sent one
require.Equal(t, received, msg)
}

err = bTransport.Close()
require.NoError(t, err)

err = aTransport.Close()
require.NoError(t, err)

err = a.Close()
require.NoError(t, err)

err = b.Close()
require.NoError(t, err)
})

t.Run("test capped_transport_buffer_should_not_result_in_hang", func(t *testing.T) {

})
}

// Given two client instances (a & b) and a server instance (s),
Expand Down

0 comments on commit e316d53

Please sign in to comment.