diff --git a/internal/noise/net.go b/internal/noise/net.go index 06a4659df6..77383e1830 100644 --- a/internal/noise/net.go +++ b/internal/noise/net.go @@ -1,8 +1,8 @@ package noise import ( - "bytes" "errors" + "io" "math" "net" "net/rpc" @@ -147,10 +147,7 @@ func (a Addr) String() string { // Conn wraps a net.Conn and encrypts the connection with noise. type Conn struct { net.Conn - buf bytes.Buffer - rMu sync.Mutex - wMu sync.Mutex - ns *ReadWriter + ns *ReadWriter } // WrapConn wraps a provided net.Conn with noise. @@ -164,36 +161,15 @@ func WrapConn(conn net.Conn, ns *Noise, hsTimeout time.Duration) (*Conn, error) // Read reads from the noise-encrypted connection. func (c *Conn) Read(b []byte) (int, error) { - c.rMu.Lock() - defer c.rMu.Unlock() - // First check buffer. - if c.buf.Len() > 0 { - return c.buf.Read(b) - } - // Grab packet, and copy to 'b'. - // If packet is too large, copy len(b) and Write the rest to buffer. - plainText, err := c.ns.ReadPacketUnsafe() - if err != nil { - return 0, err - } - n := copy(b, plainText) - if n < len(plainText) { - c.buf.Write(plainText[n:]) // Will panic if buffer is too large. - } - return n, nil + return c.ns.Read(b) } // Write writes to the noise-encrypted connection. func (c *Conn) Write(b []byte) (int, error) { if len(b) > math.MaxUint16 { - return 0, ErrPacketTooBig - } - c.wMu.Lock() - defer c.wMu.Unlock() - if _, err := c.ns.WriteUnsafe(b); err != nil { - return 0, err + return 0, io.ErrShortWrite } - return len(b), nil + return c.ns.Write(b) } // LocalAddr returns the local address of the connection. diff --git a/internal/noise/noise.go b/internal/noise/noise.go index ab0d82620d..e58572c927 100644 --- a/internal/noise/noise.go +++ b/internal/noise/noise.go @@ -2,8 +2,6 @@ package noise import ( "crypto/rand" - "sync" - "sync/atomic" "github.com/flynn/noise" @@ -22,6 +20,7 @@ type Config struct { } // Noise handles the handshake and the frame's cryptography. +// All operations on Noise are not guaranteed to be thread-safe. type Noise struct { pk cipher.PubKey sk cipher.SecKey @@ -32,11 +31,8 @@ type Noise struct { enc *noise.CipherState dec *noise.CipherState - encN uint32 - decN uint32 - - encMu sync.Mutex - decMu sync.Mutex + encN uint32 // counter to inform encrypting CipherState to re-key + decN uint32 // counter to inform decrypting CipherState to re-key } // New creates a new Noise with: @@ -123,41 +119,23 @@ func (ns *Noise) RemoteStatic() cipher.PubKey { // EncryptUnsafe encrypts plaintext without interlocking, should only // be used with external lock. func (ns *Noise) EncryptUnsafe(plaintext []byte) []byte { - atomic.AddUint32(&ns.encN, 1) - if atomic.CompareAndSwapUint32(&ns.encN, packetsTillRekey, 0) { + if ns.encN++; ns.encN > packetsTillRekey { ns.enc.Rekey() + ns.encN = 0 } - return ns.enc.Encrypt(nil, nil, plaintext) } -// Encrypt encrypts plaintext. -func (ns *Noise) Encrypt(plaintext []byte) []byte { - ns.encMu.Lock() - res := ns.EncryptUnsafe(plaintext) - ns.encMu.Unlock() - return res -} - // DecryptUnsafe decrypts ciphertext without interlocking, should only // be used with external lock. func (ns *Noise) DecryptUnsafe(ciphertext []byte) ([]byte, error) { - atomic.AddUint32(&ns.decN, 1) - if atomic.CompareAndSwapUint32(&ns.decN, packetsTillRekey, 0) { + if ns.decN++; ns.decN > packetsTillRekey { ns.dec.Rekey() + ns.decN = 0 } - return ns.dec.Decrypt(nil, nil, ciphertext) } -// Decrypt decrypts ciphertext. -func (ns *Noise) Decrypt(ciphertext []byte) ([]byte, error) { - ns.decMu.Lock() - res, err := ns.DecryptUnsafe(ciphertext) - ns.decMu.Unlock() - return res, err -} - // HandshakeFinished indicate whether handshake was completed. func (ns *Noise) HandshakeFinished() bool { return ns.hs.MessageIndex() == len(ns.pattern.Messages) diff --git a/internal/noise/noise_test.go b/internal/noise/noise_test.go index fd7a11d3ed..d9841f4595 100644 --- a/internal/noise/noise_test.go +++ b/internal/noise/noise_test.go @@ -48,13 +48,13 @@ func TestKKAndSecp256k1(t *testing.T) { require.True(t, nI.HandshakeFinished()) require.True(t, nR.HandshakeFinished()) - encrypted := nI.Encrypt([]byte("foo")) - decrypted, err := nR.Decrypt(encrypted) + encrypted := nI.EncryptUnsafe([]byte("foo")) + decrypted, err := nR.DecryptUnsafe(encrypted) require.NoError(t, err) assert.Equal(t, []byte("foo"), decrypted) - encrypted = nR.Encrypt([]byte("bar")) - decrypted, err = nI.Decrypt(encrypted) + encrypted = nR.EncryptUnsafe([]byte("bar")) + decrypted, err = nI.DecryptUnsafe(encrypted) require.NoError(t, err) assert.Equal(t, []byte("bar"), decrypted) @@ -105,13 +105,13 @@ func TestXKAndSecp256k1(t *testing.T) { require.True(t, nI.HandshakeFinished()) require.True(t, nR.HandshakeFinished()) - encrypted := nI.Encrypt([]byte("foo")) - decrypted, err := nR.Decrypt(encrypted) + encrypted := nI.EncryptUnsafe([]byte("foo")) + decrypted, err := nR.DecryptUnsafe(encrypted) require.NoError(t, err) assert.Equal(t, []byte("foo"), decrypted) - encrypted = nR.Encrypt([]byte("bar")) - decrypted, err = nI.Decrypt(encrypted) + encrypted = nR.EncryptUnsafe([]byte("bar")) + decrypted, err = nI.DecryptUnsafe(encrypted) require.NoError(t, err) assert.Equal(t, []byte("bar"), decrypted) diff --git a/internal/noise/read_writer.go b/internal/noise/read_writer.go index 317fe0d7f2..25fb25c367 100644 --- a/internal/noise/read_writer.go +++ b/internal/noise/read_writer.go @@ -3,6 +3,7 @@ package noise import ( "errors" "io" + "sync" "time" "github.com/skycoin/skywire/internal/ioutil" @@ -13,63 +14,44 @@ import ( type ReadWriter struct { lrw *ioutil.LenReadWriter ns *Noise + + rMx sync.Mutex + wMx sync.Mutex } // NewReadWriter constructs a new ReadWriter. func NewReadWriter(rw io.ReadWriter, ns *Noise) *ReadWriter { - return &ReadWriter{ioutil.NewLenReadWriter(rw), ns} -} - -// ReadPacket returns single received len prepended packet. -func (rw *ReadWriter) ReadPacket() (data []byte, err error) { - data, err = rw.lrw.ReadPacket() - if err != nil { - return + return &ReadWriter{ + lrw: ioutil.NewLenReadWriter(rw), + ns: ns, } - - return rw.ns.Decrypt(data) } -// ReadPacketUnsafe returns single received len prepended packet using DecryptUnsafe. -func (rw *ReadWriter) ReadPacketUnsafe() (data []byte, err error) { - data, err = rw.lrw.ReadPacket() - if err != nil { - return - } - - return rw.ns.DecryptUnsafe(data) -} +func (rw *ReadWriter) Read(p []byte) (int, error) { + rw.rMx.Lock() + defer rw.rMx.Unlock() -func (rw *ReadWriter) Read(p []byte) (n int, err error) { - var data []byte - data, err = rw.ReadPacket() + ciphertext, err := rw.lrw.ReadPacket() if err != nil { - return + return 0, err } - - if len(data) > len(p) { - err = io.ErrShortBuffer - return + plaintext, err := rw.ns.DecryptUnsafe(ciphertext) + if err != nil { + return 0, err } - - return copy(p, data), nil -} - -// WriteUnsafe implements io.Writer using EncryptUnsafe. -func (rw *ReadWriter) WriteUnsafe(p []byte) (n int, err error) { - encrypted := rw.ns.EncryptUnsafe(p) - n, err = rw.lrw.Write(encrypted) - if n != len(encrypted) { - err = io.ErrShortWrite - return + if len(plaintext) > len(p) { + return 0, io.ErrShortBuffer } - return len(p), err + return copy(p, plaintext), nil } func (rw *ReadWriter) Write(p []byte) (n int, err error) { - encrypted := rw.ns.Encrypt(p) - n, err = rw.lrw.Write(encrypted) - if n != len(encrypted) { + rw.wMx.Lock() + defer rw.wMx.Unlock() + + ciphertext := rw.ns.EncryptUnsafe(p) + n, err = rw.lrw.Write(ciphertext) + if n != len(ciphertext) { err = io.ErrShortWrite return } diff --git a/internal/noise/read_writer_test.go b/internal/noise/read_writer_test.go index 31ab88e116..234de7ace4 100644 --- a/internal/noise/read_writer_test.go +++ b/internal/noise/read_writer_test.go @@ -1,6 +1,7 @@ package noise import ( + "fmt" "net" "testing" "time" @@ -11,6 +12,95 @@ import ( "github.com/skycoin/skywire/pkg/cipher" ) +func TestNewReadWriter(t *testing.T) { + + type Result struct { + n int + err error + b []byte + } + + t.Run("concurrent", func(t *testing.T) { + aPK, aSK := cipher.GenerateKeyPair() + bPK, bSK := cipher.GenerateKeyPair() + + aNs, err := KKAndSecp256k1(Config{ + LocalPK: aPK, + LocalSK: aSK, + RemotePK: bPK, + Initiator: true, + }) + require.NoError(t, err) + + bNs, err := KKAndSecp256k1(Config{ + LocalPK: bPK, + LocalSK: bSK, + RemotePK: aPK, + Initiator: false, + }) + + aConn, bConn := net.Pipe() + defer func() { + _ = aConn.Close() //nolint:errcheck + _ = bConn.Close() //nolint:errcheck + }() + + aRW := NewReadWriter(aConn, aNs) + bRW := NewReadWriter(bConn, bNs) + + hsCh := make(chan error, 2) + defer close(hsCh) + go func() { hsCh <- aRW.Handshake(time.Second) }() + go func() { hsCh <- bRW.Handshake(time.Second) }() + require.NoError(t, <-hsCh) + require.NoError(t, <-hsCh) + + const groupSize = 10 + const totalGroups = 5 + const msgCount = totalGroups * groupSize + + writes := make([][]byte, msgCount) + + wCh := make(chan Result, msgCount) + defer close(wCh) + rCh := make(chan Result, msgCount) + defer close(rCh) + + for i := 0; i < msgCount; i++ { + writes[i] = []byte(fmt.Sprintf("this is message: %d", i)) + } + + for i := 0; i < totalGroups; i++ { + go func(i int) { + for j := 0; j < groupSize; j++ { + go func(i, j int) { + b := writes[i*j] + n, err := aRW.Write(b) + wCh <- Result{n: n, err: err, b: b} + }(i, j) + go func(i, j int) { + buf := make([]byte, 100) + n, err := bRW.Read(buf) + rCh <- Result{n: n, err: err, b: buf[:n]} + }(i, j) + } + }(i) + } + + for i := 0; i < msgCount; i++ { + w := <-wCh + fmt.Printf("write_result[%d]: b(%s) err(%v)\n", i, string(w.b), w.err) + assert.NoError(t, w.err) + assert.True(t, w.n > 0) + + r := <-rCh + fmt.Printf(" read_result[%d]: b(%s) err(%v)\n", i, string(r.b), r.err) + assert.NoError(t, r.err) + assert.True(t, r.n > 0) + } + }) +} + func TestReadWriterKKPattern(t *testing.T) { pkI, skI := cipher.GenerateKeyPair() pkR, skR := cipher.GenerateKeyPair() diff --git a/pkg/messaging/channel.go b/pkg/messaging/channel.go index 0cc95aa7e2..8fc47d7f7c 100644 --- a/pkg/messaging/channel.go +++ b/pkg/messaging/channel.go @@ -90,7 +90,7 @@ func (c *channel) Write(p []byte) (n int, err error) { error }) go func() { - data := c.noise.Encrypt(p) + data := c.noise.EncryptUnsafe(p) buf := make([]byte, 2) binary.BigEndian.PutUint16(buf, uint16(len(data))) n, err := c.link.Send(c.ID, append(buf, data...)) @@ -181,7 +181,7 @@ func (c *channel) readEncrypted(ctx context.Context, p []byte) (n int, err error return 0, err } - data, err := c.noise.Decrypt(encrypted) + data, err := c.noise.DecryptUnsafe(encrypted) if err != nil { return 0, err } diff --git a/pkg/messaging/channel_test.go b/pkg/messaging/channel_test.go index aa117408af..6cb1bbab75 100644 --- a/pkg/messaging/channel_test.go +++ b/pkg/messaging/channel_test.go @@ -34,14 +34,14 @@ func TestChannelRead(t *testing.T) { require.Equal(t, ErrDeadlineExceeded, err) go func() { - data := rn.Encrypt([]byte("foo")) + data := rn.EncryptUnsafe([]byte("foo")) buf := make([]byte, 2) binary.BigEndian.PutUint16(buf, uint16(len(data))) buf = append(buf, data...) c.readChan <- buf[0:3] c.readChan <- buf[3:] - data = rn.Encrypt([]byte("foo")) + data = rn.EncryptUnsafe([]byte("foo")) buf = make([]byte, 2) binary.BigEndian.PutUint16(buf, uint16(len(data))) buf = append(buf, data...) @@ -95,7 +95,7 @@ func TestChannelWrite(t *testing.T) { assert.Equal(t, byte(10), buf[3]) require.Equal(t, uint16(19), binary.BigEndian.Uint16(buf[4:])) - data, err := rn.Decrypt(buf[6:]) + data, err := rn.DecryptUnsafe(buf[6:]) require.NoError(t, err) assert.Equal(t, []byte("foo"), data) diff --git a/pkg/router/router.go b/pkg/router/router.go index 2a5c3d5f1e..65213b151d 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -230,7 +230,7 @@ func (r *Router) consumePacket(payload []byte, rule routing.Rule) error { return errors.New("unknown loop") } - data, err := l.noise.Decrypt(payload) + data, err := l.noise.DecryptUnsafe(payload) if err != nil { return fmt.Errorf("noise: %s", err) } @@ -260,7 +260,7 @@ func (r *Router) forwardAppPacket(appConn *app.Protocol, packet *app.Packet) err return errors.New("unknown transport") } - p := routing.MakePacket(l.routeID, l.noise.Encrypt(packet.Payload)) + p := routing.MakePacket(l.routeID, l.noise.EncryptUnsafe(packet.Payload)) r.Logger.Infof("Forwarded App packet from LocalPort %d using route ID %d", packet.Addr.Port, l.routeID) _, err = tr.Write(p) return err diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 2b269c2d30..1470a8009d 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -197,11 +197,11 @@ func TestRouterApp(t *testing.T) { require.NoError(t, err) assert.Equal(t, uint16(19), packet.Size()) assert.Equal(t, routing.RouteID(4), packet.RouteID()) - decrypted, err := ni2.Decrypt(packet.Payload()) + decrypted, err := ni2.DecryptUnsafe(packet.Payload()) require.NoError(t, err) assert.Equal(t, []byte("bar"), decrypted) - _, err = tr2.Write(routing.MakePacket(routeID, ni2.Encrypt([]byte("foo")))) + _, err = tr2.Write(routing.MakePacket(routeID, ni2.EncryptUnsafe([]byte("foo")))) require.NoError(t, err) time.Sleep(100 * time.Millisecond) diff --git a/setup-node b/setup-node new file mode 100755 index 0000000000..3696a1cfb3 Binary files /dev/null and b/setup-node differ