Skip to content

Commit

Permalink
First implementation of atomic noise.ReadWriter.
Browse files Browse the repository at this point in the history
* Removed locks from noise.Noise.

* Added read+decrypt and encrypt+write operation locks to noise.ReadWriter.

* Test noise.ReadWriter concurrently.

* Updated codebase to accomodate changes to noise library.
  • Loading branch information
林志宇 committed May 7, 2019
1 parent e0a77f0 commit 7051916
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 117 deletions.
34 changes: 5 additions & 29 deletions internal/noise/net.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package noise

import (
"bytes"
"errors"
"io"
"math"
"net"
"net/rpc"
Expand Down Expand Up @@ -147,10 +147,7 @@ func (a Addr) String() string {
// Conn wraps a net.Conn and encrypts the connection with noise.
type Conn struct {
net.Conn
buf bytes.Buffer
rMu sync.Mutex
wMu sync.Mutex
ns *ReadWriter
ns *ReadWriter
}

// WrapConn wraps a provided net.Conn with noise.
Expand All @@ -164,36 +161,15 @@ func WrapConn(conn net.Conn, ns *Noise, hsTimeout time.Duration) (*Conn, error)

// Read reads from the noise-encrypted connection.
func (c *Conn) Read(b []byte) (int, error) {
c.rMu.Lock()
defer c.rMu.Unlock()
// First check buffer.
if c.buf.Len() > 0 {
return c.buf.Read(b)
}
// Grab packet, and copy to 'b'.
// If packet is too large, copy len(b) and Write the rest to buffer.
plainText, err := c.ns.ReadPacketUnsafe()
if err != nil {
return 0, err
}
n := copy(b, plainText)
if n < len(plainText) {
c.buf.Write(plainText[n:]) // Will panic if buffer is too large.
}
return n, nil
return c.ns.Read(b)
}

// Write writes to the noise-encrypted connection.
func (c *Conn) Write(b []byte) (int, error) {
if len(b) > math.MaxUint16 {
return 0, ErrPacketTooBig
}
c.wMu.Lock()
defer c.wMu.Unlock()
if _, err := c.ns.WriteUnsafe(b); err != nil {
return 0, err
return 0, io.ErrShortWrite
}
return len(b), nil
return c.ns.Write(b)
}

// LocalAddr returns the local address of the connection.
Expand Down
36 changes: 7 additions & 29 deletions internal/noise/noise.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package noise

import (
"crypto/rand"
"sync"
"sync/atomic"

"github.com/flynn/noise"

Expand All @@ -22,6 +20,7 @@ type Config struct {
}

// Noise handles the handshake and the frame's cryptography.
// All operations on Noise are not guaranteed to be thread-safe.
type Noise struct {
pk cipher.PubKey
sk cipher.SecKey
Expand All @@ -32,11 +31,8 @@ type Noise struct {
enc *noise.CipherState
dec *noise.CipherState

encN uint32
decN uint32

encMu sync.Mutex
decMu sync.Mutex
encN uint32 // counter to inform encrypting CipherState to re-key
decN uint32 // counter to inform decrypting CipherState to re-key
}

// New creates a new Noise with:
Expand Down Expand Up @@ -123,41 +119,23 @@ func (ns *Noise) RemoteStatic() cipher.PubKey {
// EncryptUnsafe encrypts plaintext without interlocking, should only
// be used with external lock.
func (ns *Noise) EncryptUnsafe(plaintext []byte) []byte {
atomic.AddUint32(&ns.encN, 1)
if atomic.CompareAndSwapUint32(&ns.encN, packetsTillRekey, 0) {
if ns.encN++; ns.encN > packetsTillRekey {
ns.enc.Rekey()
ns.encN = 0
}

return ns.enc.Encrypt(nil, nil, plaintext)
}

// Encrypt encrypts plaintext.
func (ns *Noise) Encrypt(plaintext []byte) []byte {
ns.encMu.Lock()
res := ns.EncryptUnsafe(plaintext)
ns.encMu.Unlock()
return res
}

// DecryptUnsafe decrypts ciphertext without interlocking, should only
// be used with external lock.
func (ns *Noise) DecryptUnsafe(ciphertext []byte) ([]byte, error) {
atomic.AddUint32(&ns.decN, 1)
if atomic.CompareAndSwapUint32(&ns.decN, packetsTillRekey, 0) {
if ns.decN++; ns.decN > packetsTillRekey {
ns.dec.Rekey()
ns.decN = 0
}

return ns.dec.Decrypt(nil, nil, ciphertext)
}

// Decrypt decrypts ciphertext.
func (ns *Noise) Decrypt(ciphertext []byte) ([]byte, error) {
ns.decMu.Lock()
res, err := ns.DecryptUnsafe(ciphertext)
ns.decMu.Unlock()
return res, err
}

// HandshakeFinished indicate whether handshake was completed.
func (ns *Noise) HandshakeFinished() bool {
return ns.hs.MessageIndex() == len(ns.pattern.Messages)
Expand Down
16 changes: 8 additions & 8 deletions internal/noise/noise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ func TestKKAndSecp256k1(t *testing.T) {
require.True(t, nI.HandshakeFinished())
require.True(t, nR.HandshakeFinished())

encrypted := nI.Encrypt([]byte("foo"))
decrypted, err := nR.Decrypt(encrypted)
encrypted := nI.EncryptUnsafe([]byte("foo"))
decrypted, err := nR.DecryptUnsafe(encrypted)
require.NoError(t, err)
assert.Equal(t, []byte("foo"), decrypted)

encrypted = nR.Encrypt([]byte("bar"))
decrypted, err = nI.Decrypt(encrypted)
encrypted = nR.EncryptUnsafe([]byte("bar"))
decrypted, err = nI.DecryptUnsafe(encrypted)
require.NoError(t, err)
assert.Equal(t, []byte("bar"), decrypted)

Expand Down Expand Up @@ -105,13 +105,13 @@ func TestXKAndSecp256k1(t *testing.T) {
require.True(t, nI.HandshakeFinished())
require.True(t, nR.HandshakeFinished())

encrypted := nI.Encrypt([]byte("foo"))
decrypted, err := nR.Decrypt(encrypted)
encrypted := nI.EncryptUnsafe([]byte("foo"))
decrypted, err := nR.DecryptUnsafe(encrypted)
require.NoError(t, err)
assert.Equal(t, []byte("foo"), decrypted)

encrypted = nR.Encrypt([]byte("bar"))
decrypted, err = nI.Decrypt(encrypted)
encrypted = nR.EncryptUnsafe([]byte("bar"))
decrypted, err = nI.DecryptUnsafe(encrypted)
require.NoError(t, err)
assert.Equal(t, []byte("bar"), decrypted)

Expand Down
66 changes: 24 additions & 42 deletions internal/noise/read_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package noise
import (
"errors"
"io"
"sync"
"time"

"github.com/skycoin/skywire/internal/ioutil"
Expand All @@ -13,63 +14,44 @@ import (
type ReadWriter struct {
lrw *ioutil.LenReadWriter
ns *Noise

rMx sync.Mutex
wMx sync.Mutex
}

// NewReadWriter constructs a new ReadWriter.
func NewReadWriter(rw io.ReadWriter, ns *Noise) *ReadWriter {
return &ReadWriter{ioutil.NewLenReadWriter(rw), ns}
}

// ReadPacket returns single received len prepended packet.
func (rw *ReadWriter) ReadPacket() (data []byte, err error) {
data, err = rw.lrw.ReadPacket()
if err != nil {
return
return &ReadWriter{
lrw: ioutil.NewLenReadWriter(rw),
ns: ns,
}

return rw.ns.Decrypt(data)
}

// ReadPacketUnsafe returns single received len prepended packet using DecryptUnsafe.
func (rw *ReadWriter) ReadPacketUnsafe() (data []byte, err error) {
data, err = rw.lrw.ReadPacket()
if err != nil {
return
}

return rw.ns.DecryptUnsafe(data)
}
func (rw *ReadWriter) Read(p []byte) (int, error) {
rw.rMx.Lock()
defer rw.rMx.Unlock()

func (rw *ReadWriter) Read(p []byte) (n int, err error) {
var data []byte
data, err = rw.ReadPacket()
ciphertext, err := rw.lrw.ReadPacket()
if err != nil {
return
return 0, err
}

if len(data) > len(p) {
err = io.ErrShortBuffer
return
plaintext, err := rw.ns.DecryptUnsafe(ciphertext)
if err != nil {
return 0, err
}

return copy(p, data), nil
}

// WriteUnsafe implements io.Writer using EncryptUnsafe.
func (rw *ReadWriter) WriteUnsafe(p []byte) (n int, err error) {
encrypted := rw.ns.EncryptUnsafe(p)
n, err = rw.lrw.Write(encrypted)
if n != len(encrypted) {
err = io.ErrShortWrite
return
if len(plaintext) > len(p) {
return 0, io.ErrShortBuffer
}
return len(p), err
return copy(p, plaintext), nil
}

func (rw *ReadWriter) Write(p []byte) (n int, err error) {
encrypted := rw.ns.Encrypt(p)
n, err = rw.lrw.Write(encrypted)
if n != len(encrypted) {
rw.wMx.Lock()
defer rw.wMx.Unlock()

ciphertext := rw.ns.EncryptUnsafe(p)
n, err = rw.lrw.Write(ciphertext)
if n != len(ciphertext) {
err = io.ErrShortWrite
return
}
Expand Down
90 changes: 90 additions & 0 deletions internal/noise/read_writer_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package noise

import (
"fmt"
"net"
"testing"
"time"
Expand All @@ -11,6 +12,95 @@ import (
"github.com/skycoin/skywire/pkg/cipher"
)

func TestNewReadWriter(t *testing.T) {

type Result struct {
n int
err error
b []byte
}

t.Run("concurrent", func(t *testing.T) {
aPK, aSK := cipher.GenerateKeyPair()
bPK, bSK := cipher.GenerateKeyPair()

aNs, err := KKAndSecp256k1(Config{
LocalPK: aPK,
LocalSK: aSK,
RemotePK: bPK,
Initiator: true,
})
require.NoError(t, err)

bNs, err := KKAndSecp256k1(Config{
LocalPK: bPK,
LocalSK: bSK,
RemotePK: aPK,
Initiator: false,
})

aConn, bConn := net.Pipe()
defer func() {
_ = aConn.Close() //nolint:errcheck
_ = bConn.Close() //nolint:errcheck
}()

aRW := NewReadWriter(aConn, aNs)
bRW := NewReadWriter(bConn, bNs)

hsCh := make(chan error, 2)
defer close(hsCh)
go func() { hsCh <- aRW.Handshake(time.Second) }()
go func() { hsCh <- bRW.Handshake(time.Second) }()
require.NoError(t, <-hsCh)
require.NoError(t, <-hsCh)

const groupSize = 10
const totalGroups = 5
const msgCount = totalGroups * groupSize

writes := make([][]byte, msgCount)

wCh := make(chan Result, msgCount)
defer close(wCh)
rCh := make(chan Result, msgCount)
defer close(rCh)

for i := 0; i < msgCount; i++ {
writes[i] = []byte(fmt.Sprintf("this is message: %d", i))
}

for i := 0; i < totalGroups; i++ {
go func(i int) {
for j := 0; j < groupSize; j++ {
go func(i, j int) {
b := writes[i*j]
n, err := aRW.Write(b)
wCh <- Result{n: n, err: err, b: b}
}(i, j)
go func(i, j int) {
buf := make([]byte, 100)
n, err := bRW.Read(buf)
rCh <- Result{n: n, err: err, b: buf[:n]}
}(i, j)
}
}(i)
}

for i := 0; i < msgCount; i++ {
w := <-wCh
fmt.Printf("write_result[%d]: b(%s) err(%v)\n", i, string(w.b), w.err)
assert.NoError(t, w.err)
assert.True(t, w.n > 0)

r := <-rCh
fmt.Printf(" read_result[%d]: b(%s) err(%v)\n", i, string(r.b), r.err)
assert.NoError(t, r.err)
assert.True(t, r.n > 0)
}
})
}

func TestReadWriterKKPattern(t *testing.T) {
pkI, skI := cipher.GenerateKeyPair()
pkR, skR := cipher.GenerateKeyPair()
Expand Down
4 changes: 2 additions & 2 deletions pkg/messaging/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (c *channel) Write(p []byte) (n int, err error) {
error
})
go func() {
data := c.noise.Encrypt(p)
data := c.noise.EncryptUnsafe(p)
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(len(data)))
n, err := c.link.Send(c.ID, append(buf, data...))
Expand Down Expand Up @@ -181,7 +181,7 @@ func (c *channel) readEncrypted(ctx context.Context, p []byte) (n int, err error
return 0, err
}

data, err := c.noise.Decrypt(encrypted)
data, err := c.noise.DecryptUnsafe(encrypted)
if err != nil {
return 0, err
}
Expand Down
Loading

0 comments on commit 7051916

Please sign in to comment.