Skip to content

Commit

Permalink
Fixed dmsg accept-transport behaviour.
Browse files Browse the repository at this point in the history
  • Loading branch information
林志宇 committed Jun 17, 2019
1 parent ce59ee1 commit 982a8c9
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 42 deletions.
39 changes: 12 additions & 27 deletions pkg/dmsg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,10 @@ func (c *ClientConn) addTp(ctx context.Context, clientPK cipher.PubKey) (*Transp
return tp, nil
}

func (c *ClientConn) acceptTp(clientPK cipher.PubKey, id uint16) (*Transport, error) {
tp := NewTransport(c.Conn, c.log, c.local, clientPK, id, c.delTp)

func (c *ClientConn) setTp(tp *Transport) {
c.mx.Lock()
c.tps[tp.id] = tp
c.mx.Unlock()

if err := tp.WriteAccept(); err != nil {
return nil, err
}
return tp, nil
}

func (c *ClientConn) delTp(id uint16) {
Expand All @@ -131,7 +124,7 @@ func (c *ClientConn) setNextInitID(nextInitID uint16) {
c.mx.Unlock()
}

func (c *ClientConn) handleRequestFrame(ctx context.Context, accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) {
func (c *ClientConn) handleRequestFrame(accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) {
// remotely-initiated tps should:
// - have a payload structured as 'init_pk:resp_pk'.
// - resp_pk should be of local client.
Expand All @@ -143,22 +136,20 @@ func (c *ClientConn) handleRequestFrame(ctx context.Context, accept chan<- *Tran
}
return initPK, ErrRequestCheckFailed
}
tp, err := c.acceptTp(initPK, id)
if err != nil {
return initPK, err
}
go tp.Serve()

tp := NewTransport(c.Conn, c.log, c.local, initPK, id, c.delTp)
c.setTp(tp)

select {
case <-c.done:
_ = tp.Close() //nolint:errcheck
return initPK, ErrClientClosed

case <-ctx.Done():
_ = tp.Close() //nolint:errcheck
return initPK, ctx.Err()

case accept <- tp:
if err := tp.WriteAccept(); err != nil {
return initPK, err
}
go tp.Serve()
return initPK, nil
}
}
Expand Down Expand Up @@ -206,21 +197,15 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err e
c.wg.Add(1)
go func(log *logrus.Entry) {
defer c.wg.Done()

initPK, err := c.handleRequestFrame(ctx, accept, id, p)
initPK, err := c.handleRequestFrame(accept, id, p)
if err != nil {
log.
WithField("remoteClient", initPK).
WithError(err).
Infoln("Rejected [REQUEST]")
log.WithField("remoteClient", initPK).WithError(err).Infoln("Rejected [REQUEST]")
if isWriteError(err) || err == ErrClientClosed {
closeConn(log)
}
return
}
log.
WithField("remoteClient", initPK).
Infoln("Accepted [REQUEST]")
log.WithField("remoteClient", initPK).Infoln("Accepted [REQUEST]")
}(log)

default:
Expand Down
1 change: 1 addition & 0 deletions pkg/dmsg/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"golang.org/x/net/nettest"

"github.com/skycoin/skycoin/src/util/logging"

"github.com/skycoin/skywire/internal/noise"
"github.com/skycoin/skywire/pkg/cipher"
"github.com/skycoin/skywire/pkg/messaging-discovery/client"
Expand Down
63 changes: 48 additions & 15 deletions pkg/dmsg/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ type Transport struct {
bufSize int
bufMx sync.Mutex // protects 'buf' and 'bufCh'

once sync.Once
done chan struct{}
doneFunc func(id uint16)
servingOnce sync.Once
serving chan struct{}
doneOnce sync.Once
done chan struct{}
doneFunc func(id uint16) // contains a method to remove the transport from dmsg.Client
}

// NewTransport creates a new dms_tp.
Expand All @@ -59,6 +61,7 @@ func NewTransport(conn net.Conn, log *logging.Logger, local, remote cipher.PubKe
ackBuf: make([]byte, 0, tpAckCap),
buf: make(net.Buffers, 0, tpBufFrameCap),
bufCh: make(chan struct{}, 1),
serving: make(chan struct{}),
done: make(chan struct{}),
doneFunc: doneFunc,
}
Expand All @@ -68,8 +71,16 @@ func NewTransport(conn net.Conn, log *logging.Logger, local, remote cipher.PubKe
return tp
}

func (tp *Transport) serve() (started bool) {
tp.servingOnce.Do(func() {
started = true
close(tp.serving)
})
return started
}

func (tp *Transport) close() (closed bool) {
tp.once.Do(func() {
tp.doneOnce.Do(func() {
closed = true

close(tp.done)
Expand All @@ -85,6 +96,7 @@ func (tp *Transport) close() (closed bool) {

})

tp.serve() // just in case.
tp.ackWaiter.StopAll()
return closed
}
Expand Down Expand Up @@ -147,22 +159,32 @@ func (tp *Transport) WriteRequest() error {
}

// WriteAccept writes an ACCEPT frame to dmsg_server to be forwarded to associated client.
func (tp *Transport) WriteAccept() error {
func (tp *Transport) WriteAccept() (err error) {
defer func() {
if err != nil {
tp.log.WithError(err).WithField("remote", tp.remote).Warnln("(HANDSHAKE) Rejected locally.")
} else {
tp.log.WithField("remote", tp.remote).Infoln("(HANDSHAKE) Accepted locally.")
}
}()

f := MakeFrame(AcceptType, tp.id, combinePKs(tp.remote, tp.local))
if err := writeFrame(tp.Conn, f); err != nil {
tp.log.WithError(err).Error("HandshakeFailed")
if err = writeFrame(tp.Conn, f); err != nil {
tp.close()
return err
}
tp.log.WithField("sent", f).Infoln("HandshakeCompleted")
return nil
}

// ReadAccept awaits for an ACCEPT frame to be read from the remote client.
// TODO(evanlinjin): Cleanup errors.
func (tp *Transport) ReadAccept(ctx context.Context) (err error) {
defer func() {
tp.log.WithError(err).WithField("success", err == nil).Infoln("HandshakeDone")
if err != nil {
tp.log.WithError(err).WithField("remote", tp.remote).Warnln("(HANDSHAKE) Rejected by remote.")
} else {
tp.log.WithField("remote", tp.remote).Infoln("(HANDSHAKE) Accepted by remote.")
}
}()

select {
Expand Down Expand Up @@ -206,6 +228,10 @@ func (tp *Transport) ReadAccept(ctx context.Context) (err error) {

// Serve handles received frames.
func (tp *Transport) Serve() {
if !tp.serve() {
return
}

defer func() {
if tp.close() {
_ = writeCloseFrame(tp.Conn, tp.id, 0) //nolint:errcheck
Expand All @@ -221,9 +247,7 @@ func (tp *Transport) Serve() {
if !ok {
return
}
log := tp.log.
WithField("remoteClient", tp.remote).
WithField("received", f)
log := tp.log.WithField("remoteClient", tp.remote).WithField("received", f)

switch p := f.Pay(); f.Type() {
case FwdType:
Expand Down Expand Up @@ -279,6 +303,8 @@ func (tp *Transport) Serve() {
// Read implements io.Reader
// TODO(evanlinjin): read deadline.
func (tp *Transport) Read(p []byte) (n int, err error) {
<-tp.serving

startRead:
tp.bufMx.Lock()
n, err = tp.buf.Read(p)
Expand All @@ -293,22 +319,29 @@ startRead:
}
tp.bufMx.Unlock()

if tp.IsClosed() {
return n, err
if err != nil {
if tp.IsClosed() {
return n, err
}
err = nil
}
if n > 0 {
if n > 0 || len(p) == 0 {
return n, nil
}

<-tp.bufCh
goto startRead
}

// Write implements io.Writer
// TODO(evanlinjin): write deadline.
func (tp *Transport) Write(p []byte) (int, error) {
<-tp.serving

if tp.IsClosed() {
return 0, io.ErrClosedPipe
}

err := tp.ackWaiter.Wait(context.Background(), func(seq ioutil.Uint16Seq) error {
if err := writeFwdFrame(tp.Conn, tp.id, seq, p); err != nil {
tp.close()
Expand Down

0 comments on commit 982a8c9

Please sign in to comment.