diff --git a/internal/noise/noise.go b/internal/noise/noise.go index e58572c927..bdd1ee92ef 100644 --- a/internal/noise/noise.go +++ b/internal/noise/noise.go @@ -2,14 +2,16 @@ package noise import ( "crypto/rand" + "encoding/binary" + + "github.com/skycoin/skycoin/src/util/logging" "github.com/flynn/noise" "github.com/skycoin/skywire/pkg/cipher" ) -// packetsTillRekey is the number of packages after which we want to rekey for the noise protocol -const packetsTillRekey = 10 +var logger = logging.MustGetLogger("noise") // Config hold noise parameters. type Config struct { @@ -31,8 +33,11 @@ type Noise struct { enc *noise.CipherState dec *noise.CipherState - encN uint32 // counter to inform encrypting CipherState to re-key - decN uint32 // counter to inform decrypting CipherState to re-key + seq uint32 // sequence number, used as nonce for both encrypting and decrypting + previousSeq uint32 // sequence number last decrypted, check in order to avoid reply attacks + highestPrevious uint32 // highest sequence number received from the other end + //encN uint32 // counter to inform encrypting CipherState to re-key + //decN uint32 // counter to inform decrypting CipherState to re-key } // New creates a new Noise with: @@ -119,21 +124,28 @@ func (ns *Noise) RemoteStatic() cipher.PubKey { // EncryptUnsafe encrypts plaintext without interlocking, should only // be used with external lock. func (ns *Noise) EncryptUnsafe(plaintext []byte) []byte { - if ns.encN++; ns.encN > packetsTillRekey { - ns.enc.Rekey() - ns.encN = 0 - } - return ns.enc.Encrypt(nil, nil, plaintext) + ns.seq++ + seq := make([]byte, 4) + binary.BigEndian.PutUint32(seq, ns.seq) + + return append(seq, ns.enc.Cipher().Encrypt(nil, uint64(ns.seq), nil, plaintext)...) } // DecryptUnsafe decrypts ciphertext without interlocking, should only // be used with external lock. func (ns *Noise) DecryptUnsafe(ciphertext []byte) ([]byte, error) { - if ns.decN++; ns.decN > packetsTillRekey { - ns.dec.Rekey() - ns.decN = 0 + seq := binary.BigEndian.Uint32(ciphertext[:4]) + if seq <= ns.previousSeq { + logger.Warnf("current seq: %s is not higher than previous one: %s. "+ + "Highest sequence number received so far is: %s", ns.seq, ns.previousSeq, ns.highestPrevious) + } else { + if ns.previousSeq > ns.highestPrevious { + ns.highestPrevious = seq + } + ns.previousSeq = seq } - return ns.dec.Decrypt(nil, nil, ciphertext) + + return ns.dec.Cipher().Decrypt(nil, uint64(seq), nil, ciphertext[4:]) } // HandshakeFinished indicate whether handshake was completed. diff --git a/pkg/messaging/channel_test.go b/pkg/messaging/channel_test.go index 49b247c755..d7aebe3153 100644 --- a/pkg/messaging/channel_test.go +++ b/pkg/messaging/channel_test.go @@ -76,26 +76,38 @@ func TestChannelWrite(t *testing.T) { pk, sk := cipher.GenerateKeyPair() in, out := net.Pipe() + l, err := NewLink(in, &LinkConfig{Public: pk}, nil) require.NoError(t, err) - c, err := newChannel(true, sk, remotePK, l) require.NoError(t, err) c.SetID(10) rn := handshakeChannel(t, c, remotePK, remoteSK) - buf := make([]byte, 25) - go out.Read(buf) // nolint + var ( + readBuf = make([]byte, 29) + readErr error + readDone = make(chan struct{}) + ) + go func() { + _, readErr = out.Read(readBuf) + close(readDone) + }() + n, err := c.Write([]byte("foo")) require.NoError(t, err) assert.Equal(t, 3, n) - assert.Equal(t, FrameTypeSend, FrameType(buf[2])) - assert.Equal(t, byte(10), buf[3]) - require.Equal(t, uint16(19), binary.BigEndian.Uint16(buf[4:])) + <-readDone + assert.NoError(t, readErr) + assert.Equal(t, FrameTypeSend, FrameType(readBuf[2])) + assert.Equal(t, byte(10), readBuf[3]) + + // Encoded length should be length of encrypted payload "foo". + require.Equal(t, uint16(23), binary.BigEndian.Uint16(readBuf[4:])) - data, err := rn.DecryptUnsafe(buf[6:]) + data, err := rn.DecryptUnsafe(readBuf[6:]) require.NoError(t, err) assert.Equal(t, []byte("foo"), data) diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index d4f9c7aa0f..d4b3eb88ca 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -192,10 +192,10 @@ func TestRouterApp(t *testing.T) { tr2 := m2.Transport(tr.ID) go proto.Send(app.FrameSend, &app.Packet{Addr: &app.LoopAddr{Port: 6, Remote: *raddr}, Payload: []byte("bar")}, nil) // nolint: errcheck - packet := make(routing.Packet, 25) + packet := make(routing.Packet, 29) _, err = tr2.Read(packet) require.NoError(t, err) - assert.Equal(t, uint16(19), packet.Size()) + assert.Equal(t, uint16(23), packet.Size()) assert.Equal(t, routing.RouteID(4), packet.RouteID()) decrypted, err := ni2.DecryptUnsafe(packet.Payload()) require.NoError(t, err) diff --git a/pkg/transport/manager_test.go b/pkg/transport/manager_test.go index c0443c6ad2..dbe2b110e3 100644 --- a/pkg/transport/manager_test.go +++ b/pkg/transport/manager_test.go @@ -68,6 +68,8 @@ func TestTransportManager(t *testing.T) { tr2, err := m2.CreateTransport(context.TODO(), pk1, "mock", true) require.NoError(t, err) + time.Sleep(time.Second) + tr1 := m1.Transport(tr2.ID) require.NotNil(t, tr1)