diff --git a/internal/testhelpers/testhelpers.go b/internal/testhelpers/testhelpers.go index 4cfbfcb708..18c1ed661f 100644 --- a/internal/testhelpers/testhelpers.go +++ b/internal/testhelpers/testhelpers.go @@ -2,6 +2,7 @@ package testhelpers import ( + "errors" "testing" "time" @@ -13,6 +14,9 @@ const timeout = 5 * time.Second // NoErr is used with the mock interface to return from its methods. var NoErr error +// Err is used with the mock interface to return some error from its methods. +var Err = errors.New("error") + // WithinTimeout tries to read an error from error channel within timeout and returns it. // If timeout exceeds, nil value is returned. func WithinTimeout(ch <-chan error) error { diff --git a/pkg/router/rpc_gateway_test.go b/pkg/router/rpc_gateway_test.go new file mode 100644 index 0000000000..6dbc2b5f75 --- /dev/null +++ b/pkg/router/rpc_gateway_test.go @@ -0,0 +1,138 @@ +package router + +import ( + "testing" + + "github.com/SkycoinProject/dmsg/cipher" + "github.com/stretchr/testify/require" + + "github.com/SkycoinProject/skywire-mainnet/internal/testhelpers" + "github.com/SkycoinProject/skywire-mainnet/pkg/routing" +) + +func TestRPCGateway_AddEdgeRules(t *testing.T) { + srcPK, _ := cipher.GenerateKeyPair() + dstPK, _ := cipher.GenerateKeyPair() + var srcPort, dstPort routing.Port = 100, 110 + + desc := routing.NewRouteDescriptor(srcPK, dstPK, srcPort, dstPort) + + rules := routing.EdgeRules{ + Desc: desc, + Forward: routing.Rule{0, 0, 0}, + Reverse: routing.Rule{0, 0, 0}, + } + + t.Run("ok", func(t *testing.T) { + r := &MockRouter{} + r.On("IntroduceRules", rules).Return(testhelpers.NoErr) + r.On("SaveRoutingRules", rules.Forward, rules.Reverse).Return(testhelpers.NoErr) + + gateway := NewRPCGateway(r) + + var ok bool + err := gateway.AddEdgeRules(rules, &ok) + require.NoError(t, err) + require.True(t, ok) + }) + + t.Run("fail introducing rules", func(t *testing.T) { + r := &MockRouter{} + r.On("IntroduceRules", rules).Return(testhelpers.Err) + + gateway := NewRPCGateway(r) + + var ok bool + err := gateway.AddEdgeRules(rules, &ok) + require.Equal(t, testhelpers.Err, err) + require.False(t, ok) + }) + + t.Run("fail saving rules", func(t *testing.T) { + r := &MockRouter{} + r.On("IntroduceRules", rules).Return(testhelpers.NoErr) + r.On("SaveRoutingRules", rules.Forward, rules.Reverse).Return(testhelpers.Err) + + gateway := NewRPCGateway(r) + + wantErr := routing.Failure{ + Code: routing.FailureAddRules, + Msg: testhelpers.Err.Error(), + } + + var ok bool + err := gateway.AddEdgeRules(rules, &ok) + require.Equal(t, wantErr, err) + require.False(t, ok) + }) +} + +func TestRPCGateway_AddIntermediaryRules(t *testing.T) { + rule1 := routing.Rule{0, 0, 0} + rule2 := routing.Rule{0, 0, 0} + rulesIfc := []interface{}{rule1, rule2} + rules := []routing.Rule{rule1, rule2} + + t.Run("ok", func(t *testing.T) { + r := &MockRouter{} + r.On("SaveRoutingRules", rulesIfc...).Return(testhelpers.NoErr) + + gateway := NewRPCGateway(r) + + var ok bool + err := gateway.AddIntermediaryRules(rules, &ok) + require.NoError(t, err) + require.True(t, ok) + }) + + t.Run("fail saving rules", func(t *testing.T) { + r := &MockRouter{} + r.On("SaveRoutingRules", rulesIfc...).Return(testhelpers.Err) + + gateway := NewRPCGateway(r) + + wantErr := routing.Failure{ + Code: routing.FailureAddRules, + Msg: testhelpers.Err.Error(), + } + + var ok bool + err := gateway.AddIntermediaryRules(rules, &ok) + require.Equal(t, wantErr, err) + require.False(t, ok) + }) +} + +func TestRPCGateway_ReserveIDs(t *testing.T) { + n := 5 + ids := []routing.RouteID{1, 2, 3, 4, 5} + + t.Run("ok", func(t *testing.T) { + r := &MockRouter{} + r.On("ReserveKeys", n).Return(ids, testhelpers.NoErr) + + gateway := NewRPCGateway(r) + + var gotIds []routing.RouteID + err := gateway.ReserveIDs(uint8(n), &gotIds) + require.NoError(t, err) + require.Equal(t, ids, gotIds) + }) + + t.Run("fail reserving keys", func(t *testing.T) { + r := &MockRouter{} + r.On("ReserveKeys", n).Return(nil, testhelpers.Err) + + gateway := NewRPCGateway(r) + + wantErr := routing.Failure{ + Code: routing.FailureReserveRtIDs, + Msg: testhelpers.Err.Error(), + } + + var gotIds []routing.RouteID + err := gateway.ReserveIDs(uint8(n), &gotIds) + require.Equal(t, wantErr, err) + require.Nil(t, gotIds) + }) +}