diff --git a/go.mod b/go.mod index 849ed6757b..9f86146687 100644 --- a/go.mod +++ b/go.mod @@ -28,4 +28,4 @@ require ( ) // Uncomment for tests with alternate branches of 'dmsg' -//replace github.com/skycoin/dmsg => ../dmsg +replace github.com/skycoin/dmsg => ../dmsg diff --git a/go.sum b/go.sum index 9a133695e3..cdbcb7386a 100644 --- a/go.sum +++ b/go.sum @@ -52,6 +52,7 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.5 h1:hyz3dwM5QLc1Rfoz4FuWJQG5BN7tc6K1MndAUnGpQr4= github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/magiconair/properties v1.8.0 h1:LLgXmsheXeRoUOBOjtwPQCWIYqM/LU1ayDtDePerRcY= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= diff --git a/pkg/app/protocol.go b/pkg/app/protocol.go index 6c80784d22..1f0a3d52fb 100644 --- a/pkg/app/protocol.go +++ b/pkg/app/protocol.go @@ -20,7 +20,7 @@ func (f Frame) String() string { case FrameCreateLoop: return "CreateLoop" case FrameConfirmLoop: - return "ConfirmLoop" + return "OnConfirmLoop" case FrameSend: return "Send" case FrameClose: @@ -35,7 +35,7 @@ const ( FrameInit Frame = iota // FrameCreateLoop represents CreateLoop request frame type. FrameCreateLoop - // FrameConfirmLoop represents ConfirmLoop request frame type. + // FrameConfirmLoop represents OnConfirmLoop request frame type. FrameConfirmLoop // FrameSend represents Send frame type. FrameSend diff --git a/pkg/network/network.go b/pkg/network/network.go new file mode 100644 index 0000000000..e8f70cb651 --- /dev/null +++ b/pkg/network/network.go @@ -0,0 +1,172 @@ +package network + +import ( + "context" + "errors" + "fmt" + "github.com/skycoin/skycoin/src/util/logging" + "net" + "strings" + "sync" + + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/dmsg/disc" +) + +// Default ports. +// TODO(evanlinjin): Define these properly. These are currently random. +const ( + SetupPort = uint16(36) // Listening port of a setup node. + AwaitSetupPort = uint16(136) // Listening port of a visor node for setup operations. + TransportPort = uint16(45) // Listening port of a visor node for incoming transports. +) + +// Networks. +const ( + DmsgNet = "dmsg" +) + +var ( + ErrUnknownNetwork = errors.New("unknown network type") +) + +type Config struct { + PubKey cipher.PubKey + SecKey cipher.SecKey + TpNetworks []string // networks to be used with transports + + DmsgDiscAddr string + DmsgMinSrvs int +} + +// Network represents +type Network struct { + conf Config + dmsgC *dmsg.Client +} + +func New(conf Config) *Network { + dmsgC := dmsg.NewClient(conf.PubKey, conf.SecKey, disc.NewHTTP(conf.DmsgDiscAddr), dmsg.SetLogger(logging.MustGetLogger("network.dmsgC"))) + return &Network{ + conf: conf, + dmsgC: dmsgC, + } +} + +func (n *Network) Init(ctx context.Context) error { + fmt.Println("dmsg: min_servers:", n.conf.DmsgMinSrvs) + if err := n.dmsgC.InitiateServerConnections(ctx, n.conf.DmsgMinSrvs); err != nil { + return fmt.Errorf("failed to initiate 'dmsg': %v", err) + } + return nil +} + +func (n *Network) Close() error { + wg := new(sync.WaitGroup) + wg.Add(1) + + var dmsgErr error + go func() { + dmsgErr = n.dmsgC.Close() + wg.Done() + }() + + wg.Wait() + if dmsgErr != nil { + return dmsgErr + } + return nil +} + +func (n *Network) LocalPK() cipher.PubKey { return n.conf.PubKey } + +func (n *Network) LocalSK() cipher.SecKey { return n.conf.SecKey } + +func (n *Network) Dmsg() *dmsg.Client { return n.dmsgC } + +func (n *Network) Dial(network string, pk cipher.PubKey, port uint16) (*Conn, error) { + ctx := context.Background() + switch network { + case DmsgNet: + conn, err := n.dmsgC.Dial(ctx, pk, port) + if err != nil { + return nil, err + } + return makeConn(conn, network), nil + default: + return nil, ErrUnknownNetwork + } +} + +func (n *Network) Listen(network string, port uint16) (*Listener, error) { + switch network { + case DmsgNet: + lis, err := n.dmsgC.Listen(port) + if err != nil { + return nil, err + } + return makeListener(lis, network), nil + default: + return nil, ErrUnknownNetwork + } +} + +type Listener struct { + net.Listener + lPK cipher.PubKey + lPort uint16 + network string +} + +func makeListener(l net.Listener, network string) *Listener { + lPK, lPort := disassembleAddr(l.Addr()) + return &Listener{Listener: l, lPK: lPK, lPort: lPort, network: network} +} + +func (l Listener) LocalPK() cipher.PubKey { return l.lPK } +func (l Listener) LocalPort() uint16 { return l.lPort } +func (l Listener) Network() string { return l.network } + +func (l Listener) AcceptConn() (*Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + return makeConn(conn, l.network), nil +} + +type Conn struct { + net.Conn + lPK cipher.PubKey + rPK cipher.PubKey + lPort uint16 + rPort uint16 + network string +} + +func makeConn(conn net.Conn, network string) *Conn { + lPK, lPort := disassembleAddr(conn.LocalAddr()) + rPK, rPort := disassembleAddr(conn.RemoteAddr()) + return &Conn{Conn: conn, lPK: lPK, rPK: rPK, lPort: lPort, rPort: rPort, network: network} +} + +func (c Conn) LocalPK() cipher.PubKey { return c.lPK } +func (c Conn) RemotePK() cipher.PubKey { return c.rPK } +func (c Conn) LocalPort() uint16 { return c.lPort } +func (c Conn) RemotePort() uint16 { return c.rPort } +func (c Conn) Network() string { return c.network } + +func disassembleAddr(addr net.Addr) (pk cipher.PubKey, port uint16) { + strs := strings.Split(addr.String(), ":") + if len(strs) != 2 { + panic(fmt.Errorf("network.disassembleAddr: %v %s", "invalid addr", addr.String())) + } + if err := pk.Set(strs[0]); err != nil { + panic(fmt.Errorf("network.disassembleAddr: %v %s", err, addr.String())) + } + if _, err := fmt.Sscanf(strs[1], "%d", &port); err != nil { + panic(fmt.Errorf("network.disassembleAddr: %v", err)) + } + return +} diff --git a/pkg/network/network_test.go b/pkg/network/network_test.go new file mode 100644 index 0000000000..673b649fde --- /dev/null +++ b/pkg/network/network_test.go @@ -0,0 +1,21 @@ +package network + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" +) + +func TestDisassembleAddr(t *testing.T) { + pk, _ := cipher.GenerateKeyPair() + port := uint16(2) + addr := dmsg.Addr{ + PK: pk, Port: port, + } + gotPK, gotPort := disassembleAddr(addr) + require.Equal(t, pk, gotPK) + require.Equal(t, port, gotPort) +} diff --git a/pkg/network/testing.go b/pkg/network/testing.go new file mode 100644 index 0000000000..852405502c --- /dev/null +++ b/pkg/network/testing.go @@ -0,0 +1,87 @@ +package network + +//import ( +// "github.com/skycoin/dmsg/disc" +// "github.com/skycoin/skycoin/src/util/logging" +// "github.com/stretchr/testify/require" +// "golang.org/x/net/nettest" +// "testing" +//) +// +//// KeyPair holds a public/private key pair. +//type KeyPair struct { +// PK cipher.PubKey +// SK cipher.SecKey +//} +// +//// GenKeyPairs generates 'n' number of key pairs. +//func GenKeyPairs(n int) []KeyPair { +// pairs := make([]KeyPair, n) +// for i := range pairs { +// pk, sk, err := cipher.GenerateDeterministicKeyPair([]byte{byte(i)}) +// if err != nil { +// panic(err) +// } +// pairs[i] = KeyPair{PK: pk, SK: sk} +// } +// return pairs +//} +// +//// TestEnv contains a dmsg environment. +//type TestEnv struct { +// Disc disc.APIClient +// Srv *Server +// Clients []*Client +// teardown func() +//} +// +//// SetupTestEnv creates a dmsg TestEnv. +//func SetupTestEnv(t *testing.T, keyPairs []KeyPair) *TestEnv { +// discovery := disc.NewMock() +// +// srv, srvErr := createServer(t, discovery) +// +// clients := make([]*Client, len(keyPairs)) +// for i, pair := range keyPairs { +// t.Logf("dmsg_client[%d] PK: %s\n", i, pair.PK) +// c := NewClient(pair.PK, pair.SK, discovery, +// SetLogger(logging.MustGetLogger(fmt.Sprintf("client_%d:%s", i, pair.PK.String()[:6])))) +// require.NoError(t, c.InitiateServerConnections(context.TODO(), 1)) +// clients[i] = c +// } +// +// teardown := func() { +// for _, c := range clients { +// require.NoError(t, c.Close()) +// } +// require.NoError(t, srv.Close()) +// for err := range srvErr { +// require.NoError(t, err) +// } +// } +// +// return &TestEnv{ +// Disc: discovery, +// Srv: srv, +// Clients: clients, +// teardown: teardown, +// } +//} +// +//// TearDown shutdowns the TestEnv. +//func (e *TestEnv) TearDown() { e.teardown() } +// +//func createServer(t *testing.T, dc disc.APIClient) (srv *dmsg.Server, srvErr <-chan error) { +// pk, sk, err := cipher.GenerateDeterministicKeyPair([]byte("s")) +// require.NoError(t, err) +// l, err := nettest.NewLocalListener("tcp") +// require.NoError(t, err) +// srv, err = dmsg.NewServer(pk, sk, "", l, dc) +// require.NoError(t, err) +// errCh := make(chan error, 1) +// go func() { +// errCh <- srv.Serve() +// close(errCh) +// }() +// return srv, errCh +//} \ No newline at end of file diff --git a/pkg/router/route_manager.go b/pkg/router/route_manager.go index a7d8a7fa34..3a96a11216 100644 --- a/pkg/router/route_manager.go +++ b/pkg/router/route_manager.go @@ -1,28 +1,153 @@ package router import ( + "context" "encoding/json" "errors" "fmt" - "io" "time" + "github.com/skycoin/dmsg/cipher" + + "github.com/skycoin/skywire/pkg/network" + "github.com/skycoin/skywire/pkg/setup" + "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/routing" - "github.com/skycoin/skywire/pkg/setup" ) -type setupCallbacks struct { - ConfirmLoop func(loop routing.Loop, rule routing.Rule) (err error) - LoopClosed func(loop routing.Loop) error +const ( + rtGarbageCollectDuration = time.Minute * 5 +) + +type setupConfig struct { + SetupPKs []cipher.PubKey // Trusted setup PKs. + OnConfirmLoop func(loop routing.Loop, rule routing.Rule) (err error) + OnLoopClosed func(loop routing.Loop) error +} + +func (sc setupConfig) SetupIsTrusted(sPK cipher.PubKey) bool { + for _, pk := range sc.SetupPKs { + if sPK == pk { + return true + } + } + return false } type routeManager struct { Logger *logging.Logger + conf setupConfig + n *network.Network + sl *network.Listener // Listens for setup node requests. + rt *managedRoutingTable + done chan struct{} +} + +func newRouteManager(n *network.Network, rt routing.Table, config setupConfig) (*routeManager, error) { + sl, err := n.Listen(network.DmsgNet, network.AwaitSetupPort) + if err != nil { + return nil, err + } + return &routeManager{ + Logger: logging.MustGetLogger("route_manager"), + conf: config, + n: n, + sl: sl, + rt: manageRoutingTable(rt), + done: make(chan struct{}), + }, nil +} + +func (rm *routeManager) Close() error { + close(rm.done) + return rm.sl.Close() +} + +func (rm *routeManager) Serve() { + + // Routing table garbage collect loop. + go rm.rtGarbageCollectLoop() + + // Accept setup node request loop. + for { + conn, err := rm.sl.AcceptConn() + if err != nil { + rm.Logger.WithError(err).Warnf("stopped serving") + return + } + if !rm.conf.SetupIsTrusted(conn.RemotePK()) { + rm.Logger.Warnf("closing conn from untrusted setup node: %v", conn.Close()) + continue + } + go func(conn *network.Conn) { + rm.Logger.Infof("handling setup request: setupPK(%s)", conn.RemotePK()) + defer func() { _ = conn.Close() }() //nolint:errcheck + + if err := rm.handleSetupConn(conn); err != nil { + rm.Logger.WithError(err).Warnf("setup request failed: setupPK(%s)", conn.RemotePK()) + } + rm.Logger.Infof("successfully handled setup request: setupPK(%s)", conn.RemotePK()) + }(conn) + } +} + +func (rm *routeManager) rtGarbageCollectLoop() { + ticker := time.NewTicker(rtGarbageCollectDuration) + defer ticker.Stop() + for { + select { + case <-rm.done: + return + case <-ticker.C: + if err := rm.rt.Cleanup(); err != nil { + rm.Logger.WithError(err).Warnf("routing table cleanup returned error") + } + } + } +} + +func (rm *routeManager) handleSetupConn(conn *network.Conn) error { + proto := setup.NewSetupProtocol(conn) + t, body, err := proto.ReadPacket() + + if err != nil { + return err + } + rm.Logger.Infof("Got new Setup request with type %s", t) + + var respBody interface{} + switch t { + case setup.PacketAddRules: + respBody, err = rm.addRoutingRules(body) + case setup.PacketDeleteRules: + respBody, err = rm.deleteRoutingRules(body) + case setup.PacketConfirmLoop: + err = rm.confirmLoop(body) + case setup.PacketLoopClosed: + err = rm.loopClosed(body) + default: + err = errors.New("unknown foundation packet") + } - rt *managedRoutingTable - callbacks *setupCallbacks + if err != nil { + rm.Logger.Infof("Setup request with type %s failed: %s", t, err) + return proto.WritePacket(setup.RespFailure, err.Error()) + } + return proto.WritePacket(setup.RespSuccess, respBody) +} + +func (rm *routeManager) dialSetupConn(ctx context.Context) (*network.Conn, error) { + for _, sPK := range rm.conf.SetupPKs { + conn, err := rm.n.Dial(network.DmsgNet, sPK, network.SetupPort) + if err != nil { + rm.Logger.WithError(err).Warnf("failed to dial to setup node: setupPK(%s)", sPK) + continue + } + return conn, nil + } + return nil, errors.New("failed to dial to a setup node") } func (rm *routeManager) GetRule(routeID routing.RouteID) (routing.Rule, error) { @@ -78,38 +203,6 @@ func (rm *routeManager) RemoveLoopRule(loop routing.Loop) error { return nil } -func (rm *routeManager) Serve(rw io.ReadWriter) error { - proto := setup.NewSetupProtocol(rw) - t, body, err := proto.ReadPacket() - - if err != nil { - return err - } - rm.Logger.Infof("Got new Setup request with type %s", t) - - var respBody interface{} - switch t { - case setup.PacketAddRules: - respBody, err = rm.addRoutingRules(body) - case setup.PacketDeleteRules: - respBody, err = rm.deleteRoutingRules(body) - case setup.PacketConfirmLoop: - err = rm.confirmLoop(body) - case setup.PacketLoopClosed: - err = rm.loopClosed(body) - default: - err = errors.New("unknown foundation packet") - } - - if err != nil { - rm.Logger.Infof("Setup request with type %s failed: %s", t, err) - return proto.WritePacket(setup.RespFailure, err.Error()) - } - - return proto.WritePacket(setup.RespSuccess, respBody) - -} - func (rm *routeManager) addRoutingRules(data []byte) ([]routing.RouteID, error) { var rules []routing.Rule if err := json.Unmarshal(data, &rules); err != nil { @@ -181,7 +274,7 @@ func (rm *routeManager) confirmLoop(data []byte) error { return errors.New("reverse rule is not forward") } - if err = rm.callbacks.ConfirmLoop(ld.Loop, rule); err != nil { + if err = rm.conf.OnConfirmLoop(ld.Loop, rule); err != nil { return fmt.Errorf("confirm: %s", err) } @@ -201,5 +294,5 @@ func (rm *routeManager) loopClosed(data []byte) error { return err } - return rm.callbacks.LoopClosed(ld.Loop) + return rm.conf.OnLoopClosed(ld.Loop) } diff --git a/pkg/router/route_manager_test.go b/pkg/router/route_manager_test.go index 80f2372aee..c2450d6649 100644 --- a/pkg/router/route_manager_test.go +++ b/pkg/router/route_manager_test.go @@ -76,7 +76,7 @@ func TestRouteManagerAddRemoveRule(t *testing.T) { in, out := net.Pipe() errCh := make(chan error) go func() { - errCh <- rm.Serve(out) + errCh <- rm.handleSetupConn(out) }() proto := setup.NewSetupProtocol(in) @@ -102,7 +102,7 @@ func TestRouteManagerDeleteRules(t *testing.T) { in, out := net.Pipe() errCh := make(chan error) go func() { - errCh <- rm.Serve(out) + errCh <- rm.handleSetupConn(out) }() proto := setup.NewSetupProtocol(in) @@ -123,8 +123,8 @@ func TestRouteManagerConfirmLoop(t *testing.T) { rt := manageRoutingTable(routing.InMemoryRoutingTable()) var inLoop routing.Loop var inRule routing.Rule - callbacks := &setupCallbacks{ - ConfirmLoop: func(loop routing.Loop, rule routing.Rule) (err error) { + callbacks := &setupConfig{ + OnConfirmLoop: func(loop routing.Loop, rule routing.Rule) (err error) { inLoop = loop inRule = rule return nil @@ -135,7 +135,7 @@ func TestRouteManagerConfirmLoop(t *testing.T) { in, out := net.Pipe() errCh := make(chan error) go func() { - errCh <- rm.Serve(out) + errCh <- rm.handleSetupConn(out) }() proto := setup.NewSetupProtocol(in) @@ -172,8 +172,8 @@ func TestRouteManagerConfirmLoop(t *testing.T) { func TestRouteManagerLoopClosed(t *testing.T) { rt := manageRoutingTable(routing.InMemoryRoutingTable()) var inLoop routing.Loop - callbacks := &setupCallbacks{ - LoopClosed: func(loop routing.Loop) error { + callbacks := &setupConfig{ + OnLoopClosed: func(loop routing.Loop) error { inLoop = loop return nil }, @@ -183,7 +183,7 @@ func TestRouteManagerLoopClosed(t *testing.T) { in, out := net.Pipe() errCh := make(chan error) go func() { - errCh <- rm.Serve(out) + errCh <- rm.handleSetupConn(out) }() proto := setup.NewSetupProtocol(in) diff --git a/pkg/router/router.go b/pkg/router/router.go index 248edd6f5b..3779f94c2c 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -10,7 +10,8 @@ import ( "sync" "time" - "github.com/skycoin/dmsg" + "github.com/skycoin/skywire/pkg/network" + "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" @@ -47,34 +48,41 @@ type Config struct { type Router struct { Logger *logging.Logger - config *Config - tm *transport.Manager - pm *portManager - rm *routeManager + conf *Config + staticPorts map[routing.Port]struct{} - expiryTicker *time.Ticker - wg sync.WaitGroup + n *network.Network + tm *transport.Manager + pm *portManager + rm *routeManager - staticPorts map[routing.Port]struct{} - mu sync.Mutex + wg sync.WaitGroup + mx sync.Mutex } // New constructs a new Router. -func New(config *Config) *Router { +func New(n *network.Network, config *Config) (*Router, error) { r := &Router{ - Logger: config.Logger, - tm: config.TransportManager, - pm: newPortManager(10), - config: config, - expiryTicker: time.NewTicker(10 * time.Minute), - staticPorts: make(map[routing.Port]struct{}), - } - callbacks := &setupCallbacks{ - ConfirmLoop: r.confirmLoop, - LoopClosed: r.loopClosed, - } - r.rm = &routeManager{r.Logger, manageRoutingTable(config.RoutingTable), callbacks} - return r + Logger: config.Logger, + n: n, + tm: config.TransportManager, + pm: newPortManager(10), + conf: config, + staticPorts: make(map[routing.Port]struct{}), + } + + // Prepare route manager. + rm, err := newRouteManager(n, config.RoutingTable, setupConfig{ + SetupPKs: config.SetupNodes, + OnConfirmLoop: r.confirmLoop, + OnLoopClosed: r.loopClosed, + }) + if err != nil { + return nil, err + } + r.rm = rm + + return r, nil } // Serve starts transport listening loop. @@ -97,44 +105,15 @@ func (r *Router) Serve(ctx context.Context) error { } }() + r.wg.Add(1) go func() { - for { - conn, err := r.tm.AcceptSetupConn() - if err != nil { - return - } - r.Logger.Infof("New remotely-initiated transport: purpose(setup)") - go r.serveSetup(conn) - } - }() - - go func() { - for range r.expiryTicker.C { - if err := r.rm.rt.Cleanup(); err != nil { - r.Logger.Warnf("Failed to expiry routes: %s", err) - } - } + r.rm.Serve() + r.wg.Done() }() return r.tm.Serve(ctx) } -func (r *Router) serveSetup(conn io.ReadWriteCloser) { - defer func() { - if err := conn.Close(); err != nil { - r.Logger.Warnf("Failed to close transport: %s", err) - } - }() - for { - if err := r.rm.Serve(conn); err != nil { - if err != io.EOF { - r.Logger.Warnf("Stopped serving Transport: %s", err) - } - return - } - } -} - func (r *Router) handlePacket(ctx context.Context, packet routing.Packet) error { rule, err := r.rm.GetRule(packet.RouteID()) if err != nil { @@ -157,9 +136,9 @@ func (r *Router) ServeApp(conn net.Conn, port routing.Port, appConf *app.Config) return err } - r.mu.Lock() + r.mx.Lock() r.staticPorts[port] = struct{}{} - r.mu.Unlock() + r.mx.Unlock() callbacks := &appCallbacks{ CreateLoop: r.requestLoop, @@ -177,9 +156,9 @@ func (r *Router) ServeApp(conn net.Conn, port routing.Port, appConf *app.Config) } } - r.mu.Lock() + r.mx.Lock() delete(r.staticPorts, port) - r.mu.Unlock() + r.mx.Unlock() if err == io.EOF { return nil @@ -192,17 +171,19 @@ func (r *Router) Close() error { if r == nil { return nil } - r.Logger.Info("Closing all App connections and Loops") - r.expiryTicker.Stop() for _, conn := range r.pm.AppConns() { if err := conn.Close(); err != nil { - log.WithError(err).Warn("Failed to close connection") + r.Logger.WithError(err).Warn("Failed to close connection") } } + if err := r.rm.Close(); err != nil { + r.Logger.WithError(err).Warnf("closing route_manager returned error") + } r.wg.Wait() + return r.tm.Close() } @@ -236,7 +217,7 @@ func (r *Router) consumePacket(payload []byte, rule routing.Rule) error { } func (r *Router) forwardAppPacket(ctx context.Context, appConn *app.Protocol, packet *app.Packet) error { - if packet.Loop.Remote.PubKey == r.config.PubKey { + if packet.Loop.Remote.PubKey == r.conf.PubKey { return r.forwardLocalAppPacket(packet) } @@ -276,8 +257,8 @@ func (r *Router) requestLoop(ctx context.Context, appConn *app.Protocol, raddr r return routing.Addr{}, err } - laddr := routing.Addr{PubKey: r.config.PubKey, Port: lport} - if raddr.PubKey == r.config.PubKey { + laddr := routing.Addr{PubKey: r.conf.PubKey, Port: lport} + if raddr.PubKey == r.conf.PubKey { if err := r.confirmLocalLoop(laddr, raddr); err != nil { return routing.Addr{}, fmt.Errorf("confirm: %s", err) } @@ -300,17 +281,16 @@ func (r *Router) requestLoop(ctx context.Context, appConn *app.Protocol, raddr r Reverse: reverseRoute, } - proto, tr, err := r.setupProto(ctx) + sConn, err := r.rm.dialSetupConn(ctx) if err != nil { return routing.Addr{}, err } defer func() { - if err := tr.Close(); err != nil { + if err := sConn.Close(); err != nil { r.Logger.Warnf("Failed to close transport: %s", err) } }() - - if err := setup.CreateLoop(ctx, proto, ld); err != nil { + if err := setup.CreateLoop(ctx, setup.NewSetupProtocol(sConn), ld); err != nil { return routing.Addr{}, fmt.Errorf("route setup: %s", err) } @@ -342,7 +322,7 @@ func (r *Router) confirmLoop(l routing.Loop, rule routing.Rule) error { return err } - addrs := [2]routing.Addr{{PubKey: r.config.PubKey, Port: l.Local.Port}, l.Remote} + addrs := [2]routing.Addr{{PubKey: r.conf.PubKey, Port: l.Local.Port}, l.Remote} if err = b.conn.Send(app.FrameConfirmLoop, addrs, nil); err != nil { r.Logger.Warnf("Failed to notify App about new loop: %s", err) } @@ -355,22 +335,18 @@ func (r *Router) closeLoop(ctx context.Context, appConn *app.Protocol, loop rout r.Logger.Warnf("Failed to remove loop: %s", err) } - proto, tr, err := r.setupProto(ctx) + sConn, err := r.rm.dialSetupConn(ctx) if err != nil { return err } - defer func() { - if err := tr.Close(); err != nil { + if err := sConn.Close(); err != nil { r.Logger.Warnf("Failed to close transport: %s", err) } }() - - ld := routing.LoopData{Loop: loop} - if err := setup.CloseLoop(ctx, proto, ld); err != nil { + if err := setup.CloseLoop(ctx, setup.NewSetupProtocol(sConn), routing.LoopData{Loop: loop}); err != nil { return fmt.Errorf("route setup: %s", err) } - r.Logger.Infof("Closed loop %s", loop) return nil } @@ -394,9 +370,9 @@ func (r *Router) loopClosed(loop routing.Loop) error { } func (r *Router) destroyLoop(loop routing.Loop) error { - r.mu.Lock() + r.mx.Lock() _, ok := r.staticPorts[loop.Local.Port] - r.mu.Unlock() + r.mx.Unlock() if ok { if err := r.pm.RemoveLoop(loop.Local.Port, loop.Remote); err != nil { @@ -409,20 +385,6 @@ func (r *Router) destroyLoop(loop routing.Loop) error { return r.rm.RemoveLoopRule(loop) } -func (r *Router) setupProto(ctx context.Context) (*setup.Protocol, transport.Transport, error) { - if len(r.config.SetupNodes) == 0 { - return nil, nil, errors.New("route setup: no nodes") - } - - tr, err := r.tm.DialSetupConn(ctx, r.config.SetupNodes[0], dmsg.Type) - if err != nil { - return nil, nil, fmt.Errorf("setup transport: %s", err) - } - - sProto := setup.NewSetupProtocol(tr) - return sProto, tr, nil -} - func (r *Router) fetchBestRoutes(source, destination cipher.PubKey) (fwd routing.Route, rev routing.Route, err error) { r.Logger.Infof("Requesting new routes from %s to %s", source, destination) @@ -430,7 +392,7 @@ func (r *Router) fetchBestRoutes(source, destination cipher.PubKey) (fwd routing defer timer.Stop() fetchRoutesAgain: - fwdRoutes, revRoutes, err := r.config.RouteFinder.PairedRoutes(source, destination, minHops, maxHops) + fwdRoutes, revRoutes, err := r.conf.RouteFinder.PairedRoutes(source, destination, minHops, maxHops) if err != nil { select { case <-timer.C: @@ -444,12 +406,6 @@ fetchRoutesAgain: return fwdRoutes[0], revRoutes[0], nil } -// IsSetupTransport checks whether `tr` is running in the `setup` mode. -func (r *Router) IsSetupTransport(mTp *transport.ManagedTransport) bool { - for _, pk := range r.config.SetupNodes { - if mTp.Remote() == pk { - return true - } - } - return false +func (r *Router) SetupIsTrusted(sPK cipher.PubKey) bool { + return r.rm.conf.SetupIsTrusted(sPK) } diff --git a/pkg/setup/node.go b/pkg/setup/node.go index b5d16164cf..67d7eb9905 100644 --- a/pkg/setup/node.go +++ b/pkg/setup/node.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "fmt" + "github.com/skycoin/dmsg" + "github.com/skycoin/skywire/pkg/network" "time" "github.com/skycoin/dmsg/cipher" @@ -13,8 +15,6 @@ import ( "github.com/skycoin/skywire/pkg/metrics" "github.com/skycoin/skywire/pkg/routing" - "github.com/skycoin/skywire/pkg/transport" - "github.com/skycoin/skywire/pkg/transport/dmsg" ) // Hop is a wrapper around transport hop to add functionality @@ -25,56 +25,74 @@ type Hop struct { // Node performs routes setup operations over messaging channel. type Node struct { - Logger *logging.Logger - messenger *dmsg.Client - srvCount int - metrics metrics.Recorder + Logger *logging.Logger + dmsgC *dmsg.Client + dmsgL *dmsg.Listener + srvCount int + metrics metrics.Recorder } // NewNode constructs a new SetupNode. func NewNode(conf *Config, metrics metrics.Recorder) (*Node, error) { - pk := conf.PubKey - sk := conf.SecKey + ctx := context.Background() logger := logging.NewMasterLogger() if lvl, err := logging.LevelFromString(conf.LogLevel); err == nil { logger.SetLevel(lvl) } - messenger := dmsg.NewClient(pk, sk, disc.NewHTTP(conf.Messaging.Discovery), dmsg.SetLogger(logger.PackageLogger(dmsg.Type))) + log := logger.PackageLogger("setup_node") + + // Prepare dmsg. + dmsgC := dmsg.NewClient( + conf.PubKey, + conf.SecKey, + disc.NewHTTP(conf.Messaging.Discovery), + dmsg.SetLogger(logger.PackageLogger(dmsg.Type)), + ) + if err := dmsgC.InitiateServerConnections(ctx, conf.Messaging.ServerCount); err != nil { + return nil, fmt.Errorf("failed to init dmsg: %s", err) + } + log.Info("connected to dmsg servers") + + dmsgL, err := dmsgC.Listen(network.SetupPort) + if err != nil { + return nil, fmt.Errorf("failed to listen on dmsg port %d: %v", network.SetupPort, dmsgL) + } + log.Info("started listening for dmsg connections") return &Node{ - Logger: logger.PackageLogger("routesetup"), - metrics: metrics, - messenger: messenger, - srvCount: conf.Messaging.ServerCount, + Logger: log, + dmsgC: dmsgC, + dmsgL: dmsgL, + srvCount: conf.Messaging.ServerCount, + metrics: metrics, }, nil } // Serve starts transport listening loop. func (sn *Node) Serve(ctx context.Context) error { - if sn.srvCount > 0 { - if err := sn.messenger.InitiateServerConnections(ctx, sn.srvCount); err != nil { - return fmt.Errorf("messaging: %s", err) - } - sn.Logger.Info("Connected to messaging servers") + if err := sn.dmsgC.InitiateServerConnections(ctx, sn.srvCount); err != nil { + return fmt.Errorf("messaging: %s", err) } + sn.Logger.Info("Connected to messaging servers") + sn.Logger.Info("Starting Setup Node") for { - tp, err := sn.messenger.Accept(ctx) + conn, err := sn.dmsgL.AcceptTransport() if err != nil { return err } - go func(tp transport.Transport) { - if err := sn.serveTransport(ctx, tp); err != nil { + go func(conn *dmsg.Transport) { + if err := sn.serveTransport(ctx, conn); err != nil { sn.Logger.Warnf("Failed to serve Transport: %s", err) } - }(tp) + }(conn) } } -func (sn *Node) serveTransport(ctx context.Context, tr transport.Transport) error { +func (sn *Node) serveTransport(ctx context.Context, tr *dmsg.Transport) error { ctx, cancel := context.WithTimeout(ctx, ServeTransportTimeout) defer cancel() @@ -217,7 +235,7 @@ func (sn *Node) createRoute(ctx context.Context, expireAt time.Time, route routi } func (sn *Node) connectLoop(ctx context.Context, on cipher.PubKey, ld routing.LoopData) error { - tr, err := sn.messenger.Dial(ctx, on) + tr, err := sn.dmsgC.Dial(ctx, on, network.AwaitSetupPort) if err != nil { return fmt.Errorf("transport: %s", err) } @@ -240,14 +258,14 @@ func (sn *Node) Close() error { if sn == nil { return nil } - return sn.messenger.Close() + return sn.dmsgC.Close() } func (sn *Node) closeLoop(ctx context.Context, on cipher.PubKey, ld routing.LoopData) error { fmt.Printf(">>> BEGIN: closeLoop(%s, ld)\n", on) defer fmt.Printf(">>> END: closeLoop(%s, ld)\n", on) - tr, err := sn.messenger.Dial(ctx, on) + tr, err := sn.dmsgC.Dial(ctx, on, network.AwaitSetupPort) fmt.Println(">>> *****: closeLoop() dialed:", err) if err != nil { return fmt.Errorf("transport: %s", err) @@ -269,7 +287,7 @@ func (sn *Node) closeLoop(ctx context.Context, on cipher.PubKey, ld routing.Loop func (sn *Node) setupRule(ctx context.Context, pubKey cipher.PubKey, rule routing.Rule) (routeID routing.RouteID, err error) { sn.Logger.Debugf("dialing to %s to setup rule: %v\n", pubKey, rule) - tr, err := sn.messenger.Dial(ctx, pubKey) + tr, err := sn.dmsgC.Dial(ctx, pubKey, network.AwaitSetupPort) if err != nil { err = fmt.Errorf("transport: %s", err) return diff --git a/pkg/setup/node_test.go b/pkg/setup/node_test.go index a46ebad161..aaa534ff5f 100644 --- a/pkg/setup/node_test.go +++ b/pkg/setup/node_test.go @@ -71,9 +71,9 @@ func TestNode(t *testing.T) { // CLOSURE: sets up setup node. prepSetupNode := func(c *dmsg.Client) (*Node, func()) { sn := &Node{ - Logger: logging.MustGetLogger("setup_node"), - messenger: c, - metrics: metrics.NewDummy(), + Logger: logging.MustGetLogger("setup_node"), + dmsgC: c, + metrics: metrics.NewDummy(), } go func() { _ = sn.Serve(context.TODO()) }() //nolint:errcheck return sn, func() { @@ -92,7 +92,7 @@ func TestNode(t *testing.T) { // prepare and serve setup node (using client 0). sn, closeSetup := prepSetupNode(clients[0]) - setupPK := sn.messenger.Local() + setupPK := sn.dmsgC.Local() defer closeSetup() // prepare loop creation (client_1 will use this to request loop creation with setup node). @@ -156,7 +156,7 @@ func TestNode(t *testing.T) { _ = proto.WritePacket(RespSuccess, rIDs) //nolint:errcheck } - // CLOSURE: emulates how a visor node should react when expecting an ConfirmLoop packet. + // CLOSURE: emulates how a visor node should react when expecting an OnConfirmLoop packet. expectConfirmLoop := func(client int) { tp, err := clients[client].Accept(context.TODO()) require.NoError(t, err) @@ -178,7 +178,7 @@ func TestNode(t *testing.T) { require.Equal(t, ld.Loop.Local, d.Loop.Remote) require.Equal(t, ld.Loop.Remote, d.Loop.Local) default: - t.Fatalf("We shouldn't be receiving a ConfirmLoop packet from client %d", client) + t.Fatalf("We shouldn't be receiving a OnConfirmLoop packet from client %d", client) } // TODO: This error is not checked due to a bug in dmsg. @@ -209,7 +209,7 @@ func TestNode(t *testing.T) { // prepare and serve setup node. sn, closeSetup := prepSetupNode(clients[0]) - setupPK := sn.messenger.Local() + setupPK := sn.dmsgC.Local() defer closeSetup() // prepare loop data describing the loop that is to be closed. diff --git a/pkg/setup/protocol.go b/pkg/setup/protocol.go index 6532a7e636..4cdb0a6823 100644 --- a/pkg/setup/protocol.go +++ b/pkg/setup/protocol.go @@ -24,11 +24,11 @@ func (sp PacketType) String() string { case PacketCreateLoop: return "CreateLoop" case PacketConfirmLoop: - return "ConfirmLoop" + return "OnConfirmLoop" case PacketCloseLoop: return "CloseLoop" case PacketLoopClosed: - return "LoopClosed" + return "OnLoopClosed" case RespSuccess: return "Success" case RespFailure: @@ -44,11 +44,11 @@ const ( PacketDeleteRules // PacketCreateLoop represents CreateLoop foundation packet. PacketCreateLoop - // PacketConfirmLoop represents ConfirmLoop foundation packet. + // PacketConfirmLoop represents OnConfirmLoop foundation packet. PacketConfirmLoop // PacketCloseLoop represents CloseLoop foundation packet. PacketCloseLoop - // PacketLoopClosed represents LoopClosed foundation packet. + // PacketLoopClosed represents OnLoopClosed foundation packet. PacketLoopClosed // RespFailure represents failure response for a foundation packet. @@ -138,7 +138,7 @@ func CreateLoop(ctx context.Context, p *Protocol, ld routing.LoopDescriptor) err return readAndDecodePacketWithTimeout(ctx, p, nil) // TODO: data race. } -// ConfirmLoop sends ConfirmLoop setup request. +// OnConfirmLoop sends OnConfirmLoop setup request. func ConfirmLoop(ctx context.Context, p *Protocol, ld routing.LoopData) error { if err := p.WritePacket(PacketConfirmLoop, ld); err != nil { return err @@ -154,7 +154,7 @@ func CloseLoop(ctx context.Context, p *Protocol, ld routing.LoopData) error { return readAndDecodePacketWithTimeout(ctx, p, nil) } -// LoopClosed sends LoopClosed setup request. +// OnLoopClosed sends OnLoopClosed setup request. func LoopClosed(ctx context.Context, p *Protocol, ld routing.LoopData) error { if err := p.WritePacket(PacketLoopClosed, ld); err != nil { return err diff --git a/pkg/transport/dmsg/dmsg.go b/pkg/transport/dmsg/dmsg.go deleted file mode 100644 index d3d95466f1..0000000000 --- a/pkg/transport/dmsg/dmsg.go +++ /dev/null @@ -1,66 +0,0 @@ -package dmsg - -import ( - "context" - "net" - "time" - - "github.com/skycoin/dmsg" - "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/dmsg/disc" - "github.com/skycoin/skycoin/src/util/logging" - - "github.com/skycoin/skywire/pkg/transport" -) - -const ( - // Type is a wrapper type for "github.com/skycoin/dmsg".Type - Type = dmsg.Type -) - -// Config configures dmsg -type Config struct { - PubKey cipher.PubKey - SecKey cipher.SecKey - Discovery disc.APIClient - Retries int - RetryDelay time.Duration -} - -// Server is an alias for dmsg.Server. -type Server = dmsg.Server - -// NewServer is an alias for dmsg.NewServer. -func NewServer(pk cipher.PubKey, sk cipher.SecKey, addr string, l net.Listener, dc disc.APIClient) (*Server, error) { - return dmsg.NewServer(pk, sk, addr, l, dc) -} - -// ClientOption is a wrapper type for "github.com/skycoin/dmsg".ClientOption -type ClientOption = dmsg.ClientOption - -// Client is a wrapper type for "github.com/skycoin/dmsg".Client -type Client struct { - *dmsg.Client -} - -// NewClient is a wrapper type for "github.com/skycoin/dmsg".NewClient -func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc disc.APIClient, opts ...ClientOption) *Client { - return &Client{ - Client: dmsg.NewClient(pk, sk, dc, opts...), - } -} - -// Accept is a wrapper type for "github.com/skycoin/dmsg".Accept -func (c *Client) Accept(ctx context.Context) (transport.Transport, error) { - return c.Client.Accept(ctx) -} - -// Dial is a wrapper type for "github.com/skycoin/dmsg".Dial -func (c *Client) Dial(ctx context.Context, remote cipher.PubKey) (transport.Transport, error) { - return c.Client.Dial(ctx, remote) -} - -// SetLogger is a wrapper type for "github.com/skycoin/dmsg".SetLogger -func SetLogger(log *logging.Logger) ClientOption { - return dmsg.SetLogger(log) -} diff --git a/pkg/transport/dmsg/testing.go b/pkg/transport/dmsg/testing.go deleted file mode 100644 index a681035afb..0000000000 --- a/pkg/transport/dmsg/testing.go +++ /dev/null @@ -1,92 +0,0 @@ -package dmsg - -import ( - "context" - "fmt" - "testing" - - "github.com/skycoin/dmsg" - "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/dmsg/disc" - "github.com/skycoin/skycoin/src/util/logging" - "github.com/stretchr/testify/require" - "golang.org/x/net/nettest" -) - -// KeyPair holds a public/private key pair. -type KeyPair struct { - PK cipher.PubKey - SK cipher.SecKey -} - -// GenKeyPairs generates 'n' number of key pairs. -func GenKeyPairs(n int) []KeyPair { - pairs := make([]KeyPair, n) - for i := range pairs { - pk, sk, err := cipher.GenerateDeterministicKeyPair([]byte{byte(i)}) - if err != nil { - panic(err) - } - pairs[i] = KeyPair{PK: pk, SK: sk} - } - return pairs -} - -// TestEnv contains a dmsg environment. -type TestEnv struct { - Disc disc.APIClient - Srv *Server - Clients []*Client - teardown func() -} - -// SetupTestEnv creates a dmsg TestEnv. -func SetupTestEnv(t *testing.T, keyPairs []KeyPair) *TestEnv { - discovery := disc.NewMock() - - srv, srvErr := createServer(t, discovery) - - clients := make([]*Client, len(keyPairs)) - for i, pair := range keyPairs { - t.Logf("dmsg_client[%d] PK: %s\n", i, pair.PK) - c := NewClient(pair.PK, pair.SK, discovery, - SetLogger(logging.MustGetLogger(fmt.Sprintf("client_%d:%s", i, pair.PK.String()[:6])))) - require.NoError(t, c.InitiateServerConnections(context.TODO(), 1)) - clients[i] = c - } - - teardown := func() { - for _, c := range clients { - require.NoError(t, c.Close()) - } - require.NoError(t, srv.Close()) - for err := range srvErr { - require.NoError(t, err) - } - } - - return &TestEnv{ - Disc: discovery, - Srv: srv, - Clients: clients, - teardown: teardown, - } -} - -// TearDown shutdowns the TestEnv. -func (e *TestEnv) TearDown() { e.teardown() } - -func createServer(t *testing.T, dc disc.APIClient) (srv *dmsg.Server, srvErr <-chan error) { - pk, sk, err := cipher.GenerateDeterministicKeyPair([]byte("s")) - require.NoError(t, err) - l, err := nettest.NewLocalListener("tcp") - require.NoError(t, err) - srv, err = dmsg.NewServer(pk, sk, "", l, dc) - require.NoError(t, err) - errCh := make(chan error, 1) - go func() { - errCh <- srv.Serve() - close(errCh) - }() - return srv, errCh -} diff --git a/pkg/transport/handshake.go b/pkg/transport/handshake.go index aa590bb5be..6efcad5eae 100644 --- a/pkg/transport/handshake.go +++ b/pkg/transport/handshake.go @@ -7,6 +7,8 @@ import ( "fmt" "io" + "github.com/skycoin/skywire/pkg/network" + "github.com/skycoin/dmsg/cipher" ) @@ -19,8 +21,8 @@ func makeEntry(pk1, pk2 cipher.PubKey, tpType string) Entry { } } -func makeEntryFromTp(tp Transport) Entry { - return makeEntry(tp.LocalPK(), tp.RemotePK(), tp.Type()) +func makeEntryFromTpConn(conn *network.Conn) Entry { + return makeEntry(conn.LocalPK(), conn.RemotePK(), conn.Network()) } func compareEntries(expected, received *Entry) error { @@ -59,13 +61,13 @@ func receiveAndVerifyEntry(r io.Reader, expected *Entry, remotePK cipher.PubKey) // SettlementHS represents a settlement handshake. // This is the handshake responsible for registering a transport to transport discovery. -type SettlementHS func(ctx context.Context, dc DiscoveryClient, tp Transport, sk cipher.SecKey) error +type SettlementHS func(ctx context.Context, dc DiscoveryClient, conn *network.Conn, sk cipher.SecKey) error // Do performs the settlement handshake. -func (hs SettlementHS) Do(ctx context.Context, dc DiscoveryClient, tp Transport, sk cipher.SecKey) (err error) { +func (hs SettlementHS) Do(ctx context.Context, dc DiscoveryClient, conn *network.Conn, sk cipher.SecKey) (err error) { done := make(chan struct{}) go func() { - err = hs(ctx, dc, tp, sk) + err = hs(ctx, dc, conn, sk) close(done) }() select { @@ -81,23 +83,23 @@ func (hs SettlementHS) Do(ctx context.Context, dc DiscoveryClient, tp Transport, func MakeSettlementHS(init bool) SettlementHS { // initiating logic. - initHS := func(ctx context.Context, dc DiscoveryClient, tp Transport, sk cipher.SecKey) (err error) { - entry := makeEntryFromTp(tp) + initHS := func(ctx context.Context, dc DiscoveryClient, conn *network.Conn, sk cipher.SecKey) (err error) { + entry := makeEntryFromTpConn(conn) defer func() { _, _ = dc.UpdateStatuses(ctx, &Status{ID: entry.ID, IsUp: err == nil}) }() //nolint:errcheck // create signed entry and send it to responding visor node. - se, ok := NewSignedEntry(&entry, tp.LocalPK(), sk) + se, ok := NewSignedEntry(&entry, conn.LocalPK(), sk) if !ok { return errors.New("failed to sign entry") } - if err := json.NewEncoder(tp).Encode(se); err != nil { + if err := json.NewEncoder(conn).Encode(se); err != nil { return fmt.Errorf("failed to write entry: %v", err) } // await okay signal. accepted := make([]byte, 1) - if _, err := io.ReadFull(tp, accepted); err != nil { + if _, err := io.ReadFull(conn, accepted); err != nil { return fmt.Errorf("failed to read response: %v", err) } if accepted[0] == 0 { @@ -107,15 +109,15 @@ func MakeSettlementHS(init bool) SettlementHS { } // responding logic. - respHS := func(ctx context.Context, dc DiscoveryClient, tp Transport, sk cipher.SecKey) error { - entry := makeEntryFromTp(tp) + respHS := func(ctx context.Context, dc DiscoveryClient, conn *network.Conn, sk cipher.SecKey) error { + entry := makeEntryFromTpConn(conn) // receive, verify and sign entry. - recvSE, err := receiveAndVerifyEntry(tp, &entry, tp.RemotePK()) + recvSE, err := receiveAndVerifyEntry(conn, &entry, conn.RemotePK()) if err != nil { return err } - if ok := recvSE.Sign(tp.LocalPK(), sk); !ok { + if ok := recvSE.Sign(conn.LocalPK(), sk); !ok { return errors.New("failed to sign received entry") } entry = *recvSE.Entry @@ -124,7 +126,7 @@ func MakeSettlementHS(init bool) SettlementHS { _ = dc.RegisterTransports(ctx, recvSE) //nolint:errcheck // inform initiating visor node. - if _, err := tp.Write([]byte{1}); err != nil { + if _, err := conn.Write([]byte{1}); err != nil { return fmt.Errorf("failed to accept transport settlement: write failed: %v", err) } return nil diff --git a/pkg/transport/managed_transport.go b/pkg/transport/managed_transport.go index 5f97a5f886..9746d4fc1d 100644 --- a/pkg/transport/managed_transport.go +++ b/pkg/transport/managed_transport.go @@ -9,9 +9,10 @@ import ( "sync/atomic" "time" + "github.com/skycoin/skywire/pkg/network" + "github.com/skycoin/skywire/pkg/routing" - "github.com/skycoin/dmsg" "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" ) @@ -34,18 +35,17 @@ var ( type ManagedTransport struct { log *logging.Logger - lSK cipher.SecKey - rPK cipher.PubKey - - fac Factory - dc DiscoveryClient - ls LogStore - + rPK cipher.PubKey + netName string Entry Entry LogEntry *LogEntry logUpdates uint32 - conn Transport + dc DiscoveryClient + ls LogStore + + n *network.Network + conn *network.Conn connCh chan struct{} connMx sync.Mutex @@ -55,15 +55,15 @@ type ManagedTransport struct { } // NewManagedTransport creates a new ManagedTransport. -func NewManagedTransport(fac Factory, dc DiscoveryClient, ls LogStore, rPK cipher.PubKey, lSK cipher.SecKey) *ManagedTransport { +func NewManagedTransport(n *network.Network, dc DiscoveryClient, ls LogStore, rPK cipher.PubKey, netName string) *ManagedTransport { mt := &ManagedTransport{ log: logging.MustGetLogger(fmt.Sprintf("tp:%s", rPK.String()[:6])), - lSK: lSK, rPK: rPK, - fac: fac, + netName: netName, + n: n, dc: dc, ls: ls, - Entry: makeEntry(fac.Local(), rPK, dmsg.Type), + Entry: makeEntry(n.LocalPK(), rPK, netName), LogEntry: new(LogEntry), connCh: make(chan struct{}, 1), done: make(chan struct{}), @@ -183,22 +183,26 @@ func (mt *ManagedTransport) close() (closed bool) { } // Accept accepts a new underlying connection. -func (mt *ManagedTransport) Accept(ctx context.Context, tp Transport) error { +func (mt *ManagedTransport) Accept(ctx context.Context, conn *network.Conn) error { mt.connMx.Lock() defer mt.connMx.Unlock() + if conn.Network() != mt.netName { + return errors.New("wrong network") // TODO: Make global var. + } + if !mt.isServing() { - _ = tp.Close() //nolint:errcheck + _ = conn.Close() //nolint:errcheck return ErrNotServing } ctx, cancel := context.WithTimeout(ctx, time.Second*20) defer cancel() - if err := MakeSettlementHS(false).Do(ctx, mt.dc, tp, mt.lSK); err != nil { + if err := MakeSettlementHS(false).Do(ctx, mt.dc, conn, mt.n.LocalSK()); err != nil { return fmt.Errorf("settlement handshake failed: %v", err) } - return mt.setIfConnNil(ctx, tp) + return mt.setIfConnNil(ctx, conn) } // Dial dials a new underlying connection. @@ -218,21 +222,21 @@ func (mt *ManagedTransport) Dial(ctx context.Context) error { // TODO: Figure out where this fella is called. func (mt *ManagedTransport) dial(ctx context.Context) error { - tp, err := mt.fac.Dial(ctx, mt.rPK) + tp, err := mt.n.Dial(mt.netName, mt.rPK, network.TransportPort) if err != nil { return err } ctx, cancel := context.WithTimeout(ctx, time.Second*20) defer cancel() - if err := MakeSettlementHS(true).Do(ctx, mt.dc, tp, mt.lSK); err != nil { + if err := MakeSettlementHS(true).Do(ctx, mt.dc, tp, mt.n.LocalSK()); err != nil { return fmt.Errorf("settlement handshake failed: %v", err) } return mt.setIfConnNil(ctx, tp) } -func (mt *ManagedTransport) getConn() Transport { +func (mt *ManagedTransport) getConn() *network.Conn { mt.connMx.Lock() conn := mt.conn mt.connMx.Unlock() @@ -241,7 +245,7 @@ func (mt *ManagedTransport) getConn() Transport { // sets conn if `mt.conn` is nil otherwise, closes the conn. // TODO: Add logging here. -func (mt *ManagedTransport) setIfConnNil(ctx context.Context, conn Transport) error { +func (mt *ManagedTransport) setIfConnNil(ctx context.Context, conn *network.Conn) error { if mt.conn != nil { _ = conn.Close() //nolint:errcheck return ErrConnAlreadyExists @@ -304,7 +308,7 @@ func (mt *ManagedTransport) WritePacket(ctx context.Context, rtID routing.RouteI // WARNING: Not thread safe. func (mt *ManagedTransport) readPacket() (packet routing.Packet, err error) { - var conn Transport + var conn *network.Conn for { if conn = mt.getConn(); conn != nil { break @@ -358,4 +362,4 @@ func (mt *ManagedTransport) logMod() bool { func (mt *ManagedTransport) Remote() cipher.PubKey { return mt.rPK } // Type returns the transport type. -func (mt *ManagedTransport) Type() string { return mt.fac.Type() } +func (mt *ManagedTransport) Type() string { return mt.netName } diff --git a/pkg/transport/manager.go b/pkg/transport/manager.go index f633110749..25cb57542b 100644 --- a/pkg/transport/manager.go +++ b/pkg/transport/manager.go @@ -3,10 +3,13 @@ package transport import ( "context" "errors" + "fmt" "io" "strings" "sync" + "github.com/skycoin/skywire/pkg/network" + "github.com/skycoin/skywire/pkg/routing" "github.com/google/uuid" @@ -18,54 +21,60 @@ import ( type ManagerConfig struct { PubKey cipher.PubKey SecKey cipher.SecKey + DefaultNodes []cipher.PubKey // Nodes to automatically connect to + Networks []string // Networks to use. DiscoveryClient DiscoveryClient LogStore LogStore - DefaultNodes []cipher.PubKey // Nodes to automatically connect to } // Manager manages Transports. type Manager struct { - Logger *logging.Logger - conf *ManagerConfig - setupPKS []cipher.PubKey - facs map[string]Factory - tps map[uuid.UUID]*ManagedTransport - - setupCh chan Transport - readCh chan routing.Packet - mx sync.RWMutex - done chan struct{} + Logger *logging.Logger + conf *ManagerConfig + nets map[string]struct{} + tps map[uuid.UUID]*ManagedTransport + n *network.Network + + readCh chan routing.Packet + mx sync.RWMutex + done chan struct{} } // NewManager creates a Manager with the provided configuration and transport factories. // 'factories' should be ordered by preference. -func NewManager(config *ManagerConfig, setupPKs []cipher.PubKey, factories ...Factory) (*Manager, error) { - tm := &Manager{ - Logger: logging.MustGetLogger("tp_manager"), - conf: config, - setupPKS: setupPKs, - facs: make(map[string]Factory), - tps: make(map[uuid.UUID]*ManagedTransport), - setupCh: make(chan Transport, 9), // TODO: eliminate or justify buffering here - readCh: make(chan routing.Packet, 20), - done: make(chan struct{}), +func NewManager(n *network.Network, config *ManagerConfig) (*Manager, error) { + nets := make(map[string]struct{}) + for _, n := range config.Networks { + nets[n] = struct{}{} } - for _, factory := range factories { - tm.facs[factory.Type()] = factory + tm := &Manager{ + Logger: logging.MustGetLogger("tp_manager"), + conf: config, + nets: nets, + tps: make(map[uuid.UUID]*ManagedTransport), + n: n, + readCh: make(chan routing.Packet, 20), + done: make(chan struct{}), } return tm, nil } // Serve runs listening loop across all registered factories. func (tm *Manager) Serve(ctx context.Context) error { - tm.mx.Lock() - tm.initTransports(ctx) - tm.mx.Unlock() - + var listeners []*network.Listener var wg sync.WaitGroup - for _, factory := range tm.facs { + + for _, netName := range tm.conf.Networks { + lis, err := tm.n.Listen(netName, network.TransportPort) + if err != nil { + return fmt.Errorf("failed to listen on network '%s' of port '%d': %v", + netName, network.TransportPort, err) + } + tm.Logger.Infof("listening on network: %s", netName) + listeners = append(listeners, lis) + wg.Add(1) - go func(f Factory) { + go func(netName string) { defer wg.Done() for { select { @@ -74,7 +83,7 @@ func (tm *Manager) Serve(ctx context.Context) error { case <-tm.done: return default: - if err := tm.acceptTransport(ctx, f); err != nil { + if err := tm.acceptTransport(ctx, lis); err != nil { if strings.Contains(err.Error(), "closed") { return } @@ -82,15 +91,32 @@ func (tm *Manager) Serve(ctx context.Context) error { } } } - }(factory) + }(netName) + } + tm.Logger.Info("transport manager is serving.") + + // closing logic + <-tm.done + + tm.Logger.Info("transport manager is closing.") + defer tm.Logger.Info("transport manager closed.") + + // Close all listeners. + for i, lis := range listeners { + if err := lis.Close(); err != nil { + tm.Logger.Warnf("listener %d of network '%s' closed with error: %v", i, lis.Network(), err) + } } - tm.Logger.Info("TransportManager is serving.") wg.Wait() + close(tm.readCh) return nil } func (tm *Manager) initTransports(ctx context.Context) { + tm.mx.Lock() + defer tm.mx.Unlock() + entries, err := tm.conf.DiscoveryClient.GetTransportsByEdge(ctx, tm.conf.PubKey) if err != nil { log.Warnf("No transports found for local node: %v", err) @@ -107,11 +133,12 @@ func (tm *Manager) initTransports(ctx context.Context) { } } -func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) error { - tr, err := factory.Accept(ctx) +func (tm *Manager) acceptTransport(ctx context.Context, lis *network.Listener) error { + conn, err := lis.AcceptConn() if err != nil { return err } + tm.Logger.Infof("recv transport connection request: type(%s) remote(%s)", lis.Network(), conn.RemotePK()) tm.mx.Lock() defer tm.mx.Unlock() @@ -120,30 +147,26 @@ func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) error { return errors.New("transport.Manager is closing. Skipping incoming transport") } - if tm.IsSetupPK(tr.RemotePK()) { - tm.setupCh <- tr - return nil - } - // For transports for purpose(data). - tpID := tm.tpIDFromPK(tr.RemotePK(), tr.Type()) + + tpID := tm.tpIDFromPK(conn.RemotePK(), conn.Network()) mTp, ok := tm.tps[tpID] if !ok { - mTp = NewManagedTransport(factory, tm.conf.DiscoveryClient, tm.conf.LogStore, tr.RemotePK(), tm.conf.SecKey) - if err := mTp.Accept(ctx, tr); err != nil { + mTp = NewManagedTransport(tm.n, tm.conf.DiscoveryClient, tm.conf.LogStore, conn.RemotePK(), lis.Network()) + if err := mTp.Accept(ctx, conn); err != nil { return err } go mTp.Serve(tm.readCh, tm.done) tm.tps[tpID] = mTp } else { - if err := mTp.Accept(ctx, tr); err != nil { + if err := mTp.Accept(ctx, conn); err != nil { return err } } - tm.Logger.Infof("accepted tp: type(%s) remote(%s) tpID(%s) new(%v)", factory.Type(), tr.RemotePK(), tpID, !ok) + tm.Logger.Infof("accepted tp: type(%s) remote(%s) tpID(%s) new(%v)", lis.Network(), conn.RemotePK(), tpID, !ok) return nil } @@ -164,24 +187,23 @@ func (tm *Manager) SaveTransport(ctx context.Context, remote cipher.PubKey, tpTy return mTp, nil } -func (tm *Manager) saveTransport(remote cipher.PubKey, tpType string) (*ManagedTransport, error) { - factory, ok := tm.facs[tpType] - if !ok { +func (tm *Manager) saveTransport(remote cipher.PubKey, netName string) (*ManagedTransport, error) { + if _, ok := tm.nets[netName]; !ok { return nil, errors.New("unknown transport type") } - tpID := tm.tpIDFromPK(remote, tpType) + tpID := tm.tpIDFromPK(remote, netName) tp, ok := tm.tps[tpID] if ok { return tp, nil } - mTp := NewManagedTransport(factory, tm.conf.DiscoveryClient, tm.conf.LogStore, remote, tm.conf.SecKey) + mTp := NewManagedTransport(tm.n, tm.conf.DiscoveryClient, tm.conf.LogStore, remote, netName) go mTp.Serve(tm.readCh, tm.done) tm.tps[tpID] = mTp - tm.Logger.Infof("saved transport: remote(%s) type(%s) tpID(%s)", remote, tpType, tpID) + tm.Logger.Infof("saved transport: remote(%s) type(%s) tpID(%s)", remote, netName, tpID) return mTp, nil } @@ -209,65 +231,13 @@ func (tm *Manager) ReadPacket() (routing.Packet, error) { return p, nil } -/* - SETUP LOGIC -*/ - -// SetupPKs returns setup node list contained within the TransportManager. -func (tm *Manager) SetupPKs() []cipher.PubKey { - return tm.setupPKS -} - -// IsSetupPK checks whether provided `pk` is of `setup` purpose. -func (tm *Manager) IsSetupPK(pk cipher.PubKey) bool { - for _, sPK := range tm.setupPKS { - if sPK == pk { - return true - } - } - return false -} - -// DialSetupConn dials to a remote setup node. -func (tm *Manager) DialSetupConn(ctx context.Context, remote cipher.PubKey, tpType string) (Transport, error) { - tm.mx.Lock() - defer tm.mx.Unlock() - if tm.isClosing() { - return nil, io.ErrClosedPipe - } - - factory, ok := tm.facs[tpType] - if !ok { - return nil, errors.New("unknown transport type") - } - tr, err := factory.Dial(ctx, remote) - if err != nil { - return nil, err - } - tm.Logger.Infof("Dialed to setup node %s using %s factory.", remote, tpType) - return tr, nil -} - -// AcceptSetupConn accepts a connection from a remote setup node. -func (tm *Manager) AcceptSetupConn() (Transport, error) { - tp, ok := <-tm.setupCh - if !ok { - return nil, ErrNotServing - } - return tp, nil -} - /* STATE */ -// Factories returns all the factory types contained within the TransportManager. -func (tm *Manager) Factories() []string { - fTypes, i := make([]string, len(tm.facs)), 0 - for _, f := range tm.facs { - fTypes[i], i = f.Type(), i+1 - } - return fTypes +// Networks returns all the network types contained within the TransportManager. +func (tm *Manager) Networks() []string { + return tm.conf.Networks } // Transport obtains a Transport via a given Transport ID. @@ -304,13 +274,6 @@ func (tm *Manager) Close() error { defer tm.mx.Unlock() close(tm.done) - tm.Logger.Info("closing transport manager...") - defer tm.Logger.Infof("transport manager closed.") - - go func() { - for range tm.readCh { - } - }() i, statuses := 0, make([]*Status, len(tm.tps)) for _, tr := range tm.tps { @@ -321,16 +284,6 @@ func (tm *Manager) Close() error { if _, err := tm.conf.DiscoveryClient.UpdateStatuses(context.Background(), statuses...); err != nil { tm.Logger.Warnf("failed to update transport statuses: %v", err) } - - tm.Logger.Infof("closing transport factories...") - for _, f := range tm.facs { - if err := f.Close(); err != nil { - tm.Logger.Warnf("Failed to close factory: %s", err) - } - } - - close(tm.setupCh) - close(tm.readCh) return nil } diff --git a/pkg/transport/mock.go b/pkg/transport/mock.go index f57360629e..f66636683d 100644 --- a/pkg/transport/mock.go +++ b/pkg/transport/mock.go @@ -47,7 +47,7 @@ func (f *MockFactory) SetType(fType string) { } // Accept waits for new net.Conn notification from another MockFactory. -func (f *MockFactory) Accept(ctx context.Context) (Transport, error) { +func (f *MockFactory) Accept(ctx context.Context) (*MockTransport, error) { select { case conn, ok := <-f.in: if !ok { @@ -61,7 +61,7 @@ func (f *MockFactory) Accept(ctx context.Context) (Transport, error) { } // Dial creates pair of net.Conn via net.Pipe and passes one end to another MockFactory. -func (f *MockFactory) Dial(ctx context.Context, remote cipher.PubKey) (Transport, error) { +func (f *MockFactory) Dial(ctx context.Context, remote cipher.PubKey) (*MockTransport, error) { in, out := net.Pipe() select { case <-f.outDone: @@ -177,12 +177,12 @@ func MockTransportManagersPair() (pk1, pk2 cipher.PubKey, m1, m2 *Manager, errCh c1 := &ManagerConfig{PubKey: pk1, SecKey: sk1, DiscoveryClient: discovery, LogStore: logs} c2 := &ManagerConfig{PubKey: pk2, SecKey: sk2, DiscoveryClient: discovery, LogStore: logs} - f1, f2 := NewMockFactoryPair(pk1, pk2) + //f1, f2 := NewMockFactoryPair(pk1, pk2) - if m1, err = NewManager(c1, nil, f1); err != nil { + if m1, err = NewManager(nil, c1); err != nil { return } - if m2, err = NewManager(c2, nil, f2); err != nil { + if m2, err = NewManager(nil, c2); err != nil { return } diff --git a/pkg/transport/tcp_transport.go b/pkg/transport/tcp_transport.go index e5d112fc5e..c9075ab5e0 100644 --- a/pkg/transport/tcp_transport.go +++ b/pkg/transport/tcp_transport.go @@ -24,12 +24,12 @@ type TCPFactory struct { } // NewTCPFactory constructs a new TCP Factory. -func NewTCPFactory(lpk cipher.PubKey, pkt PubKeyTable, l *net.TCPListener) Factory { +func NewTCPFactory(lpk cipher.PubKey, pkt PubKeyTable, l *net.TCPListener) *TCPFactory { return &TCPFactory{l, lpk, pkt} } // Accept accepts a remotely-initiated Transport. -func (f *TCPFactory) Accept(ctx context.Context) (Transport, error) { +func (f *TCPFactory) Accept(ctx context.Context) (*TCPTransport, error) { conn, err := f.l.AcceptTCP() if err != nil { return nil, err @@ -45,7 +45,7 @@ func (f *TCPFactory) Accept(ctx context.Context) (Transport, error) { } // Dial initiates a Transport with a remote node. -func (f *TCPFactory) Dial(ctx context.Context, remote cipher.PubKey) (Transport, error) { +func (f *TCPFactory) Dial(ctx context.Context, remote cipher.PubKey) (*TCPTransport, error) { raddr := f.pkt.RemoteAddr(remote) if raddr == nil { return nil, ErrUnknownRemote diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go index ee35cf75bb..362b157ee9 100644 --- a/pkg/transport/transport.go +++ b/pkg/transport/transport.go @@ -3,63 +3,15 @@ package transport import ( - "context" "crypto/sha256" - "math/big" - "time" - "github.com/google/uuid" "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" + "math/big" ) var log = logging.MustGetLogger("transport") -// Transport represents communication between two nodes via a single hop. -type Transport interface { - - // Read implements io.Reader - Read(p []byte) (n int, err error) - - // Write implements io.Writer - Write(p []byte) (n int, err error) - - // Close implements io.Closer - Close() error - - // LocalPK returns local public key of transport - LocalPK() cipher.PubKey - - // RemotePK returns remote public key of transport - RemotePK() cipher.PubKey - - // SetDeadline functions the same as that from net.Conn - // With a Transport, we don't have a distinction between write and read timeouts. - SetDeadline(t time.Time) error - - // Type returns the string representation of the transport type. - Type() string -} - -// Factory generates Transports of a certain type. -type Factory interface { - - // Accept accepts a remotely-initiated Transport. - Accept(ctx context.Context) (Transport, error) - - // Dial initiates a Transport with a remote node. - Dial(ctx context.Context, remote cipher.PubKey) (Transport, error) - - // Close implements io.Closer - Close() error - - // Local returns the local public key. - Local() cipher.PubKey - - // Type returns the Transport type. - Type() string -} - // MakeTransportID generates uuid.UUID from pair of keys + type + public // Generated uuid is: // - always the same for a given pair diff --git a/pkg/visor/config.go b/pkg/visor/config.go index bb82104240..8b603f2804 100644 --- a/pkg/visor/config.go +++ b/pkg/visor/config.go @@ -14,7 +14,6 @@ import ( "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/skywire/pkg/transport" trClient "github.com/skycoin/skywire/pkg/transport-discovery/client" - "github.com/skycoin/skywire/pkg/transport/dmsg" ) // Config defines configuration parameters for Node. @@ -64,7 +63,7 @@ type Config struct { } // MessagingConfig returns config for dmsg client. -func (c *Config) MessagingConfig() (*dmsg.Config, error) { +func (c *Config) MessagingConfig() (*DmsgConfig, error) { msgConfig := c.Messaging @@ -72,7 +71,7 @@ func (c *Config) MessagingConfig() (*dmsg.Config, error) { return nil, errors.New("empty discovery") } - return &dmsg.Config{ + return &DmsgConfig{ PubKey: c.Node.StaticPubKey, SecKey: c.Node.StaticSecKey, Discovery: disc.NewHTTP(msgConfig.Discovery), @@ -164,6 +163,14 @@ type HypervisorConfig struct { Addr string `json:"address"` } +type DmsgConfig struct { + PubKey cipher.PubKey + SecKey cipher.SecKey + Discovery disc.APIClient + Retries int + RetryDelay time.Duration +} + // AppConfig defines app startup parameters. type AppConfig struct { Version string `json:"version"` diff --git a/pkg/visor/rpc.go b/pkg/visor/rpc.go index 701a1ff1f1..2f79d9d997 100644 --- a/pkg/visor/rpc.go +++ b/pkg/visor/rpc.go @@ -78,7 +78,7 @@ func (r *RPC) Summary(_ *struct{}, out *Summary) error { var summaries []*TransportSummary r.node.tm.WalkTransports(func(tp *transport.ManagedTransport) bool { summaries = append(summaries, - newTransportSummary(r.node.tm, tp, false, r.node.router.IsSetupTransport(tp))) + newTransportSummary(r.node.tm, tp, false, r.node.router.SetupIsTrusted(tp.Remote()))) return true }) *out = Summary{ @@ -129,7 +129,7 @@ func (r *RPC) SetAutoStart(in *SetAutoStartIn, _ *struct{}) error { // TransportTypes lists all transport types supported by the Node. func (r *RPC) TransportTypes(_ *struct{}, out *[]string) error { - *out = r.node.tm.Factories() + *out = r.node.tm.Networks() return nil } @@ -166,7 +166,7 @@ func (r *RPC) Transports(in *TransportsIn, out *[]*TransportSummary) error { } r.node.tm.WalkTransports(func(tp *transport.ManagedTransport) bool { if typeIncluded(tp.Type()) && pkIncluded(r.node.tm.Local(), tp.Remote()) { - *out = append(*out, newTransportSummary(r.node.tm, tp, in.ShowLogs, r.node.router.IsSetupTransport(tp))) + *out = append(*out, newTransportSummary(r.node.tm, tp, in.ShowLogs, r.node.router.SetupIsTrusted(tp.Remote()))) } return true }) @@ -179,7 +179,7 @@ func (r *RPC) Transport(in *uuid.UUID, out *TransportSummary) error { if tp == nil { return ErrNotFound } - *out = *newTransportSummary(r.node.tm, tp, true, r.node.router.IsSetupTransport(tp)) + *out = *newTransportSummary(r.node.tm, tp, true, r.node.router.SetupIsTrusted(tp.Remote())) return nil } @@ -204,7 +204,7 @@ func (r *RPC) AddTransport(in *AddTransportIn, out *TransportSummary) error { if err != nil { return err } - *out = *newTransportSummary(r.node.tm, tp, false, r.node.router.IsSetupTransport(tp)) + *out = *newTransportSummary(r.node.tm, tp, false, r.node.router.SetupIsTrusted(tp.Remote())) return nil } diff --git a/pkg/visor/visor.go b/pkg/visor/visor.go index 5806fd8104..037b6add36 100644 --- a/pkg/visor/visor.go +++ b/pkg/visor/visor.go @@ -6,6 +6,9 @@ import ( "context" "errors" "fmt" + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skywire/pkg/network" "io" "net" "net/rpc" @@ -27,7 +30,6 @@ import ( "github.com/skycoin/skywire/pkg/router" "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/skywire/pkg/transport" - "github.com/skycoin/skywire/pkg/transport/dmsg" "github.com/skycoin/skywire/pkg/util/pathutil" ) @@ -78,7 +80,7 @@ type PacketRouter interface { io.Closer Serve(ctx context.Context) error ServeApp(conn net.Conn, port routing.Port, appConf *app.Config) error - IsSetupTransport(tr *transport.ManagedTransport) bool + SetupIsTrusted(sPK cipher.PubKey) bool } // Node provides messaging runtime for Apps by setting up all @@ -86,7 +88,7 @@ type PacketRouter interface { type Node struct { config *Config router PacketRouter - messenger *dmsg.Client + n *network.Network tm *transport.Manager rt routing.Table executer appExecuter @@ -109,6 +111,8 @@ type Node struct { // NewNode constructs new Node. func NewNode(config *Config, masterLogger *logging.MasterLogger) (*Node, error) { + ctx := context.Background() + node := &Node{ config: config, executer: newOSExecuter(), @@ -120,12 +124,18 @@ func NewNode(config *Config, masterLogger *logging.MasterLogger) (*Node, error) pk := config.Node.StaticPubKey sk := config.Node.StaticSecKey - mConfig, err := config.MessagingConfig() - if err != nil { - return nil, fmt.Errorf("invalid Messaging config: %s", err) - } - node.messenger = dmsg.NewClient(mConfig.PubKey, mConfig.SecKey, mConfig.Discovery) + fmt.Println("min servers:", config.Messaging.ServerCount) + node.n = network.New(network.Config{ + PubKey: pk, + SecKey: sk, + TpNetworks: []string{dmsg.Type}, // TODO: Have some way to configure this. + DmsgDiscAddr: config.Messaging.Discovery, + DmsgMinSrvs: config.Messaging.ServerCount, + }) + if err := node.n.Init(ctx); err != nil { + return nil, fmt.Errorf("failed to init network: %v", err) + } trDiscovery, err := config.TransportDiscovery() if err != nil { @@ -136,12 +146,14 @@ func NewNode(config *Config, masterLogger *logging.MasterLogger) (*Node, error) return nil, fmt.Errorf("invalid TransportLogStore: %s", err) } tmConfig := &transport.ManagerConfig{ - PubKey: pk, SecKey: sk, + PubKey: pk, + SecKey: sk, + DefaultNodes: config.TrustedNodes, + Networks: []string{dmsg.Type}, // TODO: Have some way to configure this. DiscoveryClient: trDiscovery, LogStore: logStore, - DefaultNodes: config.TrustedNodes, } - node.tm, err = transport.NewManager(tmConfig, config.Routing.SetupNodes, node.messenger) + node.tm, err = transport.NewManager(node.n, tmConfig) if err != nil { return nil, fmt.Errorf("transport manager: %s", err) } @@ -159,7 +171,10 @@ func NewNode(config *Config, masterLogger *logging.MasterLogger) (*Node, error) RouteFinder: routeFinder.NewHTTP(config.Routing.RouteFinder, time.Duration(config.Routing.RouteFinderTimeout)), SetupNodes: config.Routing.SetupNodes, } - r := router.New(rConfig) + r, err := router.New(node.n, rConfig) + if err != nil { + return nil, fmt.Errorf("failed to setup router: %v", err) + } node.router = r node.appsConf, err = config.AppsConfig() @@ -204,11 +219,6 @@ func NewNode(config *Config, masterLogger *logging.MasterLogger) (*Node, error) // Start spawns auto-started Apps, starts router and RPC interfaces . func (node *Node) Start() error { ctx := context.Background() - err := node.messenger.InitiateServerConnections(ctx, node.config.Messaging.ServerCount) - if err != nil { - return fmt.Errorf("%s: %s", dmsg.Type, err) - } - node.logger.Info("Connected to messaging servers") pathutil.EnsureDir(node.dir()) node.closePreviousApps() diff --git a/vendor/github.com/skycoin/dmsg/addr.go b/vendor/github.com/skycoin/dmsg/addr.go new file mode 100644 index 0000000000..65e7f71b28 --- /dev/null +++ b/vendor/github.com/skycoin/dmsg/addr.go @@ -0,0 +1,23 @@ +package dmsg + +import ( + "fmt" + + "github.com/skycoin/dmsg/cipher" +) + +// Addr implements net.Addr for skywire addresses. +type Addr struct { + PK cipher.PubKey + Port uint16 +} + +// Network returns "dmsg" +func (Addr) Network() string { + return Type +} + +// String returns public key and port of node split by colon. +func (a Addr) String() string { + return fmt.Sprintf("%s:%d", a.PK, a.Port) +} diff --git a/vendor/github.com/skycoin/dmsg/client.go b/vendor/github.com/skycoin/dmsg/client.go index 8e6a20fede..2e7d898a2b 100644 --- a/vendor/github.com/skycoin/dmsg/client.go +++ b/vendor/github.com/skycoin/dmsg/client.go @@ -8,7 +8,6 @@ import ( "sync" "time" - "github.com/sirupsen/logrus" "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/dmsg/cipher" @@ -31,270 +30,6 @@ var ( ErrClientAcceptMaxed = errors.New("client accepts buffer maxed") ) -// ClientConn represents a connection between a dmsg.Client and dmsg.Server from a client's perspective. -type ClientConn struct { - log *logging.Logger - - net.Conn // conn to dmsg server - local cipher.PubKey // local client's pk - remoteSrv cipher.PubKey // dmsg server's public key - - // nextInitID keeps track of unused tp_ids to assign a future locally-initiated tp. - // locally-initiated tps use an even tp_id between local and intermediary dms_server. - nextInitID uint16 - - // Transports: map of transports to remote dms_clients (key: tp_id, val: transport). - tps map[uint16]*Transport - mx sync.RWMutex // to protect tps - - done chan struct{} - once sync.Once - wg sync.WaitGroup -} - -// NewClientConn creates a new ClientConn. -func NewClientConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubKey) *ClientConn { - cc := &ClientConn{ - log: log, - Conn: conn, - local: local, - remoteSrv: remote, - nextInitID: randID(true), - tps: make(map[uint16]*Transport), - done: make(chan struct{}), - } - cc.wg.Add(1) - return cc -} - -// RemotePK returns the remote Server's PK that the ClientConn is connected to. -func (c *ClientConn) RemotePK() cipher.PubKey { return c.remoteSrv } - -func (c *ClientConn) getNextInitID(ctx context.Context) (uint16, error) { - for { - select { - case <-c.done: - return 0, ErrClientClosed - case <-ctx.Done(): - return 0, ctx.Err() - default: - if ch := c.tps[c.nextInitID]; ch != nil && !ch.IsClosed() { - c.nextInitID += 2 - continue - } - c.tps[c.nextInitID] = nil - id := c.nextInitID - c.nextInitID = id + 2 - return id, nil - } - } -} - -func (c *ClientConn) addTp(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) { - c.mx.Lock() - defer c.mx.Unlock() - - id, err := c.getNextInitID(ctx) - if err != nil { - return nil, err - } - tp := NewTransport(c.Conn, c.log, c.local, clientPK, id, c.delTp) - c.tps[id] = tp - return tp, nil -} - -func (c *ClientConn) setTp(tp *Transport) { - c.mx.Lock() - c.tps[tp.id] = tp - c.mx.Unlock() -} - -func (c *ClientConn) delTp(id uint16) { - c.mx.Lock() - c.tps[id] = nil - c.mx.Unlock() -} - -func (c *ClientConn) getTp(id uint16) (*Transport, bool) { - c.mx.RLock() - tp := c.tps[id] - c.mx.RUnlock() - ok := tp != nil && !tp.IsClosed() - return tp, ok -} - -func (c *ClientConn) setNextInitID(nextInitID uint16) { - c.mx.Lock() - c.nextInitID = nextInitID - c.mx.Unlock() -} - -func (c *ClientConn) readOK() error { - fr, err := readFrame(c.Conn) - if err != nil { - return errors.New("failed to get OK from server") - } - - ft, _, _ := fr.Disassemble() - if ft != OkType { - return fmt.Errorf("wrong frame from server: %v", ft) - } - - return nil -} - -func (c *ClientConn) handleRequestFrame(accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) { - // remotely-initiated tps should: - // - have a payload structured as 'init_pk:resp_pk'. - // - resp_pk should be of local client. - // - use an odd tp_id with the intermediary dmsg_server. - initPK, respPK, ok := splitPKs(p) - if !ok || respPK != c.local || isInitiatorID(id) { - if err := writeCloseFrame(c.Conn, id, 0); err != nil { - return initPK, err - } - return initPK, ErrRequestCheckFailed - } - - tp := NewTransport(c.Conn, c.log, c.local, initPK, id, c.delTp) - - select { - case <-c.done: - if err := tp.Close(); err != nil { - log.WithError(err).Warn("Failed to close transport") - } - return initPK, ErrClientClosed - default: - select { - case accept <- tp: - c.setTp(tp) - if err := tp.WriteAccept(); err != nil { - return initPK, err - } - go tp.Serve() - return initPK, nil - - default: - if err := tp.Close(); err != nil { - log.WithError(err).Warn("Failed to close transport") - } - return initPK, ErrClientAcceptMaxed - } - } -} - -// Serve handles incoming frames. -// Remote-initiated tps that are successfully created are pushing into 'accept' and exposed via 'Client.Accept()'. -func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err error) { - log := c.log.WithField("remoteServer", c.remoteSrv) - log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") - defer func() { - c.close() - log.WithError(err).WithField("connCount", decrementServeCount()).Infoln("ConnectionClosed") - c.wg.Done() - }() - - for { - f, err := readFrame(c.Conn) - if err != nil { - return fmt.Errorf("read failed: %s", err) - } - log = log.WithField("received", f) - - ft, id, p := f.Disassemble() - - // If tp of tp_id exists, attempt to forward frame to tp. - // delete tp on any failure. - - if tp, ok := c.getTp(id); ok { - if err := tp.HandleFrame(f); err != nil { - log.WithError(err).Warnf("Rejected [%s]: Transport closed.", ft) - } - continue - } - - // if tp does not exist, frame should be 'REQUEST'. - // otherwise, handle any unexpected frames accordingly. - - c.delTp(id) // rm tp in case closed tp is not fully removed. - - switch ft { - case RequestType: - c.wg.Add(1) - go func(log *logrus.Entry) { - defer c.wg.Done() - initPK, err := c.handleRequestFrame(accept, id, p) - if err != nil { - log.WithField("remoteClient", initPK).WithError(err).Infoln("Rejected [REQUEST]") - if isWriteError(err) || err == ErrClientClosed { - err := c.Close() - log.WithError(err).Warn("ClosingConnection") - } - return - } - log.WithField("remoteClient", initPK).Infoln("Accepted [REQUEST]") - }(log) - - default: - log.Debugf("Ignored [%s]: No transport of given ID.", ft) - if ft != CloseType { - if err := writeCloseFrame(c.Conn, id, 0); err != nil { - return err - } - } - } - } -} - -// DialTransport dials a transport to remote dms_client. -func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) { - tp, err := c.addTp(ctx, clientPK) - if err != nil { - return nil, err - } - if err := tp.WriteRequest(); err != nil { - return nil, err - } - if err := tp.ReadAccept(ctx); err != nil { - return nil, err - } - go tp.Serve() - return tp, nil -} - -func (c *ClientConn) close() (closed bool) { - if c == nil { - return false - } - c.once.Do(func() { - closed = true - c.log.WithField("remoteServer", c.remoteSrv).Infoln("ClosingConnection") - close(c.done) - c.mx.Lock() - for _, tp := range c.tps { - tp := tp - go func() { - if err := tp.Close(); err != nil { - log.WithError(err).Warn("Failed to close transport") - } - }() - } - if err := c.Conn.Close(); err != nil { - log.WithError(err).Warn("Failed to close connection") - } - c.mx.Unlock() - }) - return closed -} - -// Close closes the connection to dms_server. -func (c *ClientConn) Close() error { - if c.close() { - c.wg.Wait() - } - return nil -} - // ClientOption represents an optional argument for Client. type ClientOption func(c *Client) error @@ -320,21 +55,25 @@ type Client struct { conns map[cipher.PubKey]*ClientConn // conns with messaging servers. Key: pk of server mx sync.RWMutex - accept chan *Transport - done chan struct{} - once sync.Once + pm *PortManager + + // accept map[uint16]chan *transport + done chan struct{} + once sync.Once } // NewClient creates a new Client. func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc disc.APIClient, opts ...ClientOption) *Client { c := &Client{ - log: logging.MustGetLogger("dmsg_client"), - pk: pk, - sk: sk, - dc: dc, - conns: make(map[cipher.PubKey]*ClientConn), - accept: make(chan *Transport, AcceptBufferSize), - done: make(chan struct{}), + log: logging.MustGetLogger("dmsg_client"), + pk: pk, + sk: sk, + dc: dc, + conns: make(map[cipher.PubKey]*ClientConn), + pm: newPortManager(), + // accept: make(chan *transport, AcceptBufferSize), + // accept: make(map[uint16]chan *transport), + done: make(chan struct{}), } for _, opt := range opts { if err := opt(c); err != nil { @@ -474,7 +213,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey) return nil, err } - conn := NewClientConn(c.log, nc, c.pk, srvPK) + conn := NewClientConn(c.log, nc, c.pk, srvPK, c.pm) if err := conn.readOK(); err != nil { return nil, err } @@ -482,7 +221,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey) c.setConn(ctx, conn) go func() { - err := conn.Serve(ctx, c.accept) + err := conn.Serve(ctx) conn.log.WithError(err).WithField("remoteServer", srvPK).Warn("connected with server closed") c.delConn(ctx, srvPK) @@ -503,23 +242,17 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey) return conn, nil } -// Accept accepts remotely-initiated tps. -func (c *Client) Accept(ctx context.Context) (*Transport, error) { - select { - case tp, ok := <-c.accept: - if !ok { - return nil, ErrClientClosed - } - return tp, nil - case <-c.done: - return nil, ErrClientClosed - case <-ctx.Done(): - return nil, ctx.Err() +// Listen creates a listener on a given port, adds it to port manager and returns the listener. +func (c *Client) Listen(port uint16) (*Listener, error) { + l, ok := c.pm.NewListener(c.pk, port) + if !ok { + return nil, errors.New("port is busy") } + return l, nil } // Dial dials a transport to remote dms_client. -func (c *Client) Dial(ctx context.Context, remote cipher.PubKey) (*Transport, error) { +func (c *Client) Dial(ctx context.Context, remote cipher.PubKey, port uint16) (*Transport, error) { entry, err := c.dc.Entry(ctx, remote) if err != nil { return nil, fmt.Errorf("get entry failure: %s", err) @@ -536,14 +269,16 @@ func (c *Client) Dial(ctx context.Context, remote cipher.PubKey) (*Transport, er c.log.WithError(err).Warn("failed to connect to server") continue } - return conn.DialTransport(ctx, remote) + return conn.DialTransport(ctx, remote, port) } return nil, errors.New("failed to find dms_servers for given client pk") } -// Local returns the local dms_client's public key. -func (c *Client) Local() cipher.PubKey { - return c.pk +// Addr returns the local dms_client's public key. +func (c *Client) Addr() net.Addr { + return Addr{ + PK: c.pk, + } } // Type returns the transport type. @@ -570,14 +305,13 @@ func (c *Client) Close() error { c.conns = make(map[cipher.PubKey]*ClientConn) c.mx.Unlock() - for { - select { - case <-c.accept: - default: - close(c.accept) - return - } + c.pm.mu.Lock() + defer c.pm.mu.Unlock() + + for _, lis := range c.pm.listeners { + lis.close() } }) + return nil } diff --git a/vendor/github.com/skycoin/dmsg/client_conn.go b/vendor/github.com/skycoin/dmsg/client_conn.go new file mode 100644 index 0000000000..9ee1895af2 --- /dev/null +++ b/vendor/github.com/skycoin/dmsg/client_conn.go @@ -0,0 +1,301 @@ +package dmsg + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "sync" + + "github.com/sirupsen/logrus" + "github.com/skycoin/skycoin/src/util/logging" + + "github.com/skycoin/dmsg/cipher" +) + +// ClientConn represents a connection between a dmsg.Client and dmsg.Server from a client's perspective. +type ClientConn struct { + log *logging.Logger + + net.Conn // conn to dmsg server + local cipher.PubKey // local client's pk + remoteSrv cipher.PubKey // dmsg server's public key + + // nextInitID keeps track of unused tp_ids to assign a future locally-initiated tp. + // locally-initiated tps use an even tp_id between local and intermediary dms_server. + nextInitID uint16 + + // Transports: map of transports to remote dms_clients (key: tp_id, val: transport). + tps map[uint16]*Transport + mx sync.RWMutex // to protect tps + + pm *PortManager + + done chan struct{} + once sync.Once + wg sync.WaitGroup +} + +// NewClientConn creates a new ClientConn. +func NewClientConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubKey, pm *PortManager) *ClientConn { + cc := &ClientConn{ + log: log, + Conn: conn, + local: local, + remoteSrv: remote, + nextInitID: randID(true), + tps: make(map[uint16]*Transport), + pm: pm, + done: make(chan struct{}), + } + cc.wg.Add(1) + return cc +} + +// RemotePK returns the remote Server's PK that the ClientConn is connected to. +func (c *ClientConn) RemotePK() cipher.PubKey { return c.remoteSrv } + +func (c *ClientConn) getNextInitID(ctx context.Context) (uint16, error) { + for { + select { + case <-c.done: + return 0, ErrClientClosed + case <-ctx.Done(): + return 0, ctx.Err() + default: + if ch := c.tps[c.nextInitID]; ch != nil && !ch.IsClosed() { + c.nextInitID += 2 + continue + } + c.tps[c.nextInitID] = nil + id := c.nextInitID + c.nextInitID = id + 2 + return id, nil + } + } +} + +func (c *ClientConn) addTp(ctx context.Context, rPK cipher.PubKey, lPort, rPort uint16) (*Transport, error) { + c.mx.Lock() + defer c.mx.Unlock() + + id, err := c.getNextInitID(ctx) + if err != nil { + return nil, err + } + tp := NewTransport(c.Conn, c.log, Addr{c.local, lPort}, Addr{rPK, rPort}, id, c.delTp) + c.tps[id] = tp + return tp, nil +} + +func (c *ClientConn) setTp(tp *Transport) { + c.mx.Lock() + c.tps[tp.id] = tp + c.mx.Unlock() +} + +func (c *ClientConn) delTp(id uint16) { + c.mx.Lock() + c.tps[id] = nil + c.mx.Unlock() +} + +func (c *ClientConn) getTp(id uint16) (*Transport, bool) { + c.mx.RLock() + tp := c.tps[id] + c.mx.RUnlock() + ok := tp != nil && !tp.IsClosed() + return tp, ok +} + +func (c *ClientConn) setNextInitID(nextInitID uint16) { + c.mx.Lock() + c.nextInitID = nextInitID + c.mx.Unlock() +} + +func (c *ClientConn) readOK() error { + fr, err := readFrame(c.Conn) + if err != nil { + return errors.New("failed to get OK from server") + } + + ft, _, _ := fr.Disassemble() + if ft != OkType { + return fmt.Errorf("wrong frame from server: %v", ft) + } + + return nil +} + +func (c *ClientConn) handleRequestFrame(id uint16, p []byte) (cipher.PubKey, error) { + // remotely-initiated tps should: + // - have a payload structured as HandshakePayload marshaled to JSON. + // - resp_pk should be of local client. + // - use an odd tp_id with the intermediary dmsg_server. + payload, err := unmarshalHandshakePayload(p) + if err != nil { + // TODO(nkryuchkov): When implementing reasons, send that payload format is incorrect. + if err := writeCloseFrame(c.Conn, id, PlaceholderReason); err != nil { + return cipher.PubKey{}, err + } + return cipher.PubKey{}, ErrRequestCheckFailed + } + + if payload.RespPK != c.local || isInitiatorID(id) { + // TODO(nkryuchkov): When implementing reasons, send that payload is malformed. + if err := writeCloseFrame(c.Conn, id, PlaceholderReason); err != nil { + return payload.InitPK, err + } + return payload.InitPK, ErrRequestCheckFailed + } + + lis, ok := c.pm.Listener(payload.Port) + if !ok { + // TODO(nkryuchkov): When implementing reasons, send that port is not listening + if err := writeCloseFrame(c.Conn, id, PlaceholderReason); err != nil { + return payload.InitPK, err + } + return payload.InitPK, ErrPortNotListening + } + + tp := NewTransport(c.Conn, c.log, Addr{c.local, payload.Port}, Addr{payload.InitPK, 0}, id, c.delTp) // TODO: Have proper remote port. + + select { + case <-c.done: + if err := tp.Close(); err != nil { + log.WithError(err).Warn("Failed to close transport") + } + return payload.InitPK, ErrClientClosed + + default: + err := lis.IntroduceTransport(tp) + if err == nil || err == ErrClientAcceptMaxed { + c.setTp(tp) + } + return payload.InitPK, err + } +} + +// Serve handles incoming frames. +// Remote-initiated tps that are successfully created are pushing into 'accept' and exposed via 'Client.Accept()'. +func (c *ClientConn) Serve(ctx context.Context) (err error) { + log := c.log.WithField("remoteServer", c.remoteSrv) + log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") + defer func() { + c.close() + log.WithError(err).WithField("connCount", decrementServeCount()).Infoln("ConnectionClosed") + c.wg.Done() + }() + + for { + f, err := readFrame(c.Conn) + if err != nil { + return fmt.Errorf("read failed: %s", err) + } + log = log.WithField("received", f) + + ft, id, p := f.Disassemble() + + // If tp of tp_id exists, attempt to forward frame to tp. + // delete tp on any failure. + + if tp, ok := c.getTp(id); ok { + if err := tp.HandleFrame(f); err != nil { + log.WithError(err).Warnf("Rejected [%s]: Transport closed.", ft) + } + continue + } + + // if tp does not exist, frame should be 'REQUEST'. + // otherwise, handle any unexpected frames accordingly. + + c.delTp(id) // rm tp in case closed tp is not fully removed. + + switch ft { + case RequestType: + c.wg.Add(1) + go func(log *logrus.Entry) { + defer c.wg.Done() + initPK, err := c.handleRequestFrame(id, p) + if err != nil { + log.WithField("remoteClient", initPK).WithError(err).Infoln("Rejected [REQUEST]") + if isWriteError(err) || err == ErrClientClosed { + err := c.Close() + log.WithError(err).Warn("ClosingConnection") + } + return + } + log.WithField("remoteClient", initPK).Infoln("Accepted [REQUEST]") + }(log) + + default: + log.Debugf("Ignored [%s]: No transport of given ID.", ft) + if ft != CloseType { + if err := writeCloseFrame(c.Conn, id, PlaceholderReason); err != nil { + return err + } + } + } + } +} + +// DialTransport dials a transport to remote dms_client. +func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey, port uint16) (*Transport, error) { + tp, err := c.addTp(ctx, clientPK, 0, port) // TODO: Have proper local port. + if err != nil { + return nil, err + } + if err := tp.WriteRequest(port); err != nil { + return nil, err + } + if err := tp.ReadAccept(ctx); err != nil { + return nil, err + } + go tp.Serve() + return tp, nil +} + +func (c *ClientConn) close() (closed bool) { + if c == nil { + return false + } + c.once.Do(func() { + closed = true + c.log.WithField("remoteServer", c.remoteSrv).Infoln("ClosingConnection") + close(c.done) + c.mx.Lock() + for _, tp := range c.tps { + tp := tp + go func() { + if err := tp.Close(); err != nil { + log.WithError(err).Warn("Failed to close transport") + } + }() + } + if err := c.Conn.Close(); err != nil { + log.WithError(err).Warn("Failed to close connection") + } + c.mx.Unlock() + }) + return closed +} + +// Close closes the connection to dms_server. +func (c *ClientConn) Close() error { + if c.close() { + c.wg.Wait() + } + return nil +} + +func marshalHandshakePayload(p HandshakePayload) ([]byte, error) { + return json.Marshal(p) +} + +func unmarshalHandshakePayload(b []byte) (HandshakePayload, error) { + var p HandshakePayload + err := json.Unmarshal(b, &p) + return p, err +} diff --git a/vendor/github.com/skycoin/dmsg/frame.go b/vendor/github.com/skycoin/dmsg/frame.go index 78e10edf5f..33b354ef95 100644 --- a/vendor/github.com/skycoin/dmsg/frame.go +++ b/vendor/github.com/skycoin/dmsg/frame.go @@ -16,6 +16,9 @@ import ( const ( // Type returns the transport type string. Type = "dmsg" + // HandshakePayloadVersion contains payload version to maintain compatibility with future versions + // of HandshakePayload format. + HandshakePayloadVersion = "1" tpBufCap = math.MaxUint16 tpBufFrameCap = math.MaxUint8 @@ -31,6 +34,15 @@ var ( AcceptBufferSize = 20 ) +// HandshakePayload represents format of payload sent with REQUEST frames. +// TODO(evanlinjin): Use 'dmsg.Addr' for PK:Port pair. +type HandshakePayload struct { + Version string `json:"version"` // just in case the struct changes. + InitPK cipher.PubKey `json:"init_pk"` + RespPK cipher.PubKey `json:"resp_pk"` + Port uint16 `json:"port"` +} + func isInitiatorID(tpID uint16) bool { return tpID%2 == 0 } func randID(initiator bool) uint16 { @@ -76,6 +88,11 @@ const ( AckType = FrameType(0xb) ) +// Reasons for closing frames +const ( + PlaceholderReason = iota +) + // Frame is the dmsg data unit. type Frame []byte diff --git a/vendor/github.com/skycoin/dmsg/listener.go b/vendor/github.com/skycoin/dmsg/listener.go new file mode 100644 index 0000000000..f24de89405 --- /dev/null +++ b/vendor/github.com/skycoin/dmsg/listener.go @@ -0,0 +1,122 @@ +package dmsg + +import ( + "net" + "sync" + + "github.com/skycoin/dmsg/cipher" +) + +// Listener listens for remote-initiated transports. +type Listener struct { + pk cipher.PubKey + port uint16 + mx sync.Mutex // protects 'accept' + accept chan *Transport + done chan struct{} + once sync.Once +} + +func newListener(pk cipher.PubKey, port uint16) *Listener { + return &Listener{ + pk: pk, + port: port, + accept: make(chan *Transport, AcceptBufferSize), + done: make(chan struct{}), + } +} + +// Accept accepts a connection. +func (l *Listener) Accept() (net.Conn, error) { + return l.AcceptTransport() +} + +// Close closes the listener. +func (l *Listener) Close() error { + closed := false + l.once.Do(func() { + closed = true + l.close() + }) + if !closed { + return ErrClientClosed + } + return nil +} + +func (l *Listener) close() { + l.mx.Lock() + defer l.mx.Unlock() + close(l.done) + for { + select { + case <-l.accept: + default: + close(l.accept) + return + } + } +} + +func (l *Listener) isClosed() bool { + select { + case <-l.done: + return true + default: + return false + } +} + +// Addr returns the listener's address. +func (l *Listener) Addr() net.Addr { + return Addr{ + PK: l.pk, + Port: l.port, + } +} + +// AcceptTransport accepts a transport connection. +func (l *Listener) AcceptTransport() (*Transport, error) { + select { + case <-l.done: + return nil, ErrClientClosed + case tp, ok := <-l.accept: + if !ok { + return nil, ErrClientClosed + } + return tp, nil + } +} + +// Type returns the transport type. +func (l *Listener) Type() string { + return Type +} + +// IntroduceTransport handles a transport after receiving a REQUEST frame. +func (l *Listener) IntroduceTransport(tp *Transport) error { + l.mx.Lock() + defer l.mx.Unlock() + + if l.isClosed() { + return ErrClientClosed + } + + select { + case <-l.done: + return ErrClientClosed + + case l.accept <- tp: + if err := tp.WriteAccept(); err != nil { + return err + } + go tp.Serve() + return nil + + default: + if err := tp.Close(); err != nil { + log.WithError(err).Warn("Failed to close transport") + } + return ErrClientAcceptMaxed + } +} diff --git a/vendor/github.com/skycoin/dmsg/port_manager.go b/vendor/github.com/skycoin/dmsg/port_manager.go new file mode 100644 index 0000000000..63540c7017 --- /dev/null +++ b/vendor/github.com/skycoin/dmsg/port_manager.go @@ -0,0 +1,72 @@ +package dmsg + +import ( + "math/rand" + "sync" + "time" + + "github.com/skycoin/dmsg/cipher" +) + +const ( + firstEphemeralPort = 49152 + lastEphemeralPort = 65535 +) + +// PortManager manages ports of nodes. +type PortManager struct { + mu sync.RWMutex + rand *rand.Rand + listeners map[uint16]*Listener +} + +func newPortManager() *PortManager { + return &PortManager{ + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + listeners: make(map[uint16]*Listener), + } +} + +// Listener returns a listener assigned to a given port. +func (pm *PortManager) Listener(port uint16) (*Listener, bool) { + pm.mu.RLock() + defer pm.mu.RUnlock() + + l, ok := pm.listeners[port] + return l, ok +} + +// NewListener assigns listener to port if port is available. +func (pm *PortManager) NewListener(pk cipher.PubKey, port uint16) (*Listener, bool) { + pm.mu.Lock() + defer pm.mu.Unlock() + if _, ok := pm.listeners[port]; ok { + return nil, false + } + l := newListener(pk, port) + pm.listeners[port] = l + return l, true +} + +// RemoveListener removes listener assigned to port. +func (pm *PortManager) RemoveListener(port uint16) { + pm.mu.Lock() + defer pm.mu.Unlock() + + delete(pm.listeners, port) +} + +// NextEmptyEphemeralPort returns next random ephemeral port. +// It has a value between firstEphemeralPort and lastEphemeralPort. +func (pm *PortManager) NextEmptyEphemeralPort() uint16 { + for { + port := pm.randomEphemeralPort() + if _, ok := pm.Listener(port); !ok { + return port + } + } +} + +func (pm *PortManager) randomEphemeralPort() uint16 { + return uint16(firstEphemeralPort + pm.rand.Intn(lastEphemeralPort-firstEphemeralPort)) +} diff --git a/vendor/github.com/skycoin/dmsg/server.go b/vendor/github.com/skycoin/dmsg/server.go index 4433b65cd8..ba0ee3dd37 100644 --- a/vendor/github.com/skycoin/dmsg/server.go +++ b/vendor/github.com/skycoin/dmsg/server.go @@ -143,6 +143,7 @@ func (c *ServerConn) Serve(ctx context.Context, getConn getConnFunc) (err error) log.WithError(err).Warn("Failed to close connection") } }() + log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") err = c.writeOK() @@ -155,7 +156,7 @@ func (c *ServerConn) Serve(ctx context.Context, getConn getConnFunc) (err error) if err != nil { return fmt.Errorf("read failed: %s", err) } - log = log.WithField("received", f) + log := log.WithField("received", f) ft, id, p := f.Disassemble() @@ -200,7 +201,7 @@ func (c *ServerConn) Serve(ctx context.Context, getConn getConnFunc) (err error) func (c *ServerConn) delChan(id uint16, why byte) error { c.delNext(id) - if err := writeFrame(c.Conn, MakeFrame(CloseType, id, []byte{why})); err != nil { + if err := writeCloseFrame(c.Conn, id, why); err != nil { return fmt.Errorf("failed to write frame: %s", err) } return nil @@ -227,11 +228,11 @@ func (c *ServerConn) forwardFrame(ft FrameType, id uint16, p []byte) (*NextConn, // nolint:unparam func (c *ServerConn) handleRequest(ctx context.Context, getLink getConnFunc, id uint16, p []byte) (*NextConn, byte, bool) { - initPK, respPK, ok := splitPKs(p) - if !ok || initPK != c.PK() { + payload, err := unmarshalHandshakePayload(p) + if err != nil || payload.InitPK != c.PK() { return nil, 0, false } - respL, ok := getLink(respPK) + respL, ok := getLink(payload.RespPK) if !ok { return nil, 0, false } diff --git a/vendor/github.com/skycoin/dmsg/testing.go b/vendor/github.com/skycoin/dmsg/testing.go index ef9095b9f8..49a181b755 100644 --- a/vendor/github.com/skycoin/dmsg/testing.go +++ b/vendor/github.com/skycoin/dmsg/testing.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io" + "net" "testing" "time" @@ -42,10 +43,12 @@ func checkConnCount(t *testing.T, delay time.Duration, count int, ccs ...connCou })) } -func checkTransportsClosed(t *testing.T, transports ...*Transport) { - for _, transport := range transports { - assert.False(t, isDoneChanOpen(transport.done)) - assert.False(t, isReadChanOpen(transport.inCh)) +func checkTransportsClosed(t *testing.T, transports ...net.Conn) { + for _, tr := range transports { + if tr, ok := tr.(*Transport); ok && tr != nil { + assert.False(t, isDoneChanOpen(tr.done)) + assert.False(t, isReadChanOpen(tr.inCh)) + } } } diff --git a/vendor/github.com/skycoin/dmsg/transport.go b/vendor/github.com/skycoin/dmsg/transport.go index 734983de93..e9133ce76f 100644 --- a/vendor/github.com/skycoin/dmsg/transport.go +++ b/vendor/github.com/skycoin/dmsg/transport.go @@ -19,16 +19,18 @@ var ( ErrRequestRejected = errors.New("failed to create transport: request rejected") ErrRequestCheckFailed = errors.New("failed to create transport: request check failed") ErrAcceptCheckFailed = errors.New("failed to create transport: accept check failed") + ErrPortNotListening = errors.New("failed to create transport: port not listening") ) -// Transport represents a connection from dmsg.Client to remote dmsg.Client (via dmsg.Server intermediary). +// Transport represents communication between two nodes via a single hop: +// a connection from dmsg.Client to remote dmsg.Client (via dmsg.Server intermediary). type Transport struct { net.Conn // underlying connection to dmsg.Server log *logging.Logger - id uint16 // tp ID that identifies this dmsg.Transport - local cipher.PubKey // local PK - remote cipher.PubKey // remote PK + id uint16 // tp ID that identifies this dmsg.transport + local Addr // local PK + remote Addr // remote PK inCh chan Frame // handles incoming frames (from dmsg.Client) inMx sync.Mutex // protects 'inCh' @@ -49,7 +51,7 @@ type Transport struct { } // NewTransport creates a new dms_tp. -func NewTransport(conn net.Conn, log *logging.Logger, local, remote cipher.PubKey, id uint16, doneFunc func(id uint16)) *Transport { +func NewTransport(conn net.Conn, log *logging.Logger, local, remote Addr, id uint16, doneFunc func(id uint16)) *Transport { tp := &Transport{ Conn: conn, log: log, @@ -113,7 +115,7 @@ func (tp *Transport) close() (closed bool) { // Close closes the dmsg_tp. func (tp *Transport) Close() error { if tp.close() { - if err := writeFrame(tp.Conn, MakeFrame(CloseType, tp.id, []byte{0})); err != nil { + if err := writeCloseFrame(tp.Conn, tp.id, PlaceholderReason); err != nil { log.WithError(err).Warn("Failed to write frame") } } @@ -132,14 +134,20 @@ func (tp *Transport) IsClosed() bool { // LocalPK returns the local public key of the transport. func (tp *Transport) LocalPK() cipher.PubKey { - return tp.local + return tp.local.PK } // RemotePK returns the remote public key of the transport. func (tp *Transport) RemotePK() cipher.PubKey { - return tp.remote + return tp.remote.PK } +// Local returns local address in from : +func (tp *Transport) LocalAddr() net.Addr { return tp.local } + +// Remote returns remote address in form : +func (tp *Transport) RemoteAddr() net.Addr { return tp.remote } + // Type returns the transport type. func (tp *Transport) Type() string { return Type @@ -162,8 +170,18 @@ func (tp *Transport) HandleFrame(f Frame) error { } // WriteRequest writes a REQUEST frame to dmsg_server to be forwarded to associated client. -func (tp *Transport) WriteRequest() error { - f := MakeFrame(RequestType, tp.id, combinePKs(tp.local, tp.remote)) +func (tp *Transport) WriteRequest(port uint16) error { + payload := HandshakePayload{ + Version: HandshakePayloadVersion, + InitPK: tp.local.PK, + RespPK: tp.remote.PK, + Port: port, + } + payloadBytes, err := marshalHandshakePayload(payload) + if err != nil { + return err + } + f := MakeFrame(RequestType, tp.id, payloadBytes) if err := writeFrame(tp.Conn, f); err != nil { tp.log.WithError(err).Error("HandshakeFailed") tp.close() @@ -182,7 +200,7 @@ func (tp *Transport) WriteAccept() (err error) { } }() - f := MakeFrame(AcceptType, tp.id, combinePKs(tp.remote, tp.local)) + f := MakeFrame(AcceptType, tp.id, combinePKs(tp.remote.PK, tp.local.PK)) if err = writeFrame(tp.Conn, f); err != nil { tp.close() return err @@ -225,7 +243,7 @@ func (tp *Transport) ReadAccept(ctx context.Context) (err error) { // - resp_pk should be of remote client. // - use an even number with the intermediary dmsg_server. initPK, respPK, ok := splitPKs(p) - if !ok || initPK != tp.local || respPK != tp.remote || !isInitiatorID(id) { + if !ok || initPK != tp.local.PK || respPK != tp.remote.PK || !isInitiatorID(id) { if err := tp.Close(); err != nil { log.WithError(err).Warn("Failed to close transport") } @@ -257,7 +275,7 @@ func (tp *Transport) Serve() { // also write CLOSE frame if this is the first time 'close' is triggered defer func() { if tp.close() { - if err := writeCloseFrame(tp.Conn, tp.id, 0); err != nil { + if err := writeCloseFrame(tp.Conn, tp.id, PlaceholderReason); err != nil { log.WithError(err).Warn("Failed to write close frame") } } diff --git a/vendor/modules.txt b/vendor/modules.txt index ed78a2a16d..ddacc28645 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -62,7 +62,7 @@ github.com/prometheus/procfs/internal/fs # github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus github.com/sirupsen/logrus/hooks/syslog -# github.com/skycoin/dmsg v0.0.0-20190805065636-70f4c32a994f +# github.com/skycoin/dmsg v0.0.0-20190805065636-70f4c32a994f => ../dmsg github.com/skycoin/dmsg/cipher github.com/skycoin/dmsg github.com/skycoin/dmsg/disc @@ -80,8 +80,8 @@ github.com/spf13/cobra # github.com/spf13/pflag v1.0.3 github.com/spf13/pflag # github.com/stretchr/testify v1.3.0 -github.com/stretchr/testify/require github.com/stretchr/testify/assert +github.com/stretchr/testify/require # go.etcd.io/bbolt v1.3.3 go.etcd.io/bbolt # golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 @@ -94,8 +94,8 @@ golang.org/x/crypto/internal/chacha20 golang.org/x/crypto/internal/subtle golang.org/x/crypto/poly1305 # golang.org/x/net v0.0.0-20190724013045-ca1201d0de80 -golang.org/x/net/nettest golang.org/x/net/context +golang.org/x/net/nettest golang.org/x/net/proxy golang.org/x/net/internal/socks # golang.org/x/sys v0.0.0-20190804053845-51ab0e2deafa