From 11fbeaafa145e37cde7a6dac4dce6be0e6ce1605 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E5=BF=97=E5=AE=87?= Date: Fri, 10 May 2019 01:04:27 +1200 Subject: [PATCH] made modifications to channel ID thread-safe. --- .gitignore | 3 ++- pkg/messaging/channel.go | 23 ++++++++++++++++++++--- pkg/messaging/channel_test.go | 4 ++-- pkg/messaging/client.go | 10 +++++----- 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 4cd5ed0cb7..3abd13c0f7 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,5 @@ pkg/node/foo/ /*-node /*-cli /*.json -/*.sh \ No newline at end of file +/*.sh +/*.log \ No newline at end of file diff --git a/pkg/messaging/channel.go b/pkg/messaging/channel.go index 574633225e..1256ed4485 100644 --- a/pkg/messaging/channel.go +++ b/pkg/messaging/channel.go @@ -16,7 +16,9 @@ import ( ) type channel struct { - ID byte + id byte // This is to be changed. + idMx sync.RWMutex + remotePK cipher.PubKey link *Link buf *bytes.Buffer @@ -59,6 +61,21 @@ func newChannel(initiator bool, secKey cipher.SecKey, remote cipher.PubKey, link }, nil } +// ID obtains the channel's id. +func (c *channel) ID() byte { + c.idMx.RLock() + id := c.id + c.idMx.RUnlock() + return id +} + +// SetID set's the channel's id. +func (c *channel) SetID(id byte) { + c.idMx.Lock() + c.id = id + c.idMx.Unlock() +} + // Edges returns the public keys of the channel's edge nodes func (c *channel) Edges() [2]cipher.PubKey { return transport.SortPubKeys(c.link.Local(), c.remotePK) @@ -122,7 +139,7 @@ func (c *channel) Write(p []byte) (n int, err error) { done := make(chan struct{}, 1) defer close(done) go func() { - n, err = c.link.Send(c.ID, buf) + n, err = c.link.Send(c.ID(), buf) n = n - (len(data) - len(p) + 2) select { case done <- struct{}{}: @@ -143,7 +160,7 @@ func (c *channel) Close() error { return ErrChannelClosed } - if _, err := c.link.SendCloseChannel(c.ID); err != nil { + if _, err := c.link.SendCloseChannel(c.ID()); err != nil { return err } diff --git a/pkg/messaging/channel_test.go b/pkg/messaging/channel_test.go index d8d7908b16..16b486610d 100644 --- a/pkg/messaging/channel_test.go +++ b/pkg/messaging/channel_test.go @@ -81,7 +81,7 @@ func TestChannelWrite(t *testing.T) { c, err := newChannel(true, sk, remotePK, l) require.NoError(t, err) - c.ID = 10 + c.SetID(10) rn := handshakeChannel(t, c, remotePK, remoteSK) @@ -118,7 +118,7 @@ func TestChannelClose(t *testing.T) { c, err := newChannel(true, sk, remotePK, l) require.NoError(t, err) - c.ID = 10 + c.SetID(10) handshakeChannel(t, c, remotePK, remoteSK) diff --git a/pkg/messaging/client.go b/pkg/messaging/client.go index 0ec5be2eac..c81abc44ad 100644 --- a/pkg/messaging/client.go +++ b/pkg/messaging/client.go @@ -179,7 +179,7 @@ func (c *Client) Dial(ctx context.Context, remote cipher.PubKey) (transport.Tran return nil, ctx.Err() } - c.Logger.Infof("Opened new channel local ID %d, remote ID %d with %s", localID, channel.ID, remote) + c.Logger.Infof("Opened new channel local ID %d, remote ID %d with %s", localID, channel.ID(), remote) // TODO: race condition return channel, nil } @@ -314,10 +314,10 @@ func (c *Client) onData(l *Link, frameType FrameType, body []byte) error { switch frameType { case FrameTypeCloseChannel: clientLink.chans.remove(channelID) - _, sendErr = l.SendChannelClosed(channel.ID) + _, sendErr = l.SendChannelClosed(channel.ID()) c.Logger.Debugf("Closed channel ID %d", channelID) case FrameTypeChannelOpened: - channel.ID = body[1] + channel.SetID(body[1]) if err := channel.ProcessMessage(body[2:]); err != nil { sendErr = fmt.Errorf("noise handshake: %s", err) } @@ -327,7 +327,7 @@ func (c *Client) onData(l *Link, frameType FrameType, body []byte) error { default: } case FrameTypeChannelClosed: - channel.ID = body[0] + channel.SetID(body[0]) select { case channel.waitChan <- false: case channel.closeChan <- struct{}{}: @@ -390,7 +390,7 @@ func (c *Client) openChannel(rID byte, remotePK []byte, noiseMsg []byte, chanLin } channel, err := newChannel(false, c.secKey, pubKey, chanLink.link) - channel.ID = rID + channel.SetID(rID) if err != nil { err = fmt.Errorf("noise setup: %s", err) return