Skip to content

Commit

Permalink
Merge pull request #373 from evanlinjin/bug/link-race-gz-c
Browse files Browse the repository at this point in the history
Fixed faulty test that is discovered after gz-c's PR.
  • Loading branch information
ivcosla authored May 27, 2019
2 parents 32611c5 + a04190a commit b917300
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 76 deletions.
4 changes: 3 additions & 1 deletion pkg/messaging/chan_list.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package messaging

import "sync"
import (
"sync"
)

type chanList struct {
sync.Mutex
Expand Down
74 changes: 32 additions & 42 deletions pkg/messaging/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ import (
"encoding/binary"
"io"
"sync"
"sync/atomic"
"time"
"unsafe"

"github.com/skycoin/skywire/internal/noise"
"github.com/skycoin/skywire/pkg/cipher"
Expand All @@ -24,13 +22,12 @@ type msgChannel struct {
buf *bytes.Buffer

deadline time.Time
closed unsafe.Pointer // unsafe.Pointer is used alongside 'atomic' module for fast, thread-safe access.

waitChan chan bool // waits for remote response (whether msgChannel is accepted or not).
readChan chan []byte
closeChan chan struct{}
closeChanMx sync.RWMutex // TODO(evanlinjin): This is a hack to avoid race conditions when closing msgChannel.
doneChan chan struct{}
waitChan chan bool // waits for remote response (whether msgChannel is accepted or not).
readChan chan []byte

doneChan chan struct{}
doneOnce sync.Once

noise *noise.Noise
rMx sync.Mutex // lock for decrypt cipher state
Expand All @@ -50,15 +47,13 @@ func newChannel(initiator bool, secKey cipher.SecKey, remote cipher.PubKey, link
}

return &msgChannel{
remotePK: remote,
link: link,
buf: new(bytes.Buffer),
closed: unsafe.Pointer(new(bool)), //nolint:gosec
waitChan: make(chan bool, 1), // should allows receive one reply.
readChan: make(chan []byte),
closeChan: make(chan struct{}),
doneChan: make(chan struct{}),
noise: noiseInstance,
remotePK: remote,
link: link,
buf: new(bytes.Buffer),
waitChan: make(chan bool, 1), // should allows receive one reply.
readChan: make(chan []byte),
doneChan: make(chan struct{}),
noise: noiseInstance,
}, nil
}

Expand Down Expand Up @@ -118,8 +113,10 @@ func (mCh *msgChannel) Read(p []byte) (n int, err error) {
}

func (mCh *msgChannel) Write(p []byte) (n int, err error) {
if mCh.isClosed() {
select {
case <-mCh.doneChan:
return 0, ErrChannelClosed
default:
}

ctx := context.Background()
Expand Down Expand Up @@ -153,22 +150,18 @@ func (mCh *msgChannel) Write(p []byte) (n int, err error) {
}

func (mCh *msgChannel) Close() error {
if mCh.isClosed() {
select {
case <-mCh.doneChan:
return ErrChannelClosed
default:
}

if _, err := mCh.link.SendCloseChannel(mCh.ID()); err != nil {
return err
}

mCh.setClosed(true)

select {
case <-mCh.closeChan:
case <-time.After(time.Second):
if mCh.close() {
if _, err := mCh.link.SendCloseChannel(mCh.ID()); err != nil {
return err
}
}

mCh.close()
return nil
}

Expand All @@ -181,16 +174,17 @@ func (mCh *msgChannel) Type() string {
return "messaging"
}

func (mCh *msgChannel) close() {
select {
case <-mCh.doneChan:
default:
close(mCh.doneChan)
func (mCh *msgChannel) OnChannelClosed() bool {
return mCh.close()
}

mCh.closeChanMx.Lock() // TODO(evanlinjin): START(avoid race condition).
close(mCh.closeChan) // TODO(evanlinjin): data race.
mCh.closeChanMx.Unlock() // TODO(evanlinjin): END(avoid race condition).
}
func (mCh *msgChannel) close() bool {
closed := false
mCh.doneOnce.Do(func() {
close(mCh.doneChan)
closed = true
})
return closed
}

func (mCh *msgChannel) readEncrypted(ctx context.Context, p []byte) (n int, err error) {
Expand Down Expand Up @@ -246,7 +240,3 @@ func (mCh *msgChannel) readEncrypted(ctx context.Context, p []byte) (n int, err

return copy(p, data), nil
}

// for getting and setting the 'closed' status.
func (mCh *msgChannel) isClosed() bool { return *(*bool)(atomic.LoadPointer(&mCh.closed)) }
func (mCh *msgChannel) setClosed(v bool) { atomic.StorePointer(&mCh.closed, unsafe.Pointer(&v)) } //nolint:gosec
4 changes: 3 additions & 1 deletion pkg/messaging/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ func TestChannelWrite(t *testing.T) {
_, err = c.Write([]byte("foo"))
require.Equal(t, ErrDeadlineExceeded, err)

c.setClosed(true)
closed := c.close()
require.True(t, closed)

_, err = c.Write([]byte("foo"))
require.Equal(t, ErrChannelClosed, err)
}
Expand Down
6 changes: 2 additions & 4 deletions pkg/messaging/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,14 +331,12 @@ func (c *Client) onData(l *Link, frameType FrameType, body []byte) error {
}
case FrameTypeChannelClosed:
channel.SetID(body[0])
channel.closeChanMx.RLock() // Begin: avoid race condition.
select {
case channel.waitChan <- false:
case channel.closeChan <- struct{}{}: // Previous data race.
clientLink.chans.remove(channelID)
default:
}
channel.closeChanMx.RUnlock() // End: avoid race condition.
channel.OnChannelClosed()
clientLink.chans.remove(channelID)
case FrameTypeSend:
go func() {
select {
Expand Down
100 changes: 72 additions & 28 deletions pkg/messaging/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"sync"
"testing"
"time"

"github.com/skycoin/skycoin/src/util/logging"
"github.com/stretchr/testify/assert"
Expand All @@ -24,65 +25,108 @@ func TestMain(m *testing.M) {
}

func TestClientDial(t *testing.T) {
pk, sk := cipher.GenerateKeyPair()
discovery := client.NewMock()
c := NewClient(&Config{pk, sk, discovery, 0, 0})
c.retries = 0

srv, err := newMockServer(discovery)
require.NoError(t, err)
srvPK := srv.config.Public

anotherPK, anotherSK := cipher.GenerateKeyPair()
anotherClient := NewClient(&Config{anotherPK, anotherSK, discovery, 0, 0})
require.NoError(t, anotherClient.ConnectToInitialServers(context.TODO(), 1))
pk1, sk1 := cipher.GenerateKeyPair()
c1 := NewClient(&Config{pk1, sk1, discovery, 0, 0})

pk2, sk2 := cipher.GenerateKeyPair()
c2 := NewClient(&Config{pk2, sk2, discovery, 0, 0})
require.NoError(t, c2.ConnectToInitialServers(context.TODO(), 1))

var anotherTr transport.Transport
anotherErrCh := make(chan error)
var (
tp2 transport.Transport
tp2Err error
tp2Done = make(chan struct{})
)
go func() {
t, err := anotherClient.Accept(context.TODO())
anotherTr = t
anotherErrCh <- err
tp2, tp2Err = c2.Accept(context.TODO())
close(tp2Done)
}()

var tr transport.Transport
errCh := make(chan error)
var (
tp1 transport.Transport
tp1Err error
tp1Done = make(chan struct{})
)
go func() {
t, err := c.Dial(context.TODO(), anotherPK)
tr = t
errCh <- err
tp1, tp1Err = c1.Dial(context.TODO(), pk2)
close(tp1Done)
}()

require.NoError(t, <-errCh)
require.NotNil(t, c.getLink(srvPK).chans.get(0))
<-tp1Done
require.NoError(t, tp1Err)
require.NotNil(t, c1.getLink(srvPK).chans.get(0))

entry, err := discovery.Entry(context.TODO(), pk)
entry, err := discovery.Entry(context.TODO(), pk1)
require.NoError(t, err)
require.Len(t, entry.Client.DelegatedServers, 1)

require.NoError(t, <-anotherErrCh)
require.NotNil(t, anotherClient.getLink(srvPK).chans.get(0))
<-tp2Done
require.NoError(t, tp2Err)
require.NotNil(t, c2.getLink(srvPK).chans.get(0))

go tr.Write([]byte("foo")) // nolint: errcheck
go tp1.Write([]byte("foo")) // nolint: errcheck

buf := make([]byte, 3)
n, err := anotherTr.Read(buf)
n, err := tp2.Read(buf)
require.NoError(t, err)
assert.Equal(t, 3, n)
assert.Equal(t, []byte("foo"), buf)

go anotherTr.Write([]byte("bar")) // nolint: errcheck
go tp2.Write([]byte("bar")) // nolint: errcheck

buf = make([]byte, 3)
n, err = tr.Read(buf)
n, err = tp1.Read(buf)
require.NoError(t, err)
assert.Equal(t, 3, n)
assert.Equal(t, []byte("bar"), buf)

require.NoError(t, tr.Close())
require.NoError(t, anotherTr.Close())
require.NoError(t, tp1.Close())
require.NoError(t, tp2.Close())

// It is expected for the messaging client to delete the channel for chanList eventually.
require.True(t, retry(time.Second*10, time.Second, func() bool {
return c2.getLink(srvPK).chans.get(0) == nil
}))
}

// retries until successful under a given deadline.
// 'tick' specifies the break duration before retry.
func retry(deadline, tick time.Duration, do func() bool) bool {
timer := time.NewTimer(deadline)
defer timer.Stop()

done := make(chan struct{})
doneOnce := new(sync.Once)
defer doneOnce.Do(func() { close(done) })

require.Nil(t, anotherClient.getLink(srvPK).chans.get(0))
go func() {
for {
select {
case <-done:
return
case <-time.Tick(tick):
if ok := do(); ok {
doneOnce.Do(func() { close(done) })
return
}
}
}
}()

for {
select {
case <-timer.C:
return false
case <-done:
return true
}
}
}

type mockServer struct {
Expand Down

0 comments on commit b917300

Please sign in to comment.