Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/client server reconnection and timeout #175

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
.DS_Store

.idea/
.vscode/
bin/

/dmsg-discovery
Expand Down
62 changes: 46 additions & 16 deletions pkg/dmsg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ import (
"github.com/skycoin/skywire-utilities/pkg/netutil"
)

// TODO(evanlinjin): We should implement exponential backoff at some point.
const serveWait = time.Second

// SessionDialCallback is triggered BEFORE a session is dialed to.
// If a non-nil error is returned, the session dial is instantly terminated.
type SessionDialCallback func(network, addr string) (err error)
Expand Down Expand Up @@ -74,6 +71,10 @@ type Client struct {
conf *Config
porter *netutil.Porter

bo time.Duration // initial backoff duration
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the netutil.Retrier here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

netutil.Retrier doesn't work in this case as we need to wait with exponential backoff in multiple places. And to do that in the retrier we have to return an error but by doing that we have to restart the main loop all over again instead of the current version where we wait between each entry here.

maxBO time.Duration // maximum backoff duration
factor float64 // multiplier for the backoff duration that is applied on every retry

errCh chan error
done chan struct{}
once sync.Once
Expand All @@ -82,20 +83,24 @@ type Client struct {

// NewClient creates a dmsg client entity.
func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc disc.APIClient, conf *Config) *Client {
c := new(Client)
c.ready = make(chan struct{})
c.porter = netutil.NewPorter(netutil.PorterMinEphemeral)
c.errCh = make(chan error, 10)
c.done = make(chan struct{})

log := logging.MustGetLogger("dmsg_client")

// Init config.
if conf == nil {
conf = DefaultConfig()
}
conf.Ensure()
c.conf = conf

c := &Client{
ready: make(chan struct{}),
porter: netutil.NewPorter(netutil.PorterMinEphemeral),
errCh: make(chan error, 10),
done: make(chan struct{}),
conf: conf,
bo: time.Second * 5,
maxBO: time.Minute,
factor: netutil.DefaultFactor,
}

// Init common fields.
c.EntityCommon.init(pk, sk, dc, log, conf.UpdateInterval)
Expand Down Expand Up @@ -156,15 +161,15 @@ func (ce *Client) Serve(ctx context.Context) {
if err == context.Canceled || err == context.DeadlineExceeded {
return
}
time.Sleep(time.Second) // TODO(evanlinjin): Implement exponential back off.
ce.serveWait()
continue
}
if len(entries) == 0 {
ce.log.Warnf("No entries found. Retrying after %s...", serveWait.String())
time.Sleep(serveWait)
ce.log.Warnf("No entries found. Retrying after %s...", ce.bo.String())
ce.serveWait()
}

for _, entry := range entries {
for n, entry := range entries {
if isClosed(ce.done) {
return
}
Expand All @@ -183,11 +188,21 @@ func (ce *Client) Serve(ctx context.Context) {
}

if err := ce.EnsureSession(cancellabelCtx, entry); err != nil {
ce.log.WithField("remote_pk", entry.Static).WithError(err).Warn("Failed to establish session.")
if err == context.Canceled || err == context.DeadlineExceeded {
ce.log.WithField("remote_pk", entry.Static).WithError(err).Warn("Failed to establish session.")
return
}
time.Sleep(serveWait)
// we send an error if this is the last server
if n == (len(entries) - 1) {
if !isClosed(ce.done) {
ce.sesMx.Lock()
ce.errCh <- err
ce.sesMx.Unlock()
}
}
ce.log.WithField("remote_pk", entry.Static).WithError(err).WithField("current_backoff", ce.bo.String()).
Warn("Failed to establish session.")
ce.serveWait()
}
}
// We dial all servers and wait for error or done signal.
Expand Down Expand Up @@ -443,6 +458,21 @@ func (ce *Client) ConnectionsSummary() ConnectionsSummary {
return out
}

func (ce *Client) serveWait() {
bo := ce.bo

t := time.NewTimer(bo)
defer t.Stop()

if newBO := time.Duration(float64(bo) * ce.factor); ce.maxBO == 0 || newBO <= ce.maxBO {
ce.bo = newBO
if newBO > ce.maxBO {
ce.bo = ce.maxBO
}
}
<-t.C
}

func hasPK(pks []cipher.PubKey, pk cipher.PubKey) bool {
for _, oldPK := range pks {
if oldPK == pk {
Expand Down
13 changes: 6 additions & 7 deletions pkg/dmsg/session_common.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package dmsg

import (
"bufio"
"encoding/binary"
"io"
"net"
Expand Down Expand Up @@ -62,11 +61,11 @@ func (sc *SessionCommon) initClient(entity *EntityCommon, conn net.Conn, rPK cip
return err
}

r := bufio.NewReader(conn)
if err := noise.InitiatorHandshake(ns, r, conn); err != nil {
rw := noise.NewReadWriter(conn, ns)
if err := rw.Handshake(time.Second * 5); err != nil {
return err
}
if r.Buffered() > 0 {
if rw.Buffered() > 0 {
return ErrSessionHandshakeExtraBytes
}

Expand Down Expand Up @@ -95,11 +94,11 @@ func (sc *SessionCommon) initServer(entity *EntityCommon, conn net.Conn) error {
return err
}

r := bufio.NewReader(conn)
if err := noise.ResponderHandshake(ns, r, conn); err != nil {
rw := noise.NewReadWriter(conn, ns)
if err := rw.Handshake(time.Second * 5); err != nil {
return err
}
if r.Buffered() > 0 {
if rw.Buffered() > 0 {
return ErrSessionHandshakeExtraBytes
}

Expand Down
5 changes: 5 additions & 0 deletions pkg/noise/read_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ func (rw *ReadWriter) Handshake(hsTimeout time.Duration) error {
}
}

// Buffered returns the number of bytes that can be read from the buffer rawInput.
func (rw *ReadWriter) Buffered() int {
return rw.rawInput.Buffered()
}

// LocalStatic returns the local static public key.
func (rw *ReadWriter) LocalStatic() cipher.PubKey {
return rw.ns.LocalStatic()
Expand Down