diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go index 7c0c8d2a9..e6473a9ea 100644 --- a/pkg/app2/conn.go +++ b/pkg/app2/conn.go @@ -18,14 +18,11 @@ type Conn struct { } func (c *Conn) Read(b []byte) (int, error) { - n, readBytes, err := c.rpc.Read(c.id, b) + n, err := c.rpc.Read(c.id, b) if err != nil { return 0, err } - // TODO: check for slice border - copy(b[:n], readBytes[:n]) - return n, err } diff --git a/pkg/app2/conn_test.go b/pkg/app2/conn_test.go index ef860a698..2b185065c 100644 --- a/pkg/app2/conn_test.go +++ b/pkg/app2/conn_test.go @@ -11,32 +11,27 @@ func TestConn_Read(t *testing.T) { connID := uint16(1) tt := []struct { - name string - readBuff []byte - readN int - readBytes []byte - readErr error - wantBuff []byte + name string + readBuff []byte + readN int + readErr error }{ { - name: "ok", - readBuff: make([]byte, 10), - readN: 2, - readBytes: []byte{1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, - wantBuff: []byte{1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, + name: "ok", + readBuff: make([]byte, 10), + readN: 2, }, { name: "read error", readBuff: make([]byte, 10), readErr: errors.New("read error"), - wantBuff: make([]byte, 10), }, } for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { rpc := &MockRPCClient{} - rpc.On("Read", connID, tc.readBuff).Return(tc.readN, tc.readBytes, tc.readErr) + rpc.On("Read", connID, tc.readBuff).Return(tc.readN, tc.readErr) conn := &Conn{ id: connID, @@ -46,7 +41,6 @@ func TestConn_Read(t *testing.T) { n, err := conn.Read(tc.readBuff) require.Equal(t, tc.readErr, err) require.Equal(t, tc.readN, n) - require.Equal(t, tc.wantBuff, tc.readBuff) }) } } diff --git a/pkg/app2/mock_rpc_client.go b/pkg/app2/mock_rpc_client.go index 7caecc6aa..bbf373f93 100644 --- a/pkg/app2/mock_rpc_client.go +++ b/pkg/app2/mock_rpc_client.go @@ -117,7 +117,7 @@ func (_m *MockRPCClient) Listen(local network.Addr) (uint16, error) { } // Read provides a mock function with given fields: connID, b -func (_m *MockRPCClient) Read(connID uint16, b []byte) (int, []byte, error) { +func (_m *MockRPCClient) Read(connID uint16, b []byte) (int, error) { ret := _m.Called(connID, b) var r0 int @@ -127,23 +127,14 @@ func (_m *MockRPCClient) Read(connID uint16, b []byte) (int, []byte, error) { r0 = ret.Get(0).(int) } - var r1 []byte - if rf, ok := ret.Get(1).(func(uint16, []byte) []byte); ok { + var r1 error + if rf, ok := ret.Get(1).(func(uint16, []byte) error); ok { r1 = rf(connID, b) } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).([]byte) - } - } - - var r2 error - if rf, ok := ret.Get(2).(func(uint16, []byte) error); ok { - r2 = rf(connID, b) - } else { - r2 = ret.Error(2) + r1 = ret.Error(1) } - return r0, r1, r2 + return r0, r1 } // Write provides a mock function with given fields: connID, b diff --git a/pkg/app2/network/dmsg_networker.go b/pkg/app2/network/dmsg_networker.go index e87ef5512..424b0df46 100644 --- a/pkg/app2/network/dmsg_networker.go +++ b/pkg/app2/network/dmsg_networker.go @@ -26,12 +26,7 @@ func (n *DMSGNetworker) Dial(addr Addr) (net.Conn, error) { // DialContext dials remote `addr` via dmsg network with context. func (n *DMSGNetworker) DialContext(ctx context.Context, addr Addr) (net.Conn, error) { - tp, err := n.dmsgC.Dial(ctx, addr.PubKey, uint16(addr.Port)) - if err != nil { - return nil, err - } - - return WrapConn(tp) + return n.dmsgC.Dial(ctx, addr.PubKey, uint16(addr.Port)) } // Listen starts listening on local `addr` in the dmsg network. diff --git a/pkg/app2/rpc_client.go b/pkg/app2/rpc_client.go index 6fc7b240b..a97773f27 100644 --- a/pkg/app2/rpc_client.go +++ b/pkg/app2/rpc_client.go @@ -15,7 +15,7 @@ type RPCClient interface { Listen(local network.Addr) (uint16, error) Accept(lisID uint16) (connID uint16, remote network.Addr, err error) Write(connID uint16, b []byte) (int, error) - Read(connID uint16, b []byte) (int, []byte, error) + Read(connID uint16, b []byte) (int, error) CloseConn(id uint16) error CloseListener(id uint16) error } @@ -35,7 +35,7 @@ func NewRPCClient(rpc *rpc.Client) RPCClient { // Dial sends `Dial` command to the server. func (c *rpcCLient) Dial(remote network.Addr) (connID uint16, localPort routing.Port, err error) { var resp DialResp - if err := c.rpc.Call("Dial", &remote, &resp); err != nil { + if err := c.rpc.Call("RPCGateway.Dial", &remote, &resp); err != nil { return 0, 0, err } @@ -45,7 +45,7 @@ func (c *rpcCLient) Dial(remote network.Addr) (connID uint16, localPort routing. // Listen sends `Listen` command to the server. func (c *rpcCLient) Listen(local network.Addr) (uint16, error) { var lisID uint16 - if err := c.rpc.Call("Listen", &local, &lisID); err != nil { + if err := c.rpc.Call("RPCGateway.Listen", &local, &lisID); err != nil { return 0, err } @@ -55,7 +55,7 @@ func (c *rpcCLient) Listen(local network.Addr) (uint16, error) { // Accept sends `Accept` command to the server. func (c *rpcCLient) Accept(lisID uint16) (connID uint16, remote network.Addr, err error) { var acceptResp AcceptResp - if err := c.rpc.Call("Accept", &lisID, &acceptResp); err != nil { + if err := c.rpc.Call("RPCGateway.Accept", &lisID, &acceptResp); err != nil { return 0, network.Addr{}, err } @@ -70,7 +70,7 @@ func (c *rpcCLient) Write(connID uint16, b []byte) (int, error) { } var n int - if err := c.rpc.Call("Write", &req, &n); err != nil { + if err := c.rpc.Call("RPCGateway.Write", &req, &n); err != nil { return n, err } @@ -78,28 +78,28 @@ func (c *rpcCLient) Write(connID uint16, b []byte) (int, error) { } // Read sends `Read` command to the server. -func (c *rpcCLient) Read(connID uint16, b []byte) (int, []byte, error) { +func (c *rpcCLient) Read(connID uint16, b []byte) (int, error) { req := ReadReq{ ConnID: connID, BufLen: len(b), } var resp ReadResp - if err := c.rpc.Call("Read", &req, &resp); err != nil { - return 0, nil, err + if err := c.rpc.Call("RPCGateway.Read", &req, &resp); err != nil { + return 0, err } copy(b[:resp.N], resp.B[:resp.N]) - return resp.N, resp.B, nil + return resp.N, nil } // CloseConn sends `CloseConn` command to the server. func (c *rpcCLient) CloseConn(id uint16) error { - return c.rpc.Call("CloseConn", &id, nil) + return c.rpc.Call("RPCGateway.CloseConn", &id, nil) } // CloseListener sends `CloseListener` command to the server. func (c *rpcCLient) CloseListener(id uint16) error { - return c.rpc.Call("CloseListener", &id, nil) + return c.rpc.Call("RPCGateway.CloseListener", &id, nil) } diff --git a/pkg/app2/rpc_client_test.go b/pkg/app2/rpc_client_test.go new file mode 100644 index 000000000..2d3476780 --- /dev/null +++ b/pkg/app2/rpc_client_test.go @@ -0,0 +1,494 @@ +package app2 + +import ( + "context" + "net" + "net/rpc" + "testing" + + "github.com/pkg/errors" + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skycoin/src/util/logging" + "github.com/stretchr/testify/require" + "golang.org/x/net/nettest" + + "github.com/skycoin/skywire/pkg/app2/network" + "github.com/skycoin/skywire/pkg/routing" +) + +func TestRPCClient_Dial(t *testing.T) { + t.Run("ok", func(t *testing.T) { + s := prepRPCServer(t, prepGateway()) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + remoteNet := network.TypeDMSG + remotePK, _ := cipher.GenerateKeyPair() + remotePort := routing.Port(100) + remote := network.Addr{ + Net: remoteNet, + PubKey: remotePK, + Port: remotePort, + } + + localPK, _ := cipher.GenerateKeyPair() + dmsgLocal := dmsg.Addr{ + PK: localPK, + Port: 101, + } + dmsgRemote := dmsg.Addr{ + PK: remotePK, + Port: uint16(remotePort), + } + + dialCtx := context.Background() + dialConn := dmsg.NewTransport(&MockConn{}, logging.MustGetLogger("dmsg_tp"), + dmsgLocal, dmsgRemote, 0, func() {}) + var noErr error + + n := &network.MockNetworker{} + n.On("DialContext", dialCtx, remote).Return(dialConn, noErr) + + network.ClearNetworkers() + err := network.AddNetworker(remoteNet, n) + require.NoError(t, err) + + connID, localPort, err := cl.Dial(remote) + require.NoError(t, err) + require.Equal(t, connID, uint16(1)) + require.Equal(t, localPort, routing.Port(dmsgLocal.Port)) + + }) + + t.Run("dial error", func(t *testing.T) { + s := prepRPCServer(t, prepGateway()) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + remoteNet := network.TypeDMSG + remotePK, _ := cipher.GenerateKeyPair() + remotePort := routing.Port(100) + remote := network.Addr{ + Net: remoteNet, + PubKey: remotePK, + Port: remotePort, + } + + dialCtx := context.Background() + var dialConn net.Conn + dialErr := errors.New("dial error") + + n := &network.MockNetworker{} + n.On("DialContext", dialCtx, remote).Return(dialConn, dialErr) + + network.ClearNetworkers() + err := network.AddNetworker(remoteNet, n) + require.NoError(t, err) + + connID, localPort, err := cl.Dial(remote) + require.Error(t, err) + require.Equal(t, err.Error(), dialErr.Error()) + require.Equal(t, connID, uint16(0)) + require.Equal(t, localPort, routing.Port(0)) + }) +} + +func TestRPCClient_Listen(t *testing.T) { + t.Run("ok", func(t *testing.T) { + s := prepRPCServer(t, prepGateway()) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + localNet := network.TypeDMSG + localPK, _ := cipher.GenerateKeyPair() + localPort := routing.Port(100) + local := network.Addr{ + Net: localNet, + PubKey: localPK, + Port: localPort, + } + + listenCtx := context.Background() + var listenLis net.Listener + var noErr error + + n := &network.MockNetworker{} + n.On("ListenContext", listenCtx, local).Return(listenLis, noErr) + + network.ClearNetworkers() + err := network.AddNetworker(localNet, n) + require.NoError(t, err) + + lisID, err := cl.Listen(local) + require.NoError(t, err) + require.Equal(t, lisID, uint16(1)) + }) + + t.Run("listen error", func(t *testing.T) { + s := prepRPCServer(t, prepGateway()) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + localNet := network.TypeDMSG + localPK, _ := cipher.GenerateKeyPair() + localPort := routing.Port(100) + local := network.Addr{ + Net: localNet, + PubKey: localPK, + Port: localPort, + } + + listenCtx := context.Background() + var listenLis net.Listener + listenErr := errors.New("listen error") + + n := &network.MockNetworker{} + n.On("ListenContext", listenCtx, local).Return(listenLis, listenErr) + + network.ClearNetworkers() + err := network.AddNetworker(localNet, n) + require.NoError(t, err) + + lisID, err := cl.Listen(local) + require.Error(t, err) + require.Equal(t, err.Error(), listenErr.Error()) + require.Equal(t, lisID, uint16(0)) + }) +} + +func TestRPCClient_Accept(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + localPK, _ := cipher.GenerateKeyPair() + localPort := uint16(100) + dmsgLocal := dmsg.Addr{ + PK: localPK, + Port: localPort, + } + remotePK, _ := cipher.GenerateKeyPair() + remotePort := uint16(101) + dmsgRemote := dmsg.Addr{ + PK: remotePK, + Port: remotePort, + } + lisConn := dmsg.NewTransport(&MockConn{}, logging.MustGetLogger("dmsg_tp"), + dmsgLocal, dmsgRemote, 0, func() {}) + var noErr error + + lis := &MockListener{} + lis.On("Accept").Return(lisConn, noErr) + + lisID := uint16(1) + + _, err := gateway.lm.add(lisID, lis) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + wantRemote := network.Addr{ + Net: network.TypeDMSG, + PubKey: remotePK, + Port: routing.Port(remotePort), + } + + connID, remote, err := cl.Accept(lisID) + require.NoError(t, err) + require.Equal(t, connID, uint16(1)) + require.Equal(t, remote, wantRemote) + }) + + t.Run("accept error", func(t *testing.T) { + gateway := prepGateway() + + var lisConn net.Conn + listenErr := errors.New("accept error") + + lis := &MockListener{} + lis.On("Accept").Return(lisConn, listenErr) + + lisID := uint16(1) + + _, err := gateway.lm.add(lisID, lis) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + connID, remote, err := cl.Accept(lisID) + require.Error(t, err) + require.Equal(t, err.Error(), listenErr.Error()) + require.Equal(t, connID, uint16(0)) + require.Equal(t, remote, network.Addr{}) + }) +} + +func TestRPCClient_Write(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + writeBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + writeN := 10 + var noErr error + + conn := &MockConn{} + conn.On("Write", writeBuf).Return(writeN, noErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + n, err := cl.Write(connID, writeBuf) + require.NoError(t, err) + require.Equal(t, n, writeN) + }) + + t.Run("write error", func(t *testing.T) { + gateway := prepGateway() + + writeBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + writeN := 0 + writeErr := errors.New("write error") + + conn := &MockConn{} + conn.On("Write", writeBuf).Return(writeN, writeErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + n, err := cl.Write(connID, writeBuf) + require.Error(t, err) + require.Equal(t, err.Error(), writeErr.Error()) + require.Equal(t, n, 0) + }) +} + +func TestRPCClient_Read(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + readBufLen := 10 + readBuf := make([]byte, readBufLen) + readN := 5 + var noErr error + + conn := &MockConn{} + conn.On("Read", readBuf).Return(readN, noErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + n, err := cl.Read(connID, readBuf) + require.NoError(t, err) + require.Equal(t, n, readN) + }) + + t.Run("read error", func(t *testing.T) { + gateway := prepGateway() + + readBufLen := 10 + readBuf := make([]byte, readBufLen) + readN := 0 + readErr := errors.New("read error") + + conn := &MockConn{} + conn.On("Read", readBuf).Return(readN, readErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + n, err := cl.Read(connID, readBuf) + require.Error(t, err) + require.Equal(t, err.Error(), readErr.Error()) + require.Equal(t, n, readN) + }) +} + +func TestRPCClient_CloseConn(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + var noErr error + + conn := &MockConn{} + conn.On("Close").Return(noErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + err = cl.CloseConn(connID) + require.NoError(t, err) + }) + + t.Run("close error", func(t *testing.T) { + gateway := prepGateway() + + closeErr := errors.New("close error") + + conn := &MockConn{} + conn.On("Close").Return(closeErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + err = cl.CloseConn(connID) + require.Error(t, err) + require.Equal(t, err.Error(), closeErr.Error()) + }) +} + +func TestRPCClient_CloseListener(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + var noErr error + + lis := &MockListener{} + lis.On("Close").Return(noErr) + + lisID := uint16(1) + + _, err := gateway.lm.add(lisID, lis) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + err = cl.CloseListener(lisID) + require.NoError(t, err) + }) + + t.Run("close error", func(t *testing.T) { + gateway := prepGateway() + + closeErr := errors.New("close error") + + lis := &MockListener{} + lis.On("Close").Return(closeErr) + + lisID := uint16(1) + + _, err := gateway.lm.add(lisID, lis) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + err = cl.CloseListener(lisID) + require.Error(t, err) + require.Equal(t, err.Error(), closeErr.Error()) + }) +} + +func prepGateway() *RPCGateway { + l := logging.MustGetLogger("rpc_gateway") + return newRPCGateway(l) +} + +func prepRPCServer(t *testing.T, gateway *RPCGateway) *rpc.Server { + s := rpc.NewServer() + err := s.Register(gateway) + require.NoError(t, err) + + return s +} + +func prepListener(t *testing.T) (lis net.Listener, cleanup func()) { + lis, err := nettest.NewLocalListener("tcp") + require.NoError(t, err) + + return lis, func() { + err := lis.Close() + require.NoError(t, err) + } +} + +func prepClient(t *testing.T, network, addr string) RPCClient { + rpcCl, err := rpc.Dial(network, addr) + require.NoError(t, err) + + return NewRPCClient(rpcCl) +}