diff --git a/pkg/messaging/chan_list.go b/pkg/messaging/chan_list.go index d7491ba6b..d9fe89546 100644 --- a/pkg/messaging/chan_list.go +++ b/pkg/messaging/chan_list.go @@ -1,6 +1,8 @@ package messaging -import "sync" +import ( + "sync" +) type chanList struct { sync.Mutex diff --git a/pkg/messaging/channel.go b/pkg/messaging/channel.go index 998756958..15adc1a4e 100644 --- a/pkg/messaging/channel.go +++ b/pkg/messaging/channel.go @@ -134,15 +134,11 @@ func (mCh *msgChannel) Write(p []byte) (n int, err error) { binary.BigEndian.PutUint16(buf[:2], uint16(len(data))) copy(buf[2:], data) - done := make(chan struct{}, 1) - defer close(done) + done := make(chan struct{}) go func() { n, err = mCh.link.Send(mCh.ID(), buf) n = n - (len(data) - len(p) + 2) - select { - case done <- struct{}{}: - default: - } + close(done) }() select { diff --git a/pkg/messaging/client.go b/pkg/messaging/client.go index b310fa5c9..44eb6feba 100644 --- a/pkg/messaging/client.go +++ b/pkg/messaging/client.go @@ -64,8 +64,10 @@ type Client struct { links map[cipher.PubKey]*clientLink mu sync.RWMutex - newChan chan *msgChannel - doneChan chan struct{} + newCh chan *msgChannel // chan for newly opened channels + newWG sync.WaitGroup // waits for goroutines writing to newCh to end. + + doneCh chan struct{} } // NewClient constructs a new Client. @@ -78,8 +80,8 @@ func NewClient(conf *Config) *Client { retries: conf.Retries, retryDelay: conf.RetryDelay, links: make(map[cipher.PubKey]*clientLink), - newChan: make(chan *msgChannel), - doneChan: make(chan struct{}), + newCh: make(chan *msgChannel), + doneCh: make(chan struct{}), } config := &LinkConfig{ Public: c.pubKey, @@ -129,7 +131,7 @@ func (c *Client) ConnectToInitialServers(ctx context.Context, serverCount int) e // Accept accepts a remotely-initiated Transport. func (c *Client) Accept(ctx context.Context) (transport.Transport, error) { select { - case ch, more := <-c.newChan: + case ch, more := <-c.newCh: if !more { return nil, ErrClientClosed } @@ -179,7 +181,7 @@ func (c *Client) Dial(ctx context.Context, remote cipher.PubKey) (transport.Tran return nil, ctx.Err() } - c.Logger.Infof("Opened new channel local ID %d, remote ID %d with %s", localID, channel.ID(), remote) // TODO: race condition + c.Logger.Infof("Opened new channel local ID %d, remote ID %d with %s", localID, channel.ID(), remote) return channel, nil } @@ -197,10 +199,11 @@ func (c *Client) Type() string { func (c *Client) Close() error { c.Logger.Info("Closing link pool") select { - case <-c.doneChan: + case <-c.doneCh: default: - close(c.doneChan) - close(c.newChan) + close(c.doneCh) + c.newWG.Wait() // Ensure that 'c.newCh' is not being written to before closing. + close(c.newCh) } return c.pool.Close() } @@ -337,7 +340,7 @@ func (c *Client) onData(l *Link, frameType FrameType, body []byte) error { case FrameTypeSend: go func() { select { - case <-c.doneChan: + case <-c.doneCh: case <-channel.doneChan: case channel.readChan <- body[1:]: } @@ -359,7 +362,7 @@ func (c *Client) onClose(l *Link, remote bool) { } select { - case <-c.doneChan: + case <-c.doneCh: default: c.Logger.Infof("Disconnected from the server %s. Trying to re-connect...", remotePK) for attemp := 0; attemp < c.retries; attemp++ { @@ -403,11 +406,13 @@ func (c *Client) openChannel(rID byte, remotePK []byte, noiseMsg []byte, chanLin lID = chanLink.chans.add(channel) + c.newWG.Add(1) // Ensure that 'c.newCh' is not being written to before closing. go func() { select { - case <-c.doneChan: - case c.newChan <- channel: + case <-c.doneCh: + case c.newCh <- channel: } + c.newWG.Done() }() noiseRes, err = channel.HandshakeMessage() diff --git a/pkg/messaging/client_test.go b/pkg/messaging/client_test.go index 8b6472589..33598e8c3 100644 --- a/pkg/messaging/client_test.go +++ b/pkg/messaging/client_test.go @@ -7,6 +7,7 @@ import ( "os" "sync" "testing" + "time" "github.com/skycoin/skycoin/src/util/logging" "github.com/stretchr/testify/assert" @@ -24,65 +25,108 @@ func TestMain(m *testing.M) { } func TestClientDial(t *testing.T) { - pk, sk := cipher.GenerateKeyPair() discovery := client.NewMock() - c := NewClient(&Config{pk, sk, discovery, 0, 0}) - c.retries = 0 srv, err := newMockServer(discovery) require.NoError(t, err) srvPK := srv.config.Public - anotherPK, anotherSK := cipher.GenerateKeyPair() - anotherClient := NewClient(&Config{anotherPK, anotherSK, discovery, 0, 0}) - require.NoError(t, anotherClient.ConnectToInitialServers(context.TODO(), 1)) + pk1, sk1 := cipher.GenerateKeyPair() + c1 := NewClient(&Config{pk1, sk1, discovery, 0, 0}) + + pk2, sk2 := cipher.GenerateKeyPair() + c2 := NewClient(&Config{pk2, sk2, discovery, 0, 0}) + require.NoError(t, c2.ConnectToInitialServers(context.TODO(), 1)) - var anotherTr transport.Transport - anotherErrCh := make(chan error) + var ( + tp2 transport.Transport + tp2Err error + tp2Done = make(chan struct{}) + ) go func() { - t, err := anotherClient.Accept(context.TODO()) - anotherTr = t - anotherErrCh <- err + tp2, tp2Err = c2.Accept(context.TODO()) + close(tp2Done) }() - var tr transport.Transport - errCh := make(chan error) + var ( + tp1 transport.Transport + tp1Err error + tp1Done = make(chan struct{}) + ) go func() { - t, err := c.Dial(context.TODO(), anotherPK) - tr = t - errCh <- err + tp1, tp1Err = c1.Dial(context.TODO(), pk2) + close(tp1Done) }() - require.NoError(t, <-errCh) - require.NotNil(t, c.getLink(srvPK).chans.get(0)) + <-tp1Done + require.NoError(t, tp1Err) + require.NotNil(t, c1.getLink(srvPK).chans.get(0)) - entry, err := discovery.Entry(context.TODO(), pk) + entry, err := discovery.Entry(context.TODO(), pk1) require.NoError(t, err) require.Len(t, entry.Client.DelegatedServers, 1) - require.NoError(t, <-anotherErrCh) - require.NotNil(t, anotherClient.getLink(srvPK).chans.get(0)) + <-tp2Done + require.NoError(t, tp2Err) + require.NotNil(t, c2.getLink(srvPK).chans.get(0)) - go tr.Write([]byte("foo")) // nolint: errcheck + go tp1.Write([]byte("foo")) // nolint: errcheck buf := make([]byte, 3) - n, err := anotherTr.Read(buf) + n, err := tp2.Read(buf) require.NoError(t, err) assert.Equal(t, 3, n) assert.Equal(t, []byte("foo"), buf) - go anotherTr.Write([]byte("bar")) // nolint: errcheck + go tp2.Write([]byte("bar")) // nolint: errcheck buf = make([]byte, 3) - n, err = tr.Read(buf) + n, err = tp1.Read(buf) require.NoError(t, err) assert.Equal(t, 3, n) assert.Equal(t, []byte("bar"), buf) - require.NoError(t, tr.Close()) - require.NoError(t, anotherTr.Close()) + require.NoError(t, tp1.Close()) + require.NoError(t, tp2.Close()) + + // It is expected for the messaging client to delete the channel for chanList eventually. + require.True(t, retry(time.Second*10, time.Second, func() bool { + return c2.getLink(srvPK).chans.get(0) == nil + })) +} + +// retries until successful under a given deadline. +// 'tick' specifies the break duration before retry. +func retry(deadline, tick time.Duration, do func() bool) bool { + timer := time.NewTimer(deadline) + defer timer.Stop() + + done := make(chan struct{}) + doneOnce := new(sync.Once) + defer doneOnce.Do(func() { close(done) }) - require.Nil(t, anotherClient.getLink(srvPK).chans.get(0)) + go func() { + for { + select { + case <-done: + return + case <-time.Tick(tick): + if ok := do(); ok { + doneOnce.Do(func() { close(done) }) + return + } + } + } + }() + + for { + select { + case <-timer.C: + return false + case <-done: + return true + } + } } type mockServer struct { diff --git a/pkg/node/node.go b/pkg/node/node.go index 5953d2c17..13ed9996c 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -245,35 +245,33 @@ func (node *Node) Start() error { // Close safely stops spawned Apps and messaging Node. func (node *Node) Close() (err error) { if node.rpcListener != nil { - node.logger.Info("Stopping RPC interface") - if rpcErr := node.rpcListener.Close(); rpcErr != nil && err == nil { - err = rpcErr + if err = node.rpcListener.Close(); err != nil { + node.logger.WithError(err).Error("failed to stop RPC interface") + } else { + node.logger.Info("RPC interface stopped successfully") } } - for _, dialer := range node.rpcDialers { - err = dialer.Close() + for i, dialer := range node.rpcDialers { + if err = dialer.Close(); err != nil { + node.logger.WithError(err).Errorf("(%d) failed to stop RPC dialer", i) + } else { + node.logger.Infof("(%d) RPC dialer closed successfully", i) + } } - node.startedMu.Lock() - for app, bind := range node.startedApps { - if appErr := node.stopApp(app, bind); appErr != nil && err == nil { - err = appErr + for a, bind := range node.startedApps { + if err = node.stopApp(a, bind); err != nil { + node.logger.WithError(err).Errorf("(%s) failed to stop app", a) + } else { + node.logger.Infof("(%s) app stopped successfully", a) } } node.startedMu.Unlock() - - if node.rpcListener != nil { - node.logger.Info("Stopping RPC interface") - if rpcErr := node.rpcListener.Close(); rpcErr != nil && err == nil { - err = rpcErr - } - } - - node.logger.Info("Stopping router") - if msgErr := node.router.Close(); msgErr != nil && err == nil { - err = msgErr + if err = node.router.Close(); err != nil { + node.logger.WithError(err).Error("failed to stop router") + } else { + node.logger.Info("router stopped successfully") } - return err }