Skip to content

Commit

Permalink
Implemented thread-safe closed var in messaging channel.
Browse files Browse the repository at this point in the history
  • Loading branch information
林志宇 committed May 9, 2019
1 parent 79fc1ac commit 53be5a0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
15 changes: 11 additions & 4 deletions pkg/messaging/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"encoding/binary"
"io"
"sync"
"sync/atomic"
"time"
"unsafe"

"github.com/skycoin/skywire/internal/noise"
"github.com/skycoin/skywire/pkg/cipher"
Expand All @@ -20,7 +22,7 @@ type channel struct {
buf *bytes.Buffer

deadline time.Time
closed bool
closed unsafe.Pointer // unsafe.Pointer is used alongside 'atomic' module for fast, thread-safe access.

waitChan chan bool
readChan chan []byte
Expand Down Expand Up @@ -53,6 +55,7 @@ func newChannel(initiator bool, secKey cipher.SecKey, remote cipher.PubKey, link
remotePK: remote,
link: link,
buf: new(bytes.Buffer),
closed: unsafe.Pointer(new(bool)), //nolint:gosec
waitChan: make(chan bool),
readChan: make(chan []byte),
closeChan: make(chan struct{}),
Expand All @@ -77,7 +80,7 @@ func (c *channel) Read(p []byte) (n int, err error) {
}

func (c *channel) Write(p []byte) (n int, err error) {
if c.closed {
if c.isClosed() {
return 0, ErrChannelClosed
}

Expand Down Expand Up @@ -116,15 +119,15 @@ func (c *channel) Write(p []byte) (n int, err error) {
}

func (c *channel) Close() error {
if c.closed {
if c.isClosed() {
return ErrChannelClosed
}

if _, err := c.link.SendCloseChannel(c.ID); err != nil {
return err
}

c.closed = true
c.setClosed(true)

select {
case <-c.closeChan:
Expand Down Expand Up @@ -206,3 +209,7 @@ func (c *channel) readEncrypted(ctx context.Context, p []byte) (n int, err error

return copy(p, data), nil
}

// for getting and setting the 'closed' status.
func (c *channel) isClosed() bool { return *(*bool)(atomic.LoadPointer(&c.closed)) }
func (c *channel) setClosed(v bool) { atomic.StorePointer(&c.closed, unsafe.Pointer(&v)) } //nolint:gosec
2 changes: 1 addition & 1 deletion pkg/messaging/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func TestChannelWrite(t *testing.T) {
_, err = c.Write([]byte("foo"))
require.Equal(t, ErrDeadlineExceeded, err)

c.closed = true
c.setClosed(true)
_, err = c.Write([]byte("foo"))
require.Equal(t, ErrChannelClosed, err)
}
Expand Down

0 comments on commit 53be5a0

Please sign in to comment.