Skip to content

Commit

Permalink
Add loop callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
nkryuchkov committed Sep 14, 2019
1 parent 08cb296 commit 5ba16e2
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 106 deletions.
10 changes: 8 additions & 2 deletions pkg/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ type Router struct {

wg sync.WaitGroup
mx sync.Mutex

OnConfirmLoop func(loop routing.Loop, rule routing.Rule) (err error)
OnLoopClosed func(loop routing.Loop) error
}

// New constructs a new Router.
Expand All @@ -91,6 +94,9 @@ func New(n *snet.Network, config *Config) (*Router, error) {
staticPorts: make(map[routing.Port]struct{}),
}

r.OnConfirmLoop = r.confirmLoop
r.OnLoopClosed = r.loopClosed

return r, nil
}

Expand Down Expand Up @@ -264,7 +270,7 @@ func (r *Router) confirmLoopWrapper(data []byte) error {
return errors.New("reverse rule is not forward")
}

if err = r.confirmLoop(ld.Loop, rule); err != nil {
if err = r.OnConfirmLoop(ld.Loop, rule); err != nil {
return fmt.Errorf("confirm: %s", err)
}

Expand All @@ -284,7 +290,7 @@ func (r *Router) loopClosedWrapper(data []byte) error {
return err
}

return r.loopClosed(ld.Loop)
return r.OnLoopClosed(ld.Loop)
}

func (r *Router) occupyRouteID(data []byte) ([]routing.RouteID, error) {
Expand Down
208 changes: 104 additions & 104 deletions pkg/router/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,110 +334,110 @@ func TestRouter_Rules(t *testing.T) {
require.NoError(t, setup.DeleteRule(context.TODO(), proto, id))
assert.Equal(t, 0, rt.Count())
})
//
// // TEST: Ensure ConfirmLoop request from SetupNode is handled properly.
// t.Run("ConfirmLoop", func(t *testing.T) {
// clearRules()
//
// var inLoop routing.Loop
// var inRule routing.Rule
//
// r.conf = func(loop routing.Loop, rule routing.Rule) (err error) {
// inLoop = loop
// inRule = rule
// return nil
// }
// defer func() { r.conf.OnConfirmLoop = nil }()
//
// in, out := net.Pipe()
// errCh := make(chan error, 1)
// go func() {
// errCh <- r.handleSetupConn(out)
// close(errCh)
// }()
// defer func() {
// require.NoError(t, in.Close())
// require.NoError(t, <-errCh)
// }()
//
// proto := setup.NewSetupProtocol(in)
// pk, _ := cipher.GenerateKeyPair()
//
// rule := routing.ConsumeRule(10*time.Minute, 2, pk, 2, 3)
// require.NoError(t, rt.SaveRule(rule))
//
// rule = routing.IntermediaryForwardRule(10*time.Minute, 1, 3, uuid.New())
// require.NoError(t, rt.SaveRule(rule))
//
// ld := routing.LoopData{
// Loop: routing.Loop{
// Remote: routing.Addr{
// PubKey: pk,
// Port: 3,
// },
// Local: routing.Addr{
// Port: 2,
// },
// },
// RouteID: 1,
// }
// err := setup.ConfirmLoop(context.TODO(), proto, ld)
// require.NoError(t, err)
// assert.Equal(t, rule, inRule)
// assert.Equal(t, routing.Port(2), inLoop.Local.Port)
// assert.Equal(t, routing.Port(3), inLoop.Remote.Port)
// assert.Equal(t, pk, inLoop.Remote.PubKey)
// })
//
// // TEST: Ensure LoopClosed request from SetupNode is handled properly.
// t.Run("LoopClosed", func(t *testing.T) {
// clearRules()
//
// var inLoop routing.Loop
//
// r.conf.OnLoopClosed = func(loop routing.Loop) error {
// inLoop = loop
// return nil
// }
// defer func() { r.conf.OnLoopClosed = nil }()
//
// in, out := net.Pipe()
// errCh := make(chan error, 1)
// go func() {
// errCh <- r.handleSetupConn(out)
// close(errCh)
// }()
// defer func() {
// require.NoError(t, in.Close())
// require.NoError(t, <-errCh)
// }()
//
// proto := setup.NewSetupProtocol(in)
// pk, _ := cipher.GenerateKeyPair()
//
// rule := routing.ConsumeRule(10*time.Minute, 2, pk, 2, 3)
// require.NoError(t, rt.SaveRule(rule))
//
// rule = routing.IntermediaryForwardRule(10*time.Minute, 1, 3, uuid.New())
// require.NoError(t, rt.SaveRule(rule))
//
// ld := routing.LoopData{
// Loop: routing.Loop{
// Remote: routing.Addr{
// PubKey: pk,
// Port: 3,
// },
// Local: routing.Addr{
// Port: 2,
// },
// },
// RouteID: 1,
// }
// require.NoError(t, setup.LoopClosed(context.TODO(), proto, ld))
// assert.Equal(t, routing.Port(2), inLoop.Local.Port)
// assert.Equal(t, routing.Port(3), inLoop.Remote.Port)
// assert.Equal(t, pk, inLoop.Remote.PubKey)
// })

