diff --git a/pkg/messaging/channel.go b/pkg/messaging/channel.go index 7385dd002..574633225 100644 --- a/pkg/messaging/channel.go +++ b/pkg/messaging/channel.go @@ -30,13 +30,8 @@ type channel struct { doneChan chan struct{} noise *noise.Noise - rMx sync.Mutex - wMx sync.Mutex -} - -// Edges returns the public keys of the channel's edge nodes -func (c *channel) Edges() [2]cipher.PubKey { - return transport.SortPubKeys(c.link.Local(), c.remotePK) + rMx sync.Mutex // lock for decrypt cipher state + wMx sync.Mutex // lock for encrypt cipher state } func newChannel(initiator bool, secKey cipher.SecKey, remote cipher.PubKey, link *Link) (*channel, error) { @@ -64,6 +59,31 @@ func newChannel(initiator bool, secKey cipher.SecKey, remote cipher.PubKey, link }, nil } +// Edges returns the public keys of the channel's edge nodes +func (c *channel) Edges() [2]cipher.PubKey { + return transport.SortPubKeys(c.link.Local(), c.remotePK) +} + +// HandshakeMessage prepares a handshake message safely. +func (c *channel) HandshakeMessage() ([]byte, error) { + c.rMx.Lock() + c.wMx.Lock() + res, err := c.noise.HandshakeMessage() + c.rMx.Unlock() + c.wMx.Unlock() + return res, err +} + +// ProcessMessage reads a handshake message safely. +func (c *channel) ProcessMessage(msg []byte) error { + c.rMx.Lock() + c.wMx.Lock() + err := c.noise.ProcessMessage(msg) + c.rMx.Unlock() + c.wMx.Unlock() + return err +} + func (c *channel) Read(p []byte) (n int, err error) { if c.buf.Len() != 0 { return c.buf.Read(p) diff --git a/pkg/messaging/client.go b/pkg/messaging/client.go index 51473829d..0ec5be2ea 100644 --- a/pkg/messaging/client.go +++ b/pkg/messaging/client.go @@ -161,7 +161,7 @@ func (c *Client) Dial(ctx context.Context, remote cipher.PubKey) (transport.Tran } localID := clientLink.chans.add(channel) - msg, err := channel.noise.HandshakeMessage() + msg, err := channel.HandshakeMessage() if err != nil { return nil, fmt.Errorf("noise handshake: %s", err) } @@ -318,7 +318,7 @@ func (c *Client) onData(l *Link, frameType FrameType, body []byte) error { c.Logger.Debugf("Closed channel ID %d", channelID) case FrameTypeChannelOpened: channel.ID = body[1] - if err := channel.noise.ProcessMessage(body[2:]); err != nil { + if err := channel.ProcessMessage(body[2:]); err != nil { sendErr = fmt.Errorf("noise handshake: %s", err) } @@ -396,7 +396,7 @@ func (c *Client) openChannel(rID byte, remotePK []byte, noiseMsg []byte, chanLin return } - if err = channel.noise.ProcessMessage(noiseMsg); err != nil { + if err = channel.ProcessMessage(noiseMsg); err != nil { err = fmt.Errorf("noise handshake: %s", err) return } @@ -410,7 +410,7 @@ func (c *Client) openChannel(rID byte, remotePK []byte, noiseMsg []byte, chanLin } }() - noiseRes, err = channel.noise.HandshakeMessage() + noiseRes, err = channel.HandshakeMessage() if err != nil { err = fmt.Errorf("noise handshake: %s", err) return diff --git a/pkg/messaging/link.go b/pkg/messaging/link.go index 9583986e0..57f4121d4 100644 --- a/pkg/messaging/link.go +++ b/pkg/messaging/link.go @@ -78,7 +78,6 @@ func (c *Link) Open(wg *sync.WaitGroup) error { c.callbacks.HandshakeComplete(c) // Event loops. - var done = make(chan struct{}) wg.Add(1) go func() { // Exits when connection is closed. @@ -87,7 +86,6 @@ func (c *Link) Open(wg *sync.WaitGroup) error { } // TODO(evanlinjin): Determine if the 'close' is initiated from remote instance. c.callbacks.Close(c, false) - close(done) wg.Done() }() diff --git a/pkg/transport/handshake.go b/pkg/transport/handshake.go index 4a78e87a6..5f13078fb 100644 --- a/pkg/transport/handshake.go +++ b/pkg/transport/handshake.go @@ -12,16 +12,15 @@ import ( type settlementHandshake func(tm *Manager, tr Transport) (*Entry, error) -func (handshake settlementHandshake) Do(tm *Manager, tr Transport, timeout time.Duration) (*Entry, error) { - var entry *Entry - errCh := make(chan error, 1) +func (handshake settlementHandshake) Do(tm *Manager, tr Transport, timeout time.Duration) (entry *Entry, err error) { + done := make(chan struct{}, 1) + defer close(done) go func() { - e, err := handshake(tm, tr) - entry = e - errCh <- err + entry, err = handshake(tm, tr) + done <- struct{}{} }() select { - case err := <-errCh: + case <-done: return entry, err case <-time.After(timeout): return nil, errors.New("deadline exceeded")