Skip to content

Commit

Permalink
Improved shutdown and blocked-chan management.
Browse files Browse the repository at this point in the history
  • Loading branch information
林志宇 committed Jun 6, 2019
1 parent d2a9d99 commit bd441d7
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 84 deletions.
72 changes: 48 additions & 24 deletions pkg/dmsg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"sync"
"time"

"github.com/sirupsen/logrus"

"github.com/skycoin/skycoin/src/util/logging"

"github.com/skycoin/skywire/internal/noise"
Expand Down Expand Up @@ -40,12 +42,14 @@ type ClientConn struct {
tps [math.MaxUint16 + 1]*Transport
mx sync.RWMutex // to protect tps.

wg sync.WaitGroup
done chan struct{}
once sync.Once
wg sync.WaitGroup
}

// NewClientConn creates a new ClientConn.
func NewClientConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubKey) *ClientConn {
cc := &ClientConn{log: log, Conn: conn, local: local, remoteSrv: remote, nextInitID: randID(true)}
cc := &ClientConn{log: log, Conn: conn, local: local, remoteSrv: remote, nextInitID: randID(true), done: make(chan struct{})}
cc.wg.Add(1)
return cc
}
Expand Down Expand Up @@ -75,6 +79,8 @@ func (c *ClientConn) addTp(ctx context.Context, clientPK cipher.PubKey) (*Transp
c.nextInitID += 2

select {
case <-c.done:
return nil, ErrClientClosed
case <-ctx.Done():
return nil, ctx.Err()
default:
Expand Down Expand Up @@ -104,22 +110,27 @@ func (c *ClientConn) handleRequestFrame(ctx context.Context, accept chan<- *Tran
initPK, respPK, ok := splitPKs(p)
if !ok || respPK != c.local || isInitiatorID(id) {
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
c.Close()
return initPK, err
}
return initPK, ErrRequestCheckFailed
}

tp := NewTransport(c.Conn, c.log, c.local, initPK, id)
if err := tp.Handshake(ctx); err != nil {
// return err here as response handshake is send via ClientConn and that shouldn't fail.
c.Close()
return initPK, err
}
c.setTp(tp)

select {
case <-c.done:
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return initPK, err
}
return initPK, ErrClientClosed
case <-ctx.Done():
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return initPK, err
}
return initPK, ctx.Err()
case accept <- tp:
}
Expand Down Expand Up @@ -153,6 +164,7 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err e
if !tp.InjectRead(f) {
log.WithField("remoteClient", tp.remote).Infoln("FrameTrashed")
c.delTp(id)
continue
}
log.WithField("remoteClient", tp.remote).Infoln("FrameInjected")
continue
Expand All @@ -163,21 +175,25 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err e
c.delTp(id) // rm tp in case closed tp is not fully removed.
switch ft {
case RequestType:
// TODO(evanlinjin): Allow for REQUEST frame handling to be done in goroutine.
// Currently this causes issues (probably because we need ACK frames).
initPK, err := c.handleRequestFrame(ctx, accept, id, p)
if err != nil {
log.WithField("remoteClient", initPK).WithError(err).Infoln("FrameRejected")
if err == ErrRequestCheckFailed {
continue
c.wg.Add(1)
go func(log *logrus.Entry) {
defer c.wg.Done()
ctx, cancel := context.WithTimeout(ctx, acceptTimeout)
defer cancel()
initPK, err := c.handleRequestFrame(ctx, accept, id, p)
if err != nil {
log.WithField("remoteClient", initPK).WithError(err).Infoln("TransportRejected")
if isWriteError(err) || err == ErrClientClosed {
log.WithError(c.Close()).Warn("ClosingConnection")
}
return
}
return err
}
log.WithField("remoteClient", initPK).Infoln("FrameAccepted")
log.WithField("remoteClient", initPK).Infoln("TransportAccepted")
}(log)
case CloseType:
log.Infoln("FrameIgnored")
log.Infoln("CloseTransportIgnored")
default:
log.Infoln("FrameUnexpected")
log.Infoln("Unexpected")
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return err
}
Expand All @@ -187,16 +203,24 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err e

// DialTransport dials a transport to remote dms_client.
func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) {
tp, err := c.addTp(ctx, clientPK)
if err != nil {
return nil, err
select {
case <-c.done:
return nil, ErrClientClosed
case <-ctx.Done():
return nil, ctx.Err()
default:
tp, err := c.addTp(ctx, clientPK)
if err != nil {
return nil, err
}
return tp, tp.Handshake(ctx)
}
return tp, tp.Handshake(ctx)
}

// Close closes the connection to dms_server.
func (c *ClientConn) Close() error {
c.log.Infof("closingLink: remoteSrv(%v)", c.remoteSrv)
c.log.WithField("remoteServer", c.remoteSrv).Infoln("ClosingConnection")
c.once.Do(func() { close(c.done) })
c.mx.Lock()
for _, tp := range c.tps {
if tp != nil {
Expand Down Expand Up @@ -233,7 +257,7 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc client.APIClient) *Client
sk: sk,
dc: dc,
conns: make(map[cipher.PubKey]*ClientConn),
accept: make(chan *Transport, acceptChSize),
accept: make(chan *Transport),
done: make(chan struct{}),
}
}
Expand Down Expand Up @@ -377,7 +401,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey)
c.setConn(ctx, conn)
go func() {
if err := conn.Serve(ctx, c.accept); err != nil {
conn.log.WithError(err).WithField("dms_server", srvPK).Warn("connected with dms_server closed")
conn.log.WithError(err).WithField("remoteServer", srvPK).Warn("connected with server closed")
c.delConn(ctx, srvPK)

// reconnect logic.
Expand Down
24 changes: 18 additions & 6 deletions pkg/dmsg/frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ const (
// Type returns the transport type string.
Type = "dmsg"

hsTimeout = time.Second * 10
readTimeout = time.Second * 10
acceptChSize = 1
readChSize = 20
headerLen = 5 // fType(1 byte), chID(2 byte), payLen(2 byte)
hsTimeout = time.Second * 10
readTimeout = time.Second * 10
acceptTimeout = time.Second * 5
readChSize = 20
headerLen = 5 // fType(1 byte), chID(2 byte), payLen(2 byte)
)

func isInitiatorID(tpID uint16) bool { return tpID%2 == 0 }
Expand Down Expand Up @@ -116,9 +116,21 @@ func readFrame(r io.Reader) (Frame, error) {
return f, err
}

type writeError struct{ error }

func (e *writeError) Error() string { return "write error: " + e.error.Error() }

func isWriteError(err error) bool {
_, ok := err.(*writeError)
return ok
}

func writeFrame(w io.Writer, f Frame) error {
_, err := w.Write(f)
return err
if err != nil {
return &writeError{err}
}
return nil
}

func writeFwdFrame(w io.Writer, id uint16, seq ioutil.Uint16Seq, p []byte) error {
Expand Down
Loading

0 comments on commit bd441d7

Please sign in to comment.