diff --git a/pkg/routing/rule.go b/pkg/routing/rule.go index a9a6012e6a..44eee2e547 100644 --- a/pkg/routing/rule.go +++ b/pkg/routing/rule.go @@ -96,6 +96,12 @@ func (r Rule) LocalPort() Port { return Port(binary.BigEndian.Uint16(r[48:])) } +// RegistrationID returns route ID which will be used to register this rule within +// the visor node. +func (r Rule) RegistrationID() RouteID { + return RouteID(binary.BigEndian.Uint32(r[50:])) +} + func (r Rule) String() string { if r.Type() == RuleApp { return fmt.Sprintf("App: ", @@ -121,21 +127,22 @@ type RuleForwardFields struct { // RuleSummary provides a summary of a RoutingRule. type RuleSummary struct { - ExpireAt time.Time `json:"expire_at"` - Type RuleType `json:"rule_type"` - AppFields *RuleAppFields `json:"app_fields,omitempty"` - ForwardFields *RuleForwardFields `json:"forward_fields,omitempty"` + ExpireAt time.Time `json:"expire_at"` + Type RuleType `json:"rule_type"` + AppFields *RuleAppFields `json:"app_fields,omitempty"` + ForwardFields *RuleForwardFields `json:"forward_fields,omitempty"` + RegistrationID RouteID `json:"registration_id"` } // ToRule converts RoutingRuleSummary to RoutingRule. func (rs *RuleSummary) ToRule() (Rule, error) { if rs.Type == RuleApp && rs.AppFields != nil && rs.ForwardFields == nil { f := rs.AppFields - return AppRule(rs.ExpireAt, f.RespRID, f.RemotePK, f.RemotePort, f.LocalPort), nil + return AppRule(rs.ExpireAt, f.RespRID, f.RemotePK, f.RemotePort, f.LocalPort, rs.RegistrationID), nil } if rs.Type == RuleForward && rs.AppFields == nil && rs.ForwardFields != nil { f := rs.ForwardFields - return ForwardRule(rs.ExpireAt, f.NextRID, f.NextTID), nil + return ForwardRule(rs.ExpireAt, f.NextRID, f.NextTID, rs.RegistrationID), nil } return nil, errors.New("invalid routing rule summary") } @@ -143,8 +150,9 @@ func (rs *RuleSummary) ToRule() (Rule, error) { // Summary returns the RoutingRule's summary. func (r Rule) Summary() *RuleSummary { summary := RuleSummary{ - ExpireAt: r.Expiry(), - Type: r.Type(), + ExpireAt: r.Expiry(), + Type: r.Type(), + RegistrationID: r.RegistrationID(), } if summary.Type == RuleApp { summary.AppFields = &RuleAppFields{ @@ -163,7 +171,8 @@ func (r Rule) Summary() *RuleSummary { } // AppRule constructs a new consume RoutingRule. -func AppRule(expireAt time.Time, respRoute RouteID, remotePK cipher.PubKey, remotePort, localPort Port) Rule { +func AppRule(expireAt time.Time, respRoute RouteID, remotePK cipher.PubKey, remotePort, localPort Port, + registrationID RouteID) Rule { rule := make([]byte, 13) if expireAt.Unix() <= time.Now().Unix() { binary.BigEndian.PutUint64(rule[0:], 0) @@ -177,11 +186,12 @@ func AppRule(expireAt time.Time, respRoute RouteID, remotePK cipher.PubKey, remo rule = append(rule, 0, 0, 0, 0) binary.BigEndian.PutUint16(rule[46:], uint16(remotePort)) binary.BigEndian.PutUint16(rule[48:], uint16(localPort)) + binary.BigEndian.PutUint32(rule[50:], uint32(registrationID)) return Rule(rule) } // ForwardRule constructs a new forward RoutingRule. -func ForwardRule(expireAt time.Time, nextRoute RouteID, nextTrID uuid.UUID) Rule { +func ForwardRule(expireAt time.Time, nextRoute RouteID, nextTrID uuid.UUID, registrationID RouteID) Rule { rule := make([]byte, 13) if expireAt.Unix() <= time.Now().Unix() { binary.BigEndian.PutUint64(rule[0:], 0) @@ -192,5 +202,7 @@ func ForwardRule(expireAt time.Time, nextRoute RouteID, nextTrID uuid.UUID) Rule rule[8] = byte(RuleForward) binary.BigEndian.PutUint32(rule[9:], uint32(nextRoute)) rule = append(rule, nextTrID[:]...) + rule = append(rule, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + binary.BigEndian.PutUint32(rule[50:], uint32(registrationID)) return Rule(rule) } diff --git a/pkg/setup/node.go b/pkg/setup/node.go index 0d54e99a27..246a04a054 100644 --- a/pkg/setup/node.go +++ b/pkg/setup/node.go @@ -198,16 +198,14 @@ func (sn *Node) createRoute(expireAt time.Time, route routing.Route, rport, lpor for idx := len(r) - 1; idx >= 0; idx-- { hop := &Hop{Hop: route[idx]} r[idx] = hop - var rule routing.Rule - if idx == len(r)-1 { - rule = routing.AppRule(expireAt, 0, initiator, lport, rport) - } else { - nextHop := r[idx+1] - rule = routing.ForwardRule(expireAt, nextHop.routeID, nextHop.Transport) + + var nextHop *Hop + if idx != len(r)-1 { + nextHop = r[idx+1] } - go func(ctx context.Context, hop *Hop, rule routing.Rule) { - routeID, err := sn.setupRule(ctx, hop.To, rule) + go func(idx int, hop, nextHop *Hop) { + routeID, err := sn.requestRouteID(ctx, hop.To) if err != nil { // filter out context cancellation errors if err == context.Canceled { @@ -220,9 +218,27 @@ func (sn *Node) createRoute(expireAt time.Time, route routing.Route, rport, lpor hop.routeID = routeID + var rule routing.Rule + if nextHop == nil { + rule = routing.AppRule(expireAt, 0, initiator, lport, rport, routeID) + } else { + rule = routing.ForwardRule(expireAt, nextHop.routeID, nextHop.Transport, routeID) + } + + err = sn.setupRule(ctx, hop.To, rule) + if err != nil { + // filter out context cancellation errors + if err == context.Canceled { + rulesSetupErrs <- err + } else { + rulesSetupErrs <- fmt.Errorf("rule setup: %s", err) + } + return + } + // put nil to avoid block rulesSetupErrs <- nil - }(ctx, hop, rule) + }(idx, hop, nextHop) } var rulesSetupErr error @@ -243,9 +259,13 @@ func (sn *Node) createRoute(expireAt time.Time, route routing.Route, rport, lpor return 0, rulesSetupErr } - rule := routing.ForwardRule(expireAt, r[0].routeID, r[0].Transport) - routeID, err := sn.setupRule(context.Background(), initiator, rule) + routeID, err := sn.requestRouteID(context.Background(), initiator) if err != nil { + return 0, fmt.Errorf("request route id: %s", err) + } + + rule := routing.ForwardRule(expireAt, r[0].routeID, r[0].Transport, routeID) + if err := sn.setupRule(context.Background(), initiator, rule); err != nil { return 0, fmt.Errorf("rule setup: %s", err) } @@ -317,7 +337,7 @@ func (sn *Node) closeLoop(on cipher.PubKey, ld routing.LoopData) error { return nil } -func (sn *Node) requestRouteID(ctx context.Context, pubKey cipher.PubKey) (uint32, error) { +func (sn *Node) requestRouteID(ctx context.Context, pubKey cipher.PubKey) (routing.RouteID, error) { sn.Logger.Debugf("dialing to %s to request route ID\n", pubKey) tr, err := sn.messenger.Dial(ctx, pubKey) if err != nil { @@ -339,13 +359,11 @@ func (sn *Node) requestRouteID(ctx context.Context, pubKey cipher.PubKey) (uint3 return routeID, nil } -func (sn *Node) setupRule(ctx context.Context, pubKey cipher.PubKey, - rule routing.Rule) (routeID routing.RouteID, err error) { +func (sn *Node) setupRule(ctx context.Context, pubKey cipher.PubKey, rule routing.Rule) error { sn.Logger.Debugf("dialing to %s to setup rule: %v\n", pubKey, rule) tr, err := sn.messenger.Dial(ctx, pubKey) if err != nil { - err = fmt.Errorf("transport: %s", err) - return + return fmt.Errorf("transport: %s", err) } defer func() { if err := tr.Close(); err != nil { @@ -354,11 +372,10 @@ func (sn *Node) setupRule(ctx context.Context, pubKey cipher.PubKey, }() proto := NewSetupProtocol(tr) - routeID, err = AddRule(proto, rule) - if err != nil { - return + if err := AddRule(proto, rule); err != nil { + return err } - sn.Logger.Infof("Set rule of type %s on %s with ID %d", rule.Type(), pubKey, routeID) - return routeID, nil + sn.Logger.Infof("Set rule of type %s on %s", rule.Type(), pubKey) + return nil } diff --git a/pkg/setup/protocol.go b/pkg/setup/protocol.go index e7163ce466..0bd59df304 100644 --- a/pkg/setup/protocol.go +++ b/pkg/setup/protocol.go @@ -102,11 +102,11 @@ func (p *Protocol) WritePacket(t PacketType, body interface{}) error { } // RequestRouteID sends RequestRouteID request. -func RequestRouteID(p *Protocol) (uint32, error) { +func RequestRouteID(p *Protocol) (routing.RouteID, error) { if err := p.WritePacket(PacketRequestRouteID, nil); err != nil { return 0, err } - var res []uint32 + var res []routing.RouteID if err := readAndDecodePacket(p, &res); err != nil { return 0, err } @@ -117,18 +117,11 @@ func RequestRouteID(p *Protocol) (uint32, error) { } // AddRule sends AddRule setup request. -func AddRule(p *Protocol, rule routing.Rule) (routeID routing.RouteID, err error) { - if err = p.WritePacket(PacketAddRules, []routing.Rule{rule}); err != nil { - return 0, err - } - var res []routing.RouteID - if err = readAndDecodePacket(p, &res); err != nil { - return 0, err - } - if len(res) == 0 { - return 0, errors.New("empty response") +func AddRule(p *Protocol, rule routing.Rule) error { + if err := p.WritePacket(PacketAddRules, []routing.Rule{rule}); err != nil { + return err } - return res[0], nil + return readAndDecodePacket(p, nil) } // DeleteRule sends DeleteRule setup request.