diff --git a/pkg/messaging/client.go b/pkg/messaging/client.go index 5d9aa0b130..d7892106d5 100644 --- a/pkg/messaging/client.go +++ b/pkg/messaging/client.go @@ -64,8 +64,10 @@ type Client struct { links map[cipher.PubKey]*clientLink mu sync.RWMutex - newChan chan *channel - doneChan chan struct{} + newCh chan *channel // chan for newly opened channels + newWG sync.WaitGroup // waits for goroutines writing to newCh to end. + + doneCh chan struct{} } // NewClient constructs a new Client. @@ -78,8 +80,8 @@ func NewClient(conf *Config) *Client { retries: conf.Retries, retryDelay: conf.RetryDelay, links: make(map[cipher.PubKey]*clientLink), - newChan: make(chan *channel), - doneChan: make(chan struct{}), + newCh: make(chan *channel), + doneCh: make(chan struct{}), } config := &LinkConfig{ Public: c.pubKey, @@ -129,7 +131,7 @@ func (c *Client) ConnectToInitialServers(ctx context.Context, serverCount int) e // Accept accepts a remotely-initiated Transport. func (c *Client) Accept(ctx context.Context) (transport.Transport, error) { select { - case ch, more := <-c.newChan: + case ch, more := <-c.newCh: if !more { return nil, ErrClientClosed } @@ -179,7 +181,7 @@ func (c *Client) Dial(ctx context.Context, remote cipher.PubKey) (transport.Tran return nil, ctx.Err() } - c.Logger.Infof("Opened new channel local ID %d, remote ID %d with %s", localID, channel.ID(), remote) // TODO: race condition + c.Logger.Infof("Opened new channel local ID %d, remote ID %d with %s", localID, channel.ID(), remote) return channel, nil } @@ -197,10 +199,11 @@ func (c *Client) Type() string { func (c *Client) Close() error { c.Logger.Info("Closing link pool") select { - case <-c.doneChan: + case <-c.doneCh: default: - close(c.doneChan) - close(c.newChan) + close(c.doneCh) + c.newWG.Wait() // Ensure that 'c.newCh' is not being written to before closing. + close(c.newCh) } return c.pool.Close() } @@ -328,18 +331,18 @@ func (c *Client) onData(l *Link, frameType FrameType, body []byte) error { } case FrameTypeChannelClosed: channel.SetID(body[0]) - channel.closeChanMx.RLock() // TODO(evanlinjin): START(avoid race condition). + channel.closeChanMx.RLock() // Begin: avoid race condition. select { case channel.waitChan <- false: - case channel.closeChan <- struct{}{}: // TODO(evanlinjin): data race. + case channel.closeChan <- struct{}{}: // Previous data race. clientLink.chans.remove(channelID) default: } - channel.closeChanMx.RUnlock() // TODO(evanlinjin): END(avoid race condition). + channel.closeChanMx.RUnlock() // End: avoid race condition. case FrameTypeSend: go func() { select { - case <-c.doneChan: + case <-c.doneCh: case <-channel.doneChan: case channel.readChan <- body[1:]: } @@ -361,7 +364,7 @@ func (c *Client) onClose(l *Link, remote bool) { } select { - case <-c.doneChan: + case <-c.doneCh: default: c.Logger.Infof("Disconnected from the server %s. Trying to re-connect...", remotePK) for attemp := 0; attemp < c.retries; attemp++ { @@ -405,11 +408,13 @@ func (c *Client) openChannel(rID byte, remotePK []byte, noiseMsg []byte, chanLin lID = chanLink.chans.add(channel) + c.newWG.Add(1) // Ensure that 'c.newCh' is not being written to before closing. go func() { select { - case <-c.doneChan: - case c.newChan <- channel: + case <-c.doneCh: + case c.newCh <- channel: } + c.newWG.Done() }() noiseRes, err = channel.HandshakeMessage()