diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 2f5e76d20..2e785f71b 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -272,20 +272,21 @@ func TestRouter_Rules(t *testing.T) { t.Run("RemoveRouteDescriptor", func(t *testing.T) { clearRules() - pk, _ := cipher.GenerateKeyPair() + localPK, _ := cipher.GenerateKeyPair() + remotePK, _ := cipher.GenerateKeyPair() id, err := r.rt.ReserveKeys(1) require.NoError(t, err) - rule := routing.ConsumeRule(10*time.Minute, id[0], pk, 2, 3) + rule := routing.ConsumeRule(10*time.Minute, id[0], localPK, remotePK, 2, 3) err = r.rt.SaveRule(rule) require.NoError(t, err) - desc := routing.NewRouteDescriptor(cipher.PubKey{}, pk, 3, 2) + desc := routing.NewRouteDescriptor(localPK, remotePK, 3, 2) r.RemoveRouteDescriptor(desc) assert.Equal(t, 1, rt.Count()) - desc = routing.NewRouteDescriptor(cipher.PubKey{}, pk, 2, 3) + desc = routing.NewRouteDescriptor(localPK, remotePK, 2, 3) r.RemoveRouteDescriptor(desc) assert.Equal(t, 0, rt.Count()) }) diff --git a/pkg/routing/rule.go b/pkg/routing/rule.go index c4833f6af..5f114dd49 100644 --- a/pkg/routing/rule.go +++ b/pkg/routing/rule.go @@ -299,7 +299,7 @@ func (rs *RuleSummary) ToRule() (Rule, error) { f := rs.ConsumeFields d := f.RouteDescriptor - return ConsumeRule(rs.KeepAlive, rs.KeyRouteID, d.DstPK, d.SrcPort, d.DstPort), nil + return ConsumeRule(rs.KeepAlive, rs.KeyRouteID, d.SrcPK, d.DstPK, d.SrcPort, d.DstPort), nil case rs.Type == RuleForward: if rs.ConsumeFields != nil || rs.ForwardFields == nil || rs.IntermediaryForwardFields != nil { return nil, errors.New("invalid routing rule summary") @@ -375,7 +375,6 @@ func ConsumeRule(keepAlive time.Duration, key RouteID, localPK, remotePK cipher. rule.setSrcPK(localPK) rule.setDstPK(remotePK) - rule.setSrcPK(cipher.PubKey{}) rule.setDstPort(remotePort) rule.setSrcPort(localPort) diff --git a/pkg/routing/rule_test.go b/pkg/routing/rule_test.go index 26c7de81f..94e4b608b 100644 --- a/pkg/routing/rule_test.go +++ b/pkg/routing/rule_test.go @@ -11,16 +11,18 @@ import ( func TestConsumeRule(t *testing.T) { keepAlive := 2 * time.Minute - pk, _ := cipher.GenerateKeyPair() + localPK, _ := cipher.GenerateKeyPair() + remotePK, _ := cipher.GenerateKeyPair() - rule := ConsumeRule(keepAlive, 1, pk, 2, 3) + rule := ConsumeRule(keepAlive, 1, localPK, remotePK, 2, 3) assert.Equal(t, keepAlive, rule.KeepAlive()) assert.Equal(t, RuleConsume, rule.Type()) assert.Equal(t, RouteID(1), rule.KeyRouteID()) rd := rule.RouteDescriptor() - assert.Equal(t, pk, rd.DstPK()) + assert.Equal(t, localPK, rd.SrcPK()) + assert.Equal(t, remotePK, rd.DstPK()) assert.Equal(t, Port(3), rd.DstPort()) assert.Equal(t, Port(2), rd.SrcPort()) diff --git a/pkg/setup/idreservoir.go b/pkg/setup/idreservoir.go index 0cd4054e8..c19acface 100644 --- a/pkg/setup/idreservoir.go +++ b/pkg/setup/idreservoir.go @@ -152,6 +152,7 @@ func (idr *idReservoir) GenerateRules(fwd, rev routing.Route) ( } desc := route.Desc + srcPK := desc.SrcPK() dstPK := desc.DstPK() srcPort := desc.SrcPort() dstPort := desc.DstPort() @@ -175,8 +176,8 @@ func (idr *idReservoir) GenerateRules(fwd, rev routing.Route) ( rID = nxtRID } - fmt.Printf("GENERATING CONSUME RULE WITH SRC %s") - rule := routing.ConsumeRule(route.KeepAlive, rID, dstPK, srcPort, dstPort) + fmt.Printf("GENERATING CONSUME RULE WITH SRC %s\n", srcPK) + rule := routing.ConsumeRule(route.KeepAlive, rID, srcPK, dstPK, srcPort, dstPort) consumeRules[dstPK] = rule } diff --git a/pkg/visor/rpc_client.go b/pkg/visor/rpc_client.go index 9f50aa2cb..cd99b7a46 100644 --- a/pkg/visor/rpc_client.go +++ b/pkg/visor/rpc_client.go @@ -277,7 +277,7 @@ func NewMockRPCClient(r *rand.Rand, maxTps int, maxRules int) (cipher.PubKey, RP if err != nil { panic(err) } - consumeRule := routing.ConsumeRule(ruleKeepAlive, appRID[0], remotePK, lp, rp) + consumeRule := routing.ConsumeRule(ruleKeepAlive, appRID[0], localPK, remotePK, lp, rp) if err := rt.SaveRule(consumeRule); err != nil { panic(err) }