diff --git a/pkg/router/router.go b/pkg/router/router.go index c891a8b1e1..7606ce4cdc 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -434,7 +434,7 @@ func (r *Router) setupProto(ctx context.Context) (*setup.Protocol, transport.Tra } // TODO(evanlinjin): need string constant for tp type. - tr, err := r.tm.CreateTransport(ctx, r.config.SetupNodes[0], dmsg.Type, false) + tr, err := r.tm.CreateSetupTransport(ctx, r.config.SetupNodes[0], dmsg.Type, false) if err != nil { return nil, nil, fmt.Errorf("setup transport: %s", err) } diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index ec63d7821c..d8897dcb33 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -78,10 +78,10 @@ func TestRouterForwarding(t *testing.T) { errCh <- r.Serve(context.TODO()) }() - tr1, err := m1.CreateTransport(context.TODO(), pk2, "mock", true) + tr1, err := m1.CreateDataTransport(context.TODO(), pk2, "mock", true) require.NoError(t, err) - tr3, err := m3.CreateTransport(context.TODO(), pk2, "mock2", true) + tr3, err := m3.CreateDataTransport(context.TODO(), pk2, "mock2", true) require.NoError(t, err) rule := routing.ForwardRule(time.Now().Add(time.Hour), 4, tr3.Entry.ID) @@ -201,7 +201,7 @@ func TestRouterApp(t *testing.T) { time.Sleep(100 * time.Millisecond) - tr, err := m1.CreateTransport(context.TODO(), pk2, "mock", true) + tr, err := m1.CreateDataTransport(context.TODO(), pk2, "mock", true) require.NoError(t, err) rule := routing.AppRule(time.Now().Add(time.Hour), 4, pk2, 5, 6) @@ -354,7 +354,7 @@ func TestRouterSetup(t *testing.T) { errCh <- r.Serve(context.TODO()) }() - tr, err := m2.CreateTransport(context.TODO(), pk1, "mock", false) + tr, err := m2.CreateDataTransport(context.TODO(), pk1, "mock", false) require.NoError(t, err) sProto := setup.NewSetupProtocol(tr) diff --git a/pkg/transport/manager.go b/pkg/transport/manager.go index c9176b3277..3970ca9af6 100644 --- a/pkg/transport/manager.go +++ b/pkg/transport/manager.go @@ -135,7 +135,7 @@ func (tm *Manager) reconnectTransports(ctx context.Context) { continue } - _, err := tm.createTransport(ctx, remote, entry.Type, entry.Public) + _, err := tm.CreateDataTransport(ctx, remote, entry.Type, entry.Public) if err != nil { tm.Logger.Warnf("Failed to re-establish transport: %s", err) continue @@ -180,7 +180,7 @@ func (tm *Manager) createDefaultTransports(ctx context.Context) { if exist { continue } - _, err := tm.CreateTransport(ctx, pk, "messaging", true) + _, err := tm.CreateDataTransport(ctx, pk, "messaging", true) if err != nil { tm.Logger.Warnf("Failed to establish transport to a node %s: %s", pk, err) } @@ -224,9 +224,64 @@ func (tm *Manager) Serve(ctx context.Context) error { return nil } -// CreateTransport begins to attempt to establish transports to the given 'remote' node. -func (tm *Manager) CreateTransport(ctx context.Context, remote cipher.PubKey, tpType string, public bool) (*ManagedTransport, error) { - return tm.createTransport(ctx, remote, tpType, public) +// CreateSetupTransport begins to attempt to establish setup transports to the given 'remote' node. +func (tm *Manager) CreateSetupTransport(ctx context.Context, remote cipher.PubKey, tpType string, public bool) (Transport, error) { + factory := tm.factories[tpType] + if factory == nil { + return nil, errors.New("unknown transport type") + } + + tr, entry, err := tm.dialTransport(ctx, factory, remote, public) + if err != nil { + return nil, err + } + + oldTr := tm.Transport(entry.ID) + if oldTr != nil { + oldTr.killWorker() + } + + tm.Logger.Infof("Dialed to %s using %s factory. Transport ID: %s", remote, tpType, entry.ID) + + select { + case <-tm.doneChan: + return nil, io.ErrClosedPipe + case tm.SetupTpChan <- tr: + return tr, nil + } +} + +// CreateDataTransport begins to attempt to establish data transports to the given 'remote' node. +func (tm *Manager) CreateDataTransport(ctx context.Context, remote cipher.PubKey, tpType string, public bool) (*ManagedTransport, error) { + factory := tm.factories[tpType] + if factory == nil { + return nil, errors.New("unknown transport type") + } + + tr, entry, err := tm.dialTransport(ctx, factory, remote, public) + if err != nil { + return nil, err + } + + oldTr := tm.Transport(entry.ID) + if oldTr != nil { + oldTr.killWorker() + } + + tm.Logger.Infof("Dialed to %s using %s factory. Transport ID: %s", remote, tpType, entry.ID) + mTr := newManagedTransport(tr, *entry, false) + + tm.mu.Lock() + tm.transports[entry.ID] = mTr + tm.mu.Unlock() + + select { + case <-tm.doneChan: + return nil, io.ErrClosedPipe + case tm.DataTpChan <- mTr: + go tm.manageTransport(ctx, mTr, factory, remote) + return mTr, nil + } } // DeleteTransport disconnects and removes the Transport of Transport ID. @@ -316,48 +371,6 @@ func (tm *Manager) dialTransport(ctx context.Context, factory Factory, remote ci return tr, entry, nil } -func (tm *Manager) createTransport(ctx context.Context, remote cipher.PubKey, tpType string, public bool) (*ManagedTransport, error) { - factory := tm.factories[tpType] - if factory == nil { - return nil, errors.New("unknown transport type") - } - - tr, entry, err := tm.dialTransport(ctx, factory, remote, public) - if err != nil { - return nil, err - } - - oldTr := tm.Transport(entry.ID) - if oldTr != nil { - oldTr.killWorker() - } - - tm.Logger.Infof("Dialed to %s using %s factory. Transport ID: %s", remote, tpType, entry.ID) - mTr := newManagedTransport(tr, *entry, false) - - tm.mu.Lock() - tm.transports[entry.ID] = mTr - tm.mu.Unlock() - - if tm.IsSetupTransport(tr) { - select { - case <-tm.doneChan: - return nil, io.ErrClosedPipe - case tm.SetupTpChan <- mTr: - go tm.manageTransport(ctx, mTr, factory, remote) - return mTr, nil - } - } else { - select { - case <-tm.doneChan: - return nil, io.ErrClosedPipe - case tm.DataTpChan <- mTr: - go tm.manageTransport(ctx, mTr, factory, remote) - return mTr, nil - } - } -} - func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) (*ManagedTransport, error) { tr, err := factory.Accept(ctx) if err != nil { diff --git a/pkg/transport/manager_test.go b/pkg/transport/manager_test.go index 07748b4d78..1463e38ec4 100644 --- a/pkg/transport/manager_test.go +++ b/pkg/transport/manager_test.go @@ -83,7 +83,7 @@ func TestTransportManager(t *testing.T) { } }() - tr2, err := m2.CreateTransport(context.TODO(), pk1, "mock", true) + tr2, err := m2.CreateDataTransport(context.TODO(), pk1, "mock", true) require.NoError(t, err) time.Sleep(time.Second) @@ -151,7 +151,7 @@ func TestTransportManagerReEstablishTransports(t *testing.T) { m2, err := NewManager(c2, f2) require.NoError(t, err) - tr2, err := m2.CreateTransport(context.TODO(), pk1, "mock", true) + tr2, err := m2.CreateDataTransport(context.TODO(), pk1, "mock", true) require.NoError(t, err) tr1 := m1.Transport(tr2.Entry.ID) @@ -214,7 +214,7 @@ func TestTransportManagerLogs(t *testing.T) { m2, err := NewManager(c2, f2) require.NoError(t, err) - tr2, err := m2.CreateTransport(context.TODO(), pk1, "mock", true) + tr2, err := m2.CreateDataTransport(context.TODO(), pk1, "mock", true) require.NoError(t, err) time.Sleep(100 * time.Millisecond) @@ -315,19 +315,19 @@ func ExampleManager_CreateTransport() { return } - mtrAB, err := mgrA.CreateTransport(context.TODO(), pkB, "mock", true) + mtrAB, err := mgrA.CreateDataTransport(context.TODO(), pkB, "mock", true) if err != nil { - fmt.Printf("Manager.CreateTransport failed on iteration %v with: %v\n", i, err) + fmt.Printf("Manager.CreateDataTransport failed on iteration %v with: %v\n", i, err) return } if (mtrAB.Entry.ID == uuid.UUID{}) { - fmt.Printf("Manager.CreateTransport failed on iteration %v", i) + fmt.Printf("Manager.CreateDataTransport failed on iteration %v", i) return } } - fmt.Println("Manager.CreateTransport success") + fmt.Println("Manager.CreateDataTransport success") - // Output: Manager.CreateTransport success + // Output: Manager.CreateDataTransport success } diff --git a/pkg/visor/rpc.go b/pkg/visor/rpc.go index 2b4711f583..0ddc7ffded 100644 --- a/pkg/visor/rpc.go +++ b/pkg/visor/rpc.go @@ -207,7 +207,7 @@ func (r *RPC) AddTransport(in *AddTransportIn, out *TransportSummary) error { defer cancel() } - tp, err := r.node.tm.CreateTransport(ctx, in.RemotePK, in.TpType, in.Public) + tp, err := r.node.tm.CreateDataTransport(ctx, in.RemotePK, in.TpType, in.Public) if err != nil { return err } diff --git a/pkg/visor/rpc_test.go b/pkg/visor/rpc_test.go index c0b24c4147..c7503dbfd2 100644 --- a/pkg/visor/rpc_test.go +++ b/pkg/visor/rpc_test.go @@ -112,7 +112,7 @@ func TestRPC(t *testing.T) { require.NoError(t, <-errCh) }() - _, err = tm2.CreateTransport(context.TODO(), pk1, "mock", true) + _, err = tm2.CreateDataTransport(context.TODO(), pk1, "mock", true) require.NoError(t, err) apps := []AppConfig{