Skip to content

Commit

Permalink
Merge pull request #159 from gotd/feat/message-id-buf
Browse files Browse the repository at this point in the history
feat: message id buffer
  • Loading branch information
ernado authored Feb 21, 2021
2 parents d23d301 + ea71050 commit 5b8019f
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 13 deletions.
52 changes: 52 additions & 0 deletions internal/proto/message_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,55 @@ func NewMessageIDGen(now func() time.Time, n int) *MessageIDGen {
now: now,
}
}

// MessageIDBuf stores last N message ids and is used in replay attack mitigation.
type MessageIDBuf struct {
mux sync.Mutex
buf []int64
}

// NewMessageIDBuf initializes new message id buffer for last N stored values.
func NewMessageIDBuf(n int) *MessageIDBuf {
return &MessageIDBuf{
buf: make([]int64, n),
}
}

// Consume returns false if message should be discarded.
func (b *MessageIDBuf) Consume(newID int64) bool {
// In addition, the identifiers (msg_id) of the last N messages received
// from the other side must be stored, and if a message comes in with an
// msg_id lower than all or equal to any of the stored values, that message
// is to be ignored. Otherwise, the new message msg_id is added to the set,
// and, if the number of stored msg_id values is greater than N, the oldest
// (i. e. the lowest) is discarded.
//
// https://core.telegram.org/mtproto/security_guidelines#checking-msg-id

b.mux.Lock()
defer b.mux.Unlock()

var (
minIDx int
minID int64
)
for i, id := range b.buf {
if id == newID {
// Equal to stored value.
return false
}
// Searching for minimum value.
if id < minID {
minIDx = i
minID = id
}
}
if newID < minID {
// Lower than all stored values.
return false
}

// Message is accepted. Replacing lowest message id with new id.
b.buf[minIDx] = newID
return true
}
29 changes: 29 additions & 0 deletions internal/proto/message_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gotd/neo"
Expand Down Expand Up @@ -89,3 +90,31 @@ func BenchmarkMsgIDGen_New(b *testing.B) {
_ = gen.New(MessageFromServer)
}
}

func TestNewMessageIDBuf(t *testing.T) {
t.Run("Zero", func(t *testing.T) {
buf := NewMessageIDBuf(10)

assert.False(t, buf.Consume(0))
})
t.Run("Ok", func(t *testing.T) {
buf := NewMessageIDBuf(10)

assert.True(t, buf.Consume(1))
assert.False(t, buf.Consume(1))

t.Run("Sequence", func(t *testing.T) {
for i := 2; i <= 20; i++ {
assert.True(t, buf.Consume(int64(i)))
}
assert.False(t, buf.Consume(-1))
})
})
}

func BenchmarkMessageIDBuf(b *testing.B) {
buf := NewMessageIDBuf(100)
for i := 0; i < b.N; i++ {
buf.Consume(int64(i))
}
}
28 changes: 15 additions & 13 deletions mtproto/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ type Conn struct {

// Wrappers for external world, like current time, logs or PRNG.
// Should be immutable.
clock clock.Clock
rand io.Reader
cipher Cipher
log *zap.Logger
messageID MessageIDSource
clock clock.Clock
rand io.Reader
cipher Cipher
log *zap.Logger
messageID MessageIDSource
messageIDBuf *proto.MessageIDBuf // replay attack protection

// use session() to access authKey, salt or sessionID.
sessionMux sync.RWMutex
Expand Down Expand Up @@ -92,14 +93,15 @@ func New(addr string, opt Options) *Conn {
opt.setDefaults()

conn := &Conn{
addr: addr,
transport: opt.Transport,
clock: opt.Clock,
rand: opt.Random,
cipher: opt.Cipher,
log: opt.Logger,
ping: map[int64]func(){},
messageID: opt.MessageID,
addr: addr,
transport: opt.Transport,
clock: opt.Clock,
rand: opt.Random,
cipher: opt.Cipher,
log: opt.Logger,
ping: map[int64]func(){},
messageID: opt.MessageID,
messageIDBuf: proto.NewMessageIDBuf(100),

ackSendChan: make(chan int64),
ackInterval: opt.AckInterval,
Expand Down
3 changes: 3 additions & 0 deletions mtproto/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ func (c *Conn) read(ctx context.Context, b *bin.Buffer) (*crypto.EncryptedMessag
if err := checkMessageID(c.clock.Now(), msg.MessageID); err != nil {
return nil, xerrors.Errorf("bad message id: %w", err)
}
if !c.messageIDBuf.Consume(msg.MessageID) {
return nil, xerrors.Errorf("duplicate or too low message id: %w", errRejected)
}

return msg, nil
}
Expand Down

0 comments on commit 5b8019f

Please sign in to comment.