diff --git a/cmd/apps/skychat/chat.go b/cmd/apps/skychat/chat.go index 2999f76b77..ba80c6f18e 100644 --- a/cmd/apps/skychat/chat.go +++ b/cmd/apps/skychat/chat.go @@ -108,10 +108,10 @@ func messageHandler(w http.ResponseWriter, req *http.Request) { addr := &app.Addr{PubKey: pk, Port: 1} connsMu.Lock() - conn := chatConns[pk] + conn, ok := chatConns[pk] connsMu.Unlock() - if conn == nil { + if !ok { var err error err = r.Do(func() error { conn, err = chatApp.Dial(addr) diff --git a/integration/test-messaging.sh b/integration/test-messaging.sh index 3a8ed4cf49..6a5add8b75 100755 --- a/integration/test-messaging.sh +++ b/integration/test-messaging.sh @@ -1,4 +1,3 @@ #!/usr/bin/env bash -source ./integration/generic/env-vars.sh -# curl --data {'"recipient":"'$PK_A'", "message":"Hello Joe!"}' -X POST $CHAT_C +curl --data {'"recipient":"'$PK_A'", "message":"Hello Joe!"}' -X POST $CHAT_C curl --data {'"recipient":"'$PK_C'", "message":"Hello Mike!"}' -X POST $CHAT_A diff --git a/pkg/app/app.go b/pkg/app/app.go index 71b07924eb..d0f40a2701 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -170,6 +170,8 @@ func (app *App) handleProto() { } func (app *App) serveConn(addr *LoopAddr, conn io.ReadWriteCloser) { + defer conn.Close() + for { buf := make([]byte, 32*1024) n, err := conn.Read(buf) @@ -183,11 +185,10 @@ func (app *App) serveConn(addr *LoopAddr, conn io.ReadWriteCloser) { } } - if app.conns[*addr] != nil { + app.mu.Lock() + if _, ok := app.conns[*addr]; ok { app.proto.Send(FrameClose, &addr, nil) // nolint: errcheck } - - app.mu.Lock() delete(app.conns, *addr) app.mu.Unlock() } @@ -251,13 +252,12 @@ func (app *App) confirmLoop(data []byte) error { type appConn struct { net.Conn - rw io.ReadWriteCloser laddr *Addr raddr *Addr } func newAppConn(conn net.Conn, laddr, raddr *Addr) *appConn { - return &appConn{conn, conn, laddr, raddr} + return &appConn{conn, laddr, raddr} } func (conn *appConn) LocalAddr() net.Addr { @@ -267,18 +267,3 @@ func (conn *appConn) LocalAddr() net.Addr { func (conn *appConn) RemoteAddr() net.Addr { return conn.raddr } - -func (conn *appConn) Write(p []byte) (n int, err error) { - return conn.rw.Write(p) -} - -func (conn *appConn) Read(p []byte) (n int, err error) { - return conn.rw.Read(p) -} - -func (conn *appConn) Close() error { - if conn == nil { - return nil - } - return conn.rw.Close() -} diff --git a/pkg/node/rpc.go b/pkg/node/rpc.go index b7e8ace227..1465103f87 100644 --- a/pkg/node/rpc.go +++ b/pkg/node/rpc.go @@ -55,7 +55,7 @@ func newTransportSummary(tm *transport.Manager, tp *transport.ManagedTransport, } summary := &TransportSummary{ - ID: tp.ID, + ID: tp.Entry.ID, Local: tm.Local(), Remote: remote, Type: tp.Type(), diff --git a/pkg/node/rpc_test.go b/pkg/node/rpc_test.go index feb888bc58..b8ec54c11d 100644 --- a/pkg/node/rpc_test.go +++ b/pkg/node/rpc_test.go @@ -246,7 +246,7 @@ func TestRPC(t *testing.T) { t.Run("Transport", func(t *testing.T) { var ids []uuid.UUID node.tm.WalkTransports(func(tp *transport.ManagedTransport) bool { - ids = append(ids, tp.ID) + ids = append(ids, tp.Entry.ID) return true }) diff --git a/pkg/router/router.go b/pkg/router/router.go index 74b6dde6ce..c0e9647891 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -98,6 +98,7 @@ func (r *Router) Serve(ctx context.Context) error { } go func(tp transport.Transport) { + defer tp.Close() for { if err := serve(tp); err != nil { if err != io.EOF { @@ -427,7 +428,7 @@ func (r *Router) setupProto(ctx context.Context) (*setup.Protocol, transport.Tra // TODO(evanlinjin): need string constant for tp type. tr, err := r.tm.CreateTransport(ctx, r.config.SetupNodes[0], dmsg.Type, false) if err != nil { - return nil, nil, fmt.Errorf("transport: %s", err) + return nil, nil, fmt.Errorf("setup transport: %s", err) } sProto := setup.NewSetupProtocol(tr) diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 508950ea1f..93f2207657 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -85,7 +85,7 @@ func TestRouterForwarding(t *testing.T) { tr3, err := m3.CreateTransport(context.TODO(), pk2, "mock2", true) require.NoError(t, err) - rule := routing.ForwardRule(time.Now().Add(time.Hour), 4, tr3.ID) + rule := routing.ForwardRule(time.Now().Add(time.Hour), 4, tr3.Entry.ID) routeID, err := rt.AddRule(rule) require.NoError(t, err) @@ -197,9 +197,9 @@ func TestRouterApp(t *testing.T) { ni1, ni2 := noiseInstances(t, pk1, pk2, sk1, sk2) raddr := &app.Addr{PubKey: pk2, Port: 5} - require.NoError(t, r.pm.SetLoop(6, raddr, &loop{tr.ID, 4, ni1})) + require.NoError(t, r.pm.SetLoop(6, raddr, &loop{tr.Entry.ID, 4, ni1})) - tr2 := m2.Transport(tr.ID) + tr2 := m2.Transport(tr.Entry.ID) go proto.Send(app.FrameSend, &app.Packet{Addr: &app.LoopAddr{Port: 6, Remote: *raddr}, Payload: []byte("bar")}, nil) // nolint: errcheck packet := make(routing.Packet, 29) @@ -333,13 +333,13 @@ func TestRouterSetup(t *testing.T) { var routeID routing.RouteID t.Run("add route", func(t *testing.T) { - routeID, err = setup.AddRule(sProto, routing.ForwardRule(time.Now().Add(time.Hour), 2, tr.ID)) + routeID, err = setup.AddRule(sProto, routing.ForwardRule(time.Now().Add(time.Hour), 2, tr.Entry.ID)) require.NoError(t, err) rule, err := rt.Rule(routeID) require.NoError(t, err) assert.Equal(t, routing.RouteID(2), rule.RouteID()) - assert.Equal(t, tr.ID, rule.TransportID()) + assert.Equal(t, tr.Entry.ID, rule.TransportID()) }) t.Run("`confirm loop - responder", func(t *testing.T) { @@ -371,7 +371,7 @@ func TestRouterSetup(t *testing.T) { loop, err := r.pm.GetLoop(2, &app.Addr{PubKey: pk2, Port: 1}) require.NoError(t, err) require.NotNil(t, loop) - assert.Equal(t, tr.ID, loop.trID) + assert.Equal(t, tr.Entry.ID, loop.trID) assert.Equal(t, routing.RouteID(2), loop.routeID) addrs := [2]*app.Addr{} @@ -427,7 +427,7 @@ func TestRouterSetup(t *testing.T) { l, err := r.pm.GetLoop(2, &app.Addr{PubKey: pk2, Port: 1}) require.NoError(t, err) require.NotNil(t, l) - assert.Equal(t, tr.ID, l.trID) + assert.Equal(t, tr.Entry.ID, l.trID) assert.Equal(t, routing.RouteID(2), l.routeID) addrs := [2]*app.Addr{} diff --git a/pkg/setup/node_test.go b/pkg/setup/node_test.go index 284c70d2ce..0713fb7da1 100644 --- a/pkg/setup/node_test.go +++ b/pkg/setup/node_test.go @@ -99,12 +99,12 @@ func TestCreateLoop(t *testing.T) { l := &routing.Loop{LocalPort: 1, RemotePort: 2, Expiry: time.Now().Add(time.Hour), Forward: routing.Route{ - &routing.Hop{From: pk1, To: pk2, Transport: tr1.ID}, - &routing.Hop{From: pk2, To: pk3, Transport: tr3.ID}, + &routing.Hop{From: pk1, To: pk2, Transport: tr1.Entry.ID}, + &routing.Hop{From: pk2, To: pk3, Transport: tr3.Entry.ID}, }, Reverse: routing.Route{ - &routing.Hop{From: pk3, To: pk2, Transport: tr3.ID}, - &routing.Hop{From: pk2, To: pk1, Transport: tr1.ID}, + &routing.Hop{From: pk3, To: pk2, Transport: tr3.Entry.ID}, + &routing.Hop{From: pk2, To: pk1, Transport: tr1.Entry.ID}, }, } @@ -132,25 +132,25 @@ func TestCreateLoop(t *testing.T) { assert.Equal(t, uint16(1), rule.LocalPort()) rule = rules[2] assert.Equal(t, routing.RuleForward, rule.Type()) - assert.Equal(t, tr1.ID, rule.TransportID()) + assert.Equal(t, tr1.Entry.ID, rule.TransportID()) assert.Equal(t, routing.RouteID(2), rule.RouteID()) rules = n2.getRules() require.Len(t, rules, 2) rule = rules[1] assert.Equal(t, routing.RuleForward, rule.Type()) - assert.Equal(t, tr1.ID, rule.TransportID()) + assert.Equal(t, tr1.Entry.ID, rule.TransportID()) assert.Equal(t, routing.RouteID(1), rule.RouteID()) rule = rules[2] assert.Equal(t, routing.RuleForward, rule.Type()) - assert.Equal(t, tr3.ID, rule.TransportID()) + assert.Equal(t, tr3.Entry.ID, rule.TransportID()) assert.Equal(t, routing.RouteID(2), rule.RouteID()) rules = n3.getRules() require.Len(t, rules, 2) rule = rules[1] assert.Equal(t, routing.RuleForward, rule.Type()) - assert.Equal(t, tr3.ID, rule.TransportID()) + assert.Equal(t, tr3.Entry.ID, rule.TransportID()) assert.Equal(t, routing.RouteID(1), rule.RouteID()) rule = rules[2] assert.Equal(t, routing.RuleApp, rule.Type()) diff --git a/pkg/transport/log.go b/pkg/transport/log.go index 8ae0b1bc32..1983b0ab42 100644 --- a/pkg/transport/log.go +++ b/pkg/transport/log.go @@ -1,12 +1,16 @@ package transport import ( + "bytes" + "encoding/gob" "encoding/json" + "errors" "fmt" - "math/big" "os" "path/filepath" + "strconv" "sync" + "sync/atomic" "github.com/google/uuid" ) @@ -14,8 +18,55 @@ import ( // LogEntry represents a logging entry for a given Transport. // The entry is updated every time a packet is received or sent. type LogEntry struct { - ReceivedBytes *big.Int `json:"received"` // Total received bytes. - SentBytes *big.Int `json:"sent"` // Total sent bytes. + RecvBytes uint64 `json:"recv"` // Total received bytes. + SentBytes uint64 `json:"sent"` // Total sent bytes. +} + +// AddRecv records read. +func (le *LogEntry) AddRecv(n uint64) { + atomic.AddUint64(&le.RecvBytes, n) +} + +// AddSent records write. +func (le *LogEntry) AddSent(n uint64) { + atomic.AddUint64(&le.SentBytes, n) +} + +// MarshalJSON implements json.Marshaller +func (le *LogEntry) MarshalJSON() ([]byte, error) { + rb := strconv.FormatUint(atomic.LoadUint64(&le.RecvBytes), 10) + sb := strconv.FormatUint(atomic.LoadUint64(&le.SentBytes), 10) + return []byte(`{"recv":` + rb + `,"sent":` + sb + `}`), nil +} + +// GobEncode implements gob.GobEncoder +func (le *LogEntry) GobEncode() ([]byte, error) { + var b bytes.Buffer + enc := gob.NewEncoder(&b) + if err := enc.Encode(le.RecvBytes); err != nil { + return nil, err + } + if err := enc.Encode(le.SentBytes); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// GobDecode implements gob.GobDecoder +func (le *LogEntry) GobDecode(b []byte) error { + r := bytes.NewReader(b) + dec := gob.NewDecoder(r) + var rb uint64 + if err := dec.Decode(&rb); err != nil { + return err + } + var sb uint64 + if err := dec.Decode(&sb); err != nil { + return err + } + atomic.StoreUint64(&le.RecvBytes, rb) + atomic.StoreUint64(&le.SentBytes, sb) + return nil } // LogStore stores transport log entries. @@ -32,14 +83,17 @@ type inMemoryTransportLogStore struct { // InMemoryTransportLogStore implements in-memory TransportLogStore. func InMemoryTransportLogStore() LogStore { return &inMemoryTransportLogStore{ - entries: map[uuid.UUID]*LogEntry{}, + entries: make(map[uuid.UUID]*LogEntry), } } func (tls *inMemoryTransportLogStore) Entry(id uuid.UUID) (*LogEntry, error) { tls.mu.Lock() - entry := tls.entries[id] + entry, ok := tls.entries[id] tls.mu.Unlock() + if !ok { + return entry, errors.New("transport log entry not found") + } return entry, nil } diff --git a/pkg/transport/log_test.go b/pkg/transport/log_test.go index b118f57deb..1c3f577728 100644 --- a/pkg/transport/log_test.go +++ b/pkg/transport/log_test.go @@ -1,8 +1,9 @@ package transport_test import ( + "encoding/json" + "fmt" "io/ioutil" - "math/big" "os" "testing" @@ -17,17 +18,22 @@ func testTransportLogStore(t *testing.T, logStore transport.LogStore) { t.Helper() id1 := uuid.New() - entry1 := &transport.LogEntry{big.NewInt(100), big.NewInt(200)} + entry1 := new(transport.LogEntry) + entry1.AddRecv(100) + entry1.AddSent(200) + id2 := uuid.New() - entry2 := &transport.LogEntry{big.NewInt(300), big.NewInt(400)} + entry2 := new(transport.LogEntry) + entry2.AddRecv(300) + entry2.AddSent(400) require.NoError(t, logStore.Record(id1, entry1)) require.NoError(t, logStore.Record(id2, entry2)) entry, err := logStore.Entry(id2) require.NoError(t, err) - assert.Equal(t, int64(300), entry.ReceivedBytes.Int64()) - assert.Equal(t, int64(400), entry.SentBytes.Int64()) + assert.Equal(t, uint64(300), entry.RecvBytes) + assert.Equal(t, uint64(400), entry.SentBytes) } func TestInMemoryTransportLogStore(t *testing.T) { @@ -43,3 +49,24 @@ func TestFileTransportLogStore(t *testing.T) { require.NoError(t, err) testTransportLogStore(t, ls) } + +func TestLogEntry_MarshalJSON(t *testing.T) { + entry := new(transport.LogEntry) + entry.AddSent(10) + entry.AddRecv(100) + b, err := json.Marshal(entry) + require.NoError(t, err) + fmt.Println(string(b)) + b, err = json.MarshalIndent(entry, "", "\t") + require.NoError(t, err) + fmt.Println(string(b)) +} + +func TestLogEntry_GobEncode(t *testing.T) { + var entry transport.LogEntry + + enc, err := entry.GobEncode() + require.NoError(t, err) + + require.NoError(t, entry.GobDecode(enc)) +} diff --git a/pkg/transport/managed_transport.go b/pkg/transport/managed_transport.go index a861b537e6..66a675fa5e 100644 --- a/pkg/transport/managed_transport.go +++ b/pkg/transport/managed_transport.go @@ -1,106 +1,108 @@ package transport import ( - "math/big" + "context" "sync" - - "github.com/google/uuid" + "time" ) +const logWriteInterval = time.Second * 3 + // ManagedTransport is a wrapper transport. It stores status and ID of // the Transport and can notify about network errors. type ManagedTransport struct { Transport - ID uuid.UUID - Public bool + Entry Entry Accepted bool + Setup bool LogEntry *LogEntry - doneChan chan struct{} - errChan chan error - mu sync.RWMutex - once sync.Once - - readLogChan chan int - writeLogChan chan int + done chan struct{} + update chan error + mu sync.RWMutex + once sync.Once } -func newManagedTransport(id uuid.UUID, tr Transport, public bool, accepted bool) *ManagedTransport { +func newManagedTransport(tr Transport, entry Entry, accepted bool) *ManagedTransport { return &ManagedTransport{ - ID: id, - Transport: tr, - Public: public, - Accepted: accepted, - doneChan: make(chan struct{}), - errChan: make(chan error), - readLogChan: make(chan int, 16), - writeLogChan: make(chan int, 16), - LogEntry: &LogEntry{new(big.Int), new(big.Int)}, + Transport: tr, + Entry: entry, + Accepted: accepted, + done: make(chan struct{}), + update: make(chan error, 16), + LogEntry: new(LogEntry), } } -// Read reads using underlying +// Read reads using underlying transport. func (tr *ManagedTransport) Read(p []byte) (n int, err error) { tr.mu.RLock() - n, err = tr.Transport.Read(p) // TODO: data race. - tr.mu.RUnlock() - - if err != nil { - tr.errChan <- err + n, err = tr.Transport.Read(p) + if n > 0 { + tr.LogEntry.AddRecv(uint64(n)) } - - tr.readLogChan <- n + if !tr.isClosing() { + select { + case tr.update <- err: + default: + } + } + tr.mu.RUnlock() return } -// Write writes to an underlying +// Write writes to an underlying transport. func (tr *ManagedTransport) Write(p []byte) (n int, err error) { tr.mu.RLock() n, err = tr.Transport.Write(p) - tr.mu.RUnlock() - - if err != nil { - tr.errChan <- err - return + if n > 0 { + tr.LogEntry.AddSent(uint64(n)) } - tr.writeLogChan <- n - + if !tr.isClosing() { + select { + case tr.update <- err: + default: + } + } + tr.mu.RUnlock() return } -// killWorker sends signal to Manager.manageTransport goroutine to exit -// it's safe to call it multiple times func (tr *ManagedTransport) killWorker() { tr.once.Do(func() { - close(tr.doneChan) + close(tr.done) }) } -// Close closes underlying +func (tr *ManagedTransport) killUpdate() { + tr.mu.Lock() + close(tr.update) + tr.update = nil + tr.mu.Unlock() +} + +// Close closes underlying transport and kills worker. func (tr *ManagedTransport) Close() error { if tr == nil { return nil } - - tr.mu.RLock() - err := tr.Transport.Close() - tr.mu.RUnlock() - tr.killWorker() - return err + return tr.Transport.Close() } func (tr *ManagedTransport) isClosing() bool { select { - case <-tr.doneChan: + case <-tr.done: return true default: return false } } -func (tr *ManagedTransport) updateTransport(newTr Transport) { +func (tr *ManagedTransport) updateTransport(ctx context.Context, newTr Transport, dc DiscoveryClient) error { tr.mu.Lock() tr.Transport = newTr + _, err := dc.UpdateStatuses(ctx, &Status{ID: tr.Entry.ID, IsUp: true, Updated: time.Now().UnixNano()}) tr.mu.Unlock() + return err } diff --git a/pkg/transport/manager.go b/pkg/transport/manager.go index 599f18a283..8614c39615 100644 --- a/pkg/transport/manager.go +++ b/pkg/transport/manager.go @@ -4,7 +4,6 @@ import ( "context" "errors" "io" - "math/big" "strings" "sync" "sync/atomic" @@ -210,8 +209,8 @@ func (tm *Manager) CreateTransport(ctx context.Context, remote cipher.PubKey, tp func (tm *Manager) DeleteTransport(id uuid.UUID) error { tm.mu.Lock() if tr, ok := tm.transports[id]; ok { - delete(tm.transports, id) _ = tr.Close() //nolint:errcheck + delete(tm.transports, id) } tm.mu.Unlock() @@ -235,10 +234,10 @@ func (tm *Manager) Close() error { tm.mu.Lock() statuses := make([]*Status, 0) for _, tr := range tm.transports { - if !tr.Public { + if !tr.Entry.Public { continue } - statuses = append(statuses, &Status{ID: tr.ID, IsUp: false}) + statuses = append(statuses, &Status{ID: tr.Entry.ID, IsUp: false}) tr.Close() } @@ -292,7 +291,7 @@ func (tm *Manager) createTransport(ctx context.Context, remote cipher.PubKey, tp } tm.Logger.Infof("Dialed to %s using %s factory. Transport ID: %s", remote, tpType, entry.ID) - mTr := newManagedTransport(entry.ID, tr, entry.Public, false) + mTr := newManagedTransport(tr, *entry, false) tm.mu.Lock() tm.transports[entry.ID] = mTr @@ -302,7 +301,7 @@ func (tm *Manager) createTransport(ctx context.Context, remote cipher.PubKey, tp case <-tm.doneChan: return nil, io.ErrClosedPipe case tm.TrChan <- mTr: - go tm.manageTransport(ctx, mTr, factory, remote, public, false) + go tm.manageTransport(ctx, mTr, factory, remote) return mTr, nil } } @@ -334,7 +333,8 @@ func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) (*Manag if oldTr != nil { oldTr.killWorker() } - mTr := newManagedTransport(entry.ID, tr, entry.Public, true) + + mTr := newManagedTransport(tr, *entry, true) tm.mu.Lock() tm.transports[entry.ID] = mTr @@ -344,7 +344,7 @@ func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) (*Manag case <-tm.doneChan: return nil, io.ErrClosedPipe case tm.TrChan <- mTr: - go tm.manageTransport(ctx, mTr, factory, remote, true, true) + go tm.manageTransport(ctx, mTr, factory, remote) return mTr, nil } } @@ -374,47 +374,67 @@ func (tm *Manager) isClosing() bool { } } -func (tm *Manager) manageTransport(ctx context.Context, mTr *ManagedTransport, factory Factory, remote cipher.PubKey, public bool, accepted bool) { +func (tm *Manager) manageTransport(ctx context.Context, mTr *ManagedTransport, factory Factory, remote cipher.PubKey) { + logTicker := time.NewTicker(logWriteInterval) + logUpdate := false + mgrQty := atomic.AddInt32(&tm.mgrQty, 1) - tm.Logger.Infof("Spawned manageTransport for mTr.ID: %v. mgrQty: %v", mTr.ID, mgrQty) + tm.Logger.Infof("Spawned manageTransport for mTr.ID: %v. mgrQty: %v PK: %s", mTr.Entry.ID, mgrQty, remote) + + defer func() { + logTicker.Stop() + if logUpdate { + if err := tm.config.LogStore.Record(mTr.Entry.ID, mTr.LogEntry); err != nil { + tm.Logger.Warnf("Failed to record log entry: %s", err) + } + } + mTr.killUpdate() + + mgrQty := atomic.AddInt32(&tm.mgrQty, -1) + tm.Logger.Infof("manageTransport exit for %v. mgrQty: %v", mTr.Entry.ID, mgrQty) + }() + for { select { - case <-mTr.doneChan: - mgrQty := atomic.AddInt32(&tm.mgrQty, -1) - tm.Logger.Infof("manageTransport exit for %v. mgrQty: %v", mTr.ID, mgrQty) + case <-mTr.done: return - case err := <-mTr.errChan: - if !mTr.isClosing() { - tm.Logger.Infof("Transport %s failed with error: %s. Re-dialing...", mTr.ID, err) - if accepted { - if err := tm.DeleteTransport(mTr.ID); err != nil { - tm.Logger.Warnf("Failed to delete accepted transport: %s", err) - } - } else { - tr, _, err := tm.dialTransport(ctx, factory, remote, public) - if err != nil { - tm.Logger.Infof("Failed to redial Transport %s: %s", mTr.ID, err) - if err := tm.DeleteTransport(mTr.ID); err != nil { - tm.Logger.Warnf("Failed to delete redialed transport: %s", err) - } - } else { - tm.Logger.Infof("Updating transport %s", mTr.ID) - mTr.updateTransport(tr) - } + + case <-logTicker.C: + if logUpdate { + if err := tm.config.LogStore.Record(mTr.Entry.ID, mTr.LogEntry); err != nil { + tm.Logger.Warnf("Failed to record log entry: %s", err) } - } else { - tm.Logger.Infof("Transport %s is already closing. Skipped error: %s", mTr.ID, err) } - case n := <-mTr.readLogChan: - mTr.LogEntry.ReceivedBytes.Add(mTr.LogEntry.ReceivedBytes, big.NewInt(int64(n))) - if err := tm.config.LogStore.Record(mTr.ID, mTr.LogEntry); err != nil { - tm.Logger.Warnf("Failed to record log entry: %s", err) + + case err, ok := <-mTr.update: + if !ok { + return } - case n := <-mTr.writeLogChan: - mTr.LogEntry.SentBytes.Add(mTr.LogEntry.SentBytes, big.NewInt(int64(n))) - if err := tm.config.LogStore.Record(mTr.ID, mTr.LogEntry); err != nil { - tm.Logger.Warnf("Failed to record log entry: %s", err) + + if err == nil { + logUpdate = true + continue + } + + tm.Logger.Infof("Transport %s failed with error: %s. Re-dialing...", mTr.Entry.ID, err) + if _, err := tm.config.DiscoveryClient.UpdateStatuses(ctx, &Status{ID: mTr.Entry.ID, IsUp: false, Updated: time.Now().UnixNano()}); err != nil { + tm.Logger.Warnf("Failed to change transport status: %s", err) + } + + // If we are the acceptor, we are not responsible for restarting transport. + // If the transport is private, we don't need to restart. + if mTr.Accepted || !mTr.Entry.Public { + return } + + tr, _, err := tm.dialTransport(ctx, factory, remote, mTr.Entry.Public) + if err != nil { + tm.Logger.Infof("Failed to redial Transport %s: %s", mTr.Entry.ID, err) + continue + } + + tm.Logger.Infof("Updating transport %s", mTr.Entry.ID) + _ = mTr.updateTransport(ctx, tr, tm.config.DiscoveryClient) //nolint:errcheck } } } diff --git a/pkg/transport/manager_test.go b/pkg/transport/manager_test.go index 0994271e16..4f0b30f729 100644 --- a/pkg/transport/manager_test.go +++ b/pkg/transport/manager_test.go @@ -87,16 +87,16 @@ func TestTransportManager(t *testing.T) { time.Sleep(time.Second) - tr1 := m1.Transport(tr2.ID) + tr1 := m1.Transport(tr2.Entry.ID) require.NotNil(t, tr1) - dEntry, err := client.GetTransportByID(context.TODO(), tr2.ID) + dEntry, err := client.GetTransportByID(context.TODO(), tr2.Entry.ID) require.NoError(t, err) assert.Equal(t, SortPubKeys(pk2, pk1), dEntry.Entry.Edges()) assert.True(t, dEntry.IsUp) - require.NoError(t, m1.DeleteTransport(tr1.ID)) - dEntry, err = client.GetTransportByID(context.TODO(), tr1.ID) + require.NoError(t, m1.DeleteTransport(tr1.Entry.ID)) + dEntry, err = client.GetTransportByID(context.TODO(), tr1.Entry.ID) require.NoError(t, err) assert.False(t, dEntry.IsUp) @@ -106,12 +106,12 @@ func TestTransportManager(t *testing.T) { time.Sleep(time.Second) - dEntry, err = client.GetTransportByID(context.TODO(), tr1.ID) + dEntry, err = client.GetTransportByID(context.TODO(), tr1.Entry.ID) require.NoError(t, err) assert.True(t, dEntry.IsUp) - require.NoError(t, m2.DeleteTransport(tr2.ID)) - dEntry, err = client.GetTransportByID(context.TODO(), tr2.ID) + require.NoError(t, m2.DeleteTransport(tr2.Entry.ID)) + dEntry, err = client.GetTransportByID(context.TODO(), tr2.Entry.ID) require.NoError(t, err) assert.False(t, dEntry.IsUp) @@ -153,17 +153,17 @@ func TestTransportManagerReEstablishTransports(t *testing.T) { tr2, err := m2.CreateTransport(context.TODO(), pk1, "mock", true) require.NoError(t, err) - tr1 := m1.Transport(tr2.ID) + tr1 := m1.Transport(tr2.Entry.ID) require.NotNil(t, tr1) - dEntry, err := client.GetTransportByID(context.TODO(), tr2.ID) + dEntry, err := client.GetTransportByID(context.TODO(), tr2.Entry.ID) require.NoError(t, err) assert.Equal(t, SortPubKeys(pk2, pk1), dEntry.Entry.Edges()) assert.True(t, dEntry.IsUp) require.NoError(t, m2.Close()) - dEntry2, err := client.GetTransportByID(context.TODO(), tr2.ID) + dEntry2, err := client.GetTransportByID(context.TODO(), tr2.Entry.ID) require.NoError(t, err) assert.False(t, dEntry2.IsUp) @@ -176,7 +176,7 @@ func TestTransportManagerReEstablishTransports(t *testing.T) { go func() { m2errCh <- m2.Serve(context.TODO()) }() //time.Sleep(time.Second * 1) // TODO: this time.Sleep looks fishy - figure out later - dEntry3, err := client.GetTransportByID(context.TODO(), tr2.ID) + dEntry3, err := client.GetTransportByID(context.TODO(), tr2.Entry.ID) require.NoError(t, err) assert.True(t, dEntry3.IsUp) @@ -218,7 +218,7 @@ func TestTransportManagerLogs(t *testing.T) { time.Sleep(100 * time.Millisecond) - tr1 := m1.Transport(tr2.ID) + tr1 := m1.Transport(tr2.Entry.ID) require.NotNil(t, tr1) go tr1.Write([]byte("foo")) // nolint @@ -226,17 +226,18 @@ func TestTransportManagerLogs(t *testing.T) { _, err = tr2.Read(buf) require.NoError(t, err) - time.Sleep(100 * time.Millisecond) + // 2x log write interval just to be safe. + time.Sleep(logWriteInterval * 2) - entry1, err := logStore1.Entry(tr1.ID) + entry1, err := logStore1.Entry(tr1.Entry.ID) require.NoError(t, err) - assert.Equal(t, uint64(3), entry1.SentBytes.Uint64()) - assert.Equal(t, uint64(0), entry1.ReceivedBytes.Uint64()) + assert.Equal(t, uint64(3), entry1.SentBytes) + assert.Equal(t, uint64(0), entry1.RecvBytes) - entry2, err := logStore2.Entry(tr1.ID) + entry2, err := logStore2.Entry(tr1.Entry.ID) require.NoError(t, err) - assert.Equal(t, uint64(0), entry2.SentBytes.Uint64()) - assert.Equal(t, uint64(3), entry2.ReceivedBytes.Uint64()) + assert.Equal(t, uint64(0), entry2.SentBytes) + assert.Equal(t, uint64(3), entry2.RecvBytes) require.NoError(t, m2.Close()) require.NoError(t, m1.Close()) @@ -314,7 +315,7 @@ func ExampleManager_CreateTransport() { return } - if (mtrAB.ID == uuid.UUID{}) { + if (mtrAB.Entry.ID == uuid.UUID{}) { fmt.Printf("Manager.CreateTransport failed on iteration %v", i) return }