Skip to content

Commit

Permalink
Merge branch 'feature/skymsg' of https://github.com/evanlinjin/skywire
Browse files Browse the repository at this point in the history
…into feature/skymsg-tests
  • Loading branch information
Darkren committed Jun 11, 2019
2 parents 02a15a5 + 6933890 commit 4a4fabb
Show file tree
Hide file tree
Showing 18 changed files with 824 additions and 675 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*.dylib
*.test
*.out
.DS_Store

.idea/

Expand Down
27 changes: 23 additions & 4 deletions internal/ioutil/ack_waiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"encoding/binary"
"io"
"math"
"sync"
)
Expand Down Expand Up @@ -43,6 +44,22 @@ func (w *Uint16AckWaiter) RandSeq() error {
return nil
}

func (w *Uint16AckWaiter) stopWaiter(seq Uint16Seq) {
if waiter := w.waiters[seq]; waiter != nil {
close(waiter)
w.waiters[seq] = nil
}
}

// StopAll stops all active waiters.
func (w *Uint16AckWaiter) StopAll() {
w.mx.Lock()
for seq := range w.waiters {
w.stopWaiter(Uint16Seq(seq))
}
w.mx.Unlock()
}

// Wait performs the given action, and waits for given seq to be Done.
func (w *Uint16AckWaiter) Wait(ctx context.Context, action func(seq Uint16Seq) error) (err error) {
ackCh := make(chan struct{})
Expand All @@ -58,16 +75,18 @@ func (w *Uint16AckWaiter) Wait(ctx context.Context, action func(seq Uint16Seq) e
}

select {
case <-ackCh:
case _, ok := <-ackCh:
if !ok {
// waiter stopped manually.
return io.ErrClosedPipe
}
case <-ctx.Done():
err = ctx.Err()
}

w.mx.Lock()
close(ackCh)
w.waiters[seq] = nil
w.stopWaiter(seq)
w.mx.Unlock()

return err
}

