From 95cbf466ccafa7e825a26222eb8a1ec6fe7de8ae Mon Sep 17 00:00:00 2001 From: gz-c Date: Sat, 25 May 2019 15:37:23 +0800 Subject: [PATCH 1/3] Fixing link race --- pkg/messaging/channel.go | 70 +++++++++++++++++----------------------- pkg/messaging/client.go | 7 ++-- 2 files changed, 32 insertions(+), 45 deletions(-) diff --git a/pkg/messaging/channel.go b/pkg/messaging/channel.go index db71957120..4bb893b897 100644 --- a/pkg/messaging/channel.go +++ b/pkg/messaging/channel.go @@ -6,9 +6,7 @@ import ( "encoding/binary" "io" "sync" - "sync/atomic" "time" - "unsafe" "github.com/skycoin/skywire/internal/noise" "github.com/skycoin/skywire/pkg/cipher" @@ -24,13 +22,12 @@ type msgChannel struct { buf *bytes.Buffer deadline time.Time - closed unsafe.Pointer // unsafe.Pointer is used alongside 'atomic' module for fast, thread-safe access. - waitChan chan bool // waits for remote response (whether msgChannel is accepted or not). - readChan chan []byte - closeChan chan struct{} - closeChanMx sync.RWMutex // TODO(evanlinjin): This is a hack to avoid race conditions when closing msgChannel. - doneChan chan struct{} + waitChan chan bool // waits for remote response (whether msgChannel is accepted or not). + readChan chan []byte + + doneChan chan struct{} + doneOnce sync.Once noise *noise.Noise rMx sync.Mutex // lock for decrypt cipher state @@ -50,15 +47,13 @@ func newChannel(initiator bool, secKey cipher.SecKey, remote cipher.PubKey, link } return &msgChannel{ - remotePK: remote, - link: link, - buf: new(bytes.Buffer), - closed: unsafe.Pointer(new(bool)), //nolint:gosec - waitChan: make(chan bool, 1), // should allows receive one reply. - readChan: make(chan []byte), - closeChan: make(chan struct{}), - doneChan: make(chan struct{}), - noise: noiseInstance, + remotePK: remote, + link: link, + buf: new(bytes.Buffer), + waitChan: make(chan bool, 1), // should allows receive one reply. + readChan: make(chan []byte), + doneChan: make(chan struct{}), + noise: noiseInstance, }, nil } @@ -118,8 +113,10 @@ func (mCh *msgChannel) Read(p []byte) (n int, err error) { } func (mCh *msgChannel) Write(p []byte) (n int, err error) { - if mCh.isClosed() { + select { + case <-mCh.doneChan: return 0, ErrChannelClosed + default: } ctx := context.Background() @@ -156,24 +153,18 @@ func (mCh *msgChannel) Write(p []byte) (n int, err error) { } } -func (mCh *msgChannel) Close() error { - if mCh.isClosed() { +func (mCh *msgChannel) RequestClose() error { + select { + case <-mCh.doneChan: return ErrChannelClosed + default: } if _, err := mCh.link.SendCloseChannel(mCh.ID()); err != nil { return err } - mCh.setClosed(true) - - select { - case <-mCh.closeChan: - case <-time.After(time.Second): - } - mCh.close() - return nil } func (mCh *msgChannel) SetDeadline(t time.Time) error { @@ -185,16 +176,17 @@ func (mCh *msgChannel) Type() string { return "messaging" } -func (mCh *msgChannel) close() { - select { - case <-mCh.doneChan: - default: - close(mCh.doneChan) +func (mCh *msgChannel) OnChannelClosed() bool { + return mCh.close() +} - mCh.closeChanMx.Lock() // TODO(evanlinjin): START(avoid race condition). - close(mCh.closeChan) // TODO(evanlinjin): data race. - mCh.closeChanMx.Unlock() // TODO(evanlinjin): END(avoid race condition). - } +func (mCh *msgChannel) close() bool { + closed := false + mCh.doneOnce.Do(func() { + close(mc.doneChan) + closed = true + }) + return closed } func (mCh *msgChannel) readEncrypted(ctx context.Context, p []byte) (n int, err error) { @@ -250,7 +242,3 @@ func (mCh *msgChannel) readEncrypted(ctx context.Context, p []byte) (n int, err return copy(p, data), nil } - -// for getting and setting the 'closed' status. -func (mCh *msgChannel) isClosed() bool { return *(*bool)(atomic.LoadPointer(&mCh.closed)) } -func (mCh *msgChannel) setClosed(v bool) { atomic.StorePointer(&mCh.closed, unsafe.Pointer(&v)) } //nolint:gosec diff --git a/pkg/messaging/client.go b/pkg/messaging/client.go index 3c2ef011a8..960f5bdd80 100644 --- a/pkg/messaging/client.go +++ b/pkg/messaging/client.go @@ -328,14 +328,13 @@ func (c *Client) onData(l *Link, frameType FrameType, body []byte) error { } case FrameTypeChannelClosed: channel.SetID(body[0]) - channel.closeChanMx.RLock() // TODO(evanlinjin): START(avoid race condition). select { case channel.waitChan <- false: - case channel.closeChan <- struct{}{}: // TODO(evanlinjin): data race. - clientLink.chans.remove(channelID) default: } - channel.closeChanMx.RUnlock() // TODO(evanlinjin): END(avoid race condition). + if channel.OnChannelClosed() { + clientLink.chans.remove(channelID) + } case FrameTypeSend: go func() { select { From f771869879dc168b8bb2308c07c9bd7a5164b064 Mon Sep 17 00:00:00 2001 From: gz-c Date: Sat, 25 May 2019 15:39:53 +0800 Subject: [PATCH 2/3] Fix compilation errors --- pkg/messaging/channel.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pkg/messaging/channel.go b/pkg/messaging/channel.go index 4bb893b897..998756958d 100644 --- a/pkg/messaging/channel.go +++ b/pkg/messaging/channel.go @@ -153,18 +153,20 @@ func (mCh *msgChannel) Write(p []byte) (n int, err error) { } } -func (mCh *msgChannel) RequestClose() error { +func (mCh *msgChannel) Close() error { select { case <-mCh.doneChan: return ErrChannelClosed default: } - if _, err := mCh.link.SendCloseChannel(mCh.ID()); err != nil { - return err + if mCh.close() { + if _, err := mCh.link.SendCloseChannel(mCh.ID()); err != nil { + return err + } } - mCh.close() + return nil } func (mCh *msgChannel) SetDeadline(t time.Time) error { @@ -183,7 +185,7 @@ func (mCh *msgChannel) OnChannelClosed() bool { func (mCh *msgChannel) close() bool { closed := false mCh.doneOnce.Do(func() { - close(mc.doneChan) + close(mCh.doneChan) closed = true }) return closed From 8d818cf87570d334e66b118ed26622c820827335 Mon Sep 17 00:00:00 2001 From: gz-c Date: Sat, 25 May 2019 17:37:18 +0800 Subject: [PATCH 3/3] Fix test compilation error --- pkg/messaging/channel_test.go | 4 +++- pkg/messaging/client.go | 5 ++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pkg/messaging/channel_test.go b/pkg/messaging/channel_test.go index bbdadd61c8..49b247c755 100644 --- a/pkg/messaging/channel_test.go +++ b/pkg/messaging/channel_test.go @@ -103,7 +103,9 @@ func TestChannelWrite(t *testing.T) { _, err = c.Write([]byte("foo")) require.Equal(t, ErrDeadlineExceeded, err) - c.setClosed(true) + closed := c.close() + require.True(t, closed) + _, err = c.Write([]byte("foo")) require.Equal(t, ErrChannelClosed, err) } diff --git a/pkg/messaging/client.go b/pkg/messaging/client.go index 960f5bdd80..b310fa5c95 100644 --- a/pkg/messaging/client.go +++ b/pkg/messaging/client.go @@ -332,9 +332,8 @@ func (c *Client) onData(l *Link, frameType FrameType, body []byte) error { case channel.waitChan <- false: default: } - if channel.OnChannelClosed() { - clientLink.chans.remove(channelID) - } + channel.OnChannelClosed() + clientLink.chans.remove(channelID) case FrameTypeSend: go func() { select {