From 6dd7c8b2168e33226fc618d90c62b72e00c93352 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Fri, 27 Dec 2019 14:03:37 +0300 Subject: [PATCH] Fix router close packet handling tests --- pkg/router/route_group.go | 10 ++-- pkg/router/router.go | 13 ++--- pkg/router/router_test.go | 120 +++++++++++++++++++++++++++++++++++--- 3 files changed, 120 insertions(+), 23 deletions(-) diff --git a/pkg/router/route_group.go b/pkg/router/route_group.go index cc5d8252a8..be5c74eb42 100644 --- a/pkg/router/route_group.go +++ b/pkg/router/route_group.go @@ -368,14 +368,12 @@ func (rg *RouteGroup) close(code routing.CloseCode) error { } } - rules := rg.rt.RulesWithDesc(rg.desc) - routeIDs := make([]routing.RouteID, 0, len(rules)) - - for _, rule := range rules { - routeIDs = append(routeIDs, rule.KeyRouteID()) + rules := make([]routing.RouteID, 0, len(rg.fwd)) + for _, r := range rg.fwd { + rules = append(rules, r.KeyRouteID()) } - rg.rt.DelRules(routeIDs) + rg.rt.DelRules(rules) rg.once.Do(func() { close(rg.done) diff --git a/pkg/router/router.go b/pkg/router/router.go index 2291a6fbb4..544bf8a396 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -390,16 +390,13 @@ func (r *router) handleClosePacket(ctx context.Context, packet routing.Packet) e return err } + defer func() { + routeIDs := []routing.RouteID{routeID} + r.rt.DelRules(routeIDs) + }() + if t := rule.Type(); t == routing.RuleIntermediaryForward { r.logger.Infoln("Handling intermediary close packet") - - // defer this only on intermediary nodes. destination node will remove - // the needed rules in the route group `Close` routine - defer func() { - routeIDs := []routing.RouteID{routeID} - r.rt.DelRules(routeIDs) - }() - return r.forwardPacket(ctx, packet, rule) } diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 7b8e5fc8a9..1a29917ded 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -251,10 +251,18 @@ func testHandlePackets(t *testing.T, r0, r1 *router, tp1 *transport.ManagedTrans wg.Wait() wg.Add(1) - t.Run("handlePacket_close", func(t *testing.T) { + t.Run("handlePacket_close_initiator", func(t *testing.T) { defer wg.Done() - testClosePacket(t, r0, r1, pk1, pk2) + testClosePacketInitiator(t, r0, r1, pk1, pk2, tp1) + }) + wg.Wait() + + wg.Add(1) + t.Run("handlePacket_close_remote", func(t *testing.T) { + defer wg.Done() + + testClosePacketRemote(t, r0, r1, pk1, pk2, tp1) }) wg.Wait() @@ -294,22 +302,116 @@ func testKeepAlivePacket(t *testing.T, r0, r1 *router, pk1, pk2 cipher.PubKey) { require.Len(t, r0.rt.AllRules(), 0) } -func testClosePacket(t *testing.T, r0, r1 *router, pk1, pk2 cipher.PubKey) { +func testClosePacketRemote(t *testing.T, r0, r1 *router, pk1, pk2 cipher.PubKey, tp1 *transport.ManagedTransport) { defer clearRouterRules(r0, r1) defer clearRouteGroups(r0, r1) - rtIDs, err := r0.ReserveKeys(1) + // reserve FWD IDs for r0. + intFwdID, err := r0.ReserveKeys(1) require.NoError(t, err) - cnsmRule := routing.ConsumeRule(ruleKeepAlive, rtIDs[0], pk2, pk1, 0, 0) - err = r0.rt.SaveRule(cnsmRule) + // reserve FWD and CNSM IDs for r1. + r1RtIDs, err := r1.ReserveKeys(2) require.NoError(t, err) - require.Len(t, r0.rt.AllRules(), 1) - packet := routing.MakeClosePacket(rtIDs[0], routing.CloseRequested) - require.NoError(t, r0.handleTransportPacket(context.TODO(), packet)) + intFwdRule := routing.IntermediaryForwardRule(1*time.Hour, intFwdID[0], r1RtIDs[1], tp1.Entry.ID) + err = r0.rt.SaveRule(intFwdRule) + require.NoError(t, err) + + routeID := routing.RouteID(7) + fwdRule := routing.ForwardRule(ruleKeepAlive, r1RtIDs[0], routeID, tp1.Entry.ID, pk1, pk2, 0, 0) + cnsmRule := routing.ConsumeRule(ruleKeepAlive, r1RtIDs[1], pk2, pk1, 0, 0) + + err = r1.rt.SaveRule(fwdRule) + require.NoError(t, err) + + err = r1.rt.SaveRule(cnsmRule) + require.NoError(t, err) + + fwdRtDesc := fwdRule.RouteDescriptor() + + rg1 := r1.saveRouteGroupRules(routing.EdgeRules{ + Desc: fwdRtDesc.Invert(), + Forward: fwdRule, + Reverse: cnsmRule, + }) + + packet := routing.MakeClosePacket(intFwdID[0], routing.CloseRequested) + err = r0.handleTransportPacket(context.TODO(), packet) + require.NoError(t, err) + + recvPacket, err := r1.tm.ReadPacket() + require.NoError(t, err) + require.Equal(t, packet.Size(), recvPacket.Size()) + require.Equal(t, packet.Payload(), recvPacket.Payload()) + require.Equal(t, packet.Type(), recvPacket.Type()) + require.Equal(t, r1RtIDs[1], recvPacket.RouteID()) + + err = r1.handleTransportPacket(context.TODO(), recvPacket) + require.NoError(t, err) + + require.True(t, rg1.isClosed()) + require.Len(t, r1.rgs, 0) + require.Len(t, r0.rt.AllRules(), 0) + require.Len(t, r1.rt.AllRules(), 0) +} + +func testClosePacketInitiator(t *testing.T, r0, r1 *router, pk1, pk2 cipher.PubKey, tp1 *transport.ManagedTransport) { + defer clearRouterRules(r0, r1) + defer clearRouteGroups(r0, r1) + + // reserve FWD IDs for r0. + intFwdID, err := r0.ReserveKeys(1) + require.NoError(t, err) + + // reserve FWD and CNSM IDs for r1. + r1RtIDs, err := r1.ReserveKeys(2) + require.NoError(t, err) + + intFwdRule := routing.IntermediaryForwardRule(1*time.Hour, intFwdID[0], r1RtIDs[1], tp1.Entry.ID) + err = r0.rt.SaveRule(intFwdRule) + require.NoError(t, err) + + routeID := routing.RouteID(7) + fwdRule := routing.ForwardRule(ruleKeepAlive, r1RtIDs[0], routeID, tp1.Entry.ID, pk1, pk2, 0, 0) + cnsmRule := routing.ConsumeRule(ruleKeepAlive, r1RtIDs[1], pk2, pk1, 0, 0) + + err = r1.rt.SaveRule(fwdRule) + require.NoError(t, err) + + err = r1.rt.SaveRule(cnsmRule) + require.NoError(t, err) + + fwdRtDesc := fwdRule.RouteDescriptor() + + rg1 := r1.saveRouteGroupRules(routing.EdgeRules{ + Desc: fwdRtDesc.Invert(), + Forward: fwdRule, + Reverse: cnsmRule, + }) + + packet := routing.MakeClosePacket(intFwdID[0], routing.CloseRequested) + err = r0.handleTransportPacket(context.TODO(), packet) + require.NoError(t, err) + + recvPacket, err := r1.tm.ReadPacket() + require.NoError(t, err) + require.Equal(t, packet.Size(), recvPacket.Size()) + require.Equal(t, packet.Payload(), recvPacket.Payload()) + require.Equal(t, packet.Type(), recvPacket.Type()) + require.Equal(t, r1RtIDs[1], recvPacket.RouteID()) + + rg1.closeDone.Add(1) + rg1.closeInitiated = 1 + + err = r1.handleTransportPacket(context.TODO(), recvPacket) + require.NoError(t, err) + require.Len(t, r1.rgs, 0) require.Len(t, r0.rt.AllRules(), 0) + // since this is the close initiator but the close routine wasn't called, + // forward rule is left + require.Len(t, r1.rt.AllRules(), 1) } // TEST: Ensure handleTransportPacket does as expected.