Skip to content

Commit

Permalink
Fix data race when closing messaging.Client's newCh.
Browse files Browse the repository at this point in the history
* Added wg to await goroutines writing to newCh to end before closing newCh.
* Some comments.
  • Loading branch information
林志宇 committed May 21, 2019
1 parent e32959e commit 40b3ee6
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions pkg/messaging/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ type Client struct {
links map[cipher.PubKey]*clientLink
mu sync.RWMutex

newChan chan *channel
doneChan chan struct{}
newCh chan *channel // chan for newly opened channels
newWG sync.WaitGroup // waits for goroutines writing to newCh to end.

doneCh chan struct{}
}

// NewClient constructs a new Client.
Expand All @@ -78,8 +80,8 @@ func NewClient(conf *Config) *Client {
retries: conf.Retries,
retryDelay: conf.RetryDelay,
links: make(map[cipher.PubKey]*clientLink),
newChan: make(chan *channel),
doneChan: make(chan struct{}),
newCh: make(chan *channel),
doneCh: make(chan struct{}),
}
config := &LinkConfig{
Public: c.pubKey,
Expand Down Expand Up @@ -129,7 +131,7 @@ func (c *Client) ConnectToInitialServers(ctx context.Context, serverCount int) e
// Accept accepts a remotely-initiated Transport.
func (c *Client) Accept(ctx context.Context) (transport.Transport, error) {
select {
case ch, more := <-c.newChan:
case ch, more := <-c.newCh:
if !more {
return nil, ErrClientClosed
}
Expand Down Expand Up @@ -179,7 +181,7 @@ func (c *Client) Dial(ctx context.Context, remote cipher.PubKey) (transport.Tran
return nil, ctx.Err()
}

c.Logger.Infof("Opened new channel local ID %d, remote ID %d with %s", localID, channel.ID(), remote) // TODO: race condition
c.Logger.Infof("Opened new channel local ID %d, remote ID %d with %s", localID, channel.ID(), remote)
return channel, nil
}

Expand All @@ -197,10 +199,11 @@ func (c *Client) Type() string {
func (c *Client) Close() error {
c.Logger.Info("Closing link pool")
select {
case <-c.doneChan:
case <-c.doneCh:
default:
close(c.doneChan)
close(c.newChan)
close(c.doneCh)
c.newWG.Wait() // Ensure that 'c.newCh' is not being written to before closing.
close(c.newCh)
}
return c.pool.Close()
}
Expand Down Expand Up @@ -328,18 +331,18 @@ func (c *Client) onData(l *Link, frameType FrameType, body []byte) error {
}
case FrameTypeChannelClosed:
channel.SetID(body[0])
channel.closeChanMx.RLock() // TODO(evanlinjin): START(avoid race condition).
channel.closeChanMx.RLock() // Begin: avoid race condition.
select {
case channel.waitChan <- false:
case channel.closeChan <- struct{}{}: // TODO(evanlinjin): data race.
case channel.closeChan <- struct{}{}: // Previous data race.
clientLink.chans.remove(channelID)
default:
}
channel.closeChanMx.RUnlock() // TODO(evanlinjin): END(avoid race condition).
channel.closeChanMx.RUnlock() // End: avoid race condition.
case FrameTypeSend:
go func() {
select {
case <-c.doneChan:
case <-c.doneCh:
case <-channel.doneChan:
case channel.readChan <- body[1:]:
}
Expand All @@ -361,7 +364,7 @@ func (c *Client) onClose(l *Link, remote bool) {
}

select {
case <-c.doneChan:
case <-c.doneCh:
default:
c.Logger.Infof("Disconnected from the server %s. Trying to re-connect...", remotePK)
for attemp := 0; attemp < c.retries; attemp++ {
Expand Down Expand Up @@ -405,11 +408,13 @@ func (c *Client) openChannel(rID byte, remotePK []byte, noiseMsg []byte, chanLin

lID = chanLink.chans.add(channel)

c.newWG.Add(1) // Ensure that 'c.newCh' is not being written to before closing.
go func() {
select {
case <-c.doneChan:
case c.newChan <- channel:
case <-c.doneCh:
case c.newCh <- channel:
}
c.newWG.Done()
}()

noiseRes, err = channel.HandshakeMessage()
Expand Down

0 comments on commit 40b3ee6

Please sign in to comment.