Expand Down
2 changes: 1 addition & 1 deletion internal/ioutil/ack_waiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func TestUint16AckWaiter_Wait(t *testing.T) {
seqChan := make(chan Uint16Seq)
defer close(seqChan)
for i := 0; i < 64; i++ {
go w.Wait(context.TODO(), func(seq Uint16Seq) error { //nolint:errcheck
go w.Wait(context.TODO(), func(seq Uint16Seq) error { //nolint:errcheck,unparam
seqChan <- seq
return nil
})
Expand Down
23 changes: 23 additions & 0 deletions internal/ioutil/atomic_bool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package ioutil

import "sync/atomic"

// AtomicBool implements a thread-safe boolean value.
type AtomicBool struct {
flag int32
}

// Set set's the boolean to specified value
// and returns true if the value is changed.
func (b *AtomicBool) Set(v bool) bool {
newF := int32(0)
if v {
newF = 1
}
return newF != atomic.SwapInt32(&b.flag, newF)
}

// Get obtains the current boolean value.
func (b *AtomicBool) Get() bool {
return atomic.LoadInt32(&b.flag) == 1
}
190 changes: 122 additions & 68 deletions pkg/dmsg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"sync"
"time"

"github.com/sirupsen/logrus"

"github.com/skycoin/skycoin/src/util/logging"

"github.com/skycoin/skywire/internal/noise"
Expand Down Expand Up @@ -36,94 +38,119 @@ type ClientConn struct {
// locally-initiated tps use an even tp_id between local and intermediary dms_server.
nextInitID uint16

// map of transports to remote dms_clients (key: tp_id, val: transport).
// Transports: map of transports to remote dms_clients (key: tp_id, val: transport).
tps [math.MaxUint16 + 1]*Transport
mx sync.RWMutex // to protect tps.
mx sync.RWMutex // to protect tps

wg sync.WaitGroup
done chan struct{}
once sync.Once
wg sync.WaitGroup
}

// NewClientConn creates a new ClientConn.
func NewClientConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubKey) *ClientConn {
cc := &ClientConn{log: log, Conn: conn, local: local, remoteSrv: remote, nextInitID: randID(true)}
cc := &ClientConn{
log: log,
Conn: conn,
local: local,
remoteSrv: remote,
nextInitID: randID(true),
done: make(chan struct{}),
}
cc.wg.Add(1)
return cc
}

func (c *ClientConn) delTp(id uint16) {
c.mx.Lock()
c.tps[id] = nil
c.mx.Unlock()
}

func (c *ClientConn) setTp(tp *Transport) {
c.mx.Lock()
c.tps[tp.id] = tp
c.mx.Unlock()
func (c *ClientConn) getNextInitID(ctx context.Context) (uint16, error) {
for {
select {
case <-c.done:
return 0, ErrClientClosed
case <-ctx.Done():
return 0, ctx.Err()
default:
if ch := c.tps[c.nextInitID]; ch != nil && !ch.IsClosed() {
c.nextInitID += 2
continue
}
c.tps[c.nextInitID] = nil
id := c.nextInitID
c.nextInitID = id + 2
return id, nil
}
}
}

// keeps record of a locally-initiated tp to 'clientPK'.
// assigns an even tp_id and keeps track of it in tps map.
func (c *ClientConn) addTp(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) {
c.mx.Lock()
defer c.mx.Unlock()

for {
if ch := c.tps[c.nextInitID]; ch == nil || ch.IsDone() {
break
}
c.nextInitID += 2
id, err := c.getNextInitID(ctx)
if err != nil {
return nil, err
}
tp := NewTransport(c.Conn, c.log, c.local, clientPK, id, c.delTp)
c.tps[id] = tp
return tp, nil
}

select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
func (c *ClientConn) acceptTp(clientPK cipher.PubKey, id uint16) (*Transport, error) {
tp := NewTransport(c.Conn, c.log, c.local, clientPK, id, c.delTp)

c.mx.Lock()
c.tps[tp.id] = tp
c.mx.Unlock()

if err := tp.WriteAccept(); err != nil {
return nil, err
}
return tp, nil
}

id := c.nextInitID
c.nextInitID = id + 2
ch := NewTransport(c.Conn, c.log, c.local, clientPK, id)
c.tps[id] = ch
return ch, nil
func (c *ClientConn) delTp(id uint16) {
c.mx.Lock()
c.tps[id] = nil
c.mx.Unlock()
}

func (c *ClientConn) getTp(id uint16) (*Transport, bool) {
c.mx.RLock()
tp := c.tps[id]
c.mx.RUnlock()
ok := tp != nil && !tp.IsDone()
ok := tp != nil && !tp.IsClosed()
return tp, ok
}

func (c *ClientConn) handleRequestFrame(ctx context.Context, accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) {
// remote-initiated tps should:
// remotely-initiated tps should:
// - have a payload structured as 'init_pk:resp_pk'.
// - resp_pk should be of local client.
// - use an odd tp_id with the intermediary dms_server.
// - use an odd tp_id with the intermediary dmsg_server.
initPK, respPK, ok := splitPKs(p)
if !ok || respPK != c.local || isInitiatorID(id) {
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
c.Close()
return initPK, err
}
return initPK, ErrRequestCheckFailed
}

tp := NewTransport(c.Conn, c.log, c.local, initPK, id)
if err := tp.Handshake(ctx); err != nil {
// return err here as response handshake is send via ClientConn and that shouldn't fail.
c.Close()
tp, err := c.acceptTp(initPK, id)
if err != nil {
return initPK, err
}
c.setTp(tp)
go tp.Serve()

select {
case <-c.done:
_ = tp.Close() //nolint:errcheck
return initPK, ErrClientClosed

case <-ctx.Done():
_ = tp.Close() //nolint:errcheck
return initPK, ctx.Err()

case accept <- tp:
return initPK, nil
}
return initPK, nil
}

// Serve handles incoming frames.
Expand All @@ -133,11 +160,18 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err e

log := c.log.WithField("remoteServer", c.remoteSrv)

log.WithField("connCount", incrementServeCount()).Infoln("ServingConn")
log.WithField("connCount", incrementServeCount()).
Infoln("ServingConn")
defer func() {
log.WithError(err).WithField("connCount", decrementServeCount()).Infoln("ClosingConn")
log.WithError(err).
WithField("connCount", decrementServeCount()).
Infoln("ClosingConn")
}()

closeConn := func(log *logrus.Entry) {
log.WithError(c.Close()).Warn("ClosingConnection")
}

for {
f, err := readFrame(c.Conn)
if err != nil {
Expand All @@ -147,56 +181,76 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err e

ft, id, p := f.Disassemble()

// If tp of tp_id exists, attempt to forward frame to tp.
// delete tp on any failure.

if tp, ok := c.getTp(id); ok {
// If tp of tp_id exists, attempt to forward frame to tp.
// delete tp on any failure.
if !tp.InjectRead(f) {
log.WithField("remoteClient", tp.remote).Infoln("FrameTrashed")
c.delTp(id)
if err := tp.Inject(f); err != nil {
log.WithError(err).Warnf("Rejected [%s]: Transport closed.", ft)
}
log.WithField("remoteClient", tp.remote).Infoln("FrameInjected")
continue
}

// if tp does not exist, frame should be 'REQUEST'.
// otherwise, handle any unexpected frames accordingly.

c.delTp(id) // rm tp in case closed tp is not fully removed.

switch ft {
case RequestType:
// TODO(evanlinjin): Allow for REQUEST frame handling to be done in goroutine.
// Currently this causes issues (probably because we need ACK frames).
initPK, err := c.handleRequestFrame(ctx, accept, id, p)
if err != nil {
log.WithField("remoteClient", initPK).WithError(err).Infoln("FrameRejected")
if err == ErrRequestCheckFailed {
continue
c.wg.Add(1)
go func(log *logrus.Entry) {
defer c.wg.Done()

initPK, err := c.handleRequestFrame(ctx, accept, id, p)
if err != nil {
log.
WithField("remoteClient", initPK).
WithError(err).
Infoln("Rejected [REQUEST]")
if isWriteError(err) || err == ErrClientClosed {
closeConn(log)
}
return
}
return err
}
log.WithField("remoteClient", initPK).Infoln("FrameAccepted")
case CloseType:
log.Infoln("FrameIgnored")
log.
WithField("remoteClient", initPK).
Infoln("Accepted [REQUEST]")
}(log)

default:
log.Infoln("FrameUnexpected")
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return err
log.Infof("Ignored [%s]: No transport of given ID.", ft)
if ft != CloseType {
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return err
}
}
}
}
}

// DialTransport dials a transport to remote dms_client.
func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) {
c.log.Warn("DialTransport...")
tp, err := c.addTp(ctx, clientPK)
if err != nil {
return nil, err
}
return tp, tp.Handshake(ctx)
if err := tp.WriteRequest(); err != nil {
return nil, err
}
if err := tp.ReadAccept(ctx); err != nil {
return nil, err
}
c.log.Warn("DialTransport: Accepted.")
go tp.Serve()
return tp, nil
}

// Close closes the connection to dms_server.
func (c *ClientConn) Close() error {
c.log.Infof("closingLink: remoteSrv(%v)", c.remoteSrv)
c.log.WithField("remoteServer", c.remoteSrv).Infoln("ClosingConnection")
c.once.Do(func() { close(c.done) })
c.mx.Lock()
for _, tp := range c.tps {
if tp != nil {
Expand Down Expand Up @@ -377,7 +431,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey)
c.setConn(ctx, conn)
go func() {
if err := conn.Serve(ctx, c.accept); err != nil {
conn.log.WithError(err).WithField("dms_server", srvPK).Warn("connected with dms_server closed")
conn.log.WithError(err).WithField("remoteServer", srvPK).Warn("connected with server closed")
c.delConn(ctx, srvPK)

// reconnect logic.
Expand Down
Loading

0 comments on commit 4a4fabb

Please sign in to comment.