Skip to content

Commit

Permalink
Split CreateTransport
Browse files Browse the repository at this point in the history
  • Loading branch information
nkryuchkov committed Jul 18, 2019
1 parent 3fb6859 commit 6ba5818
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 62 deletions.
2 changes: 1 addition & 1 deletion pkg/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ func (r *Router) setupProto(ctx context.Context) (*setup.Protocol, transport.Tra
}

// TODO(evanlinjin): need string constant for tp type.
tr, err := r.tm.CreateTransport(ctx, r.config.SetupNodes[0], dmsg.Type, false)
tr, err := r.tm.CreateSetupTransport(ctx, r.config.SetupNodes[0], dmsg.Type, false)
if err != nil {
return nil, nil, fmt.Errorf("setup transport: %s", err)
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/router/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ func TestRouterForwarding(t *testing.T) {
errCh <- r.Serve(context.TODO())
}()

tr1, err := m1.CreateTransport(context.TODO(), pk2, "mock", true)
tr1, err := m1.CreateDataTransport(context.TODO(), pk2, "mock", true)
require.NoError(t, err)

tr3, err := m3.CreateTransport(context.TODO(), pk2, "mock2", true)
tr3, err := m3.CreateDataTransport(context.TODO(), pk2, "mock2", true)
require.NoError(t, err)

rule := routing.ForwardRule(time.Now().Add(time.Hour), 4, tr3.Entry.ID)
Expand Down Expand Up @@ -201,7 +201,7 @@ func TestRouterApp(t *testing.T) {

time.Sleep(100 * time.Millisecond)

tr, err := m1.CreateTransport(context.TODO(), pk2, "mock", true)
tr, err := m1.CreateDataTransport(context.TODO(), pk2, "mock", true)
require.NoError(t, err)

rule := routing.AppRule(time.Now().Add(time.Hour), 4, pk2, 5, 6)
Expand Down Expand Up @@ -354,7 +354,7 @@ func TestRouterSetup(t *testing.T) {
errCh <- r.Serve(context.TODO())
}()

tr, err := m2.CreateTransport(context.TODO(), pk1, "mock", false)
tr, err := m2.CreateDataTransport(context.TODO(), pk1, "mock", false)
require.NoError(t, err)
sProto := setup.NewSetupProtocol(tr)

Expand Down
107 changes: 60 additions & 47 deletions pkg/transport/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (tm *Manager) reconnectTransports(ctx context.Context) {
continue
}

_, err := tm.createTransport(ctx, remote, entry.Type, entry.Public)
_, err := tm.CreateDataTransport(ctx, remote, entry.Type, entry.Public)
if err != nil {
tm.Logger.Warnf("Failed to re-establish transport: %s", err)
continue
Expand Down Expand Up @@ -180,7 +180,7 @@ func (tm *Manager) createDefaultTransports(ctx context.Context) {
if exist {
continue
}
_, err := tm.CreateTransport(ctx, pk, "messaging", true)
_, err := tm.CreateDataTransport(ctx, pk, "messaging", true)
if err != nil {
tm.Logger.Warnf("Failed to establish transport to a node %s: %s", pk, err)
}
Expand Down Expand Up @@ -224,9 +224,64 @@ func (tm *Manager) Serve(ctx context.Context) error {
return nil
}

// CreateTransport begins to attempt to establish transports to the given 'remote' node.
func (tm *Manager) CreateTransport(ctx context.Context, remote cipher.PubKey, tpType string, public bool) (*ManagedTransport, error) {
return tm.createTransport(ctx, remote, tpType, public)
// CreateSetupTransport begins to attempt to establish setup transports to the given 'remote' node.
func (tm *Manager) CreateSetupTransport(ctx context.Context, remote cipher.PubKey, tpType string, public bool) (Transport, error) {
factory := tm.factories[tpType]
if factory == nil {
return nil, errors.New("unknown transport type")
}

tr, entry, err := tm.dialTransport(ctx, factory, remote, public)
if err != nil {
return nil, err
}

oldTr := tm.Transport(entry.ID)
if oldTr != nil {
oldTr.killWorker()
}

tm.Logger.Infof("Dialed to %s using %s factory. Transport ID: %s", remote, tpType, entry.ID)

select {
case <-tm.doneChan:
return nil, io.ErrClosedPipe
case tm.SetupTpChan <- tr:
return tr, nil
}
}

// CreateDataTransport begins to attempt to establish data transports to the given 'remote' node.
func (tm *Manager) CreateDataTransport(ctx context.Context, remote cipher.PubKey, tpType string, public bool) (*ManagedTransport, error) {
factory := tm.factories[tpType]
if factory == nil {
return nil, errors.New("unknown transport type")
}

tr, entry, err := tm.dialTransport(ctx, factory, remote, public)
if err != nil {
return nil, err
}

oldTr := tm.Transport(entry.ID)
if oldTr != nil {
oldTr.killWorker()
}

tm.Logger.Infof("Dialed to %s using %s factory. Transport ID: %s", remote, tpType, entry.ID)
mTr := newManagedTransport(tr, *entry, false)

tm.mu.Lock()
tm.transports[entry.ID] = mTr
tm.mu.Unlock()

select {
case <-tm.doneChan:
return nil, io.ErrClosedPipe
case tm.DataTpChan <- mTr:
go tm.manageTransport(ctx, mTr, factory, remote)
return mTr, nil
}
}

// DeleteTransport disconnects and removes the Transport of Transport ID.
Expand Down Expand Up @@ -316,48 +371,6 @@ func (tm *Manager) dialTransport(ctx context.Context, factory Factory, remote ci
return tr, entry, nil
}

func (tm *Manager) createTransport(ctx context.Context, remote cipher.PubKey, tpType string, public bool) (*ManagedTransport, error) {
factory := tm.factories[tpType]
if factory == nil {
return nil, errors.New("unknown transport type")
}

tr, entry, err := tm.dialTransport(ctx, factory, remote, public)
if err != nil {
return nil, err
}

oldTr := tm.Transport(entry.ID)
if oldTr != nil {
oldTr.killWorker()
}

tm.Logger.Infof("Dialed to %s using %s factory. Transport ID: %s", remote, tpType, entry.ID)
mTr := newManagedTransport(tr, *entry, false)

tm.mu.Lock()
tm.transports[entry.ID] = mTr
tm.mu.Unlock()

if tm.IsSetupTransport(tr) {
select {
case <-tm.doneChan:
return nil, io.ErrClosedPipe
case tm.SetupTpChan <- mTr:
go tm.manageTransport(ctx, mTr, factory, remote)
return mTr, nil
}
} else {
select {
case <-tm.doneChan:
return nil, io.ErrClosedPipe
case tm.DataTpChan <- mTr:
go tm.manageTransport(ctx, mTr, factory, remote)
return mTr, nil
}
}
}

func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) (*ManagedTransport, error) {
tr, err := factory.Accept(ctx)
if err != nil {
Expand Down
16 changes: 8 additions & 8 deletions pkg/transport/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TestTransportManager(t *testing.T) {
}
}()

tr2, err := m2.CreateTransport(context.TODO(), pk1, "mock", true)
tr2, err := m2.CreateDataTransport(context.TODO(), pk1, "mock", true)
require.NoError(t, err)

time.Sleep(time.Second)
Expand Down Expand Up @@ -151,7 +151,7 @@ func TestTransportManagerReEstablishTransports(t *testing.T) {
m2, err := NewManager(c2, f2)
require.NoError(t, err)

tr2, err := m2.CreateTransport(context.TODO(), pk1, "mock", true)
tr2, err := m2.CreateDataTransport(context.TODO(), pk1, "mock", true)
require.NoError(t, err)

tr1 := m1.Transport(tr2.Entry.ID)
Expand Down Expand Up @@ -214,7 +214,7 @@ func TestTransportManagerLogs(t *testing.T) {
m2, err := NewManager(c2, f2)
require.NoError(t, err)

tr2, err := m2.CreateTransport(context.TODO(), pk1, "mock", true)
tr2, err := m2.CreateDataTransport(context.TODO(), pk1, "mock", true)
require.NoError(t, err)

time.Sleep(100 * time.Millisecond)
Expand Down Expand Up @@ -315,19 +315,19 @@ func ExampleManager_CreateTransport() {
return
}

mtrAB, err := mgrA.CreateTransport(context.TODO(), pkB, "mock", true)
mtrAB, err := mgrA.CreateDataTransport(context.TODO(), pkB, "mock", true)
if err != nil {
fmt.Printf("Manager.CreateTransport failed on iteration %v with: %v\n", i, err)
fmt.Printf("Manager.CreateDataTransport failed on iteration %v with: %v\n", i, err)
return
}

if (mtrAB.Entry.ID == uuid.UUID{}) {
fmt.Printf("Manager.CreateTransport failed on iteration %v", i)
fmt.Printf("Manager.CreateDataTransport failed on iteration %v", i)
return
}
}

fmt.Println("Manager.CreateTransport success")
fmt.Println("Manager.CreateDataTransport success")

// Output: Manager.CreateTransport success
// Output: Manager.CreateDataTransport success
}
2 changes: 1 addition & 1 deletion pkg/visor/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func (r *RPC) AddTransport(in *AddTransportIn, out *TransportSummary) error {
defer cancel()
}

tp, err := r.node.tm.CreateTransport(ctx, in.RemotePK, in.TpType, in.Public)
tp, err := r.node.tm.CreateDataTransport(ctx, in.RemotePK, in.TpType, in.Public)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/visor/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func TestRPC(t *testing.T) {
require.NoError(t, <-errCh)
}()

_, err = tm2.CreateTransport(context.TODO(), pk1, "mock", true)
_, err = tm2.CreateDataTransport(context.TODO(), pk1, "mock", true)
require.NoError(t, err)

apps := []AppConfig{
Expand Down

0 comments on commit 6ba5818

Please sign in to comment.