From 8468f56a1525fb9576417a6c152016866b018883 Mon Sep 17 00:00:00 2001 From: Nikita Kryuchkov Date: Sat, 29 Jun 2019 22:36:05 +0300 Subject: [PATCH 01/13] Check if receiver is nil in close() methods --- internal/noise/net.go | 3 +++ internal/therealproxy/client.go | 3 +++ internal/therealproxy/server.go | 3 +++ internal/therealssh/auth.go | 3 +++ internal/therealssh/channel.go | 4 ++++ internal/therealssh/client.go | 8 ++++++++ internal/therealssh/server.go | 4 ++++ internal/therealssh/session.go | 3 +++ pkg/app/app.go | 11 +++++++---- pkg/app/conn.go | 4 ++++ pkg/app/protocol.go | 3 +++ pkg/dmsg/client.go | 13 ++++++++++++- pkg/dmsg/server.go | 4 ++++ pkg/dmsg/transport.go | 7 +++++++ pkg/messaging/channel.go | 7 +++++++ pkg/messaging/factory.go | 4 ++++ pkg/messaging/link.go | 3 +++ pkg/messaging/pool.go | 4 ++++ pkg/node/node.go | 3 +++ pkg/node/node_test.go | 3 +++ pkg/router/port_manager.go | 4 ++++ pkg/router/router.go | 4 ++++ pkg/routing/boltdb_routing_table.go | 3 +++ pkg/setup/node.go | 3 +++ pkg/setup/node_test.go | 4 ++++ pkg/transport/managed_transport.go | 4 ++++ pkg/transport/manager.go | 4 ++++ pkg/transport/mock.go | 6 ++++++ pkg/transport/tcp_transport.go | 3 +++ 29 files changed, 127 insertions(+), 5 deletions(-) diff --git a/internal/noise/net.go b/internal/noise/net.go index 26dc525e27..b19db510f2 100644 --- a/internal/noise/net.go +++ b/internal/noise/net.go @@ -78,6 +78,9 @@ func (d *RPCClientDialer) Run(srv *rpc.Server, retry time.Duration) error { // Close closes the handler. func (d *RPCClientDialer) Close() (err error) { + if d == nil { + return nil + } d.mu.Lock() if d.done != nil { close(d.done) diff --git a/internal/therealproxy/client.go b/internal/therealproxy/client.go index 5bbd5b25b0..397cc4408f 100644 --- a/internal/therealproxy/client.go +++ b/internal/therealproxy/client.go @@ -74,5 +74,8 @@ func (c *Client) ListenAndServe(addr string) error { // Close implement io.Closer. func (c *Client) Close() error { + if c == nil { + return nil + } return c.listener.Close() } diff --git a/internal/therealproxy/server.go b/internal/therealproxy/server.go index d117e8ff5b..4f0c3a7f39 100644 --- a/internal/therealproxy/server.go +++ b/internal/therealproxy/server.go @@ -55,6 +55,9 @@ func (s *Server) Serve(l net.Listener) error { // Close implement io.Closer. func (s *Server) Close() error { + if s == nil { + return nil + } return s.listener.Close() } diff --git a/internal/therealssh/auth.go b/internal/therealssh/auth.go index a8e25f50ec..41d9fbd3ad 100644 --- a/internal/therealssh/auth.go +++ b/internal/therealssh/auth.go @@ -62,6 +62,9 @@ func NewFileAuthorizer(authFile string) (*FileAuthorizer, error) { // Close releases underlying file pointer. func (auth *FileAuthorizer) Close() error { + if auth == nil { + return nil + } return auth.authFile.Close() } diff --git a/internal/therealssh/channel.go b/internal/therealssh/channel.go index c3fd9a1b6e..f8d310a14e 100644 --- a/internal/therealssh/channel.go +++ b/internal/therealssh/channel.go @@ -246,6 +246,10 @@ func (sshCh *SSHChannel) WindowChange(sz *pty.Winsize) error { // Close safely closes Channel resources. func (sshCh *SSHChannel) Close() error { + if sshCh == nil { + return nil + } + select { case <-sshCh.dataCh: default: diff --git a/internal/therealssh/client.go b/internal/therealssh/client.go index bec4631c09..30c1fca807 100644 --- a/internal/therealssh/client.go +++ b/internal/therealssh/client.go @@ -149,6 +149,10 @@ func (c *Client) serveConn(conn net.Conn) error { // Close closes all opened channels. func (c *Client) Close() error { + if c == nil { + return nil + } + for _, sshCh := range c.chans.dropAll() { sshCh.Close() } @@ -226,6 +230,10 @@ func (rpc *RPCClient) WindowChange(args *WindowChangeArgs, _ *int) error { // Close defines close client RPC request. func (rpc *RPCClient) Close(channelID *uint32, _ *struct{}) error { + if rpc == nil { + return nil + } + sshCh := rpc.c.chans.getChannel(*channelID) if sshCh == nil { return errors.New("unknown ssh channel") diff --git a/internal/therealssh/server.go b/internal/therealssh/server.go index c8c987fded..556c0a93e9 100644 --- a/internal/therealssh/server.go +++ b/internal/therealssh/server.go @@ -172,6 +172,10 @@ func (s *Server) Serve(conn net.Conn) error { // Close closes all opened channels. func (s *Server) Close() error { + if s == nil { + return nil + } + for _, channel := range s.chans.dropAll() { channel.Close() } diff --git a/internal/therealssh/session.go b/internal/therealssh/session.go index 555beeb0d7..eefc72649e 100644 --- a/internal/therealssh/session.go +++ b/internal/therealssh/session.go @@ -123,5 +123,8 @@ func (s *Session) Read(p []byte) (int, error) { // Close releases PTY resources. func (s *Session) Close() error { + if s == nil { + return nil + } return s.pty.Close() } diff --git a/pkg/app/app.go b/pkg/app/app.go index 6081052ff2..71b07924eb 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -91,6 +91,10 @@ func Setup(config *Config) (*App, error) { // Close implements io.Closer for an App. func (app *App) Close() error { + if app == nil { + return nil + } + select { case <-app.doneChan: // already closed default: @@ -217,10 +221,6 @@ func (app *App) closeConn(data []byte) error { delete(app.conns, *addr) app.mu.Unlock() - if conn == nil { - return nil - } - return conn.Close() } @@ -277,5 +277,8 @@ func (conn *appConn) Read(p []byte) (n int, err error) { } func (conn *appConn) Close() error { + if conn == nil { + return nil + } return conn.rw.Close() } diff --git a/pkg/app/conn.go b/pkg/app/conn.go index 4bc7927100..43eb4e5dfa 100644 --- a/pkg/app/conn.go +++ b/pkg/app/conn.go @@ -72,6 +72,10 @@ func (conn *PipeConn) Write(b []byte) (n int, err error) { // Close closes the connection. func (conn *PipeConn) Close() error { + if conn == nil { + return nil + } + inErr := conn.inFile.Close() outErr := conn.outFile.Close() if inErr != nil { diff --git a/pkg/app/protocol.go b/pkg/app/protocol.go index 2df8734935..dcdd20cdeb 100644 --- a/pkg/app/protocol.go +++ b/pkg/app/protocol.go @@ -126,6 +126,9 @@ func (p *Protocol) Serve(handleFunc func(Frame, []byte) (interface{}, error)) er // Close closes underlying ReadWriter. func (p *Protocol) Close() error { + if p == nil { + return nil + } p.chans.closeAll() return p.rw.Close() } diff --git a/pkg/dmsg/client.go b/pkg/dmsg/client.go index 0bceec2bb7..3e68aa1efd 100644 --- a/pkg/dmsg/client.go +++ b/pkg/dmsg/client.go @@ -257,14 +257,18 @@ func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey) } func (c *ClientConn) close() (closed bool) { + if c == nil { + return false + } c.once.Do(func() { closed = true c.log.WithField("remoteServer", c.remoteSrv).Infoln("ClosingConnection") close(c.done) c.mx.Lock() for _, tp := range c.tps { + // Nil check is required here to keep 8192 running goroutines limit in tests with -race flag. if tp != nil { - go tp.Close() //nolint:errcheck + go tp.Close() // nolint:errcheck } } _ = c.Conn.Close() //nolint:errcheck @@ -275,6 +279,9 @@ func (c *ClientConn) close() (closed bool) { // Close closes the connection to dms_server. func (c *ClientConn) Close() error { + if c == nil { + return nil + } if c.close() { c.wg.Wait() } @@ -540,6 +547,10 @@ func (c *Client) Type() string { // Close closes the dms_client and associated connections. // TODO(evaninjin): proper error handling. func (c *Client) Close() error { + if c == nil { + return nil + } + c.once.Do(func() { close(c.done) for { diff --git a/pkg/dmsg/server.go b/pkg/dmsg/server.go index 07aa172b25..49f4e64db6 100644 --- a/pkg/dmsg/server.go +++ b/pkg/dmsg/server.go @@ -293,6 +293,10 @@ func (s *Server) connCount() int { // Close closes the dms_server. func (s *Server) Close() (err error) { + if s == nil { + return nil + } + if err = s.lis.Close(); err != nil { return err } diff --git a/pkg/dmsg/transport.go b/pkg/dmsg/transport.go index fdd99faabb..db741117fb 100644 --- a/pkg/dmsg/transport.go +++ b/pkg/dmsg/transport.go @@ -80,6 +80,10 @@ func (tp *Transport) serve() (started bool) { } func (tp *Transport) close() (closed bool) { + if tp == nil { + return false + } + tp.doneOnce.Do(func() { closed = true @@ -102,6 +106,9 @@ func (tp *Transport) close() (closed bool) { // Close closes the dmsg_tp. func (tp *Transport) Close() error { + if tp == nil { + return nil + } if tp.close() { _ = writeFrame(tp.Conn, MakeFrame(CloseType, tp.id, []byte{0})) //nolint:errcheck } diff --git a/pkg/messaging/channel.go b/pkg/messaging/channel.go index 02c298abec..3c0c003f6a 100644 --- a/pkg/messaging/channel.go +++ b/pkg/messaging/channel.go @@ -159,6 +159,10 @@ func (mCh *msgChannel) write(p []byte) (n int, err error) { } func (mCh *msgChannel) Close() error { + if mCh == nil { + return nil + } + select { case <-mCh.doneChan: return ErrChannelClosed @@ -188,6 +192,9 @@ func (mCh *msgChannel) OnChannelClosed() bool { } func (mCh *msgChannel) close() bool { + if mCh == nil { + return false + } closed := false mCh.doneOnce.Do(func() { close(mCh.doneChan) diff --git a/pkg/messaging/factory.go b/pkg/messaging/factory.go index b3c5204e47..10fc1dfb55 100644 --- a/pkg/messaging/factory.go +++ b/pkg/messaging/factory.go @@ -197,6 +197,10 @@ func (msgFactory *MsgFactory) Type() string { // Close closes underlying link pool. func (msgFactory *MsgFactory) Close() error { + if msgFactory == nil { + return nil + } + msgFactory.Logger.Info("Closing link pool") select { case <-msgFactory.doneCh: diff --git a/pkg/messaging/link.go b/pkg/messaging/link.go index f1b724117b..ab4c4ee354 100644 --- a/pkg/messaging/link.go +++ b/pkg/messaging/link.go @@ -94,6 +94,9 @@ func (c *Link) Open(wg *sync.WaitGroup) error { // Close closes the connection with the remote instance. func (c *Link) Close() error { + if c == nil { + return nil + } return c.rw.Close() } diff --git a/pkg/messaging/pool.go b/pkg/messaging/pool.go index 122706768a..208a3f1849 100644 --- a/pkg/messaging/pool.go +++ b/pkg/messaging/pool.go @@ -122,6 +122,10 @@ func (p *Pool) Respond(l net.Listener) error { // Close closes the Pool. func (p *Pool) Close() error { + if p == nil { + return nil + } + p.closeDoneChan() p.listenerMutex.Lock() if p.listener != nil { diff --git a/pkg/node/node.go b/pkg/node/node.go index 0d91b6bc57..942da44137 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -305,6 +305,9 @@ func (node *Node) stopUnhandledApp(name string, pid int) { // Close safely stops spawned Apps and messaging Node. func (node *Node) Close() (err error) { + if node == nil { + return nil + } if node.rpcListener != nil { if err = node.rpcListener.Close(); err != nil { node.logger.WithError(err).Error("failed to stop RPC interface") diff --git a/pkg/node/node_test.go b/pkg/node/node_test.go index de6a0867ea..a3ae8d573c 100644 --- a/pkg/node/node_test.go +++ b/pkg/node/node_test.go @@ -268,6 +268,9 @@ func (r *mockRouter) ServeApp(conn net.Conn, port uint16, appConf *app.Config) e } func (r *mockRouter) Close() error { + if r == nil { + return nil + } r.didClose = true r.Lock() if r.errChan != nil { diff --git a/pkg/router/port_manager.go b/pkg/router/port_manager.go index 718ba9c883..b5d2019687 100644 --- a/pkg/router/port_manager.go +++ b/pkg/router/port_manager.go @@ -62,6 +62,10 @@ func (pm *portManager) AppPorts(appConn *app.Protocol) []uint16 { } func (pm *portManager) Close(port uint16) []app.Addr { + if pm == nil { + return nil + } + b := pm.ports.remove(port) if b == nil { return nil diff --git a/pkg/router/router.go b/pkg/router/router.go index 2f1851fdb7..d54280e924 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -164,6 +164,10 @@ func (r *Router) ServeApp(conn net.Conn, port uint16, appConf *app.Config) error // Close safely stops Router. func (r *Router) Close() error { + if r == nil { + return nil + } + r.Logger.Info("Closing all App connections and Loops") r.expiryTicker.Stop() diff --git a/pkg/routing/boltdb_routing_table.go b/pkg/routing/boltdb_routing_table.go index e2394740bc..1e08a13bde 100644 --- a/pkg/routing/boltdb_routing_table.go +++ b/pkg/routing/boltdb_routing_table.go @@ -142,6 +142,9 @@ func (rt *boltDBRoutingTable) Count() (count int) { // Close closes underlying BoltDB instance. func (rt *boltDBRoutingTable) Close() error { + if rt == nil { + return nil + } return rt.db.Close() } diff --git a/pkg/setup/node.go b/pkg/setup/node.go index 56fe1ed2e4..0ff38c6e42 100644 --- a/pkg/setup/node.go +++ b/pkg/setup/node.go @@ -181,6 +181,9 @@ func (sn *Node) createRoute(expireAt time.Time, route routing.Route, rport, lpor // Close closes underlying transport manager. func (sn *Node) Close() error { + if sn == nil { + return nil + } return sn.tm.Close() } diff --git a/pkg/setup/node_test.go b/pkg/setup/node_test.go index 02ae11f489..b175671acf 100644 --- a/pkg/setup/node_test.go +++ b/pkg/setup/node_test.go @@ -263,6 +263,10 @@ func (f *muxFactory) Dial(ctx context.Context, remote cipher.PubKey) (transport. } func (f *muxFactory) Close() error { + if f == nil { + return nil + } + var err error for _, factory := range f.factories { if fErr := factory.Close(); err == nil && fErr != nil { diff --git a/pkg/transport/managed_transport.go b/pkg/transport/managed_transport.go index 0873f9b2d4..a861b537e6 100644 --- a/pkg/transport/managed_transport.go +++ b/pkg/transport/managed_transport.go @@ -78,6 +78,10 @@ func (tr *ManagedTransport) killWorker() { // Close closes underlying func (tr *ManagedTransport) Close() error { + if tr == nil { + return nil + } + tr.mu.RLock() err := tr.Transport.Close() tr.mu.RUnlock() diff --git a/pkg/transport/manager.go b/pkg/transport/manager.go index fc2b8517c1..4f2fc24637 100644 --- a/pkg/transport/manager.go +++ b/pkg/transport/manager.go @@ -226,6 +226,10 @@ func (tm *Manager) DeleteTransport(id uuid.UUID) error { // Close closes opened transports and registered factories. func (tm *Manager) Close() error { + if tm == nil { + return nil + } + close(tm.doneChan) tm.Logger.Info("Closing transport manager") diff --git a/pkg/transport/mock.go b/pkg/transport/mock.go index a260aaaae6..6077b4c226 100644 --- a/pkg/transport/mock.go +++ b/pkg/transport/mock.go @@ -71,6 +71,9 @@ func (f *MockFactory) Dial(ctx context.Context, remote cipher.PubKey) (Transport // Close closes notification channel between a pair of MockFactories. func (f *MockFactory) Close() error { + if f == nil { + return nil + } select { case <-f.inDone: default: @@ -125,6 +128,9 @@ func (m *MockTransport) Write(p []byte) (n int, err error) { // Close implements closer for mock transport func (m *MockTransport) Close() error { + if m == nil { + return nil + } return m.rw.Close() } diff --git a/pkg/transport/tcp_transport.go b/pkg/transport/tcp_transport.go index b2b117e301..5ebf1873fa 100644 --- a/pkg/transport/tcp_transport.go +++ b/pkg/transport/tcp_transport.go @@ -61,6 +61,9 @@ func (f *TCPFactory) Dial(ctx context.Context, remote cipher.PubKey) (Transport, // Close implements io.Closer func (f *TCPFactory) Close() error { + if f == nil { + return nil + } return f.l.Close() } From c6ffbc59634cd1f2abbb72f2db5467cf763bb354 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Sun, 30 Jun 2019 17:26:56 +0300 Subject: [PATCH 02/13] Protect write/close operations on `dataCh` with `dataChMx` --- internal/therealssh/channel.go | 5 +++++ internal/therealssh/client.go | 2 ++ internal/therealssh/server.go | 2 ++ 3 files changed, 9 insertions(+) diff --git a/internal/therealssh/channel.go b/internal/therealssh/channel.go index c3fd9a1b6e..47abdf4640 100644 --- a/internal/therealssh/channel.go +++ b/internal/therealssh/channel.go @@ -11,6 +11,7 @@ import ( "os/user" "path/filepath" "strings" + "sync" "github.com/kr/pty" @@ -35,6 +36,8 @@ type SSHChannel struct { session *Session listener *net.UnixListener + + dataChMx sync.Mutex dataCh chan []byte } @@ -249,7 +252,9 @@ func (sshCh *SSHChannel) Close() error { select { case <-sshCh.dataCh: default: + sshCh.dataChMx.Lock() close(sshCh.dataCh) + sshCh.dataChMx.Unlock() } close(sshCh.msgCh) diff --git a/internal/therealssh/client.go b/internal/therealssh/client.go index bec4631c09..bc0d9b6534 100644 --- a/internal/therealssh/client.go +++ b/internal/therealssh/client.go @@ -134,7 +134,9 @@ func (c *Client) serveConn(conn net.Conn) error { case CmdChannelOpenResponse, CmdChannelResponse: sshCh.msgCh <- data case CmdChannelData: + sshCh.dataChMx.Lock() sshCh.dataCh <- data + sshCh.dataChMx.Unlock() case CmdChannelServerClose: err = sshCh.Close() default: diff --git a/internal/therealssh/server.go b/internal/therealssh/server.go index c8c987fded..1f21985d36 100644 --- a/internal/therealssh/server.go +++ b/internal/therealssh/server.go @@ -125,7 +125,9 @@ func (s *Server) HandleData(remotePK cipher.PubKey, localID uint32, data []byte) return errors.New("session is not started") } + channel.dataChMx.Lock() channel.dataCh <- data + channel.dataChMx.Unlock() return nil } From 3c6b4f1e3ae7c96d34511fa400da271654ffd372 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Sun, 30 Jun 2019 17:35:26 +0300 Subject: [PATCH 03/13] Add `done` channel, correspoding `IsClosed` func `IsClosed` is now being called before writing to `dataCh` in order to prevent writing to closed channel --- internal/therealssh/channel.go | 17 ++++++++++++++++- internal/therealssh/client.go | 4 +++- internal/therealssh/server.go | 4 +++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/internal/therealssh/channel.go b/internal/therealssh/channel.go index 47abdf4640..4eaa326931 100644 --- a/internal/therealssh/channel.go +++ b/internal/therealssh/channel.go @@ -39,11 +39,14 @@ type SSHChannel struct { dataChMx sync.Mutex dataCh chan []byte + + done chan struct{} } // OpenChannel constructs new SSHChannel with empty Session. func OpenChannel(remoteID uint32, remoteAddr *app.Addr, conn net.Conn) *SSHChannel { - return &SSHChannel{RemoteID: remoteID, conn: conn, RemoteAddr: remoteAddr, msgCh: make(chan []byte), dataCh: make(chan []byte)} + return &SSHChannel{RemoteID: remoteID, conn: conn, RemoteAddr: remoteAddr, msgCh: make(chan []byte), + dataCh: make(chan []byte), done: make(chan struct{})} } // OpenClientChannel constructs new client SSHChannel with empty Session. @@ -249,6 +252,8 @@ func (sshCh *SSHChannel) WindowChange(sz *pty.Winsize) error { // Close safely closes Channel resources. func (sshCh *SSHChannel) Close() error { + close(sshCh.done) + select { case <-sshCh.dataCh: default: @@ -278,6 +283,16 @@ func (sshCh *SSHChannel) Close() error { return nil } +// IsClosed returns whether the Channel is closed. +func (sshCh *SSHChannel) IsClosed() bool { + select { + case <-sshCh.done: + return true + default: + return false + } +} + func debug(format string, v ...interface{}) { if !Debug { return diff --git a/internal/therealssh/client.go b/internal/therealssh/client.go index bc0d9b6534..3f3e767870 100644 --- a/internal/therealssh/client.go +++ b/internal/therealssh/client.go @@ -135,7 +135,9 @@ func (c *Client) serveConn(conn net.Conn) error { sshCh.msgCh <- data case CmdChannelData: sshCh.dataChMx.Lock() - sshCh.dataCh <- data + if !sshCh.IsClosed() { + sshCh.dataCh <- data + } sshCh.dataChMx.Unlock() case CmdChannelServerClose: err = sshCh.Close() diff --git a/internal/therealssh/server.go b/internal/therealssh/server.go index 1f21985d36..3c548af1bf 100644 --- a/internal/therealssh/server.go +++ b/internal/therealssh/server.go @@ -126,7 +126,9 @@ func (s *Server) HandleData(remotePK cipher.PubKey, localID uint32, data []byte) } channel.dataChMx.Lock() - channel.dataCh <- data + if !channel.IsClosed() { + channel.dataCh <- data + } channel.dataChMx.Unlock() return nil } From 33cee7ea2c487812f8d8d94724517bc1911b60d3 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Sun, 30 Jun 2019 17:46:40 +0300 Subject: [PATCH 04/13] Protect `Close` flow with `sync.Once` --- internal/therealssh/channel.go | 66 +++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/internal/therealssh/channel.go b/internal/therealssh/channel.go index 4eaa326931..b6439e8e5b 100644 --- a/internal/therealssh/channel.go +++ b/internal/therealssh/channel.go @@ -40,7 +40,8 @@ type SSHChannel struct { dataChMx sync.Mutex dataCh chan []byte - done chan struct{} + doneOnce sync.Once + done chan struct{} } // OpenChannel constructs new SSHChannel with empty Session. @@ -250,34 +251,51 @@ func (sshCh *SSHChannel) WindowChange(sz *pty.Winsize) error { return sshCh.session.WindowChange(sz) } -// Close safely closes Channel resources. -func (sshCh *SSHChannel) Close() error { - close(sshCh.done) +func (sshCh *SSHChannel) close() (closed bool, err error) { + sshCh.doneOnce.Do(func() { + closed = true - select { - case <-sshCh.dataCh: - default: - sshCh.dataChMx.Lock() - close(sshCh.dataCh) - sshCh.dataChMx.Unlock() - } - close(sshCh.msgCh) + close(sshCh.done) - var sErr, lErr error - if sshCh.session != nil { - sErr = sshCh.session.Close() - } + select { + case <-sshCh.dataCh: + default: + sshCh.dataChMx.Lock() + close(sshCh.dataCh) + sshCh.dataChMx.Unlock() + } + close(sshCh.msgCh) - if sshCh.listener != nil { - lErr = sshCh.listener.Close() - } + var sErr, lErr error + if sshCh.session != nil { + sErr = sshCh.session.Close() + } - if sErr != nil { - return sErr - } + if sshCh.listener != nil { + lErr = sshCh.listener.Close() + } - if lErr != nil { - return lErr + if sErr != nil { + err = sErr + return + } + + if lErr != nil { + err = lErr + } + }) + + return closed, err +} + +// Close safely closes Channel resources. +func (sshCh *SSHChannel) Close() error { + closed, err := sshCh.close() + if err != nil { + return err + } + if !closed { + return errors.New("channel is already closed") } return nil From 25cbf6de0aa933098e21727470632b90a58edef8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E5=BF=97=E5=AE=87?= Date: Wed, 3 Jul 2019 05:25:26 +0800 Subject: [PATCH 05/13] Fixed routing logic. * Fixed various race conditions in 'app' module. * Fix chance where app.appConn does not get closed. * Fixed various memory leaks in managedTransport (where channels are not closed). * Improved transport logging so that we don't need to write to disk on every network operation. --- cmd/apps/skychat/chat.go | 4 +- integration/test-messaging.sh | 2 +- pkg/app/app.go | 20 +++--- pkg/node/rpc.go | 2 +- pkg/node/rpc_test.go | 2 +- pkg/router/router.go | 3 +- pkg/router/router_test.go | 14 ++-- pkg/setup/node_test.go | 16 ++--- pkg/transport/log.go | 73 ++++++++++++++++++-- pkg/transport/log_test.go | 35 ++++++++-- pkg/transport/managed_transport.go | 96 ++++++++++++++------------- pkg/transport/manager.go | 103 +++++++++++++++++------------ pkg/transport/manager_test.go | 36 +++++----- 13 files changed, 262 insertions(+), 144 deletions(-) diff --git a/cmd/apps/skychat/chat.go b/cmd/apps/skychat/chat.go index 2999f76b77..ba80c6f18e 100644 --- a/cmd/apps/skychat/chat.go +++ b/cmd/apps/skychat/chat.go @@ -108,10 +108,10 @@ func messageHandler(w http.ResponseWriter, req *http.Request) { addr := &app.Addr{PubKey: pk, Port: 1} connsMu.Lock() - conn := chatConns[pk] + conn, ok := chatConns[pk] connsMu.Unlock() - if conn == nil { + if !ok { var err error err = r.Do(func() error { conn, err = chatApp.Dial(addr) diff --git a/integration/test-messaging.sh b/integration/test-messaging.sh index 3a8ed4cf49..47e5e1fa26 100755 --- a/integration/test-messaging.sh +++ b/integration/test-messaging.sh @@ -1,4 +1,4 @@ #!/usr/bin/env bash source ./integration/generic/env-vars.sh -# curl --data {'"recipient":"'$PK_A'", "message":"Hello Joe!"}' -X POST $CHAT_C +curl --data {'"recipient":"'$PK_A'", "message":"Hello Joe!"}' -X POST $CHAT_C curl --data {'"recipient":"'$PK_C'", "message":"Hello Mike!"}' -X POST $CHAT_A diff --git a/pkg/app/app.go b/pkg/app/app.go index 6081052ff2..e1b69232f9 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -166,6 +166,8 @@ func (app *App) handleProto() { } func (app *App) serveConn(addr *LoopAddr, conn io.ReadWriteCloser) { + defer conn.Close() + for { buf := make([]byte, 32*1024) n, err := conn.Read(buf) @@ -179,11 +181,10 @@ func (app *App) serveConn(addr *LoopAddr, conn io.ReadWriteCloser) { } } - if app.conns[*addr] != nil { + app.mu.Lock() + if _, ok := app.conns[*addr]; ok { app.proto.Send(FrameClose, &addr, nil) // nolint: errcheck } - - app.mu.Lock() delete(app.conns, *addr) app.mu.Unlock() } @@ -251,13 +252,12 @@ func (app *App) confirmLoop(data []byte) error { type appConn struct { net.Conn - rw io.ReadWriteCloser - laddr *Addr - raddr *Addr + laddr *Addr + raddr *Addr } func newAppConn(conn net.Conn, laddr, raddr *Addr) *appConn { - return &appConn{conn, conn, laddr, raddr} + return &appConn{conn, laddr, raddr} } func (conn *appConn) LocalAddr() net.Addr { @@ -269,13 +269,13 @@ func (conn *appConn) RemoteAddr() net.Addr { } func (conn *appConn) Write(p []byte) (n int, err error) { - return conn.rw.Write(p) + return conn.Conn.Write(p) } func (conn *appConn) Read(p []byte) (n int, err error) { - return conn.rw.Read(p) + return conn.Conn.Read(p) } func (conn *appConn) Close() error { - return conn.rw.Close() + return conn.Conn.Close() } diff --git a/pkg/node/rpc.go b/pkg/node/rpc.go index b7e8ace227..1465103f87 100644 --- a/pkg/node/rpc.go +++ b/pkg/node/rpc.go @@ -55,7 +55,7 @@ func newTransportSummary(tm *transport.Manager, tp *transport.ManagedTransport, } summary := &TransportSummary{ - ID: tp.ID, + ID: tp.Entry.ID, Local: tm.Local(), Remote: remote, Type: tp.Type(), diff --git a/pkg/node/rpc_test.go b/pkg/node/rpc_test.go index feb888bc58..b8ec54c11d 100644 --- a/pkg/node/rpc_test.go +++ b/pkg/node/rpc_test.go @@ -246,7 +246,7 @@ func TestRPC(t *testing.T) { t.Run("Transport", func(t *testing.T) { var ids []uuid.UUID node.tm.WalkTransports(func(tp *transport.ManagedTransport) bool { - ids = append(ids, tp.ID) + ids = append(ids, tp.Entry.ID) return true }) diff --git a/pkg/router/router.go b/pkg/router/router.go index a920d79f91..9f5e58cdaa 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -98,6 +98,7 @@ func (r *Router) Serve(ctx context.Context) error { } go func(tp transport.Transport) { + defer tp.Close() for { if err := serve(tp); err != nil { if err != io.EOF { @@ -423,7 +424,7 @@ func (r *Router) setupProto(ctx context.Context) (*setup.Protocol, transport.Tra // TODO(evanlinjin): need string constant for tp type. tr, err := r.tm.CreateTransport(ctx, r.config.SetupNodes[0], dmsg.Type, false) if err != nil { - return nil, nil, fmt.Errorf("transport: %s", err) + return nil, nil, fmt.Errorf("setup transport: %s", err) } sProto := setup.NewSetupProtocol(tr) diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 508950ea1f..93f2207657 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -85,7 +85,7 @@ func TestRouterForwarding(t *testing.T) { tr3, err := m3.CreateTransport(context.TODO(), pk2, "mock2", true) require.NoError(t, err) - rule := routing.ForwardRule(time.Now().Add(time.Hour), 4, tr3.ID) + rule := routing.ForwardRule(time.Now().Add(time.Hour), 4, tr3.Entry.ID) routeID, err := rt.AddRule(rule) require.NoError(t, err) @@ -197,9 +197,9 @@ func TestRouterApp(t *testing.T) { ni1, ni2 := noiseInstances(t, pk1, pk2, sk1, sk2) raddr := &app.Addr{PubKey: pk2, Port: 5} - require.NoError(t, r.pm.SetLoop(6, raddr, &loop{tr.ID, 4, ni1})) + require.NoError(t, r.pm.SetLoop(6, raddr, &loop{tr.Entry.ID, 4, ni1})) - tr2 := m2.Transport(tr.ID) + tr2 := m2.Transport(tr.Entry.ID) go proto.Send(app.FrameSend, &app.Packet{Addr: &app.LoopAddr{Port: 6, Remote: *raddr}, Payload: []byte("bar")}, nil) // nolint: errcheck packet := make(routing.Packet, 29) @@ -333,13 +333,13 @@ func TestRouterSetup(t *testing.T) { var routeID routing.RouteID t.Run("add route", func(t *testing.T) { - routeID, err = setup.AddRule(sProto, routing.ForwardRule(time.Now().Add(time.Hour), 2, tr.ID)) + routeID, err = setup.AddRule(sProto, routing.ForwardRule(time.Now().Add(time.Hour), 2, tr.Entry.ID)) require.NoError(t, err) rule, err := rt.Rule(routeID) require.NoError(t, err) assert.Equal(t, routing.RouteID(2), rule.RouteID()) - assert.Equal(t, tr.ID, rule.TransportID()) + assert.Equal(t, tr.Entry.ID, rule.TransportID()) }) t.Run("`confirm loop - responder", func(t *testing.T) { @@ -371,7 +371,7 @@ func TestRouterSetup(t *testing.T) { loop, err := r.pm.GetLoop(2, &app.Addr{PubKey: pk2, Port: 1}) require.NoError(t, err) require.NotNil(t, loop) - assert.Equal(t, tr.ID, loop.trID) + assert.Equal(t, tr.Entry.ID, loop.trID) assert.Equal(t, routing.RouteID(2), loop.routeID) addrs := [2]*app.Addr{} @@ -427,7 +427,7 @@ func TestRouterSetup(t *testing.T) { l, err := r.pm.GetLoop(2, &app.Addr{PubKey: pk2, Port: 1}) require.NoError(t, err) require.NotNil(t, l) - assert.Equal(t, tr.ID, l.trID) + assert.Equal(t, tr.Entry.ID, l.trID) assert.Equal(t, routing.RouteID(2), l.routeID) addrs := [2]*app.Addr{} diff --git a/pkg/setup/node_test.go b/pkg/setup/node_test.go index 2c5534ea78..224ba4e0be 100644 --- a/pkg/setup/node_test.go +++ b/pkg/setup/node_test.go @@ -99,12 +99,12 @@ func TestCreateLoop(t *testing.T) { l := &routing.Loop{LocalPort: 1, RemotePort: 2, Expiry: time.Now().Add(time.Hour), Forward: routing.Route{ - &routing.Hop{From: pk1, To: pk2, Transport: tr1.ID}, - &routing.Hop{From: pk2, To: pk3, Transport: tr3.ID}, + &routing.Hop{From: pk1, To: pk2, Transport: tr1.Entry.ID}, + &routing.Hop{From: pk2, To: pk3, Transport: tr3.Entry.ID}, }, Reverse: routing.Route{ - &routing.Hop{From: pk3, To: pk2, Transport: tr3.ID}, - &routing.Hop{From: pk2, To: pk1, Transport: tr1.ID}, + &routing.Hop{From: pk3, To: pk2, Transport: tr3.Entry.ID}, + &routing.Hop{From: pk2, To: pk1, Transport: tr1.Entry.ID}, }, } @@ -132,25 +132,25 @@ func TestCreateLoop(t *testing.T) { assert.Equal(t, uint16(1), rule.LocalPort()) rule = rules[2] assert.Equal(t, routing.RuleForward, rule.Type()) - assert.Equal(t, tr1.ID, rule.TransportID()) + assert.Equal(t, tr1.Entry.ID, rule.TransportID()) assert.Equal(t, routing.RouteID(2), rule.RouteID()) rules = n2.getRules() require.Len(t, rules, 2) rule = rules[1] assert.Equal(t, routing.RuleForward, rule.Type()) - assert.Equal(t, tr1.ID, rule.TransportID()) + assert.Equal(t, tr1.Entry.ID, rule.TransportID()) assert.Equal(t, routing.RouteID(1), rule.RouteID()) rule = rules[2] assert.Equal(t, routing.RuleForward, rule.Type()) - assert.Equal(t, tr3.ID, rule.TransportID()) + assert.Equal(t, tr3.Entry.ID, rule.TransportID()) assert.Equal(t, routing.RouteID(2), rule.RouteID()) rules = n3.getRules() require.Len(t, rules, 2) rule = rules[1] assert.Equal(t, routing.RuleForward, rule.Type()) - assert.Equal(t, tr3.ID, rule.TransportID()) + assert.Equal(t, tr3.Entry.ID, rule.TransportID()) assert.Equal(t, routing.RouteID(1), rule.RouteID()) rule = rules[2] assert.Equal(t, routing.RuleApp, rule.Type()) diff --git a/pkg/transport/log.go b/pkg/transport/log.go index 8ae0b1bc32..699be3252f 100644 --- a/pkg/transport/log.go +++ b/pkg/transport/log.go @@ -2,6 +2,7 @@ package transport import ( "encoding/json" + "errors" "fmt" "math/big" "os" @@ -14,8 +15,69 @@ import ( // LogEntry represents a logging entry for a given Transport. // The entry is updated every time a packet is received or sent. type LogEntry struct { - ReceivedBytes *big.Int `json:"received"` // Total received bytes. - SentBytes *big.Int `json:"sent"` // Total sent bytes. + RecvBytes big.Int `json:"recv"` // Total received bytes. + SentBytes big.Int `json:"sent"` // Total sent bytes. + rMx sync.Mutex + sMx sync.Mutex +} + +// AddRecv records read. +func (le *LogEntry) AddRecv(n int64) { + le.rMx.Lock() + le.RecvBytes.Add(&le.RecvBytes, big.NewInt(n)) + le.rMx.Unlock() +} + +// AddSent records write. +func (le *LogEntry) AddSent(n int64) { + le.sMx.Lock() + le.SentBytes.Add(&le.SentBytes, big.NewInt(n)) + le.sMx.Unlock() +} + +// MarshalJSON implements json.Marshaller +func (le *LogEntry) MarshalJSON() ([]byte, error) { + le.rMx.Lock() + recv := le.RecvBytes.String() + le.rMx.Unlock() + + le.sMx.Lock() + sent := le.SentBytes.String() + le.sMx.Unlock() + + data := `{"recv":` + recv + `,"sent":` + sent + `}` + return []byte(data), nil +} + +// GobEncode implements gob.GobEncoder +func (le *LogEntry) GobEncode() ([]byte, error) { + le.rMx.Lock() + rb, err := le.RecvBytes.GobEncode() + le.rMx.Unlock() + if err != nil { + return nil, err + } + le.sMx.Lock() + sb, err := le.SentBytes.GobEncode() + le.sMx.Unlock() + if err != nil { + return nil, err + } + return append(rb, sb...), err +} + +// GobDecode implements gob.GobDecoder +func (le *LogEntry) GobDecode(b []byte) error { + le.rMx.Lock() + err := le.RecvBytes.GobDecode(b) + le.rMx.Unlock() + if err != nil { + return err + } + le.sMx.Lock() + err = le.SentBytes.GobDecode(b) + le.sMx.Unlock() + return err } // LogStore stores transport log entries. @@ -32,14 +94,17 @@ type inMemoryTransportLogStore struct { // InMemoryTransportLogStore implements in-memory TransportLogStore. func InMemoryTransportLogStore() LogStore { return &inMemoryTransportLogStore{ - entries: map[uuid.UUID]*LogEntry{}, + entries: make(map[uuid.UUID]*LogEntry), } } func (tls *inMemoryTransportLogStore) Entry(id uuid.UUID) (*LogEntry, error) { tls.mu.Lock() - entry := tls.entries[id] + entry, ok := tls.entries[id] tls.mu.Unlock() + if !ok { + return entry, errors.New("not found") + } return entry, nil } diff --git a/pkg/transport/log_test.go b/pkg/transport/log_test.go index b118f57deb..52e3a15035 100644 --- a/pkg/transport/log_test.go +++ b/pkg/transport/log_test.go @@ -1,8 +1,9 @@ package transport_test import ( + "encoding/json" + "fmt" "io/ioutil" - "math/big" "os" "testing" @@ -17,16 +18,21 @@ func testTransportLogStore(t *testing.T, logStore transport.LogStore) { t.Helper() id1 := uuid.New() - entry1 := &transport.LogEntry{big.NewInt(100), big.NewInt(200)} + entry1 := new(transport.LogEntry) + entry1.AddRecv(100) + entry1.AddSent(200) + id2 := uuid.New() - entry2 := &transport.LogEntry{big.NewInt(300), big.NewInt(400)} + entry2 := new(transport.LogEntry) + entry2.AddRecv(300) + entry2.AddSent(400) require.NoError(t, logStore.Record(id1, entry1)) require.NoError(t, logStore.Record(id2, entry2)) entry, err := logStore.Entry(id2) require.NoError(t, err) - assert.Equal(t, int64(300), entry.ReceivedBytes.Int64()) + assert.Equal(t, int64(300), entry.RecvBytes.Int64()) assert.Equal(t, int64(400), entry.SentBytes.Int64()) } @@ -43,3 +49,24 @@ func TestFileTransportLogStore(t *testing.T) { require.NoError(t, err) testTransportLogStore(t, ls) } + +func TestLogEntry_MarshalJSON(t *testing.T) { + entry := new(transport.LogEntry) + entry.AddSent(10) + entry.AddRecv(100) + b, err := json.Marshal(entry) + require.NoError(t, err) + fmt.Println(string(b)) + b, err = json.MarshalIndent(entry, "", "\t") + require.NoError(t, err) + fmt.Println(string(b)) +} + +func TestLogEntry_GobEncode(t *testing.T) { + var entry transport.LogEntry + + enc, err := entry.GobEncode() + require.NoError(t, err) + + require.NoError(t, entry.GobDecode(enc)) +} diff --git a/pkg/transport/managed_transport.go b/pkg/transport/managed_transport.go index 0873f9b2d4..f2fcc204de 100644 --- a/pkg/transport/managed_transport.go +++ b/pkg/transport/managed_transport.go @@ -1,55 +1,51 @@ package transport import ( - "math/big" + "context" "sync" - - "github.com/google/uuid" + "time" ) // ManagedTransport is a wrapper transport. It stores status and ID of // the Transport and can notify about network errors. type ManagedTransport struct { Transport - ID uuid.UUID - Public bool + Entry Entry Accepted bool + Setup bool LogEntry *LogEntry - doneChan chan struct{} - errChan chan error - mu sync.RWMutex - once sync.Once - - readLogChan chan int - writeLogChan chan int + done chan struct{} + update chan error + mu sync.RWMutex + once sync.Once } -func newManagedTransport(id uuid.UUID, tr Transport, public bool, accepted bool) *ManagedTransport { +func newManagedTransport(tr Transport, entry Entry, accepted bool) *ManagedTransport { return &ManagedTransport{ - ID: id, - Transport: tr, - Public: public, - Accepted: accepted, - doneChan: make(chan struct{}), - errChan: make(chan error), - readLogChan: make(chan int, 16), - writeLogChan: make(chan int, 16), - LogEntry: &LogEntry{new(big.Int), new(big.Int)}, + Transport: tr, + Entry: entry, + Accepted: accepted, + done: make(chan struct{}), + update: make(chan error, 16), + LogEntry: new(LogEntry), } } // Read reads using underlying func (tr *ManagedTransport) Read(p []byte) (n int, err error) { tr.mu.RLock() - n, err = tr.Transport.Read(p) // TODO: data race. - tr.mu.RUnlock() - - if err != nil { - tr.errChan <- err + n, err = tr.Transport.Read(p) + if n > 0 { + tr.LogEntry.AddRecv(int64(n)) } - - tr.readLogChan <- n + if !tr.isClosing() { + select { + case tr.update <- err: + default: + } + } + tr.mu.RUnlock() return } @@ -57,46 +53,54 @@ func (tr *ManagedTransport) Read(p []byte) (n int, err error) { func (tr *ManagedTransport) Write(p []byte) (n int, err error) { tr.mu.RLock() n, err = tr.Transport.Write(p) - tr.mu.RUnlock() - - if err != nil { - tr.errChan <- err - return + if n > 0 { + tr.LogEntry.AddSent(int64(n)) } - tr.writeLogChan <- n - + if !tr.isClosing() { + select { + case tr.update <- err: + default: + } + } + tr.mu.RUnlock() return } -// killWorker sends signal to Manager.manageTransport goroutine to exit -// it's safe to call it multiple times func (tr *ManagedTransport) killWorker() { tr.once.Do(func() { - close(tr.doneChan) + close(tr.done) }) } +func (tr *ManagedTransport) killUpdate() { + tr.mu.Lock() + close(tr.update) + tr.update = nil + tr.mu.Unlock() +} + // Close closes underlying func (tr *ManagedTransport) Close() error { - tr.mu.RLock() - err := tr.Transport.Close() - tr.mu.RUnlock() - + if tr == nil { + return nil + } tr.killWorker() - return err + return tr.Transport.Close() } func (tr *ManagedTransport) isClosing() bool { select { - case <-tr.doneChan: + case <-tr.done: return true default: return false } } -func (tr *ManagedTransport) updateTransport(newTr Transport) { +func (tr *ManagedTransport) updateTransport(ctx context.Context, newTr Transport, dc DiscoveryClient) error { tr.mu.Lock() tr.Transport = newTr + _, err := dc.UpdateStatuses(ctx, &Status{ID: tr.Entry.ID, IsUp: true, Updated: time.Now().UnixNano()}) tr.mu.Unlock() + return err } diff --git a/pkg/transport/manager.go b/pkg/transport/manager.go index 87f200732c..dd02191ffc 100644 --- a/pkg/transport/manager.go +++ b/pkg/transport/manager.go @@ -4,7 +4,6 @@ import ( "context" "errors" "io" - "math/big" "strings" "sync" "sync/atomic" @@ -210,8 +209,8 @@ func (tm *Manager) CreateTransport(ctx context.Context, remote cipher.PubKey, tp func (tm *Manager) DeleteTransport(id uuid.UUID) error { tm.mu.Lock() if tr, ok := tm.transports[id]; ok { - delete(tm.transports, id) _ = tr.Close() //nolint:errcheck + delete(tm.transports, id) } tm.mu.Unlock() @@ -231,10 +230,10 @@ func (tm *Manager) Close() error { tm.mu.Lock() statuses := make([]*Status, 0) for _, tr := range tm.transports { - if !tr.Public { + if !tr.Entry.Public { continue } - statuses = append(statuses, &Status{ID: tr.ID, IsUp: false}) + statuses = append(statuses, &Status{ID: tr.Entry.ID, IsUp: false}) tr.Close() } @@ -288,7 +287,7 @@ func (tm *Manager) createTransport(ctx context.Context, remote cipher.PubKey, tp } tm.Logger.Infof("Dialed to %s using %s factory. Transport ID: %s", remote, tpType, entry.ID) - mTr := newManagedTransport(entry.ID, tr, entry.Public, false) + mTr := newManagedTransport(tr, *entry, false) tm.mu.Lock() tm.transports[entry.ID] = mTr @@ -298,7 +297,7 @@ func (tm *Manager) createTransport(ctx context.Context, remote cipher.PubKey, tp case <-tm.doneChan: return nil, io.ErrClosedPipe case tm.TrChan <- mTr: - go tm.manageTransport(ctx, mTr, factory, remote, public, false) + go tm.manageTransport(ctx, mTr, factory, remote) return mTr, nil } } @@ -330,7 +329,8 @@ func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) (*Manag if oldTr != nil { oldTr.killWorker() } - mTr := newManagedTransport(entry.ID, tr, entry.Public, true) + + mTr := newManagedTransport(tr, *entry, true) tm.mu.Lock() tm.transports[entry.ID] = mTr @@ -340,7 +340,7 @@ func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) (*Manag case <-tm.doneChan: return nil, io.ErrClosedPipe case tm.TrChan <- mTr: - go tm.manageTransport(ctx, mTr, factory, remote, true, true) + go tm.manageTransport(ctx, mTr, factory, remote) return mTr, nil } } @@ -370,47 +370,68 @@ func (tm *Manager) isClosing() bool { } } -func (tm *Manager) manageTransport(ctx context.Context, mTr *ManagedTransport, factory Factory, remote cipher.PubKey, public bool, accepted bool) { +func (tm *Manager) manageTransport(ctx context.Context, mTr *ManagedTransport, factory Factory, remote cipher.PubKey) { + + logTicker := time.NewTicker(time.Second * 5) + logUpdate := false + mgrQty := atomic.AddInt32(&tm.mgrQty, 1) - tm.Logger.Infof("Spawned manageTransport for mTr.ID: %v. mgrQty: %v", mTr.ID, mgrQty) + tm.Logger.Infof("Spawned manageTransport for mTr.ID: %v. mgrQty: %v PK: %s", mTr.Entry.ID, mgrQty, remote) + + defer func() { + logTicker.Stop() + if logUpdate { + if err := tm.config.LogStore.Record(mTr.Entry.ID, mTr.LogEntry); err != nil { + tm.Logger.Warnf("Failed to record log entry: %s", err) + } + } + mTr.killUpdate() + + mgrQty := atomic.AddInt32(&tm.mgrQty, -1) + tm.Logger.Infof("manageTransport exit for %v. mgrQty: %v", mTr.Entry.ID, mgrQty) + }() + for { select { - case <-mTr.doneChan: - mgrQty := atomic.AddInt32(&tm.mgrQty, -1) - tm.Logger.Infof("manageTransport exit for %v. mgrQty: %v", mTr.ID, mgrQty) + case <-mTr.done: return - case err := <-mTr.errChan: - if !mTr.isClosing() { - tm.Logger.Infof("Transport %s failed with error: %s. Re-dialing...", mTr.ID, err) - if accepted { - if err := tm.DeleteTransport(mTr.ID); err != nil { - tm.Logger.Warnf("Failed to delete accepted transport: %s", err) - } - } else { - tr, _, err := tm.dialTransport(ctx, factory, remote, public) - if err != nil { - tm.Logger.Infof("Failed to redial Transport %s: %s", mTr.ID, err) - if err := tm.DeleteTransport(mTr.ID); err != nil { - tm.Logger.Warnf("Failed to delete redialed transport: %s", err) - } - } else { - tm.Logger.Infof("Updating transport %s", mTr.ID) - mTr.updateTransport(tr) - } + + case <-logTicker.C: + if logUpdate { + if err := tm.config.LogStore.Record(mTr.Entry.ID, mTr.LogEntry); err != nil { + tm.Logger.Warnf("Failed to record log entry: %s", err) } - } else { - tm.Logger.Infof("Transport %s is already closing. Skipped error: %s", mTr.ID, err) } - case n := <-mTr.readLogChan: - mTr.LogEntry.ReceivedBytes.Add(mTr.LogEntry.ReceivedBytes, big.NewInt(int64(n))) - if err := tm.config.LogStore.Record(mTr.ID, mTr.LogEntry); err != nil { - tm.Logger.Warnf("Failed to record log entry: %s", err) + + case err, ok := <-mTr.update: + if !ok { + return } - case n := <-mTr.writeLogChan: - mTr.LogEntry.SentBytes.Add(mTr.LogEntry.SentBytes, big.NewInt(int64(n))) - if err := tm.config.LogStore.Record(mTr.ID, mTr.LogEntry); err != nil { - tm.Logger.Warnf("Failed to record log entry: %s", err) + + if err == nil { + logUpdate = true + continue } + + tm.Logger.Infof("Transport %s failed with error: %s. Re-dialing...", mTr.Entry.ID, err) + if _, err := tm.config.DiscoveryClient.UpdateStatuses(ctx, &Status{ID: mTr.Entry.ID, IsUp: false, Updated: time.Now().UnixNano()}); err != nil { + tm.Logger.Warnf("Failed to change transport status: %s", err) + } + + // If we are the acceptor, we are not responsible for restarting transport. + // If the transport is private, we don't need to restart. + if mTr.Accepted || !mTr.Entry.Public { + return + } + + tr, _, err := tm.dialTransport(ctx, factory, remote, mTr.Entry.Public) + if err != nil { + tm.Logger.Infof("Failed to redial Transport %s: %s", mTr.Entry.ID, err) + continue + } + + tm.Logger.Infof("Updating transport %s", mTr.Entry.ID) + _ = mTr.updateTransport(ctx, tr, tm.config.DiscoveryClient) //nolint:errcheck } } } diff --git a/pkg/transport/manager_test.go b/pkg/transport/manager_test.go index 0994271e16..7d896cf811 100644 --- a/pkg/transport/manager_test.go +++ b/pkg/transport/manager_test.go @@ -87,16 +87,16 @@ func TestTransportManager(t *testing.T) { time.Sleep(time.Second) - tr1 := m1.Transport(tr2.ID) + tr1 := m1.Transport(tr2.Entry.ID) require.NotNil(t, tr1) - dEntry, err := client.GetTransportByID(context.TODO(), tr2.ID) + dEntry, err := client.GetTransportByID(context.TODO(), tr2.Entry.ID) require.NoError(t, err) assert.Equal(t, SortPubKeys(pk2, pk1), dEntry.Entry.Edges()) assert.True(t, dEntry.IsUp) - require.NoError(t, m1.DeleteTransport(tr1.ID)) - dEntry, err = client.GetTransportByID(context.TODO(), tr1.ID) + require.NoError(t, m1.DeleteTransport(tr1.Entry.ID)) + dEntry, err = client.GetTransportByID(context.TODO(), tr1.Entry.ID) require.NoError(t, err) assert.False(t, dEntry.IsUp) @@ -106,12 +106,12 @@ func TestTransportManager(t *testing.T) { time.Sleep(time.Second) - dEntry, err = client.GetTransportByID(context.TODO(), tr1.ID) + dEntry, err = client.GetTransportByID(context.TODO(), tr1.Entry.ID) require.NoError(t, err) assert.True(t, dEntry.IsUp) - require.NoError(t, m2.DeleteTransport(tr2.ID)) - dEntry, err = client.GetTransportByID(context.TODO(), tr2.ID) + require.NoError(t, m2.DeleteTransport(tr2.Entry.ID)) + dEntry, err = client.GetTransportByID(context.TODO(), tr2.Entry.ID) require.NoError(t, err) assert.False(t, dEntry.IsUp) @@ -153,17 +153,17 @@ func TestTransportManagerReEstablishTransports(t *testing.T) { tr2, err := m2.CreateTransport(context.TODO(), pk1, "mock", true) require.NoError(t, err) - tr1 := m1.Transport(tr2.ID) + tr1 := m1.Transport(tr2.Entry.ID) require.NotNil(t, tr1) - dEntry, err := client.GetTransportByID(context.TODO(), tr2.ID) + dEntry, err := client.GetTransportByID(context.TODO(), tr2.Entry.ID) require.NoError(t, err) assert.Equal(t, SortPubKeys(pk2, pk1), dEntry.Entry.Edges()) assert.True(t, dEntry.IsUp) require.NoError(t, m2.Close()) - dEntry2, err := client.GetTransportByID(context.TODO(), tr2.ID) + dEntry2, err := client.GetTransportByID(context.TODO(), tr2.Entry.ID) require.NoError(t, err) assert.False(t, dEntry2.IsUp) @@ -176,7 +176,7 @@ func TestTransportManagerReEstablishTransports(t *testing.T) { go func() { m2errCh <- m2.Serve(context.TODO()) }() //time.Sleep(time.Second * 1) // TODO: this time.Sleep looks fishy - figure out later - dEntry3, err := client.GetTransportByID(context.TODO(), tr2.ID) + dEntry3, err := client.GetTransportByID(context.TODO(), tr2.Entry.ID) require.NoError(t, err) assert.True(t, dEntry3.IsUp) @@ -218,7 +218,7 @@ func TestTransportManagerLogs(t *testing.T) { time.Sleep(100 * time.Millisecond) - tr1 := m1.Transport(tr2.ID) + tr1 := m1.Transport(tr2.Entry.ID) require.NotNil(t, tr1) go tr1.Write([]byte("foo")) // nolint @@ -226,17 +226,17 @@ func TestTransportManagerLogs(t *testing.T) { _, err = tr2.Read(buf) require.NoError(t, err) - time.Sleep(100 * time.Millisecond) + time.Sleep(time.Second * 10) - entry1, err := logStore1.Entry(tr1.ID) + entry1, err := logStore1.Entry(tr1.Entry.ID) require.NoError(t, err) assert.Equal(t, uint64(3), entry1.SentBytes.Uint64()) - assert.Equal(t, uint64(0), entry1.ReceivedBytes.Uint64()) + assert.Equal(t, uint64(0), entry1.RecvBytes.Uint64()) - entry2, err := logStore2.Entry(tr1.ID) + entry2, err := logStore2.Entry(tr1.Entry.ID) require.NoError(t, err) assert.Equal(t, uint64(0), entry2.SentBytes.Uint64()) - assert.Equal(t, uint64(3), entry2.ReceivedBytes.Uint64()) + assert.Equal(t, uint64(3), entry2.RecvBytes.Uint64()) require.NoError(t, m2.Close()) require.NoError(t, m1.Close()) @@ -314,7 +314,7 @@ func ExampleManager_CreateTransport() { return } - if (mtrAB.ID == uuid.UUID{}) { + if (mtrAB.Entry.ID == uuid.UUID{}) { fmt.Printf("Manager.CreateTransport failed on iteration %v", i) return } From 27998929c1723ecc885cde61f7a9b7a92d4297ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E5=BF=97=E5=AE=87?= Date: Wed, 3 Jul 2019 19:50:31 +0800 Subject: [PATCH 06/13] Rm redundant line from test-messaging.sh --- integration/test-messaging.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/integration/test-messaging.sh b/integration/test-messaging.sh index 47e5e1fa26..6a5add8b75 100755 --- a/integration/test-messaging.sh +++ b/integration/test-messaging.sh @@ -1,4 +1,3 @@ #!/usr/bin/env bash -source ./integration/generic/env-vars.sh curl --data {'"recipient":"'$PK_A'", "message":"Hello Joe!"}' -X POST $CHAT_C curl --data {'"recipient":"'$PK_C'", "message":"Hello Mike!"}' -X POST $CHAT_A From 69a6556c7cbcd2d4a6c5abef5bb54c2e2f1d6ece Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E5=BF=97=E5=AE=87?= Date: Wed, 3 Jul 2019 19:50:58 +0800 Subject: [PATCH 07/13] Format code. --- pkg/app/app.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/app/app.go b/pkg/app/app.go index e1b69232f9..1342d9c05e 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -252,8 +252,8 @@ func (app *App) confirmLoop(data []byte) error { type appConn struct { net.Conn - laddr *Addr - raddr *Addr + laddr *Addr + raddr *Addr } func newAppConn(conn net.Conn, laddr, raddr *Addr) *appConn { From 0983ce0ea77cf541cc6212afc87075336990bf80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E5=BF=97=E5=AE=87?= Date: Wed, 3 Jul 2019 21:14:37 +0800 Subject: [PATCH 08/13] Changed to use uint64 for transport.LogEntry. --- pkg/transport/log.go | 69 +++++++++++++----------------- pkg/transport/log_test.go | 4 +- pkg/transport/managed_transport.go | 4 +- pkg/transport/manager_test.go | 8 ++-- 4 files changed, 37 insertions(+), 48 deletions(-) diff --git a/pkg/transport/log.go b/pkg/transport/log.go index 699be3252f..b5a07182e3 100644 --- a/pkg/transport/log.go +++ b/pkg/transport/log.go @@ -1,13 +1,16 @@ package transport import ( + "bytes" + "encoding/gob" "encoding/json" "errors" "fmt" - "math/big" "os" "path/filepath" + "strconv" "sync" + "sync/atomic" "github.com/google/uuid" ) @@ -15,69 +18,55 @@ import ( // LogEntry represents a logging entry for a given Transport. // The entry is updated every time a packet is received or sent. type LogEntry struct { - RecvBytes big.Int `json:"recv"` // Total received bytes. - SentBytes big.Int `json:"sent"` // Total sent bytes. - rMx sync.Mutex - sMx sync.Mutex + RecvBytes uint64 `json:"recv"` // Total received bytes. + SentBytes uint64 `json:"sent"` // Total sent bytes. } // AddRecv records read. -func (le *LogEntry) AddRecv(n int64) { - le.rMx.Lock() - le.RecvBytes.Add(&le.RecvBytes, big.NewInt(n)) - le.rMx.Unlock() +func (le *LogEntry) AddRecv(n uint64) { + atomic.AddUint64(&le.RecvBytes, n) } // AddSent records write. -func (le *LogEntry) AddSent(n int64) { - le.sMx.Lock() - le.SentBytes.Add(&le.SentBytes, big.NewInt(n)) - le.sMx.Unlock() +func (le *LogEntry) AddSent(n uint64) { + atomic.AddUint64(&le.SentBytes, n) } // MarshalJSON implements json.Marshaller func (le *LogEntry) MarshalJSON() ([]byte, error) { - le.rMx.Lock() - recv := le.RecvBytes.String() - le.rMx.Unlock() - - le.sMx.Lock() - sent := le.SentBytes.String() - le.sMx.Unlock() - - data := `{"recv":` + recv + `,"sent":` + sent + `}` - return []byte(data), nil + rb := strconv.FormatUint(atomic.LoadUint64(&le.RecvBytes), 10) + sb := strconv.FormatUint(atomic.LoadUint64(&le.SentBytes), 10) + return []byte(`{"recv":` + rb + `,"sent":` + sb + `}`), nil } // GobEncode implements gob.GobEncoder func (le *LogEntry) GobEncode() ([]byte, error) { - le.rMx.Lock() - rb, err := le.RecvBytes.GobEncode() - le.rMx.Unlock() - if err != nil { + var b bytes.Buffer + enc := gob.NewEncoder(&b) + if err := enc.Encode(le.RecvBytes); err != nil { return nil, err } - le.sMx.Lock() - sb, err := le.SentBytes.GobEncode() - le.sMx.Unlock() - if err != nil { + if err := enc.Encode(le.SentBytes); err != nil { return nil, err } - return append(rb, sb...), err + return b.Bytes(), nil } // GobDecode implements gob.GobDecoder func (le *LogEntry) GobDecode(b []byte) error { - le.rMx.Lock() - err := le.RecvBytes.GobDecode(b) - le.rMx.Unlock() - if err != nil { + r := bytes.NewReader(b) + dec := gob.NewDecoder(r) + var rb uint64 + if err := dec.Decode(&rb); err != nil { + return err + } + var sb uint64 + if err := dec.Decode(&sb); err != nil { return err } - le.sMx.Lock() - err = le.SentBytes.GobDecode(b) - le.sMx.Unlock() - return err + atomic.StoreUint64(&le.RecvBytes, rb) + atomic.StoreUint64(&le.SentBytes, sb) + return nil } // LogStore stores transport log entries. diff --git a/pkg/transport/log_test.go b/pkg/transport/log_test.go index 52e3a15035..1c3f577728 100644 --- a/pkg/transport/log_test.go +++ b/pkg/transport/log_test.go @@ -32,8 +32,8 @@ func testTransportLogStore(t *testing.T, logStore transport.LogStore) { entry, err := logStore.Entry(id2) require.NoError(t, err) - assert.Equal(t, int64(300), entry.RecvBytes.Int64()) - assert.Equal(t, int64(400), entry.SentBytes.Int64()) + assert.Equal(t, uint64(300), entry.RecvBytes) + assert.Equal(t, uint64(400), entry.SentBytes) } func TestInMemoryTransportLogStore(t *testing.T) { diff --git a/pkg/transport/managed_transport.go b/pkg/transport/managed_transport.go index f2fcc204de..5ed1d673be 100644 --- a/pkg/transport/managed_transport.go +++ b/pkg/transport/managed_transport.go @@ -37,7 +37,7 @@ func (tr *ManagedTransport) Read(p []byte) (n int, err error) { tr.mu.RLock() n, err = tr.Transport.Read(p) if n > 0 { - tr.LogEntry.AddRecv(int64(n)) + tr.LogEntry.AddRecv(uint64(n)) } if !tr.isClosing() { select { @@ -54,7 +54,7 @@ func (tr *ManagedTransport) Write(p []byte) (n int, err error) { tr.mu.RLock() n, err = tr.Transport.Write(p) if n > 0 { - tr.LogEntry.AddSent(int64(n)) + tr.LogEntry.AddSent(uint64(n)) } if !tr.isClosing() { select { diff --git a/pkg/transport/manager_test.go b/pkg/transport/manager_test.go index 7d896cf811..c0c7f58360 100644 --- a/pkg/transport/manager_test.go +++ b/pkg/transport/manager_test.go @@ -230,13 +230,13 @@ func TestTransportManagerLogs(t *testing.T) { entry1, err := logStore1.Entry(tr1.Entry.ID) require.NoError(t, err) - assert.Equal(t, uint64(3), entry1.SentBytes.Uint64()) - assert.Equal(t, uint64(0), entry1.RecvBytes.Uint64()) + assert.Equal(t, uint64(3), entry1.SentBytes) + assert.Equal(t, uint64(0), entry1.RecvBytes) entry2, err := logStore2.Entry(tr1.Entry.ID) require.NoError(t, err) - assert.Equal(t, uint64(0), entry2.SentBytes.Uint64()) - assert.Equal(t, uint64(3), entry2.RecvBytes.Uint64()) + assert.Equal(t, uint64(0), entry2.SentBytes) + assert.Equal(t, uint64(3), entry2.RecvBytes) require.NoError(t, m2.Close()) require.NoError(t, m1.Close()) From 3b79fd236d6178aaa4abcd4382e1eba6d96e2da4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E5=BF=97=E5=AE=87?= Date: Wed, 3 Jul 2019 21:27:53 +0800 Subject: [PATCH 09/13] Added transport.logWriteInterval --- pkg/transport/managed_transport.go | 2 ++ pkg/transport/manager.go | 3 +-- pkg/transport/manager_test.go | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pkg/transport/managed_transport.go b/pkg/transport/managed_transport.go index 5ed1d673be..cfab50ade7 100644 --- a/pkg/transport/managed_transport.go +++ b/pkg/transport/managed_transport.go @@ -6,6 +6,8 @@ import ( "time" ) +const logWriteInterval = time.Second * 3 + // ManagedTransport is a wrapper transport. It stores status and ID of // the Transport and can notify about network errors. type ManagedTransport struct { diff --git a/pkg/transport/manager.go b/pkg/transport/manager.go index dd02191ffc..ae643431cb 100644 --- a/pkg/transport/manager.go +++ b/pkg/transport/manager.go @@ -371,8 +371,7 @@ func (tm *Manager) isClosing() bool { } func (tm *Manager) manageTransport(ctx context.Context, mTr *ManagedTransport, factory Factory, remote cipher.PubKey) { - - logTicker := time.NewTicker(time.Second * 5) + logTicker := time.NewTicker(logWriteInterval) logUpdate := false mgrQty := atomic.AddInt32(&tm.mgrQty, 1) diff --git a/pkg/transport/manager_test.go b/pkg/transport/manager_test.go index c0c7f58360..4f0b30f729 100644 --- a/pkg/transport/manager_test.go +++ b/pkg/transport/manager_test.go @@ -226,7 +226,8 @@ func TestTransportManagerLogs(t *testing.T) { _, err = tr2.Read(buf) require.NoError(t, err) - time.Sleep(time.Second * 10) + // 2x log write interval just to be safe. + time.Sleep(logWriteInterval * 2) entry1, err := logStore1.Entry(tr1.Entry.ID) require.NoError(t, err) From ddd2593517905ca6ed7f3c13898183e385c660a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E5=BF=97=E5=AE=87?= Date: Wed, 3 Jul 2019 21:47:11 +0800 Subject: [PATCH 10/13] Fix comments. --- pkg/transport/log.go | 2 +- pkg/transport/managed_transport.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/transport/log.go b/pkg/transport/log.go index b5a07182e3..1983b0ab42 100644 --- a/pkg/transport/log.go +++ b/pkg/transport/log.go @@ -92,7 +92,7 @@ func (tls *inMemoryTransportLogStore) Entry(id uuid.UUID) (*LogEntry, error) { entry, ok := tls.entries[id] tls.mu.Unlock() if !ok { - return entry, errors.New("not found") + return entry, errors.New("transport log entry not found") } return entry, nil diff --git a/pkg/transport/managed_transport.go b/pkg/transport/managed_transport.go index cfab50ade7..66a675fa5e 100644 --- a/pkg/transport/managed_transport.go +++ b/pkg/transport/managed_transport.go @@ -34,7 +34,7 @@ func newManagedTransport(tr Transport, entry Entry, accepted bool) *ManagedTrans } } -// Read reads using underlying +// Read reads using underlying transport. func (tr *ManagedTransport) Read(p []byte) (n int, err error) { tr.mu.RLock() n, err = tr.Transport.Read(p) @@ -51,7 +51,7 @@ func (tr *ManagedTransport) Read(p []byte) (n int, err error) { return } -// Write writes to an underlying +// Write writes to an underlying transport. func (tr *ManagedTransport) Write(p []byte) (n int, err error) { tr.mu.RLock() n, err = tr.Transport.Write(p) @@ -81,7 +81,7 @@ func (tr *ManagedTransport) killUpdate() { tr.mu.Unlock() } -// Close closes underlying +// Close closes underlying transport and kills worker. func (tr *ManagedTransport) Close() error { if tr == nil { return nil From 51d2ab0bc49784f755bac52b9f2b98dec01c73aa Mon Sep 17 00:00:00 2001 From: Nikita Kryuchkov Date: Wed, 3 Jul 2019 16:59:11 +0300 Subject: [PATCH 11/13] Fix build after merging --- pkg/messaging/channel.go | 0 pkg/messaging/factory.go | 0 pkg/messaging/link.go | 0 pkg/messaging/pool.go | 0 4 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 pkg/messaging/channel.go delete mode 100644 pkg/messaging/factory.go delete mode 100644 pkg/messaging/link.go delete mode 100644 pkg/messaging/pool.go diff --git a/pkg/messaging/channel.go b/pkg/messaging/channel.go deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/pkg/messaging/factory.go b/pkg/messaging/factory.go deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/pkg/messaging/link.go b/pkg/messaging/link.go deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/pkg/messaging/pool.go b/pkg/messaging/pool.go deleted file mode 100644 index e69de29bb2..0000000000 From fb9507aef3c136c564a9516070bb002e52ac2e39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E5=BF=97=E5=AE=87?= Date: Wed, 3 Jul 2019 23:37:27 +0800 Subject: [PATCH 12/13] Removed duplicate modules. --- Makefile | 2 +- go.mod | 1 - internal/ioutil/ack_waiter.go | 101 ------- internal/ioutil/ack_waiter_test.go | 47 ---- internal/ioutil/atomic_bool.go | 23 -- internal/ioutil/buf_read.go | 16 -- internal/ioutil/len_read_writer.go | 57 ---- internal/ioutil/len_read_writer_test.go | 72 ----- internal/noise/dh.go | 37 --- internal/noise/net.go | 239 ----------------- internal/noise/net_test.go | 341 ------------------------ internal/noise/noise.go | 152 ----------- internal/noise/noise_test.go | 139 ---------- internal/noise/read_writer.go | 172 ------------ internal/noise/read_writer_test.go | 215 --------------- pkg/cipher/cipher.go | 255 ------------------ pkg/cipher/cipher_test.go | 100 ------- pkg/manager/node.go | 2 +- pkg/node/node.go | 2 +- pkg/router/loop_list.go | 2 +- pkg/router/router.go | 2 +- pkg/router/router_test.go | 2 +- vendor/modules.txt | 2 +- 23 files changed, 7 insertions(+), 1974 deletions(-) delete mode 100644 internal/ioutil/ack_waiter.go delete mode 100644 internal/ioutil/ack_waiter_test.go delete mode 100644 internal/ioutil/atomic_bool.go delete mode 100644 internal/ioutil/buf_read.go delete mode 100644 internal/ioutil/len_read_writer.go delete mode 100644 internal/ioutil/len_read_writer_test.go delete mode 100644 internal/noise/dh.go delete mode 100644 internal/noise/net.go delete mode 100644 internal/noise/net_test.go delete mode 100644 internal/noise/noise.go delete mode 100644 internal/noise/noise_test.go delete mode 100644 internal/noise/read_writer.go delete mode 100644 internal/noise/read_writer_test.go delete mode 100644 pkg/cipher/cipher.go delete mode 100644 pkg/cipher/cipher_test.go diff --git a/Makefile b/Makefile index 454c1e33d5..e5ad3fa044 100644 --- a/Makefile +++ b/Makefile @@ -64,7 +64,7 @@ test: ## Run tests ${OPTS} go test ${TEST_OPTS} ./internal/... #${OPTS} go test -race -tags no_ci -cover -timeout=5m ./pkg/... ${OPTS} go test ${TEST_OPTS} ./pkg/app/... - ${OPTS} go test ${TEST_OPTS} ./pkg/cipher/... +# ${OPTS} go test ${TEST_OPTS} ./pkg/cipher/... ${OPTS} go test ${TEST_OPTS} ./pkg/manager/... ${OPTS} go test ${TEST_OPTS} ./pkg/node/... ${OPTS} go test ${TEST_OPTS} ./pkg/route-finder/... diff --git a/go.mod b/go.mod index 2d0733a747..8372dd0a9d 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.12 require ( github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 - github.com/flynn/noise v0.0.0-20180327030543-2492fe189ae6 github.com/go-chi/chi v4.0.2+incompatible github.com/google/uuid v1.1.1 github.com/gorilla/handlers v1.4.0 diff --git a/internal/ioutil/ack_waiter.go b/internal/ioutil/ack_waiter.go deleted file mode 100644 index db68b8820b..0000000000 --- a/internal/ioutil/ack_waiter.go +++ /dev/null @@ -1,101 +0,0 @@ -package ioutil - -import ( - "context" - "crypto/rand" - "encoding/binary" - "io" - "math" - "sync" -) - -// Uint16Seq is part of the acknowledgement-waiting logic. -type Uint16Seq uint16 - -// DecodeUint16Seq decodes a slice to Uint16Seq. -func DecodeUint16Seq(b []byte) Uint16Seq { - if len(b) < 2 { - return 0 - } - return Uint16Seq(binary.BigEndian.Uint16(b[:2])) -} - -// Encode encodes the Uint16Seq to a 2-byte slice. -func (s Uint16Seq) Encode() []byte { - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, uint16(s)) - return b -} - -// Uint16AckWaiter implements acknowledgement-waiting logic (with uint16 sequences). -type Uint16AckWaiter struct { - nextSeq Uint16Seq - waiters [math.MaxUint16 + 1]chan struct{} - mx sync.RWMutex -} - -// RandSeq should only be run once on startup. It is not thread-safe. -func (w *Uint16AckWaiter) RandSeq() error { - b := make([]byte, 2) - if _, err := rand.Read(b); err != nil { - return err - } - w.nextSeq = DecodeUint16Seq(b) - return nil -} - -func (w *Uint16AckWaiter) stopWaiter(seq Uint16Seq) { - if waiter := w.waiters[seq]; waiter != nil { - close(waiter) - w.waiters[seq] = nil - } -} - -// StopAll stops all active waiters. -func (w *Uint16AckWaiter) StopAll() { - w.mx.Lock() - for seq := range w.waiters { - w.stopWaiter(Uint16Seq(seq)) - } - w.mx.Unlock() -} - -// Wait performs the given action, and waits for given seq to be Done. -func (w *Uint16AckWaiter) Wait(ctx context.Context, action func(seq Uint16Seq) error) (err error) { - ackCh := make(chan struct{}, 1) - - w.mx.Lock() - seq := w.nextSeq - w.nextSeq++ - w.waiters[seq] = ackCh - w.mx.Unlock() - - if err = action(seq); err != nil { - return err - } - - select { - case _, ok := <-ackCh: - if !ok { - // waiter stopped manually. - err = io.ErrClosedPipe - } - case <-ctx.Done(): - err = ctx.Err() - } - - w.mx.Lock() - w.stopWaiter(seq) - w.mx.Unlock() - return err -} - -// Done finishes given sequence. -func (w *Uint16AckWaiter) Done(seq Uint16Seq) { - w.mx.RLock() - select { - case w.waiters[seq] <- struct{}{}: - default: - } - w.mx.RUnlock() -} diff --git a/internal/ioutil/ack_waiter_test.go b/internal/ioutil/ack_waiter_test.go deleted file mode 100644 index 3ba22c5910..0000000000 --- a/internal/ioutil/ack_waiter_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package ioutil_test - -import ( - "context" - "sync" - "testing" - - "github.com/skycoin/skywire/internal/ioutil" -) - -func TestUint16AckWaiter_Wait(t *testing.T) { - - // Ensure that no race conditions occurs when - // each concurrent call to 'Uint16AckWaiter.Wait()' is met with - // multiple concurrent calls to 'Uint16AckWaiter.Done()' with the same seq. - t.Run("ensure_no_race_conditions", func(*testing.T) { - w := new(ioutil.Uint16AckWaiter) - defer w.StopAll() - - seqChan := make(chan ioutil.Uint16Seq) - defer close(seqChan) - - wg := new(sync.WaitGroup) - - for i := 0; i < 64; i++ { - wg.Add(1) - go func() { - defer wg.Done() - _ = w.Wait(context.TODO(), func(seq ioutil.Uint16Seq) error { //nolint:errcheck,unparam - seqChan <- seq - return nil - }) - }() - - seq := <-seqChan - for j := 0; j <= i; j++ { - wg.Add(1) - go func() { - defer wg.Done() - w.Done(seq) - }() - } - } - - wg.Wait() - }) -} diff --git a/internal/ioutil/atomic_bool.go b/internal/ioutil/atomic_bool.go deleted file mode 100644 index dab1f0b472..0000000000 --- a/internal/ioutil/atomic_bool.go +++ /dev/null @@ -1,23 +0,0 @@ -package ioutil - -import "sync/atomic" - -// AtomicBool implements a thread-safe boolean value. -type AtomicBool struct { - flag int32 -} - -// Set set's the boolean to specified value -// and returns true if the value is changed. -func (b *AtomicBool) Set(v bool) bool { - newF := int32(0) - if v { - newF = 1 - } - return atomic.CompareAndSwapInt32(&b.flag, b.flag, newF) -} - -// Get obtains the current boolean value. -func (b *AtomicBool) Get() bool { - return atomic.LoadInt32(&b.flag) == 1 -} diff --git a/internal/ioutil/buf_read.go b/internal/ioutil/buf_read.go deleted file mode 100644 index b6310fd7ff..0000000000 --- a/internal/ioutil/buf_read.go +++ /dev/null @@ -1,16 +0,0 @@ -package ioutil - -import ( - "bytes" -) - -// BufRead is designed to help writing 'io.Reader' implementations. -// It reads from 'data' into 'p'. If 'p' is short, write to 'buf'. -// Note that one should check if 'buf' has data and read from that first before calling this function. -func BufRead(buf *bytes.Buffer, data, p []byte) (int, error) { - n := copy(p, data) - if n < len(data) { - buf.Write(data[n:]) - } - return n, nil -} diff --git a/internal/ioutil/len_read_writer.go b/internal/ioutil/len_read_writer.go deleted file mode 100644 index dffaf48b72..0000000000 --- a/internal/ioutil/len_read_writer.go +++ /dev/null @@ -1,57 +0,0 @@ -package ioutil - -import ( - "bytes" - "encoding/binary" - "io" - "sync" -) - -// LenReadWriter writes len prepended packets and always reads whole -// packet. If read buffer is smaller than packet, LenReadWriter will -// buffer unread part and will return it first in subsequent reads. -type LenReadWriter struct { - io.ReadWriter - buf bytes.Buffer - mx sync.Mutex -} - -// NewLenReadWriter constructs a new LenReadWriter. -func NewLenReadWriter(rw io.ReadWriter) *LenReadWriter { - return &LenReadWriter{ReadWriter: rw} -} - -// ReadPacket returns single received len prepended packet. -func (rw *LenReadWriter) ReadPacket() ([]byte, error) { - h := make([]byte, 2) - if _, err := io.ReadFull(rw.ReadWriter, h); err != nil { - return nil, err - } - data := make([]byte, binary.BigEndian.Uint16(h)) - _, err := io.ReadFull(rw.ReadWriter, data) - return data, err -} - -func (rw *LenReadWriter) Read(p []byte) (n int, err error) { - rw.mx.Lock() - defer rw.mx.Unlock() - - if rw.buf.Len() != 0 { - return rw.buf.Read(p) - } - - var data []byte - data, err = rw.ReadPacket() - if err != nil { - return - } - - return BufRead(&rw.buf, data, p) -} - -func (rw *LenReadWriter) Write(p []byte) (n int, err error) { - buf := make([]byte, 2) - binary.BigEndian.PutUint16(buf, uint16(len(p))) - n, err = rw.ReadWriter.Write(append(buf, p...)) - return n - 2, err -} diff --git a/internal/ioutil/len_read_writer_test.go b/internal/ioutil/len_read_writer_test.go deleted file mode 100644 index 0b55b1b7aa..0000000000 --- a/internal/ioutil/len_read_writer_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package ioutil - -import ( - "log" - "net" - "os" - "testing" - - "github.com/skycoin/skycoin/src/util/logging" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMain(m *testing.M) { - loggingLevel, ok := os.LookupEnv("TEST_LOGGING_LEVEL") - if ok { - lvl, err := logging.LevelFromString(loggingLevel) - if err != nil { - log.Fatal(err) - } - logging.SetLevel(lvl) - } else { - logging.Disable() - } - - os.Exit(m.Run()) -} - -func TestLenReadWriter(t *testing.T) { - in, out := net.Pipe() - rwIn := NewLenReadWriter(in) - rwOut := NewLenReadWriter(out) - - errCh := make(chan error) - go func() { - _, err := rwIn.Write([]byte("foo")) - errCh <- err - }() - - buf := make([]byte, 2) - n, err := rwOut.Read(buf) - require.NoError(t, err) - require.NoError(t, <-errCh) - assert.Equal(t, 2, n) - assert.Equal(t, []byte("fo"), buf) - - buf = make([]byte, 2) - n, err = rwOut.Read(buf) - require.NoError(t, err) - assert.Equal(t, 1, n) - assert.Equal(t, []byte("o"), buf[:n]) - - go func() { - _, err := rwIn.Write([]byte("foo")) - errCh <- err - }() - - packet, err := rwOut.ReadPacket() - require.NoError(t, err) - require.NoError(t, <-errCh) - assert.Equal(t, []byte("foo"), packet) - - go func() { - _, err := rwOut.ReadPacket() - errCh <- err - }() - - n, err = rwIn.Write([]byte("bar")) - require.NoError(t, err) - require.NoError(t, <-errCh) - assert.Equal(t, 3, n) -} diff --git a/internal/noise/dh.go b/internal/noise/dh.go deleted file mode 100644 index 6da7ab6931..0000000000 --- a/internal/noise/dh.go +++ /dev/null @@ -1,37 +0,0 @@ -package noise - -import ( - "io" - - "github.com/flynn/noise" - "github.com/skycoin/skycoin/src/cipher" -) - -// Secp256k1 implements `noise.DHFunc`. -type Secp256k1 struct{} - -// GenerateKeypair helps to implement `noise.DHFunc`. -func (Secp256k1) GenerateKeypair(_ io.Reader) (noise.DHKey, error) { - pk, sk := cipher.GenerateKeyPair() - return noise.DHKey{ - Private: sk[:], - Public: pk[:], - }, nil -} - -// DH helps to implement `noise.DHFunc`. -func (Secp256k1) DH(sk, pk []byte) []byte { - return append( - cipher.MustECDH(cipher.MustNewPubKey(pk), cipher.MustNewSecKey(sk)), - byte(0)) -} - -// DHLen helps to implement `noise.DHFunc`. -func (Secp256k1) DHLen() int { - return 33 -} - -// DHName helps to implement `noise.DHFunc`. -func (Secp256k1) DHName() string { - return "Secp256k1" -} diff --git a/internal/noise/net.go b/internal/noise/net.go deleted file mode 100644 index 45f288e1e9..0000000000 --- a/internal/noise/net.go +++ /dev/null @@ -1,239 +0,0 @@ -package noise - -import ( - "errors" - "io" - "math" - "net" - "net/rpc" - "sync" - "time" - - "github.com/flynn/noise" - "github.com/skycoin/dmsg/cipher" -) - -var ( - // ErrAlreadyServing is returned when an operation fails due to an operation - // that is currently running. - ErrAlreadyServing = errors.New("already serving") - - // ErrPacketTooBig occurs when data is too large. - ErrPacketTooBig = errors.New("data too large to contain within a packet") - - // HandshakeXK is the XK handshake pattern. - HandshakeXK = noise.HandshakeXK - - // HandshakeKK is the KK handshake pattern. - HandshakeKK = noise.HandshakeKK - - // AcceptHandshakeTimeout determines how long a noise hs should take. - AcceptHandshakeTimeout = time.Second * 10 -) - -// RPCClientDialer attempts to redial to a remotely served RPCClient. -// It exposes an RPCServer to the remote server. -// The connection is encrypted via noise. -type RPCClientDialer struct { - config Config - pattern noise.HandshakePattern - addr string - conn net.Conn - mu sync.Mutex - done chan struct{} // nil: loop is not running, non-nil: loop is running. -} - -// NewRPCClientDialer creates a new RPCClientDialer. -func NewRPCClientDialer(addr string, pattern noise.HandshakePattern, config Config) *RPCClientDialer { - return &RPCClientDialer{config: config, pattern: pattern, addr: addr} -} - -// Run repeatedly dials to remote until a successful connection is established. -// It exposes a RPC Server. -// It will return if Close is called or crypto fails. -func (d *RPCClientDialer) Run(srv *rpc.Server, retry time.Duration) error { - if ok := d.setDone(); !ok { - return ErrAlreadyServing - } - for { - if err := d.establishConn(); err != nil { - // Only return if not network error. - if _, ok := err.(net.Error); !ok { - return err - } - } else { - // Only serve when then dial succeeds. - srv.ServeConn(d.conn) - d.setConn(nil) - } - select { - case <-d.done: - d.clearDone() - return nil - case <-time.After(retry): - } - } -} - -// Close closes the handler. -func (d *RPCClientDialer) Close() (err error) { - d.mu.Lock() - if d.done != nil { - close(d.done) - } - if d.conn != nil { - err = d.conn.Close() - } - d.mu.Unlock() - return -} - -// This operation should be atomic, hence protected by mutex. -func (d *RPCClientDialer) establishConn() error { - d.mu.Lock() - defer d.mu.Unlock() - - conn, err := net.Dial("tcp", d.addr) - if err != nil { - return err - } - ns, err := New(d.pattern, d.config) - if err != nil { - return err - } - conn, err = WrapConn(conn, ns, time.Second*5) - if err != nil { - return err - } - d.conn = conn - return nil -} - -func (d *RPCClientDialer) setConn(conn net.Conn) { - d.mu.Lock() - d.conn = conn - d.mu.Unlock() -} - -func (d *RPCClientDialer) setDone() (ok bool) { - d.mu.Lock() - if ok = d.done == nil; ok { - d.done = make(chan struct{}) - } - d.mu.Unlock() - return -} - -func (d *RPCClientDialer) clearDone() { - d.mu.Lock() - d.done = nil - d.mu.Unlock() -} - -// Addr is the address of a either an AppNode or ManagerNode. -type Addr struct { - PK cipher.PubKey - Addr net.Addr -} - -// Network returns the network type. -func (a Addr) Network() string { - return "noise" -} - -// String implements fmt.Stringer -func (a Addr) String() string { - return a.Addr.String() + "(" + a.PK.Hex() + ")" -} - -// Conn wraps a net.Conn and encrypts the connection with noise. -type Conn struct { - net.Conn - ns *ReadWriter -} - -// WrapConn wraps a provided net.Conn with noise. -func WrapConn(conn net.Conn, ns *Noise, hsTimeout time.Duration) (*Conn, error) { - rw := NewReadWriter(conn, ns) - if err := rw.Handshake(hsTimeout); err != nil { - return nil, err - } - return &Conn{Conn: conn, ns: rw}, nil -} - -// Read reads from the noise-encrypted connection. -func (c *Conn) Read(b []byte) (int, error) { - return c.ns.Read(b) -} - -// Write writes to the noise-encrypted connection. -func (c *Conn) Write(b []byte) (int, error) { - if len(b) > math.MaxUint16 { - return 0, io.ErrShortWrite - } - return c.ns.Write(b) -} - -// LocalAddr returns the local address of the connection. -func (c *Conn) LocalAddr() net.Addr { - return &Addr{ - PK: c.ns.LocalStatic(), - Addr: c.Conn.LocalAddr(), - } -} - -// RemoteAddr returns the remote address of the connection. -func (c *Conn) RemoteAddr() net.Addr { - return &Addr{ - PK: c.ns.RemoteStatic(), - Addr: c.Conn.RemoteAddr(), - } -} - -// Listener accepts incoming connections and encrypts with noise. -type Listener struct { - net.Listener - pk cipher.PubKey - sk cipher.SecKey - init bool - pattern noise.HandshakePattern -} - -// WrapListener wraps a listener and encrypts incoming connections with noise. -func WrapListener(lis net.Listener, pk cipher.PubKey, sk cipher.SecKey, init bool, pattern noise.HandshakePattern) *Listener { - return &Listener{Listener: lis, pk: pk, sk: sk, init: init, pattern: pattern} -} - -// Accept calls Accept from the underlying net.Listener and encrypts the -// obtained net.Conn with noise. -func (ml *Listener) Accept() (net.Conn, error) { - for { - conn, err := ml.Listener.Accept() - if err != nil { - return nil, err - } - ns, err := New(ml.pattern, Config{ - LocalPK: ml.pk, - LocalSK: ml.sk, - Initiator: ml.init, - }) - if err != nil { - return nil, err - } - rw := NewReadWriter(conn, ns) - if err := rw.Handshake(AcceptHandshakeTimeout); err != nil { - noiseLogger.WithError(err).Warn("accept: noise handshake failed.") - continue - } - noiseLogger.Infoln("accepted:", rw.RemoteStatic()) - return &Conn{Conn: conn, ns: rw}, nil - } -} - -// Addr returns the local address of the noise-encrypted Listener. -func (ml *Listener) Addr() net.Addr { - return &Addr{ - PK: ml.pk, - Addr: ml.Listener.Addr(), - } -} diff --git a/internal/noise/net_test.go b/internal/noise/net_test.go deleted file mode 100644 index ae9130ead3..0000000000 --- a/internal/noise/net_test.go +++ /dev/null @@ -1,341 +0,0 @@ -package noise - -import ( - "fmt" - "io" - "log" - "net" - "net/rpc" - "sync" - "testing" - "time" - - "github.com/flynn/noise" - "github.com/skycoin/dmsg/cipher" - "github.com/stretchr/testify/require" -) - -type TestRPC struct{} - -type AddIn struct{ A, B int } - -func (r *TestRPC) Add(in *AddIn, out *int) error { - *out = in.A + in.B - return nil -} - -func TestRPCClientDialer(t *testing.T) { - var ( - pattern = HandshakeXK - ) - - svr := rpc.NewServer() - require.NoError(t, svr.Register(new(TestRPC))) - - lPK, lSK := cipher.GenerateKeyPair() - var l net.Listener - var lAddr string - - setup := func() { - if len(lAddr) == 0 { - lAddr = ":0" - } - var err error - - l, err = net.Listen("tcp", lAddr) - require.NoError(t, err) - - l = WrapListener(l, lPK, lSK, false, pattern) - lAddr = l.Addr().(*Addr).Addr.String() - t.Logf("Listening on %s", lAddr) - } - - teardown := func() { - if l != nil { - require.NoError(t, l.Close()) - l = nil - } - } - - t.Run("RunRetry", func(t *testing.T) { - setup() - defer teardown() // Just in case of failure. - - const reconCount = 5 - const retry = time.Second / 4 - - dPK, dSK := cipher.GenerateKeyPair() - d := NewRPCClientDialer(lAddr, pattern, Config{ - LocalPK: dPK, - LocalSK: dSK, - RemotePK: lPK, - Initiator: true, - }) - dDone := make(chan error, 1) - - go func() { - dDone <- d.Run(svr, retry) - close(dDone) - }() - - for i := 0; i < reconCount; i++ { - teardown() - time.Sleep(retry * 2) // Dialer shouldn't quit retrying in this time. - setup() - - conn, err := l.Accept() - require.NoError(t, err) - - in, out := &AddIn{A: i, B: i}, new(int) - require.NoError(t, rpc.NewClient(conn).Call("TestRPC.Add", in, out)) - require.Equal(t, in.A+in.B, *out) - require.NoError(t, conn.Close()) - } - - _ = d.Close() - require.NoError(t, <-dDone) - }) -} - -func TestConn(t *testing.T) { - type Result struct { - N int - Err error - } - - const timeout = time.Second - - aPK, aSK := cipher.GenerateKeyPair() - bPK, bSK := cipher.GenerateKeyPair() - - aNs, err := XKAndSecp256k1(Config{LocalPK: aPK, LocalSK: aSK, RemotePK: bPK, Initiator: true}) - require.NoError(t, err) - bNs, err := XKAndSecp256k1(Config{LocalPK: bPK, LocalSK: bSK, Initiator: false}) - require.NoError(t, err) - - aConn, bConn := net.Pipe() - defer func() { _, _ = aConn.Close(), bConn.Close() }() - - aRW := NewReadWriter(aConn, aNs) - bRW := NewReadWriter(bConn, bNs) - - errChan := make(chan error, 2) - go func() { errChan <- aRW.Handshake(timeout) }() - go func() { errChan <- bRW.Handshake(timeout) }() - require.NoError(t, <-errChan) - require.NoError(t, <-errChan) - close(errChan) - - a := &Conn{Conn: aConn, ns: aRW} - b := &Conn{Conn: bConn, ns: bRW} - - t.Run("ReadWrite", func(t *testing.T) { - aResults := make(chan Result) - bResults := make(chan Result) - - for i := 0; i < 10; i++ { - msgAtoB := []byte(fmt.Sprintf("this is message %d from A for B", i)) - - go func() { - n, err := a.Write(msgAtoB) - aResults <- Result{N: n, Err: err} - }() - - receivedMsgAtoB := make([]byte, len(msgAtoB)) - n, err := io.ReadFull(b, receivedMsgAtoB) - require.Equal(t, len(msgAtoB), n) - require.NoError(t, err) - - aResult := <-aResults - require.Equal(t, len(msgAtoB), aResult.N) - require.NoError(t, aResult.Err) - - msgBtoA := []byte(fmt.Sprintf("this is message %d from B for A", i)) - - go func() { - n, err := b.Write(msgAtoB) - bResults <- Result{N: n, Err: err} - }() - - receivedMsgBtoA := make([]byte, len(msgBtoA)) - n, err = io.ReadFull(a, receivedMsgBtoA) - require.Equal(t, len(msgBtoA), n) - require.NoError(t, err) - - bResult := <-bResults - require.Equal(t, len(msgBtoA), bResult.N) - require.NoError(t, bResult.Err) - } - }) - - t.Run("ReadWriteConcurrent", func(t *testing.T) { - type ReadResult struct { - Msg string - N int - Err error - } - const ( - MsgCount = 100 - MsgLen = 4 - ) - var ( - aMap = make(map[string]struct{}) - bMap = make(map[string]struct{}) - aWrites = make(chan Result, MsgCount) - bWrites = make(chan Result, MsgCount) - aReads = make(chan ReadResult, MsgCount) - bReads = make(chan ReadResult, MsgCount) - ) - randSleep := func() { time.Sleep(time.Duration(cipher.RandByte(1)[0]) / 255 * time.Second) } - - for i := 0; i < MsgCount; i++ { - msg := fmt.Sprintf("%4d", i) - go func() { - randSleep() - n, err := a.Write([]byte(msg)) - aWrites <- Result{N: n, Err: err} - }() - go func() { - randSleep() - n, err := b.Write([]byte(msg)) - bWrites <- Result{N: n, Err: err} - }() - go func() { - randSleep() - msg := make([]byte, MsgLen) - n, err := io.ReadFull(a, msg) - aReads <- ReadResult{Msg: string(msg), N: n, Err: err} - }() - go func() { - randSleep() - msg := make([]byte, MsgLen) - n, err := io.ReadFull(b, msg) - bReads <- ReadResult{Msg: string(msg), N: n, Err: err} - }() - } - - for i := 0; i < MsgCount; i++ { - aWrite := <-aWrites - require.NoError(t, aWrite.Err) - require.Equal(t, MsgLen, aWrite.N) - - bWrite := <-bWrites - require.NoError(t, bWrite.Err) - require.Equal(t, MsgLen, bWrite.N) - - aRead := <-aReads - require.NoError(t, aRead.Err) - require.Equal(t, MsgLen, aRead.N) - _, aHas := aMap[aRead.Msg] - require.False(t, aHas) - aMap[aRead.Msg] = struct{}{} - - bRead := <-bReads - require.NoError(t, bRead.Err) - require.Equal(t, MsgLen, bRead.N) - _, bHas := bMap[bRead.Msg] - require.False(t, bHas) - bMap[bRead.Msg] = struct{}{} - } - - require.Len(t, aMap, MsgCount) - require.Len(t, bMap, MsgCount) - }) - - t.Run("ReadWriteIrregular", func(t *testing.T) { - const segLen = 100 - const segCount = 1000 - - aResults := make([]Result, segCount) - - msg := cipher.RandByte(segLen * segCount) - - wg := new(sync.WaitGroup) - wg.Add(1) - go func() { - for i := 0; i < segCount; i++ { - n, err := a.Write(msg[i*segLen : (i+1)*segLen]) - aResults[i] = Result{N: n, Err: err} - } - wg.Done() - }() - - msgResult := make([]byte, len(msg)) - _, err := io.ReadFull(b, msgResult) - require.NoError(t, err) - require.Equal(t, msg, msgResult) - - wg.Wait() - - for i, r := range aResults { - require.NoError(t, r.Err, i) - require.Equal(t, segLen, r.N, i) - } - }) -} - -func TestListener(t *testing.T) { - const ( - connCount = 10 - msg = "Hello, world!" - timeout = time.Second - ) - var ( - pattern = noise.HandshakeXK - ) - - dialAndWrite := func(remote cipher.PubKey, addr string) error { - pk, sk := cipher.GenerateKeyPair() - conn, err := net.Dial("tcp", addr) - if err != nil { - return err - } - ns, err := New(pattern, Config{LocalPK: pk, LocalSK: sk, RemotePK: remote, Initiator: true}) - if err != nil { - return err - } - conn, err = WrapConn(conn, ns, timeout) - if err != nil { - return err - } - _, err = conn.Write([]byte(msg)) - if err != nil { - return err - } - return conn.Close() - } - - lPK, lSK := cipher.GenerateKeyPair() - l, err := net.Listen("tcp", "") - require.NoError(t, err) - defer l.Close() - - l = WrapListener(l, lPK, lSK, false, pattern) - addr := l.Addr().(*Addr) - - t.Run("Accept", func(t *testing.T) { - hResults := make([]error, connCount) - wg := new(sync.WaitGroup) - wg.Add(1) - go func() { - for i := 0; i < connCount; i++ { - hResults[i] = dialAndWrite(lPK, addr.Addr.String()) - } - wg.Done() - }() - for i := 0; i < connCount; i++ { - lConn, err := l.Accept() - require.NoError(t, err) - rec := make([]byte, len(msg)) - n, err := io.ReadFull(lConn, rec) - log.Printf("Accept('%s'): received: '%s'", lConn.RemoteAddr(), string(rec)) - require.Equal(t, len(msg), n) - require.NoError(t, err) - require.NoError(t, lConn.Close()) - } - wg.Wait() - for i := 0; i < connCount; i++ { - require.NoError(t, hResults[i]) - } - }) -} diff --git a/internal/noise/noise.go b/internal/noise/noise.go deleted file mode 100644 index 6ea3ec4463..0000000000 --- a/internal/noise/noise.go +++ /dev/null @@ -1,152 +0,0 @@ -package noise - -import ( - "crypto/rand" - "encoding/binary" - - "github.com/flynn/noise" - "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/skycoin/src/util/logging" -) - -var noiseLogger = logging.MustGetLogger("noise") // TODO: initialize properly or remove - -// Config hold noise parameters. -type Config struct { - LocalPK cipher.PubKey // Local instance static public key. - LocalSK cipher.SecKey // Local instance static secret key. - RemotePK cipher.PubKey // Remote instance static public key. - Initiator bool // Whether the local instance initiates the connection. -} - -// Noise handles the handshake and the frame's cryptography. -// All operations on Noise are not guaranteed to be thread-safe. -type Noise struct { - pk cipher.PubKey - sk cipher.SecKey - init bool - - pattern noise.HandshakePattern - hs *noise.HandshakeState - enc *noise.CipherState - dec *noise.CipherState - - seq uint32 // sequence number, used as nonce for both encrypting and decrypting - previousSeq uint32 // sequence number last decrypted, check in order to avoid reply attacks - highestPrevious uint32 // highest sequence number received from the other end - //encN uint32 // counter to inform encrypting CipherState to re-key - //decN uint32 // counter to inform decrypting CipherState to re-key -} - -// New creates a new Noise with: -// - provided pattern for handshake. -// - Secp256k1 for the curve. -func New(pattern noise.HandshakePattern, config Config) (*Noise, error) { - nc := noise.Config{ - CipherSuite: noise.NewCipherSuite(Secp256k1{}, noise.CipherChaChaPoly, noise.HashSHA256), - Random: rand.Reader, - Pattern: pattern, - Initiator: config.Initiator, - StaticKeypair: noise.DHKey{ - Public: config.LocalPK[:], - Private: config.LocalSK[:], - }, - } - if !config.RemotePK.Null() { - nc.PeerStatic = config.RemotePK[:] - } - - hs, err := noise.NewHandshakeState(nc) - if err != nil { - return nil, err - } - return &Noise{ - pk: config.LocalPK, - sk: config.LocalSK, - init: config.Initiator, - pattern: pattern, - hs: hs, - }, nil -} - -// KKAndSecp256k1 creates a new Noise with: -// - KK pattern for handshake. -// - Secp256k1 for the curve. -func KKAndSecp256k1(config Config) (*Noise, error) { - return New(noise.HandshakeKK, config) -} - -// XKAndSecp256k1 creates a new Noise with: -// - XK pattern for handshake. -// - Secp256 for the curve. -func XKAndSecp256k1(config Config) (*Noise, error) { - return New(noise.HandshakeXK, config) -} - -// HandshakeMessage generates handshake message for a current handshake state. -func (ns *Noise) HandshakeMessage() (res []byte, err error) { - if ns.hs.MessageIndex() < len(ns.pattern.Messages)-1 { - res, _, _, err = ns.hs.WriteMessage(nil, nil) - return - } - - res, ns.dec, ns.enc, err = ns.hs.WriteMessage(nil, nil) - return res, err -} - -// ProcessMessage processes a received handshake message and appends the payload. -func (ns *Noise) ProcessMessage(msg []byte) (err error) { - if ns.hs.MessageIndex() < len(ns.pattern.Messages)-1 { - _, _, _, err = ns.hs.ReadMessage(nil, msg) - return - } - - _, ns.enc, ns.dec, err = ns.hs.ReadMessage(nil, msg) - return err -} - -// LocalStatic returns the local static public key. -func (ns *Noise) LocalStatic() cipher.PubKey { - return ns.pk -} - -// RemoteStatic returns the remote static public key. -func (ns *Noise) RemoteStatic() cipher.PubKey { - pk, err := cipher.NewPubKey(ns.hs.PeerStatic()) - if err != nil { - panic(err) - } - return cipher.PubKey(pk) -} - -// EncryptUnsafe encrypts plaintext without interlocking, should only -// be used with external lock. -func (ns *Noise) EncryptUnsafe(plaintext []byte) []byte { - ns.seq++ - seq := make([]byte, 4) - binary.BigEndian.PutUint32(seq, ns.seq) - - return append(seq, ns.enc.Cipher().Encrypt(nil, uint64(ns.seq), nil, plaintext)...) -} - -// DecryptUnsafe decrypts ciphertext without interlocking, should only -// be used with external lock. -func (ns *Noise) DecryptUnsafe(ciphertext []byte) ([]byte, error) { - seq := binary.BigEndian.Uint32(ciphertext[:4]) - if seq <= ns.previousSeq { - noiseLogger.Warnf("current seq: %s is not higher than previous one: %s. "+ - "Highest sequence number received so far is: %s", ns.seq, ns.previousSeq, ns.highestPrevious) - } else { - if ns.previousSeq > ns.highestPrevious { - ns.highestPrevious = seq - } - ns.previousSeq = seq - } - - return ns.dec.Cipher().Decrypt(nil, uint64(seq), nil, ciphertext[4:]) -} - -// HandshakeFinished indicate whether handshake was completed. -func (ns *Noise) HandshakeFinished() bool { - return ns.hs.MessageIndex() == len(ns.pattern.Messages) -} diff --git a/internal/noise/noise_test.go b/internal/noise/noise_test.go deleted file mode 100644 index 9aefde1b3a..0000000000 --- a/internal/noise/noise_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package noise - -import ( - "log" - "os" - "testing" - - "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/skycoin/src/util/logging" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMain(m *testing.M) { - loggingLevel, ok := os.LookupEnv("TEST_LOGGING_LEVEL") - if ok { - lvl, err := logging.LevelFromString(loggingLevel) - if err != nil { - log.Fatal(err) - } - logging.SetLevel(lvl) - } else { - logging.Disable() - } - - os.Exit(m.Run()) -} - -func TestKKAndSecp256k1(t *testing.T) { - pkI, skI := cipher.GenerateKeyPair() - pkR, skR := cipher.GenerateKeyPair() - - confI := Config{ - LocalPK: pkI, - LocalSK: skI, - RemotePK: pkR, - Initiator: true, - } - - confR := Config{ - LocalPK: pkR, - LocalSK: skR, - RemotePK: pkI, - Initiator: false, - } - - nI, err := KKAndSecp256k1(confI) - require.NoError(t, err) - - nR, err := KKAndSecp256k1(confR) - require.NoError(t, err) - - // -> e, es - msg, err := nI.HandshakeMessage() - require.NoError(t, err) - require.Error(t, nR.ProcessMessage(append(msg, 1))) - require.NoError(t, nR.ProcessMessage(msg)) - - // <- e, ee - msg, err = nR.HandshakeMessage() - require.NoError(t, err) - require.Error(t, nI.ProcessMessage(append(msg, 1))) - require.NoError(t, nI.ProcessMessage(msg)) - - require.True(t, nI.HandshakeFinished()) - require.True(t, nR.HandshakeFinished()) - - encrypted := nI.EncryptUnsafe([]byte("foo")) - decrypted, err := nR.DecryptUnsafe(encrypted) - require.NoError(t, err) - assert.Equal(t, []byte("foo"), decrypted) - - encrypted = nR.EncryptUnsafe([]byte("bar")) - decrypted, err = nI.DecryptUnsafe(encrypted) - require.NoError(t, err) - assert.Equal(t, []byte("bar"), decrypted) - - encrypted = nI.EncryptUnsafe([]byte("baz")) - decrypted, err = nR.DecryptUnsafe(encrypted) - require.NoError(t, err) - assert.Equal(t, []byte("baz"), decrypted) -} - -func TestXKAndSecp256k1(t *testing.T) { - pkI, skI := cipher.GenerateKeyPair() - pkR, skR := cipher.GenerateKeyPair() - - confI := Config{ - LocalPK: pkI, - LocalSK: skI, - RemotePK: pkR, - Initiator: true, - } - - confR := Config{ - LocalPK: pkR, - LocalSK: skR, - Initiator: false, - } - - nI, err := XKAndSecp256k1(confI) - require.NoError(t, err) - - nR, err := XKAndSecp256k1(confR) - require.NoError(t, err) - - // -> e, es - msg, err := nI.HandshakeMessage() - require.NoError(t, err) - require.NoError(t, nR.ProcessMessage(msg)) - - // <- e, ee - msg, err = nR.HandshakeMessage() - require.NoError(t, err) - require.NoError(t, nI.ProcessMessage(msg)) - - // -> s, se - msg, err = nI.HandshakeMessage() - require.NoError(t, err) - require.NoError(t, nR.ProcessMessage(msg)) - - require.True(t, nI.HandshakeFinished()) - require.True(t, nR.HandshakeFinished()) - - encrypted := nI.EncryptUnsafe([]byte("foo")) - decrypted, err := nR.DecryptUnsafe(encrypted) - require.NoError(t, err) - assert.Equal(t, []byte("foo"), decrypted) - - encrypted = nR.EncryptUnsafe([]byte("bar")) - decrypted, err = nI.DecryptUnsafe(encrypted) - require.NoError(t, err) - assert.Equal(t, []byte("bar"), decrypted) - - encrypted = nI.EncryptUnsafe([]byte("baz")) - decrypted, err = nR.DecryptUnsafe(encrypted) - require.NoError(t, err) - assert.Equal(t, []byte("baz"), decrypted) -} diff --git a/internal/noise/read_writer.go b/internal/noise/read_writer.go deleted file mode 100644 index e5e6355819..0000000000 --- a/internal/noise/read_writer.go +++ /dev/null @@ -1,172 +0,0 @@ -package noise - -import ( - "bytes" - "encoding/binary" - "errors" - "io" - "sync" - "time" - - "github.com/skycoin/dmsg/cipher" - - "github.com/skycoin/skywire/internal/ioutil" -) - -// ReadWriter implements noise encrypted read writer. -type ReadWriter struct { - origin io.ReadWriter - ns *Noise - rBuf bytes.Buffer - rMx sync.Mutex - wMx sync.Mutex -} - -// NewReadWriter constructs a new ReadWriter. -func NewReadWriter(rw io.ReadWriter, ns *Noise) *ReadWriter { - return &ReadWriter{ - origin: rw, - ns: ns, - } -} - -func (rw *ReadWriter) Read(p []byte) (int, error) { - rw.rMx.Lock() - defer rw.rMx.Unlock() - - if rw.rBuf.Len() > 0 { - return rw.rBuf.Read(p) - } - - ciphertext, err := rw.readPacket() - if err != nil { - return 0, err - } - plaintext, err := rw.ns.DecryptUnsafe(ciphertext) - if err != nil { - return 0, err - } - return ioutil.BufRead(&rw.rBuf, plaintext, p) -} - -func (rw *ReadWriter) readPacket() ([]byte, error) { - h := make([]byte, 2) - if _, err := io.ReadFull(rw.origin, h); err != nil { - return nil, err - } - data := make([]byte, binary.BigEndian.Uint16(h)) - _, err := io.ReadFull(rw.origin, data) - return data, err -} - -func (rw *ReadWriter) Write(p []byte) (int, error) { - rw.wMx.Lock() - defer rw.wMx.Unlock() - - ciphertext := rw.ns.EncryptUnsafe(p) - - if err := rw.writePacket(ciphertext); err != nil { - return 0, err - } - return len(p), nil -} - -func (rw *ReadWriter) writePacket(p []byte) error { - buf := make([]byte, 2) - binary.BigEndian.PutUint16(buf, uint16(len(p))) - _, err := rw.origin.Write(append(buf, p...)) - return err -} - -// Handshake performs a Noise handshake using the provided io.ReadWriter. -func (rw *ReadWriter) Handshake(hsTimeout time.Duration) error { - doneChan := make(chan error) - go func() { - if rw.ns.init { - doneChan <- rw.initiatorHandshake() - } else { - doneChan <- rw.responderHandshake() - } - }() - - select { - case err := <-doneChan: - return err - case <-time.After(hsTimeout): - return errors.New("timeout") - } -} - -// LocalStatic returns the local static public key. -func (rw *ReadWriter) LocalStatic() cipher.PubKey { - return rw.ns.LocalStatic() -} - -// RemoteStatic returns the remote static public key. -func (rw *ReadWriter) RemoteStatic() cipher.PubKey { - return rw.ns.RemoteStatic() -} - -func (rw *ReadWriter) initiatorHandshake() error { - for { - msg, err := rw.ns.HandshakeMessage() - if err != nil { - return err - } - - if err := rw.writePacket(msg); err != nil { - return err - } - - if rw.ns.HandshakeFinished() { - break - } - - res, err := rw.readPacket() - if err != nil { - return err - } - - if err = rw.ns.ProcessMessage(res); err != nil { - return err - } - - if rw.ns.HandshakeFinished() { - break - } - } - - return nil -} - -func (rw *ReadWriter) responderHandshake() error { - for { - msg, err := rw.readPacket() - if err != nil { - return err - } - - if err := rw.ns.ProcessMessage(msg); err != nil { - return err - } - - if rw.ns.HandshakeFinished() { - break - } - - res, err := rw.ns.HandshakeMessage() - if err != nil { - return err - } - - if err := rw.writePacket(res); err != nil { - return err - } - - if rw.ns.HandshakeFinished() { - break - } - } - - return nil -} diff --git a/internal/noise/read_writer_test.go b/internal/noise/read_writer_test.go deleted file mode 100644 index 637e6753b2..0000000000 --- a/internal/noise/read_writer_test.go +++ /dev/null @@ -1,215 +0,0 @@ -package noise - -import ( - "fmt" - "net" - "testing" - "time" - - "github.com/skycoin/dmsg/cipher" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNewReadWriter(t *testing.T) { - - type Result struct { - n int - err error - b []byte - } - - t.Run("concurrent", func(t *testing.T) { - aPK, aSK := cipher.GenerateKeyPair() - bPK, bSK := cipher.GenerateKeyPair() - - aNs, err := KKAndSecp256k1(Config{ - LocalPK: aPK, - LocalSK: aSK, - RemotePK: bPK, - Initiator: true, - }) - require.NoError(t, err) - - bNs, err := KKAndSecp256k1(Config{ - LocalPK: bPK, - LocalSK: bSK, - RemotePK: aPK, - Initiator: false, - }) - require.NoError(t, err) - - aConn, bConn := net.Pipe() - defer func() { - _ = aConn.Close() //nolint:errcheck - _ = bConn.Close() //nolint:errcheck - }() - - aRW := NewReadWriter(aConn, aNs) - bRW := NewReadWriter(bConn, bNs) - - hsCh := make(chan error, 2) - defer close(hsCh) - go func() { hsCh <- aRW.Handshake(time.Second) }() - go func() { hsCh <- bRW.Handshake(time.Second) }() - require.NoError(t, <-hsCh) - require.NoError(t, <-hsCh) - - const groupSize = 10 - const totalGroups = 5 - const msgCount = totalGroups * groupSize - - writes := make([][]byte, msgCount) - - wCh := make(chan Result, msgCount) - defer close(wCh) - rCh := make(chan Result, msgCount) - defer close(rCh) - - for i := 0; i < msgCount; i++ { - writes[i] = []byte(fmt.Sprintf("this is message: %d", i)) - } - - for i := 0; i < totalGroups; i++ { - go func(i int) { - for j := 0; j < groupSize; j++ { - go func(i, j int) { - b := writes[i*j] - n, err := aRW.Write(b) - wCh <- Result{n: n, err: err, b: b} - }(i, j) - go func() { - buf := make([]byte, 100) - n, err := bRW.Read(buf) - rCh <- Result{n: n, err: err, b: buf[:n]} - }() - } - }(i) - } - - for i := 0; i < msgCount; i++ { - w := <-wCh - fmt.Printf("write_result[%d]: b(%s) err(%v)\n", i, string(w.b), w.err) - assert.NoError(t, w.err) - assert.True(t, w.n > 0) - - r := <-rCh - fmt.Printf(" read_result[%d]: b(%s) err(%v)\n", i, string(r.b), r.err) - assert.NoError(t, r.err) - assert.True(t, r.n > 0) - } - }) -} - -func TestReadWriterKKPattern(t *testing.T) { - pkI, skI := cipher.GenerateKeyPair() - pkR, skR := cipher.GenerateKeyPair() - - confI := Config{ - LocalPK: pkI, - LocalSK: skI, - RemotePK: pkR, - Initiator: true, - } - - confR := Config{ - LocalPK: pkR, - LocalSK: skR, - RemotePK: pkI, - Initiator: false, - } - - nI, err := KKAndSecp256k1(confI) - require.NoError(t, err) - - nR, err := KKAndSecp256k1(confR) - require.NoError(t, err) - - connI, connR := net.Pipe() - rwI := NewReadWriter(connI, nI) - rwR := NewReadWriter(connR, nR) - - errCh := make(chan error) - go func() { errCh <- rwR.Handshake(time.Second) }() - require.NoError(t, rwI.Handshake(time.Second)) - require.NoError(t, <-errCh) - - go func() { - _, err := rwI.Write([]byte("foo")) - errCh <- err - }() - - buf := make([]byte, 3) - n, err := rwR.Read(buf) - require.NoError(t, err) - require.NoError(t, <-errCh) - assert.Equal(t, 3, n) - assert.Equal(t, []byte("foo"), buf) - - go func() { - _, err := rwI.Read(buf) - errCh <- err - }() - - n, err = rwR.Write([]byte("bar")) - require.NoError(t, err) - require.NoError(t, <-errCh) - assert.Equal(t, 3, n) - assert.Equal(t, []byte("bar"), buf) -} - -func TestReadWriterXKPattern(t *testing.T) { - pkI, skI := cipher.GenerateKeyPair() - pkR, skR := cipher.GenerateKeyPair() - - confI := Config{ - LocalPK: pkI, - LocalSK: skI, - RemotePK: pkR, - Initiator: true, - } - - confR := Config{ - LocalPK: pkR, - LocalSK: skR, - Initiator: false, - } - - nI, err := XKAndSecp256k1(confI) - require.NoError(t, err) - - nR, err := XKAndSecp256k1(confR) - require.NoError(t, err) - - connI, connR := net.Pipe() - rwI := NewReadWriter(connI, nI) - rwR := NewReadWriter(connR, nR) - - errCh := make(chan error) - go func() { errCh <- rwR.Handshake(time.Second) }() - require.NoError(t, rwI.Handshake(time.Second)) - require.NoError(t, <-errCh) - - go func() { - _, err := rwI.Write([]byte("foo")) - errCh <- err - }() - - buf := make([]byte, 3) - n, err := rwR.Read(buf) - require.NoError(t, err) - require.NoError(t, <-errCh) - assert.Equal(t, 3, n) - assert.Equal(t, []byte("foo"), buf) - - go func() { - _, err := rwI.Read(buf) - errCh <- err - }() - - n, err = rwR.Write([]byte("bar")) - require.NoError(t, err) - require.NoError(t, <-errCh) - assert.Equal(t, 3, n) - assert.Equal(t, []byte("bar"), buf) -} diff --git a/pkg/cipher/cipher.go b/pkg/cipher/cipher.go deleted file mode 100644 index 9d8a897000..0000000000 --- a/pkg/cipher/cipher.go +++ /dev/null @@ -1,255 +0,0 @@ -// Package cipher implements common golang encoding interfaces for -// github.com/skycoin/skycoin/src/cipher -package cipher - -import ( - "fmt" - "strings" - - "github.com/skycoin/skycoin/src/cipher" -) - -func init() { - cipher.DebugLevel2 = false // DebugLevel2 causes ECDH to be really slow -} - -// GenerateKeyPair creates key pair -func GenerateKeyPair() (PubKey, SecKey) { - pk, sk := cipher.GenerateKeyPair() - return PubKey(pk), SecKey(sk) -} - -// GenerateDeterministicKeyPair generates deterministic key pair -func GenerateDeterministicKeyPair(seed []byte) (PubKey, SecKey, error) { - pk, sk, err := cipher.GenerateDeterministicKeyPair(seed) - return PubKey(pk), SecKey(sk), err -} - -// NewPubKey converts []byte to a PubKey -func NewPubKey(b []byte) (PubKey, error) { - pk, err := cipher.NewPubKey(b) - return PubKey(pk), err -} - -// PubKey is a wrapper type for cipher.PubKey that implements common -// golang interfaces. -type PubKey cipher.PubKey - -// Hex returns a hex encoded PubKey string -func (pk PubKey) Hex() string { - return cipher.PubKey(pk).Hex() -} - -// Null returns true if PubKey is the null PubKey -func (pk PubKey) Null() bool { - return cipher.PubKey(pk).Null() -} - -// String implements fmt.Stringer for PubKey. Returns Hex representation. -func (pk PubKey) String() string { - return pk.Hex() -} - -// Set implements pflag.Value for PubKey. -func (pk *PubKey) Set(s string) error { - cPK, err := cipher.PubKeyFromHex(s) - if err != nil { - return err - } - *pk = PubKey(cPK) - return nil -} - -// Type implements pflag.Value for PubKey. -func (pk PubKey) Type() string { - return "cipher.PubKey" -} - -// MarshalText implements encoding.TextMarshaler. -func (pk PubKey) MarshalText() ([]byte, error) { - return []byte(pk.Hex()), nil -} - -// UnmarshalText implements encoding.TextUnmarshaler. -func (pk *PubKey) UnmarshalText(data []byte) error { - dPK, err := cipher.PubKeyFromHex(string(data)) - if err == nil { - *pk = PubKey(dPK) - } - return err -} - -// MarshalBinary implements encoding.BinaryMarshaler. -func (pk PubKey) MarshalBinary() ([]byte, error) { - return pk[:], nil -} - -// UnmarshalBinary implements encoding.BinaryUnmarshaler. -func (pk *PubKey) UnmarshalBinary(data []byte) error { - dPK, err := cipher.NewPubKey(data) - if err == nil { - *pk = PubKey(dPK) - } - return err -} - -// PubKeys represents a slice of PubKeys. -type PubKeys []PubKey - -// String implements stringer for PubKeys. -func (p PubKeys) String() string { - res := "public keys:\n" - for _, pk := range p { - res += fmt.Sprintf("\t%s\n", pk) - } - return res -} - -// Set implements pflag.Value for PubKeys. -func (p *PubKeys) Set(list string) error { - *p = PubKeys{} - for _, s := range strings.Split(list, ",") { - var pk PubKey - if err := pk.Set(strings.TrimSpace(s)); err != nil { - return err - } - *p = append(*p, pk) - } - return nil -} - -// Type implements pflag.Value for PubKeys. -func (p PubKeys) Type() string { - return "cipher.PubKeys" -} - -// SecKey is a wrapper type for cipher.SecKey that implements common -// golang interfaces. -type SecKey cipher.SecKey - -// Hex returns a hex encoded SecKey string -func (sk SecKey) Hex() string { - return cipher.SecKey(sk).Hex() -} - -// Null returns true if SecKey is the null SecKey. -func (sk SecKey) Null() bool { - return cipher.SecKey(sk).Null() -} - -// String implements fmt.Stringer for SecKey. Returns Hex representation. -func (sk SecKey) String() string { - return sk.Hex() -} - -// Set implements pflag.Value for SecKey. -func (sk *SecKey) Set(s string) error { - cSK, err := cipher.SecKeyFromHex(s) - if err != nil { - return err - } - *sk = SecKey(cSK) - return nil -} - -// Type implements pflag.Value for SecKey. -func (sk *SecKey) Type() string { - return "cipher.SecKey" -} - -// MarshalText implements encoding.TextMarshaler. -func (sk SecKey) MarshalText() ([]byte, error) { - return []byte(sk.Hex()), nil -} - -// UnmarshalText implements encoding.TextUnmarshaler. -func (sk *SecKey) UnmarshalText(data []byte) error { - dSK, err := cipher.SecKeyFromHex(string(data)) - if err == nil { - *sk = SecKey(dSK) - } - return err -} - -// MarshalBinary implements encoding.BinaryMarshaler. -func (sk SecKey) MarshalBinary() ([]byte, error) { - return sk[:], nil -} - -// UnmarshalBinary implements encoding.BinaryUnmarshaler. -func (sk *SecKey) UnmarshalBinary(data []byte) error { - dSK, err := cipher.NewSecKey(data) - if err == nil { - *sk = SecKey(dSK) - } - return err -} - -// PubKey recovers the public key for a secret key -func (sk SecKey) PubKey() (PubKey, error) { - pk, err := cipher.PubKeyFromSecKey(cipher.SecKey(sk)) - return PubKey(pk), err -} - -// Sig is a wrapper type for cipher.Sig that implements common golang interfaces. -type Sig cipher.Sig - -// Hex returns a hex encoded Sig string -func (sig Sig) Hex() string { - return cipher.Sig(sig).Hex() -} - -// String implements fmt.Stringer for Sig. Returns Hex representation. -func (sig Sig) String() string { - return sig.Hex() -} - -// Null returns true if Sig is a null Sig -func (sig Sig) Null() bool { - return sig == Sig{} -} - -// MarshalText implements encoding.TextMarshaler. -func (sig Sig) MarshalText() ([]byte, error) { - return []byte(sig.Hex()), nil -} - -// UnmarshalText implements encoding.TextUnmarshaler. -func (sig *Sig) UnmarshalText(data []byte) error { - dSig, err := cipher.SigFromHex(string(data)) - if err == nil { - *sig = Sig(dSig) - } - return err -} - -// SignPayload creates Sig for payload using SHA256 -func SignPayload(payload []byte, sec SecKey) (Sig, error) { - sig, err := cipher.SignHash(cipher.SumSHA256(payload), cipher.SecKey(sec)) - return Sig(sig), err -} - -// VerifyPubKeySignedPayload verifies that SHA256 hash of the payload was signed by PubKey -func VerifyPubKeySignedPayload(pubkey PubKey, sig Sig, payload []byte) error { - return cipher.VerifyPubKeySignedHash(cipher.PubKey(pubkey), cipher.Sig(sig), cipher.SumSHA256(payload)) -} - -// RandByte returns rand N bytes -func RandByte(n int) []byte { - return cipher.RandByte(n) -} - -// SHA256 is a wrapper type for cipher.SHA256 that implements common -// golang interfaces. -type SHA256 cipher.SHA256 - -// SHA256FromBytes converts []byte to SHA256 -func SHA256FromBytes(b []byte) (SHA256, error) { - h, err := cipher.SHA256FromBytes(b) - return SHA256(h), err -} - -// SumSHA256 sum sha256 -func SumSHA256(b []byte) SHA256 { - return SHA256(cipher.SumSHA256(b)) -} diff --git a/pkg/cipher/cipher_test.go b/pkg/cipher/cipher_test.go deleted file mode 100644 index a9a97ec740..0000000000 --- a/pkg/cipher/cipher_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package cipher - -import ( - "log" - "os" - "testing" - - "github.com/skycoin/skycoin/src/util/logging" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMain(m *testing.M) { - loggingLevel, ok := os.LookupEnv("TEST_LOGGING_LEVEL") - if ok { - lvl, err := logging.LevelFromString(loggingLevel) - if err != nil { - log.Fatal(err) - } - logging.SetLevel(lvl) - } else { - logging.Disable() - } - - os.Exit(m.Run()) -} - -func TestPubKeyString(t *testing.T) { - p, _ := GenerateKeyPair() - require.Equal(t, p.Hex(), p.String()) -} - -func TestPubKeyTextMarshaller(t *testing.T) { - p, _ := GenerateKeyPair() - h, err := p.MarshalText() - require.NoError(t, err) - - var p2 PubKey - err = p2.UnmarshalText(h) - require.NoError(t, err) - require.Equal(t, p, p2) -} - -func TestPubKeyBinaryMarshaller(t *testing.T) { - p, _ := GenerateKeyPair() - b, err := p.MarshalBinary() - require.NoError(t, err) - - var p2 PubKey - err = p2.UnmarshalBinary(b) - require.NoError(t, err) - require.Equal(t, p, p2) -} - -func TestSecKeyString(t *testing.T) { - _, s := GenerateKeyPair() - require.Equal(t, s.Hex(), s.String()) -} - -func TestSecKeyTextMarshaller(t *testing.T) { - _, s := GenerateKeyPair() - h, err := s.MarshalText() - require.NoError(t, err) - - var s2 SecKey - err = s2.UnmarshalText(h) - require.NoError(t, err) - require.Equal(t, s, s2) -} - -func TestSecKeyBinaryMarshaller(t *testing.T) { - _, s := GenerateKeyPair() - b, err := s.MarshalBinary() - require.NoError(t, err) - - var s2 SecKey - err = s2.UnmarshalBinary(b) - require.NoError(t, err) - require.Equal(t, s, s2) -} - -func TestSigString(t *testing.T) { - _, sk := GenerateKeyPair() - sig, err := SignPayload([]byte("foo"), sk) - require.NoError(t, err) - assert.Equal(t, sig.Hex(), sig.String()) -} - -func TestSigTextMarshaller(t *testing.T) { - _, sk := GenerateKeyPair() - sig, err := SignPayload([]byte("foo"), sk) - require.NoError(t, err) - h, err := sig.MarshalText() - require.NoError(t, err) - - var sig2 Sig - err = sig2.UnmarshalText(h) - require.NoError(t, err) - assert.Equal(t, sig, sig2) -} diff --git a/pkg/manager/node.go b/pkg/manager/node.go index b0de9967e0..5e05a145a0 100644 --- a/pkg/manager/node.go +++ b/pkg/manager/node.go @@ -19,7 +19,7 @@ import ( "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" - "github.com/skycoin/skywire/internal/noise" + "github.com/skycoin/dmsg/noise" "github.com/skycoin/skywire/pkg/httputil" "github.com/skycoin/skywire/pkg/node" "github.com/skycoin/skywire/pkg/routing" diff --git a/pkg/node/node.go b/pkg/node/node.go index 0c9bb1707c..eb93c052ce 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -21,7 +21,7 @@ import ( "github.com/skycoin/skycoin/src/util/logging" - "github.com/skycoin/skywire/internal/noise" + "github.com/skycoin/dmsg/noise" "github.com/skycoin/skywire/pkg/app" routeFinder "github.com/skycoin/skywire/pkg/route-finder/client" "github.com/skycoin/skywire/pkg/router" diff --git a/pkg/router/loop_list.go b/pkg/router/loop_list.go index 7641f93dfc..c78f737529 100644 --- a/pkg/router/loop_list.go +++ b/pkg/router/loop_list.go @@ -5,7 +5,7 @@ import ( "github.com/google/uuid" - "github.com/skycoin/skywire/internal/noise" + "github.com/skycoin/dmsg/noise" "github.com/skycoin/skywire/pkg/app" "github.com/skycoin/skywire/pkg/routing" ) diff --git a/pkg/router/router.go b/pkg/router/router.go index 9f5e58cdaa..16533cdc68 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -14,7 +14,7 @@ import ( "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" - "github.com/skycoin/skywire/internal/noise" + "github.com/skycoin/dmsg/noise" "github.com/skycoin/skywire/pkg/app" routeFinder "github.com/skycoin/skywire/pkg/route-finder/client" "github.com/skycoin/skywire/pkg/routing" diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 93f2207657..c358521648 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -16,7 +16,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/skycoin/skywire/internal/noise" + "github.com/skycoin/dmsg/noise" "github.com/skycoin/skywire/pkg/app" routeFinder "github.com/skycoin/skywire/pkg/route-finder/client" "github.com/skycoin/skywire/pkg/routing" diff --git a/vendor/modules.txt b/vendor/modules.txt index 0e0aee4464..ece3fd9d3b 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -66,8 +66,8 @@ github.com/sirupsen/logrus github.com/skycoin/dmsg/cipher github.com/skycoin/dmsg github.com/skycoin/dmsg/disc -github.com/skycoin/dmsg/ioutil github.com/skycoin/dmsg/noise +github.com/skycoin/dmsg/ioutil # github.com/skycoin/skycoin v0.26.0 github.com/skycoin/skycoin/src/util/logging github.com/skycoin/skycoin/src/cipher From 62b61e3bc9e7bd8054e173c1313a5fe2cbd1c69f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E5=BF=97=E5=AE=87?= Date: Wed, 3 Jul 2019 23:50:15 +0800 Subject: [PATCH 13/13] Simplified make test taarget. --- Makefile | 13 +------------ pkg/manager/node.go | 1 + pkg/node/node.go | 1 + pkg/router/loop_list.go | 1 + pkg/router/router.go | 1 + pkg/router/router_test.go | 1 + 6 files changed, 6 insertions(+), 12 deletions(-) diff --git a/Makefile b/Makefile index e5ad3fa044..c23af79f68 100644 --- a/Makefile +++ b/Makefile @@ -62,18 +62,7 @@ vendorcheck: ## Run vendorcheck test: ## Run tests -go clean -testcache &>/dev/null ${OPTS} go test ${TEST_OPTS} ./internal/... - #${OPTS} go test -race -tags no_ci -cover -timeout=5m ./pkg/... - ${OPTS} go test ${TEST_OPTS} ./pkg/app/... -# ${OPTS} go test ${TEST_OPTS} ./pkg/cipher/... - ${OPTS} go test ${TEST_OPTS} ./pkg/manager/... - ${OPTS} go test ${TEST_OPTS} ./pkg/node/... - ${OPTS} go test ${TEST_OPTS} ./pkg/route-finder/... - ${OPTS} go test ${TEST_OPTS} ./pkg/router/... - ${OPTS} go test ${TEST_OPTS} ./pkg/routing/... - ${OPTS} go test ${TEST_OPTS} ./pkg/setup/... - ${OPTS} go test ${TEST_OPTS} ./pkg/transport/... - ${OPTS} go test ${TEST_OPTS} ./pkg/transport-discovery/... - + ${OPTS} go test ${TEST_OPTS} ./pkg/... install-linters: ## Install linters - VERSION=1.17.1 ./ci_scripts/install-golangci-lint.sh diff --git a/pkg/manager/node.go b/pkg/manager/node.go index 5e05a145a0..c03fd31eca 100644 --- a/pkg/manager/node.go +++ b/pkg/manager/node.go @@ -20,6 +20,7 @@ import ( "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/dmsg/noise" + "github.com/skycoin/skywire/pkg/httputil" "github.com/skycoin/skywire/pkg/node" "github.com/skycoin/skywire/pkg/routing" diff --git a/pkg/node/node.go b/pkg/node/node.go index eb93c052ce..f9ae5367ac 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -22,6 +22,7 @@ import ( "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/dmsg/noise" + "github.com/skycoin/skywire/pkg/app" routeFinder "github.com/skycoin/skywire/pkg/route-finder/client" "github.com/skycoin/skywire/pkg/router" diff --git a/pkg/router/loop_list.go b/pkg/router/loop_list.go index c78f737529..2640c9de35 100644 --- a/pkg/router/loop_list.go +++ b/pkg/router/loop_list.go @@ -6,6 +6,7 @@ import ( "github.com/google/uuid" "github.com/skycoin/dmsg/noise" + "github.com/skycoin/skywire/pkg/app" "github.com/skycoin/skywire/pkg/routing" ) diff --git a/pkg/router/router.go b/pkg/router/router.go index 16533cdc68..14bad14847 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -15,6 +15,7 @@ import ( "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/dmsg/noise" + "github.com/skycoin/skywire/pkg/app" routeFinder "github.com/skycoin/skywire/pkg/route-finder/client" "github.com/skycoin/skywire/pkg/routing" diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index c358521648..5d3f9c22fe 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -17,6 +17,7 @@ import ( "github.com/stretchr/testify/require" "github.com/skycoin/dmsg/noise" + "github.com/skycoin/skywire/pkg/app" routeFinder "github.com/skycoin/skywire/pkg/route-finder/client" "github.com/skycoin/skywire/pkg/routing"