diff --git a/pkg/transport-discovery/client/client.go b/pkg/transport-discovery/client/client.go index d29ec03d0b..78c09b3716 100644 --- a/pkg/transport-discovery/client/client.go +++ b/pkg/transport-discovery/client/client.go @@ -73,6 +73,16 @@ func (c *apiClient) Get(ctx context.Context, path string) (*http.Response, error return c.client.Do(req.WithContext(ctx)) } +// Delete performs a new DELETE request. +func (c *apiClient) Delete(ctx context.Context, path string) (*http.Response, error) { + req, err := http.NewRequest(http.MethodDelete, c.client.Addr()+path, new(bytes.Buffer)) + if err != nil { + return nil, err + } + + return c.client.Do(req.WithContext(ctx)) +} + // RegisterTransports registers new Transports. func (c *apiClient) RegisterTransports(ctx context.Context, entries ...*transport.SignedEntry) error { if len(entries) == 0 { @@ -150,6 +160,26 @@ func (c *apiClient) GetTransportsByEdge(ctx context.Context, pk cipher.PubKey) ( return entries, nil } +// DeleteTransport deletes given transport by it's ID. A visor can only delete transports if he is one of it's edges. +func (c *apiClient) DeleteTransport(ctx context.Context, id uuid.UUID) error { + resp, err := c.Delete(ctx, fmt.Sprintf("/transports/id:%s", id.String())) + if resp != nil { + defer func() { + if err := resp.Body.Close(); err != nil { + log.WithError(err).Warn("Failed to close HTTP response body") + } + }() + } + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("status: %d, error: %v", resp.StatusCode, extractError(resp.Body)) + } + + return nil +} + // UpdateStatuses updates statuses of transports in discovery. func (c *apiClient) UpdateStatuses(ctx context.Context, statuses ...*transport.Status) ([]*transport.EntryWithStatus, error) { if len(statuses) == 0 { diff --git a/pkg/transport/discovery.go b/pkg/transport/discovery.go index 2fdee01a2c..5fbcafa25b 100644 --- a/pkg/transport/discovery.go +++ b/pkg/transport/discovery.go @@ -3,6 +3,7 @@ package transport import ( "context" "errors" + "fmt" "sync" "time" @@ -15,6 +16,7 @@ type DiscoveryClient interface { RegisterTransports(ctx context.Context, entries ...*SignedEntry) error GetTransportByID(ctx context.Context, id uuid.UUID) (*EntryWithStatus, error) GetTransportsByEdge(ctx context.Context, pk cipher.PubKey) ([]*EntryWithStatus, error) + DeleteTransport(ctx context.Context, id uuid.UUID) error UpdateStatuses(ctx context.Context, statuses ...*Status) ([]*EntryWithStatus, error) } @@ -81,6 +83,21 @@ func (td *mockDiscoveryClient) GetTransportsByEdge(ctx context.Context, pk ciphe return res, nil } +// NOTE that mock implementation doesn't checks whether the transport to be deleted is valid or not, this is, that +// it can be deleted by the visor who called DeleteTransport +func (td *mockDiscoveryClient) DeleteTransport(ctx context.Context, id uuid.UUID) error { + td.Lock() + defer td.Unlock() + + _, ok := td.entries[id] + if !ok { + return fmt.Errorf("transport with id: %s not found in transport discovery", id) + } + + delete(td.entries, id) + return nil +} + func (td *mockDiscoveryClient) UpdateStatuses(ctx context.Context, statuses ...*Status) ([]*EntryWithStatus, error) { res := make([]*EntryWithStatus, 0) diff --git a/pkg/transport/manager.go b/pkg/transport/manager.go index 0aedab6a1d..e463803efc 100644 --- a/pkg/transport/manager.go +++ b/pkg/transport/manager.go @@ -6,6 +6,7 @@ import ( "io" "strings" "sync" + "time" "github.com/SkycoinProject/skywire-mainnet/internal/skyenv" "github.com/SkycoinProject/skywire-mainnet/pkg/snet/snettest" @@ -224,7 +225,7 @@ func (tm *Manager) saveTransport(remote cipher.PubKey, netName string) (*Managed return mTp, nil } -// DeleteTransport disconnects and removes the Transport of Transport ID. +// DeleteTransport deregisters the Transport of Transport ID in transport discovery and deletes it locally. func (tm *Manager) DeleteTransport(id uuid.UUID) { tm.mx.Lock() defer tm.mx.Unlock() @@ -234,8 +235,17 @@ func (tm *Manager) DeleteTransport(id uuid.UUID) { if tp, ok := tm.tps[id]; ok { tp.Close() + tm.Logger.Infof("Deregister transport %s from manager", id) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + err := tm.Conf.DiscoveryClient.DeleteTransport(ctx, id) + if err != nil { + tm.Logger.Errorf("Deregister transport %s from discovery failed with error: %s", id, err) + } + tm.Logger.Infof("Deregister transport %s from discovery", id) + delete(tm.tps, id) - tm.Logger.Infof("Unregistered transport %s", id) } } diff --git a/pkg/transport/manager_test.go b/pkg/transport/manager_test.go index 8ed8a4859d..1a2ed69f68 100644 --- a/pkg/transport/manager_test.go +++ b/pkg/transport/manager_test.go @@ -45,12 +45,38 @@ func TestNewManager(t *testing.T) { nEnv := snettest.NewEnv(t, keys, []string{dmsg.Type}) defer nEnv.Teardown() - m0, m1, tp0, tp1, err := transport.CreateTransportPair(tpDisc, keys, nEnv, "dmsg") + // Prepare tp manager 0. + pk0, sk0 := keys[0].PK, keys[0].SK + ls0 := transport.InMemoryTransportLogStore() + m0, err := transport.NewManager(nEnv.Nets[0], &transport.ManagerConfig{ + PubKey: pk0, + SecKey: sk0, + DiscoveryClient: tpDisc, + LogStore: ls0, + }) + require.NoError(t, err) + go m0.Serve(context.TODO()) defer func() { require.NoError(t, m0.Close()) }() - defer func() { require.NoError(t, m1.Close()) }() + // Prepare tp manager 1. + pk1, sk1 := keys[1].PK, keys[1].SK + ls1 := transport.InMemoryTransportLogStore() + m2, err := transport.NewManager(nEnv.Nets[1], &transport.ManagerConfig{ + PubKey: pk1, + SecKey: sk1, + DiscoveryClient: tpDisc, + LogStore: ls1, + }) require.NoError(t, err) - require.NotNil(t, tp0) + go m2.Serve(context.TODO()) + defer func() { require.NoError(t, m2.Close()) }() + + // Create data transport between manager 1 & manager 2. + tp2, err := m2.SaveTransport(context.TODO(), pk0, "dmsg") + require.NoError(t, err) + tp1 := m0.Transport(transport.MakeTransportID(pk0, pk1, "dmsg")) + require.NotNil(t, tp1) + fmt.Println("transports created") totalSent2 := 0 @@ -63,8 +89,7 @@ func TestNewManager(t *testing.T) { totalSent2 += i rID := routing.RouteID(i) payload := cipher.RandByte(i) - packet := routing.MakeDataPacket(rID, payload) - require.NoError(t, tp1.WritePacket(context.TODO(), packet)) + require.NoError(t, tp2.WritePacket(context.TODO(), routing.MakeDataPacket(rID, payload))) recv, err := m0.ReadPacket() require.NoError(t, err) @@ -77,10 +102,9 @@ func TestNewManager(t *testing.T) { totalSent1 += i rID := routing.RouteID(i) payload := cipher.RandByte(i) - packet := routing.MakeDataPacket(rID, payload) - require.NoError(t, tp0.WritePacket(context.TODO(), packet)) + require.NoError(t, tp1.WritePacket(context.TODO(), routing.MakeDataPacket(rID, payload))) - recv, err := m1.ReadPacket() + recv, err := m2.ReadPacket() require.NoError(t, err) require.Equal(t, rID, recv.RouteID()) require.Equal(t, uint16(i), recv.Size()) @@ -94,12 +118,12 @@ func TestNewManager(t *testing.T) { // 1.5x log write interval just to be safe. time.Sleep(time.Second * 9 / 2) - entry1, err := m0.Conf.LogStore.Entry(tp0.Entry.ID) + entry1, err := ls0.Entry(tp1.Entry.ID) require.NoError(t, err) assert.Equal(t, uint64(totalSent1), entry1.SentBytes) assert.Equal(t, uint64(totalSent2), entry1.RecvBytes) - entry2, err := m1.Conf.LogStore.Entry(tp1.Entry.ID) + entry2, err := ls1.Entry(tp2.Entry.ID) require.NoError(t, err) assert.Equal(t, uint64(totalSent2), entry2.SentBytes) assert.Equal(t, uint64(totalSent1), entry2.RecvBytes) @@ -109,18 +133,17 @@ func TestNewManager(t *testing.T) { t.Run("check_delete_tp", func(t *testing.T) { // Make transport ID. - tpID := transport.MakeTransportID(m0.Conf.PubKey, m1.Conf.PubKey, "dmsg") + tpID := transport.MakeTransportID(pk0, pk1, "dmsg") // Ensure transports are registered properly in tp discovery. entry, err := tpDisc.GetTransportByID(context.TODO(), tpID) require.NoError(t, err) - assert.Equal(t, transport.SortEdges(m0.Conf.PubKey, m1.Conf.PubKey), entry.Entry.Edges) + assert.Equal(t, transport.SortEdges(pk0, pk1), entry.Entry.Edges) assert.True(t, entry.IsUp) - m1.DeleteTransport(tp1.Entry.ID) - entry, err = tpDisc.GetTransportByID(context.TODO(), tpID) - require.NoError(t, err) - assert.False(t, entry.IsUp) + m2.DeleteTransport(tp2.Entry.ID) + _, err = tpDisc.GetTransportByID(context.TODO(), tpID) + require.Contains(t, err.Error(), "not found") }) }