diff --git a/go.sum b/go.sum index 88c3a66580..553501acd5 100644 --- a/go.sum +++ b/go.sum @@ -116,8 +116,10 @@ golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/pkg/dmsg/server_test.go b/pkg/dmsg/server_test.go index 1222f2fc93..644efed20c 100644 --- a/pkg/dmsg/server_test.go +++ b/pkg/dmsg/server_test.go @@ -801,8 +801,7 @@ func testReconnect(t *testing.T, randomAddr bool) { assert.Equal(t, 0, s.connCount()) - remote := NewClient(remotePK, remoteSK, dc) - remote.SetLogger(logging.MustGetLogger("remote")) + remote := NewClient(remotePK, remoteSK, dc, SetLogger(logging.MustGetLogger("remote"))) err = remote.InitiateServerConnections(ctx, 1) require.NoError(t, err) @@ -813,14 +812,16 @@ func testReconnect(t *testing.T, randomAddr bool) { return nil })) - initiator := NewClient(initiatorPK, initiatorSK, dc) - initiator.SetLogger(logging.MustGetLogger("initiator")) + initiator := NewClient(initiatorPK, initiatorSK, dc, SetLogger(logging.MustGetLogger("initiator"))) err = initiator.InitiateServerConnections(ctx, 1) require.NoError(t, err) initiatorTransport, err := initiator.Dial(ctx, remotePK) require.NoError(t, err) + remoteTransport, err := remote.Accept(context.Background()) + require.NoError(t, err) + require.NoError(t, testWithTimeout(smallDelay, func() error { if s.connCount() != 2 { return errors.New("s.conns is not equal to 2") @@ -835,6 +836,10 @@ func testReconnect(t *testing.T, randomAddr bool) { assert.False(t, isDoneChannelOpen(initTr.done)) assert.False(t, isReadChannelOpen(initTr.inCh)) + remoteTr := remoteTransport.(*Transport) + assert.False(t, isDoneChannelOpen(remoteTr.done)) + assert.False(t, isReadChannelOpen(remoteTr.inCh)) + assert.Equal(t, 0, s.connCount()) addr := "" @@ -857,21 +862,16 @@ func testReconnect(t *testing.T, randomAddr bool) { return nil })) - remoteDone := make(chan struct{}) - var remoteErr error - go func() { - _, remoteErr = remote.Accept(ctx) - close(remoteDone) - }() - require.NoError(t, testWithTimeout(smallDelay, func() error { _, err = initiator.Dial(ctx, remotePK) + if err != nil { + return err + } + + _, err = remote.Accept(context.Background()) return err })) - <-remoteDone - require.NoError(t, remoteErr) - err = s.Close() assert.NoError(t, err) }