diff --git a/internal/ioutil/ack_waiter.go b/internal/ioutil/ack_waiter.go index 3cac07f844..4e72795769 100644 --- a/internal/ioutil/ack_waiter.go +++ b/internal/ioutil/ack_waiter.go @@ -4,7 +4,6 @@ import ( "context" "crypto/rand" "encoding/binary" - "io" "math" "sync" ) @@ -45,9 +44,8 @@ func (w *Uint16AckWaiter) RandSeq() error { } // Wait performs the given action, and waits for given seq to be Done. -func (w *Uint16AckWaiter) Wait(ctx context.Context, done <-chan struct{}, action func(seq Uint16Seq) error) error { +func (w *Uint16AckWaiter) Wait(ctx context.Context, action func(seq Uint16Seq) error) (err error) { ackCh := make(chan struct{}) - defer close(ackCh) w.mx.Lock() seq := w.nextSeq @@ -55,28 +53,30 @@ func (w *Uint16AckWaiter) Wait(ctx context.Context, done <-chan struct{}, action w.waiters[seq] = ackCh w.mx.Unlock() - if err := action(seq); err != nil { + if err = action(seq); err != nil { return err } select { case <-ackCh: - return nil - case <-done: - return io.ErrClosedPipe case <-ctx.Done(): - return ctx.Err() + err = ctx.Err() } + + w.mx.Lock() + close(ackCh) + w.waiters[seq] = nil + w.mx.Unlock() + + return err } // Done finishes given sequence. func (w *Uint16AckWaiter) Done(seq Uint16Seq) { w.mx.RLock() - ackCh := w.waiters[seq] - w.mx.RUnlock() - select { - case ackCh <- struct{}{}: + case w.waiters[seq] <- struct{}{}: default: } + w.mx.RUnlock() } diff --git a/internal/ioutil/ack_waiter_test.go b/internal/ioutil/ack_waiter_test.go new file mode 100644 index 0000000000..6a14aab1aa --- /dev/null +++ b/internal/ioutil/ack_waiter_test.go @@ -0,0 +1,24 @@ +package ioutil + +import ( + "context" + "testing" +) + +// Ensure that no race conditions occurs. +func TestUint16AckWaiter_Wait(t *testing.T) { + w := new(Uint16AckWaiter) + + seqChan := make(chan Uint16Seq) + defer close(seqChan) + for i := 0; i < 64; i++ { + go w.Wait(context.TODO(), func(seq Uint16Seq) error { //nolint:errcheck + seqChan <- seq + return nil + }) + seq := <-seqChan + for j := 0; j < i; j++ { + go w.Done(seq) + } + } +} diff --git a/pkg/dmsg/client.go b/pkg/dmsg/client.go index 24411561fc..b6be47dd9d 100644 --- a/pkg/dmsg/client.go +++ b/pkg/dmsg/client.go @@ -39,11 +39,15 @@ type ClientConn struct { // map of transports to remote dms_clients (key: tp_id, val: transport). tps [math.MaxUint16 + 1]*Transport mx sync.RWMutex // to protect tps. + + wg sync.WaitGroup } // NewClientConn creates a new ClientConn. func NewClientConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubKey) *ClientConn { - return &ClientConn{log: log, Conn: conn, local: local, remoteSrv: remote, nextInitID: randID(true)} + cc := &ClientConn{log: log, Conn: conn, local: local, remoteSrv: remote, nextInitID: randID(true)} + cc.wg.Add(1) + return cc } func (c *ClientConn) delTp(id uint16) { @@ -92,7 +96,7 @@ func (c *ClientConn) getTp(id uint16) (*Transport, bool) { return tp, ok } -func (c *ClientConn) handleRequestFrame(ctx context.Context, done <-chan struct{}, accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) { +func (c *ClientConn) handleRequestFrame(ctx context.Context, accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) { // remote-initiated tps should: // - have a payload structured as 'init_pk:resp_pk'. // - resp_pk should be of local client. @@ -115,8 +119,6 @@ func (c *ClientConn) handleRequestFrame(ctx context.Context, done <-chan struct{ c.setTp(tp) select { - case <-done: - return initPK, ErrClientClosed case <-ctx.Done(): return initPK, ctx.Err() case accept <- tp: @@ -126,19 +128,22 @@ func (c *ClientConn) handleRequestFrame(ctx context.Context, done <-chan struct{ // Serve handles incoming frames. // Remote-initiated tps that are successfully created are pushing into 'accept' and exposed via 'Client.Accept()'. -func (c *ClientConn) Serve(ctx context.Context, done <-chan struct{}, accept chan<- *Transport) (err error) { +func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err error) { + defer c.wg.Done() + log := c.log.WithField("remoteServer", c.remoteSrv) + + log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") defer func() { log.WithError(err).WithField("connCount", decrementServeCount()).Infoln("ClosingConn") }() - log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") for { f, err := readFrame(c.Conn) if err != nil { return fmt.Errorf("read failed: %s", err) } - log = log.WithField("frame", f) + log = log.WithField("received", f) ft, id, p := f.Disassemble() @@ -160,7 +165,7 @@ func (c *ClientConn) Serve(ctx context.Context, done <-chan struct{}, accept cha case RequestType: // TODO(evanlinjin): Allow for REQUEST frame handling to be done in goroutine. // Currently this causes issues (probably because we need ACK frames). - initPK, err := c.handleRequestFrame(ctx, done, accept, id, p) + initPK, err := c.handleRequestFrame(ctx, accept, id, p) if err != nil { log.WithField("remoteClient", initPK).WithError(err).Infoln("FrameRejected") if err == ErrRequestCheckFailed { @@ -200,7 +205,7 @@ func (c *ClientConn) Close() error { } err := c.Conn.Close() c.mx.Unlock() - //c.wg.Wait() + c.wg.Wait() return err } @@ -223,7 +228,7 @@ type Client struct { // NewClient creates a new Client. func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc client.APIClient) *Client { return &Client{ - log: logging.MustGetLogger("dms_client"), + log: logging.MustGetLogger("dmsg_client"), pk: pk, sk: sk, dc: dc, @@ -260,7 +265,7 @@ func (c *Client) setConn(ctx context.Context, l *ClientConn) { c.mx.Lock() c.conns[l.remoteSrv] = l if err := c.updateDiscEntry(ctx); err != nil { - c.log.WithError(err).Warn("failed to update dms_client entry") + c.log.WithError(err).Warn("updateEntry: failed") } c.mx.Unlock() } @@ -269,7 +274,7 @@ func (c *Client) delConn(ctx context.Context, pk cipher.PubKey) { c.mx.Lock() delete(c.conns, pk) if err := c.updateDiscEntry(ctx); err != nil { - c.log.WithError(err).Warn("failed to update dms_client entry") + c.log.WithError(err).Warn("updateEntry: failed") } c.mx.Unlock() } @@ -312,8 +317,9 @@ func (c *Client) findServerEntries(ctx context.Context) ([]*client.Entry, error) case <-ctx.Done(): return nil, fmt.Errorf("dms_servers are not available: %s", err) default: - c.log.WithError(err).Warnf("no dms_servers found: trying again is 1 second...") - time.Sleep(time.Second) + retry := time.Second + c.log.WithError(err).Warnf("no dms_servers found: trying again in %d second...", retry) + time.Sleep(retry) continue } } @@ -370,7 +376,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey) conn := NewClientConn(c.log, nc, c.pk, srvPK) c.setConn(ctx, conn) go func() { - if err := conn.Serve(ctx, c.done, c.accept); err != nil { + if err := conn.Serve(ctx, c.accept); err != nil { conn.log.WithError(err).WithField("dms_server", srvPK).Warn("connected with dms_server closed") c.delConn(ctx, srvPK) @@ -380,7 +386,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey) case <-c.done: case <-ctx.Done(): case <-t.C: - conn.log.WithField("dms_server", srvPK).Warn("reconnecting to dms_server") + conn.log.WithField("remoteServer", srvPK).Warn("Reconnecting") _, _ = c.findOrConnectToServer(ctx, srvPK) //nolint:errcheck } return diff --git a/pkg/dmsg/server.go b/pkg/dmsg/server.go index ea17fe5549..b777c88b20 100644 --- a/pkg/dmsg/server.go +++ b/pkg/dmsg/server.go @@ -8,6 +8,7 @@ import ( "math" "net" "sync" + "time" "github.com/skycoin/skycoin/src/util/logging" @@ -117,7 +118,7 @@ func (c *ServerConn) Serve(ctx context.Context, getConn getConnFunc) (err error) if err != nil { return fmt.Errorf("read failed: %s", err) } - log = log.WithField("frame", f) + log = log.WithField("received", f) ft, id, p := f.Disassemble() @@ -292,11 +293,11 @@ func (s *Server) Serve() error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := s.updateDiscEntry(ctx); err != nil { + if err := s.retryUpdateEntry(ctx, hsTimeout); err != nil { return fmt.Errorf("updating server's discovery entry failed with: %s", err) } - s.log.Infof("serving: pk(%s) addr(%s)", s.pk, s.lis.Addr()) + s.log.Infof("serving: pk(%s) addr(%s)", s.pk, s.Addr()) for { rawConn, err := s.lis.Accept() @@ -335,3 +336,23 @@ func (s *Server) updateDiscEntry(ctx context.Context) error { return s.dc.UpdateEntry(ctx, s.sk, entry) } + +func (s *Server) retryUpdateEntry(ctx context.Context, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + for { + if err := s.updateDiscEntry(ctx); err != nil { + select { + case <-ctx.Done(): + return ctx.Err() + default: + retry := time.Second + s.log.WithError(err).Warnf("updateEntry failed: trying again in %d second...", retry) + time.Sleep(retry) + continue + } + } + return nil + } +} diff --git a/pkg/dmsg/transport.go b/pkg/dmsg/transport.go index a9391d6aa6..e17c21a76b 100644 --- a/pkg/dmsg/transport.go +++ b/pkg/dmsg/transport.go @@ -51,7 +51,7 @@ func NewTransport(conn net.Conn, log *logging.Logger, local, remote cipher.PubKe doneCh: make(chan struct{}), } if err := tp.ackWaiter.RandSeq(); err != nil { - log.Fatalln("failed to set ack_water seq:", err) + log.Fatalln("failed to set ack_waiter seq:", err) } return tp } @@ -107,14 +107,13 @@ func (c *Transport) Handshake(ctx context.Context) error { return err } } else { - c.log.Infof("tp_hs responding...") f := MakeFrame(AcceptType, c.id, combinePKs(c.remote, c.local)) if err := writeFrame(c.Conn, f); err != nil { - c.log.WithError(err).Error("tp_hs responded with error.") + c.log.WithError(err).Error("HandshakeFailed") c.close() return err } - c.log.Infoln("tp_hs responded:", f) + c.log.WithField("sent", f).Infoln("HandshakeCompleted") } return nil } @@ -213,9 +212,14 @@ func (c *Transport) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe default: ctx, cancel := context.WithTimeout(context.Background(), readTimeout) - defer cancel() - - err := c.ackWaiter.Wait(ctx, c.doneCh, func(seq ioutil.Uint16Seq) error { + go func() { + select { + case <-ctx.Done(): + case <-c.doneCh: + cancel() + } + }() + err := c.ackWaiter.Wait(ctx, func(seq ioutil.Uint16Seq) error { if err := writeFwdFrame(c.Conn, c.id, seq, p); err != nil { c.close() return err @@ -223,6 +227,7 @@ func (c *Transport) Write(p []byte) (int, error) { return nil }) if err != nil { + cancel() return 0, err } return len(p), nil diff --git a/pkg/transport/managed_transport.go b/pkg/transport/managed_transport.go index 1217c493a4..f7ab42dcc4 100644 --- a/pkg/transport/managed_transport.go +++ b/pkg/transport/managed_transport.go @@ -3,6 +3,9 @@ package transport import ( "math/big" "sync" + "time" + + "github.com/skycoin/skywire/pkg/cipher" "github.com/google/uuid" ) @@ -81,6 +84,22 @@ func (tr *ManagedTransport) Write(p []byte) (n int, err error) { return } +// Edges returns the edges of underlying transport. +func (tr *ManagedTransport) Edges() [2]cipher.PubKey { + tr.mu.RLock() + edges := tr.Transport.Edges() + tr.mu.RUnlock() + return edges +} + +// SetDeadline sets the deadline of the underlying transport. +func (tr *ManagedTransport) SetDeadline(t time.Time) error { + tr.mu.RLock() + err := tr.Transport.SetDeadline(t) + tr.mu.RUnlock() + return err +} + // IsClosing determines whether is closing. func (tr *ManagedTransport) IsClosing() bool { select { @@ -102,6 +121,6 @@ func (tr *ManagedTransport) Close() (err error) { func (tr *ManagedTransport) updateTransport(newTr Transport) { tr.mu.Lock() - tr.Transport = newTr + tr.Transport = newTr // TODO: data race. tr.mu.Unlock() } diff --git a/pkg/transport/manager.go b/pkg/transport/manager.go index 004266f0d4..869f93aa7c 100644 --- a/pkg/transport/manager.go +++ b/pkg/transport/manager.go @@ -34,7 +34,7 @@ type Manager struct { doneChan chan struct{} TrChan chan *ManagedTransport - mu sync.RWMutex + mx sync.RWMutex } // NewManager creates a Manager with the provided configuration and transport factories. @@ -53,7 +53,7 @@ func NewManager(config *ManagerConfig, factories ...Factory) (*Manager, error) { } return &Manager{ - Logger: logging.MustGetLogger("trmanager"), + Logger: logging.MustGetLogger("tpmanager"), config: config, factories: fMap, transports: make(map[uuid.UUID]*ManagedTransport), @@ -74,31 +74,32 @@ func (tm *Manager) Factories() []string { // Transport obtains a Transport via a given Transport ID. func (tm *Manager) Transport(id uuid.UUID) *ManagedTransport { - tm.mu.RLock() + tm.mx.RLock() tr := tm.transports[id] - tm.mu.RUnlock() + tm.mx.RUnlock() return tr } // WalkTransports ranges through all transports. func (tm *Manager) WalkTransports(walk func(tp *ManagedTransport) bool) { - tm.mu.RLock() + tm.mx.RLock() for _, tp := range tm.transports { - if ok := walk(tp); !ok { + if ok := walk(tp); !ok { // TODO: data race. break } } - tm.mu.RUnlock() + tm.mx.RUnlock() } // reconnectTransports tries to reconnect previously established transports. func (tm *Manager) reconnectTransports(ctx context.Context) { - tm.mu.RLock() + tm.mx.RLock() entries := make(map[Entry]struct{}) for tmEntry := range tm.entries { entries[tmEntry] = struct{}{} } - tm.mu.RUnlock() + tm.mx.RUnlock() + for entry := range entries { if tm.Transport(entry.ID) != nil { continue @@ -154,7 +155,7 @@ func (tm *Manager) createDefaultTransports(ctx context.Context) { if exist { continue } - _, err := tm.CreateTransport(ctx, pk, "messaging", true) + _, err := tm.CreateTransport(ctx, pk, "dmsg", true) if err != nil { tm.Logger.Warnf("Failed to establish transport to a node %s: %s", pk, err) } @@ -234,18 +235,18 @@ func (tm *Manager) CreateTransport(ctx context.Context, remote cipher.PubKey, tp // DeleteTransport disconnects and removes the Transport of Transport ID. func (tm *Manager) DeleteTransport(id uuid.UUID) error { - tm.mu.Lock() - tr := tm.transports[id] + tm.mx.Lock() + tp := tm.transports[id] delete(tm.transports, id) - tm.mu.Unlock() + tm.mx.Unlock() if _, err := tm.config.DiscoveryClient.UpdateStatuses(context.Background(), &Status{ID: id, IsUp: false}); err != nil { tm.Logger.Warnf("Failed to change transport status: %s", err) } tm.Logger.Infof("Unregistered transport %s", id) - if tr != nil { - return tr.Close() + if tp != nil { + tp.Close() } return nil @@ -257,7 +258,7 @@ func (tm *Manager) Close() error { close(tm.doneChan) tm.Logger.Info("Closing transport manager") - tm.mu.Lock() + tm.mx.Lock() statuses := make([]*Status, 0) for _, tr := range tm.transports { if !tr.Public { @@ -267,7 +268,7 @@ func (tm *Manager) Close() error { go tr.Close() } - tm.mu.Unlock() + tm.mx.Unlock() if _, err := tm.config.DiscoveryClient.UpdateStatuses(context.Background(), statuses...); err != nil { tm.Logger.Warnf("Failed to change transport status: %s", err) @@ -309,14 +310,14 @@ func (tm *Manager) createTransport(ctx context.Context, remote cipher.PubKey, tp tm.Logger.Infof("Dialed to %s using %s factory. Transport ID: %s", remote, tpType, entry.ID) managedTr := newManagedTransport(entry.ID, tr, entry.Public, false) - tm.mu.Lock() + tm.mx.Lock() tm.transports[entry.ID] = managedTr select { case <-tm.doneChan: case tm.TrChan <- managedTr: default: } - tm.mu.Unlock() + tm.mx.Unlock() go tm.manageTransport(ctx, managedTr, factory, remote, public, false) @@ -345,7 +346,7 @@ func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) (*Manag tm.Logger.Infof("Accepted new transport with type %s from %s. ID: %s", factory.Type(), remote, entry.ID) managedTr := newManagedTransport(entry.ID, tr, entry.Public, true) - tm.mu.Lock() + tm.mx.Lock() tm.transports[entry.ID] = managedTr select { @@ -353,7 +354,7 @@ func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) (*Manag case tm.TrChan <- managedTr: default: } - tm.mu.Unlock() + tm.mx.Unlock() go tm.manageTransport(ctx, managedTr, factory, remote, true, true) @@ -363,8 +364,8 @@ func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) (*Manag } func (tm *Manager) walkEntries(walkFunc func(*Entry) bool) *Entry { - tm.mu.Lock() - defer tm.mu.Unlock() + tm.mx.Lock() + defer tm.mx.Unlock() for entry := range tm.entries { if walkFunc(&entry) { @@ -376,9 +377,9 @@ func (tm *Manager) walkEntries(walkFunc func(*Entry) bool) *Entry { } func (tm *Manager) addEntry(entry *Entry) { - tm.mu.Lock() + tm.mx.Lock() tm.entries[*entry] = struct{}{} - tm.mu.Unlock() + tm.mx.Unlock() } func (tm *Manager) manageTransport(ctx context.Context, managedTr *ManagedTransport, factory Factory, remote cipher.PubKey, public bool, accepted bool) { @@ -407,7 +408,6 @@ func (tm *Manager) manageTransport(ctx context.Context, managedTr *ManagedTransp } else { tm.Logger.Infof("Transport %s is already closing. Skipped error: %s", managedTr.ID, err) } - } }