Skip to content

Commit

Permalink
p2p: fix RLPx disconnect message decoding (#8056)
Browse files Browse the repository at this point in the history
The disconnect message could either be a plain integer, or a list with
one integer element. We were encoding it as a plain integer, but
decoding as a list. Change this to be able to decode any format.
  • Loading branch information
battlmonstr authored Aug 24, 2023
1 parent 66d93f2 commit bb2c2ad
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 20 deletions.
10 changes: 3 additions & 7 deletions cmd/observer/observer/handshake.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package observer

import (
"bytes"
"context"
"crypto/ecdsa"
"fmt"
"math/big"
"net"
"strings"
"time"

libcommon "github.com/ledgerwatch/erigon-lib/common"
Expand Down Expand Up @@ -215,15 +215,11 @@ func readMessage(conn *rlpx.Conn, expectedMessageID uint64, decodeError Handshak
return readMessage(conn, expectedMessageID, decodeError, message)
}
if messageID == RLPxMessageIDDisconnect {
var reason [1]p2p.DiscReason
err = rlp.DecodeBytes(data, &reason)
if (err != nil) && strings.Contains(err.Error(), "rlp: expected input list") {
err = rlp.DecodeBytes(data, &reason[0])
}
reason, err := p2p.DisconnectMessagePayloadDecode(bytes.NewBuffer(data))
if err != nil {
return NewHandshakeError(HandshakeErrorIDDisconnectDecode, err, 0)
}
return NewHandshakeError(HandshakeErrorIDDisconnect, reason[0], uint64(reason[0]))
return NewHandshakeError(HandshakeErrorIDDisconnect, reason, uint64(reason))
}
if messageID != expectedMessageID {
return NewHandshakeError(HandshakeErrorIDUnexpectedMessage, nil, messageID)
Expand Down
12 changes: 7 additions & 5 deletions p2p/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,13 @@ func (p *Peer) handle(msg Msg) error {
msg.Discard()
go SendItems(p.rw, pongMsg)
case msg.Code == discMsg:
// This is the last message. We don't need to discard or
// check errors because, the connection will be closed after it.
var m struct{ R DiscReason }
rlp.Decode(msg.Payload, &m)
return m.R
// This is the last message.
// We don't need to discard because the connection will be closed after it.
reason, err := DisconnectMessagePayloadDecode(msg.Payload)
if err != nil {
p.log.Debug("Peer.handle: failed to rlp.Decode msg.Payload", "err", err)
}
return reason
case msg.Code < baseProtocolLength:
// ignore other base protocol messages
msg.Discard()
Expand Down
43 changes: 35 additions & 8 deletions p2p/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"io"
"net"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -123,7 +124,7 @@ func (t *rlpxTransport) close(err error) {
if err := t.conn.SetWriteDeadline(deadline); err == nil {
// Connection supports write deadline.
t.wbuf.Reset()
rlp.Encode(&t.wbuf, []DiscReason{r}) //nolint:errcheck
_ = DisconnectMessagePayloadEncode(&t.wbuf, r)
t.conn.Write(discMsg, t.wbuf.Bytes()) //nolint:errcheck
}
}
Expand Down Expand Up @@ -169,13 +170,8 @@ func readProtocolHandshake(rw MsgReader) (*protoHandshake, error) {
if msg.Code == discMsg {
// Disconnect before protocol handshake is valid according to the
// spec and we send it ourself if the post-handshake checks fail.
// We can't return the reason directly, though, because it is echoed
// back otherwise. Wrap it in a string instead.
var reason [1]DiscReason
if err = rlp.Decode(msg.Payload, &reason); err != nil {
return nil, err
}
return nil, reason[0]
reason, _ := DisconnectMessagePayloadDecode(msg.Payload)
return nil, reason
}
if msg.Code != handshakeMsg {
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
Expand All @@ -189,3 +185,34 @@ func readProtocolHandshake(rw MsgReader) (*protoHandshake, error) {
}
return &hs, nil
}

func DisconnectMessagePayloadDecode(reader io.Reader) (DiscReason, error) {
var buffer bytes.Buffer
_, err := buffer.ReadFrom(reader)
if err != nil {
return DiscRequested, err
}
data := buffer.Bytes()
if len(data) == 0 {
return DiscRequested, nil
}

var reasonList struct{ Reason DiscReason }
err = rlp.DecodeBytes(data, &reasonList)

// en empty list
if (err != nil) && strings.Contains(err.Error(), "rlp: too few elements") {
return DiscRequested, nil
}

// not a list, try to decode as a plain integer
if (err != nil) && strings.Contains(err.Error(), "rlp: expected input list") {
err = rlp.DecodeBytes(data, &reasonList.Reason)
}

return reasonList.Reason, err
}

func DisconnectMessagePayloadEncode(writer io.Writer, reason DiscReason) error {
return rlp.Encode(writer, []DiscReason{reason})
}
52 changes: 52 additions & 0 deletions p2p/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package p2p

import (
"bytes"
"errors"
"reflect"
"sync"
Expand Down Expand Up @@ -146,3 +147,54 @@ func TestProtocolHandshakeErrors(t *testing.T) {
}
}
}

func TestDisconnectMessagePayloadDecode(t *testing.T) {
var buffer bytes.Buffer
err := DisconnectMessagePayloadEncode(&buffer, DiscTooManyPeers)
if err != nil {
t.Error(err)
}
reason, err := DisconnectMessagePayloadDecode(&buffer)
if err != nil {
t.Error(err)
}
if reason != DiscTooManyPeers {
t.Fail()
}

// plain integer
reason, err = DisconnectMessagePayloadDecode(bytes.NewBuffer([]byte{uint8(DiscTooManyPeers)}))
if err != nil {
t.Error(err)
}
if reason != DiscTooManyPeers {
t.Fail()
}

// single-element RLP list
reason, err = DisconnectMessagePayloadDecode(bytes.NewBuffer([]byte{0xC1, uint8(DiscTooManyPeers)}))
if err != nil {
t.Error(err)
}
if reason != DiscTooManyPeers {
t.Fail()
}

// empty RLP list
reason, err = DisconnectMessagePayloadDecode(bytes.NewBuffer([]byte{0xC0}))
if err != nil {
t.Error(err)
}
if reason != DiscRequested {
t.Fail()
}

// empty payload
reason, err = DisconnectMessagePayloadDecode(bytes.NewBuffer([]byte{}))
if err != nil {
t.Error(err)
}
if reason != DiscRequested {
t.Fail()
}
}

0 comments on commit bb2c2ad

Please sign in to comment.