diff --git a/pkg/router/router.go b/pkg/router/router.go index fc223a38ed..d21b31fbb2 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -69,6 +69,9 @@ type Router struct { wg sync.WaitGroup mx sync.Mutex + + OnConfirmLoop func(loop routing.Loop, rule routing.Rule) (err error) + OnLoopClosed func(loop routing.Loop) error } // New constructs a new Router. @@ -91,6 +94,9 @@ func New(n *snet.Network, config *Config) (*Router, error) { staticPorts: make(map[routing.Port]struct{}), } + r.OnConfirmLoop = r.confirmLoop + r.OnLoopClosed = r.loopClosed + return r, nil } @@ -264,7 +270,7 @@ func (r *Router) confirmLoopWrapper(data []byte) error { return errors.New("reverse rule is not forward") } - if err = r.confirmLoop(ld.Loop, rule); err != nil { + if err = r.OnConfirmLoop(ld.Loop, rule); err != nil { return fmt.Errorf("confirm: %s", err) } @@ -284,7 +290,7 @@ func (r *Router) loopClosedWrapper(data []byte) error { return err } - return r.loopClosed(ld.Loop) + return r.OnLoopClosed(ld.Loop) } func (r *Router) occupyRouteID(data []byte) ([]routing.RouteID, error) { diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 725972e805..ddab1f03ba 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -334,110 +334,110 @@ func TestRouter_Rules(t *testing.T) { require.NoError(t, setup.DeleteRule(context.TODO(), proto, id)) assert.Equal(t, 0, rt.Count()) }) - // - // // TEST: Ensure ConfirmLoop request from SetupNode is handled properly. - // t.Run("ConfirmLoop", func(t *testing.T) { - // clearRules() - // - // var inLoop routing.Loop - // var inRule routing.Rule - // - // r.conf = func(loop routing.Loop, rule routing.Rule) (err error) { - // inLoop = loop - // inRule = rule - // return nil - // } - // defer func() { r.conf.OnConfirmLoop = nil }() - // - // in, out := net.Pipe() - // errCh := make(chan error, 1) - // go func() { - // errCh <- r.handleSetupConn(out) - // close(errCh) - // }() - // defer func() { - // require.NoError(t, in.Close()) - // require.NoError(t, <-errCh) - // }() - // - // proto := setup.NewSetupProtocol(in) - // pk, _ := cipher.GenerateKeyPair() - // - // rule := routing.ConsumeRule(10*time.Minute, 2, pk, 2, 3) - // require.NoError(t, rt.SaveRule(rule)) - // - // rule = routing.IntermediaryForwardRule(10*time.Minute, 1, 3, uuid.New()) - // require.NoError(t, rt.SaveRule(rule)) - // - // ld := routing.LoopData{ - // Loop: routing.Loop{ - // Remote: routing.Addr{ - // PubKey: pk, - // Port: 3, - // }, - // Local: routing.Addr{ - // Port: 2, - // }, - // }, - // RouteID: 1, - // } - // err := setup.ConfirmLoop(context.TODO(), proto, ld) - // require.NoError(t, err) - // assert.Equal(t, rule, inRule) - // assert.Equal(t, routing.Port(2), inLoop.Local.Port) - // assert.Equal(t, routing.Port(3), inLoop.Remote.Port) - // assert.Equal(t, pk, inLoop.Remote.PubKey) - // }) - // - // // TEST: Ensure LoopClosed request from SetupNode is handled properly. - // t.Run("LoopClosed", func(t *testing.T) { - // clearRules() - // - // var inLoop routing.Loop - // - // r.conf.OnLoopClosed = func(loop routing.Loop) error { - // inLoop = loop - // return nil - // } - // defer func() { r.conf.OnLoopClosed = nil }() - // - // in, out := net.Pipe() - // errCh := make(chan error, 1) - // go func() { - // errCh <- r.handleSetupConn(out) - // close(errCh) - // }() - // defer func() { - // require.NoError(t, in.Close()) - // require.NoError(t, <-errCh) - // }() - // - // proto := setup.NewSetupProtocol(in) - // pk, _ := cipher.GenerateKeyPair() - // - // rule := routing.ConsumeRule(10*time.Minute, 2, pk, 2, 3) - // require.NoError(t, rt.SaveRule(rule)) - // - // rule = routing.IntermediaryForwardRule(10*time.Minute, 1, 3, uuid.New()) - // require.NoError(t, rt.SaveRule(rule)) - // - // ld := routing.LoopData{ - // Loop: routing.Loop{ - // Remote: routing.Addr{ - // PubKey: pk, - // Port: 3, - // }, - // Local: routing.Addr{ - // Port: 2, - // }, - // }, - // RouteID: 1, - // } - // require.NoError(t, setup.LoopClosed(context.TODO(), proto, ld)) - // assert.Equal(t, routing.Port(2), inLoop.Local.Port) - // assert.Equal(t, routing.Port(3), inLoop.Remote.Port) - // assert.Equal(t, pk, inLoop.Remote.PubKey) - // }) + + // TEST: Ensure ConfirmLoop request from SetupNode is handled properly. + t.Run("ConfirmLoop", func(t *testing.T) { + clearRules() + + var inLoop routing.Loop + var inRule routing.Rule + + r.OnConfirmLoop = func(loop routing.Loop, rule routing.Rule) (err error) { + inLoop = loop + inRule = rule + return nil + } + defer func() { r.OnConfirmLoop = nil }() + + in, out := net.Pipe() + errCh := make(chan error, 1) + go func() { + errCh <- r.handleSetupConn(out) + close(errCh) + }() + defer func() { + require.NoError(t, in.Close()) + require.NoError(t, <-errCh) + }() + + proto := setup.NewSetupProtocol(in) + pk, _ := cipher.GenerateKeyPair() + + rule := routing.ConsumeRule(10*time.Minute, 2, pk, 2, 3) + require.NoError(t, rt.SaveRule(rule)) + + rule = routing.IntermediaryForwardRule(10*time.Minute, 1, 3, uuid.New()) + require.NoError(t, rt.SaveRule(rule)) + + ld := routing.LoopData{ + Loop: routing.Loop{ + Remote: routing.Addr{ + PubKey: pk, + Port: 3, + }, + Local: routing.Addr{ + Port: 2, + }, + }, + RouteID: 1, + } + err := setup.ConfirmLoop(context.TODO(), proto, ld) + require.NoError(t, err) + assert.Equal(t, rule, inRule) + assert.Equal(t, routing.Port(2), inLoop.Local.Port) + assert.Equal(t, routing.Port(3), inLoop.Remote.Port) + assert.Equal(t, pk, inLoop.Remote.PubKey) + }) + + // TEST: Ensure LoopClosed request from SetupNode is handled properly. + t.Run("LoopClosed", func(t *testing.T) { + clearRules() + + var inLoop routing.Loop + + r.OnLoopClosed = func(loop routing.Loop) error { + inLoop = loop + return nil + } + defer func() { r.OnLoopClosed = nil }() + + in, out := net.Pipe() + errCh := make(chan error, 1) + go func() { + errCh <- r.handleSetupConn(out) + close(errCh) + }() + defer func() { + require.NoError(t, in.Close()) + require.NoError(t, <-errCh) + }() + + proto := setup.NewSetupProtocol(in) + pk, _ := cipher.GenerateKeyPair() + + rule := routing.ConsumeRule(10*time.Minute, 2, pk, 2, 3) + require.NoError(t, rt.SaveRule(rule)) + + rule = routing.IntermediaryForwardRule(10*time.Minute, 1, 3, uuid.New()) + require.NoError(t, rt.SaveRule(rule)) + + ld := routing.LoopData{ + Loop: routing.Loop{ + Remote: routing.Addr{ + PubKey: pk, + Port: 3, + }, + Local: routing.Addr{ + Port: 2, + }, + }, + RouteID: 1, + } + require.NoError(t, setup.LoopClosed(context.TODO(), proto, ld)) + assert.Equal(t, routing.Port(2), inLoop.Local.Port) + assert.Equal(t, routing.Port(3), inLoop.Remote.Port) + assert.Equal(t, pk, inLoop.Remote.PubKey) + }) } type TestEnv struct {