diff --git a/pkg/setup/node.go b/pkg/setup/node.go index 3cf69496f9..e5d255ed5b 100644 --- a/pkg/setup/node.go +++ b/pkg/setup/node.go @@ -339,18 +339,10 @@ func (sn *Node) closeLoop(ctx context.Context, on cipher.PubKey, ld routing.Loop func (sn *Node) setupRule(ctx context.Context, pubKey cipher.PubKey, rule routing.Rule, regIDChIn <-chan routing.RouteID, regIDChOut chan<- routing.RouteID) (routing.RouteID, error) { sn.Logger.Debugf("trying to setup setup rule: %v with %s\n", rule, pubKey) - proto, err := sn.dialAndCreateProto(ctx, pubKey) + registrationID, err := sn.requestRegistrationID(ctx, pubKey) if err != nil { return 0, err } - defer sn.closeProto(proto) - - registrationID, err := RequestRegistrationID(ctx, proto) - if err != nil { - return 0, err - } - - sn.Logger.Infof("Received route ID %d from %s", registrationID, pubKey) if regIDChOut != nil { regIDChOut <- registrationID @@ -365,7 +357,7 @@ func (sn *Node) setupRule(ctx context.Context, pubKey cipher.PubKey, rule routin sn.Logger.Debugf("dialing to %s to setup rule: %v\n", pubKey, rule) - if err := AddRule(ctx, proto, rule); err != nil { + if err := sn.addRule(ctx, pubKey, rule); err != nil { return 0, err } @@ -374,6 +366,33 @@ func (sn *Node) setupRule(ctx context.Context, pubKey cipher.PubKey, rule routin return registrationID, nil } +func (sn *Node) requestRegistrationID(ctx context.Context, pubKey cipher.PubKey) (routing.RouteID, error) { + proto, err := sn.dialAndCreateProto(ctx, pubKey) + if err != nil { + return 0, err + } + defer sn.closeProto(proto) + + registrationID, err := RequestRegistrationID(ctx, proto) + if err != nil { + return 0, err + } + + sn.Logger.Infof("Received route ID %d from %s", registrationID, pubKey) + + return registrationID, nil +} + +func (sn *Node) addRule(ctx context.Context, pubKey cipher.PubKey, rule routing.Rule) error { + proto, err := sn.dialAndCreateProto(ctx, pubKey) + if err != nil { + return err + } + defer sn.closeProto(proto) + + return AddRule(ctx, proto, rule) +} + func (sn *Node) dialAndCreateProto(ctx context.Context, pubKey cipher.PubKey) (*Protocol, error) { sn.Logger.Debugf("dialing to %s\n", pubKey) tr, err := sn.dmsgC.Dial(ctx, pubKey, snet.AwaitSetupPort)