diff --git a/pkg/app2/network/networker_test.go b/pkg/app2/network/networker_test.go index 526801daf9..c03bc02bcf 100644 --- a/pkg/app2/network/networker_test.go +++ b/pkg/app2/network/networker_test.go @@ -15,7 +15,7 @@ import ( func TestAddNetworker(t *testing.T) { clearNetworkers() - nType := Type(TypeDMSG) + nType := TypeDMSG var n Networker err := AddNetworker(nType, n) @@ -28,7 +28,7 @@ func TestAddNetworker(t *testing.T) { func TestResolveNetworker(t *testing.T) { clearNetworkers() - nType := Type(TypeDMSG) + nType := TypeDMSG var n Networker n, err := ResolveNetworker(nType) diff --git a/pkg/app2/network/type.go b/pkg/app2/network/type.go index d91c6a0dcc..c91f9128d2 100644 --- a/pkg/app2/network/type.go +++ b/pkg/app2/network/type.go @@ -5,7 +5,7 @@ type Type string const ( // TypeDMSG is a network type for DMSG communication. - TypeDMSG = "dmsg" + TypeDMSG Type = "dmsg" ) // IsValid checks whether the network contains valid value for the type. diff --git a/pkg/app2/rpc_gateway.go b/pkg/app2/rpc_gateway.go index e5e9c7ec0b..d85cc50946 100644 --- a/pkg/app2/rpc_gateway.go +++ b/pkg/app2/rpc_gateway.go @@ -28,7 +28,7 @@ func newRPCGateway(log *logging.Logger) *RPCGateway { // Dial dials to the remote. func (r *RPCGateway) Dial(remote *network.Addr, connID *uint16) error { - connID, err := r.cm.nextKey() + reservedConnID, err := r.cm.nextKey() if err != nil { return err } @@ -38,7 +38,7 @@ func (r *RPCGateway) Dial(remote *network.Addr, connID *uint16) error { return err } - if err := r.cm.set(*connID, conn); err != nil { + if err := r.cm.set(*reservedConnID, conn); err != nil { if err := conn.Close(); err != nil { r.log.WithError(err).Error("error closing conn") } @@ -46,6 +46,8 @@ func (r *RPCGateway) Dial(remote *network.Addr, connID *uint16) error { return err } + *connID = *reservedConnID + return nil } diff --git a/pkg/app2/rpc_gateway_test.go b/pkg/app2/rpc_gateway_test.go new file mode 100644 index 0000000000..d6c89de2e6 --- /dev/null +++ b/pkg/app2/rpc_gateway_test.go @@ -0,0 +1,49 @@ +package app2 + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/skycoin/dmsg" + + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skywire/pkg/routing" + + "github.com/skycoin/skywire/pkg/app2/network" + + "github.com/skycoin/skycoin/src/util/logging" +) + +func TestRPCGateway_Dial(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + nType := network.TypeDMSG + + dialCtx := context.Background() + dialAddrPK, _ := cipher.GenerateKeyPair() + dialAddrPort := routing.Port(100) + dialAddr := network.Addr{ + Net: nType, + PubKey: dialAddrPK, + Port: dialAddrPort, + } + dialConn := &dmsg.Transport{} + var dialErr error + + n := &network.MockNetworker{} + n.On("DialContext", dialCtx, dialAddr).Return(dialConn, dialErr) + + err := network.AddNetworker(nType, n) + require.NoError(t, err) + + rpc := newRPCGateway(l) + + t.Run("ok", func(t *testing.T) { + var connID uint16 + + err := rpc.Dial(&dialAddr, &connID) + require.NoError(t, err) + require.Equal(t, connID, uint16(1)) + }) +}