From 982a8c9b3cb783a3d5e54b5a945b3c27e71698fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E5=BF=97=E5=AE=87?= Date: Mon, 17 Jun 2019 17:28:48 +0800 Subject: [PATCH] Fixed dmsg accept-transport behaviour. --- pkg/dmsg/client.go | 39 ++++++++----------------- pkg/dmsg/server_test.go | 1 + pkg/dmsg/transport.go | 63 +++++++++++++++++++++++++++++++---------- 3 files changed, 61 insertions(+), 42 deletions(-) diff --git a/pkg/dmsg/client.go b/pkg/dmsg/client.go index 5cc1bd59ed..ab63215773 100644 --- a/pkg/dmsg/client.go +++ b/pkg/dmsg/client.go @@ -98,17 +98,10 @@ func (c *ClientConn) addTp(ctx context.Context, clientPK cipher.PubKey) (*Transp return tp, nil } -func (c *ClientConn) acceptTp(clientPK cipher.PubKey, id uint16) (*Transport, error) { - tp := NewTransport(c.Conn, c.log, c.local, clientPK, id, c.delTp) - +func (c *ClientConn) setTp(tp *Transport) { c.mx.Lock() c.tps[tp.id] = tp c.mx.Unlock() - - if err := tp.WriteAccept(); err != nil { - return nil, err - } - return tp, nil } func (c *ClientConn) delTp(id uint16) { @@ -131,7 +124,7 @@ func (c *ClientConn) setNextInitID(nextInitID uint16) { c.mx.Unlock() } -func (c *ClientConn) handleRequestFrame(ctx context.Context, accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) { +func (c *ClientConn) handleRequestFrame(accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) { // remotely-initiated tps should: // - have a payload structured as 'init_pk:resp_pk'. // - resp_pk should be of local client. @@ -143,22 +136,20 @@ func (c *ClientConn) handleRequestFrame(ctx context.Context, accept chan<- *Tran } return initPK, ErrRequestCheckFailed } - tp, err := c.acceptTp(initPK, id) - if err != nil { - return initPK, err - } - go tp.Serve() + + tp := NewTransport(c.Conn, c.log, c.local, initPK, id, c.delTp) + c.setTp(tp) select { case <-c.done: _ = tp.Close() //nolint:errcheck return initPK, ErrClientClosed - case <-ctx.Done(): - _ = tp.Close() //nolint:errcheck - return initPK, ctx.Err() - case accept <- tp: + if err := tp.WriteAccept(); err != nil { + return initPK, err + } + go tp.Serve() return initPK, nil } } @@ -206,21 +197,15 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err e c.wg.Add(1) go func(log *logrus.Entry) { defer c.wg.Done() - - initPK, err := c.handleRequestFrame(ctx, accept, id, p) + initPK, err := c.handleRequestFrame(accept, id, p) if err != nil { - log. - WithField("remoteClient", initPK). - WithError(err). - Infoln("Rejected [REQUEST]") + log.WithField("remoteClient", initPK).WithError(err).Infoln("Rejected [REQUEST]") if isWriteError(err) || err == ErrClientClosed { closeConn(log) } return } - log. - WithField("remoteClient", initPK). - Infoln("Accepted [REQUEST]") + log.WithField("remoteClient", initPK).Infoln("Accepted [REQUEST]") }(log) default: diff --git a/pkg/dmsg/server_test.go b/pkg/dmsg/server_test.go index a936051911..e0886c99dc 100644 --- a/pkg/dmsg/server_test.go +++ b/pkg/dmsg/server_test.go @@ -17,6 +17,7 @@ import ( "golang.org/x/net/nettest" "github.com/skycoin/skycoin/src/util/logging" + "github.com/skycoin/skywire/internal/noise" "github.com/skycoin/skywire/pkg/cipher" "github.com/skycoin/skywire/pkg/messaging-discovery/client" diff --git a/pkg/dmsg/transport.go b/pkg/dmsg/transport.go index e433f40283..6fbddd9c2d 100644 --- a/pkg/dmsg/transport.go +++ b/pkg/dmsg/transport.go @@ -42,9 +42,11 @@ type Transport struct { bufSize int bufMx sync.Mutex // protects 'buf' and 'bufCh' - once sync.Once - done chan struct{} - doneFunc func(id uint16) + servingOnce sync.Once + serving chan struct{} + doneOnce sync.Once + done chan struct{} + doneFunc func(id uint16) // contains a method to remove the transport from dmsg.Client } // NewTransport creates a new dms_tp. @@ -59,6 +61,7 @@ func NewTransport(conn net.Conn, log *logging.Logger, local, remote cipher.PubKe ackBuf: make([]byte, 0, tpAckCap), buf: make(net.Buffers, 0, tpBufFrameCap), bufCh: make(chan struct{}, 1), + serving: make(chan struct{}), done: make(chan struct{}), doneFunc: doneFunc, } @@ -68,8 +71,16 @@ func NewTransport(conn net.Conn, log *logging.Logger, local, remote cipher.PubKe return tp } +func (tp *Transport) serve() (started bool) { + tp.servingOnce.Do(func() { + started = true + close(tp.serving) + }) + return started +} + func (tp *Transport) close() (closed bool) { - tp.once.Do(func() { + tp.doneOnce.Do(func() { closed = true close(tp.done) @@ -85,6 +96,7 @@ func (tp *Transport) close() (closed bool) { }) + tp.serve() // just in case. tp.ackWaiter.StopAll() return closed } @@ -147,14 +159,20 @@ func (tp *Transport) WriteRequest() error { } // WriteAccept writes an ACCEPT frame to dmsg_server to be forwarded to associated client. -func (tp *Transport) WriteAccept() error { +func (tp *Transport) WriteAccept() (err error) { + defer func() { + if err != nil { + tp.log.WithError(err).WithField("remote", tp.remote).Warnln("(HANDSHAKE) Rejected locally.") + } else { + tp.log.WithField("remote", tp.remote).Infoln("(HANDSHAKE) Accepted locally.") + } + }() + f := MakeFrame(AcceptType, tp.id, combinePKs(tp.remote, tp.local)) - if err := writeFrame(tp.Conn, f); err != nil { - tp.log.WithError(err).Error("HandshakeFailed") + if err = writeFrame(tp.Conn, f); err != nil { tp.close() return err } - tp.log.WithField("sent", f).Infoln("HandshakeCompleted") return nil } @@ -162,7 +180,11 @@ func (tp *Transport) WriteAccept() error { // TODO(evanlinjin): Cleanup errors. func (tp *Transport) ReadAccept(ctx context.Context) (err error) { defer func() { - tp.log.WithError(err).WithField("success", err == nil).Infoln("HandshakeDone") + if err != nil { + tp.log.WithError(err).WithField("remote", tp.remote).Warnln("(HANDSHAKE) Rejected by remote.") + } else { + tp.log.WithField("remote", tp.remote).Infoln("(HANDSHAKE) Accepted by remote.") + } }() select { @@ -206,6 +228,10 @@ func (tp *Transport) ReadAccept(ctx context.Context) (err error) { // Serve handles received frames. func (tp *Transport) Serve() { + if !tp.serve() { + return + } + defer func() { if tp.close() { _ = writeCloseFrame(tp.Conn, tp.id, 0) //nolint:errcheck @@ -221,9 +247,7 @@ func (tp *Transport) Serve() { if !ok { return } - log := tp.log. - WithField("remoteClient", tp.remote). - WithField("received", f) + log := tp.log.WithField("remoteClient", tp.remote).WithField("received", f) switch p := f.Pay(); f.Type() { case FwdType: @@ -279,6 +303,8 @@ func (tp *Transport) Serve() { // Read implements io.Reader // TODO(evanlinjin): read deadline. func (tp *Transport) Read(p []byte) (n int, err error) { + <-tp.serving + startRead: tp.bufMx.Lock() n, err = tp.buf.Read(p) @@ -293,12 +319,16 @@ startRead: } tp.bufMx.Unlock() - if tp.IsClosed() { - return n, err + if err != nil { + if tp.IsClosed() { + return n, err + } + err = nil } - if n > 0 { + if n > 0 || len(p) == 0 { return n, nil } + <-tp.bufCh goto startRead } @@ -306,9 +336,12 @@ startRead: // Write implements io.Writer // TODO(evanlinjin): write deadline. func (tp *Transport) Write(p []byte) (int, error) { + <-tp.serving + if tp.IsClosed() { return 0, io.ErrClosedPipe } + err := tp.ackWaiter.Wait(context.Background(), func(seq ioutil.Uint16Seq) error { if err := writeFwdFrame(tp.Conn, tp.id, seq, p); err != nil { tp.close()