Skip to content

Commit

Permalink
Improved dmsg logic.
Browse files Browse the repository at this point in the history
* Each dmsg.Transport now has it's own event loop.

* Made use of net.Buffers to make dmsg.Transport behave more like net.Conn
  • Loading branch information
林志宇 committed Jun 9, 2019
1 parent 63dc3f2 commit 23fcf30
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 268 deletions.
Binary file added pkg/.DS_Store
Binary file not shown.
268 changes: 58 additions & 210 deletions pkg/dmsg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@ import (
"context"
"errors"
"fmt"
"io"
"math"
"net"
"sync"
"time"

"github.com/skycoin/skywire/internal/ioutil"

"github.com/sirupsen/logrus"

"github.com/skycoin/skycoin/src/util/logging"
Expand All @@ -29,15 +26,6 @@ var (
ErrClientClosed = errors.New("client closed")
)

type dialWaiter struct {
tp *Transport
accept chan bool
}

func makeDialWaiter(tp *Transport) *dialWaiter {
return &dialWaiter{tp: tp, accept: make(chan bool)}
}

// ClientConn represents a connection between a dmsg.Client and dmsg.Server from a client's perspective.
type ClientConn struct {
log *logging.Logger
Expand All @@ -52,11 +40,7 @@ type ClientConn struct {

// Transports: map of transports to remote dms_clients (key: tp_id, val: transport).
tps [math.MaxUint16 + 1]*Transport

// Dial Waiters: locally-initiated transports awaiting remote confirmation.
dws map[uint16]*dialWaiter

mx sync.RWMutex // to protect tps and dws
mx sync.RWMutex // to protect tps

done chan struct{}
once sync.Once
Expand All @@ -71,14 +55,13 @@ func NewClientConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubK
local: local,
remoteSrv: remote,
nextInitID: randID(true),
dws: make(map[uint16]*dialWaiter),
done: make(chan struct{}),
}
cc.wg.Add(1)
return cc
}

func (c *ClientConn) nextID(ctx context.Context) (uint16, error) {
func (c *ClientConn) getNextInitID(ctx context.Context) (uint16, error) {
for {
select {
case <-c.done:
Expand All @@ -98,98 +81,30 @@ func (c *ClientConn) nextID(ctx context.Context) (uint16, error) {
}
}

// add dial waiter
func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) {
var (
id uint16
tp *Transport
dw *dialWaiter
)

prepareWaiter := func() (err error) {
c.mx.Lock()
defer c.mx.Unlock()

if id, err = c.nextID(ctx); err != nil {
return err
}
tp = NewTransport(c.Conn, c.log, c.local, clientPK, id)
dw = makeDialWaiter(tp)
c.dws[id] = dw
return nil
}

startWaiting := func() error {
if err := tp.WriteRequest(); err != nil {
return err
}
select {
case <-c.done:
tp.close()
return io.ErrClosedPipe
case <-ctx.Done():
tp.close()
return ctx.Err()
case ok := <-dw.accept:
if !ok {
tp.close()
return ErrRequestRejected
}
return nil
}
}

stopWaiting := func() {
c.mx.Lock()
defer c.mx.Unlock()

for {
select {
case <-dw.accept:
continue
default:
close(dw.accept)
delete(c.dws, id)
return
}
}
}

defer stopWaiting()
func (c *ClientConn) addTp(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) {
c.mx.Lock()
defer c.mx.Unlock()

if err := prepareWaiter(); err != nil {
return nil, err
}
if err := startWaiting(); err != nil {
id, err := c.getNextInitID(ctx)
if err != nil {
return nil, err
}
tp := NewTransport(c.Conn, c.log, c.local, clientPK, id, c.delTp)
c.tps[id] = tp
return tp, nil
}

func (c *ClientConn) completeDial(id uint16, accept bool) error {
c.mx.RLock()
defer c.mx.RUnlock()

dw, ok := c.dws[id]
if !ok {
return errors.New("failed to complete dial: tp_id not found")
}
func (c *ClientConn) acceptTp(clientPK cipher.PubKey, id uint16) (*Transport, error) {
tp := NewTransport(c.Conn, c.log, c.local, clientPK, id, c.delTp)

select {
case dw.accept <- accept:
if accept {
c.tps[dw.tp.id] = dw.tp
}
return nil
default:
return errors.New("failed to complete dial: dial canceled locally")
}
}

func (c *ClientConn) setTp(tp *Transport) {
c.mx.Lock()
c.tps[tp.id] = tp
c.mx.Unlock()

if err := tp.WriteAccept(); err != nil {
return nil, err
}
return tp, nil
}

func (c *ClientConn) delTp(id uint16) {
Expand All @@ -210,48 +125,32 @@ func (c *ClientConn) handleRequestFrame(ctx context.Context, accept chan<- *Tran
// remotely-initiated tps should:
// - have a payload structured as 'init_pk:resp_pk'.
// - resp_pk should be of local client.
// - use an odd tp_id with the intermediary dms_server.
// - use an odd tp_id with the intermediary dmsg_server.
initPK, respPK, ok := splitPKs(p)
if !ok || respPK != c.local || isInitiatorID(id) {
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return initPK, err
}
return initPK, ErrRequestCheckFailed
}

tp := NewTransport(c.Conn, c.log, c.local, initPK, id)
if err := tp.WriteAccept(); err != nil {
tp, err := c.acceptTp(initPK, id)
if err != nil {
return initPK, err
}
c.setTp(tp)
go tp.Serve()

select {
case <-c.done:
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return initPK, err
}
_ = tp.Close() //nolint:errcheck
return initPK, ErrClientClosed

case <-ctx.Done():
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return initPK, err
}
_ = tp.Close() //nolint:errcheck
return initPK, ctx.Err()
case accept <- tp:
}
return initPK, nil
}

func (c *ClientConn) handleAcceptFrame(id uint16, p []byte) (cipher.PubKey, error) {
// locally-initiated tps should:
initPK, respPK, ok := splitPKs(p)
if !ok || initPK != c.local || !isInitiatorID(id) {
_ = c.completeDial(id, false) //nolint:errcheck
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return respPK, err
}
return respPK, ErrRequestRejected
case accept <- tp:
return initPK, nil
}
return respPK, c.completeDial(id, true)
}

// Serve handles incoming frames.
Expand Down Expand Up @@ -284,64 +183,17 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err e

// If tp of tp_id exists, attempt to forward frame to tp.
// delete tp on any failure.
if tp, ok := c.getTp(id); ok {
log = log.WithField("remoteClient", tp.remote)

switch ft {
case RequestType:
log.Infoln("TransportRejected: ID already occupied, malicious server.")
closeConn(log)

case CloseType:
log.Infoln("CloseTransport")
tp.close()
c.delTp(tp.id)

case AckType:
if len(p) != 2 {
log.Warnln("AckRejected: Invalid sequence.")
tp.close()
c.delTp(tp.id)
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return err
}
continue
}
if err := tp.InjectAck(ioutil.DecodeUint16Seq(p)); err != nil {
log.WithError(err).Warnln("AckRejected")
continue
}
log.Infoln("AckInjected")

case FwdType:
if len(p) < 2 {
log.Warnln("FwdRejected: Invalid frame.")
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return err
}
continue
}
if err := tp.InjectFwd(p[2:]); err != nil {
log.WithError(err).Warnln("FwdRejected")
continue
}
if err := writeFrame(c.Conn, MakeFrame(AckType, tp.id, p[:2])); err != nil {
return err
}
log.Infoln("FwdInjected")

default:
log.Infoln("Unexpected")
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return err
}
if tp, ok := c.getTp(id); ok {
if err := tp.Inject(f); err != nil {
log.WithError(err).Warnf("Rejected [%s]: Transport closed.", ft)
}

continue
}

// if tp does not exist, frame should be 'REQUEST'.
// otherwise, handle any unexpected frames accordingly.

c.delTp(id) // rm tp in case closed tp is not fully removed.

switch ft {
Expand All @@ -350,54 +202,50 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err e
go func(log *logrus.Entry) {
defer c.wg.Done()

ctx, cancel := context.WithTimeout(ctx, acceptTimeout)
defer cancel()

initPK, err := c.handleRequestFrame(ctx, accept, id, p)
if err != nil {
log.WithField("remoteClient", initPK).WithError(err).Infoln("Rejected:REQUEST")
log.
WithField("remoteClient", initPK).
WithError(err).
Infoln("Rejected [REQUEST]")
if isWriteError(err) || err == ErrClientClosed {
closeConn(log)
}
return
}
log.WithField("remoteClient", initPK).Infoln("Accepted:REQUEST")
log.
WithField("remoteClient", initPK).
Infoln("Accepted [REQUEST]")
}(log)

case AcceptType:
respPK, err := c.handleAcceptFrame(id, p)
if err != nil {
log.WithField("remoteClient", respPK).WithError(err).Infoln("Rejected:ACCEPT")
if isWriteError(err) || err == ErrClientClosed {
closeConn(log)
}
continue
}
log.WithField("remoteClient", respPK).Infoln("Accepted:ACCEPT")

case CloseType:
log.Infoln("CloseIgnored: Transport already ")

default:
log.Infoln("Unexpected")
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return err
log.Infof("Ignored [%s]: No transport of given ID.", ft)
if ft != CloseType {
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return err
}
}
}
}
}

// DialTransport dials a transport to remote dms_client.
//func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) {
// select {
// case <-c.done:
// return nil, ErrClientClosed
// case <-ctx.Done():
// return nil, ctx.Err()
// default:
// return c.addDialWaiter(ctx, clientPK)()
// }
//}
func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) {
c.log.Warn("DialTransport...")
tp, err := c.addTp(ctx, clientPK)
if err != nil {
return nil, err
}
if err := tp.WriteRequest(); err != nil {
return nil, err
}
if err := tp.ReadAccept(ctx); err != nil {
return nil, err
}
c.log.Warn("DialTransport: Accepted.")
go tp.Serve()
return tp, nil
}

// Close closes the connection to dms_server.
func (c *ClientConn) Close() error {
Expand Down Expand Up @@ -439,7 +287,7 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc client.APIClient) *Client
sk: sk,
dc: dc,
conns: make(map[cipher.PubKey]*ClientConn),
accept: make(chan *Transport),
accept: make(chan *Transport, acceptChSize),
done: make(chan struct{}),
}
}
Expand Down
Loading

0 comments on commit 23fcf30

Please sign in to comment.