diff --git a/pkg/messaging/chan_list.go b/pkg/messaging/chan_list.go index 22594353b..4c5fe9d0a 100644 --- a/pkg/messaging/chan_list.go +++ b/pkg/messaging/chan_list.go @@ -26,12 +26,6 @@ func (c *chanList) add(channel *channel) byte { panic("no free channels") } -func (c *chanList) set(id byte, channel *channel) { - c.Lock() - c.chans[id] = channel - c.Unlock() -} - func (c *chanList) get(id byte) *channel { c.Lock() ch := c.chans[id] diff --git a/pkg/messaging/client.go b/pkg/messaging/client.go index 023c46d20..184dd84b7 100644 --- a/pkg/messaging/client.go +++ b/pkg/messaging/client.go @@ -290,12 +290,13 @@ func (c *Client) onData(l *Link, frameType FrameType, body []byte) error { c.Logger.Debugf("New frame %s from %s@%d", frameType, remotePK, channelID) if frameType == FrameTypeOpenChannel { - if msg, err := c.openChannel(channelID, body[1:34], body[34:], clientLink); err != nil { + if lID, msg, err := c.openChannel(channelID, body[1:34], body[34:], clientLink); err != nil { c.Logger.Warnf("Failed to open new channel for %s: %s", remotePK, err) _, sendErr = l.SendChannelClosed(channelID) } else { - c.Logger.Infof("Opened new channel ID %d with %s", channelID, hex.EncodeToString(body[1:34])) - _, sendErr = l.SendChannelOpened(channelID, msg) + c.Logger.Infof("Opened new channel local ID %d, remote ID %d with %s", lID, channelID, + hex.EncodeToString(body[1:34])) + _, sendErr = l.SendChannelOpened(lID, msg) } return c.warnSendError(remotePK, sendErr) @@ -312,7 +313,7 @@ func (c *Client) onData(l *Link, frameType FrameType, body []byte) error { switch frameType { case FrameTypeCloseChannel: clientLink.chans.remove(channelID) - _, sendErr = l.SendChannelClosed(channelID) + _, sendErr = l.SendChannelClosed(channel.ID) c.Logger.Debugf("Closed channel ID %d", channelID) case FrameTypeChannelOpened: if err := channel.noise.ProcessMessage(body[1:]); err != nil { @@ -324,6 +325,7 @@ func (c *Client) onData(l *Link, frameType FrameType, body []byte) error { default: } case FrameTypeChannelClosed: + channel.ID = body[0] select { case channel.waitChan <- false: case channel.closeChan <- struct{}{}: @@ -350,6 +352,10 @@ func (c *Client) onClose(l *Link, remote bool) { chanLink := c.links[remotePK] c.mu.RUnlock() + for _, channel := range chanLink.chans.dropAll() { + channel.close() + } + select { case <-c.doneChan: default: @@ -365,10 +371,6 @@ func (c *Client) onClose(l *Link, remote bool) { c.Logger.Infof("Closing link with the server %s", remotePK) - for _, channel := range chanLink.chans.dropAll() { - channel.close() - } - c.mu.Lock() delete(c.links, remotePK) c.mu.Unlock() @@ -378,28 +380,27 @@ func (c *Client) onClose(l *Link, remote bool) { } } -func (c *Client) openChannel(channelID byte, remotePK []byte, msg []byte, chanLink *clientLink) ([]byte, error) { - channel := chanLink.chans.get(channelID) - if channel != nil { - return nil, errors.New("channel is already opened") - } - - pubKey, err := cipher.NewPubKey(remotePK) +func (c *Client) openChannel(rID byte, remotePK []byte, noiseMsg []byte, chanLink *clientLink) (lID byte, noiseRes []byte, err error) { + var pubKey cipher.PubKey + pubKey, err = cipher.NewPubKey(remotePK) if err != nil { - return nil, err + return } - channel, err = newChannel(false, c.secKey, pubKey, chanLink.link) + + channel, err := newChannel(false, c.secKey, pubKey, chanLink.link) + channel.ID = rID if err != nil { - return nil, fmt.Errorf("noise setup: %s", err) + err = fmt.Errorf("noise setup: %s", err) + return } - channel.ID = channelID - chanLink.chans.set(channelID, channel) - - if err := channel.noise.ProcessMessage(msg); err != nil { - return nil, fmt.Errorf("noise handshake: %s", err) + if err = channel.noise.ProcessMessage(noiseMsg); err != nil { + err = fmt.Errorf("noise handshake: %s", err) + return } + lID = chanLink.chans.add(channel) + go func() { select { case <-c.doneChan: @@ -407,12 +408,13 @@ func (c *Client) openChannel(channelID byte, remotePK []byte, msg []byte, chanLi } }() - res, err := channel.noise.HandshakeMessage() + noiseRes, err = channel.noise.HandshakeMessage() if err != nil { - return nil, fmt.Errorf("noise handshake: %s", err) + err = fmt.Errorf("noise handshake: %s", err) + return } - return res, nil + return lID, noiseRes, err } func (c *Client) warnSendError(remote cipher.PubKey, err error) error {