diff --git a/cmd/skywire-cli/commands/visor/transports.go b/cmd/skywire-cli/commands/visor/transports.go index c95c48207..99cf89cd1 100644 --- a/cmd/skywire-cli/commands/visor/transports.go +++ b/cmd/skywire-cli/commands/visor/transports.go @@ -7,12 +7,11 @@ import ( "text/tabwriter" "time" - "github.com/skycoin/dmsg" "github.com/skycoin/dmsg/cipher" "github.com/spf13/cobra" "github.com/skycoin/skywire/cmd/skywire-cli/internal" - "github.com/skycoin/skywire/pkg/snet/directtp/tptypes" + "github.com/skycoin/skywire/pkg/transport/network" "github.com/skycoin/skywire/pkg/visor" ) @@ -113,15 +112,15 @@ var addTpCmd = &cobra.Command{ logger.Infof("Established %v transport to %v", transportType, pk) } else { - transportTypes := []string{ - tptypes.STCP, - tptypes.STCPR, - tptypes.SUDPH, - dmsg.Type, + transportTypes := []network.Type{ + network.STCP, + network.STCPR, + network.SUDPH, + network.DMSG, } for _, transportType := range transportTypes { - tp, err = rpcClient().AddTransport(pk, transportType, public, timeout) + tp, err = rpcClient().AddTransport(pk, string(transportType), public, timeout) if err == nil { logger.Infof("Established %v transport to %v", transportType, pk) break diff --git a/pkg/app/appevent/utils.go b/pkg/app/appevent/utils.go new file mode 100644 index 000000000..064627400 --- /dev/null +++ b/pkg/app/appevent/utils.go @@ -0,0 +1,26 @@ +package appevent + +import "context" + +// SendTCPDial sends tcp dial event +func (eb *Broadcaster) SendTCPDial(ctx context.Context, remoteNet, remoteAddr string) { + data := TCPDialData{RemoteNet: remoteNet, RemoteAddr: remoteAddr} + event := NewEvent(TCPDial, data) + eb.sendEvent(ctx, event) +} + +// SendTPClose sends transport close event +func (eb *Broadcaster) SendTPClose(ctx context.Context, netType, addr string) { + data := TCPCloseData{RemoteNet: string(netType), RemoteAddr: addr} + event := NewEvent(TCPClose, data) + if err := eb.Broadcast(context.Background(), event); err != nil { + eb.log.WithError(err).Errorln("Failed to broadcast TCPClose event") + } +} + +func (eb *Broadcaster) sendEvent(_ context.Context, event *Event) { + err := eb.Broadcast(context.Background(), event) //nolint:errcheck + if err != nil { + eb.log.Warn("Failed to broadcast event: %v", event) + } +} diff --git a/pkg/app/conn_test.go b/pkg/app/conn_test.go index 7f337dd29..5b25497d7 100644 --- a/pkg/app/conn_test.go +++ b/pkg/app/conn_test.go @@ -18,7 +18,7 @@ import ( "github.com/skycoin/skywire/pkg/app/appserver" "github.com/skycoin/skywire/pkg/app/idmanager" "github.com/skycoin/skywire/pkg/routing" - "github.com/skycoin/skywire/pkg/snet/snettest" + "github.com/skycoin/skywire/pkg/util/cipherutil" ) func TestConn_Read(t *testing.T) { @@ -175,7 +175,7 @@ func (p *wrappedConn) RemoteAddr() net.Addr { func TestConn_TestConn(t *testing.T) { mp := func() (net.Conn, net.Conn, func(), error) { netType := appnet.TypeSkynet - keys := snettest.GenKeyPairs(2) + keys := cipherutil.GenKeyPairs(2) fmt.Printf("C1 Local: %s\n", keys[0].PK) fmt.Printf("C2 Local: %s\n", keys[1].PK) p1, p2 := net.Pipe() @@ -219,7 +219,7 @@ func TestConn_TestConn(t *testing.T) { rpcS := rpc.NewServer() - appKeys := snettest.GenKeyPairs(2) + appKeys := cipherutil.GenKeyPairs(2) var ( procKey1 appcommon.ProcKey diff --git a/pkg/dmsgc/dmsgc.go b/pkg/dmsgc/dmsgc.go new file mode 100644 index 000000000..bd6d36a36 --- /dev/null +++ b/pkg/dmsgc/dmsgc.go @@ -0,0 +1,42 @@ +package dmsgc + +import ( + "context" + + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/dmsg/disc" + "github.com/skycoin/skycoin/src/util/logging" + + "github.com/skycoin/skywire/pkg/app/appevent" +) + +// DmsgConfig defines config for Dmsg network. +type DmsgConfig struct { + Discovery string `json:"discovery"` + SessionsCount int `json:"sessions_count"` +} + +// New makes new dmsg client from configuration +func New(pk cipher.PubKey, sk cipher.SecKey, eb *appevent.Broadcaster, conf *DmsgConfig) *dmsg.Client { + dmsgConf := &dmsg.Config{ + MinSessions: conf.SessionsCount, + Callbacks: &dmsg.ClientCallbacks{ + OnSessionDial: func(network, addr string) error { + data := appevent.TCPDialData{RemoteNet: network, RemoteAddr: addr} + event := appevent.NewEvent(appevent.TCPDial, data) + _ = eb.Broadcast(context.Background(), event) //nolint:errcheck + // @evanlinjin: An error is not returned here as this will cancel the session dial. + return nil + }, + OnSessionDisconnect: func(network, addr string, _ error) { + data := appevent.TCPCloseData{RemoteNet: network, RemoteAddr: addr} + event := appevent.NewEvent(appevent.TCPClose, data) + _ = eb.Broadcast(context.Background(), event) //nolint:errcheck + }, + }, + } + dmsgC := dmsg.NewClient(pk, sk, disc.NewHTTP(conf.Discovery), dmsgConf) + dmsgC.SetLogger(logging.MustGetLogger("dmsgC")) + return dmsgC +} diff --git a/pkg/router/route_group_test.go b/pkg/router/route_group_test.go index cab574704..3ad02ca34 100644 --- a/pkg/router/route_group_test.go +++ b/pkg/router/route_group_test.go @@ -1,23 +1,12 @@ package router import ( - "context" - "fmt" - "io" - "strconv" - "strings" - "sync" "testing" - "time" "github.com/skycoin/dmsg/cipher" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/skycoin/skywire/pkg/routing" - "github.com/skycoin/skywire/pkg/snet/directtp/tptypes" - "github.com/skycoin/skywire/pkg/snet/snettest" - "github.com/skycoin/skywire/pkg/transport" ) func TestNewRouteGroup(t *testing.T) { @@ -26,484 +15,6 @@ func TestNewRouteGroup(t *testing.T) { require.Equal(t, DefaultRouteGroupConfig(), rg.cfg) } -// Uncomment for debugging -/* -func TestRouteGroupAlignment(t *testing.T) { - alignment.PrintStruct(RouteGroup{}) -} -*/ - -func TestRouteGroup_Close(t *testing.T) { - rg1, rg2, m1, m2, teardown := setupEnv(t) - - ctx, cancel := context.WithCancel(context.Background()) - - // push close packet from transport to route group - go pushPackets(ctx, m2, rg2) - go pushPackets(ctx, m1, rg1) - - err := rg1.Close() - require.NoError(t, err) - require.True(t, rg1.isClosed()) - require.True(t, rg2.isRemoteClosed()) - // rg1 should be done (not getting any new data, returning `io.EOF` on further reads) - // but not closed - require.False(t, rg2.isClosed()) - - err = rg1.Close() - require.Equal(t, io.ErrClosedPipe, err) - - err = rg2.Close() - require.NoError(t, err) - require.True(t, rg2.isClosed()) - - err = rg2.Close() - require.Equal(t, io.ErrClosedPipe, err) - - cancel() - teardown() -} - -func TestRouteGroup_Read(t *testing.T) { - rg1, rg2, m1, m2, teardown := setupEnv(t) - - ctx, cancel := context.WithCancel(context.Background()) - - // push close packet from transport to route group - go pushPackets(ctx, m2, rg2) - go pushPackets(ctx, m1, rg1) - - msg1 := []byte("hello1") - msg2 := []byte("hello2") - msg3 := []byte("hello3") - buf1 := make([]byte, len(msg1)) - buf2 := make([]byte, len(msg2)) - buf3 := make([]byte, len(msg2)/2) - buf4 := make([]byte, len(msg2)/2) - - rg1.readCh <- msg1 - rg2.readCh <- msg2 - rg2.readCh <- msg3 - - n, err := rg1.Read([]byte{}) - require.Equal(t, 0, n) - require.NoError(t, err) - - n, err = rg1.Read(buf1) - require.NoError(t, err) - require.Equal(t, msg1, buf1) - require.Equal(t, len(msg1), n) - - n, err = rg2.Read(buf2) - require.NoError(t, err) - require.Equal(t, msg2, buf2) - require.Equal(t, len(msg2), n) - - // Test short reads. - n, err = rg2.Read(buf3) - require.NoError(t, err) - require.Equal(t, msg3[0:len(msg3)/2], buf3) - require.Equal(t, len(msg3)/2, n) - - n, err = rg2.Read(buf4) - require.NoError(t, err) - require.Equal(t, msg3[len(msg3)/2:], buf4) - require.Equal(t, len(msg3)/2, n) - - require.NoError(t, rg1.Close()) - require.NoError(t, rg2.Close()) - cancel() - teardown() -} - -func TestRouteGroup_Write(t *testing.T) { - rg1, rg2, m1, m2, teardown := setupEnv(t) - - testWrite(t, rg1, rg2, m1, m2) - - require.NoError(t, rg1.Close()) - require.NoError(t, rg2.Close()) - teardown() -} - -func testWrite(t *testing.T, rg1, rg2 *RouteGroup, m1, m2 *transport.Manager) { - msg1 := []byte("hello1") - msg2 := []byte("hello2") - - n, err := rg1.Write([]byte{}) - require.Equal(t, 0, n) - require.NoError(t, err) - - n, err = rg2.Write([]byte{}) - require.Equal(t, 0, n) - require.NoError(t, err) - - _, err = rg1.Write(msg1) - require.NoError(t, err) - - _, err = rg2.Write(msg2) - require.NoError(t, err) - - recv, err := m1.ReadPacket() - require.NoError(t, err) - require.Equal(t, msg2, recv.Payload()) - - recv, err = m2.ReadPacket() - require.NoError(t, err) - require.Equal(t, msg1, recv.Payload()) - - rg1.mu.Lock() - tpBackup := rg1.tps[0] - rg1.tps[0] = nil - rg1.mu.Unlock() - _, err = rg1.Write(msg1) - require.Equal(t, ErrBadTransport, err) - - rg1.mu.Lock() - rg1.tps[0] = tpBackup - - tpsBackup := rg1.tps - rg1.tps = nil - rg1.mu.Unlock() - _, err = rg1.Write(msg1) - require.Equal(t, ErrNoTransports, err) - - rg1.mu.Lock() - rg1.tps = tpsBackup - - fwdBackup := rg1.fwd - rg1.fwd = nil - rg1.mu.Unlock() - _, err = rg1.Write(msg1) - require.Equal(t, ErrNoRules, err) - - rg1.mu.Lock() - rg1.fwd = fwdBackup - rg1.mu.Unlock() -} - -func TestRouteGroup_ReadWrite(t *testing.T) { - const iterations = 3 - - for i := 0; i < iterations; i++ { - testReadWrite(t, iterations) - } -} - -func testReadWrite(t *testing.T, iterations int) { - rg1, rg2, m1, m2, teardown := setupEnv(t) - - ctx, cancel := context.WithCancel(context.Background()) - - // push close packet from transport to route group - go pushPackets(ctx, m2, rg2) - go pushPackets(ctx, m1, rg1) - - testRouteGroupReadWrite(t, iterations, rg1, rg2) - - assert.NoError(t, rg1.Close()) - assert.NoError(t, rg2.Close()) - cancel() - teardown() -} - -func testRouteGroupReadWrite(t *testing.T, iterations int, rg1, rg2 io.ReadWriter) { - msg1 := []byte("hello1_") - msg2 := []byte("hello2_") - - t.Run("Group", func(t *testing.T) { - t.Run("MultipleWriteRead", func(t *testing.T) { - testMultipleWR(t, iterations, rg1, rg2, msg1, msg2) - }) - - t.Run("SingleReadWrite", func(t *testing.T) { - testSingleRW(t, rg1, rg2, msg1, msg2) - }) - - t.Run("MultipleReadWrite", func(t *testing.T) { - testMultipleRW(t, iterations, rg1, rg2, msg1, msg2) - }) - - t.Run("SingleWriteRead", func(t *testing.T) { - testSingleWR(t, rg1, rg2, msg1, msg2) - }) - }) -} - -func testSingleWR(t *testing.T, rg1, rg2 io.ReadWriter, msg1, msg2 []byte) { - _, err := rg1.Write(msg1) - require.NoError(t, err) - - _, err = rg2.Write(msg2) - require.NoError(t, err) - - buf1 := make([]byte, len(msg2)) - _, err = rg1.Read(buf1) - require.NoError(t, err) - require.Equal(t, msg2, buf1) - - buf2 := make([]byte, len(msg1)) - _, err = rg2.Read(buf2) - require.NoError(t, err) - require.Equal(t, msg1, buf2) -} - -func testMultipleRW(t *testing.T, iterations int, rg1, rg2 io.ReadWriter, msg1, msg2 []byte) { - var err1, err2 error - - for i := 0; i < iterations; i++ { - var wg sync.WaitGroup - - wg.Add(1) - - go func() { - defer wg.Done() - - time.Sleep(100 * time.Millisecond) - - for j := 0; j < iterations; j++ { - _, err := rg1.Write(append(msg1, []byte(strconv.Itoa(j))...)) - require.NoError(t, err) - - _, err = rg2.Write(append(msg2, []byte(strconv.Itoa(j))...)) - require.NoError(t, err) - } - }() - - require.NoError(t, err1) - require.NoError(t, err2) - - for j := 0; j < iterations; j++ { - msg := append(msg2, []byte(strconv.Itoa(j))...) - buf1 := make([]byte, len(msg)) - _, err := rg1.Read(buf1) - require.NoError(t, err) - require.Equal(t, msg, buf1) - } - - for j := 0; j < iterations; j++ { - msg := append(msg1, []byte(strconv.Itoa(j))...) - buf2 := make([]byte, len(msg)) - _, err := rg2.Read(buf2) - require.NoError(t, err) - require.Equal(t, msg, buf2) - } - - wg.Wait() - } -} - -func testSingleRW(t *testing.T, rg1, rg2 io.ReadWriter, msg1, msg2 []byte) { - var err1, err2 error - - go func() { - time.Sleep(1 * time.Second) - _, err1 = rg1.Write(msg1) - _, err2 = rg2.Write(msg2) - }() - - require.NoError(t, err1) - require.NoError(t, err2) - - buf1 := make([]byte, len(msg2)) - _, err := rg1.Read(buf1) - require.NoError(t, err) - require.Equal(t, msg2, buf1) - - buf2 := make([]byte, len(msg1)) - _, err = rg2.Read(buf2) - require.NoError(t, err) - require.Equal(t, msg1, buf2) -} - -func testMultipleWR(t *testing.T, iterations int, rg1, rg2 io.ReadWriter, msg1, msg2 []byte) { - for i := 0; i < iterations; i++ { - for j := 0; j < iterations; j++ { - _, err := rg1.Write(append(msg1, []byte(strconv.Itoa(j))...)) - require.NoError(t, err) - - _, err = rg2.Write(append(msg2, []byte(strconv.Itoa(j))...)) - require.NoError(t, err) - } - - for j := 0; j < iterations; j++ { - msg := append(msg2, []byte(strconv.Itoa(j))...) - buf1 := make([]byte, len(msg)) - _, err := rg1.Read(buf1) - require.NoError(t, err) - require.Equal(t, msg, buf1) - } - - for j := 0; j < iterations; j++ { - msg := append(msg1, []byte(strconv.Itoa(j))...) - buf2 := make([]byte, len(msg)) - _, err := rg2.Read(buf2) - require.NoError(t, err) - require.Equal(t, msg, buf2) - } - } -} - -func TestArbitrarySizeOneMessage(t *testing.T) { - // Test fails if message size is above 4059 - const ( - value1 = 4058 // dmsg/noise.maxFrameSize - 38 - value2 = 4059 // dmsg/noise.maxFrameSize - 37 - ) - - var wg sync.WaitGroup - - wg.Add(1) - - t.Run("Value1", func(t *testing.T) { - defer wg.Done() - testArbitrarySizeOneMessage(t, value1) - }) - - wg.Wait() - - t.Run("Value2", func(t *testing.T) { - testArbitrarySizeOneMessage(t, value2) - }) -} - -func TestArbitrarySizeMultipleMessagesByChunks(t *testing.T) { - // Test fails if message size is above 64810 - const ( - value1 = 64810 // 2^16 - 726 - value2 = 64811 // 2^16 - 725 - ) - - var wg sync.WaitGroup - - wg.Add(1) - - t.Run("Value1", func(t *testing.T) { - defer wg.Done() - testArbitrarySizeMultipleMessagesByChunks(t, value1) - }) - - wg.Wait() - - t.Run("Value2", func(t *testing.T) { - testArbitrarySizeMultipleMessagesByChunks(t, value2) - }) -} - -func testArbitrarySizeMultipleMessagesByChunks(t *testing.T, size int) { - rg1, rg2, m1, m2, teardown := setupEnv(t) - - ctx, cancel := context.WithCancel(context.Background()) - - // push close packet from transport to route group - go pushPackets(ctx, m2, rg2) - go pushPackets(ctx, m1, rg1) - - defer func() { - cancel() - teardown() - }() - - chunkSize := 1024 - - msg := []byte(strings.Repeat("A", size)) - - for offset := 0; offset < size; offset += chunkSize { - _, err := rg1.Write(msg[offset : offset+chunkSize]) - require.NoError(t, err) - } - - for offset := 0; offset < size; offset += chunkSize { - buf := make([]byte, chunkSize) - n, err := rg2.Read(buf) - require.NoError(t, err) - require.Equal(t, chunkSize, n) - require.Equal(t, msg[offset:offset+chunkSize], buf) - } - - var ( - errCh = make(chan error) - nCh = make(chan int) - bufCh = make(chan []byte) - ) - go func() { - buf := make([]byte, size) - n, err := rg2.Read(buf) - errCh <- err - nCh <- n - bufCh <- buf - }() - - // close remote to simulate `io.EOF` on local connection - require.NoError(t, rg1.Close()) - - err := <-errCh - n := <-nCh - readBuf := <-bufCh - close(nCh) - close(errCh) - close(bufCh) - require.Equal(t, io.EOF, err) - require.Equal(t, 0, n) - require.Equal(t, make([]byte, size), readBuf) - - require.NoError(t, rg2.Close()) -} - -func testArbitrarySizeOneMessage(t *testing.T, size int) { - rg1, rg2, m1, m2, teardown := setupEnv(t) - - ctx, cancel := context.WithCancel(context.Background()) - - // push close packet from transport to route group - go pushPackets(ctx, m2, rg2) - go pushPackets(ctx, m1, rg1) - - defer func() { - cancel() - teardown() - }() - - msg := []byte(strings.Repeat("A", size)) - - _, err := rg1.Write(msg) - require.NoError(t, err) - - buf := make([]byte, size) - n, err := rg2.Read(buf) - require.NoError(t, err) - require.Equal(t, size, n) - require.Equal(t, msg, buf) - - var ( - errCh = make(chan error) - nCh = make(chan int) - bufCh = make(chan []byte) - ) - go func() { - buf := make([]byte, size) - n, err := rg2.Read(buf) - errCh <- err - nCh <- n - bufCh <- buf - }() - - // close remote to simulate `io.EOF` on local connection - require.NoError(t, rg1.Close()) - - err = <-errCh - n = <-nCh - readBuf := <-bufCh - close(nCh) - close(errCh) - close(bufCh) - require.Equal(t, io.EOF, err) - require.Equal(t, 0, n) - require.Equal(t, make([]byte, size), readBuf) - - require.NoError(t, rg2.Close()) -} - func TestRouteGroup_LocalAddr(t *testing.T) { rg := createRouteGroup(DefaultRouteGroupConfig()) require.Equal(t, rg.desc.Dst(), rg.LocalAddr()) @@ -518,91 +29,6 @@ func TestRouteGroup_RemoteAddr(t *testing.T) { require.NoError(t, rg.Close()) } -// TODO(darkrengarius): Uncomment and fix. -/* -func TestRouteGroup_TestConn(t *testing.T) { - mp := func() (c1, c2 net.Conn, stop func(), err error) { - rg1, rg2, m1, m2, teardown := setupEnv(t) - - ctx, cancel := context.WithCancel(context.Background()) - - // push close packet from transport to route group - go pushPackets(ctx, m2, rg2) - go pushPackets(ctx, m1, rg1) - - stop = func() { - _ = rg1.Close() // nolint:errcheck - _ = rg2.Close() // nolint:errcheck - cancel() - teardown() - } - - return rg1, rg2, stop, nil - } - - nettest.TestConn(t, mp) -} -*/ - -func pushPackets(ctx context.Context, from *transport.Manager, to *RouteGroup) { - for { - select { - case <-ctx.Done(): - return - default: - } - - packet, err := from.ReadPacket() - if err != nil { - panic(err) - } - - payload := packet.Payload() - if len(payload) != int(packet.Size()) { - panic("malformed packet") - } - - switch packet.Type() { - case routing.ClosePacket: - if to.isClosed() { - panic(io.ErrClosedPipe) - } - - if err := to.handleClosePacket(routing.CloseCode(packet.Payload()[0])); err != nil { - panic(err) - } - - return - case routing.DataPacket: - if !safeSend(ctx, to, payload) { - return - } - case routing.HandshakePacket: - // error won't happen with the handshake packet - _ = to.handlePacket(packet) //nolint:errcheck - default: - panic(fmt.Sprintf("wrong packet type %v", packet.Type())) - } - } -} - -func safeSend(ctx context.Context, to *RouteGroup, payload []byte) (keepSending bool) { - defer func() { - if r := recover(); r != nil { - keepSending = r == "send on closed channel" - } - }() - - select { - case <-ctx.Done(): - return false - case <-to.closed: - return false - case to.readCh <- payload: - return true - } -} - func createRouteGroup(cfg *RouteGroupConfig) *RouteGroup { rt := routing.NewTable() @@ -616,67 +42,3 @@ func createRouteGroup(cfg *RouteGroupConfig) *RouteGroup { return rg } - -func setupEnv(t *testing.T) (rg1, rg2 *RouteGroup, m1, m2 *transport.Manager, teardown func()) { - keys := snettest.GenKeyPairs(2) - - pk1 := keys[0].PK - pk2 := keys[1].PK - - // create test env - nEnv := snettest.NewEnv(t, keys, []string{tptypes.STCP}) - - tpDisc := transport.NewDiscoveryMock() - tpKeys := snettest.GenKeyPairs(2) - - m1, m2, tp1, tp2, err := transport.CreateTransportPair(tpDisc, tpKeys, nEnv, tptypes.STCP) - require.NoError(t, err) - require.NotNil(t, tp1) - require.NotNil(t, tp2) - require.NotNil(t, tp1.Entry) - require.NotNil(t, tp2.Entry) - - // because some subtests of `TestConn` are highly specific in their behavior, - // it's best to exceed the `readCh` size - rgCfg := &RouteGroupConfig{ - ReadChBufSize: defaultReadChBufSize * 3, - KeepAliveInterval: defaultRouteGroupKeepAliveInterval, - } - - rg1 = createRouteGroup(rgCfg) - rg2 = createRouteGroup(rgCfg) - - r1RtIDs, err := rg1.rt.ReserveKeys(1) - require.NoError(t, err) - - r2RtIDs, err := rg2.rt.ReserveKeys(1) - require.NoError(t, err) - - r1FwdRule := routing.ForwardRule(ruleKeepAlive, r1RtIDs[0], r2RtIDs[0], tp1.Entry.ID, pk2, pk1, 0, 0) - err = rg1.rt.SaveRule(r1FwdRule) - require.NoError(t, err) - - r2FwdRule := routing.ForwardRule(ruleKeepAlive, r2RtIDs[0], r1RtIDs[0], tp2.Entry.ID, pk1, pk2, 0, 0) - err = rg2.rt.SaveRule(r2FwdRule) - require.NoError(t, err) - - r1FwdRtDesc := r1FwdRule.RouteDescriptor() - rg1.mu.Lock() - rg1.desc = r1FwdRtDesc.Invert() - rg1.tps = append(rg1.tps, tp1) - rg1.fwd = append(rg1.fwd, r1FwdRule) - rg1.mu.Unlock() - - r2FwdRtDesc := r2FwdRule.RouteDescriptor() - rg2.mu.Lock() - rg2.desc = r2FwdRtDesc.Invert() - rg2.tps = append(rg2.tps, tp2) - rg2.fwd = append(rg2.fwd, r2FwdRule) - rg2.mu.Unlock() - - teardown = func() { - nEnv.Teardown() - } - - return rg1, rg2, m1, m2, teardown -} diff --git a/pkg/router/router.go b/pkg/router/router.go index 668a2c8cd..4836d77d0 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -21,9 +21,8 @@ import ( "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/skywire/pkg/setup/setupclient" "github.com/skycoin/skywire/pkg/skyenv" - "github.com/skycoin/skywire/pkg/snet" - "github.com/skycoin/skywire/pkg/snet/directtp/noisewrapper" "github.com/skycoin/skywire/pkg/transport" + "github.com/skycoin/skywire/pkg/transport/network" ) //go:generate mockery -name Router -case underscore -inpkg @@ -37,7 +36,6 @@ const ( handshakeAwaitTimeout = 2 * time.Second - minHops = 0 maxHops = 50 retryDuration = 2 * time.Second retryInterval = 500 * time.Millisecond @@ -145,8 +143,8 @@ type router struct { mx sync.Mutex conf *Config logger *logging.Logger - n *snet.Network - sl *snet.Listener + sl *dmsg.Listener + dmsgC *dmsg.Client trustedVisors map[cipher.PubKey]struct{} tm *transport.Manager rt routing.Table @@ -159,10 +157,10 @@ type router struct { } // New constructs a new Router. -func New(n *snet.Network, config *Config) (Router, error) { +func New(dmsgC *dmsg.Client, config *Config) (Router, error) { config.SetDefaults() - sl, err := n.Listen(dmsg.Type, skyenv.DmsgAwaitSetupPort) + sl, err := dmsgC.Listen(skyenv.DmsgAwaitSetupPort) if err != nil { return nil, err } @@ -175,10 +173,10 @@ func New(n *snet.Network, config *Config) (Router, error) { r := &router{ conf: config, logger: config.Logger, - n: n, tm: config.TransportManager, rt: routing.NewTable(), sl: sl, + dmsgC: dmsgC, rgsNs: make(map[routing.RouteDescriptor]*NoiseRouteGroup), rgsRaw: make(map[routing.RouteDescriptor]*RouteGroup), rpcSrv: rpc.NewServer(), @@ -232,7 +230,7 @@ func (r *router) DialRoutes( Reverse: reversePath, } - rules, err := r.conf.RouteGroupDialer.Dial(ctx, r.logger, r.n, r.conf.SetupNodes, req) + rules, err := r.conf.RouteGroupDialer.Dial(ctx, r.logger, r.dmsgC, r.conf.SetupNodes, req) if err != nil { r.logger.WithError(err).Error("Error dialing route group") return nil, err @@ -350,7 +348,7 @@ func (r *router) serveTransportManager(ctx context.Context) { func (r *router) serveSetup() { for { - conn, err := r.sl.AcceptConn() + conn, err := r.sl.AcceptStream() if err != nil { log := r.logger.WithError(err) if err == dmsg.ErrEntityClosed { @@ -361,12 +359,13 @@ func (r *router) serveSetup() { return } - if !r.SetupIsTrusted(conn.RemotePK()) { + remotePK := conn.RawRemoteAddr().PK + if !r.SetupIsTrusted(remotePK) { r.logger.Warnf("closing conn from untrusted setup node: %v", conn.Close()) continue } - r.logger.Infof("handling setup request: setupPK(%s)", conn.RemotePK()) + r.logger.Infof("handling setup request: setupPK(%s)", remotePK) go r.rpcSrv.ServeConn(conn) } @@ -451,7 +450,7 @@ func (r *router) saveRouteGroupRules(rules routing.EdgeRules, nsConf noise.Confi if rg.encrypt { // wrapping rg with noise - wrappedRG, err := noisewrapper.WrapConn(nsConf, rg) + wrappedRG, err := network.EncryptConn(nsConf, rg) if err != nil { r.logger.WithError(err).Errorf("Failed to wrap route group (%s): %v, closing...", &rules.Desc, err) if err := rg.Close(); err != nil { diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index df22a96c0..5140908b8 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -1,30 +1,14 @@ package router import ( - "context" "fmt" "log" - "net" "os" - "sync" "testing" - "time" - "github.com/google/uuid" "github.com/sirupsen/logrus" - "github.com/skycoin/dmsg" - "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "github.com/skycoin/skywire/internal/testhelpers" - "github.com/skycoin/skywire/pkg/routefinder/rfclient" - "github.com/skycoin/skywire/pkg/routing" - "github.com/skycoin/skywire/pkg/setup/setupclient" - "github.com/skycoin/skywire/pkg/snet" - "github.com/skycoin/skywire/pkg/snet/snettest" "github.com/skycoin/skywire/pkg/transport" ) @@ -44,732 +28,6 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -// Test ensures that we can establish connection between 2 routers. 1st router dials -// the 2nd one, 2nd one accepts. We get 2 noise-wrapped route groups and check that -// these route groups correctly communicate with each other. -func Test_router_NoiseRouteGroups(t *testing.T) { - // We're doing 2 key pairs for 2 communicating routers. - keys := snettest.GenKeyPairs(2) - - desc := routing.NewRouteDescriptor(keys[0].PK, keys[1].PK, 1, 1) - - forwardHops := []routing.Hop{ - {From: keys[0].PK, To: keys[1].PK, TpID: transport.MakeTransportID(keys[0].PK, keys[1].PK, dmsg.Type)}, - } - - reverseHops := []routing.Hop{ - {From: keys[1].PK, To: keys[0].PK, TpID: transport.MakeTransportID(keys[1].PK, keys[0].PK, dmsg.Type)}, - } - - // Route that will be established - route := routing.BidirectionalRoute{ - Desc: desc, - KeepAlive: DefaultRouteKeepAlive, - Forward: forwardHops, - Reverse: reverseHops, - } - - // Create test env - nEnv := snettest.NewEnv(t, keys, []string{dmsg.Type}) - defer nEnv.Teardown() - - tpD := transport.NewDiscoveryMock() - - // Prepare transports - m0, m1, _, _, err := transport.CreateTransportPair(tpD, keys[:2], nEnv, dmsg.Type) - require.NoError(t, err) - - forward := [2]cipher.PubKey{keys[0].PK, keys[1].PK} - backward := [2]cipher.PubKey{keys[1].PK, keys[0].PK} - - // Paths to be returned from route finder - rfPaths := make(map[routing.PathEdges][][]routing.Hop) - rfPaths[forward] = append(rfPaths[forward], forwardHops) - rfPaths[backward] = append(rfPaths[backward], reverseHops) - - rfCl := &rfclient.MockClient{} - rfCl.On("FindRoutes", mock.Anything, []routing.PathEdges{forward, backward}, - &rfclient.RouteOptions{MinHops: minHops, MaxHops: maxHops}).Return(rfPaths, testhelpers.NoErr) - - r0Logger := logging.MustGetLogger(fmt.Sprintf("router_%d", 0)) - - fwdRt, revRt := route.ForwardAndReverse() - srcPK := route.Desc.SrcPK() - dstPK := route.Desc.DstPK() - - fwdRules0 := routing.ForwardRule(route.KeepAlive, 1, 2, forwardHops[0].TpID, srcPK, dstPK, 1, 1) - revRules0 := routing.ConsumeRule(route.KeepAlive, 3, srcPK, dstPK, 1, 1) - - // Edge rules to be returned from route group dialer - initEdge := routing.EdgeRules{Desc: revRt.Desc, Forward: fwdRules0, Reverse: revRules0} - - setupCl0 := &setupclient.MockRouteGroupDialer{} - setupCl0.On("Dial", mock.Anything, r0Logger, nEnv.Nets[0], mock.Anything, route). - Return(initEdge, testhelpers.NoErr) - - r0Conf := &Config{ - Logger: r0Logger, - PubKey: keys[0].PK, - SecKey: keys[0].SK, - TransportManager: m0, - RouteFinder: rfCl, - RouteGroupDialer: setupCl0, - } - - // Create routers - r0Ifc, err := New(nEnv.Nets[0], r0Conf) - require.NoError(t, err) - - r0, ok := r0Ifc.(*router) - require.True(t, ok) - - r1Conf := &Config{ - Logger: logging.MustGetLogger(fmt.Sprintf("router_%d", 1)), - PubKey: keys[1].PK, - SecKey: keys[1].SK, - TransportManager: m1, - } - - r1Ifc, err := New(nEnv.Nets[1], r1Conf) - require.NoError(t, err) - - r1, ok := r1Ifc.(*router) - require.True(t, ok) - - ctx := context.Background() - - nrg1IfcCh := make(chan net.Conn) - acceptErrCh := make(chan error) - go func() { - nrg1Ifc, err := r1.AcceptRoutes(ctx) - acceptErrCh <- err - nrg1IfcCh <- nrg1Ifc - close(acceptErrCh) - close(nrg1IfcCh) - }() - - dialErrCh := make(chan error) - nrg0IfcCh := make(chan net.Conn) - go func() { - nrg0Ifc, err := r0.DialRoutes(context.Background(), r1.conf.PubKey, 1, 1, nil) - dialErrCh <- err - nrg0IfcCh <- nrg0Ifc - close(dialErrCh) - close(nrg0IfcCh) - }() - - fwdRules1 := routing.ForwardRule(route.KeepAlive, 4, 3, reverseHops[0].TpID, dstPK, srcPK, 1, 1) - revRules1 := routing.ConsumeRule(route.KeepAlive, 2, dstPK, srcPK, 1, 1) - - // This edge is returned by the setup node to accepting router - respEdge := routing.EdgeRules{Desc: fwdRt.Desc, Forward: fwdRules1, Reverse: revRules1} - - // Unblock AcceptRoutes, imitates setup node request with EdgeRules - r1.accept <- respEdge - - // At some point raw route group gets into `rgsRaw` and waits for - // handshake packets. we're waiting for this moment in the cycle - // to start passing packets from the transport to route group - for { - r0.mx.Lock() - if _, ok := r0.rgsRaw[initEdge.Desc]; ok { - rg := r0.rgsRaw[initEdge.Desc] - go pushPackets(ctx, m0, rg) - r0.mx.Unlock() - break - } - r0.mx.Unlock() - } - - for { - r1.mx.Lock() - if _, ok := r1.rgsRaw[respEdge.Desc]; ok { - rg := r1.rgsRaw[respEdge.Desc] - go pushPackets(ctx, m1, rg) - r1.mx.Unlock() - break - } - r1.mx.Unlock() - } - - require.NoError(t, <-acceptErrCh) - require.NoError(t, <-dialErrCh) - - nrg0Ifc := <-nrg0IfcCh - require.NotNil(t, nrg0Ifc) - nrg1Ifc := <-nrg1IfcCh - require.NotNil(t, nrg1Ifc) - - nrg0, ok := nrg0Ifc.(*NoiseRouteGroup) - require.True(t, ok) - require.NotNil(t, nrg0) - - nrg1, ok := nrg1Ifc.(*NoiseRouteGroup) - require.True(t, ok) - require.NotNil(t, nrg1) - - data := []byte("Hello there!") - n, err := nrg0.Write(data) - require.NoError(t, err) - require.Equal(t, len(data), n) - - received := make([]byte, 1024) - n, err = nrg1.Read(received) - require.NoError(t, err) - require.Equal(t, len(data), n) - require.Equal(t, data, received[:n]) - - require.True(t, nrg0.IsAlive()) - require.True(t, nrg1.IsAlive()) - - err = nrg0.Close() - require.NoError(t, err) - - require.False(t, nrg0.IsAlive()) - require.False(t, nrg1.IsAlive()) - - require.True(t, nrg1.rg.isRemoteClosed()) - err = nrg1.Close() - require.NoError(t, err) - require.False(t, nrg1.IsAlive()) -} - -func TestRouter_Serve(t *testing.T) { - // We are generating two key pairs - one for the a `Router`, the other to send packets to `Router`. - keys := snettest.GenKeyPairs(2) - - // create test env - nEnv := snettest.NewEnv(t, keys, []string{dmsg.Type}) - defer nEnv.Teardown() - - rEnv := NewTestEnv(t, nEnv.Nets) - defer rEnv.Teardown() - - // Create routers - r0Ifc, err := New(nEnv.Nets[0], rEnv.GenRouterConfig(0)) - require.NoError(t, err) - - r0, ok := r0Ifc.(*router) - require.True(t, ok) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - require.NoError(t, r0.tm.Close()) - require.NoError(t, r0.Serve(ctx)) -} - -const ruleKeepAlive = 1 * time.Hour - -// Ensure that received packets are handled properly in `(*Router).handleTransportPacket()`. -func TestRouter_handleTransportPacket(t *testing.T) { - // We are generating two key pairs - one for the a `Router`, the other to send packets to `Router`. - keys := snettest.GenKeyPairs(2) - - pk1 := keys[0].PK - pk2 := keys[1].PK - - // create test env - nEnv := snettest.NewEnv(t, keys, []string{dmsg.Type}) - defer nEnv.Teardown() - - rEnv := NewTestEnv(t, nEnv.Nets) - defer rEnv.Teardown() - - // Create routers - r0Ifc, err := New(nEnv.Nets[0], rEnv.GenRouterConfig(0)) - require.NoError(t, err) - - r0, ok := r0Ifc.(*router) - require.True(t, ok) - - r1Ifc, err := New(nEnv.Nets[1], rEnv.GenRouterConfig(1)) - require.NoError(t, err) - - r1, ok := r1Ifc.(*router) - require.True(t, ok) - - defer func() { - require.NoError(t, r0.Close()) - require.NoError(t, r1.Close()) - }() - - // Create dmsg transport between two `snet.Network` entities. - tp1, err := rEnv.TpMngrs[1].SaveTransport(context.TODO(), pk1, dmsg.Type, transport.LabelUser) - require.NoError(t, err) - - testHandlePackets(t, r0, r1, tp1, pk1, pk2) -} - -func testHandlePackets(t *testing.T, r0, r1 *router, tp1 *transport.ManagedTransport, pk1, pk2 cipher.PubKey) { - var wg sync.WaitGroup - - wg.Add(1) - t.Run("handlePacket_fwdRule", func(t *testing.T) { - defer wg.Done() - - testForwardRule(t, r0, r1, tp1, pk1, pk2) - }) - wg.Wait() - - wg.Add(1) - t.Run("handlePacket_intFwdRule", func(t *testing.T) { - defer wg.Done() - - testIntermediaryForwardRule(t, r0, r1, tp1) - }) - wg.Wait() - - wg.Add(1) - t.Run("handlePacket_cnsmRule", func(t *testing.T) { - defer wg.Done() - - testConsumeRule(t, r0, r1, tp1, pk1, pk2) - }) - wg.Wait() - - wg.Add(1) - t.Run("handlePacket_close_initiator", func(t *testing.T) { - defer wg.Done() - - testClosePacketInitiator(t, r0, r1, pk1, pk2, tp1) - }) - wg.Wait() - - wg.Add(1) - t.Run("handlePacket_close_remote", func(t *testing.T) { - defer wg.Done() - - testClosePacketRemote(t, r0, r1, pk1, pk2, tp1) - }) - wg.Wait() - - wg.Add(1) - t.Run("handlePacket_keepalive", func(t *testing.T) { - defer wg.Done() - - testKeepAlivePacket(t, r0, r1, pk1, pk2) - }) - wg.Wait() -} - -func testKeepAlivePacket(t *testing.T, r0, r1 *router, pk1, pk2 cipher.PubKey) { - defer clearRouterRules(r0, r1) - defer clearRouteGroups(r0, r1) - - rtIDs, err := r0.ReserveKeys(1) - require.NoError(t, err) - - rtID := rtIDs[0] - - cnsmRule := routing.ConsumeRule(100*time.Millisecond, rtID, pk2, pk1, 0, 0) - err = r0.rt.SaveRule(cnsmRule) - require.NoError(t, err) - require.Len(t, r0.rt.AllRules(), 1) - - time.Sleep(10 * time.Millisecond) - - packet := routing.MakeKeepAlivePacket(rtIDs[0]) - require.NoError(t, r0.handleTransportPacket(context.TODO(), packet)) - - require.Len(t, r0.rt.AllRules(), 1) - time.Sleep(10 * time.Millisecond) - require.Len(t, r0.rt.AllRules(), 1) - - time.Sleep(200 * time.Millisecond) - require.Len(t, r0.rt.AllRules(), 0) -} - -func testClosePacketRemote(t *testing.T, r0, r1 *router, pk1, pk2 cipher.PubKey, tp1 *transport.ManagedTransport) { - defer clearRouterRules(r0, r1) - defer clearRouteGroups(r0, r1) - - // reserve FWD IDs for r0. - intFwdID, err := r0.ReserveKeys(1) - require.NoError(t, err) - - // reserve FWD and CNSM IDs for r1. - r1RtIDs, err := r1.ReserveKeys(2) - require.NoError(t, err) - - intFwdRule := routing.IntermediaryForwardRule(1*time.Hour, intFwdID[0], r1RtIDs[1], tp1.Entry.ID) - err = r0.rt.SaveRule(intFwdRule) - require.NoError(t, err) - - routeID := routing.RouteID(7) - fwdRule := routing.ForwardRule(ruleKeepAlive, r1RtIDs[0], routeID, tp1.Entry.ID, pk1, pk2, 0, 0) - cnsmRule := routing.ConsumeRule(ruleKeepAlive, r1RtIDs[1], pk2, pk1, 0, 0) - - err = r1.rt.SaveRule(fwdRule) - require.NoError(t, err) - - err = r1.rt.SaveRule(cnsmRule) - require.NoError(t, err) - - fwdRtDesc := fwdRule.RouteDescriptor() - - rules := routing.EdgeRules{ - Desc: fwdRtDesc.Invert(), - Forward: fwdRule, - Reverse: cnsmRule, - } - - rg1 := NewRouteGroup(DefaultRouteGroupConfig(), r1.rt, rules.Desc) - rg1.appendRules(rules.Forward, rules.Reverse, r1.tm.Transport(rules.Forward.NextTransportID())) - - nrg1 := &NoiseRouteGroup{rg: rg1} - r1.rgsNs[rg1.desc] = nrg1 - - packet := routing.MakeClosePacket(intFwdID[0], routing.CloseRequested) - err = r0.handleTransportPacket(context.TODO(), packet) - require.NoError(t, err) - - recvPacket, err := r1.tm.ReadPacket() - require.NoError(t, err) - require.Equal(t, packet.Size(), recvPacket.Size()) - require.Equal(t, packet.Payload(), recvPacket.Payload()) - require.Equal(t, packet.Type(), recvPacket.Type()) - require.Equal(t, r1RtIDs[1], recvPacket.RouteID()) - - err = r1.handleTransportPacket(context.TODO(), recvPacket) - require.NoError(t, err) - - require.True(t, nrg1.rg.isRemoteClosed()) - require.False(t, nrg1.isClosed()) - require.Len(t, r1.rgsNs, 0) - require.Len(t, r0.rt.AllRules(), 0) - require.Len(t, r1.rt.AllRules(), 0) -} - -func testClosePacketInitiator(t *testing.T, r0, r1 *router, pk1, pk2 cipher.PubKey, tp1 *transport.ManagedTransport) { - defer clearRouterRules(r0, r1) - defer clearRouteGroups(r0, r1) - - // reserve FWD IDs for r0. - intFwdID, err := r0.ReserveKeys(1) - require.NoError(t, err) - - // reserve FWD and CNSM IDs for r1. - r1RtIDs, err := r1.ReserveKeys(2) - require.NoError(t, err) - - intFwdRule := routing.IntermediaryForwardRule(1*time.Hour, intFwdID[0], r1RtIDs[1], tp1.Entry.ID) - err = r0.rt.SaveRule(intFwdRule) - require.NoError(t, err) - - routeID := routing.RouteID(7) - fwdRule := routing.ForwardRule(ruleKeepAlive, r1RtIDs[0], routeID, tp1.Entry.ID, pk1, pk2, 0, 0) - cnsmRule := routing.ConsumeRule(ruleKeepAlive, r1RtIDs[1], pk2, pk1, 0, 0) - - err = r1.rt.SaveRule(fwdRule) - require.NoError(t, err) - - err = r1.rt.SaveRule(cnsmRule) - require.NoError(t, err) - - fwdRtDesc := fwdRule.RouteDescriptor() - - rules := routing.EdgeRules{ - Desc: fwdRtDesc.Invert(), - Forward: fwdRule, - Reverse: cnsmRule, - } - - rg1 := NewRouteGroup(DefaultRouteGroupConfig(), r1.rt, rules.Desc) - rg1.appendRules(rules.Forward, rules.Reverse, r1.tm.Transport(rules.Forward.NextTransportID())) - - nrg1 := &NoiseRouteGroup{rg: rg1} - r1.rgsNs[rg1.desc] = nrg1 - - packet := routing.MakeClosePacket(intFwdID[0], routing.CloseRequested) - err = r0.handleTransportPacket(context.TODO(), packet) - require.NoError(t, err) - - recvPacket, err := r1.tm.ReadPacket() - require.NoError(t, err) - require.Equal(t, packet.Size(), recvPacket.Size()) - require.Equal(t, packet.Payload(), recvPacket.Payload()) - require.Equal(t, packet.Type(), recvPacket.Type()) - require.Equal(t, r1RtIDs[1], recvPacket.RouteID()) - - rg1.closeDone.Add(1) - rg1.closeInitiated = 1 - - err = r1.handleTransportPacket(context.TODO(), recvPacket) - require.NoError(t, err) - - require.Len(t, r1.rgsNs, 0) - require.Len(t, r0.rt.AllRules(), 0) - // since this is the close initiator but the close routine wasn't called, - // forward rule is left - require.Len(t, r1.rt.AllRules(), 1) -} - -// TEST: Ensure handleTransportPacket does as expected. -// After setting a rule in r0, r0 should forward a packet to r1 (as specified in the given rule) -// when r0.handleTransportPacket() is called. -func testForwardRule(t *testing.T, r0, r1 *router, tp1 *transport.ManagedTransport, pk1, pk2 cipher.PubKey) { - defer clearRouterRules(r0, r1) - defer clearRouteGroups(r0, r1) - - // Add a FWD rule for r0. - fwdRtID, err := r0.ReserveKeys(1) - require.NoError(t, err) - - routeID := routing.RouteID(1) - fwdRule := routing.ForwardRule(ruleKeepAlive, fwdRtID[0], routeID, tp1.Entry.ID, pk1, pk2, 0, 0) - err = r0.rt.SaveRule(fwdRule) - require.NoError(t, err) - - rules := routing.EdgeRules{Desc: fwdRule.RouteDescriptor(), Forward: fwdRule, Reverse: nil} - rg0 := NewRouteGroup(DefaultRouteGroupConfig(), r0.rt, rules.Desc) - rg0.appendRules(rules.Forward, rules.Reverse, r0.tm.Transport(rules.Forward.NextTransportID())) - - nrg0 := &NoiseRouteGroup{rg: rg0} - r0.rgsNs[rg0.desc] = nrg0 - - // Call handleTransportPacket for r0 (this should in turn, use the rule we added). - packet, err := routing.MakeDataPacket(fwdRtID[0], []byte("This is a test!")) - require.NoError(t, err) - - require.NoError(t, r0.handleTransportPacket(context.TODO(), packet)) - - // r1 should receive the packet handled by r0. - recvPacket, err := r1.tm.ReadPacket() - assert.NoError(t, err) - assert.Equal(t, packet.Size(), recvPacket.Size()) - assert.Equal(t, packet.Payload(), recvPacket.Payload()) - assert.Equal(t, routeID, recvPacket.RouteID()) -} - -func testIntermediaryForwardRule(t *testing.T, r0, r1 *router, tp1 *transport.ManagedTransport) { - defer clearRouterRules(r0, r1) - defer clearRouteGroups(r0, r1) - - // Add a FWD rule for r0. - fwdRtID, err := r0.ReserveKeys(1) - require.NoError(t, err) - - fwdRule := routing.IntermediaryForwardRule(ruleKeepAlive, fwdRtID[0], routing.RouteID(5), tp1.Entry.ID) - err = r0.rt.SaveRule(fwdRule) - require.NoError(t, err) - - // Call handleTransportPacket for r0 (this should in turn, use the rule we added). - packet, err := routing.MakeDataPacket(fwdRtID[0], []byte("This is a test!")) - require.NoError(t, err) - - require.NoError(t, r0.handleTransportPacket(context.TODO(), packet)) - - // r1 should receive the packet handled by r0. - recvPacket, err := r1.tm.ReadPacket() - assert.NoError(t, err) - assert.Equal(t, packet.Size(), recvPacket.Size()) - assert.Equal(t, packet.Payload(), recvPacket.Payload()) - assert.Equal(t, routing.RouteID(5), recvPacket.RouteID()) -} - -func testConsumeRule(t *testing.T, r0, r1 *router, tp1 *transport.ManagedTransport, pk1, pk2 cipher.PubKey) { - defer clearRouterRules(r0, r1) - defer clearRouteGroups(r0, r1) - - // one for consume rule and one for reverse forward rule - dstRtIDs, err := r1.ReserveKeys(2) - require.NoError(t, err) - - intFwdRtID, err := r0.ReserveKeys(1) - require.NoError(t, err) - - intFwdRule := routing.IntermediaryForwardRule(ruleKeepAlive, intFwdRtID[0], dstRtIDs[1], tp1.Entry.ID) - err = r0.rt.SaveRule(intFwdRule) - require.NoError(t, err) - - routeID := routing.RouteID(7) - fwdRule := routing.ForwardRule(ruleKeepAlive, dstRtIDs[0], routeID, tp1.Entry.ID, pk1, pk2, 0, 0) - cnsmRule := routing.ConsumeRule(ruleKeepAlive, dstRtIDs[1], pk2, pk1, 0, 0) - - err = r1.rt.SaveRule(fwdRule) - require.NoError(t, err) - - err = r1.rt.SaveRule(cnsmRule) - require.NoError(t, err) - - fwdRtDesc := fwdRule.RouteDescriptor() - - rules := routing.EdgeRules{ - Desc: fwdRtDesc.Invert(), - Forward: fwdRule, - Reverse: cnsmRule, - } - - rg1 := NewRouteGroup(DefaultRouteGroupConfig(), r1.rt, rules.Desc) - rg1.appendRules(rules.Forward, rules.Reverse, r1.tm.Transport(rules.Forward.NextTransportID())) - - nrg1 := &NoiseRouteGroup{rg: rg1} - r1.rgsNs[rg1.desc] = nrg1 - - packet, err := routing.MakeDataPacket(intFwdRtID[0], []byte("test intermediary forward")) - require.NoError(t, err) - - require.NoError(t, r0.handleTransportPacket(context.TODO(), packet)) - - recvPacket, err := r1.tm.ReadPacket() - assert.NoError(t, err) - assert.Equal(t, packet.Size(), recvPacket.Size()) - assert.Equal(t, packet.Payload(), recvPacket.Payload()) - assert.Equal(t, dstRtIDs[1], recvPacket.RouteID()) - - consumeMsg := []byte("test_consume") - packet, err = routing.MakeDataPacket(dstRtIDs[1], consumeMsg) - require.NoError(t, err) - - require.NoError(t, r1.handleTransportPacket(context.TODO(), packet)) - - nrg, ok := r1.noiseRouteGroup(fwdRtDesc.Invert()) - require.True(t, ok) - require.NotNil(t, nrg) - - data := <-nrg.rg.readCh - require.Equal(t, consumeMsg, data) -} - -func TestRouter_Rules(t *testing.T) { - pk, sk := cipher.GenerateKeyPair() - - env := snettest.NewEnv(t, []snettest.KeyPair{{PK: pk, SK: sk}}, []string{dmsg.Type}) - defer env.Teardown() - - rt := routing.NewTable() - - // We are generating two key pairs - one for the a `Router`, the other to send packets to `Router`. - keys := snettest.GenKeyPairs(2) - - // create test env - nEnv := snettest.NewEnv(t, keys, []string{dmsg.Type}) - defer nEnv.Teardown() - - rEnv := NewTestEnv(t, nEnv.Nets) - defer rEnv.Teardown() - - rIfc, err := New(nEnv.Nets[0], rEnv.GenRouterConfig(0)) - require.NoError(t, err) - - r, ok := rIfc.(*router) - require.True(t, ok) - - defer func() { - require.NoError(t, r.Close()) - }() - - r.rt = rt - - // TEST: Set and get expired and unexpired rule. - t.Run("GetRule", func(t *testing.T) { - testGetRule(t, r, rt) - }) - - // TEST: Ensure removing route descriptor works properly. - t.Run("RemoveRouteDescriptor", func(t *testing.T) { - testRemoveRouteDescriptor(t, r, rt) - }) -} - -func testRemoveRouteDescriptor(t *testing.T, r *router, rt routing.Table) { - clearRoutingTableRules(rt) - - localPK, _ := cipher.GenerateKeyPair() - remotePK, _ := cipher.GenerateKeyPair() - - id, err := r.rt.ReserveKeys(1) - require.NoError(t, err) - - rule := routing.ConsumeRule(10*time.Minute, id[0], localPK, remotePK, 2, 3) - err = r.rt.SaveRule(rule) - require.NoError(t, err) - - desc := routing.NewRouteDescriptor(localPK, remotePK, 3, 2) - r.RemoveRouteDescriptor(desc) - assert.Equal(t, 1, rt.Count()) - - desc = routing.NewRouteDescriptor(localPK, remotePK, 2, 3) - r.RemoveRouteDescriptor(desc) - assert.Equal(t, 0, rt.Count()) -} - -func testGetRule(t *testing.T, r *router, rt routing.Table) { - clearRoutingTableRules(rt) - - expiredID, err := r.rt.ReserveKeys(1) - require.NoError(t, err) - - expiredRule := routing.IntermediaryForwardRule(-10*time.Minute, expiredID[0], 3, uuid.New()) - err = r.rt.SaveRule(expiredRule) - require.NoError(t, err) - - id, err := r.rt.ReserveKeys(1) - require.NoError(t, err) - - rule := routing.IntermediaryForwardRule(10*time.Minute, id[0], 3, uuid.New()) - err = r.rt.SaveRule(rule) - require.NoError(t, err) - - defer r.rt.DelRules([]routing.RouteID{id[0], expiredID[0]}) - - // rule should already be expired at this point due to the execution time. - // However, we'll just a bit to be sure - time.Sleep(1 * time.Millisecond) - - _, err = r.GetRule(expiredID[0]) - require.Error(t, err) - - _, err = r.GetRule(123) - require.Error(t, err) - - gotRule, err := r.GetRule(id[0]) - require.NoError(t, err) - assert.Equal(t, rule, gotRule) -} - -func TestRouter_SetupIsTrusted(t *testing.T) { - keys := snettest.GenKeyPairs(2) - - nEnv := snettest.NewEnv(t, keys, []string{dmsg.Type}) - defer nEnv.Teardown() - - rEnv := NewTestEnv(t, nEnv.Nets) - defer rEnv.Teardown() - - routerConfig := rEnv.GenRouterConfig(0) - routerConfig.SetupNodes = append(routerConfig.SetupNodes, keys[0].PK) - - r0, err := New(nEnv.Nets[0], routerConfig) - require.NoError(t, err) - - assert.True(t, r0.SetupIsTrusted(keys[0].PK)) - assert.False(t, r0.SetupIsTrusted(keys[1].PK)) -} - -func clearRouteGroups(routers ...*router) { - for _, r := range routers { - r.rgsNs = make(map[routing.RouteDescriptor]*NoiseRouteGroup) - } -} - -func clearRouterRules(routers ...*router) { - for _, r := range routers { - rules := r.rt.AllRules() - for _, rule := range rules { - r.rt.DelRules([]routing.RouteID{rule.KeyRouteID()}) - } - } -} - -func clearRoutingTableRules(rt routing.Table) { - rules := rt.AllRules() - for _, rule := range rules { - rt.DelRules([]routing.RouteID{rule.KeyRouteID()}) - } -} - type TestEnv struct { TpD transport.DiscoveryClient @@ -779,42 +37,6 @@ type TestEnv struct { teardown func() } -func NewTestEnv(t *testing.T, nets []*snet.Network) *TestEnv { - tpD := transport.NewDiscoveryMock() - - mConfs := make([]*transport.ManagerConfig, len(nets)) - ms := make([]*transport.Manager, len(nets)) - - for i, n := range nets { - var err error - - mConfs[i] = &transport.ManagerConfig{ - PubKey: n.LocalPK(), - SecKey: n.LocalSK(), - DiscoveryClient: tpD, - LogStore: transport.InMemoryTransportLogStore(), - } - - ms[i], err = transport.NewManager(nil, n, mConfs[i]) - require.NoError(t, err) - - go ms[i].Serve(context.TODO()) - } - - teardown := func() { - for _, m := range ms { - assert.NoError(t, m.Close()) - } - } - - return &TestEnv{ - TpD: tpD, - TpMngrConfs: mConfs, - TpMngrs: ms, - teardown: teardown, - } -} - func (e *TestEnv) GenRouterConfig(i int) *Config { return &Config{ Logger: logging.MustGetLogger(fmt.Sprintf("router_%d", i)), diff --git a/pkg/router/routerclient/client.go b/pkg/router/routerclient/client.go index ac920464b..f1976b27f 100644 --- a/pkg/router/routerclient/client.go +++ b/pkg/router/routerclient/client.go @@ -12,7 +12,7 @@ import ( "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/skywire/pkg/skyenv" - "github.com/skycoin/skywire/pkg/snet" + "github.com/skycoin/skywire/pkg/transport/network" ) // RPCName is the RPC gateway object name. @@ -26,7 +26,7 @@ type Client struct { } // NewClient creates a new Client. -func NewClient(ctx context.Context, dialer snet.Dialer, rPK cipher.PubKey) (*Client, error) { +func NewClient(ctx context.Context, dialer network.Dialer, rPK cipher.PubKey) (*Client, error) { s, err := dialer.Dial(ctx, rPK, skyenv.DmsgAwaitSetupPort) if err != nil { return nil, fmt.Errorf("dial %v@%v: %w", rPK, skyenv.DmsgAwaitSetupPort, err) diff --git a/pkg/router/routerclient/dmsg_wrapper.go b/pkg/router/routerclient/dmsg_wrapper.go index 245f58158..1f03d240e 100644 --- a/pkg/router/routerclient/dmsg_wrapper.go +++ b/pkg/router/routerclient/dmsg_wrapper.go @@ -7,11 +7,11 @@ import ( "github.com/skycoin/dmsg" "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/skywire/pkg/snet" + "github.com/skycoin/skywire/pkg/transport/network" ) // WrapDmsgClient wraps a dmsg client to implement snet.Dialer -func WrapDmsgClient(dmsgC *dmsg.Client) snet.Dialer { +func WrapDmsgClient(dmsgC *dmsg.Client) network.Dialer { return &dmsgClientDialer{Client: dmsgC} } @@ -24,5 +24,5 @@ func (w *dmsgClientDialer) Dial(ctx context.Context, remote cipher.PubKey, port } func (w *dmsgClientDialer) Type() string { - return dmsg.Type + return string(network.DMSG) } diff --git a/pkg/router/routerclient/map.go b/pkg/router/routerclient/map.go index 7745c5927..c7397d665 100644 --- a/pkg/router/routerclient/map.go +++ b/pkg/router/routerclient/map.go @@ -5,7 +5,7 @@ import ( "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/skywire/pkg/snet" + "github.com/skycoin/skywire/pkg/transport/network" ) // Map is a map of router RPC clients associated with the router's visor PK. @@ -18,7 +18,7 @@ type dialResult struct { // MakeMap makes a Map of the router clients, where the key is the router's visor public key. // It creates these router clients by dialing to them concurrently. -func MakeMap(ctx context.Context, dialer snet.Dialer, pks []cipher.PubKey) (Map, error) { +func MakeMap(ctx context.Context, dialer network.Dialer, pks []cipher.PubKey) (Map, error) { ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/pkg/router/routerclient/map_test.go b/pkg/router/routerclient/map_test.go index f8d7b1a04..d066f1781 100644 --- a/pkg/router/routerclient/map_test.go +++ b/pkg/router/routerclient/map_test.go @@ -10,7 +10,6 @@ import ( "time" "github.com/sirupsen/logrus" - "github.com/skycoin/dmsg" "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" "github.com/stretchr/testify/assert" @@ -20,6 +19,7 @@ import ( "github.com/skycoin/skywire/pkg/router" "github.com/skycoin/skywire/pkg/routing" + "github.com/skycoin/skywire/pkg/transport/network" ) func TestMakeMap(t *testing.T) { @@ -160,5 +160,5 @@ func (d *testDialer) Dial(_ context.Context, remote cipher.PubKey, _ uint16) (ne } func (testDialer) Type() string { - return dmsg.Type + return string(network.DMSG) } diff --git a/pkg/servicedisc/autoconnect.go b/pkg/servicedisc/autoconnect.go index 0012d9fc7..07aa5abc7 100644 --- a/pkg/servicedisc/autoconnect.go +++ b/pkg/servicedisc/autoconnect.go @@ -8,8 +8,8 @@ import ( "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/internal/netutil" - "github.com/skycoin/skywire/pkg/snet/directtp/tptypes" "github.com/skycoin/skywire/pkg/transport" + "github.com/skycoin/skywire/pkg/transport/network" ) const ( @@ -59,8 +59,8 @@ func (a *autoconnector) Run(ctx context.Context) error { absent := a.filterDuplicates(addresses, tps) for _, pk := range absent { a.log.WithField("pk", pk).Infoln("Adding transport to public visor") - logger := a.log.WithField("pk", pk).WithField("type", tptypes.STCPR) - if _, err := a.tm.SaveTransport(ctx, pk, tptypes.STCPR, transport.LabelAutomatic); err != nil { + logger := a.log.WithField("pk", pk).WithField("type", string(network.STCPR)) + if _, err := a.tm.SaveTransport(ctx, pk, network.STCPR, transport.LabelAutomatic); err != nil { logger.WithError(err).Warnln("Failed to add transport to public visor") continue } diff --git a/pkg/setup/config.go b/pkg/setup/config.go index c99127240..295a49b3e 100644 --- a/pkg/setup/config.go +++ b/pkg/setup/config.go @@ -5,7 +5,7 @@ import ( "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/skywire/pkg/snet" + "github.com/skycoin/skywire/pkg/dmsgc" ) //go:generate readmegen -n Config -o ./README.md ./config.go @@ -18,9 +18,9 @@ const ( // Config defines configuration parameters for setup Node. type Config struct { - PK cipher.PubKey `json:"public_key"` - SK cipher.SecKey `json:"secret_key"` - Dmsg snet.DmsgConfig `json:"dmsg"` - TransportDiscovery string `json:"transport_discovery"` - LogLevel string `json:"log_level"` + PK cipher.PubKey `json:"public_key"` + SK cipher.SecKey `json:"secret_key"` + Dmsg dmsgc.DmsgConfig `json:"dmsg"` + TransportDiscovery string `json:"transport_discovery"` + LogLevel string `json:"log_level"` } diff --git a/pkg/setup/id_reserver.go b/pkg/setup/id_reserver.go index 03441c984..1df6a22ae 100644 --- a/pkg/setup/id_reserver.go +++ b/pkg/setup/id_reserver.go @@ -12,7 +12,7 @@ import ( "github.com/skycoin/skywire/pkg/router/routerclient" "github.com/skycoin/skywire/pkg/routing" - "github.com/skycoin/skywire/pkg/snet" + "github.com/skycoin/skywire/pkg/transport/network" ) // ErrNoKey is returned when key is not found. @@ -51,7 +51,7 @@ type idReserver struct { // NewIDReserver creates a new route ID reserver from a dialer and a slice of paths. // The exact number of route IDs to reserve from each router is determined from the slice of paths. -func NewIDReserver(ctx context.Context, dialer snet.Dialer, paths [][]routing.Hop) (IDReserver, error) { +func NewIDReserver(ctx context.Context, dialer network.Dialer, paths [][]routing.Hop) (IDReserver, error) { var total int // the total number of route IDs we reserve from the routers // Prepare 'rec': A map representing the number of expected rules per visor PK. diff --git a/pkg/setup/node.go b/pkg/setup/node.go index 7c6b5c419..2453a75d5 100644 --- a/pkg/setup/node.go +++ b/pkg/setup/node.go @@ -16,7 +16,7 @@ import ( "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/skywire/pkg/setup/setupmetrics" "github.com/skycoin/skywire/pkg/skyenv" - "github.com/skycoin/skywire/pkg/snet" + "github.com/skycoin/skywire/pkg/transport/network" ) var log = logging.MustGetLogger("setup_node") @@ -106,7 +106,7 @@ func (sn *Node) Serve(ctx context.Context, m setupmetrics.Metrics) error { // * Intermediary rules are broadcasted to the intermediary routers. // * Edge rules are broadcasted to the responding router. // * Edge rules is returned (to the initiating router). -func CreateRouteGroup(ctx context.Context, dialer snet.Dialer, biRt routing.BidirectionalRoute, metrics setupmetrics.Metrics) (resp routing.EdgeRules, err error) { +func CreateRouteGroup(ctx context.Context, dialer network.Dialer, biRt routing.BidirectionalRoute, metrics setupmetrics.Metrics) (resp routing.EdgeRules, err error) { log := logging.MustGetLogger(fmt.Sprintf("request:%s->%s", biRt.Desc.SrcPK(), biRt.Desc.DstPK())) log.Info("Processing request.") defer metrics.RecordRoute()(&err) @@ -158,7 +158,7 @@ func CreateRouteGroup(ctx context.Context, dialer snet.Dialer, biRt routing.Bidi // ReserveRouteIDs dials to all routers and reserves required route IDs from them. // The number of route IDs to be reserved per router, is extrapolated from the 'route' input. -func ReserveRouteIDs(ctx context.Context, log logrus.FieldLogger, dialer snet.Dialer, route routing.BidirectionalRoute) (idR IDReserver, err error) { +func ReserveRouteIDs(ctx context.Context, log logrus.FieldLogger, dialer network.Dialer, route routing.BidirectionalRoute) (idR IDReserver, err error) { log.Debug("Reserving route IDs...") defer func() { if err != nil { diff --git a/pkg/setup/rpc_gateway.go b/pkg/setup/rpc_gateway.go index dd8ebf750..28edfaf1f 100644 --- a/pkg/setup/rpc_gateway.go +++ b/pkg/setup/rpc_gateway.go @@ -10,7 +10,7 @@ import ( "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/skywire/pkg/setup/setupmetrics" - "github.com/skycoin/skywire/pkg/snet" + "github.com/skycoin/skywire/pkg/transport/network" ) // RPCGateway is a RPC interface for setup node. @@ -19,7 +19,7 @@ type RPCGateway struct { Ctx context.Context Conn net.Conn ReqPK cipher.PubKey - Dialer snet.Dialer + Dialer network.Dialer Timeout time.Duration } diff --git a/pkg/setup/setupclient/client.go b/pkg/setup/setupclient/client.go index 86e319b4a..a017d57b7 100644 --- a/pkg/setup/setupclient/client.go +++ b/pkg/setup/setupclient/client.go @@ -3,6 +3,7 @@ package setupclient import ( "context" "errors" + "net" "net/rpc" "github.com/skycoin/dmsg" @@ -11,7 +12,6 @@ import ( "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/skywire/pkg/skyenv" - "github.com/skycoin/skywire/pkg/snet" ) const rpcName = "RPCGateway" @@ -19,21 +19,19 @@ const rpcName = "RPCGateway" // Client is an RPC client for setup node. type Client struct { log *logging.Logger - n *snet.Network setupNodes []cipher.PubKey - conn *snet.Conn + conn net.Conn rpc *rpc.Client } // NewClient creates a new Client. -func NewClient(ctx context.Context, log *logging.Logger, n *snet.Network, setupNodes []cipher.PubKey) (*Client, error) { +func NewClient(ctx context.Context, log *logging.Logger, dmsgC *dmsg.Client, setupNodes []cipher.PubKey) (*Client, error) { client := &Client{ log: log, - n: n, setupNodes: setupNodes, } - conn, err := client.dial(ctx) + conn, err := client.dial(ctx, dmsgC) if err != nil { return nil, err } @@ -45,9 +43,10 @@ func NewClient(ctx context.Context, log *logging.Logger, n *snet.Network, setupN return client, nil } -func (c *Client) dial(ctx context.Context) (*snet.Conn, error) { +func (c *Client) dial(ctx context.Context, dmsgC *dmsg.Client) (net.Conn, error) { for _, sPK := range c.setupNodes { - conn, err := c.n.Dial(ctx, dmsg.Type, sPK, skyenv.DmsgSetupPort) + addr := dmsg.Addr{PK: sPK, Port: skyenv.DmsgSetupPort} + conn, err := dmsgC.Dial(ctx, addr) if err != nil { c.log.WithError(err).Warnf("failed to dial to setup node: setupPK(%s)", sPK) continue diff --git a/pkg/setup/setupclient/mock_route_group_dialer.go b/pkg/setup/setupclient/mock_route_group_dialer.go deleted file mode 100644 index 6d098af92..000000000 --- a/pkg/setup/setupclient/mock_route_group_dialer.go +++ /dev/null @@ -1,40 +0,0 @@ -// Code generated by mockery v1.0.0. DO NOT EDIT. - -package setupclient - -import ( - context "context" - - cipher "github.com/skycoin/dmsg/cipher" - logging "github.com/skycoin/skycoin/src/util/logging" - mock "github.com/stretchr/testify/mock" - - routing "github.com/skycoin/skywire/pkg/routing" - snet "github.com/skycoin/skywire/pkg/snet" -) - -// MockRouteGroupDialer is an autogenerated mock type for the RouteGroupDialer type -type MockRouteGroupDialer struct { - mock.Mock -} - -// Dial provides a mock function with given fields: ctx, log, n, setupNodes, req -func (_m *MockRouteGroupDialer) Dial(ctx context.Context, log *logging.Logger, n *snet.Network, setupNodes []cipher.PubKey, req routing.BidirectionalRoute) (routing.EdgeRules, error) { - ret := _m.Called(ctx, log, n, setupNodes, req) - - var r0 routing.EdgeRules - if rf, ok := ret.Get(0).(func(context.Context, *logging.Logger, *snet.Network, []cipher.PubKey, routing.BidirectionalRoute) routing.EdgeRules); ok { - r0 = rf(ctx, log, n, setupNodes, req) - } else { - r0 = ret.Get(0).(routing.EdgeRules) - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *logging.Logger, *snet.Network, []cipher.PubKey, routing.BidirectionalRoute) error); ok { - r1 = rf(ctx, log, n, setupNodes, req) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} diff --git a/pkg/setup/setupclient/wrappers.go b/pkg/setup/setupclient/wrappers.go index 96bb4f5aa..030f91a81 100644 --- a/pkg/setup/setupclient/wrappers.go +++ b/pkg/setup/setupclient/wrappers.go @@ -4,11 +4,11 @@ import ( "context" "fmt" + "github.com/skycoin/dmsg" "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/routing" - "github.com/skycoin/skywire/pkg/snet" ) //go:generate mockery -name RouteGroupDialer -case underscore -inpkg @@ -18,7 +18,7 @@ type RouteGroupDialer interface { Dial( ctx context.Context, log *logging.Logger, - n *snet.Network, + dmsgC *dmsg.Client, setupNodes []cipher.PubKey, req routing.BidirectionalRoute, ) (routing.EdgeRules, error) @@ -35,11 +35,11 @@ func NewSetupNodeDialer() RouteGroupDialer { func (d *setupNodeDialer) Dial( ctx context.Context, log *logging.Logger, - n *snet.Network, + dmsgC *dmsg.Client, setupNodes []cipher.PubKey, req routing.BidirectionalRoute, ) (routing.EdgeRules, error) { - client, err := NewClient(ctx, log, n, setupNodes) + client, err := NewClient(ctx, log, dmsgC, setupNodes) if err != nil { return routing.EdgeRules{}, err } diff --git a/pkg/setup/testing_test.go b/pkg/setup/testing_test.go index e9d263eb8..b07dd14a7 100644 --- a/pkg/setup/testing_test.go +++ b/pkg/setup/testing_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/skycoin/dmsg" "github.com/skycoin/dmsg/cipher" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -17,11 +16,11 @@ import ( "github.com/skycoin/skywire/pkg/router/routerclient" "github.com/skycoin/skywire/pkg/routing" - "github.com/skycoin/skywire/pkg/snet" + "github.com/skycoin/skywire/pkg/transport/network" ) // creates a mock dialer -func newMockDialer(t *testing.T, gateways map[cipher.PubKey]interface{}) snet.Dialer { +func newMockDialer(t *testing.T, gateways map[cipher.PubKey]interface{}) network.Dialer { newRPCConn := func(gw interface{}) net.Conn { connC, connS := net.Pipe() t.Cleanup(func() { @@ -38,7 +37,7 @@ func newMockDialer(t *testing.T, gateways map[cipher.PubKey]interface{}) snet.Di if gateways == nil { conn := newRPCConn(new(mockGatewayForDialer)) - dialer := new(snet.MockDialer) + dialer := new(network.MockDialer) dialer.On("Dial", mock.Anything, mock.Anything, mock.Anything).Return(conn, nil) return dialer } @@ -52,7 +51,7 @@ func newMockDialer(t *testing.T, gateways map[cipher.PubKey]interface{}) snet.Di type mockDialer map[cipher.PubKey]net.Conn -func (d mockDialer) Type() string { return dmsg.Type } +func (d mockDialer) Type() string { return string(network.DMSG) } func (d mockDialer) Dial(_ context.Context, remote cipher.PubKey, _ uint16) (net.Conn, error) { conn, ok := d[remote] diff --git a/pkg/snet/conn.go b/pkg/snet/conn.go deleted file mode 100644 index fdaf75cb7..000000000 --- a/pkg/snet/conn.go +++ /dev/null @@ -1,38 +0,0 @@ -package snet - -import ( - "net" - - "github.com/skycoin/dmsg/cipher" -) - -// Conn represent a connection between nodes in Skywire. -type Conn struct { - net.Conn - lPK cipher.PubKey - rPK cipher.PubKey - lPort uint16 - rPort uint16 - network string -} - -func makeConn(conn net.Conn, network string) *Conn { - lPK, lPort := disassembleAddr(conn.LocalAddr()) - rPK, rPort := disassembleAddr(conn.RemoteAddr()) - return &Conn{Conn: conn, lPK: lPK, rPK: rPK, lPort: lPort, rPort: rPort, network: network} -} - -// LocalPK returns local public key of connection. -func (c Conn) LocalPK() cipher.PubKey { return c.lPK } - -// RemotePK returns remote public key of connection. -func (c Conn) RemotePK() cipher.PubKey { return c.rPK } - -// LocalPort returns local port of connection. -func (c Conn) LocalPort() uint16 { return c.lPort } - -// RemotePort returns remote port of connection. -func (c Conn) RemotePort() uint16 { return c.rPort } - -// Network returns network of connection. -func (c Conn) Network() string { return c.network } diff --git a/pkg/snet/dialer.go b/pkg/snet/dialer.go deleted file mode 100644 index 728edece3..000000000 --- a/pkg/snet/dialer.go +++ /dev/null @@ -1,16 +0,0 @@ -package snet - -import ( - "context" - "net" - - "github.com/skycoin/dmsg/cipher" -) - -//go:generate mockery -name Dialer -case underscore -inpkg - -// Dialer is an entity that can be dialed and asked for its type. -type Dialer interface { - Dial(ctx context.Context, remote cipher.PubKey, port uint16) (net.Conn, error) - Type() string -} diff --git a/pkg/snet/directtp/client.go b/pkg/snet/directtp/client.go deleted file mode 100644 index be99aa86a..000000000 --- a/pkg/snet/directtp/client.go +++ /dev/null @@ -1,537 +0,0 @@ -package directtp - -import ( - "context" - "errors" - "fmt" - "io" - "net" - "strings" - "sync" - "time" - - "github.com/AudriusButkevicius/pfilter" - "github.com/skycoin/dmsg" - "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/skycoin/src/util/logging" - "github.com/xtaci/kcp-go" - - "github.com/skycoin/skywire/internal/packetfilter" - "github.com/skycoin/skywire/pkg/snet/arclient" - "github.com/skycoin/skywire/pkg/snet/directtp/pktable" - "github.com/skycoin/skywire/pkg/snet/directtp/porter" - "github.com/skycoin/skywire/pkg/snet/directtp/tpconn" - "github.com/skycoin/skywire/pkg/snet/directtp/tphandshake" - "github.com/skycoin/skywire/pkg/snet/directtp/tplistener" - "github.com/skycoin/skywire/pkg/snet/directtp/tptypes" - "github.com/skycoin/skywire/pkg/util/netutil" -) - -const ( - // holePunchMessage is sent in a dummy UDP packet that is sent by both parties to establish UDP hole punching. - holePunchMessage = "holepunch" - dialTimeout = 30 * time.Second - // dialConnPriority and visorsConnPriority are used to set an order how connection filters apply. - dialConnPriority = 2 - visorsConnPriority = 3 -) - -var ( - // ErrUnknownTransportType is returned when transport type is unknown. - ErrUnknownTransportType = errors.New("unknown transport type") - - // ErrTimeout indicates a timeout. - ErrTimeout = errors.New("timeout") - - // ErrAlreadyListening is returned when transport is already listening. - ErrAlreadyListening = errors.New("already listening") - - // ErrNotListening is returned when transport is not listening. - ErrNotListening = errors.New("not listening") - - // ErrPortOccupied is returned when port is occupied. - ErrPortOccupied = errors.New("port is already occupied") -) - -// Client is the central control for incoming and outgoing 'Conn's. -type Client interface { - Dial(ctx context.Context, rPK cipher.PubKey, rPort uint16) (*tpconn.Conn, error) - Listen(lPort uint16) (*tplistener.Listener, error) - LocalAddr() (net.Addr, error) - Serve() error - Close() error - Type() string -} - -// Config configures Client. -type Config struct { - Type string - PK cipher.PubKey - SK cipher.SecKey - LocalAddr string - Table pktable.PKTable - AddressResolver arclient.APIClient - BeforeDialCallback BeforeDialCallback -} - -// BeforeDialCallback is triggered before client dials. -// If a non-nil error is returned, the dial is instantly terminated. -type BeforeDialCallback func(network, addr string) (err error) - -type client struct { - conf Config - mu sync.Mutex - done chan struct{} - once sync.Once - log *logging.Logger - porter *porter.Porter - listener net.Listener - listening chan struct{} - listeners map[uint16]*tplistener.Listener // key: lPort - sudphPacketFilter *pfilter.PacketFilter - sudphListener net.PacketConn - sudphVisorsConn net.PacketConn - beforeDialCallback BeforeDialCallback -} - -// NewClient creates a net Client. -func NewClient(conf Config, masterLogger *logging.MasterLogger) Client { - return &client{ - conf: conf, - log: masterLogger.PackageLogger(conf.Type), - porter: porter.New(porter.MinEphemeral), - listeners: make(map[uint16]*tplistener.Listener), - done: make(chan struct{}), - listening: make(chan struct{}), - beforeDialCallback: conf.BeforeDialCallback, - } -} - -// Serve serves the listening portion of the client. -func (c *client) Serve() error { - switch c.conf.Type { - case tptypes.STCP, tptypes.STCPR: - if c.listener != nil { - return ErrAlreadyListening - } - case tptypes.SUDPH: - if c.sudphListener != nil { - return ErrAlreadyListening - } - } - - go func() { - l, err := c.listen(c.conf.LocalAddr) - if err != nil { - c.log.Errorf("Failed to listen on %q: %v", c.conf.LocalAddr, err) - return - } - - c.listener = l - close(c.listening) - - if c.conf.Type == tptypes.STCPR { - localAddr := c.listener.Addr().String() - _, port, err := net.SplitHostPort(localAddr) - if err != nil { - c.log.Errorf("Failed to extract port from addr %v: %v", err) - return - } - hasPublic, err := netutil.HasPublicIP() - if err != nil { - c.log.Errorf("Failed to check for public IP: %v", err) - } - if !hasPublic { - c.log.Infof("Not binding STCPR: no public IP address found") - return - } - if err := c.conf.AddressResolver.BindSTCPR(context.Background(), port); err != nil { - c.log.Errorf("Failed to bind STCPR: %v", err) - return - } - } - - c.log.Infof("listening on addr: %v", c.listener.Addr()) - - for { - if err := c.acceptConn(); err != nil { - if strings.Contains(err.Error(), io.EOF.Error()) { - continue // likely it's a dummy connection from service discovery - } - - c.log.Warnf("failed to accept incoming connection: %v", err) - - if !tphandshake.IsHandshakeError(err) { - c.log.Warnf("stopped serving") - return - } - } - } - }() - - return nil -} - -func (c *client) LocalAddr() (net.Addr, error) { - <-c.listening - - switch c.conf.Type { - case tptypes.STCP, tptypes.STCPR: - if c.listener == nil { - return nil, ErrNotListening - } - - return c.listener.Addr(), nil - case tptypes.SUDPH: - if c.sudphListener == nil { - return nil, ErrNotListening - } - - return c.listener.Addr(), nil - } - - return nil, ErrUnknownTransportType -} - -func (c *client) acceptConn() error { - if c.isClosed() { - return io.ErrClosedPipe - } - - conn, err := c.listener.Accept() - if err != nil { - return err - } - - remoteAddr := conn.RemoteAddr() - - c.log.Infof("Accepted connection from %v", remoteAddr) - - var lis *tplistener.Listener - - hs := tphandshake.ResponderHandshake(func(f2 tphandshake.Frame2) error { - c.mu.Lock() - defer c.mu.Unlock() - - var ok bool - if lis, ok = c.listeners[f2.DstAddr.Port]; !ok { - return errors.New("not listening on given port") - } - - return nil - }) - - connConfig := tpconn.Config{ - Log: c.log, - Conn: conn, - LocalPK: c.conf.PK, - LocalSK: c.conf.SK, - Deadline: time.Now().Add(tphandshake.Timeout), - Handshake: hs, - FreePort: nil, - Encrypt: true, - Initiator: false, - } - - wrappedConn, err := tpconn.NewConn(connConfig) - if err != nil { - return err - } - - return lis.Introduce(wrappedConn) -} - -// Dial dials a new Conn to specified remote public key and port. -func (c *client) Dial(ctx context.Context, rPK cipher.PubKey, rPort uint16) (*tpconn.Conn, error) { - if c.isClosed() { - return nil, io.ErrClosedPipe - } - - c.log.Infof("Dialing PK %v", rPK) - - var visorConn net.Conn - - switch c.conf.Type { - case tptypes.STCP: - addr, ok := c.conf.Table.Addr(rPK) - if !ok { - return nil, fmt.Errorf("pk table: entry of %s does not exist", rPK) - } - - conn, err := c.dial(addr) - if err != nil { - return nil, err - } - - visorConn = conn - - case tptypes.STCPR, tptypes.SUDPH: - visorData, err := c.conf.AddressResolver.Resolve(ctx, c.Type(), rPK) - if err != nil { - return nil, fmt.Errorf("resolve PK: %w", err) - } - - c.log.Infof("Resolved PK %v to visor data %v", rPK, visorData) - - conn, err := c.dialVisor(ctx, visorData) - if err != nil { - return nil, err - } - - visorConn = conn - - default: - return nil, ErrUnknownTransportType - } - - c.log.Infof("Dialed %v:%v@%v", rPK, rPort, visorConn.RemoteAddr()) - - lPort, freePort, err := c.porter.ReserveEphemeral(ctx) - if err != nil { - return nil, err - } - - hs := tphandshake.InitiatorHandshake(c.conf.SK, dmsg.Addr{PK: c.conf.PK, Port: lPort}, dmsg.Addr{PK: rPK, Port: rPort}) - - connConfig := tpconn.Config{ - Log: c.log, - Conn: visorConn, - LocalPK: c.conf.PK, - LocalSK: c.conf.SK, - Deadline: time.Now().Add(tphandshake.Timeout), - Handshake: hs, - FreePort: freePort, - Encrypt: true, - Initiator: true, - } - - return tpconn.NewConn(connConfig) -} - -func (c *client) dial(addr string) (net.Conn, error) { - switch c.conf.Type { - case tptypes.STCP, tptypes.STCPR: - return net.Dial("tcp", addr) - - case tptypes.SUDPH: - return c.dialUDPWithTimeout(addr) - - default: - return nil, ErrUnknownTransportType - } -} - -func (c *client) dialContext(ctx context.Context, addr string) (net.Conn, error) { - dialer := net.Dialer{} - switch c.conf.Type { - case tptypes.STCP, tptypes.STCPR: - return dialer.DialContext(ctx, "tcp", addr) - - case tptypes.SUDPH: - return c.dialUDPWithTimeout(addr) - - default: - return nil, ErrUnknownTransportType - } -} - -func (c *client) listen(addr string) (net.Listener, error) { - switch c.conf.Type { - case tptypes.STCP, tptypes.STCPR: - return net.Listen("tcp", addr) - - case tptypes.SUDPH: - packetListener, err := net.ListenPacket("udp", "") - if err != nil { - return nil, err - } - - c.sudphListener = packetListener - - c.sudphPacketFilter = pfilter.NewPacketFilter(packetListener) - c.sudphVisorsConn = c.sudphPacketFilter.NewConn(visorsConnPriority, nil) - - c.sudphPacketFilter.Start() - - addrCh, err := c.conf.AddressResolver.BindSUDPH(c.sudphPacketFilter) - if err != nil { - return nil, err - } - - go func() { - for addr := range addrCh { - udpAddr, err := net.ResolveUDPAddr("udp", addr.Addr) - if err != nil { - c.log.WithError(err).Errorf("Failed to resolve UDP address %q", addr) - continue - } - - c.log.Infof("Sending hole punch packet to %v", addr) - - if _, err := c.sudphVisorsConn.WriteTo([]byte(holePunchMessage), udpAddr); err != nil { - c.log.WithError(err).Errorf("Failed to send hole punch packet to %v", udpAddr) - continue - } - - c.log.Infof("Sent hole punch packet to %v", addr) - } - }() - - return kcp.ServeConn(nil, 0, 0, c.sudphVisorsConn) - - default: - return nil, ErrUnknownTransportType - } -} - -func (c *client) dialUDP(remoteAddr string) (net.Conn, error) { - rAddr, err := net.ResolveUDPAddr("udp", remoteAddr) - if err != nil { - return nil, fmt.Errorf("net.ResolveUDPAddr (remote): %w", err) - } - - dialConn := c.sudphPacketFilter.NewConn(dialConnPriority, packetfilter.NewKCPConversationFilter()) - - if _, err := dialConn.WriteTo([]byte(holePunchMessage), rAddr); err != nil { - return nil, fmt.Errorf("dialConn.WriteTo: %w", err) - } - - kcpConn, err := kcp.NewConn(remoteAddr, nil, 0, 0, dialConn) - if err != nil { - return nil, err - } - - return kcpConn, nil -} - -func (c *client) dialUDPWithTimeout(addr string) (net.Conn, error) { - timer := time.NewTimer(dialTimeout) - defer timer.Stop() - - c.log.Infof("Dialing %v", addr) - - for { - select { - case <-timer.C: - return nil, ErrTimeout - default: - conn, err := c.dialUDP(addr) - if err == nil { - c.log.Infof("Dialed %v", addr) - return conn, nil - } - - c.log.WithError(err). - Warnf("Failed to dial %v, trying again: %v", addr, err) - } - } -} - -func (c *client) dialVisor(ctx context.Context, visorData arclient.VisorData) (net.Conn, error) { - if visorData.IsLocal { - for _, host := range visorData.Addresses { - addr := net.JoinHostPort(host, visorData.Port) - - if c.beforeDialCallback != nil { - if err := c.beforeDialCallback(c.conf.Type, addr); err != nil { - return nil, err - } - } - - conn, err := c.dialContext(ctx, addr) - if err == nil { - return conn, nil - } - } - } - - addr := visorData.RemoteAddr - if _, _, err := net.SplitHostPort(addr); err != nil { - addr = net.JoinHostPort(addr, visorData.Port) - } - - if c.beforeDialCallback != nil { - if err := c.beforeDialCallback(c.conf.Type, addr); err != nil { - return nil, err - } - } - - return c.dialContext(ctx, addr) -} - -// Listen creates a new listener for sudp. -// The created Listener cannot actually accept remote connections unless Serve is called beforehand. -func (c *client) Listen(lPort uint16) (*tplistener.Listener, error) { - if c.isClosed() { - return nil, io.ErrClosedPipe - } - - ok, freePort := c.porter.Reserve(lPort) - if !ok { - return nil, ErrPortOccupied - } - - c.mu.Lock() - defer c.mu.Unlock() - - lAddr := dmsg.Addr{PK: c.conf.PK, Port: lPort} - lis := tplistener.NewListener(lAddr, freePort) - c.listeners[lPort] = lis - - return lis, nil -} - -// Close closes the Client. -func (c *client) Close() error { - if c == nil { - return nil - } - - c.once.Do(func() { - close(c.done) - - c.mu.Lock() - defer c.mu.Unlock() - - if c.listener != nil { - if err := c.listener.Close(); err != nil { - c.log.WithError(err).Warnf("Failed to close listener") - } - } - - for _, lis := range c.listeners { - if err := lis.Close(); err != nil { - c.log.WithError(err).Warnf("Failed to close listener") - } - } - - switch c.Type() { - case tptypes.STCPR, tptypes.SUDPH: - if err := c.conf.AddressResolver.Close(); err != nil { - c.log.WithError(err).Warnf("Failed to close address-resolver") - } - } - - if c.sudphVisorsConn != nil { - if err := c.sudphVisorsConn.Close(); err != nil { - c.log.WithError(err).Warnf("Failed to close connection to visors") - } - } - }) - - return nil -} - -func (c *client) isClosed() bool { - select { - case <-c.done: - return true - default: - return false - } -} - -// Type returns the stream type. -func (c *client) Type() string { - return c.conf.Type -} diff --git a/pkg/snet/directtp/noisewrapper/noisewrapper.go b/pkg/snet/directtp/noisewrapper/noisewrapper.go deleted file mode 100644 index a95ad88d5..000000000 --- a/pkg/snet/directtp/noisewrapper/noisewrapper.go +++ /dev/null @@ -1,27 +0,0 @@ -package noisewrapper - -import ( - "fmt" - "net" - "time" - - "github.com/skycoin/dmsg/noise" -) - -// HSTimeout sets handshake timeout. -const HSTimeout = 5 * time.Second - -// WrapConn wraps `conn` with noise. -func WrapConn(config noise.Config, conn net.Conn) (net.Conn, error) { - ns, err := noise.New(noise.HandshakeKK, config) - if err != nil { - return nil, fmt.Errorf("failed to prepare stream noise object: %w", err) - } - - wrappedConn, err := noise.WrapConn(conn, ns, HSTimeout) - if err != nil { - return nil, fmt.Errorf("error performing noise handshake: %w", err) - } - - return wrappedConn, nil -} diff --git a/pkg/snet/directtp/tpconn/conn.go b/pkg/snet/directtp/tpconn/conn.go deleted file mode 100644 index 6aae796dd..000000000 --- a/pkg/snet/directtp/tpconn/conn.go +++ /dev/null @@ -1,103 +0,0 @@ -package tpconn - -import ( - "fmt" - "net" - "time" - - "github.com/skycoin/dmsg" - "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/dmsg/noise" - "github.com/skycoin/skycoin/src/util/logging" - - "github.com/skycoin/skywire/pkg/snet/directtp/noisewrapper" - "github.com/skycoin/skywire/pkg/snet/directtp/tphandshake" -) - -// Conn wraps an underlying net.Conn and modifies various methods to integrate better with the 'network' package. -type Conn struct { - net.Conn - lAddr dmsg.Addr - rAddr dmsg.Addr - freePort func() -} - -// Config describes a config for Conn. -type Config struct { - Log *logging.Logger - Conn net.Conn - LocalPK cipher.PubKey - LocalSK cipher.SecKey - Deadline time.Time - Handshake tphandshake.Handshake - FreePort func() - Encrypt bool - Initiator bool -} - -// NewConn creates a new Conn. -func NewConn(c Config) (*Conn, error) { - if c.Log != nil { - c.Log.Infof("Performing handshake with %v", c.Conn.RemoteAddr()) - } - - lAddr, rAddr, err := c.Handshake(c.Conn, c.Deadline) - if err != nil { - if err := c.Conn.Close(); err != nil && c.Log != nil { - c.Log.WithError(err).Warnf("Failed to close connection") - } - - if c.FreePort != nil { - c.FreePort() - } - - return nil, err - } - - if c.Log != nil { - c.Log.Infof("Sent handshake to %v, local addr %v, remote addr %v", c.Conn.RemoteAddr(), lAddr, rAddr) - } - - if c.Encrypt { - config := noise.Config{ - LocalPK: c.LocalPK, - LocalSK: c.LocalSK, - RemotePK: rAddr.PK, - Initiator: c.Initiator, - } - - wrappedConn, err := noisewrapper.WrapConn(config, c.Conn) - if err != nil { - return nil, fmt.Errorf("encrypt connection to %v@%v: %w", rAddr, c.Conn.RemoteAddr(), err) - } - - c.Conn = wrappedConn - - if c.Log != nil { - c.Log.Infof("Connection with %v@%v is encrypted", rAddr, c.Conn.RemoteAddr()) - } - } else if c.Log != nil { - c.Log.Infof("Connection with %v@%v is NOT encrypted", rAddr, c.Conn.RemoteAddr()) - } - - return &Conn{Conn: c.Conn, lAddr: lAddr, rAddr: rAddr, freePort: c.FreePort}, nil -} - -// LocalAddr implements net.Conn -func (c *Conn) LocalAddr() net.Addr { - return c.lAddr -} - -// RemoteAddr implements net.Conn -func (c *Conn) RemoteAddr() net.Addr { - return c.rAddr -} - -// Close implements net.Conn -func (c *Conn) Close() error { - if c.freePort != nil { - c.freePort() - } - - return c.Conn.Close() -} diff --git a/pkg/snet/directtp/tplistener/listener.go b/pkg/snet/directtp/tplistener/listener.go deleted file mode 100644 index ccb2ccb80..000000000 --- a/pkg/snet/directtp/tplistener/listener.go +++ /dev/null @@ -1,79 +0,0 @@ -package tplistener - -import ( - "io" - "net" - "sync" - - "github.com/skycoin/dmsg" - - "github.com/skycoin/skywire/pkg/snet/directtp/tpconn" -) - -// Listener implements net.Listener -type Listener struct { - lAddr dmsg.Addr - mx sync.Mutex - once sync.Once - freePort func() - accept chan *tpconn.Conn - done chan struct{} -} - -// NewListener returns a new Listener. -func NewListener(lAddr dmsg.Addr, freePort func()) *Listener { - return &Listener{ - lAddr: lAddr, - freePort: freePort, - accept: make(chan *tpconn.Conn), - done: make(chan struct{}), - } -} - -// Introduce is used by Client to introduce Conn to Listener. -func (l *Listener) Introduce(conn *tpconn.Conn) error { - select { - case <-l.done: - return io.ErrClosedPipe - default: - l.mx.Lock() - defer l.mx.Unlock() - - select { - case l.accept <- conn: - return nil - case <-l.done: - return io.ErrClosedPipe - } - } -} - -// Accept implements net.Listener -func (l *Listener) Accept() (net.Conn, error) { - c, ok := <-l.accept - if !ok { - return nil, io.ErrClosedPipe - } - - return c, nil -} - -// Close implements net.Listener -func (l *Listener) Close() error { - l.once.Do(func() { - close(l.done) - - l.mx.Lock() - close(l.accept) - l.mx.Unlock() - - l.freePort() - }) - - return nil -} - -// Addr implements net.Listener -func (l *Listener) Addr() net.Addr { - return l.lAddr -} diff --git a/pkg/snet/directtp/tptypes/tptypes.go b/pkg/snet/directtp/tptypes/tptypes.go deleted file mode 100644 index bf0f385ef..000000000 --- a/pkg/snet/directtp/tptypes/tptypes.go +++ /dev/null @@ -1,11 +0,0 @@ -package tptypes - -const ( - // STCP is a type of a transport that works via TCP and resolves addresses using PK table. - STCP = "stcp" - // STCPR is a type of a transport that works via TCP and resolves addresses using address-resolver service. - STCPR = "stcpr" - // SUDPH is a type of a transport that works via UDP, resolves addresses using address-resolver service, - // and uses UDP hole punching. - SUDPH = "sudph" -) diff --git a/pkg/snet/listener.go b/pkg/snet/listener.go deleted file mode 100644 index d0809f327..000000000 --- a/pkg/snet/listener.go +++ /dev/null @@ -1,39 +0,0 @@ -package snet - -import ( - "net" - - "github.com/skycoin/dmsg/cipher" -) - -// Listener represents a listener. -type Listener struct { - net.Listener - lPK cipher.PubKey - lPort uint16 - network string -} - -func makeListener(l net.Listener, network string) *Listener { - lPK, lPort := disassembleAddr(l.Addr()) - return &Listener{Listener: l, lPK: lPK, lPort: lPort, network: network} -} - -// LocalPK returns a local public key of listener. -func (l Listener) LocalPK() cipher.PubKey { return l.lPK } - -// LocalPort returns a local port of listener. -func (l Listener) LocalPort() uint16 { return l.lPort } - -// Network returns a network of listener. -func (l Listener) Network() string { return l.network } - -// AcceptConn accepts a connection from listener. -func (l Listener) AcceptConn() (*Conn, error) { - conn, err := l.Listener.Accept() - if err != nil { - return nil, err - } - - return makeConn(conn, l.network), nil -} diff --git a/pkg/snet/network.go b/pkg/snet/network.go deleted file mode 100644 index 24162bc88..000000000 --- a/pkg/snet/network.go +++ /dev/null @@ -1,425 +0,0 @@ -package snet - -import ( - "context" - "errors" - "fmt" - "net" - "strings" - "sync" - "time" - - "github.com/skycoin/dmsg" - "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/dmsg/disc" - "github.com/skycoin/skycoin/src/util/logging" - - "github.com/skycoin/skywire/pkg/app/appevent" - "github.com/skycoin/skywire/pkg/snet/arclient" - "github.com/skycoin/skywire/pkg/snet/directtp" - "github.com/skycoin/skywire/pkg/snet/directtp/pktable" - "github.com/skycoin/skywire/pkg/snet/directtp/tptypes" -) - -var log = logging.MustGetLogger("snet") - -var ( - // ErrUnknownNetwork occurs on attempt to dial an unknown network type. - ErrUnknownNetwork = errors.New("unknown network type") - knownNetworks = map[string]struct{}{ - dmsg.Type: {}, - tptypes.STCP: {}, - tptypes.STCPR: {}, - tptypes.SUDPH: {}, - } -) - -// IsKnownNetwork tells whether network type `netType` is known. -func IsKnownNetwork(netType string) bool { - _, ok := knownNetworks[netType] - return ok -} - -// NetworkConfig is a common interface for network configs. -type NetworkConfig interface { - Type() string -} - -// DmsgConfig defines config for Dmsg network. -type DmsgConfig struct { - Discovery string `json:"discovery"` - SessionsCount int `json:"sessions_count"` -} - -// Type returns DmsgType. -func (c *DmsgConfig) Type() string { - return dmsg.Type -} - -// STCPConfig defines config for STCP network. -type STCPConfig struct { - PKTable map[cipher.PubKey]string `json:"pk_table"` - LocalAddr string `json:"local_address"` -} - -// Type returns STCP type. -func (c *STCPConfig) Type() string { - return tptypes.STCP -} - -// Config represents a network configuration. -type Config struct { - PubKey cipher.PubKey - SecKey cipher.SecKey - ARClient arclient.APIClient - NetworkConfigs NetworkConfigs -} - -// NetworkConfigs represents all network configs. -type NetworkConfigs struct { - Dmsg *DmsgConfig // The dmsg service will not be started if nil. - STCP *STCPConfig // The stcp service will not be started if nil. -} - -// NetworkClients represents all network clients. -type NetworkClients struct { - DmsgC *dmsg.Client - Direct map[string]directtp.Client -} - -// Network represents a network between nodes in Skywire. -type Network struct { - conf Config - netsMu sync.RWMutex - nets map[string]struct{} // networks to be used with transports - clients NetworkClients - - onNewNetworkTypeMu sync.Mutex - onNewNetworkType func(netType string) -} - -// New creates a network from a config. -func New(conf Config, eb *appevent.Broadcaster, masterLogger *logging.MasterLogger) (*Network, error) { - clients := NetworkClients{ - Direct: make(map[string]directtp.Client), - } - - if conf.NetworkConfigs.Dmsg != nil { - dmsgConf := &dmsg.Config{ - MinSessions: conf.NetworkConfigs.Dmsg.SessionsCount, - Callbacks: &dmsg.ClientCallbacks{ - OnSessionDial: func(network, addr string) error { - data := appevent.TCPDialData{RemoteNet: network, RemoteAddr: addr} - event := appevent.NewEvent(appevent.TCPDial, data) - _ = eb.Broadcast(context.Background(), event) //nolint:errcheck - // @evanlinjin: An error is not returned here as this will cancel the session dial. - return nil - }, - OnSessionDisconnect: func(network, addr string, _ error) { - data := appevent.TCPCloseData{RemoteNet: network, RemoteAddr: addr} - event := appevent.NewEvent(appevent.TCPClose, data) - _ = eb.Broadcast(context.Background(), event) //nolint:errcheck - }, - }, - } - clients.DmsgC = dmsg.NewClient(conf.PubKey, conf.SecKey, disc.NewHTTP(conf.NetworkConfigs.Dmsg.Discovery), dmsgConf) - clients.DmsgC.SetLogger(masterLogger.PackageLogger("snet.dmsgC")) - } - - if conf.NetworkConfigs.STCP != nil { - conf := directtp.Config{ - Type: tptypes.STCP, - PK: conf.PubKey, - SK: conf.SecKey, - Table: pktable.NewTable(conf.NetworkConfigs.STCP.PKTable), - LocalAddr: conf.NetworkConfigs.STCP.LocalAddr, - BeforeDialCallback: func(network, addr string) error { - data := appevent.TCPDialData{RemoteNet: network, RemoteAddr: addr} - event := appevent.NewEvent(appevent.TCPDial, data) - _ = eb.Broadcast(context.Background(), event) //nolint:errcheck - return nil - }, - } - clients.Direct[tptypes.STCP] = directtp.NewClient(conf, masterLogger) - } - - if conf.ARClient != nil { - stcprConf := directtp.Config{ - Type: tptypes.STCPR, - PK: conf.PubKey, - SK: conf.SecKey, - AddressResolver: conf.ARClient, - BeforeDialCallback: func(network, addr string) error { - data := appevent.TCPDialData{RemoteNet: network, RemoteAddr: addr} - event := appevent.NewEvent(appevent.TCPDial, data) - _ = eb.Broadcast(context.Background(), event) //nolint:errcheck - return nil - }, - } - - clients.Direct[tptypes.STCPR] = directtp.NewClient(stcprConf, masterLogger) - - sudphConf := directtp.Config{ - Type: tptypes.SUDPH, - PK: conf.PubKey, - SK: conf.SecKey, - AddressResolver: conf.ARClient, - } - - clients.Direct[tptypes.SUDPH] = directtp.NewClient(sudphConf, masterLogger) - } - - return NewRaw(conf, clients), nil -} - -// NewRaw creates a network from a config and a dmsg client. -func NewRaw(conf Config, clients NetworkClients) *Network { - n := &Network{ - conf: conf, - nets: make(map[string]struct{}), - clients: clients, - } - - if clients.DmsgC != nil { - n.addNetworkType(dmsg.Type) - } - - for k, v := range clients.Direct { - if v != nil { - n.addNetworkType(k) - } - } - - return n -} - -// Conf gets network configuration. -func (n *Network) Conf() Config { - return n.conf -} - -// Init initiates server connections. -func (n *Network) Init() error { - if n.clients.DmsgC != nil { - time.Sleep(200 * time.Millisecond) - go n.clients.DmsgC.Serve(context.Background()) - time.Sleep(200 * time.Millisecond) - } - - if n.conf.NetworkConfigs.STCP != nil { - if client, ok := n.clients.Direct[tptypes.STCP]; ok && client != nil && n.conf.NetworkConfigs.STCP.LocalAddr != "" { - if err := client.Serve(); err != nil { - return fmt.Errorf("failed to initiate 'stcp': %w", err) - } - } else { - log.Infof("No config found for stcp") - } - } - - if n.conf.ARClient != nil { - if client, ok := n.clients.Direct[tptypes.STCPR]; ok && client != nil { - if err := client.Serve(); err != nil { - return fmt.Errorf("failed to initiate 'stcpr': %w", err) - } - } else { - log.Infof("No config found for stcpr") - } - - if client, ok := n.clients.Direct[tptypes.SUDPH]; ok && client != nil { - if err := client.Serve(); err != nil { - return fmt.Errorf("failed to initiate 'sudph': %w", err) - } - } else { - log.Infof("No config found for sudph") - } - } - - return nil -} - -// OnNewNetworkType sets callback to be called when new network type is ready. -func (n *Network) OnNewNetworkType(callback func(netType string)) { - n.onNewNetworkTypeMu.Lock() - n.onNewNetworkType = callback - n.onNewNetworkTypeMu.Unlock() -} - -// IsNetworkReady checks whether network of type `netType` is ready. -func (n *Network) IsNetworkReady(netType string) bool { - n.netsMu.Lock() - _, ok := n.nets[netType] - n.netsMu.Unlock() - return ok -} - -// Close closes underlying connections. -func (n *Network) Close() error { - n.netsMu.Lock() - defer n.netsMu.Unlock() - - wg := new(sync.WaitGroup) - - var dmsgErr error - if n.clients.DmsgC != nil { - wg.Add(1) - go func() { - dmsgErr = n.clients.DmsgC.Close() - wg.Done() - }() - } - - directErrors := make(chan error) - - for _, directClient := range n.clients.Direct { - if directClient == nil { - continue - } - err := directClient.Close() - if err != nil { - directErrors <- err - } - } - close(directErrors) - wg.Wait() - if dmsgErr != nil { - return dmsgErr - } - - for err := range directErrors { - if err != nil { - return err - } - } - - return nil -} - -// LocalPK returns local public key. -func (n *Network) LocalPK() cipher.PubKey { return n.conf.PubKey } - -// LocalSK returns local secure key. -func (n *Network) LocalSK() cipher.SecKey { return n.conf.SecKey } - -// TransportNetworks returns network types that are used for transports. -func (n *Network) TransportNetworks() []string { - n.netsMu.RLock() - networks := make([]string, 0, len(n.nets)) - for network := range n.nets { - networks = append(networks, network) - } - n.netsMu.RUnlock() - - return networks -} - -// Dmsg returns underlying dmsg client. -func (n *Network) Dmsg() *dmsg.Client { return n.clients.DmsgC } - -// STcp returns the underlying stcp.Client. -func (n *Network) STcp() (directtp.Client, bool) { - return n.getClient(tptypes.STCP) -} - -// STcpr returns the underlying stcpr.Client. -func (n *Network) STcpr() (directtp.Client, bool) { - return n.getClient(tptypes.STCPR) -} - -// SUdpH returns the underlying sudph.Client. -func (n *Network) SUdpH() (directtp.Client, bool) { - return n.getClient(tptypes.SUDPH) -} - -func (n *Network) getClient(tpType string) (directtp.Client, bool) { - c, ok := n.clients.Direct[tpType] - return c, ok -} - -// Dial dials a visor by its public key and returns a connection. -func (n *Network) Dial(ctx context.Context, network string, pk cipher.PubKey, port uint16) (*Conn, error) { - switch network { - case dmsg.Type: - addr := dmsg.Addr{ - PK: pk, - Port: port, - } - - conn, err := n.clients.DmsgC.Dial(ctx, addr) - if err != nil { - return nil, fmt.Errorf("dmsg client dial %v: %w", addr, err) - } - - return makeConn(conn, network), nil - default: - client, ok := n.clients.Direct[network] - if !ok { - return nil, ErrUnknownNetwork - } - - conn, err := client.Dial(ctx, pk, port) - if err != nil { - return nil, fmt.Errorf("dial: %w", err) - } - - log.Infof("Dialed %v, conn local address %q, remote address %q", network, conn.LocalAddr(), conn.RemoteAddr()) - return makeConn(conn, network), nil - } -} - -// Listen listens on the specified port. -func (n *Network) Listen(network string, port uint16) (*Listener, error) { - switch network { - case dmsg.Type: - lis, err := n.clients.DmsgC.Listen(port) - if err != nil { - return nil, err - } - - return makeListener(lis, network), nil - default: - client, ok := n.clients.Direct[network] - if !ok { - return nil, ErrUnknownNetwork - } - - lis, err := client.Listen(port) - if err != nil { - return nil, fmt.Errorf("listen: %w", err) - } - - return makeListener(lis, network), nil - } -} - -func (n *Network) addNetworkType(netType string) { - n.netsMu.Lock() - defer n.netsMu.Unlock() - - if _, ok := n.nets[netType]; !ok { - n.nets[netType] = struct{}{} - n.onNewNetworkTypeMu.Lock() - if n.onNewNetworkType != nil { - n.onNewNetworkType(netType) - } - n.onNewNetworkTypeMu.Unlock() - } -} - -func disassembleAddr(addr net.Addr) (pk cipher.PubKey, port uint16) { - strs := strings.Split(addr.String(), ":") - if len(strs) != 2 { - panic(fmt.Errorf("network.disassembleAddr: %v %s", "invalid addr", addr.String())) - } - - if err := pk.Set(strs[0]); err != nil { - panic(fmt.Errorf("network.disassembleAddr: %v %s", err, addr.String())) - } - - if strs[1] != "~" { - if _, err := fmt.Sscanf(strs[1], "%d", &port); err != nil { - panic(fmt.Errorf("network.disassembleAddr: %w", err)) - } - } - - return -} diff --git a/pkg/snet/network_test.go b/pkg/snet/network_test.go deleted file mode 100644 index 3dce54c36..000000000 --- a/pkg/snet/network_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package snet - -import ( - "testing" - - "github.com/skycoin/dmsg" - "github.com/skycoin/dmsg/cipher" - "github.com/stretchr/testify/require" -) - -func TestDisassembleAddr(t *testing.T) { - pk, _ := cipher.GenerateKeyPair() - port := uint16(2) - addr := dmsg.Addr{ - PK: pk, Port: port, - } - - gotPK, gotPort := disassembleAddr(addr) - require.Equal(t, pk, gotPK) - require.Equal(t, port, gotPort) -} diff --git a/pkg/snet/snettest/env.go b/pkg/snet/snettest/env.go deleted file mode 100644 index e883a5bdb..000000000 --- a/pkg/snet/snettest/env.go +++ /dev/null @@ -1,197 +0,0 @@ -package snettest - -import ( - "context" - "strconv" - "testing" - - "github.com/skycoin/dmsg" - "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/dmsg/disc" - "github.com/skycoin/skycoin/src/util/logging" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/net/nettest" - - "github.com/skycoin/skywire/pkg/snet" - "github.com/skycoin/skywire/pkg/snet/arclient" - "github.com/skycoin/skywire/pkg/snet/directtp" - "github.com/skycoin/skywire/pkg/snet/directtp/pktable" - "github.com/skycoin/skywire/pkg/snet/directtp/tptypes" -) - -// KeyPair holds a public/private key pair. -type KeyPair struct { - PK cipher.PubKey - SK cipher.SecKey -} - -// GenKeyPairs generates 'n' number of key pairs. -func GenKeyPairs(n int) []KeyPair { - pairs := make([]KeyPair, n) - for i := range pairs { - pk, sk, err := cipher.GenerateDeterministicKeyPair([]byte{byte(i)}) - if err != nil { - panic(err) - } - - pairs[i] = KeyPair{PK: pk, SK: sk} - } - - return pairs -} - -// Env contains a network test environment. -type Env struct { - DmsgD disc.APIClient - DmsgS *dmsg.Server - Keys []KeyPair - Nets []*snet.Network - teardown func() -} - -// NewEnv creates a `network.Network` test environment. -// `nPairs` is the public/private key pairs of all the `network.Network`s to be created. -func NewEnv(t *testing.T, keys []KeyPair, networks []string) *Env { - // Prepare `dmsg`. - dmsgD := disc.NewMock(0) - dmsgS, dmsgSErr := createDmsgSrv(t, dmsgD) - - const baseSTCPPort = 7033 - - tableEntries := make(map[cipher.PubKey]string) - for i, pair := range keys { - tableEntries[pair.PK] = "127.0.0.1:" + strconv.Itoa(baseSTCPPort+i) - } - - table := pktable.NewTable(tableEntries) - - var hasDmsg, hasStcp, hasStcpr, hasSudph bool - - for _, network := range networks { - switch network { - case dmsg.Type: - hasDmsg = true - case tptypes.STCP: - hasStcp = true - case tptypes.STCPR: - hasStcpr = true - case tptypes.SUDPH: - hasSudph = true - } - } - - // Prepare `snets`. - ns := make([]*snet.Network, len(keys)) - - const stcpBasePort = 7033 - - for i, pairs := range keys { - networkConfigs := snet.NetworkConfigs{ - Dmsg: &snet.DmsgConfig{ - SessionsCount: 1, - }, - STCP: &snet.STCPConfig{ - LocalAddr: "127.0.0.1:" + strconv.Itoa(stcpBasePort+i), - }, - } - - clients := snet.NetworkClients{ - Direct: make(map[string]directtp.Client), - } - - if hasDmsg { - clients.DmsgC = dmsg.NewClient(pairs.PK, pairs.SK, dmsgD, nil) - go clients.DmsgC.Serve(context.Background()) - } - - addressResolver := new(arclient.MockAPIClient) - - if hasStcp { - conf := directtp.Config{ - Type: tptypes.STCP, - PK: pairs.PK, - SK: pairs.SK, - Table: table, - LocalAddr: networkConfigs.STCP.LocalAddr, - } - - clients.Direct[tptypes.STCP] = directtp.NewClient(conf, logging.NewMasterLogger()) - } - - if hasStcpr { - conf := directtp.Config{ - Type: tptypes.STCPR, - PK: pairs.PK, - SK: pairs.SK, - AddressResolver: addressResolver, - } - - clients.Direct[tptypes.STCPR] = directtp.NewClient(conf, logging.NewMasterLogger()) - } - - if hasSudph { - conf := directtp.Config{ - Type: tptypes.SUDPH, - PK: pairs.PK, - SK: pairs.SK, - AddressResolver: addressResolver, - } - - clients.Direct[tptypes.SUDPH] = directtp.NewClient(conf, logging.NewMasterLogger()) - } - - snetConfig := snet.Config{ - PubKey: pairs.PK, - SecKey: pairs.SK, - NetworkConfigs: networkConfigs, - } - - n := snet.NewRaw(snetConfig, clients) - require.NoError(t, n.Init()) - ns[i] = n - } - - // Prepare teardown closure. - teardown := func() { - for _, n := range ns { - assert.NoError(t, n.Close()) - } - assert.NoError(t, dmsgS.Close()) - for err := range dmsgSErr { - assert.NoError(t, err) - } - } - - return &Env{ - DmsgD: dmsgD, - DmsgS: dmsgS, - Keys: keys, - Nets: ns, - teardown: teardown, - } -} - -// Teardown shutdowns the Env. -func (e *Env) Teardown() { e.teardown() } - -func createDmsgSrv(t *testing.T, dc disc.APIClient) (srv *dmsg.Server, srvErr <-chan error) { - pk, sk, err := cipher.GenerateDeterministicKeyPair([]byte("s")) - require.NoError(t, err) - - l, err := nettest.NewLocalListener("tcp") - require.NoError(t, err) - - srv = dmsg.NewServer(pk, sk, dc, &dmsg.ServerConfig{MaxSessions: 100}, nil) - - errCh := make(chan error, 1) - - go func() { - errCh <- srv.Serve(l, "") - close(errCh) - }() - - <-srv.Ready() - - return srv, errCh -} diff --git a/pkg/transport/entry.go b/pkg/transport/entry.go index f474e6fb1..24f0a9ad2 100644 --- a/pkg/transport/entry.go +++ b/pkg/transport/entry.go @@ -7,6 +7,8 @@ import ( "github.com/google/uuid" "github.com/skycoin/dmsg/cipher" + + "github.com/skycoin/skywire/pkg/transport/network" ) var ( @@ -38,7 +40,7 @@ type Entry struct { Edges [2]cipher.PubKey `json:"edges"` // Type represents the transport type. - Type string `json:"type"` + Type network.Type `json:"type"` // Public determines whether the transport is to be exposed to other nodes or not. // Public transports are to be registered in the Transport Discovery. @@ -48,10 +50,10 @@ type Entry struct { } // MakeEntry creates a new transport entry -func MakeEntry(initiator, target cipher.PubKey, tpType string, public bool, label Label) Entry { +func MakeEntry(initiator, target cipher.PubKey, netType network.Type, public bool, label Label) Entry { entry := Entry{ - ID: MakeTransportID(initiator, target, tpType), - Type: tpType, + ID: MakeTransportID(initiator, target, netType), + Type: netType, Public: public, Label: label, } diff --git a/pkg/transport/handshake.go b/pkg/transport/handshake.go index 899447160..16e3fdf8e 100644 --- a/pkg/transport/handshake.go +++ b/pkg/transport/handshake.go @@ -11,7 +11,7 @@ import ( "github.com/skycoin/dmsg/cipher" "github.com/skycoin/dmsg/httputil" - "github.com/skycoin/skywire/pkg/snet" + "github.com/skycoin/skywire/pkg/transport/network" ) type hsResponse byte @@ -23,7 +23,7 @@ const ( responseInvalidEntry ) -func makeEntryFromTpConn(conn *snet.Conn, isInitiator bool) Entry { +func makeEntryFromTpConn(conn network.Conn, isInitiator bool) Entry { initiator, target := conn.LocalPK(), conn.RemotePK() if !isInitiator { initiator, target = target, initiator @@ -76,10 +76,10 @@ func receiveAndVerifyEntry(r io.Reader, expected *Entry, remotePK cipher.PubKey) // SettlementHS represents a settlement handshake. // This is the handshake responsible for registering a transport to transport discovery. -type SettlementHS func(ctx context.Context, dc DiscoveryClient, conn *snet.Conn, sk cipher.SecKey) error +type SettlementHS func(ctx context.Context, dc DiscoveryClient, conn network.Conn, sk cipher.SecKey) error // Do performs the settlement handshake. -func (hs SettlementHS) Do(ctx context.Context, dc DiscoveryClient, conn *snet.Conn, sk cipher.SecKey) (err error) { +func (hs SettlementHS) Do(ctx context.Context, dc DiscoveryClient, conn network.Conn, sk cipher.SecKey) (err error) { done := make(chan struct{}) go func() { err = hs(ctx, dc, conn, sk) @@ -98,7 +98,7 @@ func (hs SettlementHS) Do(ctx context.Context, dc DiscoveryClient, conn *snet.Co // The handshake logic only REGISTERS the transport, and does not update the status of the transport. func MakeSettlementHS(init bool) SettlementHS { // initiating logic. - initHS := func(ctx context.Context, dc DiscoveryClient, conn *snet.Conn, sk cipher.SecKey) (err error) { + initHS := func(ctx context.Context, dc DiscoveryClient, conn network.Conn, sk cipher.SecKey) (err error) { entry := makeEntryFromTpConn(conn, true) // TODO(evanlinjin): Probably not needed as this is called in mTp already. Need to double check. @@ -138,7 +138,7 @@ func MakeSettlementHS(init bool) SettlementHS { } // responding logic. - respHS := func(ctx context.Context, dc DiscoveryClient, conn *snet.Conn, sk cipher.SecKey) error { + respHS := func(ctx context.Context, dc DiscoveryClient, conn network.Conn, sk cipher.SecKey) error { entry := makeEntryFromTpConn(conn, false) // receive, verify and sign entry. diff --git a/pkg/transport/handshake_test.go b/pkg/transport/handshake_test.go deleted file mode 100644 index d0225a768..000000000 --- a/pkg/transport/handshake_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package transport_test - -import ( - "context" - "testing" - "time" - - "github.com/skycoin/dmsg" - "github.com/stretchr/testify/require" - - "github.com/skycoin/skywire/pkg/skyenv" - "github.com/skycoin/skywire/pkg/snet/snettest" - "github.com/skycoin/skywire/pkg/transport" -) - -func TestSettlementHS(t *testing.T) { - tpDisc := transport.NewDiscoveryMock() - - keys := snettest.GenKeyPairs(2) - nEnv := snettest.NewEnv(t, keys, []string{dmsg.Type}) - defer nEnv.Teardown() - - // TEST: Perform a handshake between two snet.Network instances. - t.Run("Do", func(t *testing.T) { - lis1, err := nEnv.Nets[1].Listen(dmsg.Type, skyenv.DmsgTransportPort) - require.NoError(t, err) - - errCh1 := make(chan error, 1) - go func() { - defer close(errCh1) - conn1, err := lis1.AcceptConn() - if err != nil { - errCh1 <- err - return - } - errCh1 <- transport.MakeSettlementHS(false).Do(context.TODO(), tpDisc, conn1, keys[1].SK) - }() - - const entryTimeout = 5 * time.Second - start := time.Now() - - // Wait until entry is set. - // TODO: Implement more elegant solution. - for { - if time.Since(start) > entryTimeout { - t.Fatal("Entry in Dmsg Discovery is not set within expected time") - } - - if _, err := nEnv.DmsgD.Entry(context.TODO(), keys[1].PK); err == nil { - break - } - } - - conn0, err := nEnv.Nets[0].Dial(context.TODO(), dmsg.Type, keys[1].PK, skyenv.DmsgTransportPort) - require.NoError(t, err) - require.NoError(t, transport.MakeSettlementHS(true).Do(context.TODO(), tpDisc, conn0, keys[0].SK)) - - require.NoError(t, <-errCh1) - }) -} - -// TODO(evanlinjin): This will need further testing. diff --git a/pkg/transport/managed_transport.go b/pkg/transport/managed_transport.go index 2cbcf4874..c1ed2e22f 100644 --- a/pkg/transport/managed_transport.go +++ b/pkg/transport/managed_transport.go @@ -16,9 +16,10 @@ import ( "github.com/skycoin/dmsg/netutil" "github.com/skycoin/skycoin/src/util/logging" + "github.com/skycoin/skywire/pkg/app/appevent" "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/skywire/pkg/skyenv" - "github.com/skycoin/skywire/pkg/snet" + "github.com/skycoin/skywire/pkg/transport/network" ) const logWriteInterval = time.Second * 3 @@ -45,12 +46,11 @@ const ( // ManagedTransportConfig is a configuration for managed transport. type ManagedTransportConfig struct { - Net *snet.Network + client network.Client + ebc *appevent.Broadcaster DC DiscoveryClient LS LogStore RemotePK cipher.PubKey - NetName string - AfterClosed TPCloseCallback TransportLabel Label } @@ -62,13 +62,13 @@ type ManagedTransport struct { log *logging.Logger rPK cipher.PubKey - netName string Entry Entry LogEntry *LogEntry logUpdates uint32 - dc DiscoveryClient - ls LogStore + dc DiscoveryClient + ls LogStore + ebc *appevent.Broadcaster isUp bool // records last successful status update to discovery isUpErr error // records whether the last status update was successful or not @@ -77,8 +77,8 @@ type ManagedTransport struct { redialCancel context.CancelFunc // for canceling redialling logic redialMx sync.Mutex - n *snet.Network - conn *snet.Conn + client network.Client + conn network.Conn connCh chan struct{} connMx sync.Mutex @@ -87,29 +87,25 @@ type ManagedTransport struct { wg sync.WaitGroup remoteAddr string - - afterClosedMu sync.RWMutex - afterClosed TPCloseCallback } // NewManagedTransport creates a new ManagedTransport. func NewManagedTransport(conf ManagedTransportConfig, isInitiator bool) *ManagedTransport { - initiator, target := conf.Net.LocalPK(), conf.RemotePK + initiator, target := conf.client.PK(), conf.RemotePK if !isInitiator { initiator, target = target, initiator } mt := &ManagedTransport{ - log: logging.MustGetLogger(fmt.Sprintf("tp:%s", conf.RemotePK.String()[:6])), - rPK: conf.RemotePK, - netName: conf.NetName, - n: conf.Net, - dc: conf.DC, - ls: conf.LS, - Entry: MakeEntry(initiator, target, conf.NetName, true, conf.TransportLabel), - LogEntry: new(LogEntry), - connCh: make(chan struct{}, 1), - done: make(chan struct{}), - afterClosed: conf.AfterClosed, + log: logging.MustGetLogger(fmt.Sprintf("tp:%s", conf.RemotePK.String()[:6])), + rPK: conf.RemotePK, + dc: conf.DC, + ls: conf.LS, + client: conf.client, + Entry: MakeEntry(initiator, target, conf.client.Type(), true, conf.TransportLabel), + LogEntry: new(LogEntry), + connCh: make(chan struct{}, 1), + done: make(chan struct{}), + ebc: conf.ebc, } mt.wg.Add(2) return mt @@ -222,12 +218,6 @@ func (mt *ManagedTransport) Serve(readCh chan<- routing.Packet) { } } -func (mt *ManagedTransport) onAfterClosed(f TPCloseCallback) { - mt.afterClosedMu.Lock() - mt.afterClosed = f - mt.afterClosedMu.Unlock() -} - func (mt *ManagedTransport) isServing() bool { select { case <-mt.done: @@ -253,13 +243,8 @@ func (mt *ManagedTransport) Close() (err error) { func (mt *ManagedTransport) close() { mt.disconnect() - - mt.afterClosedMu.RLock() - afterClosed := mt.afterClosed - mt.afterClosedMu.RUnlock() - - if afterClosed != nil { - afterClosed(mt.netName, mt.remoteAddr) + if mt.Type() == network.STCPR && mt.remoteAddr != "" { + mt.ebc.SendTPClose(context.Background(), string(network.STCPR), mt.remoteAddr) } } @@ -271,11 +256,11 @@ func (mt *ManagedTransport) disconnect() { } // Accept accepts a new underlying connection. -func (mt *ManagedTransport) Accept(ctx context.Context, conn *snet.Conn) error { +func (mt *ManagedTransport) Accept(ctx context.Context, conn network.Conn) error { mt.connMx.Lock() defer mt.connMx.Unlock() - if conn.Network() != mt.netName { + if conn.Network() != mt.Type() { return ErrWrongNetwork } @@ -292,7 +277,7 @@ func (mt *ManagedTransport) Accept(ctx context.Context, conn *snet.Conn) error { defer cancel() mt.log.Debug("Performing settlement handshake...") - if err := MakeSettlementHS(false).Do(ctx, mt.dc, conn, mt.n.LocalSK()); err != nil { + if err := MakeSettlementHS(false).Do(ctx, mt.dc, conn, mt.client.SK()); err != nil { return fmt.Errorf("settlement handshake failed: %w", err) } @@ -316,7 +301,7 @@ func (mt *ManagedTransport) Dial(ctx context.Context) error { } func (mt *ManagedTransport) dial(ctx context.Context) error { - tp, err := mt.n.Dial(ctx, mt.netName, mt.rPK, skyenv.DmsgTransportPort) + conn, err := mt.client.Dial(ctx, mt.rPK, skyenv.DmsgTransportPort) if err != nil { return fmt.Errorf("snet.Dial: %w", err) } @@ -324,11 +309,11 @@ func (mt *ManagedTransport) dial(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, time.Second*20) defer cancel() - if err := MakeSettlementHS(true).Do(ctx, mt.dc, tp, mt.n.LocalSK()); err != nil { + if err := MakeSettlementHS(true).Do(ctx, mt.dc, conn, mt.client.SK()); err != nil { return fmt.Errorf("settlement handshake failed: %w", err) } - if err := mt.setConn(tp); err != nil { + if err := mt.setConn(conn); err != nil { return fmt.Errorf("setConn: %w", err) } @@ -389,18 +374,18 @@ func (mt *ManagedTransport) redialLoop(ctx context.Context) error { func (mt *ManagedTransport) isLeastSignificantEdge() bool { sorted := SortEdges(mt.Entry.Edges[0], mt.Entry.Edges[1]) - return sorted[0] == mt.n.LocalPK() + return sorted[0] == mt.client.PK() } func (mt *ManagedTransport) isInitiator() bool { - return mt.Entry.EdgeIndex(mt.n.LocalPK()) == 0 + return mt.Entry.EdgeIndex(mt.client.PK()) == 0 } /* <<< UNDERLYING CONNECTION >>> */ -func (mt *ManagedTransport) getConn() *snet.Conn { +func (mt *ManagedTransport) getConn() network.Conn { if !mt.isServing() { return nil } @@ -413,7 +398,7 @@ func (mt *ManagedTransport) getConn() *snet.Conn { // setConn sets 'mt.conn' (the underlying connection). // If 'mt.conn' is already occupied, close the newly introduced connection. -func (mt *ManagedTransport) setConn(newConn *snet.Conn) error { +func (mt *ManagedTransport) setConn(newConn network.Conn) error { if mt.conn != nil { if mt.isLeastSignificantEdge() { @@ -573,7 +558,7 @@ func (mt *ManagedTransport) WritePacket(ctx context.Context, packet routing.Pack func (mt *ManagedTransport) readPacket() (packet routing.Packet, err error) { log := mt.log.WithField("func", "readPacket") - var conn *snet.Conn + var conn network.Conn for { if conn = mt.getConn(); conn != nil { break @@ -639,4 +624,4 @@ func (mt *ManagedTransport) logMod() bool { func (mt *ManagedTransport) Remote() cipher.PubKey { return mt.rPK } // Type returns the transport type. -func (mt *ManagedTransport) Type() string { return mt.netName } +func (mt *ManagedTransport) Type() network.Type { return mt.client.Type() } diff --git a/pkg/transport/manager.go b/pkg/transport/manager.go index e4135fbb2..e5150a8df 100644 --- a/pkg/transport/manager.go +++ b/pkg/transport/manager.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "io" - "strings" "sync" "time" @@ -13,17 +12,13 @@ import ( "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" + "github.com/skycoin/skywire/pkg/app/appevent" "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/skywire/pkg/skyenv" - "github.com/skycoin/skywire/pkg/snet" - "github.com/skycoin/skywire/pkg/snet/arclient" - "github.com/skycoin/skywire/pkg/snet/directtp/tptypes" - "github.com/skycoin/skywire/pkg/snet/snettest" + "github.com/skycoin/skywire/pkg/transport/network" + "github.com/skycoin/skywire/pkg/transport/network/addrresolver" ) -// TPCloseCallback triggers after a session is closed. -type TPCloseCallback func(network, addr string) - // ManagerConfig configures a Manager. type ManagerConfig struct { PubKey cipher.PubKey @@ -34,57 +29,44 @@ type ManagerConfig struct { // Manager manages Transports. type Manager struct { - Logger *logging.Logger - Conf *ManagerConfig - tps map[uuid.UUID]*ManagedTransport - n *snet.Network - - listenersMu sync.Mutex - listeners []*snet.Listener - servingNetsMu sync.Mutex - servingNets map[string]struct{} - readCh chan routing.Packet - mx sync.RWMutex - wgMu sync.Mutex - wg sync.WaitGroup - serveOnce sync.Once // ensure we only serve once. - closeOnce sync.Once // ensure we only close once. - done chan struct{} - - afterTPClosed TPCloseCallback + Logger *logging.Logger + Conf *ManagerConfig + tps map[uuid.UUID]*ManagedTransport + arClient addrresolver.APIClient + ebc *appevent.Broadcaster + + readCh chan routing.Packet + mx sync.RWMutex + wgMu sync.Mutex + wg sync.WaitGroup + serveOnce sync.Once // ensure we only serve once. + closeOnce sync.Once // ensure we only close once. + done chan struct{} + + factory network.ClientFactory + netClients map[network.Type]network.Client } // NewManager creates a Manager with the provided configuration and transport factories. // 'factories' should be ordered by preference. -func NewManager(log *logging.Logger, n *snet.Network, config *ManagerConfig) (*Manager, error) { +func NewManager(log *logging.Logger, arClient addrresolver.APIClient, ebc *appevent.Broadcaster, config *ManagerConfig, factory network.ClientFactory) (*Manager, error) { if log == nil { log = logging.MustGetLogger("tp_manager") } tm := &Manager{ - Logger: log, - Conf: config, - servingNets: make(map[string]struct{}), - tps: make(map[uuid.UUID]*ManagedTransport), - n: n, - readCh: make(chan routing.Packet, 20), - done: make(chan struct{}), + Logger: log, + Conf: config, + tps: make(map[uuid.UUID]*ManagedTransport), + readCh: make(chan routing.Packet, 20), + done: make(chan struct{}), + netClients: make(map[network.Type]network.Client), + arClient: arClient, + factory: factory, + ebc: ebc, } return tm, nil } -// OnAfterTPClosed sets callback which will fire after transport gets closed. -func (tm *Manager) OnAfterTPClosed(f TPCloseCallback) { - tm.mx.Lock() - defer tm.mx.Unlock() - - tm.afterTPClosed = f - - // set callback for all already known tps - for _, tp := range tm.tps { - tp.onAfterClosed(f) - } -} - // Serve runs listening loop across all registered factories. func (tm *Manager) Serve(ctx context.Context) { tm.serveOnce.Do(func() { @@ -92,89 +74,75 @@ func (tm *Manager) Serve(ctx context.Context) { }) } -func (tm *Manager) serveNetwork(ctx context.Context, netType string) { - if tm.isClosing() { - return - } - - // this func may be called by either initiating routing or a callback, - // so we should check whether this type of network is already being served - tm.servingNetsMu.Lock() - if _, ok := tm.servingNets[netType]; ok { - tm.servingNetsMu.Unlock() - return - } - tm.servingNets[netType] = struct{}{} - tm.servingNetsMu.Unlock() +func (tm *Manager) serve(ctx context.Context) { + tm.initClients() + tm.runClients(ctx) + tm.initTransports(ctx) + tm.Logger.Info("transport manager is serving.") +} - lis, err := tm.n.Listen(netType, skyenv.DmsgTransportPort) - if err != nil { - tm.Logger.WithError(err).Fatalf("failed to listen on network '%s' of port '%d'", - netType, skyenv.DmsgTransportPort) - return +func (tm *Manager) initClients() { + acceptedNetworks := []network.Type{network.STCP, network.STCPR, network.SUDPH, network.DMSG} + for _, netType := range acceptedNetworks { + client, err := tm.factory.MakeClient(netType) + if err != nil { + tm.Logger.Warnf("Cannot initialize %s transport client", netType) + continue + } + tm.netClients[netType] = client } - tm.Logger.Infof("listening on network: %s", netType) - tm.listenersMu.Lock() - tm.listeners = append(tm.listeners, lis) - tm.listenersMu.Unlock() +} +func (tm *Manager) runClients(ctx context.Context) { if tm.isClosing() { return } + for _, client := range tm.netClients { + tm.Logger.Infof("Serving %s network", client.Type()) + err := client.Start() + if err != nil { + tm.Logger.WithError(err).Errorf("Failed to listen on %s network", client.Type()) + continue + } + lis, err := client.Listen(skyenv.DmsgTransportPort) + if err != nil { + tm.Logger.WithError(err).Fatalf("failed to listen on network '%s' of port '%d'", + client.Type(), skyenv.DmsgTransportPort) + return + } + tm.Logger.Infof("listening on network: %s", client.Type()) + tm.wgMu.Lock() + tm.wg.Add(1) + tm.wgMu.Unlock() + go tm.acceptTransports(ctx, lis) + } +} - tm.wgMu.Lock() - tm.wg.Add(1) - tm.wgMu.Unlock() - - go func() { - defer tm.wg.Done() - for { - select { - case <-ctx.Done(): - return - case <-tm.done: - return - default: - if err := tm.acceptTransport(ctx, lis); err != nil { - tm.Logger.Warnf("Failed to accept connection: %v", err) - if strings.Contains(err.Error(), "closed") { - return - } +func (tm *Manager) acceptTransports(ctx context.Context, lis network.Listener) { + defer tm.wg.Done() + for { + select { + case <-ctx.Done(): + case <-tm.done: + return + default: + if err := tm.acceptTransport(ctx, lis); err != nil { + tm.Logger.Warnf("Failed to accept connection: %v", err) + if errors.Is(err, io.ErrClosedPipe) { + return } } } - }() -} - -func (tm *Manager) serve(ctx context.Context) { - // TODO: to get rid of this callback, we need to have method on future network interface like: `Ready() <-chan struct{}` - // some networks may not be ready yet, so we're setting a callback first - tm.n.OnNewNetworkType(func(netType string) { - tm.serveNetwork(ctx, netType) - }) - - // here we may start serving all the networks which are ready at this point - for _, netType := range tm.n.TransportNetworks() { - tm.serveNetwork(ctx, netType) } +} - tm.initTransports(ctx) - tm.Logger.Info("transport manager is serving.") - - // closing logic - <-tm.done - - tm.Logger.Info("transport manager is closing.") - defer tm.Logger.Info("transport manager closed.") - - // Close all listeners. - tm.listenersMu.Lock() - for i, lis := range tm.listeners { - if err := lis.Close(); err != nil { - tm.Logger.Warnf("listener %d of network '%s' closed with error: %v", i, lis.Network(), err) - } +// Networks returns all the network types contained within the TransportManager. +func (tm *Manager) Networks() []string { + var nets []string + for netType := range tm.netClients { + nets = append(nets, string(netType)) } - tm.listenersMu.Unlock() + return nets } func (tm *Manager) initTransports(ctx context.Context) { @@ -191,7 +159,7 @@ func (tm *Manager) initTransports(ctx context.Context) { remote = entry.Entry.RemoteEdge(tm.Conf.PubKey) tpID = entry.Entry.ID ) - isInitiator := tm.n.LocalPK() == entry.Entry.Edges[0] + isInitiator := tm.Conf.PubKey == entry.Entry.Edges[0] if _, err := tm.saveTransport(ctx, remote, isInitiator, tpType, entry.Entry.Label); err != nil { tm.Logger.Warnf("INIT: failed to init tp: type(%s) remote(%s) tpID(%s)", tpType, remote, tpID) } else { @@ -200,7 +168,13 @@ func (tm *Manager) initTransports(ctx context.Context) { } } -func (tm *Manager) acceptTransport(ctx context.Context, lis *snet.Listener) error { +// Stcpr returns stcpr client +func (tm *Manager) Stcpr() (network.Client, bool) { + c, ok := tm.netClients[network.STCP] + return c, ok +} + +func (tm *Manager) acceptTransport(ctx context.Context, lis network.Listener) error { conn, err := lis.AcceptConn() // TODO: tcp panic. if err != nil { return err @@ -219,18 +193,22 @@ func (tm *Manager) acceptTransport(ctx context.Context, lis *snet.Listener) erro tpID := tm.tpIDFromPK(conn.RemotePK(), conn.Network()) + client, ok := tm.netClients[network.Type(conn.Network())] + if !ok { + return fmt.Errorf("client not found for the type %s", conn.Network()) + } + mTp, ok := tm.tps[tpID] if !ok { tm.Logger.Debugln("No TP found, creating new one") mTp = NewManagedTransport(ManagedTransportConfig{ - Net: tm.n, + client: client, DC: tm.Conf.DiscoveryClient, LS: tm.Conf.LogStore, RemotePK: conn.RemotePK(), - NetName: lis.Network(), - AfterClosed: tm.afterTPClosed, TransportLabel: LabelUser, + ebc: tm.ebc, }, false) go func() { @@ -258,18 +236,28 @@ func (tm *Manager) acceptTransport(ctx context.Context, lis *snet.Listener) erro // ErrNotFound is returned when requested transport is not found var ErrNotFound = errors.New("transport not found") +// ErrUnknownNetwork occurs on attempt to use an unknown network type. +var ErrUnknownNetwork = errors.New("unknown network type") + +// IsKnownNetwork returns true when netName is a known +// network type that we are able to operate in +func (tm *Manager) IsKnownNetwork(netName network.Type) bool { + _, ok := tm.netClients[netName] + return ok +} + // GetTransport gets transport entity to the given remote -func (tm *Manager) GetTransport(remote cipher.PubKey, tpType string) (*ManagedTransport, error) { +func (tm *Manager) GetTransport(remote cipher.PubKey, netType network.Type) (*ManagedTransport, error) { tm.mx.RLock() defer tm.mx.RUnlock() - if !snet.IsKnownNetwork(tpType) { - return nil, snet.ErrUnknownNetwork + if !tm.IsKnownNetwork(netType) { + return nil, ErrUnknownNetwork } - tpID := tm.tpIDFromPK(remote, tpType) + tpID := tm.tpIDFromPK(remote, netType) tp, ok := tm.tps[tpID] if !ok { - return nil, fmt.Errorf("transport to %s of type %s error: %w", remote, tpType, ErrNotFound) + return nil, fmt.Errorf("transport to %s of type %s error: %w", remote, netType, ErrNotFound) } return tp, nil } @@ -297,14 +285,12 @@ func (tm *Manager) GetTransportsByLabel(label Label) []*ManagedTransport { } // SaveTransport begins to attempt to establish data transports to the given 'remote' visor. -func (tm *Manager) SaveTransport(ctx context.Context, remote cipher.PubKey, tpType string, label Label) (*ManagedTransport, error) { - +func (tm *Manager) SaveTransport(ctx context.Context, remote cipher.PubKey, netType network.Type, label Label) (*ManagedTransport, error) { if tm.isClosing() { return nil, io.ErrClosedPipe } - for { - mTp, err := tm.saveTransport(ctx, remote, true, tpType, label) + mTp, err := tm.saveTransport(ctx, remote, true, netType, label) if err != nil { if err == ErrNotServing { @@ -316,20 +302,14 @@ func (tm *Manager) SaveTransport(ctx context.Context, remote cipher.PubKey, tpTy } } -// isSTCPPKError returns true if the error is a STCP table error. -// This occurs the requested remote public key does not exist in the STCP table. -func isSTCPTableError(remotePK cipher.PubKey, err error) bool { - return err.Error() == fmt.Sprintf("pk table: entry of %s does not exist", remotePK.String()) -} - -func (tm *Manager) saveTransport(ctx context.Context, remote cipher.PubKey, initiator bool, netName string, label Label) (*ManagedTransport, error) { +func (tm *Manager) saveTransport(ctx context.Context, remote cipher.PubKey, initiator bool, netType network.Type, label Label) (*ManagedTransport, error) { tm.mx.Lock() defer tm.mx.Unlock() - if !snet.IsKnownNetwork(netName) { - return nil, snet.ErrUnknownNetwork + if !tm.IsKnownNetwork(netType) { + return nil, ErrUnknownNetwork } - tpID := tm.tpIDFromPK(remote, netName) + tpID := tm.tpIDFromPK(remote, netType) tm.Logger.Debugf("Initializing TP with ID %s", tpID) oldMTp, ok := tm.tps[tpID] @@ -338,41 +318,41 @@ func (tm *Manager) saveTransport(ctx context.Context, remote cipher.PubKey, init return oldMTp, nil } - afterTPClosed := tm.afterTPClosed + client, ok := tm.netClients[network.Type(netType)] + if !ok { + return nil, fmt.Errorf("client not found for the type %s", netType) + } mTp := NewManagedTransport(ManagedTransportConfig{ - Net: tm.n, + client: client, + ebc: tm.ebc, DC: tm.Conf.DiscoveryClient, LS: tm.Conf.LogStore, RemotePK: remote, - NetName: netName, - AfterClosed: afterTPClosed, TransportLabel: label, }, initiator) - if mTp.netName == tptypes.STCPR { - ar := mTp.n.Conf().ARClient - if ar != nil { - visorData, err := ar.Resolve(context.Background(), mTp.netName, remote) - if err == nil { - mTp.remoteAddr = visorData.RemoteAddr - } else { - if err != arclient.ErrNoEntry { - return nil, fmt.Errorf("failed to resolve %s: %w", remote, err) - } + // todo: do we need this here? Client dial will run resolve anyway + if mTp.Type() == network.STCPR && tm.arClient != nil { + visorData, err := tm.arClient.Resolve(context.Background(), string(mTp.Type()), remote) + if err == nil { + mTp.remoteAddr = visorData.RemoteAddr + } else { + if err != addrresolver.ErrNoEntry { + return nil, fmt.Errorf("failed to resolve %s: %w", remote, err) } } } - tm.Logger.Debugf("Dialing transport to %v via %v", mTp.Remote(), mTp.netName) + tm.Logger.Debugf("Dialing transport to %v via %v", mTp.Remote(), mTp.client.Type()) if err := mTp.Dial(ctx); err != nil { - tm.Logger.Debugf("Error dialing transport to %v via %v: %v", mTp.Remote(), mTp.netName, err) + tm.Logger.Debugf("Error dialing transport to %v via %v: %v", mTp.Remote(), mTp.client.Type(), err) // The first occurs when an old tp is returned by 'tm.saveTransport', meaning a tp of the same transport ID was // just deleted (and has not yet fully closed). Hence, we should close and delete the old tp and try again. // The second occurs when the tp type is STCP and the requested remote PK is not associated with an IP address in // the STCP table. There is no point in retrying as a connection would be impossible, so we just return an // error. - if err == ErrNotServing || isSTCPTableError(remote, err) { + if err == ErrNotServing || errors.Is(err, network.ErrStcpEntryNotFound) { if closeErr := mTp.Close(); closeErr != nil { tm.Logger.WithError(err).Warn("Closing mTp returns non-nil error.") } @@ -387,7 +367,7 @@ func (tm *Manager) saveTransport(ctx context.Context, remote cipher.PubKey, init tm.deleteTransport(mTp.Entry.ID) }() tm.tps[tpID] = mTp - tm.Logger.Infof("saved transport: remote(%s) type(%s) tpID(%s)", remote, netName, tpID) + tm.Logger.Infof("saved transport: remote(%s) type(%s) tpID(%s)", remote, netType, tpID) return mTp, nil } @@ -399,7 +379,7 @@ func (tm *Manager) STCPRRemoteAddrs() []string { defer tm.mx.RUnlock() for _, tp := range tm.tps { - if tp.Entry.Type == tptypes.STCPR && tp.remoteAddr != "" { + if tp.Entry.Type == network.STCPR && tp.remoteAddr != "" { addrs = append(addrs, tp.remoteAddr) } } @@ -453,11 +433,6 @@ func (tm *Manager) ReadPacket() (routing.Packet, error) { STATE */ -// Networks returns all the network types contained within the TransportManager. -func (tm *Manager) Networks() []string { - return tm.n.TransportNetworks() -} - // Transport obtains a Transport via a given Transport ID. func (tm *Manager) Transport(id uuid.UUID) *ManagedTransport { tm.mx.RLock() @@ -484,16 +459,13 @@ func (tm *Manager) Local() cipher.PubKey { // Close closes opened transports and registered factories. func (tm *Manager) Close() error { - tm.closeOnce.Do(func() { - tm.close() - }) + tm.closeOnce.Do(tm.close) return nil } func (tm *Manager) close() { - if tm == nil { - return - } + tm.Logger.Info("transport manager is closing.") + defer tm.Logger.Info("transport manager closed.") tm.mx.Lock() defer tm.mx.Unlock() @@ -524,54 +496,6 @@ func (tm *Manager) isClosing() bool { } } -func (tm *Manager) tpIDFromPK(pk cipher.PubKey, tpType string) uuid.UUID { - return MakeTransportID(tm.Conf.PubKey, pk, tpType) -} - -// CreateTransportPair create a new transport pair for tests. -func CreateTransportPair( - tpDisc DiscoveryClient, - keys []snettest.KeyPair, - nEnv *snettest.Env, - network string, -) (m0 *Manager, m1 *Manager, tp0 *ManagedTransport, tp1 *ManagedTransport, err error) { - // Prepare tp manager 0. - pk0, sk0 := keys[0].PK, keys[0].SK - ls0 := InMemoryTransportLogStore() - m0, err = NewManager(nil, nEnv.Nets[0], &ManagerConfig{ - PubKey: pk0, - SecKey: sk0, - DiscoveryClient: tpDisc, - LogStore: ls0, - }) - if err != nil { - return nil, nil, nil, nil, err - } - - go m0.Serve(context.TODO()) - - // Prepare tp manager 1. - pk1, sk1 := keys[1].PK, keys[1].SK - ls1 := InMemoryTransportLogStore() - m1, err = NewManager(nil, nEnv.Nets[1], &ManagerConfig{ - PubKey: pk1, - SecKey: sk1, - DiscoveryClient: tpDisc, - LogStore: ls1, - }) - if err != nil { - return nil, nil, nil, nil, err - } - - go m1.Serve(context.TODO()) - - // Create data transport between manager 1 & manager 2. - tp1, err = m1.SaveTransport(context.TODO(), pk0, network, LabelUser) - if err != nil { - return nil, nil, nil, nil, err - } - - tp0 = m0.Transport(MakeTransportID(pk0, pk1, network)) - - return m0, m1, tp0, tp1, nil +func (tm *Manager) tpIDFromPK(pk cipher.PubKey, netType network.Type) uuid.UUID { + return MakeTransportID(tm.Conf.PubKey, pk, netType) } diff --git a/pkg/transport/manager_test.go b/pkg/transport/manager_test.go index 1a37b2e44..33d229a4a 100644 --- a/pkg/transport/manager_test.go +++ b/pkg/transport/manager_test.go @@ -1,22 +1,15 @@ package transport_test import ( - "context" - "fmt" "io/ioutil" "log" "os" "testing" - "time" - "github.com/skycoin/dmsg" "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/skycoin/skywire/pkg/routing" - "github.com/skycoin/skywire/pkg/snet/snettest" "github.com/skycoin/skywire/pkg/transport" ) @@ -37,127 +30,6 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } - -// TODO: test hangs if Manager is closed to early, needs to receive an error though -func TestNewManager(t *testing.T) { - tpDisc := transport.NewDiscoveryMock() - - keys := snettest.GenKeyPairs(2) - nEnv := snettest.NewEnv(t, keys, []string{dmsg.Type}) - defer nEnv.Teardown() - - // Prepare tp manager 0. - pk0, sk0 := keys[0].PK, keys[0].SK - ls0 := transport.InMemoryTransportLogStore() - m0, err := transport.NewManager(nil, nEnv.Nets[0], &transport.ManagerConfig{ - PubKey: pk0, - SecKey: sk0, - DiscoveryClient: tpDisc, - LogStore: ls0, - }) - require.NoError(t, err) - go m0.Serve(context.TODO()) - defer func() { require.NoError(t, m0.Close()) }() - - // Prepare tp manager 1. - pk1, sk1 := keys[1].PK, keys[1].SK - ls1 := transport.InMemoryTransportLogStore() - m2, err := transport.NewManager(nil, nEnv.Nets[1], &transport.ManagerConfig{ - PubKey: pk1, - SecKey: sk1, - DiscoveryClient: tpDisc, - LogStore: ls1, - }) - require.NoError(t, err) - go m2.Serve(context.TODO()) - defer func() { require.NoError(t, m2.Close()) }() - - // Create data transport between manager 1 & manager 2. - tp2, err := m2.SaveTransport(context.TODO(), pk0, "dmsg", transport.LabelUser) - require.NoError(t, err) - tp1 := m0.Transport(transport.MakeTransportID(pk0, pk1, "dmsg")) - require.NotNil(t, tp1) - - fmt.Println("transports created") - - totalSent2 := 0 - totalSent1 := 0 - - // Check read/writes are of expected. - t.Run("check_read_write", func(t *testing.T) { - - for i := 0; i < 10; i++ { - totalSent2 += i - rID := routing.RouteID(i) - payload := cipher.RandByte(i) - - packet, err := routing.MakeDataPacket(rID, payload) - require.NoError(t, err) - - require.NoError(t, tp2.WritePacket(context.TODO(), packet)) - - recv, err := m0.ReadPacket() - require.NoError(t, err) - require.Equal(t, rID, recv.RouteID()) - require.Equal(t, uint16(i), recv.Size()) - require.Equal(t, payload, recv.Payload()) - } - - for i := 0; i < 20; i++ { - totalSent1 += i - rID := routing.RouteID(i) - payload := cipher.RandByte(i) - - packet, err := routing.MakeDataPacket(rID, payload) - require.NoError(t, err) - - require.NoError(t, tp1.WritePacket(context.TODO(), packet)) - - recv, err := m2.ReadPacket() - require.NoError(t, err) - require.Equal(t, rID, recv.RouteID()) - require.Equal(t, uint16(i), recv.Size()) - require.Equal(t, payload, recv.Payload()) - } - }) - - // Ensure tp log entries are of expected. - t.Run("check_tp_logs", func(t *testing.T) { - - // 1.5x log write interval just to be safe. - time.Sleep(time.Second * 9 / 2) - - entry1, err := ls0.Entry(tp1.Entry.ID) - require.NoError(t, err) - assert.Equal(t, uint64(totalSent1), entry1.SentBytes) - assert.Equal(t, uint64(totalSent2), entry1.RecvBytes) - - entry2, err := ls1.Entry(tp2.Entry.ID) - require.NoError(t, err) - assert.Equal(t, uint64(totalSent2), entry2.SentBytes) - assert.Equal(t, uint64(totalSent1), entry2.RecvBytes) - }) - - // Ensure deleting a transport works as expected. - t.Run("check_delete_tp", func(t *testing.T) { - - // Make transport ID. - tpID := transport.MakeTransportID(pk0, pk1, "dmsg") - - // Ensure transports are registered properly in tp discovery. - entry, err := tpDisc.GetTransportByID(context.TODO(), tpID) - require.NoError(t, err) - assert.Equal(t, transport.SortEdges(pk0, pk1), entry.Entry.Edges) - assert.True(t, entry.IsUp) - - m2.DeleteTransport(tp2.Entry.ID) - - _, err = tpDisc.GetTransportByID(context.TODO(), tpID) - require.NotNil(t, err) - require.Contains(t, err.Error(), "not found") - }) -} - func TestSortEdges(t *testing.T) { for i := 0; i < 100; i++ { keyA, _ := cipher.GenerateKeyPair() diff --git a/pkg/snet/arclient/client.go b/pkg/transport/network/addrresolver/client.go similarity index 89% rename from pkg/snet/arclient/client.go rename to pkg/transport/network/addrresolver/client.go index c51d02dd7..286e114c2 100644 --- a/pkg/snet/arclient/client.go +++ b/pkg/transport/network/addrresolver/client.go @@ -1,5 +1,5 @@ -// Package arclient implements address resolver client -package arclient +// Package addrresolver implements address resolver client +package addrresolver import ( "bytes" @@ -15,7 +15,6 @@ import ( "time" "github.com/AudriusButkevicius/pfilter" - "github.com/skycoin/dmsg" "github.com/skycoin/dmsg/cipher" dmsgnetutil "github.com/skycoin/dmsg/netutil" "github.com/skycoin/skycoin/src/util/logging" @@ -24,8 +23,6 @@ import ( "github.com/skycoin/skywire/internal/httpauth" "github.com/skycoin/skywire/internal/netutil" "github.com/skycoin/skywire/internal/packetfilter" - "github.com/skycoin/skywire/pkg/snet/directtp/tpconn" - "github.com/skycoin/skywire/pkg/snet/directtp/tphandshake" ) const ( @@ -56,8 +53,8 @@ type Error struct { type APIClient interface { io.Closer BindSTCPR(ctx context.Context, port string) error - BindSUDPH(filter *pfilter.PacketFilter) (<-chan RemoteVisor, error) - Resolve(ctx context.Context, tType string, pk cipher.PubKey) (VisorData, error) + BindSUDPH(filter *pfilter.PacketFilter, handshake Handshake) (<-chan RemoteVisor, error) + Resolve(ctx context.Context, netType string, pk cipher.PubKey) (VisorData, error) Health(ctx context.Context) (int, error) } @@ -222,7 +219,10 @@ func (c *httpClient) BindSTCPR(ctx context.Context, port string) error { return nil } -func (c *httpClient) BindSUDPH(filter *pfilter.PacketFilter) (<-chan RemoteVisor, error) { +// Handshake type is used to decouple client from handshake and network packages +type Handshake func(net.Conn) (net.Conn, error) + +func (c *httpClient) BindSUDPH(filter *pfilter.PacketFilter, hs Handshake) (<-chan RemoteVisor, error) { if !c.isReady() { c.log.Infof("BindSUDPR: Address resolver is not ready yet, waiting...") <-c.ready @@ -242,8 +242,11 @@ func (c *httpClient) BindSUDPH(filter *pfilter.PacketFilter) (<-chan RemoteVisor } c.log.Infof("SUDPH Local port: %v", localPort) - - arConn, err := c.wrapConn(c.sudphConn) + kcpConn, err := kcp.NewConn(c.remoteUDPAddr, nil, 0, 0, c.sudphConn) + if err != nil { + return nil, err + } + arConn, err := hs(kcpConn) if err != nil { return nil, err } @@ -389,34 +392,6 @@ func (c *httpClient) readSUDPHMessages(reader io.Reader) <-chan RemoteVisor { return addrCh } -func (c *httpClient) wrapConn(conn net.PacketConn) (*tpconn.Conn, error) { - arKCPConn, err := kcp.NewConn(c.remoteUDPAddr, nil, 0, 0, conn) - if err != nil { - return nil, err - } - - emptyAddr := dmsg.Addr{PK: cipher.PubKey{}, Port: 0} - hs := tphandshake.InitiatorHandshake(c.sk, dmsg.Addr{PK: c.pk, Port: 0}, emptyAddr) - - connConfig := tpconn.Config{ - Log: c.log, - Conn: arKCPConn, - LocalPK: c.pk, - LocalSK: c.sk, - Deadline: time.Now().Add(tphandshake.Timeout), - Handshake: hs, - Encrypt: false, - Initiator: true, - } - - arConn, err := tpconn.NewConn(connConfig) - if err != nil { - return nil, fmt.Errorf("newConn: %w", err) - } - - return arConn, nil -} - func (c *httpClient) Close() error { select { case <-c.closed: diff --git a/pkg/snet/arclient/client_test.go b/pkg/transport/network/addrresolver/client_test.go similarity index 99% rename from pkg/snet/arclient/client_test.go rename to pkg/transport/network/addrresolver/client_test.go index 304f33ddb..1746c0e7a 100644 --- a/pkg/snet/arclient/client_test.go +++ b/pkg/transport/network/addrresolver/client_test.go @@ -1,4 +1,4 @@ -package arclient +package addrresolver import ( "context" diff --git a/pkg/snet/arclient/mock_api_client.go b/pkg/transport/network/addrresolver/mock_api_client.go similarity index 94% rename from pkg/snet/arclient/mock_api_client.go rename to pkg/transport/network/addrresolver/mock_api_client.go index 5ee1f8dfc..a6cd7281d 100644 --- a/pkg/snet/arclient/mock_api_client.go +++ b/pkg/transport/network/addrresolver/mock_api_client.go @@ -1,6 +1,6 @@ // Code generated by mockery v1.0.0. DO NOT EDIT. -package arclient +package addrresolver import ( context "context" @@ -30,7 +30,7 @@ func (_m *MockAPIClient) BindSTCPR(ctx context.Context, port string) error { } // BindSUDPH provides a mock function with given fields: filter -func (_m *MockAPIClient) BindSUDPH(filter *pfilter.PacketFilter) (<-chan RemoteVisor, error) { +func (_m *MockAPIClient) BindSUDPH(filter *pfilter.PacketFilter, handshake Handshake) (<-chan RemoteVisor, error) { ret := _m.Called(filter) var r0 <-chan RemoteVisor diff --git a/pkg/transport/network/client.go b/pkg/transport/network/client.go new file mode 100644 index 000000000..a5edc85b6 --- /dev/null +++ b/pkg/transport/network/client.go @@ -0,0 +1,332 @@ +package network + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "sync" + + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skycoin/src/util/logging" + + "github.com/skycoin/skywire/pkg/app/appevent" + "github.com/skycoin/skywire/pkg/transport/network/addrresolver" + "github.com/skycoin/skywire/pkg/transport/network/handshake" + "github.com/skycoin/skywire/pkg/transport/network/porter" + "github.com/skycoin/skywire/pkg/transport/network/stcp" +) + +// Client provides access to skywire network +// It allows dialing remote visors using their public keys, as +// well as listening to incoming connections from other visors +type Client interface { + // Dial remote visor, that is listening on the given skywire port + Dial(ctx context.Context, remote cipher.PubKey, port uint16) (Conn, error) + // Start initializes the client and prepares it for listening. It is required + // to be called to start accepting connections + Start() error + // Listen on the given skywire port. This can be called multiple times + // for different ports for the same client. It requires Start to be called + // to start accepting connections + Listen(port uint16) (Listener, error) + // LocalAddr returns the actual network address under which this client listens to + // new connections + LocalAddr() (net.Addr, error) + // PK returns public key of the visor running this client + PK() cipher.PubKey + // SK returns secret key of the visor running this client + SK() cipher.SecKey + // Close the client, stop accepting connections. Connections returned by the + // client should be closed manually + Close() error + // Type returns skywire network type in which this client operates + Type() Type +} + +// ClientFactory is used to create Client instances +// and holds dependencies for different clients +type ClientFactory struct { + PK cipher.PubKey + SK cipher.SecKey + ListenAddr string + PKTable stcp.PKTable + ARClient addrresolver.APIClient + EB *appevent.Broadcaster + DmsgC *dmsg.Client +} + +// MakeClient creates a new client of specified type +func (f *ClientFactory) MakeClient(netType Type) (Client, error) { + log := logging.MustGetLogger(string(netType)) + p := porter.New(porter.MinEphemeral) + + generic := &genericClient{} + generic.listenStarted = make(chan struct{}) + generic.done = make(chan struct{}) + generic.listeners = make(map[uint16]*listener) + generic.log = log + generic.porter = p + generic.eb = f.EB + generic.lPK = f.PK + generic.lSK = f.SK + generic.listenAddr = f.ListenAddr + + resolved := &resolvedClient{genericClient: generic, ar: f.ARClient} + + switch netType { + case STCP: + return newStcp(generic, f.PKTable), nil + case STCPR: + return newStcpr(resolved), nil + case SUDPH: + return newSudph(resolved), nil + case DMSG: + return newDmsgClient(f.DmsgC), nil + } + return nil, fmt.Errorf("cannot initiate client, type %s not supported", netType) +} + +// genericClient unites common logic for all clients +// The main responsibility is handshaking over incoming +// and outgoing raw network connections, obtaining remote information +// from the handshake and wrapping raw connections with skywire +// connection type. +// Incoming connections also directed to appropriate listener using +// skywire port, obtained from incoming connection handshake +type genericClient struct { + lPK cipher.PubKey + lSK cipher.SecKey + listenAddr string + netType Type + + log *logging.Logger + porter *porter.Porter + eb *appevent.Broadcaster + + connListener net.Listener + listeners map[uint16]*listener + listenStarted chan struct{} + mu sync.RWMutex + done chan struct{} + closeOnce sync.Once +} + +// initConnection will initialize skywire connection over opened raw connection to +// the remote client +// The process will perform handshake over raw connection +func (c *genericClient) initConnection(ctx context.Context, conn net.Conn, rPK cipher.PubKey, rPort uint16) (*conn, error) { + lPort, freePort, err := c.porter.ReserveEphemeral(ctx) + if err != nil { + return nil, err + } + lAddr, rAddr := dmsg.Addr{PK: c.lPK, Port: lPort}, dmsg.Addr{PK: rPK, Port: rPort} + remoteAddr := conn.RemoteAddr() + c.log.Infof("Performing handshake with %v", remoteAddr) + hs := handshake.InitiatorHandshake(c.lSK, lAddr, rAddr) + return c.wrapConn(conn, hs, true, freePort) +} + +// acceptConnections continuously accepts incoming connections that come from given listener +// these connections will be properly handshaked and passed to an appropriate skywire listener +// using skywire port +func (c *genericClient) acceptConnections(lis net.Listener) { + c.mu.Lock() + c.connListener = lis + close(c.listenStarted) + c.mu.Unlock() + c.log.Infof("listening on addr: %v", c.connListener.Addr()) + for { + if err := c.acceptConn(); err != nil { + if errors.Is(err, io.EOF) { + continue // likely it's a dummy connection from service discovery + } + c.log.Warnf("failed to accept incoming connection: %v", err) + if !handshake.IsHandshakeError(err) { + c.log.Warnf("stopped serving") + return + } + } + } +} + +// wrapConn performs handshake over provided raw connection and wraps it in +// network.Conn type using the data obtained from handshake process +func (c *genericClient) wrapConn(rawConn net.Conn, hs handshake.Handshake, initiator bool, onClose func()) (*conn, error) { + conn, err := doHandshake(rawConn, hs, c.netType, c.log) + if err != nil { + onClose() + } + conn.freePort = onClose + c.log.Infof("Sent handshake to %v, local addr %v, remote addr %v", rawConn.RemoteAddr(), conn.lAddr, conn.rAddr) + if err := conn.encrypt(c.lPK, c.lSK, initiator); err != nil { + return nil, err + } + return conn, nil +} + +// acceptConn accepts new connection in underlying raw network listener, +// performs handshake, and using the data from the handshake wraps +// connection and delivers it to the appropriate listener. +// The listener is chosen using skywire port from the incoming visor connection +func (c *genericClient) acceptConn() error { + if c.isClosed() { + return io.ErrClosedPipe + } + conn, err := c.connListener.Accept() + if err != nil { + return err + } + remoteAddr := conn.RemoteAddr() + c.log.Infof("Accepted connection from %v", remoteAddr) + + onClose := func() {} + hs := handshake.ResponderHandshake(handshake.MakeF2PortChecker(c.checkListener)) + wrappedConn, err := c.wrapConn(conn, hs, false, onClose) + if err != nil { + return err + } + lis, err := c.getListener(wrappedConn.lAddr.Port) + if err != nil { + return err + } + return lis.introduce(wrappedConn) +} + +// LocalAddr returns local address. This is network address the client +// listens to for incoming connections, not skywire address +func (c *genericClient) LocalAddr() (net.Addr, error) { + <-c.listenStarted + if c.isClosed() { + return nil, ErrNotListening + } + return c.connListener.Addr(), nil +} + +// getListener returns listener to specified skywire port +func (c *genericClient) getListener(port uint16) (*listener, error) { + c.mu.Lock() + defer c.mu.Unlock() + lis, ok := c.listeners[port] + if !ok { + return nil, errors.New("not listening on given port") + } + return lis, nil +} + +func (c *genericClient) checkListener(port uint16) error { + _, err := c.getListener(port) + return err +} + +// Listen starts listening on a specified port number. The port is a skywire port +// and is not related to local OS ports. Underlying connection will most likely use +// a different port number +// Listen requires Serve to be called, which will accept connections to all skywire ports +func (c *genericClient) Listen(port uint16) (Listener, error) { + if c.isClosed() { + return nil, io.ErrClosedPipe + } + + ok, freePort := c.porter.Reserve(port) + if !ok { + return nil, ErrPortOccupied + } + + c.mu.Lock() + defer c.mu.Unlock() + + lAddr := dmsg.Addr{PK: c.lPK, Port: port} + lis := newListener(lAddr, freePort, c.netType) + c.listeners[port] = lis + + return lis, nil +} + +func (c *genericClient) isClosed() bool { + select { + case <-c.done: + return true + default: + return false + } +} + +// PK implements interface +func (c *genericClient) PK() cipher.PubKey { + return c.lPK +} + +// SK implements interface +func (c *genericClient) SK() cipher.SecKey { + return c.lSK +} + +// Close implements interface +func (c *genericClient) Close() error { + c.closeOnce.Do(func() { + close(c.done) + + c.mu.Lock() + defer c.mu.Unlock() + + if c.connListener != nil { + if err := c.connListener.Close(); err != nil { + c.log.WithError(err).Warnf("Failed to close incoming connection listener") + } + } + + for _, lis := range c.listeners { + if err := lis.Close(); err != nil { + c.log.WithError(err).WithField("addr", lis.Addr().String()).Warnf("Failed to close listener") + } + } + }) + + return nil +} + +// Type implements interface +func (c *genericClient) Type() Type { + return c.netType +} + +// resolvedClient is a wrapper around genericClient, +// for the types of transports that use address resolver service +// to resolve addresses of remote visors +type resolvedClient struct { + *genericClient + ar addrresolver.APIClient +} + +type dialFunc func(ctx context.Context, addr string) (net.Conn, error) + +// dialVisor uses address resovler to obtain network address of the target visor +// and dials that visor address(es) +// dial process is specific to transport type and is provided by the client +func (c *resolvedClient) dialVisor(ctx context.Context, rPK cipher.PubKey, dial dialFunc) (net.Conn, error) { + c.log.Infof("Dialing PK %v", rPK) + visorData, err := c.ar.Resolve(ctx, string(c.netType), rPK) + if err != nil { + return nil, fmt.Errorf("resolve PK: %w", err) + } + c.log.Infof("Resolved PK %v to visor data %v", rPK, visorData) + + if visorData.IsLocal { + for _, host := range visorData.Addresses { + addr := net.JoinHostPort(host, visorData.Port) + conn, err := dial(ctx, addr) + if err == nil { + return conn, nil + } + } + } + + addr := visorData.RemoteAddr + if _, _, err := net.SplitHostPort(addr); err != nil { + addr = net.JoinHostPort(addr, visorData.Port) + } + return dial(ctx, addr) +} diff --git a/pkg/transport/network/connection.go b/pkg/transport/network/connection.go new file mode 100644 index 000000000..2bd5b5d48 --- /dev/null +++ b/pkg/transport/network/connection.go @@ -0,0 +1,134 @@ +package network + +import ( + "fmt" + "net" + "time" + + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/dmsg/noise" + "github.com/skycoin/skycoin/src/util/logging" + + "github.com/skycoin/skywire/pkg/transport/network/handshake" +) + +const encryptHSTimout = 5 * time.Second + +// Conn represents a network connection between two visors in skywire network +// This connection wraps raw network connection and is ready to use for sending data. +// It also provides skywire-specific methods on top of net.Conn +type Conn interface { + net.Conn + // LocalPK returns local public key of connection + LocalPK() cipher.PubKey + + // RemotePK returns remote public key of connection + RemotePK() cipher.PubKey + + // LocalPort returns local skywire port of connection + // This is not underlying OS port, but port within skywire network + LocalPort() uint16 + + // RemotePort returns remote skywire port of connection + // This is not underlying OS port, but port within skywire network + RemotePort() uint16 + + // Network returns network of connection + Network() Type +} + +type conn struct { + net.Conn + lAddr, rAddr dmsg.Addr + freePort func() + connType Type +} + +// DoHandshake performs given handshake over given raw connection and wraps +// connection in network.Conn +func DoHandshake(rawConn net.Conn, hs handshake.Handshake, netType Type, log *logging.Logger) (Conn, error) { + return doHandshake(rawConn, hs, netType, log) +} + +// handshake performs given handshake over given raw connection and wraps +// connection in network.conn +func doHandshake(rawConn net.Conn, hs handshake.Handshake, netType Type, log *logging.Logger) (*conn, error) { + lAddr, rAddr, err := hs(rawConn, time.Now().Add(handshake.Timeout)) + if err != nil { + if err := rawConn.Close(); err != nil { + log.WithError(err).Warnf("Failed to close connection") + } + return nil, err + } + handshakedConn := &conn{Conn: rawConn, lAddr: lAddr, rAddr: rAddr, connType: netType} + return handshakedConn, nil +} + +func (c *conn) encrypt(lPK cipher.PubKey, lSK cipher.SecKey, initator bool) error { + config := noise.Config{ + LocalPK: lPK, + LocalSK: lSK, + RemotePK: c.rAddr.PK, + Initiator: initator, + } + + wrappedConn, err := EncryptConn(config, c.Conn) + if err != nil { + return fmt.Errorf("encrypt connection to %v@%v: %w", c.rAddr, c.Conn.RemoteAddr(), err) + } + + c.Conn = wrappedConn + return nil +} + +// EncryptConn encrypts given connection +func EncryptConn(config noise.Config, conn net.Conn) (net.Conn, error) { + ns, err := noise.New(noise.HandshakeKK, config) + if err != nil { + return nil, fmt.Errorf("failed to prepare stream noise object: %w", err) + } + + wrappedConn, err := noise.WrapConn(conn, ns, encryptHSTimout) + if err != nil { + return nil, fmt.Errorf("error performing noise handshake: %w", err) + } + + return wrappedConn, nil +} + +// LocalAddr implements net.Conn +func (c *conn) LocalAddr() net.Addr { + return c.lAddr +} + +// RemoteAddr implements net.Conn +func (c *conn) RemoteAddr() net.Addr { + return c.rAddr +} + +// Close implements net.Conn +func (c *conn) Close() error { + if c.freePort != nil { + c.freePort() + } + + return c.Conn.Close() +} + +// LocalPK returns local public key of connection +func (c *conn) LocalPK() cipher.PubKey { return c.lAddr.PK } + +// RemotePK returns remote public key of connection +func (c *conn) RemotePK() cipher.PubKey { return c.rAddr.PK } + +// LocalPort returns local skywire port of connection +// This is not underlying OS port, but port within skywire network +func (c *conn) LocalPort() uint16 { return c.lAddr.Port } + +// RemotePort returns remote skywire port of connection +// This is not underlying OS port, but port within skywire network +func (c *conn) RemotePort() uint16 { return c.rAddr.Port } + +// Network returns network of connection +func (c *conn) Network() Type { return c.connType } diff --git a/pkg/transport/network/dmsg.go b/pkg/transport/network/dmsg.go new file mode 100644 index 000000000..f9f97121e --- /dev/null +++ b/pkg/transport/network/dmsg.go @@ -0,0 +1,135 @@ +package network + +import ( + "context" + "fmt" + "net" + + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" +) + +// dmsgClientAdapter is a wrapper around dmsg.Client to conform to Client +// interface +type dmsgClientAdapter struct { + dmsgC *dmsg.Client +} + +func newDmsgClient(dmsgC *dmsg.Client) Client { + return &dmsgClientAdapter{dmsgC: dmsgC} +} + +// LocalAddr implements interface +func (c *dmsgClientAdapter) LocalAddr() (net.Addr, error) { + for _, ses := range c.dmsgC.AllSessions() { + return ses.SessionCommon.GetConn().LocalAddr(), nil + } + return nil, fmt.Errorf("not listening to dmsg") +} + +// Dial implements Client interface +func (c *dmsgClientAdapter) Dial(ctx context.Context, remote cipher.PubKey, port uint16) (Conn, error) { + conn, err := c.dmsgC.DialStream(ctx, dmsg.Addr{PK: remote, Port: port}) + if err != nil { + return nil, err + } + return &dmsgConnAdapter{conn}, nil +} + +// Start implements Client interface +func (c *dmsgClientAdapter) Start() error { + // no need to serve, the wrapped dmsgC is already serving + return nil +} + +// Listen implements Client interface +func (c *dmsgClientAdapter) Listen(port uint16) (Listener, error) { + lis, err := c.dmsgC.Listen(port) + if err != nil { + return nil, err + } + return &dmsgListenerAdapter{lis}, nil +} + +// PK implements Client interface +func (c *dmsgClientAdapter) PK() cipher.PubKey { + return c.dmsgC.LocalPK() +} + +// SK implements Client interface +func (c *dmsgClientAdapter) SK() cipher.SecKey { + return c.dmsgC.LocalSK() +} + +// Close implements Client interface +func (c *dmsgClientAdapter) Close() error { + // this client is for transport usage, but dmsgC it wraps may be used in + // other places. It should be closed by whoever initialized it, not here + return nil +} + +// Type implements Client interface +func (c *dmsgClientAdapter) Type() Type { + return DMSG +} + +// wrapper around listener returned by dmsg.Client +// that conforms to Listener interface +type dmsgListenerAdapter struct { + *dmsg.Listener +} + +// AcceptConn implements Listener interface +func (lis *dmsgListenerAdapter) AcceptConn() (Conn, error) { + stream, err := lis.Listener.AcceptStream() + if err != nil { + return nil, err + } + return &dmsgConnAdapter{stream}, nil +} + +// Network implements Listener interface +func (lis *dmsgListenerAdapter) Network() Type { + return DMSG +} + +// PK implements Listener interface +func (lis *dmsgListenerAdapter) PK() cipher.PubKey { + return lis.Listener.DmsgAddr().PK +} + +// Port implements Listener interface +func (lis *dmsgListenerAdapter) Port() uint16 { + return lis.DmsgAddr().Port +} + +// wrapper around connection returned by dmsg.Client +// that conforms to Conn interface +type dmsgConnAdapter struct { + *dmsg.Stream +} + +// LocalPK implements Conn interface +func (c *dmsgConnAdapter) LocalPK() cipher.PubKey { + return c.RawLocalAddr().PK +} + +// RemotePK implements Conn interface +func (c *dmsgConnAdapter) RemotePK() cipher.PubKey { + return c.RawRemoteAddr().PK +} + +// LocalPort implements Conn interface +func (c *dmsgConnAdapter) LocalPort() uint16 { + return c.RawLocalAddr().Port +} + +// RemotePort implements Conn interface +func (c *dmsgConnAdapter) RemotePort() uint16 { + return c.RawRemoteAddr().Port +} + +// Network implements Conn interface +func (c *dmsgConnAdapter) Network() Type { + return DMSG +} diff --git a/pkg/snet/directtp/tphandshake/handshake.go b/pkg/transport/network/handshake/handshake.go similarity index 93% rename from pkg/snet/directtp/tphandshake/handshake.go rename to pkg/transport/network/handshake/handshake.go index e037bbf29..10afde452 100644 --- a/pkg/snet/directtp/tphandshake/handshake.go +++ b/pkg/transport/network/handshake/handshake.go @@ -1,4 +1,4 @@ -package tphandshake +package handshake import ( "bytes" @@ -98,8 +98,19 @@ func InitiatorHandshake(lSK cipher.SecKey, localAddr, remoteAddr dmsg.Addr) Hand }) } +// CheckF2 checks second frame of handshake +type CheckF2 = func(f2 Frame2) error + +// MakeF2PortChecker returns new CheckF2 function that will use +// port checker to check port in Frame2 +func MakeF2PortChecker(portChecker func(port uint16) error) CheckF2 { + return func(f2 Frame2) error { + return portChecker(f2.DstAddr.Port) + } +} + // ResponderHandshake creates the handshake logic on the responder's side. -func ResponderHandshake(checkF2 func(f2 Frame2) error) Handshake { +func ResponderHandshake(checkF2 CheckF2) Handshake { return handshakeMiddleware(func(conn net.Conn, deadline time.Time) (lAddr, rAddr dmsg.Addr, err error) { if err = readFrame0(conn); err != nil { return dmsg.Addr{}, dmsg.Addr{}, err diff --git a/pkg/snet/directtp/tphandshake/handshake_test.go b/pkg/transport/network/handshake/handshake_test.go similarity index 98% rename from pkg/snet/directtp/tphandshake/handshake_test.go rename to pkg/transport/network/handshake/handshake_test.go index 8731ed569..474527f9d 100644 --- a/pkg/snet/directtp/tphandshake/handshake_test.go +++ b/pkg/transport/network/handshake/handshake_test.go @@ -1,4 +1,4 @@ -package tphandshake +package handshake import ( "errors" diff --git a/pkg/transport/network/listener.go b/pkg/transport/network/listener.go new file mode 100644 index 000000000..96ebb33bd --- /dev/null +++ b/pkg/transport/network/listener.go @@ -0,0 +1,114 @@ +package network + +import ( + "io" + "net" + "sync" + + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" +) + +// Listener represents a skywire network listener. It wraps net.Listener +// with other skywire-specific data +// Listener implements net.Listener +type Listener interface { + net.Listener + PK() cipher.PubKey + Port() uint16 + Network() Type + AcceptConn() (Conn, error) +} + +type listener struct { + lAddr dmsg.Addr + mx sync.Mutex + once sync.Once + freePort func() + accept chan *conn + done chan struct{} + network Type +} + +// NewListener returns a new Listener. +func NewListener(lAddr dmsg.Addr, freePort func(), network Type) Listener { + return newListener(lAddr, freePort, network) +} + +func newListener(lAddr dmsg.Addr, freePort func(), network Type) *listener { + return &listener{ + lAddr: lAddr, + freePort: freePort, + accept: make(chan *conn), + done: make(chan struct{}), + network: network, + } +} + +// Accept implements net.Listener, returns generic net.Conn +func (l *listener) Accept() (net.Conn, error) { + return l.AcceptConn() +} + +// AcceptConn accepts a skywire connection and returns network.Conn +func (l *listener) AcceptConn() (Conn, error) { + c, ok := <-l.accept + if !ok { + return nil, io.ErrClosedPipe + } + + return c, nil +} + +// Close implements net.Listener +func (l *listener) Close() error { + l.once.Do(func() { + close(l.done) + l.mx.Lock() + close(l.accept) + l.mx.Unlock() + for conn := range l.accept { + conn.Close() //nolint: errcheck, gosec + } + l.freePort() + }) + + return nil +} + +// Addr implements net.Listener +func (l *listener) Addr() net.Addr { + return l.lAddr +} + +// Addr implements net.Listener +func (l *listener) PK() cipher.PubKey { + return l.lAddr.PK +} + +// Addr implements net.Listener +func (l *listener) Port() uint16 { + return l.lAddr.Port +} + +// Network returns network type +func (l *listener) Network() Type { + return l.network +} + +// Introduce is used by Client to introduce a new connection to this Listener +func (l *listener) introduce(conn *conn) error { + select { + case <-l.done: + return io.ErrClosedPipe + default: + l.mx.Lock() + defer l.mx.Unlock() + select { + case l.accept <- conn: + return nil + case <-l.done: + return io.ErrClosedPipe + } + } +} diff --git a/pkg/snet/mock_dialer.go b/pkg/transport/network/mock_dialer.go similarity index 98% rename from pkg/snet/mock_dialer.go rename to pkg/transport/network/mock_dialer.go index 7cbb57833..e0cef6585 100644 --- a/pkg/snet/mock_dialer.go +++ b/pkg/transport/network/mock_dialer.go @@ -1,6 +1,6 @@ // Code generated by mockery v1.0.0. DO NOT EDIT. -package snet +package network import ( context "context" diff --git a/pkg/transport/network/network.go b/pkg/transport/network/network.go new file mode 100644 index 000000000..354eeea07 --- /dev/null +++ b/pkg/transport/network/network.go @@ -0,0 +1,50 @@ +package network + +import ( + "context" + "errors" + "net" + + "github.com/skycoin/dmsg/cipher" +) + +// Type is a type of network. Type affects the way connection is established +// and the way data is sent +type Type string + +const ( + // STCPR is a type of a transport that works via TCP and resolves addresses using address-resolver service. + STCPR Type = "stcpr" + // SUDPH is a type of a transport that works via UDP, resolves addresses using address-resolver service, + // and uses UDP hole punching. + SUDPH Type = "sudph" + // STCP is a type of a transport that works via TCP and resolves addresses using PK table. + STCP Type = "stcp" + // DMSG is a type of a transport that works through an intermediary service + DMSG Type = "dmsg" +) + +//go:generate mockery -name Dialer -case underscore -inpkg + +// Dialer is an entity that can be dialed and asked for its type. +type Dialer interface { + Dial(ctx context.Context, remote cipher.PubKey, port uint16) (net.Conn, error) + Type() string +} + +var ( + // ErrUnknownTransportType is returned when transport type is unknown. + ErrUnknownTransportType = errors.New("unknown transport type") + + // ErrTimeout indicates a timeout. + ErrTimeout = errors.New("timeout") + + // ErrAlreadyListening is returned when transport is already listening. + ErrAlreadyListening = errors.New("already listening") + + // ErrNotListening is returned when transport is not listening. + ErrNotListening = errors.New("not listening") + + // ErrPortOccupied is returned when port is occupied. + ErrPortOccupied = errors.New("port is already occupied") +) diff --git a/pkg/snet/directtp/porter/porter.go b/pkg/transport/network/porter/porter.go similarity index 100% rename from pkg/snet/directtp/porter/porter.go rename to pkg/transport/network/porter/porter.go diff --git a/pkg/transport/network/stcp.go b/pkg/transport/network/stcp.go new file mode 100644 index 000000000..0972e04c7 --- /dev/null +++ b/pkg/transport/network/stcp.go @@ -0,0 +1,74 @@ +package network + +import ( + "context" + "errors" + "io" + "net" + + "github.com/skycoin/dmsg/cipher" + + "github.com/skycoin/skywire/pkg/transport/network/stcp" +) + +// STCPConfig defines config for STCP network. +type STCPConfig struct { + PKTable map[cipher.PubKey]string `json:"pk_table"` + LocalAddr string `json:"local_address"` +} + +type stcpClient struct { + *genericClient + table stcp.PKTable +} + +func newStcp(generic *genericClient, table stcp.PKTable) Client { + client := &stcpClient{genericClient: generic, table: table} + client.netType = STCP + return client +} + +// ErrStcpEntryNotFound is returned when requested PK is not found in the local +// PK table +var ErrStcpEntryNotFound = errors.New("entry not found in PK table") + +// Dial implements Client interface +func (c *stcpClient) Dial(ctx context.Context, rPK cipher.PubKey, rPort uint16) (Conn, error) { + if c.isClosed() { + return nil, io.ErrClosedPipe + } + + c.log.Infof("Dialing PK %v", rPK) + + var conn net.Conn + addr, ok := c.table.Addr(rPK) + if !ok { + return nil, ErrStcpEntryNotFound + } + c.eb.SendTCPDial(context.Background(), string(STCP), addr) + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + + c.log.Infof("Dialed %v:%v@%v", rPK, rPort, conn.RemoteAddr()) + return c.initConnection(ctx, conn, rPK, rPort) +} + +// Start implements Client interface +func (c *stcpClient) Start() error { + if c.connListener != nil { + return ErrAlreadyListening + } + go c.serve() + return nil +} + +func (c *stcpClient) serve() { + lis, err := net.Listen("tcp", c.listenAddr) + if err != nil { + c.log.Errorf("Failed to listen on %q: %v", c.listenAddr, err) + return + } + c.acceptConnections(lis) +} diff --git a/pkg/snet/directtp/pktable/pktable.go b/pkg/transport/network/stcp/pktable.go similarity index 99% rename from pkg/snet/directtp/pktable/pktable.go rename to pkg/transport/network/stcp/pktable.go index bc40761f7..f7662d7ba 100644 --- a/pkg/snet/directtp/pktable/pktable.go +++ b/pkg/transport/network/stcp/pktable.go @@ -1,4 +1,4 @@ -package pktable +package stcp import ( "bufio" diff --git a/pkg/transport/network/stcpr.go b/pkg/transport/network/stcpr.go new file mode 100644 index 000000000..5444da12f --- /dev/null +++ b/pkg/transport/network/stcpr.go @@ -0,0 +1,86 @@ +package network + +import ( + "context" + "fmt" + "io" + "net" + + "github.com/skycoin/dmsg/cipher" + + "github.com/skycoin/skywire/pkg/util/netutil" +) + +type stcprClient struct { + *resolvedClient +} + +func newStcpr(resolved *resolvedClient) Client { + client := &stcprClient{resolvedClient: resolved} + client.netType = STCPR + return client +} + +// Dial implements interface +func (c *stcprClient) Dial(ctx context.Context, rPK cipher.PubKey, rPort uint16) (Conn, error) { + if c.isClosed() { + return nil, io.ErrClosedPipe + } + c.log.Infof("Dialing PK %v", rPK) + visorData, err := c.ar.Resolve(ctx, string(STCPR), rPK) + if err != nil { + return nil, fmt.Errorf("resolve PK: %w", err) + } + c.log.Infof("Resolved PK %v to visor data %v", rPK, visorData) + conn, err := c.dialVisor(ctx, rPK, c.dial) + if err != nil { + return nil, err + } + + return c.initConnection(ctx, conn, rPK, rPort) +} + +func (c *stcprClient) dial(ctx context.Context, addr string) (net.Conn, error) { + c.eb.SendTCPDial(context.Background(), string(STCPR), addr) + dialer := net.Dialer{} + return dialer.DialContext(ctx, "tcp", addr) +} + +// Start implements Client interface +func (c *stcprClient) Start() error { + if c.connListener != nil { + return ErrAlreadyListening + } + go c.serve() + return nil +} + +func (c *stcprClient) serve() { + lis, err := net.Listen("tcp", "") + if err != nil { + c.log.Errorf("Failed to listen on random port: %v", err) + return + } + + localAddr := lis.Addr().String() + _, port, err := net.SplitHostPort(localAddr) + if err != nil { + c.log.Errorf("Failed to extract port from addr %v: %v", err) + return + } + hasPublic, err := netutil.HasPublicIP() + if err != nil { + c.log.Errorf("Failed to check for public IP: %v", err) + } + if !hasPublic { + c.log.Infof("Not binding STCPR: no public IP address found") + return + } + c.log.Infof("Binding") + if err := c.ar.BindSTCPR(context.Background(), port); err != nil { + c.log.Errorf("Failed to bind STCPR: %v", err) + return + } + c.log.Infof("Successfully bound stcpr to port %s", port) + c.acceptConnections(lis) +} diff --git a/pkg/transport/network/sudph.go b/pkg/transport/network/sudph.go new file mode 100644 index 000000000..fc4bef3e3 --- /dev/null +++ b/pkg/transport/network/sudph.go @@ -0,0 +1,163 @@ +package network + +import ( + "context" + "fmt" + "io" + "net" + "time" + + "github.com/AudriusButkevicius/pfilter" + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" + "github.com/xtaci/kcp-go" + + "github.com/skycoin/skywire/internal/packetfilter" + "github.com/skycoin/skywire/pkg/transport/network/addrresolver" + "github.com/skycoin/skywire/pkg/transport/network/handshake" +) + +const ( + // holePunchMessage is sent in a dummy UDP packet that is sent by both parties to establish UDP hole punching. + holePunchMessage = "holepunch" + // dialConnPriority and visorsConnPriority are used to set an order how connection filters apply. + dialConnPriority = 2 + visorsConnPriority = 3 + dialTimeout = 30 * time.Second +) + +type sudphClient struct { + *resolvedClient + filter *pfilter.PacketFilter +} + +func newSudph(resolved *resolvedClient) Client { + client := &sudphClient{resolvedClient: resolved} + client.netType = SUDPH + return client +} + +// Start implements Client interface +func (c *sudphClient) Start() error { + if c.connListener != nil { + return ErrAlreadyListening + } + go c.serve() + return nil +} + +func (c *sudphClient) serve() { + lis, err := c.listen() + if err != nil { + c.log.Errorf("Failed to listen on random port: %v", err) + return + } + c.acceptConnections(lis) +} + +// listen +func (c *sudphClient) listen() (net.Listener, error) { + packetListener, err := net.ListenPacket("udp", "") + if err != nil { + return nil, err + } + c.filter = pfilter.NewPacketFilter(packetListener) + sudphVisorsConn := c.filter.NewConn(visorsConnPriority, nil) + c.filter.Start() + c.log.Infof("Binding") + addrCh, err := c.ar.BindSUDPH(c.filter, c.makeBindHandshake()) + if err != nil { + return nil, err + } + go c.acceptAddresses(sudphVisorsConn, addrCh) + return kcp.ServeConn(nil, 0, 0, sudphVisorsConn) +} + +// make a handshake function that is compatible with address resolver interface +func (c *sudphClient) makeBindHandshake() func(in net.Conn) (net.Conn, error) { + emptyAddr := dmsg.Addr{PK: cipher.PubKey{}, Port: 0} + hs := handshake.InitiatorHandshake(c.SK(), dmsg.Addr{PK: c.PK(), Port: 0}, emptyAddr) + return func(in net.Conn) (net.Conn, error) { + return doHandshake(in, hs, SUDPH, c.log) + } +} + +// acceptAddresses will read visor addresses from addrCh and send holepunch +// packets to them +// Basically each address coming from addrCh is a dial request from some remote +// visor to us. Dialing visor contacts address resolver and gives the address to +// it, address resolver in turn sends us this address. +func (c *sudphClient) acceptAddresses(conn net.PacketConn, addrCh <-chan addrresolver.RemoteVisor) { + for addr := range addrCh { + udpAddr, err := net.ResolveUDPAddr("udp", addr.Addr) + if err != nil { + c.log.WithError(err).Errorf("Failed to resolve UDP address %q", addr) + continue + } + + c.log.Infof("Sending hole punch packet to %v", addr) + + if _, err := conn.WriteTo([]byte(holePunchMessage), udpAddr); err != nil { + c.log.WithError(err).Errorf("Failed to send hole punch packet to %v", udpAddr) + continue + } + c.log.Infof("Sent hole punch packet to %v", addr) + } +} + +// Dial implements interface +func (c *sudphClient) Dial(ctx context.Context, rPK cipher.PubKey, rPort uint16) (Conn, error) { + if c.isClosed() { + return nil, io.ErrClosedPipe + } + // this will lookup visor address in address resolver and then dial that address + conn, err := c.dialVisor(ctx, rPK, c.dialWithTimeout) + if err != nil { + return nil, err + } + + return c.initConnection(ctx, conn, rPK, rPort) +} + +func (c *sudphClient) dialWithTimeout(ctx context.Context, addr string) (net.Conn, error) { + timedCtx, cancel := context.WithTimeout(ctx, dialTimeout) + defer cancel() + c.log.Infof("Dialing %v", addr) + + for { + select { + case <-timedCtx.Done(): + return nil, timedCtx.Err() + default: + conn, err := c.dial(addr) + if err == nil { + c.log.Infof("Dialed %v", addr) + return conn, nil + } + c.log.WithError(err). + Warnf("Failed to dial %v, trying again: %v", addr, err) + } + } +} + +// dial will send holepunch packet to the remote addr over UDP, and +// return the connection +func (c *sudphClient) dial(remoteAddr string) (net.Conn, error) { + rAddr, err := net.ResolveUDPAddr("udp", remoteAddr) + if err != nil { + return nil, fmt.Errorf("net.ResolveUDPAddr (remote): %w", err) + } + + dialConn := c.filter.NewConn(dialConnPriority, packetfilter.NewKCPConversationFilter()) + + if _, err := dialConn.WriteTo([]byte(holePunchMessage), rAddr); err != nil { + return nil, fmt.Errorf("dialConn.WriteTo: %w", err) + } + + kcpConn, err := kcp.NewConn(remoteAddr, nil, 0, 0, dialConn) + if err != nil { + return nil, err + } + + return kcpConn, nil +} diff --git a/pkg/transport/setup/rpc.go b/pkg/transport/setup/rpc.go index d8989b440..625226157 100644 --- a/pkg/transport/setup/rpc.go +++ b/pkg/transport/setup/rpc.go @@ -9,6 +9,7 @@ import ( "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/transport" + "github.com/skycoin/skywire/pkg/transport/network" ) // TransportGateway that exposes methods to be used via RPC @@ -20,7 +21,7 @@ type TransportGateway struct { // TransportRequest to perform an action over RPC type TransportRequest struct { RemotePK cipher.PubKey - Type string + Type network.Type } // UUIDRequest contains id in UUID format @@ -33,7 +34,7 @@ type TransportResponse struct { ID uuid.UUID Local cipher.PubKey Remote cipher.PubKey - Type string + Type network.Type IsUp bool } diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go index d4af4a0e4..679d017fd 100644 --- a/pkg/transport/transport.go +++ b/pkg/transport/transport.go @@ -9,6 +9,8 @@ import ( "github.com/google/uuid" "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" + + "github.com/skycoin/skywire/pkg/transport/network" ) var log = logging.MustGetLogger("transport") @@ -17,7 +19,8 @@ var log = logging.MustGetLogger("transport") // Generated uuid is: // - always the same for a given pair // - GenTransportUUID(keyA,keyB) == GenTransportUUID(keyB, keyA) -func MakeTransportID(keyA, keyB cipher.PubKey, tpType string) uuid.UUID { +func MakeTransportID(keyA, keyB cipher.PubKey, netType network.Type) uuid.UUID { + tpType := string(netType) keys := SortEdges(keyA, keyB) b := make([]byte, 33*2+len(tpType)) i := 0 diff --git a/pkg/util/cipherutil/cipherutil.go b/pkg/util/cipherutil/cipherutil.go new file mode 100644 index 000000000..d537e8368 --- /dev/null +++ b/pkg/util/cipherutil/cipherutil.go @@ -0,0 +1,24 @@ +package cipherutil + +import "github.com/skycoin/dmsg/cipher" + +// KeyPair is a pair of public and secret keys +type KeyPair struct { + PK cipher.PubKey + SK cipher.SecKey +} + +// GenKeyPairs generates n random key pairs +func GenKeyPairs(n int) []KeyPair { + pairs := make([]KeyPair, n) + for i := range pairs { + pk, sk, err := cipher.GenerateDeterministicKeyPair([]byte{byte(i)}) + if err != nil { + panic(err) + } + + pairs[i] = KeyPair{PK: pk, SK: sk} + } + + return pairs +} diff --git a/pkg/visor/api.go b/pkg/visor/api.go index 3817ca5c3..c98da58af 100644 --- a/pkg/visor/api.go +++ b/pkg/visor/api.go @@ -20,6 +20,7 @@ import ( "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/skywire/pkg/skyenv" "github.com/skycoin/skywire/pkg/transport" + "github.com/skycoin/skywire/pkg/transport/network" "github.com/skycoin/skywire/pkg/util/netutil" "github.com/skycoin/skywire/pkg/util/updater" "github.com/skycoin/skywire/pkg/visor/dmsgtracker" @@ -280,7 +281,7 @@ func (v *Visor) StartApp(appName string) error { if appName == skyenv.VPNClientName { // todo: can we use some kind of app start hook that will be used for both autostart // and start? Reason: this is also called in init for autostart - maker := vpnEnvMaker(v.conf, v.net, v.tpM.STCPRRemoteAddrs()) + maker := vpnEnvMaker(v.conf, v.dmsgC, v.tpM.STCPRRemoteAddrs()) envs, err = maker() if err != nil { return err @@ -478,10 +479,10 @@ func (v *Visor) TransportTypes() ([]string, error) { func (v *Visor) Transports(types []string, pks []cipher.PubKey, logs bool) ([]*TransportSummary, error) { var result []*TransportSummary - typeIncluded := func(tType string) bool { + typeIncluded := func(tType network.Type) bool { if types != nil { for _, ft := range types { - if tType == ft { + if string(tType) == ft { return true } } @@ -532,7 +533,7 @@ func (v *Visor) AddTransport(remote cipher.PubKey, tpType string, public bool, t v.log.Debugf("Saving transport to %v via %v", remote, tpType) - tp, err := v.tpM.SaveTransport(ctx, remote, tpType, transport.LabelUser) + tp, err := v.tpM.SaveTransport(ctx, remote, network.Type(tpType), transport.LabelUser) if err != nil { return nil, err } diff --git a/pkg/visor/init.go b/pkg/visor/init.go index 834da9aec..f6f4ca5a8 100644 --- a/pkg/visor/init.go +++ b/pkg/visor/init.go @@ -23,15 +23,16 @@ import ( "github.com/skycoin/skywire/pkg/app/appevent" "github.com/skycoin/skywire/pkg/app/appserver" "github.com/skycoin/skywire/pkg/app/launcher" + "github.com/skycoin/skywire/pkg/dmsgc" "github.com/skycoin/skywire/pkg/routefinder/rfclient" "github.com/skycoin/skywire/pkg/router" "github.com/skycoin/skywire/pkg/servicedisc" "github.com/skycoin/skywire/pkg/setup/setupclient" "github.com/skycoin/skywire/pkg/skyenv" - "github.com/skycoin/skywire/pkg/snet" - "github.com/skycoin/skywire/pkg/snet/arclient" - "github.com/skycoin/skywire/pkg/snet/directtp/tptypes" "github.com/skycoin/skywire/pkg/transport" + "github.com/skycoin/skywire/pkg/transport/network" + "github.com/skycoin/skywire/pkg/transport/network/addrresolver" + "github.com/skycoin/skywire/pkg/transport/network/stcp" ts "github.com/skycoin/skywire/pkg/transport/setup" "github.com/skycoin/skywire/pkg/transport/tpdclient" "github.com/skycoin/skywire/pkg/util/netutil" @@ -62,10 +63,10 @@ var ( ar vinit.Module // App discovery disc vinit.Module - // Snet (different network types) - sn vinit.Module // dmsg pty: a remote terminal to the visor working over dmsg protocol pty vinit.Module + // Dmsg module + dmsgC vinit.Module // Transport manager tr vinit.Module // Transport setup @@ -104,20 +105,20 @@ func registerModules(logger *logging.MasterLogger) { ebc = maker("event_broadcaster", initEventBroadcaster) ar = maker("address_resolver", initAddressResolver) disc = maker("discovery", initDiscovery) - sn = maker("snet", initSNet, &ar, &disc, &ebc) - dmsgCtrl = maker("dmsg_ctrl", initDmsgCtrl, &sn) - pty = maker("dmsg_pty", initDmsgpty, &sn) - tr = maker("transport", initTransport, &sn, &ebc) - trs = maker("transport_setup", initTransportSetup, &sn, &tr) - rt = maker("router", initRouter, &tr, &sn) - launch = maker("launcher", initLauncher, &ebc, &disc, &sn, &tr, &rt) + dmsgC = maker("dmsg", initDmsg, &ebc) + dmsgCtrl = maker("dmsg_ctrl", initDmsgCtrl, &dmsgC) + pty = maker("dmsg_pty", initDmsgpty, &dmsgC) + tr = maker("transport", initTransport, &ar, &ebc, &dmsgC) + rt = maker("router", initRouter, &tr, &dmsgC) + launch = maker("launcher", initLauncher, &ebc, &disc, &dmsgC, &tr, &rt) cli = maker("cli", initCLI) - hvs = maker("hypervisors", initHypervisors, &sn) + hvs = maker("hypervisors", initHypervisors, &dmsgC) ut = maker("uptime_tracker", initUptimeTracker) pv = maker("public_visors", initPublicVisors, &tr) - pvs = maker("public_visor", initPublicVisor, &sn, &ar, &disc) - vis = vinit.MakeModule("visor", vinit.DoNothing, logger, &up, &ebc, &ar, &disc, &sn, &pty, - &tr, &rt, &launch, &cli, &hvs, &ut, &pv, &pvs, &trs, &dmsgCtrl) + pvs = maker("public_visor", initPublicVisor, &tr, &ar, &disc) + trs = maker("transport_setup", initTransportSetup, &dmsgC, &tr) + vis = vinit.MakeModule("visor", vinit.DoNothing, logger, &up, &ebc, &ar, &disc, &pty, + &tr, &rt, &launch, &cli, &trs, &hvs, &ut, &pv, &pvs, &dmsgCtrl) hv = maker("hypervisor", initHypervisor, &vis) } @@ -149,7 +150,7 @@ func initEventBroadcaster(ctx context.Context, v *Visor, log *logging.Logger) er func initAddressResolver(ctx context.Context, v *Visor, log *logging.Logger) error { conf := v.conf.Transport - arClient, err := arclient.NewHTTP(conf.AddressResolver, v.conf.PK, v.conf.SK, log) + arClient, err := addrresolver.NewHTTP(conf.AddressResolver, v.conf.PK, v.conf.SK, log) if err != nil { err := fmt.Errorf("failed to create address resolver client: %w", err) return err @@ -180,37 +181,28 @@ func initDiscovery(ctx context.Context, v *Visor, log *logging.Logger) error { return nil } -func initSNet(ctx context.Context, v *Visor, log *logging.Logger) error { - nc := snet.NetworkConfigs{ - Dmsg: v.conf.Dmsg, - STCP: v.conf.STCP, +func initDmsg(ctx context.Context, v *Visor, log *logging.Logger) error { + if v.conf.Dmsg == nil { + return fmt.Errorf("cannot initialize dmsg: empty configuration") } + dmsgC := dmsgc.New(v.conf.PK, v.conf.SK, v.ebc, v.conf.Dmsg) - conf := snet.Config{ - PubKey: v.conf.PK, - SecKey: v.conf.SK, - ARClient: v.arClient, - NetworkConfigs: nc, - } - - n, err := snet.New(conf, v.ebc, v.MasterLogger()) - if err != nil { - return err - } - - if err := n.Init(); err != nil { - return err - } - v.pushCloseStack("snet", n.Close) + time.Sleep(200 * time.Millisecond) + go dmsgC.Serve(context.Background()) + time.Sleep(200 * time.Millisecond) v.initLock.Lock() - v.net = n + v.dmsgC = dmsgC v.initLock.Unlock() + + v.pushCloseStack("dmsgC", func() error { + return dmsgC.Close() + }) return nil } func initDmsgCtrl(ctx context.Context, v *Visor, _ *logging.Logger) error { - dmsgC := v.net.Dmsg() + dmsgC := v.dmsgC if dmsgC == nil { return nil } @@ -220,7 +212,7 @@ func initDmsgCtrl(ctx context.Context, v *Visor, _ *logging.Logger) error { select { case <-time.After(dmsgTimeout): logger.Warn("Failed to connect to the dmsg network, will try again later.") - case <-v.net.Dmsg().Ready(): + case <-v.dmsgC.Ready(): logger.Info("Connected to the dmsg network.") } // dmsgctrl setup @@ -251,22 +243,24 @@ func initTransport(ctx context.Context, v *Visor, log *logging.Logger) error { LogStore: logS, } managerLogger := v.MasterLogger().PackageLogger("transport_manager") - tpM, err := transport.NewManager(managerLogger, v.net, &tpMConf) + + // todo: pass down configuration? + table := stcp.NewTable(v.conf.STCP.PKTable) + factory := network.ClientFactory{ + PK: v.conf.PK, + SK: v.conf.SK, + ListenAddr: v.conf.STCP.LocalAddr, + PKTable: table, + ARClient: v.arClient, + EB: v.ebc, + DmsgC: v.dmsgC, + } + tpM, err := transport.NewManager(managerLogger, v.arClient, v.ebc, &tpMConf, factory) if err != nil { err := fmt.Errorf("failed to start transport manager: %w", err) return err } - tpM.OnAfterTPClosed(func(network, addr string) { - if network == tptypes.STCPR && addr != "" { - data := appevent.TCPCloseData{RemoteNet: network, RemoteAddr: addr} - event := appevent.NewEvent(appevent.TCPClose, data) - if err := v.ebc.Broadcast(context.Background(), event); err != nil { - v.log.WithError(err).Errorln("Failed to broadcast TCPClose event") - } - } - }) - ctx, cancel := context.WithCancel(context.Background()) wg := new(sync.WaitGroup) wg.Add(1) @@ -291,7 +285,7 @@ func initTransport(ctx context.Context, v *Visor, log *logging.Logger) error { func initTransportSetup(ctx context.Context, v *Visor, log *logging.Logger) error { ctx, cancel := context.WithCancel(ctx) - ts, err := ts.NewTransportListener(ctx, v.conf, v.net.Dmsg(), v.tpM, v.MasterLogger()) + ts, err := ts.NewTransportListener(ctx, v.conf, v.dmsgC, v.tpM, v.MasterLogger()) if err != nil { cancel() return err @@ -320,7 +314,7 @@ func initRouter(ctx context.Context, v *Visor, log *logging.Logger) error { MinHops: v.conf.Routing.MinHops, } - r, err := router.New(v.net, &rConf) + r, err := router.New(v.dmsgC, &rConf) if err != nil { err := fmt.Errorf("failed to create router: %w", err) return err @@ -393,15 +387,15 @@ func initLauncher(ctx context.Context, v *Visor, log *logging.Logger) error { launchLog := v.MasterLogger().PackageLogger("launcher") - launch, err := launcher.NewLauncher(launchLog, launchConf, v.net.Dmsg(), v.router, procM) + launch, err := launcher.NewLauncher(launchLog, launchConf, v.dmsgC, v.router, procM) if err != nil { err := fmt.Errorf("failed to start launcher: %w", err) return err } err = launch.AutoStart(launcher.EnvMap{ - skyenv.VPNClientName: vpnEnvMaker(v.conf, v.net, v.tpM.STCPRRemoteAddrs()), - skyenv.VPNServerName: vpnEnvMaker(v.conf, v.net, nil), + skyenv.VPNClientName: vpnEnvMaker(v.conf, v.dmsgC, v.tpM.STCPRRemoteAddrs()), + skyenv.VPNServerName: vpnEnvMaker(v.conf, v.dmsgC, nil), }) if err != nil { @@ -418,7 +412,7 @@ func initLauncher(ctx context.Context, v *Visor, log *logging.Logger) error { } // Make an env maker function for vpn application -func vpnEnvMaker(conf *visorconfig.V1, n *snet.Network, tpRemoteAddrs []string) launcher.EnvMaker { +func vpnEnvMaker(conf *visorconfig.V1, dmsgC *dmsg.Client, tpRemoteAddrs []string) launcher.EnvMaker { return launcher.EnvMaker(func() ([]string, error) { var envCfg vpn.DirectRoutesEnvConfig @@ -427,7 +421,7 @@ func vpnEnvMaker(conf *visorconfig.V1, n *snet.Network, tpRemoteAddrs []string) r := dmsgnetutil.NewRetrier(logrus.New(), 1*time.Second, 10*time.Second, 0, 1) err := r.Do(context.Background(), func() error { - for _, ses := range n.Dmsg().AllSessions() { + for _, ses := range dmsgC.AllSessions() { envCfg.DmsgServers = append(envCfg.DmsgServers, ses.RemoteTCPAddr().String()) } @@ -518,7 +512,7 @@ func initHypervisors(ctx context.Context, v *Visor, log *logging.Logger) error { go func(hvErrs chan error) { defer wg.Done() - ServeRPCClient(ctx, log, v.net, rpcS, addr, hvErrs) + ServeRPCClient(ctx, log, v.dmsgC, rpcS, addr, hvErrs) }(hvErrs) v.pushCloseStack("hypervisor."+hvPK.String()[:shortHashLen], func() error { @@ -595,7 +589,7 @@ func initPublicVisor(_ context.Context, v *Visor, log *logging.Logger) error { } // todo: consider moving this to transport into some helper function - stcpr, ok := v.net.STcpr() + stcpr, ok := v.tpM.Stcpr() if !ok { return nil } @@ -663,7 +657,7 @@ func initHypervisor(_ context.Context, v *Visor, log *logging.Logger) error { conf.DmsgDiscovery = v.conf.Dmsg.Discovery // Prepare hypervisor. - hv, err := New(conf, v, v.net.Dmsg()) + hv, err := New(conf, v, v.dmsgC) if err != nil { v.log.Fatalln("Failed to start hypervisor:", err) } diff --git a/pkg/visor/init_unix.go b/pkg/visor/init_unix.go index 14bc0e23f..f0c4df7a8 100644 --- a/pkg/visor/init_unix.go +++ b/pkg/visor/init_unix.go @@ -45,7 +45,7 @@ func initDmsgpty(ctx context.Context, v *Visor, log *logging.Logger) error { v.log.Errorf("Cannot add itself to the pty whitelist: %s", err) } - dmsgC := v.net.Dmsg() + dmsgC := v.dmsgC if dmsgC == nil { err := errors.New("cannot create dmsgpty with nil dmsg client") return err diff --git a/pkg/visor/rpc.go b/pkg/visor/rpc.go index 43bb84765..34f0bcb9b 100644 --- a/pkg/visor/rpc.go +++ b/pkg/visor/rpc.go @@ -14,6 +14,7 @@ import ( "github.com/skycoin/skywire/pkg/app/launcher" "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/skywire/pkg/transport" + "github.com/skycoin/skywire/pkg/transport/network" "github.com/skycoin/skywire/pkg/util/rpcutil" "github.com/skycoin/skywire/pkg/util/updater" ) @@ -121,7 +122,7 @@ type TransportSummary struct { ID uuid.UUID `json:"id"` Local cipher.PubKey `json:"local_pk"` Remote cipher.PubKey `json:"remote_pk"` - Type string `json:"type"` + Type network.Type `json:"type"` Log *transport.LogEntry `json:"log,omitempty"` IsSetup bool `json:"is_setup"` IsUp bool `json:"is_up"` diff --git a/pkg/visor/rpc_client.go b/pkg/visor/rpc_client.go index 7545c3ef3..b573273b9 100644 --- a/pkg/visor/rpc_client.go +++ b/pkg/visor/rpc_client.go @@ -25,8 +25,9 @@ import ( "github.com/skycoin/skywire/pkg/router" "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/skywire/pkg/skyenv" - "github.com/skycoin/skywire/pkg/snet/snettest" "github.com/skycoin/skywire/pkg/transport" + "github.com/skycoin/skywire/pkg/transport/network" + "github.com/skycoin/skywire/pkg/util/cipherutil" "github.com/skycoin/skywire/pkg/util/updater" ) @@ -445,7 +446,7 @@ func (rc *rpcClient) UpdateStatus() (string, error) { type mockRPCClient struct { startedAt time.Time o *Overview - tpTypes []string + tpTypes []network.Type rt routing.Table logS appcommon.LogStore sync.RWMutex @@ -455,7 +456,7 @@ type mockRPCClient struct { func NewMockRPCClient(r *rand.Rand, maxTps int, maxRules int) (cipher.PubKey, API, error) { log := logging.MustGetLogger("mock-rpc-client") - types := []string{"messaging", "native"} + types := []network.Type{"messaging", "native"} localPK, _ := cipher.GenerateKeyPair() log.Infof("generating mock client with: localPK(%s) maxTps(%d) maxRules(%d)", localPK, maxTps, maxRules) @@ -496,7 +497,7 @@ func NewMockRPCClient(r *rand.Rand, maxTps int, maxRules int) (cipher.PubKey, AP panic(err) } - keys := snettest.GenKeyPairs(2) + keys := cipherutil.GenKeyPairs(2) fwdRule := routing.ForwardRule(ruleKeepAlive, fwdRID[0], routing.RouteID(r.Uint32()), uuid.New(), keys[0].PK, keys[1].PK, 0, 0) if err := rt.SaveRule(fwdRule); err != nil { @@ -759,7 +760,11 @@ func (mc *mockRPCClient) GetAppConnectionsSummary(_ string) ([]appserver.Connect // TransportTypes implements API. func (mc *mockRPCClient) TransportTypes() ([]string, error) { - return mc.tpTypes, nil + var res []string + for _, tptype := range mc.tpTypes { + res = append(res, string(tptype)) + } + return res, nil } // Transports implements API. @@ -770,7 +775,7 @@ func (mc *mockRPCClient) Transports(types []string, pks []cipher.PubKey, logs bo tp := tp if types != nil { for _, reqT := range types { - if tp.Type == reqT { + if string(tp.Type) == reqT { goto TypeOK } } @@ -817,10 +822,10 @@ func (mc *mockRPCClient) Transport(tid uuid.UUID) (*TransportSummary, error) { // AddTransport implements API. func (mc *mockRPCClient) AddTransport(remote cipher.PubKey, tpType string, _ bool, _ time.Duration) (*TransportSummary, error) { summary := &TransportSummary{ - ID: transport.MakeTransportID(mc.o.PubKey, remote, tpType), + ID: transport.MakeTransportID(mc.o.PubKey, remote, network.Type(tpType)), Local: mc.o.PubKey, Remote: remote, - Type: tpType, + Type: network.Type(tpType), Log: new(transport.LogEntry), } return summary, mc.do(true, func() error { diff --git a/pkg/visor/rpc_client_serve.go b/pkg/visor/rpc_client_serve.go index ec1760f1f..c9e449225 100644 --- a/pkg/visor/rpc_client_serve.go +++ b/pkg/visor/rpc_client_serve.go @@ -2,14 +2,13 @@ package visor import ( "context" + "net" "net/rpc" "time" "github.com/sirupsen/logrus" "github.com/skycoin/dmsg" "github.com/skycoin/dmsg/netutil" - - "github.com/skycoin/skywire/pkg/snet" ) func isDone(ctx context.Context) bool { @@ -22,15 +21,16 @@ func isDone(ctx context.Context) bool { } // ServeRPCClient repetitively dials to a remote dmsg address and serves a RPC server to that address. -func ServeRPCClient(ctx context.Context, log logrus.FieldLogger, n *snet.Network, rpcS *rpc.Server, rAddr dmsg.Addr, errCh chan<- error) { +func ServeRPCClient(ctx context.Context, log logrus.FieldLogger, dmsgC *dmsg.Client, rpcS *rpc.Server, rAddr dmsg.Addr, errCh chan<- error) { const maxBackoff = time.Second * 5 retry := netutil.NewRetrier(log, netutil.DefaultInitBackoff, maxBackoff, netutil.DefaultTries, netutil.DefaultFactor) for { - var conn *snet.Conn + var conn net.Conn err := retry.Do(ctx, func() (rErr error) { log.Info("Dialing...") - conn, rErr = n.Dial(ctx, dmsg.Type, rAddr.PK, rAddr.Port) + addr := dmsg.Addr{PK: rAddr.PK, Port: rAddr.Port} + conn, rErr = dmsgC.Dial(ctx, addr) return rErr }) if err != nil { diff --git a/pkg/visor/rpc_test.go b/pkg/visor/rpc_test.go index 178506d47..e064b4ed4 100644 --- a/pkg/visor/rpc_test.go +++ b/pkg/visor/rpc_test.go @@ -15,8 +15,8 @@ import ( "github.com/skycoin/skywire/internal/testhelpers" "github.com/skycoin/skywire/internal/utclient" "github.com/skycoin/skywire/pkg/routefinder/rfclient" - "github.com/skycoin/skywire/pkg/snet/arclient" "github.com/skycoin/skywire/pkg/transport" + "github.com/skycoin/skywire/pkg/transport/network/addrresolver" "github.com/skycoin/skywire/pkg/visor/visorconfig" ) @@ -40,7 +40,7 @@ func TestHealth(t *testing.T) { utClient := &utclient.MockAPIClient{} utClient.On("Health", mock.Anything).Return(http.StatusOK, nil) - arClient := &arclient.MockAPIClient{} + arClient := &addrresolver.MockAPIClient{} arClient.On("Health", mock.Anything).Return(http.StatusOK, nil) rfClient := &rfclient.MockClient{} diff --git a/pkg/visor/visor.go b/pkg/visor/visor.go index 9a8277e07..9486028bc 100644 --- a/pkg/visor/visor.go +++ b/pkg/visor/visor.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/skycoin/dmsg" "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/internal/utclient" @@ -18,9 +19,8 @@ import ( "github.com/skycoin/skywire/pkg/restart" "github.com/skycoin/skywire/pkg/routefinder/rfclient" "github.com/skycoin/skywire/pkg/router" - "github.com/skycoin/skywire/pkg/snet" - "github.com/skycoin/skywire/pkg/snet/arclient" "github.com/skycoin/skywire/pkg/transport" + "github.com/skycoin/skywire/pkg/transport/network/addrresolver" "github.com/skycoin/skywire/pkg/util/updater" "github.com/skycoin/skywire/pkg/visor/logstore" "github.com/skycoin/skywire/pkg/visor/visorconfig" @@ -54,11 +54,11 @@ type Visor struct { updater *updater.Updater uptimeTracker utclient.APIClient - ebc *appevent.Broadcaster // event broadcaster + ebc *appevent.Broadcaster // event broadcaster + dmsgC *dmsg.Client - net *snet.Network tpM *transport.Manager - arClient arclient.APIClient + arClient addrresolver.APIClient router router.Router rfClient rfclient.Client @@ -219,6 +219,6 @@ func (v *Visor) uptimeTrackerClient() utclient.APIClient { } // addressResolverClient is a convenience function to obtain uptime address resovler client. -func (v *Visor) addressResolverClient() arclient.APIClient { +func (v *Visor) addressResolverClient() addrresolver.APIClient { return v.arClient } diff --git a/pkg/visor/visorconfig/config.go b/pkg/visor/visorconfig/config.go index d6cecdc92..dbdb8ebe6 100644 --- a/pkg/visor/visorconfig/config.go +++ b/pkg/visor/visorconfig/config.go @@ -5,10 +5,11 @@ import ( "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/app/launcher" + "github.com/skycoin/skywire/pkg/dmsgc" "github.com/skycoin/skywire/pkg/restart" "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/skywire/pkg/skyenv" - "github.com/skycoin/skywire/pkg/snet" + "github.com/skycoin/skywire/pkg/transport/network" "github.com/skycoin/skywire/pkg/visor/hypervisorconfig" ) @@ -18,7 +19,7 @@ import ( func MakeBaseConfig(common *Common) *V1 { conf := new(V1) conf.Common = common - conf.Dmsg = &snet.DmsgConfig{ + conf.Dmsg = &dmsgc.DmsgConfig{ Discovery: skyenv.DefaultDmsgDiscAddr, SessionsCount: 1, } @@ -78,7 +79,7 @@ func defaultConfigFromCommon(cc *Common, hypervisor bool) (*V1, error) { CLIAddr: skyenv.DefaultDmsgPtyCLIAddr, } - conf.STCP = &snet.STCPConfig{ + conf.STCP = &network.STCPConfig{ LocalAddr: skyenv.DefaultSTCPAddr, PKTable: nil, } diff --git a/pkg/visor/visorconfig/v0.go b/pkg/visor/visorconfig/v0.go index a80b0c72e..b4d03e042 100644 --- a/pkg/visor/visorconfig/v0.go +++ b/pkg/visor/visorconfig/v0.go @@ -3,8 +3,9 @@ package visorconfig import ( "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skywire/pkg/dmsgc" "github.com/skycoin/skywire/pkg/routing" - "github.com/skycoin/skywire/pkg/snet" + "github.com/skycoin/skywire/pkg/transport/network" ) // V0Name is the version string before proper versioning is implemented. @@ -20,11 +21,11 @@ type V0 struct { SecKey cipher.SecKey `json:"secret_key"` } `json:"key_pair"` - Dmsg *snet.DmsgConfig `json:"dmsg"` + Dmsg *dmsgc.DmsgConfig `json:"dmsg"` DmsgPty *V1Dmsgpty `json:"dmsg_pty,omitempty"` - STCP *snet.STCPConfig `json:"stcp,omitempty"` + STCP *network.STCPConfig `json:"stcp,omitempty"` Transport *struct { Discovery string `json:"discovery"` diff --git a/pkg/visor/visorconfig/v1.go b/pkg/visor/visorconfig/v1.go index 3f1cc3c12..42815688d 100644 --- a/pkg/visor/visorconfig/v1.go +++ b/pkg/visor/visorconfig/v1.go @@ -8,7 +8,8 @@ import ( "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skywire/pkg/app/launcher" - "github.com/skycoin/skywire/pkg/snet" + "github.com/skycoin/skywire/pkg/dmsgc" + "github.com/skycoin/skywire/pkg/transport/network" "github.com/skycoin/skywire/pkg/visor/hypervisorconfig" ) @@ -39,13 +40,13 @@ type V1 struct { *Common mu sync.RWMutex - Dmsg *snet.DmsgConfig `json:"dmsg"` - Dmsgpty *V1Dmsgpty `json:"dmsgpty,omitempty"` - STCP *snet.STCPConfig `json:"stcp,omitempty"` - Transport *V1Transport `json:"transport"` - Routing *V1Routing `json:"routing"` - UptimeTracker *V1UptimeTracker `json:"uptime_tracker,omitempty"` - Launcher *V1Launcher `json:"launcher"` + Dmsg *dmsgc.DmsgConfig `json:"dmsg"` + Dmsgpty *V1Dmsgpty `json:"dmsgpty,omitempty"` + STCP *network.STCPConfig `json:"stcp,omitempty"` + Transport *V1Transport `json:"transport"` + Routing *V1Routing `json:"routing"` + UptimeTracker *V1UptimeTracker `json:"uptime_tracker,omitempty"` + Launcher *V1Launcher `json:"launcher"` Hypervisors []cipher.PubKey `json:"hypervisors"` CLIAddr string `json:"cli_addr"`