From 985f0566ed1110b747888ab347052abf1b82d29f Mon Sep 17 00:00:00 2001 From: Evan Lin Date: Tue, 13 Aug 2019 07:43:29 +0800 Subject: [PATCH] Use a single underlying connection for transports. --- pkg/router/route_manager.go | 4 + pkg/router/router.go | 4 +- pkg/routing/rule.go | 2 +- pkg/transport/managed_transport.go | 204 ++++++++++++++--------------- pkg/transport/manager.go | 13 +- pkg/transport/manager_test.go | 4 + vendor/modules.txt | 2 +- 7 files changed, 118 insertions(+), 115 deletions(-) diff --git a/pkg/router/route_manager.go b/pkg/router/route_manager.go index d944c56efa..614a568f1d 100644 --- a/pkg/router/route_manager.go +++ b/pkg/router/route_manager.go @@ -35,6 +35,10 @@ func (rm *routeManager) GetRule(routeID routing.RouteID) (routing.Rule, error) { return nil, errors.New("unknown RouteID") } + if len(rule) < 13 { + return nil, errors.New("corrupted rule") + } + if rule.Expiry().Before(time.Now()) { return nil, errors.New("expired routing rule") } diff --git a/pkg/router/router.go b/pkg/router/router.go index 248edd6f5b..c2556d9060 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -87,7 +87,7 @@ func (r *Router) Serve(ctx context.Context) error { if err != nil { return } - if err := r.handlePacket(ctx, packet); err != nil { + if err := r.handlePacket(ctx, packet); err != nil { // TODO: race if err == transport.ErrNotServing { r.Logger.WithError(err).Warnf("Stopped serving Transport.") return @@ -214,7 +214,7 @@ func (r *Router) forwardPacket(ctx context.Context, payload []byte, rule routing if err := tp.WritePacket(ctx, rule.RouteID(), payload); err != nil { return err } - r.Logger.Infof("Forwarded packet via Transport %s using rule %d", rule.TransportID(), rule.RouteID()) + r.Logger.Infof("Forwarded packet via Transport %s using rule %d", rule.TransportID(), rule.RouteID()) // TODO: race TransportID() return nil } diff --git a/pkg/routing/rule.go b/pkg/routing/rule.go index a9a6012e6a..9893e3a5e3 100644 --- a/pkg/routing/rule.go +++ b/pkg/routing/rule.go @@ -43,7 +43,7 @@ func (r Rule) Expiry() time.Time { } // Type returns type of a rule. -func (r Rule) Type() RuleType { +func (r Rule) Type() RuleType { // TODO: segfault return RuleType(r[8]) } diff --git a/pkg/transport/managed_transport.go b/pkg/transport/managed_transport.go index f0c636587f..84654371a3 100644 --- a/pkg/transport/managed_transport.go +++ b/pkg/transport/managed_transport.go @@ -21,8 +21,13 @@ const logWriteInterval = time.Second * 3 // Records number of managedTransports. var mTpCount int32 -// ErrNotServing is the error returned when a transport is no longer served. -var ErrNotServing = errors.New("transport is no longer being served") +var ( + // ErrNotServing is the error returned when a transport is no longer served. + ErrNotServing = errors.New("transport is no longer being served") + + // ErrConnAlreadyExists occurs when an underlying transport connection already exists. + ErrConnAlreadyExists = errors.New("underlying transport connection already exists") +) // ManagedTransport manages a direct line of communication between two visor nodes. // It is made up of two underlying uni-directional connections. @@ -40,11 +45,9 @@ type ManagedTransport struct { LogEntry *LogEntry logUpdates uint32 - readConn Transport - writeConn Transport - acceptCh chan Transport - acceptMx sync.RWMutex - dialMx sync.Mutex + conn Transport + connCh chan struct{} + connMx sync.Mutex done chan struct{} once sync.Once @@ -62,7 +65,7 @@ func NewManagedTransport(fac Factory, dc DiscoveryClient, ls LogStore, rPK ciphe ls: ls, Entry: makeEntry(fac.Local(), rPK, dmsg.Type), LogEntry: new(LogEntry), - acceptCh: make(chan Transport, 1), + connCh: make(chan struct{}, 1), done: make(chan struct{}), } mt.wg.Add(2) @@ -70,7 +73,7 @@ func NewManagedTransport(fac Factory, dc DiscoveryClient, ls LogStore, rPK ciphe } // Serve serves and manages the transport. -func (mt *ManagedTransport) Serve(readCh chan<- routing.Packet) { +func (mt *ManagedTransport) Serve(readCh chan<- routing.Packet, done <-chan struct{}) { defer mt.wg.Done() ctx, cancel := context.WithCancel(context.Background()) @@ -93,23 +96,14 @@ func (mt *ManagedTransport) Serve(readCh chan<- routing.Packet) { } } - // End reading connection. - mt.acceptMx.Lock() - close(mt.acceptCh) - mt.acceptCh = nil - if mt.readConn != nil { - _ = mt.readConn.Close() //nolint:errcheck - mt.readConn = nil - } - mt.acceptMx.Unlock() - - // End writing connection. - mt.dialMx.Lock() - if mt.writeConn != nil { - _ = mt.writeConn.Close() //nolint:errcheck - mt.writeConn = nil + // End connection. + mt.connMx.Lock() + close(mt.connCh) + if mt.conn != nil { + _ = mt.conn.Close() //nolint:errcheck + mt.conn = nil } - mt.dialMx.Unlock() + mt.connMx.Unlock() }() go func() { @@ -123,13 +117,16 @@ func (mt *ManagedTransport) Serve(readCh chan<- routing.Packet) { if err == ErrNotServing { return } + mt.connMx.Lock() + mt.clearConn(ctx) + mt.connMx.Unlock() mt.log.Warnf("failed to read packet: %v", err) continue } - if !mt.isServing() { - return + select { + case <-done: + case readCh <- p: } - readCh <- p // TODO: data race } }() @@ -145,16 +142,13 @@ func (mt *ManagedTransport) Serve(readCh chan<- routing.Packet) { } } else { // If there has not been any activity, ensure underlying 'write' tp is still up. - mt.dialMx.Lock() - if mt.writeConn == nil { - if !mt.isServing() { - return - } + mt.connMx.Lock() + if mt.conn == nil { if err := mt.dial(ctx); err != nil { - mt.log.Warnf("failed to dial underlying 'write' transport: %v", err) + mt.log.Warnf("failed to redial underlying connection: %v", err) } } - mt.dialMx.Unlock() + mt.connMx.Unlock() } } } @@ -188,10 +182,10 @@ func (mt *ManagedTransport) close() (closed bool) { return closed } -// Accept accepts a new underlying 'read' connection (and close/replace the old one). +// Accept accepts a new underlying connection. func (mt *ManagedTransport) Accept(ctx context.Context, tp Transport) error { - mt.acceptMx.RLock() - defer mt.acceptMx.RUnlock() + mt.connMx.Lock() + defer mt.connMx.Unlock() if !mt.isServing() { _ = tp.Close() //nolint:errcheck @@ -204,125 +198,127 @@ func (mt *ManagedTransport) Accept(ctx context.Context, tp Transport) error { return fmt.Errorf("settlement handshake failed: %v", err) } - for { - select { - case oldTp, ok := <-mt.acceptCh: - if !ok { - return ErrNotServing - } - _ = oldTp.Close() //nolint:errcheck - default: - mt.acceptCh <- tp - return nil - } - } + return mt.setIfConnNil(ctx, tp) } -// Dial dials a new underlying 'write' connection (and close/replace the old one). +// Dial dials a new underlying connection. func (mt *ManagedTransport) Dial(ctx context.Context) error { - mt.dialMx.Lock() - defer mt.dialMx.Unlock() + mt.connMx.Lock() + defer mt.connMx.Unlock() if !mt.isServing() { return ErrNotServing } - if mt.writeConn != nil { - _ = mt.writeConn.Close() //nolint:errcheck + if mt.conn != nil { + return nil } return mt.dial(ctx) } +// TODO: Figure out where this fella is called. func (mt *ManagedTransport) dial(ctx context.Context) error { tp, err := mt.fac.Dial(ctx, mt.rPK) if err != nil { return err } + ctx, cancel := context.WithTimeout(ctx, time.Second*20) defer cancel() if err := MakeSettlementHS(true).Do(ctx, mt.dc, tp, mt.lSK); err != nil { return fmt.Errorf("settlement handshake failed: %v", err) } - mt.writeConn = tp + + return mt.setIfConnNil(ctx, tp) +} + +func (mt *ManagedTransport) getConn() Transport { + mt.connMx.Lock() + conn := mt.conn + mt.connMx.Unlock() + return conn +} + +// sets conn if `mt.conn` is nil otherwise, closes the conn. +// TODO: Add logging here. +func (mt *ManagedTransport) setIfConnNil(ctx context.Context, conn Transport) error { + if mt.conn != nil { + _ = conn.Close() //nolint:errcheck + return ErrConnAlreadyExists + } + + if _, err := mt.dc.UpdateStatuses(ctx, &Status{ID: mt.Entry.ID, IsUp: true}); err != nil { + mt.log.Warnf("Failed to update transport status: %s", err) + } + mt.log.Infoln("Status updated: UP") + mt.conn = conn + select { + case mt.connCh <- struct{}{}: + default: + } return nil } +func (mt *ManagedTransport) clearConn(ctx context.Context) { + if _, err := mt.dc.UpdateStatuses(ctx, &Status{ID: mt.Entry.ID, IsUp: false}); err != nil { + mt.log.Warnf("Failed to update transport status: %s", err) + } + mt.log.Infoln("Status updated: DOWN") + mt.conn = nil +} + // WritePacket writes a packet to the remote. -func (mt *ManagedTransport) WritePacket(ctx context.Context, rtID routing.RouteID, payload []byte) (err error) { - mt.dialMx.Lock() - defer mt.dialMx.Unlock() +func (mt *ManagedTransport) WritePacket(ctx context.Context, rtID routing.RouteID, payload []byte) error { + mt.connMx.Lock() + defer mt.connMx.Unlock() if !mt.isServing() { return ErrNotServing } - if mt.writeConn == nil { // TODO: race condition + if mt.conn == nil { if err := mt.dial(ctx); err != nil { - return fmt.Errorf("failed to redial transport: %v", err) + return fmt.Errorf("failed to redial underlying connection: %v", err) } } - n, err := mt.writeConn.Write(routing.MakePacket(rtID, payload)) + n, err := mt.conn.Write(routing.MakePacket(rtID, payload)) if err != nil { - if _, err := mt.dc.UpdateStatuses(context.Background(), &Status{ID: mt.Entry.ID, IsUp: false}); err != nil { - mt.log.Warnf("Failed to change transport status: %s", err) - } - mt.writeConn = nil + mt.clearConn(ctx) return err } - if n > 0 { - mt.logSent(uint64(len(payload))) + if n > 6 { + mt.logSent(uint64(n - 6)) } return nil } -func (mt *ManagedTransport) latestReadTp() (Transport, error) { - mt.acceptMx.RLock() - defer mt.acceptMx.RUnlock() - - if mt.readConn != nil { - return mt.readConn, nil - } - - select { - case <-mt.done: - return nil, ErrNotServing - - case tp, ok := <-mt.acceptCh: - if !ok { - return nil, ErrNotServing - } - mt.readConn = tp - return mt.readConn, nil - } -} - // WARNING: Not thread safe. func (mt *ManagedTransport) readPacket() (packet routing.Packet, err error) { - tp, err := mt.latestReadTp() - if err != nil { - return nil, err - } - - defer func() { - if err != nil && mt.isServing() { - mt.acceptMx.RLock() - mt.readConn = nil - mt.acceptMx.RUnlock() + var conn Transport + for { + if conn = mt.getConn(); conn != nil { + break } - }() + select { + case <-mt.done: + return nil, ErrNotServing + case <-mt.connCh: + } + } h := make(routing.Packet, 6) - if _, err := io.ReadFull(tp, h); err != nil { + if _, err = io.ReadFull(conn, h); err != nil { return nil, err } - p := make([]byte, h.Size()) - if _, err := io.ReadFull(tp, p); err != nil { + if _, err = io.ReadFull(conn, p); err != nil { return nil, err } packet = append(h, p...) - mt.logRecv(uint64(len(p))) + if n := len(packet); n > 6 { + mt.logRecv(uint64(n - 6)) + } mt.log.Infof("recv packet: rtID(%d) size(%d)", packet.RouteID(), packet.Size()) return packet, nil } diff --git a/pkg/transport/manager.go b/pkg/transport/manager.go index e7645ab7e7..cc17d0ef94 100644 --- a/pkg/transport/manager.go +++ b/pkg/transport/manager.go @@ -43,6 +43,8 @@ func NewManager(config *ManagerConfig, factories ...Factory) (*Manager, error) { log := logging.MustGetLogger("tp_manager") ctx := context.Background() + done := make(chan struct{}) + fMap := make(map[string]Factory) for _, factory := range factories { fMap[factory.Type()] = factory @@ -62,7 +64,7 @@ func NewManager(config *ManagerConfig, factories ...Factory) (*Manager, error) { continue } mTp := NewManagedTransport(fac, config.DiscoveryClient, config.LogStore, entry.Entry.RemoteEdge(config.PubKey), config.SecKey) - go mTp.Serve(rCh) + go mTp.Serve(rCh, done) tpMap[entry.Entry.ID] = mTp } @@ -73,7 +75,7 @@ func NewManager(config *ManagerConfig, factories ...Factory) (*Manager, error) { tps: tpMap, setupCh: make(chan Transport, 9), // TODO: eliminate or justify buffering here readCh: rCh, - done: make(chan struct{}), + done: done, }, nil } @@ -159,10 +161,7 @@ func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) error { if err := mTp.Accept(ctx, tr); err != nil { return err } - if err := mTp.Dial(ctx); err != nil { - return err - } - go mTp.Serve(tm.readCh) + go mTp.Serve(tm.readCh, tm.done) tm.tps[tpID] = mTp } else { @@ -199,7 +198,7 @@ func (tm *Manager) SaveTransport(ctx context.Context, remote cipher.PubKey, tpTy if err := mTp.Dial(ctx); err != nil { tm.Logger.Warnf("underlying 'write' tp failed, will retry: %v", err) } - go mTp.Serve(tm.readCh) + go mTp.Serve(tm.readCh, tm.done) tm.tps[tpID] = mTp tm.Logger.Infof("saved transport: remote(%s) type(%s) tpID(%s)", remote, tpType, tpID) diff --git a/pkg/transport/manager_test.go b/pkg/transport/manager_test.go index 42f4caf48b..2275bff987 100644 --- a/pkg/transport/manager_test.go +++ b/pkg/transport/manager_test.go @@ -58,6 +58,7 @@ func TestNewManager(t *testing.T) { require.NoError(t, m1.Close()) require.NoError(t, <-m1Err) }() + fmt.Println("tp manager 1 prepared") // Prepare tp manager 2. pk2, sk2 := keys[1].PK, keys[1].SK @@ -75,6 +76,7 @@ func TestNewManager(t *testing.T) { require.NoError(t, m2.Close()) require.NoError(t, <-m2Err) }() + fmt.Println("tp manager 2 prepared") // Create data transport between manager 1 & manager 2. tp2, err := m2.SaveTransport(context.TODO(), pk1, "dmsg") @@ -82,6 +84,8 @@ func TestNewManager(t *testing.T) { tp1 := m1.Transport(transport.MakeTransportID(pk1, pk2, "dmsg")) require.NotNil(t, tp1) + fmt.Println("transports created") + totalSent2 := 0 totalSent1 := 0 diff --git a/vendor/modules.txt b/vendor/modules.txt index 04d03828e2..0b9ef24e79 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -62,7 +62,7 @@ github.com/prometheus/procfs/internal/fs # github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus github.com/sirupsen/logrus/hooks/syslog -# github.com/skycoin/dmsg v0.0.0-20190805065636-70f4c32a994f => ../dmsg +# github.com/skycoin/dmsg v0.0.0-20190805065636-70f4c32a994f github.com/skycoin/dmsg/cipher github.com/skycoin/dmsg github.com/skycoin/dmsg/disc