diff --git a/pkg/.DS_Store b/pkg/.DS_Store new file mode 100644 index 0000000000..5008ddfcf5 Binary files /dev/null and b/pkg/.DS_Store differ diff --git a/pkg/dmsg/client.go b/pkg/dmsg/client.go index 25b4e85050..d7eddf44d7 100644 --- a/pkg/dmsg/client.go +++ b/pkg/dmsg/client.go @@ -4,14 +4,11 @@ import ( "context" "errors" "fmt" - "io" "math" "net" "sync" "time" - "github.com/skycoin/skywire/internal/ioutil" - "github.com/sirupsen/logrus" "github.com/skycoin/skycoin/src/util/logging" @@ -29,15 +26,6 @@ var ( ErrClientClosed = errors.New("client closed") ) -type dialWaiter struct { - tp *Transport - accept chan bool -} - -func makeDialWaiter(tp *Transport) *dialWaiter { - return &dialWaiter{tp: tp, accept: make(chan bool)} -} - // ClientConn represents a connection between a dmsg.Client and dmsg.Server from a client's perspective. type ClientConn struct { log *logging.Logger @@ -52,11 +40,7 @@ type ClientConn struct { // Transports: map of transports to remote dms_clients (key: tp_id, val: transport). tps [math.MaxUint16 + 1]*Transport - - // Dial Waiters: locally-initiated transports awaiting remote confirmation. - dws map[uint16]*dialWaiter - - mx sync.RWMutex // to protect tps and dws + mx sync.RWMutex // to protect tps done chan struct{} once sync.Once @@ -71,14 +55,13 @@ func NewClientConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubK local: local, remoteSrv: remote, nextInitID: randID(true), - dws: make(map[uint16]*dialWaiter), done: make(chan struct{}), } cc.wg.Add(1) return cc } -func (c *ClientConn) nextID(ctx context.Context) (uint16, error) { +func (c *ClientConn) getNextInitID(ctx context.Context) (uint16, error) { for { select { case <-c.done: @@ -98,98 +81,30 @@ func (c *ClientConn) nextID(ctx context.Context) (uint16, error) { } } -// add dial waiter -func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) { - var ( - id uint16 - tp *Transport - dw *dialWaiter - ) - - prepareWaiter := func() (err error) { - c.mx.Lock() - defer c.mx.Unlock() - - if id, err = c.nextID(ctx); err != nil { - return err - } - tp = NewTransport(c.Conn, c.log, c.local, clientPK, id) - dw = makeDialWaiter(tp) - c.dws[id] = dw - return nil - } - - startWaiting := func() error { - if err := tp.WriteRequest(); err != nil { - return err - } - select { - case <-c.done: - tp.close() - return io.ErrClosedPipe - case <-ctx.Done(): - tp.close() - return ctx.Err() - case ok := <-dw.accept: - if !ok { - tp.close() - return ErrRequestRejected - } - return nil - } - } - - stopWaiting := func() { - c.mx.Lock() - defer c.mx.Unlock() - - for { - select { - case <-dw.accept: - continue - default: - close(dw.accept) - delete(c.dws, id) - return - } - } - } - - defer stopWaiting() +func (c *ClientConn) addTp(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) { + c.mx.Lock() + defer c.mx.Unlock() - if err := prepareWaiter(); err != nil { - return nil, err - } - if err := startWaiting(); err != nil { + id, err := c.getNextInitID(ctx) + if err != nil { return nil, err } + tp := NewTransport(c.Conn, c.log, c.local, clientPK, id, c.delTp) + c.tps[id] = tp return tp, nil } -func (c *ClientConn) completeDial(id uint16, accept bool) error { - c.mx.RLock() - defer c.mx.RUnlock() - - dw, ok := c.dws[id] - if !ok { - return errors.New("failed to complete dial: tp_id not found") - } +func (c *ClientConn) acceptTp(clientPK cipher.PubKey, id uint16) (*Transport, error) { + tp := NewTransport(c.Conn, c.log, c.local, clientPK, id, c.delTp) - select { - case dw.accept <- accept: - if accept { - c.tps[dw.tp.id] = dw.tp - } - return nil - default: - return errors.New("failed to complete dial: dial canceled locally") - } -} - -func (c *ClientConn) setTp(tp *Transport) { c.mx.Lock() c.tps[tp.id] = tp c.mx.Unlock() + + if err := tp.WriteAccept(); err != nil { + return nil, err + } + return tp, nil } func (c *ClientConn) delTp(id uint16) { @@ -210,7 +125,7 @@ func (c *ClientConn) handleRequestFrame(ctx context.Context, accept chan<- *Tran // remotely-initiated tps should: // - have a payload structured as 'init_pk:resp_pk'. // - resp_pk should be of local client. - // - use an odd tp_id with the intermediary dms_server. + // - use an odd tp_id with the intermediary dmsg_server. initPK, respPK, ok := splitPKs(p) if !ok || respPK != c.local || isInitiatorID(id) { if err := writeCloseFrame(c.Conn, id, 0); err != nil { @@ -218,40 +133,24 @@ func (c *ClientConn) handleRequestFrame(ctx context.Context, accept chan<- *Tran } return initPK, ErrRequestCheckFailed } - - tp := NewTransport(c.Conn, c.log, c.local, initPK, id) - if err := tp.WriteAccept(); err != nil { + tp, err := c.acceptTp(initPK, id) + if err != nil { return initPK, err } - c.setTp(tp) + go tp.Serve() select { case <-c.done: - if err := writeCloseFrame(c.Conn, id, 0); err != nil { - return initPK, err - } + _ = tp.Close() //nolint:errcheck return initPK, ErrClientClosed + case <-ctx.Done(): - if err := writeCloseFrame(c.Conn, id, 0); err != nil { - return initPK, err - } + _ = tp.Close() //nolint:errcheck return initPK, ctx.Err() - case accept <- tp: - } - return initPK, nil -} -func (c *ClientConn) handleAcceptFrame(id uint16, p []byte) (cipher.PubKey, error) { - // locally-initiated tps should: - initPK, respPK, ok := splitPKs(p) - if !ok || initPK != c.local || !isInitiatorID(id) { - _ = c.completeDial(id, false) //nolint:errcheck - if err := writeCloseFrame(c.Conn, id, 0); err != nil { - return respPK, err - } - return respPK, ErrRequestRejected + case accept <- tp: + return initPK, nil } - return respPK, c.completeDial(id, true) } // Serve handles incoming frames. @@ -284,64 +183,17 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err e // If tp of tp_id exists, attempt to forward frame to tp. // delete tp on any failure. - if tp, ok := c.getTp(id); ok { - log = log.WithField("remoteClient", tp.remote) - - switch ft { - case RequestType: - log.Infoln("TransportRejected: ID already occupied, malicious server.") - closeConn(log) - - case CloseType: - log.Infoln("CloseTransport") - tp.close() - c.delTp(tp.id) - - case AckType: - if len(p) != 2 { - log.Warnln("AckRejected: Invalid sequence.") - tp.close() - c.delTp(tp.id) - if err := writeCloseFrame(c.Conn, id, 0); err != nil { - return err - } - continue - } - if err := tp.InjectAck(ioutil.DecodeUint16Seq(p)); err != nil { - log.WithError(err).Warnln("AckRejected") - continue - } - log.Infoln("AckInjected") - - case FwdType: - if len(p) < 2 { - log.Warnln("FwdRejected: Invalid frame.") - if err := writeCloseFrame(c.Conn, id, 0); err != nil { - return err - } - continue - } - if err := tp.InjectFwd(p[2:]); err != nil { - log.WithError(err).Warnln("FwdRejected") - continue - } - if err := writeFrame(c.Conn, MakeFrame(AckType, tp.id, p[:2])); err != nil { - return err - } - log.Infoln("FwdInjected") - default: - log.Infoln("Unexpected") - if err := writeCloseFrame(c.Conn, id, 0); err != nil { - return err - } + if tp, ok := c.getTp(id); ok { + if err := tp.Inject(f); err != nil { + log.WithError(err).Warnf("Rejected [%s]: Transport closed.", ft) } - continue } // if tp does not exist, frame should be 'REQUEST'. // otherwise, handle any unexpected frames accordingly. + c.delTp(id) // rm tp in case closed tp is not fully removed. switch ft { @@ -350,54 +202,50 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err e go func(log *logrus.Entry) { defer c.wg.Done() - ctx, cancel := context.WithTimeout(ctx, acceptTimeout) - defer cancel() - initPK, err := c.handleRequestFrame(ctx, accept, id, p) if err != nil { - log.WithField("remoteClient", initPK).WithError(err).Infoln("Rejected:REQUEST") + log. + WithField("remoteClient", initPK). + WithError(err). + Infoln("Rejected [REQUEST]") if isWriteError(err) || err == ErrClientClosed { closeConn(log) } return } - log.WithField("remoteClient", initPK).Infoln("Accepted:REQUEST") + log. + WithField("remoteClient", initPK). + Infoln("Accepted [REQUEST]") }(log) - case AcceptType: - respPK, err := c.handleAcceptFrame(id, p) - if err != nil { - log.WithField("remoteClient", respPK).WithError(err).Infoln("Rejected:ACCEPT") - if isWriteError(err) || err == ErrClientClosed { - closeConn(log) - } - continue - } - log.WithField("remoteClient", respPK).Infoln("Accepted:ACCEPT") - - case CloseType: - log.Infoln("CloseIgnored: Transport already ") - default: - log.Infoln("Unexpected") - if err := writeCloseFrame(c.Conn, id, 0); err != nil { - return err + log.Infof("Ignored [%s]: No transport of given ID.", ft) + if ft != CloseType { + if err := writeCloseFrame(c.Conn, id, 0); err != nil { + return err + } } } } } // DialTransport dials a transport to remote dms_client. -//func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) { -// select { -// case <-c.done: -// return nil, ErrClientClosed -// case <-ctx.Done(): -// return nil, ctx.Err() -// default: -// return c.addDialWaiter(ctx, clientPK)() -// } -//} +func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) { + c.log.Warn("DialTransport...") + tp, err := c.addTp(ctx, clientPK) + if err != nil { + return nil, err + } + if err := tp.WriteRequest(); err != nil { + return nil, err + } + if err := tp.ReadAccept(ctx); err != nil { + return nil, err + } + c.log.Warn("DialTransport: Accepted.") + go tp.Serve() + return tp, nil +} // Close closes the connection to dms_server. func (c *ClientConn) Close() error { @@ -439,7 +287,7 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc client.APIClient) *Client sk: sk, dc: dc, conns: make(map[cipher.PubKey]*ClientConn), - accept: make(chan *Transport), + accept: make(chan *Transport, acceptChSize), done: make(chan struct{}), } } diff --git a/pkg/dmsg/frame.go b/pkg/dmsg/frame.go index 4b850ea417..27c7af6013 100644 --- a/pkg/dmsg/frame.go +++ b/pkg/dmsg/frame.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "io" + "math" "sync/atomic" "time" @@ -16,11 +17,11 @@ const ( // Type returns the transport type string. Type = "dmsg" - hsTimeout = time.Second * 10 - readTimeout = time.Second * 10 - acceptTimeout = time.Second * 5 - readChSize = 20 - headerLen = 5 // fType(1 byte), chID(2 byte), payLen(2 byte) + hsTimeout = time.Second * 10 + tpBufCap = math.MaxUint16 + tpAckCap = math.MaxUint8 + acceptChSize = 20 + headerLen = 5 // fType(1 byte), chID(2 byte), payLen(2 byte) ) func isInitiatorID(tpID uint16) bool { return tpID%2 == 0 } diff --git a/pkg/dmsg/transport.go b/pkg/dmsg/transport.go index 8ab0b6bb83..8dd076ecd9 100644 --- a/pkg/dmsg/transport.go +++ b/pkg/dmsg/transport.go @@ -1,9 +1,9 @@ package dmsg import ( - "bufio" "context" "errors" + "fmt" "io" "net" "sync" @@ -15,10 +15,11 @@ import ( "github.com/skycoin/skywire/pkg/transport" ) -// Errors related to REQUESTs. +// Errors related to REQUEST frames. var ( - ErrRequestRejected = errors.New("request rejected") - ErrRequestCheckFailed = errors.New("request check failed") + ErrRequestRejected = errors.New("failed to create transport: request rejected") + ErrRequestCheckFailed = errors.New("failed to create transport: request check failed") + ErrAcceptCheckFailed = errors.New("failed to create transport: accept check failed") ) // Transport represents a connection from dmsg.Client to remote dmsg.Client (via dmsg.Server intermediary). @@ -31,37 +32,58 @@ type Transport struct { local cipher.PubKey remote cipher.PubKey // remote PK - ackWaiter ioutil.Uint16AckWaiter + inCh chan Frame + inMx sync.RWMutex - pWrite *bufio.Writer - pRead *io.PipeReader - mx sync.RWMutex + ackWaiter ioutil.Uint16AckWaiter + ackBuf []byte + buf net.Buffers + bufCh chan struct{} + bufSize int + bufMx sync.Mutex // protects 'buf' and 'bufCh' - closed ioutil.AtomicBool + once sync.Once + done chan struct{} + doneFunc func(id uint16) } // NewTransport creates a new dms_tp. -func NewTransport(conn net.Conn, log *logging.Logger, local, remote cipher.PubKey, id uint16) *Transport { - pRead, pWrite := io.Pipe() +func NewTransport(conn net.Conn, log *logging.Logger, local, remote cipher.PubKey, id uint16, doneFunc func(id uint16)) *Transport { tp := &Transport{ - Conn: conn, - log: log, - id: id, - local: local, - remote: remote, - pWrite: bufio.NewWriter(pWrite), - pRead: pRead, + Conn: conn, + log: log, + id: id, + local: local, + remote: remote, + inCh: make(chan Frame), + ackBuf: make([]byte, 0, tpAckCap), + bufCh: make(chan struct{}, 1), + done: make(chan struct{}), + doneFunc: doneFunc, } if err := tp.ackWaiter.RandSeq(); err != nil { log.Fatalln("failed to set ack_waiter seq:", err) } - tp.closed.Set(false) return tp } -func (tp *Transport) close() bool { - closed := tp.closed.Set(true) - _ = tp.pRead.Close() //nolint:errcheck +func (tp *Transport) close() (closed bool) { + tp.once.Do(func() { + closed = true + + close(tp.done) + tp.doneFunc(tp.id) + + tp.bufMx.Lock() + close(tp.bufCh) + tp.bufMx.Unlock() + + tp.inMx.Lock() + close(tp.inCh) + tp.inMx.Unlock() + + }) + tp.ackWaiter.StopAll() return closed } @@ -76,7 +98,12 @@ func (tp *Transport) Close() error { // IsClosed returns whether dms_tp is closed. func (tp *Transport) IsClosed() bool { - return tp.closed.Get() + select { + case <-tp.done: + return true + default: + return false + } } // Edges returns the local/remote edges of the transport (dms_client to dms_client). @@ -89,15 +116,36 @@ func (tp *Transport) Type() string { return Type } +// Inject injects a frame from 'ClientConn' to transport. +// Frame is then handled by 'tp.Serve'. +func (tp *Transport) Inject(f Frame) error { + if tp.IsClosed() { + return io.ErrClosedPipe + } + + tp.inMx.RLock() + defer tp.inMx.RUnlock() + + select { + case <-tp.done: + return io.ErrClosedPipe + case tp.inCh <- f: + return nil + } +} + +// WriteRequest writes a REQUEST frame to dmsg_server to be forwarded to associated client. func (tp *Transport) WriteRequest() error { f := MakeFrame(RequestType, tp.id, combinePKs(tp.local, tp.remote)) if err := writeFrame(tp.Conn, f); err != nil { + tp.log.WithError(err).Error("HandshakeFailed") tp.close() return err } return nil } +// WriteAccept writes an ACCEPT frame to dmsg_server to be forwarded to associated client. func (tp *Transport) WriteAccept() error { f := MakeFrame(AcceptType, tp.id, combinePKs(tp.remote, tp.local)) if err := writeFrame(tp.Conn, f); err != nil { @@ -109,47 +157,157 @@ func (tp *Transport) WriteAccept() error { return nil } -func (tp *Transport) InjectFwd(d []byte) error { - if tp.IsClosed() { +// ReadAccept awaits for an ACCEPT frame to be read from the remote client. +// TODO(evanlinjin): Cleanup errors. +func (tp *Transport) ReadAccept(ctx context.Context) (err error) { + defer func() { + tp.log.WithError(err).WithField("success", err == nil).Infoln("HandshakeDone") + }() + + select { + case <-tp.done: + tp.close() return io.ErrClosedPipe + + case <-ctx.Done(): + _ = tp.Close() //nolint:errcheck + return ctx.Err() + + case f, ok := <-tp.inCh: + if !ok { + tp.close() + return io.ErrClosedPipe + } + switch ft, id, p := f.Disassemble(); ft { + case AcceptType: + // locally-initiated tps should: + // - have a payload structured as 'init_pk:resp_pk'. + // - init_pk should be of local client. + // - resp_pk should be of remote client. + // - use an even number with the intermediary dmsg_server. + initPK, respPK, ok := splitPKs(p) + if !ok || initPK != tp.local || respPK != tp.remote || !isInitiatorID(id) { + _ = tp.Close() //nolint:errcheck + return ErrAcceptCheckFailed + } + return nil + + case CloseType: + tp.close() + return ErrRequestRejected + + default: + _ = tp.Close() //nolint:errcheck + return ErrAcceptCheckFailed + } } - // TODO(evanlinjin): find a better solution. - go func() { - tp.mx.Lock() - tp.pWrite.Write(d) //nolint:errcheck - tp.mx.Unlock() - go func() { - tp.mx.Lock() - tp.pWrite.Flush() //nolint:errcheck - tp.mx.Unlock() - }() - }() - return nil } -func (tp *Transport) InjectAck(seq ioutil.Uint16Seq) error { - if tp.IsClosed() { - return io.ErrClosedPipe +// Serve handles received frames. +func (tp *Transport) Serve() { + defer func() { + if tp.close() { + _ = writeCloseFrame(tp.Conn, tp.id, 0) //nolint:errcheck + } + }() + + for { + select { + case <-tp.done: + return + + case f, ok := <-tp.inCh: + if !ok { + return + } + log := tp.log. + WithField("remoteClient", tp.remote). + WithField("received", f) + + switch p := f.Pay(); f.Type() { + case FwdType: + if len(p) < 2 { + log.Warnln("Rejected [FWD]: Invalid payload size.") + return + } + ack := MakeFrame(AckType, tp.id, p[:2]) + + tp.bufMx.Lock() + if tp.bufSize += len(p[2:]); tp.bufSize > tpBufCap { + tp.ackBuf = append(tp.ackBuf, ack...) + } else { + go func() { + if err := writeFrame(tp.Conn, ack); err != nil { + tp.close() + } + }() + } + tp.buf = append(tp.buf, p[2:]) + select { + case <-tp.done: + case tp.bufCh <- struct{}{}: + default: + } + log.WithField("bufSize", fmt.Sprintf("%d/%d", tp.bufSize, tpBufCap)).Infoln("Injected [FWD]") + tp.bufMx.Unlock() + + case AckType: + if len(p) != 2 { + log.Warnln("Rejected [ACK]: Invalid payload size.") + return + } + tp.ackWaiter.Done(ioutil.DecodeUint16Seq(p[:2])) + log.Infoln("Injected [ACK]") + + case CloseType: + log.Infoln("Injected [CLOSE]: Closing transport...") + return + + case RequestType: + log.Warnln("Rejected [REQUEST]: ID already occupied, malicious server.") + _ = tp.Conn.Close() + return + + default: + tp.log.Infof("Rejected [%s]: Unexpected frame, malicious server (ignored for now).", f.Type()) + } + } } - tp.ackWaiter.Done(seq) - return nil } // Read implements io.Reader +// TODO(evanlinjin): read deadline. func (tp *Transport) Read(p []byte) (n int, err error) { - return tp.pRead.Read(p) +startRead: + tp.bufMx.Lock() + n, err = tp.buf.Read(p) + go func() { + if tp.bufSize -= n; tp.bufSize < tpBufCap { + if err := writeFrame(tp.Conn, tp.ackBuf); err != nil { + tp.close() + } + tp.ackBuf = make([]byte, 0, tpAckCap) + } + tp.bufMx.Unlock() + }() + + if tp.IsClosed() { + return n, err + } + if n > 0 { + return n, nil + } + <-tp.bufCh + goto startRead } // Write implements io.Writer +// TODO(evanlinjin): write deadline. func (tp *Transport) Write(p []byte) (int, error) { if tp.IsClosed() { return 0, io.ErrClosedPipe } - - ctx, cancel := context.WithTimeout(context.Background(), readTimeout) - defer cancel() - - err := tp.ackWaiter.Wait(ctx, func(seq ioutil.Uint16Seq) error { + err := tp.ackWaiter.Wait(context.Background(), func(seq ioutil.Uint16Seq) error { if err := writeFwdFrame(tp.Conn, tp.id, seq, p); err != nil { tp.close() return err diff --git a/pkg/router/router.go b/pkg/router/router.go index 0ef310a352..c676c06925 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -430,15 +430,25 @@ func (r *Router) setupProto(ctx context.Context) (*setup.Protocol, transport.Tra return sProto, tr, nil } -func (r *Router) fetchBestRoutes(source, destination cipher.PubKey) (routing.Route, routing.Route, error) { +func (r *Router) fetchBestRoutes(source, destination cipher.PubKey) (fwd routing.Route, rev routing.Route, err error) { r.Logger.Infof("Requesting new routes from %s to %s", source, destination) - forwardRoutes, reverseRoutes, err := r.config.RouteFinder.PairedRoutes(source, destination, minHops, maxHops) + + timer := time.NewTimer(time.Second * 10) + defer timer.Stop() + +fetchRoutesAgain: + fwdRoutes, revRoutes, err := r.config.RouteFinder.PairedRoutes(source, destination, minHops, maxHops) if err != nil { - return nil, nil, err + select { + case <-timer.C: + return nil, nil, err + default: + goto fetchRoutesAgain + } } - r.Logger.Infof("Found routes Forward: %s. Reverse %s", forwardRoutes, reverseRoutes) - return forwardRoutes[0], reverseRoutes[0], nil + r.Logger.Infof("Found routes Forward: %s. Reverse %s", fwdRoutes, revRoutes) + return fwdRoutes[0], revRoutes[0], nil } func (r *Router) advanceNoiseHandshake(addr *app.LoopAddr, noiseMsg []byte) (ni *noise.Noise, noiseRes []byte, err error) {