Skip to content

Commit

Permalink
Made noise handshakes within messaging thread safe.
Browse files Browse the repository at this point in the history
  • Loading branch information
林志宇 committed May 9, 2019
1 parent 53be5a0 commit cfcbbcd
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 20 deletions.
34 changes: 27 additions & 7 deletions pkg/messaging/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,8 @@ type channel struct {
doneChan chan struct{}

noise *noise.Noise
rMx sync.Mutex
wMx sync.Mutex
}

// Edges returns the public keys of the channel's edge nodes
func (c *channel) Edges() [2]cipher.PubKey {
return transport.SortPubKeys(c.link.Local(), c.remotePK)
rMx sync.Mutex // lock for decrypt cipher state
wMx sync.Mutex // lock for encrypt cipher state
}

func newChannel(initiator bool, secKey cipher.SecKey, remote cipher.PubKey, link *Link) (*channel, error) {
Expand Down Expand Up @@ -64,6 +59,31 @@ func newChannel(initiator bool, secKey cipher.SecKey, remote cipher.PubKey, link
}, nil
}

// Edges returns the public keys of the channel's edge nodes
func (c *channel) Edges() [2]cipher.PubKey {
return transport.SortPubKeys(c.link.Local(), c.remotePK)
}

// HandshakeMessage prepares a handshake message safely.
func (c *channel) HandshakeMessage() ([]byte, error) {
c.rMx.Lock()
c.wMx.Lock()
res, err := c.noise.HandshakeMessage()
c.rMx.Unlock()
c.wMx.Unlock()
return res, err
}

// ProcessMessage reads a handshake message safely.
func (c *channel) ProcessMessage(msg []byte) error {
c.rMx.Lock()
c.wMx.Lock()
err := c.noise.ProcessMessage(msg)
c.rMx.Unlock()
c.wMx.Unlock()
return err
}

func (c *channel) Read(p []byte) (n int, err error) {
if c.buf.Len() != 0 {
return c.buf.Read(p)
Expand Down
8 changes: 4 additions & 4 deletions pkg/messaging/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (c *Client) Dial(ctx context.Context, remote cipher.PubKey) (transport.Tran
}
localID := clientLink.chans.add(channel)

msg, err := channel.noise.HandshakeMessage()
msg, err := channel.HandshakeMessage()
if err != nil {
return nil, fmt.Errorf("noise handshake: %s", err)
}
Expand Down Expand Up @@ -318,7 +318,7 @@ func (c *Client) onData(l *Link, frameType FrameType, body []byte) error {
c.Logger.Debugf("Closed channel ID %d", channelID)
case FrameTypeChannelOpened:
channel.ID = body[1]
if err := channel.noise.ProcessMessage(body[2:]); err != nil {
if err := channel.ProcessMessage(body[2:]); err != nil {
sendErr = fmt.Errorf("noise handshake: %s", err)
}

Expand Down Expand Up @@ -396,7 +396,7 @@ func (c *Client) openChannel(rID byte, remotePK []byte, noiseMsg []byte, chanLin
return
}

if err = channel.noise.ProcessMessage(noiseMsg); err != nil {
if err = channel.ProcessMessage(noiseMsg); err != nil {
err = fmt.Errorf("noise handshake: %s", err)
return
}
Expand All @@ -410,7 +410,7 @@ func (c *Client) openChannel(rID byte, remotePK []byte, noiseMsg []byte, chanLin
}
}()

noiseRes, err = channel.noise.HandshakeMessage()
noiseRes, err = channel.HandshakeMessage()
if err != nil {
err = fmt.Errorf("noise handshake: %s", err)
return
Expand Down
2 changes: 0 additions & 2 deletions pkg/messaging/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ func (c *Link) Open(wg *sync.WaitGroup) error {
c.callbacks.HandshakeComplete(c)

// Event loops.
var done = make(chan struct{})
wg.Add(1)
go func() {
// Exits when connection is closed.
Expand All @@ -87,7 +86,6 @@ func (c *Link) Open(wg *sync.WaitGroup) error {
}
// TODO(evanlinjin): Determine if the 'close' is initiated from remote instance.
c.callbacks.Close(c, false)
close(done)
wg.Done()
}()

Expand Down
13 changes: 6 additions & 7 deletions pkg/transport/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@ import (

type settlementHandshake func(tm *Manager, tr Transport) (*Entry, error)

func (handshake settlementHandshake) Do(tm *Manager, tr Transport, timeout time.Duration) (*Entry, error) {
var entry *Entry
errCh := make(chan error, 1)
func (handshake settlementHandshake) Do(tm *Manager, tr Transport, timeout time.Duration) (entry *Entry, err error) {
done := make(chan struct{}, 1)
defer close(done)
go func() {
e, err := handshake(tm, tr)
entry = e
errCh <- err
entry, err = handshake(tm, tr)
done <- struct{}{}
}()
select {
case err := <-errCh:
case <-done:
return entry, err
case <-time.After(timeout):
return nil, errors.New("deadline exceeded")
Expand Down

0 comments on commit cfcbbcd

Please sign in to comment.