// TEST: Ensure ConfirmLoop request from SetupNode is handled properly.
t.Run("ConfirmLoop", func(t *testing.T) {
clearRules()

var inLoop routing.Loop
var inRule routing.Rule

r.OnConfirmLoop = func(loop routing.Loop, rule routing.Rule) (err error) {
inLoop = loop
inRule = rule
return nil
}
defer func() { r.OnConfirmLoop = nil }()

in, out := net.Pipe()
errCh := make(chan error, 1)
go func() {
errCh <- r.handleSetupConn(out)
close(errCh)
}()
defer func() {
require.NoError(t, in.Close())
require.NoError(t, <-errCh)
}()

proto := setup.NewSetupProtocol(in)
pk, _ := cipher.GenerateKeyPair()

rule := routing.ConsumeRule(10*time.Minute, 2, pk, 2, 3)
require.NoError(t, rt.SaveRule(rule))

rule = routing.IntermediaryForwardRule(10*time.Minute, 1, 3, uuid.New())
require.NoError(t, rt.SaveRule(rule))

ld := routing.LoopData{
Loop: routing.Loop{
Remote: routing.Addr{
PubKey: pk,
Port: 3,
},
Local: routing.Addr{
Port: 2,
},
},
RouteID: 1,
}
err := setup.ConfirmLoop(context.TODO(), proto, ld)
require.NoError(t, err)
assert.Equal(t, rule, inRule)
assert.Equal(t, routing.Port(2), inLoop.Local.Port)
assert.Equal(t, routing.Port(3), inLoop.Remote.Port)
assert.Equal(t, pk, inLoop.Remote.PubKey)
})

// TEST: Ensure LoopClosed request from SetupNode is handled properly.
t.Run("LoopClosed", func(t *testing.T) {
clearRules()

var inLoop routing.Loop

r.OnLoopClosed = func(loop routing.Loop) error {
inLoop = loop
return nil
}
defer func() { r.OnLoopClosed = nil }()

in, out := net.Pipe()
errCh := make(chan error, 1)
go func() {
errCh <- r.handleSetupConn(out)
close(errCh)
}()
defer func() {
require.NoError(t, in.Close())
require.NoError(t, <-errCh)
}()

proto := setup.NewSetupProtocol(in)
pk, _ := cipher.GenerateKeyPair()

rule := routing.ConsumeRule(10*time.Minute, 2, pk, 2, 3)
require.NoError(t, rt.SaveRule(rule))

rule = routing.IntermediaryForwardRule(10*time.Minute, 1, 3, uuid.New())
require.NoError(t, rt.SaveRule(rule))

ld := routing.LoopData{
Loop: routing.Loop{
Remote: routing.Addr{
PubKey: pk,
Port: 3,
},
Local: routing.Addr{
Port: 2,
},
},
RouteID: 1,
}
require.NoError(t, setup.LoopClosed(context.TODO(), proto, ld))
assert.Equal(t, routing.Port(2), inLoop.Local.Port)
assert.Equal(t, routing.Port(3), inLoop.Remote.Port)
assert.Equal(t, pk, inLoop.Remote.PubKey)
})
}

type TestEnv struct {
Expand Down

0 comments on commit 5ba16e2

Please sign in to comment.