diff --git a/pkg/router/route_group.go b/pkg/router/route_group.go index be5c74eb42..ad0c336f8d 100644 --- a/pkg/router/route_group.go +++ b/pkg/router/route_group.go @@ -397,7 +397,7 @@ func (rg *RouteGroup) handleClosePacket(code routing.CloseCode) error { } // TODO: use `close` with some close code if we decide that it should be different from the current one - return rg.Close() + return rg.close(code) } func (rg *RouteGroup) broadcastClosePackets(code routing.CloseCode) error { diff --git a/pkg/router/route_group_test.go b/pkg/router/route_group_test.go index 60a384d68e..744d2bcdc7 100644 --- a/pkg/router/route_group_test.go +++ b/pkg/router/route_group_test.go @@ -29,18 +29,98 @@ func TestNewRouteGroup(t *testing.T) { } func TestRouteGroup_Close(t *testing.T) { - rg := createRouteGroup() - require.NotNil(t, rg) + keys := snettest.GenKeyPairs(2) - require.False(t, rg.isClosed()) - require.NoError(t, rg.Close()) - require.True(t, rg.isClosed()) + pk1 := keys[0].PK + pk2 := keys[1].PK - rg = createRouteGroup() - require.NotNil(t, rg) + // create test env + nEnv := snettest.NewEnv(t, keys, []string{stcp.Type}) + defer nEnv.Teardown() + + tpDisc := transport.NewDiscoveryMock() + tpKeys := snettest.GenKeyPairs(2) + + m1, m2, tp1, tp2, err := transport.CreateTransportPair(tpDisc, tpKeys, nEnv, stcp.Type) + require.NoError(t, err) + require.NotNil(t, tp1) + require.NotNil(t, tp2) + require.NotNil(t, tp1.Entry) + require.NotNil(t, tp2.Entry) + + rg0 := createRouteGroup() + rg1 := createRouteGroup() + + // reserve FWD and CNSM IDs for r0. + r0RtIDs, err := rg0.rt.ReserveKeys(2) + require.NoError(t, err) + + // reserve FWD and CNSM IDs for r1. + r1RtIDs, err := rg1.rt.ReserveKeys(2) + require.NoError(t, err) + + r0FwdRule := routing.ForwardRule(ruleKeepAlive, r0RtIDs[0], r1RtIDs[1], tp1.Entry.ID, pk2, pk1, 0, 0) + r0CnsmRule := routing.ConsumeRule(ruleKeepAlive, r0RtIDs[1], pk1, pk2, 0, 0) + + err = rg0.rt.SaveRule(r0FwdRule) + require.NoError(t, err) + err = rg0.rt.SaveRule(r0CnsmRule) + require.NoError(t, err) - rg.tps = append(rg.tps, &transport.ManagedTransport{}) - require.Equal(t, ErrRuleTransportMismatch, rg.Close()) + r1FwdRule := routing.ForwardRule(ruleKeepAlive, r1RtIDs[0], r0RtIDs[1], tp2.Entry.ID, pk1, pk2, 0, 0) + r1CnsmRule := routing.ConsumeRule(ruleKeepAlive, r1RtIDs[1], pk2, pk1, 0, 0) + + err = rg1.rt.SaveRule(r1FwdRule) + require.NoError(t, err) + err = rg1.rt.SaveRule(r1CnsmRule) + require.NoError(t, err) + + r0FwdRtDesc := r0FwdRule.RouteDescriptor() + rg0.desc = r0FwdRtDesc.Invert() + rg0.tps = append(rg0.tps, tp1) + rg0.fwd = append(rg0.fwd, r0FwdRule) + + r1FwdRtDesc := r1FwdRule.RouteDescriptor() + rg1.desc = r1FwdRtDesc.Invert() + rg1.tps = append(rg1.tps, tp2) + rg1.fwd = append(rg1.fwd, r1FwdRule) + + // push close packet from transport to route group + go func() { + packet, err := m1.ReadPacket() + if err != nil { + panic(err) + } + + if packet.Type() != routing.ClosePacket { + panic("wrong packet type") + } + + if err := rg0.handleClosePacket(routing.CloseCode(packet.Payload()[0])); err != nil { + panic(err) + } + }() + + // push close packet from transport to route group + go func() { + packet, err := m2.ReadPacket() + if err != nil { + panic(err) + } + + if packet.Type() != routing.ClosePacket { + panic("wrong packet type") + } + + if err := rg1.handleClosePacket(routing.CloseCode(packet.Payload()[0])); err != nil { + panic(err) + } + }() + + err = rg0.Close() + require.NoError(t, err) + require.True(t, rg0.isClosed()) + require.True(t, rg1.isClosed()) } func TestRouteGroup_Read(t *testing.T) { @@ -485,7 +565,7 @@ func TestRouteGroup_TestConn(t *testing.T) { cancel() teardownEnv() require.NoError(t, rg1.Close()) - require.NoError(t, rg2.Close()) + //require.NoError(t, rg2.Close()) } return @@ -505,7 +585,7 @@ func pushPackets(ctx context.Context, t *testing.T, from *transport.Manager, to packet, err := from.ReadPacket() assert.NoError(t, err) - if packet.Type() != routing.DataPacket { + if packet.Type() != routing.DataPacket && packet.Type() != routing.ClosePacket { continue }