From 21783cc3915c7e1c8b2ba7c2f2862a151b4ec8e4 Mon Sep 17 00:00:00 2001 From: Alex Yu Date: Mon, 15 Apr 2019 12:28:00 +0300 Subject: [PATCH] Changes: 1. Attempt to adopt tests from skywire-services --- pkg/setup/node.go | 286 ++++++++++++++++++++++++++++++++ pkg/setup/node_test.go | 360 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 646 insertions(+) create mode 100644 pkg/setup/node.go create mode 100644 pkg/setup/node_test.go diff --git a/pkg/setup/node.go b/pkg/setup/node.go new file mode 100644 index 0000000000..3f085e7c05 --- /dev/null +++ b/pkg/setup/node.go @@ -0,0 +1,286 @@ +package setup + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "time" + + "github.com/skycoin/skycoin/src/util/logging" + "github.com/skycoin/skywire/pkg/cipher" + "github.com/skycoin/skywire/pkg/messaging" + mClient "github.com/skycoin/skywire/pkg/messaging-discovery/client" + "github.com/skycoin/skywire/pkg/routing" + + // ssetup "github.com/skycoin/skywire/pkg/setup" + "github.com/skycoin/skywire/pkg/transport" + trClient "github.com/skycoin/skywire/pkg/transport-discovery/client" + // "github.com/watercompany/skywire-services/internal/metrics" +) + +// Hop is a wrapper around transport hop to add functionality +type Hop struct { + *routing.Hop + routeID routing.RouteID +} + +// Node performs routes setup operations over messaging channel. +type Node struct { + Logger *logging.Logger + + tm *transport.Manager + messenger *messaging.Client + srvCount int + // metrics metrics.Recorder +} + +// NewNode constructs a new SetupNode. +func NewNode(conf *Config, metrics metrics.Recorder) (*Node, error) { + pk := conf.PubKey + sk := conf.SecKey + + logger := logging.NewMasterLogger() + if lvl, err := logging.LevelFromString(conf.LogLevel); err == nil { + logger.SetLevel(lvl) + } + messenger := messaging.NewClient(&messaging.Config{ + PubKey: pk, + SecKey: sk, + Discovery: mClient.NewHTTP(conf.Messaging.Discovery), + Retries: 10, + RetryDelay: time.Second, + }) + messenger.Logger = logger.PackageLogger("messenger") + + trDiscovery, err := trClient.NewHTTP(conf.TransportDiscovery, pk, sk) + if err != nil { + return nil, fmt.Errorf("trdiscovery: %s", err) + } + + tmConf := &transport.ManagerConfig{ + PubKey: pk, + SecKey: sk, + DiscoveryClient: trDiscovery, + LogStore: transport.InMemoryTransportLogStore(), + } + + tm, err := transport.NewManager(tmConf, messenger) + if err != nil { + log.Fatal("Failed to setup Transport Manager: ", err) + } + tm.Logger = logger.PackageLogger("trmanager") + + return &Node{ + Logger: logger.PackageLogger("routesetup"), + metrics: metrics, + tm: tm, + messenger: messenger, + srvCount: conf.Messaging.ServerCount, + }, nil +} + +// Serve starts transport listening loop. +func (sn *Node) Serve(ctx context.Context) error { + if sn.srvCount > 0 { + if err := sn.messenger.ConnectToInitialServers(ctx, sn.srvCount); err != nil { + return fmt.Errorf("messaging: %s", err) + } + sn.Logger.Info("Connected to messaging servers") + } + + acceptCh, dialCh := sn.tm.Observe() + go func() { + for tr := range acceptCh { + go func(t transport.Transport) { + for { + if err := sn.serveTransport(t); err != nil { + sn.Logger.Warnf("Failed to serve Transport: %s", err) + return + } + } + }(tr) + } + }() + + go func() { + for range dialCh { + } + }() + + sn.Logger.Info("Starting Setup Node") + return sn.tm.Serve(ctx) +} + +func (sn *Node) createLoop(l *routing.Loop) error { + sn.Logger.Infof("Creating new Loop %s", l) + rRouteID, err := sn.createRoute(l.Expiry, l.Reverse, l.LocalPort, l.RemotePort) + if err != nil { + return err + } + + fRouteID, err := sn.createRoute(l.Expiry, l.Forward, l.RemotePort, l.LocalPort) + if err != nil { + return err + } + + if len(l.Forward) == 0 || len(l.Reverse) == 0 { + return nil + } + + initiator := l.Initiator() + responder := l.Responder() + + ldR := &LoopData{RemotePK: initiator, RemotePort: l.LocalPort, LocalPort: l.RemotePort, RouteID: rRouteID, NoiseMessage: l.NoiseMessage} + noiseRes, err := sn.connectLoop(responder, ldR) + if err != nil { + sn.Logger.Warnf("Failed to confirm loop with responder: %s", err) + return fmt.Errorf("loop connect: %s", err) + } + + ldI := &LoopData{RemotePK: responder, RemotePort: l.RemotePort, LocalPort: l.LocalPort, RouteID: fRouteID, NoiseMessage: noiseRes} + if _, err := sn.connectLoop(initiator, ldI); err != nil { + sn.Logger.Warnf("Failed to confirm loop with initiator: %s", err) + if err := sn.closeLoop(responder, ldR); err != nil { + sn.Logger.Warnf("Failed to close loop: %s", err) + } + + return fmt.Errorf("loop connect: %s", err) + } + + sn.Logger.Infof("Created Loop %s", l) + return nil +} + +func (sn *Node) createRoute(expireAt time.Time, route routing.Route, rport, lport uint16) (routing.RouteID, error) { + if len(route) == 0 { + return 0, nil + } + + sn.Logger.Infof("Creating new Route %s", route) + r := make([]*Hop, len(route)) + + initiator := route[0].From + for idx := len(r) - 1; idx >= 0; idx-- { + hop := &Hop{Hop: route[idx]} + r[idx] = hop + var rule routing.Rule + if idx == len(r)-1 { + rule = routing.AppRule(expireAt, 0, initiator, lport, rport) + } else { + nextHop := r[idx+1] + rule = routing.ForwardRule(expireAt, nextHop.routeID, nextHop.Transport) + } + + routeID, err := sn.setupRule(hop.To, rule) + if err != nil { + return 0, fmt.Errorf("rule setup: %s", err) + } + + hop.routeID = routeID + } + + rule := routing.ForwardRule(expireAt, r[0].routeID, r[0].Transport) + routeID, err := sn.setupRule(initiator, rule) + if err != nil { + return 0, fmt.Errorf("rule setup: %s", err) + } + + return routeID, nil +} + +// Close closes underlying transport manager. +func (sn *Node) Close() error { + return sn.tm.Close() +} + +func (sn *Node) serveTransport(tr transport.Transport) error { + proto := NewSetupProtocol(tr) + sp, data, err := proto.ReadPacket() + if err != nil { + return err + } + + sn.Logger.Infof("Got new Setup request with type %s: %s", sp, string(data)) + + startTime := time.Now() + switch sp { + case PacketCreateLoop: + loop := &routing.Loop{} + if err = json.Unmarshal(data, loop); err == nil { + err = sn.createLoop(loop) + } + case PacketCloseLoop: + ld := &LoopData{} + if err = json.Unmarshal(data, ld); err == nil { + remote, ok := sn.tm.Remote(tr.Edges()) + if !ok { + return errors.New("configured PubKey not found in edges") + } + err = sn.closeLoop(ld.RemotePK, &LoopData{RemotePK: remote, RemotePort: ld.LocalPort, LocalPort: ld.RemotePort}) + } + default: + err = errors.New("unknown foundation packet") + } + sn.metrics.Record(time.Since(startTime), err != nil) + + if err != nil { + sn.Logger.Infof("Setup request with type %s failed: %s", sp, err) + return proto.WritePacket(RespFailure, err) + } + + return proto.WritePacket(RespSuccess, nil) +} + +func (sn *Node) connectLoop(on cipher.PubKey, ld *LoopData) (noiseRes []byte, err error) { + tr, err := sn.tm.CreateTransport(context.Background(), on, "messaging", false) + if err != nil { + err = fmt.Errorf("transport: %s", err) + return + } + defer tr.Close() + + proto := NewSetupProtocol(tr) + res, err := ConfirmLoop(proto, ld) + if err != nil { + return nil, err + } + + sn.Logger.Infof("Confirmed loop on %s with %s. RemotePort: %d. LocalPort: %d", on, ld.RemotePK, ld.RemotePort, ld.LocalPort) + return res, nil +} + +func (sn *Node) closeLoop(on cipher.PubKey, ld *LoopData) error { + tr, err := sn.tm.CreateTransport(context.Background(), on, "messaging", false) + if err != nil { + return fmt.Errorf("transport: %s", err) + } + defer tr.Close() + + proto := NewSetupProtocol(tr) + if err := LoopClosed(proto, ld); err != nil { + return err + } + + sn.Logger.Infof("Closed loop on %s. LocalPort: %d", on, ld.LocalPort) + return nil +} + +func (sn *Node) setupRule(pubKey cipher.PubKey, rule routing.Rule) (routeID routing.RouteID, err error) { + tr, err := sn.tm.CreateTransport(context.Background(), pubKey, "messaging", false) + if err != nil { + err = fmt.Errorf("transport: %s", err) + return + } + defer tr.Close() + + proto := NewSetupProtocol(tr) + routeID, err = AddRule(proto, rule) + if err != nil { + return + } + + sn.Logger.Infof("Set rule of type %s on %s with ID %d", rule.Type(), pubKey, routeID) + return routeID, nil +} diff --git a/pkg/setup/node_test.go b/pkg/setup/node_test.go new file mode 100644 index 0000000000..34b5220c6c --- /dev/null +++ b/pkg/setup/node_test.go @@ -0,0 +1,360 @@ +package setup + +import ( + "context" + "encoding/json" + "errors" + "sync" + "testing" + "time" + + "github.com/skycoin/skycoin/src/util/logging" + "github.com/skycoin/skywire/pkg/cipher" + "github.com/skycoin/skywire/pkg/routing" + + // ssetup "github.com/skycoin/skywire/pkg/setup" + "github.com/skycoin/skywire/pkg/transport" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + // "github.com/watercompany/skywire-services/internal/metrics" +) + +func TestCreateLoop(t *testing.T) { + client := transport.NewDiscoveryMock() + logStore := transport.InMemoryTransportLogStore() + + pk1, sk1 := cipher.GenerateKeyPair() + pk2, sk2 := cipher.GenerateKeyPair() + pk3, sk3 := cipher.GenerateKeyPair() + pk4, sk4 := cipher.GenerateKeyPair() + pkS, skS := cipher.GenerateKeyPair() + + c1 := &transport.ManagerConfig{PubKey: pk1, SecKey: sk1, DiscoveryClient: client, LogStore: logStore} + c2 := &transport.ManagerConfig{PubKey: pk2, SecKey: sk2, DiscoveryClient: client, LogStore: logStore} + c3 := &transport.ManagerConfig{PubKey: pk3, SecKey: sk3, DiscoveryClient: client, LogStore: logStore} + c4 := &transport.ManagerConfig{PubKey: pk4, SecKey: sk4, DiscoveryClient: client, LogStore: logStore} + cS := &transport.ManagerConfig{PubKey: pkS, SecKey: skS, DiscoveryClient: client, LogStore: logStore} + + f1, f2 := transport.NewMockFactoryPair(pk1, pk2) + f3, f4 := transport.NewMockFactoryPair(pk2, pk3) + f3.SetType("mock2") + f4.SetType("mock2") + + fs1, fs2 := transport.NewMockFactoryPair(pk1, pkS) + fs1.SetType("messaging") + fs2.SetType("messaging") + fs3, fs4 := transport.NewMockFactoryPair(pk2, pkS) + fs3.SetType("messaging") + fs5, fs6 := transport.NewMockFactoryPair(pk3, pkS) + fs5.SetType("messaging") + fs7, fs8 := transport.NewMockFactoryPair(pk4, pkS) + fs7.SetType("messaging") + + fS := newMuxFactory(pkS, "messaging", map[cipher.PubKey]transport.Factory{pk1: fs2, pk2: fs4, pk3: fs6, pk4: fs8}) + + m1, err := transport.NewManager(c1, f1, fs1) + require.NoError(t, err) + + m2, err := transport.NewManager(c2, f2, f3, fs3) + require.NoError(t, err) + + m3, err := transport.NewManager(c3, f4, fs5) + require.NoError(t, err) + + m4, err := transport.NewManager(c4, fs7) + require.NoError(t, err) + + mS, err := transport.NewManager(cS, fS) + require.NoError(t, err) + + n1 := newMockNode(m1) + go n1.serve() // nolint: errcheck + n2 := newMockNode(m2) + go n2.serve() // nolint: errcheck + n3 := newMockNode(m3) + go n3.serve() // nolint: errcheck + + tr1, err := m1.CreateTransport(context.TODO(), pk2, "mock", true) + require.NoError(t, err) + + tr3, err := m3.CreateTransport(context.TODO(), pk2, "mock2", true) + require.NoError(t, err) + + l := &routing.Loop{LocalPort: 1, RemotePort: 2, Expiry: time.Now().Add(time.Hour), + Forward: routing.Route{ + &routing.Hop{From: pk1, To: pk2, Transport: tr1.ID}, + &routing.Hop{From: pk2, To: pk3, Transport: tr3.ID}, + }, + Reverse: routing.Route{ + &routing.Hop{From: pk3, To: pk2, Transport: tr3.ID}, + &routing.Hop{From: pk2, To: pk1, Transport: tr1.ID}, + }, + } + + time.Sleep(100 * time.Millisecond) + + sn := &Node{logging.MustGetLogger("routesetup"), mS, nil, 0, metrics.NewDummy()} + errChan := make(chan error) + go func() { + errChan <- sn.Serve(context.TODO()) + }() + + tr, err := m4.CreateTransport(context.TODO(), pkS, "messaging", false) + require.NoError(t, err) + + proto := NewSetupProtocol(tr) + require.NoError(t, CreateLoop(proto, l)) + + rules := n1.getRules() + require.Len(t, rules, 2) + rule := rules[1] + assert.Equal(t, routing.RuleApp, rule.Type()) + assert.Equal(t, routing.RouteID(2), rule.RouteID()) + assert.Equal(t, pk3, rule.RemotePK()) + assert.Equal(t, uint16(2), rule.RemotePort()) + assert.Equal(t, uint16(1), rule.LocalPort()) + rule = rules[2] + assert.Equal(t, routing.RuleForward, rule.Type()) + assert.Equal(t, tr1.ID, rule.TransportID()) + assert.Equal(t, routing.RouteID(2), rule.RouteID()) + + rules = n2.getRules() + require.Len(t, rules, 2) + rule = rules[1] + assert.Equal(t, routing.RuleForward, rule.Type()) + assert.Equal(t, tr1.ID, rule.TransportID()) + assert.Equal(t, routing.RouteID(1), rule.RouteID()) + rule = rules[2] + assert.Equal(t, routing.RuleForward, rule.Type()) + assert.Equal(t, tr3.ID, rule.TransportID()) + assert.Equal(t, routing.RouteID(2), rule.RouteID()) + + rules = n3.getRules() + require.Len(t, rules, 2) + rule = rules[1] + assert.Equal(t, routing.RuleForward, rule.Type()) + assert.Equal(t, tr3.ID, rule.TransportID()) + assert.Equal(t, routing.RouteID(1), rule.RouteID()) + rule = rules[2] + assert.Equal(t, routing.RuleApp, rule.Type()) + assert.Equal(t, routing.RouteID(1), rule.RouteID()) + assert.Equal(t, pk1, rule.RemotePK()) + assert.Equal(t, uint16(1), rule.RemotePort()) + assert.Equal(t, uint16(2), rule.LocalPort()) + + require.NoError(t, sn.Close()) + require.NoError(t, <-errChan) +} + +func TestCloseLoop(t *testing.T) { + client := transport.NewDiscoveryMock() + logStore := transport.InMemoryTransportLogStore() + + pk1, sk1 := cipher.GenerateKeyPair() + pk3, sk3 := cipher.GenerateKeyPair() + pkS, skS := cipher.GenerateKeyPair() + + c1 := &transport.ManagerConfig{PubKey: pk1, SecKey: sk1, DiscoveryClient: client, LogStore: logStore} + c3 := &transport.ManagerConfig{PubKey: pk3, SecKey: sk3, DiscoveryClient: client, LogStore: logStore} + cS := &transport.ManagerConfig{PubKey: pkS, SecKey: skS, DiscoveryClient: client, LogStore: logStore} + + fs1, fs2 := transport.NewMockFactoryPair(pk1, pkS) + fs1.SetType("messaging") + fs2.SetType("messaging") + fs5, fs6 := transport.NewMockFactoryPair(pk3, pkS) + fs5.SetType("messaging") + + fS := newMuxFactory(pkS, "messaging", map[cipher.PubKey]transport.Factory{pk1: fs2, pk3: fs6}) + + m1, err := transport.NewManager(c1, fs1) + require.NoError(t, err) + + m3, err := transport.NewManager(c3, fs5) + require.NoError(t, err) + + mS, err := transport.NewManager(cS, fS) + require.NoError(t, err) + + n3 := newMockNode(m3) + go n3.serve() // nolint: errcheck + + time.Sleep(100 * time.Millisecond) + + sn := &Node{logging.MustGetLogger("routesetup"), mS, nil, 0, metrics.NewDummy()} + errChan := make(chan error) + go func() { + errChan <- sn.Serve(context.TODO()) + }() + + n3.setRule(1, routing.AppRule(time.Now(), 2, pk1, 1, 2)) + rules := n3.getRules() + require.Len(t, rules, 1) + + tr, err := m1.CreateTransport(context.TODO(), pkS, "messaging", false) + require.NoError(t, err) + + proto := NewSetupProtocol(tr) + require.NoError(t, CloseLoop(proto, &LoopData{RemotePK: pk3, RemotePort: 2, LocalPort: 1})) + + rules = n3.getRules() + require.Len(t, rules, 0) + require.Nil(t, rules[1]) + + require.NoError(t, sn.Close()) + require.NoError(t, <-errChan) +} + +type muxFactory struct { + pk cipher.PubKey + fType string + factories map[cipher.PubKey]transport.Factory +} + +func newMuxFactory(pk cipher.PubKey, fType string, factories map[cipher.PubKey]transport.Factory) *muxFactory { + return &muxFactory{pk, fType, factories} +} + +func (f *muxFactory) Accept(ctx context.Context) (transport.Transport, error) { + trChan := make(chan transport.Transport) + defer close(trChan) + + errChan := make(chan error) + + for _, factory := range f.factories { + go func(ff transport.Factory) { + tr, err := ff.Accept(ctx) + if err != nil { + errChan <- err + } else { + trChan <- tr + } + }(factory) + } + + select { + case tr := <-trChan: + return tr, nil + case err := <-errChan: + return nil, err + } +} + +func (f *muxFactory) Dial(ctx context.Context, remote cipher.PubKey) (transport.Transport, error) { + return f.factories[remote].Dial(ctx, remote) +} + +func (f *muxFactory) Close() error { + var err error + for _, factory := range f.factories { + if fErr := factory.Close(); err == nil && fErr != nil { + err = fErr + } + } + + return err +} + +func (f *muxFactory) Local() cipher.PubKey { + return f.pk +} + +func (f *muxFactory) Type() string { + return f.fType +} + +type mockNode struct { + sync.Mutex + rules map[routing.RouteID]routing.Rule + tm *transport.Manager +} + +func newMockNode(tm *transport.Manager) *mockNode { + return &mockNode{tm: tm, rules: make(map[routing.RouteID]routing.Rule)} +} + +func (n *mockNode) serve() error { + acceptCh, dialCh := n.tm.Observe() + go func() { + for tr := range dialCh { + go func(t transport.Transport) { n.serveTransport(t) }(tr) // nolint: errcheck + } + }() + + go func() { + for tr := range acceptCh { + go func(t transport.Transport) { n.serveTransport(t) }(tr) // nolint: errcheck + } + }() + + return n.tm.Serve(context.Background()) +} + +func (n *mockNode) setRule(id routing.RouteID, rule routing.Rule) { + n.Lock() + n.rules[id] = rule + n.Unlock() +} + +func (n *mockNode) getRules() map[routing.RouteID]routing.Rule { + res := make(map[routing.RouteID]routing.Rule) + n.Lock() + for id, rule := range n.rules { + res[id] = rule + } + n.Unlock() + return res +} + +func (n *mockNode) serveTransport(tr transport.Transport) error { + proto := NewSetupProtocol(tr) + sp, data, err := proto.ReadPacket() + if err != nil { + return err + } + + n.Lock() + switch sp { + case PacketAddRules: + rules := []routing.Rule{} + json.Unmarshal(data, &rules) // nolint: errcheck + for _, rule := range rules { + for i := routing.RouteID(1); i < 255; i++ { + if n.rules[i] == nil { + n.rules[i] = rule + break + } + } + } + case PacketConfirmLoop: + ld := LoopData{} + json.Unmarshal(data, &ld) // nolint: errcheck + for _, rule := range n.rules { + if rule.Type() == routing.RuleApp && rule.RemotePK() == ld.RemotePK && + rule.RemotePort() == ld.RemotePort && rule.LocalPort() == ld.LocalPort { + + rule.SetRouteID(ld.RouteID) + break + } + } + case PacketLoopClosed: + ld := &LoopData{} + json.Unmarshal(data, ld) // nolint: errcheck + for routeID, rule := range n.rules { + if rule.Type() == routing.RuleApp && rule.RemotePK() == ld.RemotePK && + rule.RemotePort() == ld.RemotePort && rule.LocalPort() == ld.LocalPort { + + delete(n.rules, routeID) + break + } + } + default: + err = errors.New("unknown foundation packet") + } + n.Unlock() + + if err != nil { + return proto.WritePacket(RespFailure, err) + } + + return proto.WritePacket(RespSuccess, nil) +}