diff --git a/pkg/router/route_group.go b/pkg/router/route_group.go index 88a7e4623..783e7720b 100644 --- a/pkg/router/route_group.go +++ b/pkg/router/route_group.go @@ -31,6 +31,8 @@ var ( ErrNoRules = errors.New("no rules") // ErrBadTransport is returned when transport is nil. ErrBadTransport = errors.New("bad transport") + // ErrRuleTransportMismatch is returned when number of forward rules does not equal to number of transports. + ErrRuleTransportMismatch = errors.New("rule/transport mismatch") ) type timeoutError struct{} @@ -59,6 +61,7 @@ func DefaultRouteGroupConfig() *RouteGroupConfig { type RouteGroup struct { mu sync.Mutex + cfg *RouteGroupConfig logger *logging.Logger desc routing.RouteDescriptor // describes the route group rt routing.Table @@ -97,6 +100,7 @@ func NewRouteGroup(cfg *RouteGroupConfig, rt routing.Table, desc routing.RouteDe } rg := &RouteGroup{ + cfg: cfg, logger: logging.MustGetLogger(fmt.Sprintf("RouteGroup %s", desc.String())), desc: desc, rt: rt, @@ -236,7 +240,7 @@ func (r *RouteGroup) Close() error { defer r.mu.Unlock() if len(r.fwd) != len(r.tps) { - return errors.New("len(r.fwd) != len(r.tps)") + return ErrRuleTransportMismatch } for i := 0; i < len(r.tps); i++ { diff --git a/pkg/router/route_group_test.go b/pkg/router/route_group_test.go index 719a61153..d48e5b52e 100644 --- a/pkg/router/route_group_test.go +++ b/pkg/router/route_group_test.go @@ -25,6 +25,7 @@ import ( func TestNewRouteGroup(t *testing.T) { rg := createRouteGroup() require.NotNil(t, rg) + require.Equal(t, DefaultRouteGroupConfig(), rg.cfg) } func TestRouteGroup_Close(t *testing.T) { @@ -34,13 +35,22 @@ func TestRouteGroup_Close(t *testing.T) { require.False(t, rg.isClosed()) require.NoError(t, rg.Close()) require.True(t, rg.isClosed()) + + rg = createRouteGroup() + require.NotNil(t, rg) + + rg.tps = append(rg.tps, &transport.ManagedTransport{}) + require.Equal(t, ErrRuleTransportMismatch, rg.Close()) } func TestRouteGroup_Read(t *testing.T) { msg1 := []byte("hello1") msg2 := []byte("hello2") + msg3 := []byte("hello3") buf1 := make([]byte, len(msg1)) buf2 := make([]byte, len(msg2)) + buf3 := make([]byte, len(msg2)/2) + buf4 := make([]byte, len(msg2)/2) rg1 := createRouteGroup() rg2 := createRouteGroup() @@ -50,6 +60,7 @@ func TestRouteGroup_Read(t *testing.T) { rg1.readCh <- msg1 rg2.readCh <- msg2 + rg2.readCh <- msg3 n, err := rg1.Read([]byte{}) require.Equal(t, 0, n) @@ -65,6 +76,17 @@ func TestRouteGroup_Read(t *testing.T) { require.Equal(t, msg2, buf2) require.Equal(t, len(msg2), n) + // Test short reads. + n, err = rg2.Read(buf3) + require.NoError(t, err) + require.Equal(t, msg3[0:len(msg3)/2], buf3) + require.Equal(t, len(msg3)/2, n) + + n, err = rg2.Read(buf4) + require.NoError(t, err) + require.Equal(t, msg3[len(msg3)/2:], buf4) + require.Equal(t, len(msg3)/2, n) + require.NoError(t, rg1.Close()) require.NoError(t, rg2.Close()) } @@ -108,6 +130,24 @@ func TestRouteGroup_Write(t *testing.T) { require.NoError(t, err) require.Equal(t, msg1, recv.Payload()) + tpBackup := rg1.tps[0] + rg1.tps[0] = nil + _, err = rg1.Write(msg1) + require.Equal(t, ErrBadTransport, err) + rg1.tps[0] = tpBackup + + tpsBackup := rg1.tps + rg1.tps = nil + _, err = rg1.Write(msg1) + require.Equal(t, ErrNoTransports, err) + rg1.tps = tpsBackup + + fwdBackup := rg1.fwd + rg1.fwd = nil + _, err = rg1.Write(msg1) + require.Equal(t, ErrNoRules, err) + rg1.fwd = fwdBackup + require.NoError(t, rg1.Close()) require.NoError(t, rg2.Close()) } @@ -502,8 +542,7 @@ func createRouteGroup() *RouteGroup { port2 := routing.Port(2) desc := routing.NewRouteDescriptor(pk1, pk2, port1, port2) - cfg := DefaultRouteGroupConfig() - rg := NewRouteGroup(cfg, rt, desc) + rg := NewRouteGroup(nil, rt, desc) return rg } diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index eb8e641b1..da3a37346 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -3,6 +3,7 @@ package router import ( "context" "fmt" + "io" "log" "os" "testing" @@ -151,6 +152,32 @@ func Test_router_Introduce_AcceptRoutes(t *testing.T) { require.Contains(t, allRules, cnsmRule) require.NoError(t, r0.Close()) + require.Equal(t, io.ErrClosedPipe, r0.IntroduceRules(rules)) +} + +func TestRouter_Serve(t *testing.T) { + // We are generating two key pairs - one for the a `Router`, the other to send packets to `Router`. + keys := snettest.GenKeyPairs(2) + + // create test env + nEnv := snettest.NewEnv(t, keys, []string{dmsg.Type}) + defer nEnv.Teardown() + + rEnv := NewTestEnv(t, nEnv.Nets) + defer rEnv.Teardown() + + // Create routers + r0Ifc, err := New(nEnv.Nets[0], rEnv.GenRouterConfig(0)) + require.NoError(t, err) + + r0, ok := r0Ifc.(*router) + require.True(t, ok) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + require.NoError(t, r0.tm.Close()) + require.NoError(t, r0.Serve(ctx)) } // Ensure that received packets are handled properly in `(*Router).handleTransportPacket()`. diff --git a/pkg/transport/manager.go b/pkg/transport/manager.go index 0c6587f82..ac92ed6a2 100644 --- a/pkg/transport/manager.go +++ b/pkg/transport/manager.go @@ -73,6 +73,10 @@ func (tm *Manager) serve(ctx context.Context) { var listeners []*snet.Listener for _, netType := range tm.n.TransportNetworks() { + if tm.isClosing() { + return + } + lis, err := tm.n.Listen(netType, skyenv.DmsgTransportPort) if err != nil { tm.Logger.WithError(err).Fatalf("failed to listen on network '%s' of port '%d'",