diff --git a/pkg/router/route_manager.go b/pkg/router/route_manager.go index db23f3d049..788c7ed9bf 100644 --- a/pkg/router/route_manager.go +++ b/pkg/router/route_manager.go @@ -127,7 +127,7 @@ func (rm *routeManager) handleSetupConn(conn net.Conn) error { case setup.PacketLoopClosed: err = rm.loopClosed(body) case setup.PacketRequestRouteID: - respBody, err = rm.occupyRouteID() + respBody, err = rm.occupyRouteID(body) default: err = errors.New("unknown foundation packet") } @@ -314,12 +314,20 @@ func (rm *routeManager) loopClosed(data []byte) error { return rm.conf.OnLoopClosed(ld.Loop) } -func (rm *routeManager) occupyRouteID() ([]routing.RouteID, error) { - rule := routing.IntermediaryForwardRule(DefaultRouteKeepAlive, 0, 0, uuid.UUID{}) - routeID, err := rm.rt.AddRule(rule) - if err != nil { +func (rm *routeManager) occupyRouteID(data []byte) ([]routing.RouteID, error) { + var n uint8 + if err := json.Unmarshal(data, &n); err != nil { return nil, err } - return []routing.RouteID{routeID}, nil + var ids = make([]routing.RouteID, n) + for i := range ids { + rule := routing.IntermediaryForwardRule(DefaultRouteKeepAlive, 0, 0, uuid.UUID{}) + routeID, err := rm.rt.AddRule(rule) + if err != nil { + return nil, err + } + ids[i] = routeID + } + return ids, nil } diff --git a/pkg/router/route_manager_test.go b/pkg/router/route_manager_test.go index 8546fa5578..d24af90360 100644 --- a/pkg/router/route_manager_test.go +++ b/pkg/router/route_manager_test.go @@ -113,26 +113,26 @@ func TestNewRouteManager(t *testing.T) { }() // Emulate SetupNode sending RequestRegistrationID request. - id, err := setup.RequestRouteID(context.TODO(), setup.NewSetupProtocol(requestIDIn)) + ids, err := setup.RequestRouteIDs(context.TODO(), setup.NewSetupProtocol(requestIDIn), 1) require.NoError(t, err) // Emulate SetupNode sending AddRule request. - rule := routing.IntermediaryForwardRule(10*time.Minute, id, 3, uuid.New()) - err = setup.AddRule(context.TODO(), setup.NewSetupProtocol(addIn), rule) + rule := routing.IntermediaryForwardRule(10*time.Minute, ids[0], 3, uuid.New()) + err = setup.AddRules(context.TODO(), setup.NewSetupProtocol(addIn), []routing.Rule{rule}) require.NoError(t, err) // Check routing table state after AddRule. assert.Equal(t, 1, rt.Count()) - r, err := rt.Rule(id) + r, err := rt.Rule(ids[0]) require.NoError(t, err) assert.Equal(t, rule, r) // Emulate SetupNode sending RemoveRule request. - require.NoError(t, setup.DeleteRule(context.TODO(), setup.NewSetupProtocol(delIn), id)) + require.NoError(t, setup.DeleteRule(context.TODO(), setup.NewSetupProtocol(delIn), ids[0])) // Check routing table state after DeleteRule. assert.Equal(t, 0, rt.Count()) - r, err = rt.Rule(id) + r, err = rt.Rule(ids[0]) assert.Error(t, err) assert.Nil(t, r) } diff --git a/pkg/routing/rule.go b/pkg/routing/rule.go index 7ca7f79d55..a6c33eb4a9 100644 --- a/pkg/routing/rule.go +++ b/pkg/routing/rule.go @@ -272,15 +272,16 @@ func (d RouteDescriptor) DstPort() Port { func (r Rule) String() string { switch t := r.Type(); t { case RuleConsume: - return fmt.Sprintf("App: ", - r.RouteDescriptor().DstPK(), r.RouteDescriptor().DstPort(), r.RouteDescriptor().SrcPK()) + rd := r.RouteDescriptor() + return fmt.Sprintf("APP(keyRtID:%d, resRtID:%d, rPK:%s, rPort:%d, lPort:%d)", + r.KeyRouteID(), r.NextRouteID(), rd.DstPK(), rd.DstPort(), rd.SrcPK()) case RuleForward: - return fmt.Sprintf("Forward: ", - r.NextRouteID(), r.NextTransportID(), - r.RouteDescriptor().DstPK(), r.RouteDescriptor().DstPort(), r.RouteDescriptor().SrcPK()) + rd := r.RouteDescriptor() + return fmt.Sprintf("FWD(keyRtID:%d, nxtRtID:%d, nxtTpID:%s, rPK:%s, rPort:%d, lPort:%d)", + r.KeyRouteID(), r.NextRouteID(), r.NextTransportID(), rd.DstPK(), rd.DstPort(), rd.SrcPK()) case RuleIntermediaryForward: - return fmt.Sprintf("IntermediaryForward: ", - r.NextRouteID(), r.NextTransportID()) + return fmt.Sprintf("IFWD(keyRtID:%d, nxtRtID:%d, nxtTpID:%s)", + r.KeyRouteID(), r.NextRouteID(), r.NextTransportID()) default: panic(fmt.Sprintf("%v: %v", invalidRule, t.String())) } @@ -294,6 +295,10 @@ type RouteDescriptorFields struct { SrcPort Port `json:"src_port"` } +//func (r Rule) MarshalJSON() ([]byte, error) { +// return json.Marshal(r.String()) +//} + // RuleConsumeFields summarizes consume fields of a RoutingRule. type RuleConsumeFields struct { RouteDescriptor RouteDescriptorFields `json:"route_descriptor"` diff --git a/pkg/setup/config.go b/pkg/setup/config.go index ccb4ddad30..e30becc02c 100644 --- a/pkg/setup/config.go +++ b/pkg/setup/config.go @@ -8,8 +8,8 @@ import ( // Various timeouts for setup node. const ( - ServeTransportTimeout = time.Second * 30 - ReadTimeout = time.Second * 10 + RequestTimeout = time.Second * 30 + ReadTimeout = time.Second * 10 ) // Config defines configuration parameters for setup Node. diff --git a/pkg/setup/idreservoir.go b/pkg/setup/idreservoir.go new file mode 100644 index 0000000000..5cfe3f1da7 --- /dev/null +++ b/pkg/setup/idreservoir.go @@ -0,0 +1,164 @@ +package setup + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "github.com/skycoin/dmsg/cipher" + + "github.com/skycoin/skywire/pkg/routing" +) + +type idReservoir struct { + rec map[cipher.PubKey]uint8 + ids map[cipher.PubKey][]routing.RouteID + mx sync.Mutex +} + +func newIDReservoir(routes ...routing.Route) (*idReservoir, int) { + rec := make(map[cipher.PubKey]uint8) + var total int + + for _, rt := range routes { + if len(rt) == 0 { + continue + } + rec[rt[0].From]++ + for _, hop := range rt { + rec[hop.To]++ + } + total += len(rt) + 1 + } + + return &idReservoir{ + rec: rec, + ids: make(map[cipher.PubKey][]routing.RouteID), + }, total +} + +func (idr *idReservoir) ReserveIDs(ctx context.Context, reserve func(ctx context.Context, pk cipher.PubKey, n uint8) ([]routing.RouteID, error)) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + errCh := make(chan error, len(idr.rec)) + defer close(errCh) + + for pk, n := range idr.rec { + pk, n := pk, n + go func() { + ids, err := reserve(ctx, pk, n) + if err != nil { + errCh <- fmt.Errorf("reserve routeID from %s failed: %v", pk, err) + return + } + idr.mx.Lock() + idr.ids[pk] = ids + idr.mx.Unlock() + errCh <- nil + }() + } + + return finalError(len(idr.rec), errCh) +} + +func (idr *idReservoir) PopID(pk cipher.PubKey) (routing.RouteID, bool) { + idr.mx.Lock() + defer idr.mx.Unlock() + + ids, ok := idr.ids[pk] + if !ok || len(ids) == 0 { + return 0, false + } + + idr.ids[pk] = ids[1:] + return ids[0], true +} + +// RulesMap associates a slice of rules to a visor's public key. +type RulesMap map[cipher.PubKey][]routing.Rule + +func (rm RulesMap) String() string { + out := make(map[cipher.PubKey][]string, len(rm)) + for pk, rules := range rm { + str := make([]string, len(rules)) + for i, rule := range rules { + str[i] = rule.String() + } + out[pk] = str + } + jb, err := json.MarshalIndent(out, "", "\t") + if err != nil { + panic(err) + } + return string(jb) +} + +// GenerateRules generates rules for a given LoopDescriptor. +// The outputs are as follows: +// - rules: a map that relates a slice of routing rules to a given visor's public key. +// - srcAppRID: the initiating node's route ID that references the FWD rule. +// - dstAppRID: the responding node's route ID that references the FWD rule. +// - err: an error (if any). +func GenerateRules(idc *idReservoir, ld routing.LoopDescriptor) (rules RulesMap, srcFwdRID, dstFwdRID routing.RouteID, err error) { + rules = make(RulesMap) + src, dst := ld.Loop.Local, ld.Loop.Remote + + firstFwdRID, _, err := SaveForwardRules(rules, idc, ld.KeepAlive, ld.Forward) + if err != nil { + return nil, 0, 0, err + } + firstRevRID, _, err := SaveForwardRules(rules, idc, ld.KeepAlive, ld.Reverse) + if err != nil { + return nil, 0, 0, err + } + + rules[src.PubKey] = append(rules[src.PubKey], + routing.ConsumeRule(ld.KeepAlive, firstRevRID, dst.PubKey, src.Port, dst.Port)) + rules[dst.PubKey] = append(rules[dst.PubKey], + routing.ConsumeRule(ld.KeepAlive, firstFwdRID, src.PubKey, dst.Port, src.Port)) + + return rules, firstFwdRID, firstRevRID, nil +} + +// SaveForwardRules creates the rules of the given route, and saves them in the 'rules' input. +// Note that the last rule for the route is always an APP rule, and so is not created here. +// The outputs are as follows: +// - firstRID: the first visor's route ID. +// - lastRID: the last visor's route ID (note that there is no rule set for this ID yet). +// - err: an error (if any). +func SaveForwardRules(rules RulesMap, idc *idReservoir, keepAlive time.Duration, route routing.Route) (firstRID, lastRID routing.RouteID, err error) { + + // 'firstRID' is the first visor's key routeID - this is to be returned. + var ok bool + if firstRID, ok = idc.PopID(route[0].From); !ok { + return 0, 0, errors.New("fucked up") + } + + var rID = firstRID + for _, hop := range route { + nxtRID, ok := idc.PopID(hop.To) + if !ok { + return 0, 0, errors.New("fucked up") + } + rule := routing.IntermediaryForwardRule(keepAlive, rID, nxtRID, hop.Transport) + rules[hop.From] = append(rules[hop.From], rule) + + rID = nxtRID + } + + return firstRID, rID, nil +} + +func finalError(n int, errCh <-chan error) error { + var finalErr error + for i := 0; i < n; i++ { + if err := <-errCh; err != nil { + finalErr = err + } + } + return finalErr +} diff --git a/pkg/setup/node.go b/pkg/setup/node.go index 1de20870c3..8eb0454e9d 100644 --- a/pkg/setup/node.go +++ b/pkg/setup/node.go @@ -5,21 +5,16 @@ import ( "encoding/json" "errors" "fmt" - "sync" "time" - "github.com/google/uuid" - - "github.com/skycoin/skywire/pkg/snet" - "github.com/skycoin/dmsg" - "github.com/skycoin/dmsg/cipher" "github.com/skycoin/dmsg/disc" "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/metrics" "github.com/skycoin/skywire/pkg/routing" + "github.com/skycoin/skywire/pkg/snet" ) // Node performs routes setup operations over messaging channel. @@ -68,6 +63,14 @@ func NewNode(conf *Config, metrics metrics.Recorder) (*Node, error) { }, nil } +// Close closes underlying dmsg client. +func (sn *Node) Close() error { + if sn == nil { + return nil + } + return sn.dmsgC.Close() +} + // Serve starts transport listening loop. func (sn *Node) Serve(ctx context.Context) error { sn.Logger.Info("serving setup node") @@ -78,15 +81,15 @@ func (sn *Node) Serve(ctx context.Context) error { return err } go func(conn *dmsg.Transport) { - if err := sn.serveTransport(ctx, conn); err != nil { + if err := sn.handleRequest(ctx, conn); err != nil { sn.Logger.Warnf("Failed to serve Transport: %s", err) } }(conn) } } -func (sn *Node) serveTransport(ctx context.Context, tr *dmsg.Transport) error { - ctx, cancel := context.WithTimeout(ctx, ServeTransportTimeout) +func (sn *Node) handleRequest(ctx context.Context, tr *dmsg.Transport) error { + ctx, cancel := context.WithTimeout(ctx, RequestTimeout) defer cancel() proto := NewSetupProtocol(tr) @@ -95,265 +98,152 @@ func (sn *Node) serveTransport(ctx context.Context, tr *dmsg.Transport) error { return err } - sn.Logger.Infof("Got new Setup request with type %s: %s", sp, string(data)) - defer sn.Logger.Infof("Completed Setup request with type %s: %s", sp, string(data)) + log := sn.Logger.WithField("requester", tr.RemotePK()).WithField("reqType", sp) + log.Infof("Received request.") startTime := time.Now() + switch sp { case PacketCreateLoop: var ld routing.LoopDescriptor - if err = json.Unmarshal(data, &ld); err == nil { - err = sn.createLoop(ctx, ld) + if err = json.Unmarshal(data, &ld); err != nil { + break } + ldJSON, jErr := json.MarshalIndent(ld, "", "\t") + if jErr != nil { + panic(jErr) + } + log.Infof("CreateLoop loop descriptor: %s", string(ldJSON)) + err = sn.handleCreateLoop(ctx, ld) + case PacketCloseLoop: var ld routing.LoopData - if err = json.Unmarshal(data, &ld); err == nil { - err = sn.closeLoop(ctx, ld.Loop.Remote.PubKey, routing.LoopData{ - Loop: routing.Loop{ - Remote: ld.Loop.Local, - Local: ld.Loop.Remote, - }, - }) + if err = json.Unmarshal(data, &ld); err != nil { + break } + err = sn.handleCloseLoop(ctx, ld.Loop.Remote.PubKey, routing.LoopData{ + Loop: routing.Loop{ + Remote: ld.Loop.Local, + Local: ld.Loop.Remote, + }, + }) + default: err = errors.New("unknown foundation packet") } sn.metrics.Record(time.Since(startTime), err != nil) if err != nil { - sn.Logger.Infof("Setup request with type %s failed: %s", sp, err) + log.WithError(err).Warnf("Request completed with error.") return proto.WritePacket(RespFailure, err) } + log.Infof("Request completed successfully.") return proto.WritePacket(RespSuccess, nil) } -func (sn *Node) createLoop(ctx context.Context, ld routing.LoopDescriptor) error { - sn.Logger.Infof("Creating new Loop %s", ld) - rRouteID, err := sn.createRoute(ctx, ld.KeepAlive, ld.Reverse, ld.Loop.Local.Port, ld.Loop.Remote.Port) +func (sn *Node) handleCreateLoop(ctx context.Context, ld routing.LoopDescriptor) error { + src := ld.Loop.Local + dst := ld.Loop.Remote + + // Reserve route IDs from visors. + idr, err := sn.reserveRouteIDs(ctx, ld.Forward, ld.Reverse) if err != nil { return err } - fRouteID, err := sn.createRoute(ctx, ld.KeepAlive, ld.Forward, ld.Loop.Remote.Port, ld.Loop.Local.Port) + // Determine the rules to send to visors using loop descriptor and reserved route IDs. + rulesMap, srcFwdRID, dstFwdRID, err := GenerateRules(idr, ld) if err != nil { return err } + sn.Logger.Infof("generated rules: %v", rulesMap) - if len(ld.Forward) == 0 || len(ld.Reverse) == 0 { - return nil - } - - initiator := ld.Initiator() - responder := ld.Responder() - - ldR := routing.LoopData{ - Loop: routing.Loop{ - Remote: routing.Addr{ - PubKey: initiator, - Port: ld.Loop.Local.Port, - }, - Local: routing.Addr{ - PubKey: responder, - Port: ld.Loop.Remote.Port, - }, - }, - RouteID: rRouteID, - } - if err := sn.connectLoop(ctx, responder, ldR); err != nil { - sn.Logger.Warnf("Failed to confirm loop with responder: %s", err) - return fmt.Errorf("loop connect: %s", err) - } - - ldI := routing.LoopData{ - Loop: routing.Loop{ - Remote: routing.Addr{ - PubKey: responder, - Port: ld.Loop.Remote.Port, - }, - Local: routing.Addr{ - PubKey: initiator, - Port: ld.Loop.Local.Port, - }, - }, - RouteID: fRouteID, - } - if err := sn.connectLoop(ctx, initiator, ldI); err != nil { - sn.Logger.Warnf("Failed to confirm loop with initiator: %s", err) - if err := sn.closeLoop(ctx, responder, ldR); err != nil { - sn.Logger.Warnf("Failed to close loop: %s", err) - } - return fmt.Errorf("loop connect: %s", err) - } - - sn.Logger.Infof("Created Loop %s", ld) - return nil -} - -// createRoute setups the route. Route setup involves applying routing rules to each visor node along the route. -// Each rule applying procedure consists of two steps: -// 1. Request free route ID from the visor node -// 2. Apply the rule, using route ID from the step 1 to register this rule inside the visor node -// -// Route ID received as a response after 1st step is used in two rules. 1st, it's used in the rule being applied -// to the current visor node as a route ID to register this rule within the visor node. -// 2nd, it's used in the rule being applied to the previous visor node along the route as a `respRouteID/nextRouteID`. -// For this reason, each 2nd step must wait for completion of its 1st step and the 1st step of the next visor node -// along the route to be able to obtain route ID from there. IDs serving as `respRouteID/nextRouteID` are being -// passed in a fan-like fashion. -// -// Example. Let's say, we have N visor nodes along the route. Visor[0] is the initiator. Setup node sends N requests to -// each visor along the route according to the 1st step and gets N route IDs in response. Then we assemble N rules to -// be applied. We construct each rule as the following: -// - Rule[0..N-1] are of type `ForwardRule`; -// - Rule[N] is of type `AppRule`; -// - For i = 0..N-1 rule[i] takes `nextTransportID` from the rule[i+1]; -// - For i = 0..N-1 rule[i] takes `respRouteID/nextRouteID` from rule[i+1] (after [i+1] request for free route ID -// completes; -// - Rule[N] has `respRouteID/nextRouteID` equal to 0; -// Rule[0..N] use their route ID retrieved from the 1st step to be registered within the corresponding visor node. -// -// During the setup process each error received along the way causes all the procedure to be canceled. RouteID received -// from the 1st step connecting to the initiating node is used as the ID for the overall rule, thus being returned. -func (sn *Node) createRoute(ctx context.Context, keepAlive time.Duration, route routing.Route, - rport, lport routing.Port) (routing.RouteID, error) { - if len(route) == 0 { - return 0, nil - } + // Add rules to visors. + errCh := make(chan error, len(rulesMap)) + defer close(errCh) + for pk, rules := range rulesMap { + pk, rules := pk, rules + go func() { + log := sn.Logger.WithField("remote", pk) - sn.Logger.Infof("Creating new Route %s", route) - - // add the initiating node to the start of the route. We need to loop over all the visor nodes - // along the route to apply rules including the initiating one - r := make(routing.Route, len(route)+1) - r[0] = &routing.Hop{ - Transport: route[0].Transport, - To: route[0].From, - } - copy(r[1:], route) - - init := route[0].From - - // indicate errors occurred during rules setup - rulesSetupErrs := make(chan error, len(r)) - // reqIDsCh is an array of chans used to pass the requested route IDs around the goroutines. - // We do it in a fan fashion here. We create as many goroutines as there are rules to be applied. - // Goroutine[i] requests visor node for a free route ID. It passes this route ID through a chan to - // a goroutine[i-1]. In turn, goroutine[i-1] waits for a route ID from chan[i]. - // Thus, goroutine[len(r)] doesn't get a route ID and uses 0 instead, goroutine[0] doesn't pass - // its route ID to anyone - reqIDsCh := make([]chan routing.RouteID, 0, len(r)) - for range r { - reqIDsCh = append(reqIDsCh, make(chan routing.RouteID, 2)) - } - - // chan to receive the resulting route ID from a goroutine - resultingRouteIDCh := make(chan routing.RouteID, 2) - - // context to cancel rule setup in case of errors - cancellableCtx, cancel := context.WithCancel(ctx) - for i := len(r) - 1; i >= 0; i-- { - var reqIDChIn, reqIDChOut chan routing.RouteID - // goroutine[0] doesn't need to pass the route ID from the 1st step to anyone - if i > 0 { - reqIDChOut = reqIDsCh[i-1] - } - var ( - nextTpID uuid.UUID - rule routing.Rule - ) - // goroutine[len(r)-1] uses 0 as the route ID from the 1st step - if i != len(r)-1 { - reqIDChIn = reqIDsCh[i] - nextTpID = r[i+1].Transport - rule = routing.IntermediaryForwardRule(keepAlive, 0, 0, nextTpID) - } else { - rule = routing.ConsumeRule(keepAlive, 0, init, lport, rport) - } - - go func(i int, pk cipher.PubKey, rule routing.Rule, reqIDChIn <-chan routing.RouteID, - reqIDChOut chan<- routing.RouteID) { - routeID, err := sn.setupRule(cancellableCtx, pk, rule, reqIDChIn, reqIDChOut) - // adding rule for initiator must result with a route ID for the overall route - // it doesn't matter for now if there was an error, this result will be fetched only if there wasn't one - if i == 0 { - resultingRouteIDCh <- routeID - } + proto, err := sn.dialAndCreateProto(ctx, pk) if err != nil { - // filter out context cancellation errors - if err == context.Canceled { - rulesSetupErrs <- err - } else { - rulesSetupErrs <- fmt.Errorf("rule setup: %s", err) - } - + log.WithError(err).Warn("failed to create proto") + errCh <- err return } + defer sn.closeProto(proto) + log.Debug("proto created successfully") - rulesSetupErrs <- nil - }(i, r[i].To, rule, reqIDChIn, reqIDChOut) + if err := AddRules(ctx, proto, rules); err != nil { + log.WithError(err).Warn("failed to add rules") + errCh <- err + return + } + log.Debug("rules added") + errCh <- nil + }() + } + if err := finalError(len(rulesMap), errCh); err != nil { + return err } - var rulesSetupErr error - var cancelOnce sync.Once - // check for any errors occurred so far - for range r { - // filter out context cancellation errors - if err := <-rulesSetupErrs; err != nil && err != context.Canceled { - // rules setup failed, cancel further setup - cancelOnce.Do(cancel) - rulesSetupErr = err + // Confirm loop with responding visor. + err = func() error { + proto, err := sn.dialAndCreateProto(ctx, dst.PubKey) + if err != nil { + return err } - } - cancelOnce.Do(cancel) + defer sn.closeProto(proto) - // close chan to avoid leaks - close(rulesSetupErrs) - for _, ch := range reqIDsCh { - close(ch) - } - if rulesSetupErr != nil { - return 0, rulesSetupErr + data := routing.LoopData{Loop: routing.Loop{Local: dst, Remote: src}, RouteID: dstFwdRID} + return ConfirmLoop(ctx, proto, data) + }() + if err != nil { + return fmt.Errorf("failed to confirm loop with destination visor: %v", err) } - // value gets passed to the chan only if no errors occurred during the route establishment - // errors are being filtered above, so at the moment we get to this part, the value is - // guaranteed to be in the channel - routeID := <-resultingRouteIDCh - close(resultingRouteIDCh) - - return routeID, nil -} + // Confirm loop with initiating visor. + err = func() error { + proto, err := sn.dialAndCreateProto(ctx, src.PubKey) + if err != nil { + return err + } + defer sn.closeProto(proto) -func (sn *Node) connectLoop(ctx context.Context, on cipher.PubKey, ld routing.LoopData) error { - proto, err := sn.dialAndCreateProto(ctx, on) + data := routing.LoopData{Loop: routing.Loop{Local: src, Remote: dst}, RouteID: srcFwdRID} + return ConfirmLoop(ctx, proto, data) + }() if err != nil { - return err - } - defer sn.closeProto(proto) - - if err := ConfirmLoop(ctx, proto, ld); err != nil { - return err + return fmt.Errorf("failed to confirm loop with destination visor: %v", err) } - sn.Logger.Infof("Confirmed loop on %s with %s. RemotePort: %d. LocalPort: %d", on, ld.Loop.Remote.PubKey, ld.Loop.Remote.Port, ld.Loop.Local.Port) return nil } -// Close closes underlying dmsg client. -func (sn *Node) Close() error { - if sn == nil { - return nil +func (sn *Node) reserveRouteIDs(ctx context.Context, fwd, rev routing.Route) (*idReservoir, error) { + idc, total := newIDReservoir(fwd, rev) + sn.Logger.Infof("There are %d route IDs to reserve.", total) + + err := idc.ReserveIDs(ctx, func(ctx context.Context, pk cipher.PubKey, n uint8) ([]routing.RouteID, error) { + proto, err := sn.dialAndCreateProto(ctx, pk) + if err != nil { + return nil, err + } + defer sn.closeProto(proto) + return RequestRouteIDs(ctx, proto, n) + }) + if err != nil { + sn.Logger.WithError(err).Warnf("Failed to reserve route IDs.") + return nil, err } - return sn.dmsgC.Close() + sn.Logger.Infof("Successfully reserved route IDs.") + return idc, err } -func (sn *Node) closeLoop(ctx context.Context, on cipher.PubKey, ld routing.LoopData) error { - fmt.Printf(">>> BEGIN: closeLoop(%s, ld)\n", on) - defer fmt.Printf(">>> END: closeLoop(%s, ld)\n", on) - +func (sn *Node) handleCloseLoop(ctx context.Context, on cipher.PubKey, ld routing.LoopData) error { proto, err := sn.dialAndCreateProto(ctx, on) - fmt.Println(">>> *****: closeLoop() dialed:", err) if err != nil { return err } @@ -367,65 +257,7 @@ func (sn *Node) closeLoop(ctx context.Context, on cipher.PubKey, ld routing.Loop return nil } -func (sn *Node) setupRule(ctx context.Context, pk cipher.PubKey, rule routing.Rule, - reqIDChIn <-chan routing.RouteID, reqIDChOut chan<- routing.RouteID) (routing.RouteID, error) { - sn.Logger.Debugf("trying to setup setup rule: %v with %s\n", rule, pk) - requestRouteID, err := sn.requestRouteID(ctx, pk) - if err != nil { - return 0, err - } - - if reqIDChOut != nil { - reqIDChOut <- requestRouteID - } - var nextRouteID routing.RouteID - if reqIDChIn != nil { - nextRouteID = <-reqIDChIn - rule.SetNextRouteID(nextRouteID) - } - - rule.SetKeyRouteID(requestRouteID) - - sn.Logger.Debugf("dialing to %s to setup rule: %v\n", pk, rule) - - if err := sn.addRule(ctx, pk, rule); err != nil { - return 0, err - } - - sn.Logger.Infof("Set rule of type %s on %s", rule.Type(), pk) - - return requestRouteID, nil -} - -func (sn *Node) requestRouteID(ctx context.Context, pk cipher.PubKey) (routing.RouteID, error) { - proto, err := sn.dialAndCreateProto(ctx, pk) - if err != nil { - return 0, err - } - defer sn.closeProto(proto) - - requestRouteID, err := RequestRouteID(ctx, proto) - if err != nil { - return 0, err - } - - sn.Logger.Infof("Received route ID %d from %s", requestRouteID, pk) - - return requestRouteID, nil -} - -func (sn *Node) addRule(ctx context.Context, pk cipher.PubKey, rule routing.Rule) error { - proto, err := sn.dialAndCreateProto(ctx, pk) - if err != nil { - return err - } - defer sn.closeProto(proto) - - return AddRule(ctx, proto, rule) -} - func (sn *Node) dialAndCreateProto(ctx context.Context, pk cipher.PubKey) (*Protocol, error) { - sn.Logger.Debugf("dialing to %s\n", pk) tr, err := sn.dmsgC.Dial(ctx, pk, snet.AwaitSetupPort) if err != nil { return nil, fmt.Errorf("transport: %s", err) diff --git a/pkg/setup/node_test.go b/pkg/setup/node_test.go index 78b5567fc0..7f765e2e3d 100644 --- a/pkg/setup/node_test.go +++ b/pkg/setup/node_test.go @@ -3,9 +3,24 @@ package setup import ( + "context" + "encoding/json" + "errors" + "fmt" "log" "os" "testing" + "time" + + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/dmsg/disc" + "github.com/stretchr/testify/require" + "golang.org/x/net/nettest" + + "github.com/skycoin/skywire/pkg/metrics" + "github.com/skycoin/skywire/pkg/routing" + "github.com/skycoin/skywire/pkg/snet" "github.com/skycoin/skycoin/src/util/logging" ) @@ -40,324 +55,324 @@ func TestMain(m *testing.M) { // 3. Hanging may be not the problem of the DMSG. Probably some of the communication part here is wrong. // The reason I think so is that - if we ensure read timeouts, why doesn't this test constantly fail? // Maybe some wrapper for DMSG is wrong, or some internal operations before the actual communication behave bad -//func TestNode(t *testing.T) { -// // Prepare mock dmsg discovery. -// discovery := disc.NewMock() -// -// // Prepare dmsg server. -// server, serverErr := createServer(t, discovery) -// defer func() { -// require.NoError(t, server.Close()) -// require.NoError(t, errWithTimeout(serverErr)) -// }() -// -// type clientWithDMSGAddrAndListener struct { -// *dmsg.Client -// Addr dmsg.Addr -// Listener *dmsg.Listener -// } -// -// // CLOSURE: sets up dmsg clients. -// prepClients := func(n int) ([]clientWithDMSGAddrAndListener, func()) { -// clients := make([]clientWithDMSGAddrAndListener, n) -// for i := 0; i < n; i++ { -// var port uint16 -// // setup node -// if i == 0 { -// port = snet.SetupPort -// } else { -// port = snet.AwaitSetupPort -// } -// pk, sk, err := cipher.GenerateDeterministicKeyPair([]byte{byte(i)}) -// require.NoError(t, err) -// t.Logf("client[%d] PK: %s\n", i, pk) -// c := dmsg.NewClient(pk, sk, discovery, dmsg.SetLogger(logging.MustGetLogger(fmt.Sprintf("client_%d:%s:%d", i, pk, port)))) -// require.NoError(t, c.InitiateServerConnections(context.TODO(), 1)) -// listener, err := c.Listen(port) -// require.NoError(t, err) -// clients[i] = clientWithDMSGAddrAndListener{ -// Client: c, -// Addr: dmsg.Addr{ -// PK: pk, -// Port: port, -// }, -// Listener: listener, -// } -// } -// return clients, func() { -// for _, c := range clients { -// //require.NoError(t, c.Listener.Close()) -// require.NoError(t, c.Close()) -// } -// } -// } -// -// // CLOSURE: sets up setup node. -// prepSetupNode := func(c *dmsg.Client, listener *dmsg.Listener) (*Node, func()) { -// sn := &Node{ -// Logger: logging.MustGetLogger("setup_node"), -// dmsgC: c, -// dmsgL: listener, -// metrics: metrics.NewDummy(), -// } -// go func() { -// if err := sn.Serve(context.TODO()); err != nil { -// sn.Logger.WithError(err).Error("Failed to serve") -// } -// }() -// return sn, func() { -// require.NoError(t, sn.Close()) -// } -// } -// -// // TEST: Emulates the communication between 4 visor nodes and a setup node, -// // where the first client node initiates a loop to the last. -// t.Run("CreateLoop", func(t *testing.T) { -// // client index 0 is for setup node. -// // clients index 1 to 4 are for visor nodes. -// clients, closeClients := prepClients(5) -// defer closeClients() -// -// // prepare and serve setup node (using client 0). -// _, closeSetup := prepSetupNode(clients[0].Client, clients[0].Listener) -// setupPK := clients[0].Addr.PK -// setupPort := clients[0].Addr.Port -// defer closeSetup() -// -// // prepare loop creation (client_1 will use this to request loop creation with setup node). -// ld := routing.LoopDescriptor{ -// Loop: routing.Loop{ -// Local: routing.Addr{PubKey: clients[1].Addr.PK, Port: 1}, -// Remote: routing.Addr{PubKey: clients[4].Addr.PK, Port: 1}, -// }, -// Reverse: routing.Route{ -// &routing.Hop{From: clients[1].Addr.PK, To: clients[2].Addr.PK, Transport: uuid.New()}, -// &routing.Hop{From: clients[2].Addr.PK, To: clients[3].Addr.PK, Transport: uuid.New()}, -// &routing.Hop{From: clients[3].Addr.PK, To: clients[4].Addr.PK, Transport: uuid.New()}, -// }, -// Forward: routing.Route{ -// &routing.Hop{From: clients[4].Addr.PK, To: clients[3].Addr.PK, Transport: uuid.New()}, -// &routing.Hop{From: clients[3].Addr.PK, To: clients[2].Addr.PK, Transport: uuid.New()}, -// &routing.Hop{From: clients[2].Addr.PK, To: clients[1].Addr.PK, Transport: uuid.New()}, -// }, -// KeepAlive: 1 * time.Hour, -// } -// -// // client_1 initiates loop creation with setup node. -// iTp, err := clients[1].Dial(context.TODO(), setupPK, setupPort) -// require.NoError(t, err) -// iTpErrs := make(chan error, 2) -// go func() { -// iTpErrs <- CreateLoop(context.TODO(), NewSetupProtocol(iTp), ld) -// iTpErrs <- iTp.Close() -// close(iTpErrs) -// }() -// defer func() { -// i := 0 -// for err := range iTpErrs { -// require.NoError(t, err, i) -// i++ -// } -// }() -// -// var addRuleDone sync.WaitGroup -// var nextRouteID uint32 -// // CLOSURE: emulates how a visor node should react when expecting an AddRules packet. -// expectAddRules := func(client int, expRule routing.RuleType) { -// conn, err := clients[client].Listener.Accept() -// require.NoError(t, err) -// -// fmt.Printf("client %v:%v accepted\n", client, clients[client].Addr) -// -// proto := NewSetupProtocol(conn) -// -// pt, _, err := proto.ReadPacket() -// require.NoError(t, err) -// require.Equal(t, PacketRequestRouteID, pt) -// -// fmt.Printf("client %v:%v got PacketRequestRouteID\n", client, clients[client].Addr) -// -// routeID := atomic.AddUint32(&nextRouteID, 1) -// -// // TODO: This error is not checked due to a bug in dmsg. -// _ = proto.WritePacket(RespSuccess, []routing.RouteID{routing.RouteID(routeID)}) // nolint:errcheck -// require.NoError(t, err) -// -// fmt.Printf("client %v:%v responded to with registration ID: %v\n", client, clients[client].Addr, routeID) -// -// require.NoError(t, conn.Close()) -// -// conn, err = clients[client].Listener.Accept() -// require.NoError(t, err) -// -// fmt.Printf("client %v:%v accepted 2nd time\n", client, clients[client].Addr) -// -// proto = NewSetupProtocol(conn) -// -// pt, pp, err := proto.ReadPacket() -// require.NoError(t, err) -// require.Equal(t, PacketAddRules, pt) -// -// fmt.Printf("client %v:%v got PacketAddRules\n", client, clients[client].Addr) -// -// var rs []routing.Rule -// require.NoError(t, json.Unmarshal(pp, &rs)) -// -// for _, r := range rs { -// require.Equal(t, expRule, r.Type()) -// } -// -// // TODO: This error is not checked due to a bug in dmsg. -// err = proto.WritePacket(RespSuccess, nil) -// _ = err -// -// fmt.Printf("client %v:%v responded for PacketAddRules\n", client, clients[client].Addr) -// -// require.NoError(t, conn.Close()) -// -// addRuleDone.Done() -// } -// -// // CLOSURE: emulates how a visor node should react when expecting an OnConfirmLoop packet. -// expectConfirmLoop := func(client int) { -// tp, err := clients[client].Listener.AcceptTransport() -// require.NoError(t, err) -// -// proto := NewSetupProtocol(tp) -// -// pt, pp, err := proto.ReadPacket() -// require.NoError(t, err) -// require.Equal(t, PacketConfirmLoop, pt) -// -// var d routing.LoopData -// require.NoError(t, json.Unmarshal(pp, &d)) -// -// switch client { -// case 1: -// require.Equal(t, ld.Loop, d.Loop) -// case 4: -// require.Equal(t, ld.Loop.Local, d.Loop.Remote) -// require.Equal(t, ld.Loop.Remote, d.Loop.Local) -// default: -// t.Fatalf("We shouldn't be receiving a OnConfirmLoop packet from client %d", client) -// } -// -// // TODO: This error is not checked due to a bug in dmsg. -// err = proto.WritePacket(RespSuccess, nil) -// _ = err -// -// require.NoError(t, tp.Close()) -// } -// -// // since the route establishment is asynchronous, -// // we must expect all the messages in parallel -// addRuleDone.Add(4) -// go expectAddRules(4, routing.RuleApp) -// go expectAddRules(3, routing.RuleForward) -// go expectAddRules(2, routing.RuleForward) -// go expectAddRules(1, routing.RuleForward) -// addRuleDone.Wait() -// fmt.Println("FORWARD ROUTE DONE") -// addRuleDone.Add(4) -// go expectAddRules(1, routing.RuleApp) -// go expectAddRules(2, routing.RuleForward) -// go expectAddRules(3, routing.RuleForward) -// go expectAddRules(4, routing.RuleForward) -// addRuleDone.Wait() -// fmt.Println("REVERSE ROUTE DONE") -// expectConfirmLoop(1) -// expectConfirmLoop(4) -// }) -// -// // TEST: Emulates the communication between 2 visor nodes and a setup nodes, -// // where a route is already established, -// // and the first client attempts to tear it down. -// t.Run("CloseLoop", func(t *testing.T) { -// // client index 0 is for setup node. -// // clients index 1 and 2 are for visor nodes. -// clients, closeClients := prepClients(3) -// defer closeClients() -// -// // prepare and serve setup node. -// _, closeSetup := prepSetupNode(clients[0].Client, clients[0].Listener) -// setupPK := clients[0].Addr.PK -// setupPort := clients[0].Addr.Port -// defer closeSetup() -// -// // prepare loop data describing the loop that is to be closed. -// ld := routing.LoopData{ -// Loop: routing.Loop{ -// Local: routing.Addr{ -// PubKey: clients[1].Addr.PK, -// Port: 1, -// }, -// Remote: routing.Addr{ -// PubKey: clients[2].Addr.PK, -// Port: 2, -// }, -// }, -// RouteID: 3, -// } -// -// // client_1 initiates close loop with setup node. -// iTp, err := clients[1].Dial(context.TODO(), setupPK, setupPort) -// require.NoError(t, err) -// iTpErrs := make(chan error, 2) -// go func() { -// iTpErrs <- CloseLoop(context.TODO(), NewSetupProtocol(iTp), ld) -// iTpErrs <- iTp.Close() -// close(iTpErrs) -// }() -// defer func() { -// i := 0 -// for err := range iTpErrs { -// require.NoError(t, err, i) -// i++ -// } -// }() -// -// // client_2 accepts close request. -// tp, err := clients[2].Listener.AcceptTransport() -// require.NoError(t, err) -// defer func() { require.NoError(t, tp.Close()) }() -// -// proto := NewSetupProtocol(tp) -// -// pt, pp, err := proto.ReadPacket() -// require.NoError(t, err) -// require.Equal(t, PacketLoopClosed, pt) -// -// var d routing.LoopData -// require.NoError(t, json.Unmarshal(pp, &d)) -// require.Equal(t, ld.Loop.Remote, d.Loop.Local) -// require.Equal(t, ld.Loop.Local, d.Loop.Remote) -// -// // TODO: This error is not checked due to a bug in dmsg. -// err = proto.WritePacket(RespSuccess, nil) -// _ = err -// }) -//} -// -//func createServer(t *testing.T, dc disc.APIClient) (srv *dmsg.Server, srvErr <-chan error) { -// pk, sk, err := cipher.GenerateDeterministicKeyPair([]byte("s")) -// require.NoError(t, err) -// l, err := nettest.NewLocalListener("tcp") -// require.NoError(t, err) -// srv, err = dmsg.NewServer(pk, sk, "", l, dc) -// require.NoError(t, err) -// errCh := make(chan error, 1) -// go func() { -// errCh <- srv.Serve() -// close(errCh) -// }() -// return srv, errCh -//} -// -//func errWithTimeout(ch <-chan error) error { -// select { -// case err := <-ch: -// return err -// case <-time.After(5 * time.Second): -// return errors.New("timeout") -// } -//} +func TestNode(t *testing.T) { + // Prepare mock dmsg discovery. + discovery := disc.NewMock() + + // Prepare dmsg server. + server, serverErr := createServer(t, discovery) + defer func() { + require.NoError(t, server.Close()) + require.NoError(t, errWithTimeout(serverErr)) + }() + + type clientWithDMSGAddrAndListener struct { + *dmsg.Client + Addr dmsg.Addr + Listener *dmsg.Listener + } + + // CLOSURE: sets up dmsg clients. + prepClients := func(n int) ([]clientWithDMSGAddrAndListener, func()) { + clients := make([]clientWithDMSGAddrAndListener, n) + for i := 0; i < n; i++ { + var port uint16 + // setup node + if i == 0 { + port = snet.SetupPort + } else { + port = snet.AwaitSetupPort + } + pk, sk, err := cipher.GenerateDeterministicKeyPair([]byte{byte(i)}) + require.NoError(t, err) + t.Logf("client[%d] PK: %s\n", i, pk) + c := dmsg.NewClient(pk, sk, discovery, dmsg.SetLogger(logging.MustGetLogger(fmt.Sprintf("client_%d:%s:%d", i, pk, port)))) + require.NoError(t, c.InitiateServerConnections(context.TODO(), 1)) + listener, err := c.Listen(port) + require.NoError(t, err) + clients[i] = clientWithDMSGAddrAndListener{ + Client: c, + Addr: dmsg.Addr{ + PK: pk, + Port: port, + }, + Listener: listener, + } + } + return clients, func() { + for _, c := range clients { + //require.NoError(t, c.Listener.Close()) + require.NoError(t, c.Close()) + } + } + } + + // CLOSURE: sets up setup node. + prepSetupNode := func(c *dmsg.Client, listener *dmsg.Listener) (*Node, func()) { + sn := &Node{ + Logger: logging.MustGetLogger("setup_node"), + dmsgC: c, + dmsgL: listener, + metrics: metrics.NewDummy(), + } + go func() { + if err := sn.Serve(context.TODO()); err != nil { + sn.Logger.WithError(err).Error("Failed to serve") + } + }() + return sn, func() { + require.NoError(t, sn.Close()) + } + } + + //// TEST: Emulates the communication between 4 visor nodes and a setup node, + //// where the first client node initiates a loop to the last. + //t.Run("CreateLoop", func(t *testing.T) { + // // client index 0 is for setup node. + // // clients index 1 to 4 are for visor nodes. + // clients, closeClients := prepClients(5) + // defer closeClients() + // + // // prepare and serve setup node (using client 0). + // _, closeSetup := prepSetupNode(clients[0].Client, clients[0].Listener) + // setupPK := clients[0].Addr.PK + // setupPort := clients[0].Addr.Port + // defer closeSetup() + // + // // prepare loop creation (client_1 will use this to request loop creation with setup node). + // ld := routing.LoopDescriptor{ + // Loop: routing.Loop{ + // Local: routing.Addr{PubKey: clients[1].Addr.PK, Port: 1}, + // Remote: routing.Addr{PubKey: clients[4].Addr.PK, Port: 1}, + // }, + // Reverse: routing.Route{ + // &routing.Hop{From: clients[1].Addr.PK, To: clients[2].Addr.PK, Transport: uuid.New()}, + // &routing.Hop{From: clients[2].Addr.PK, To: clients[3].Addr.PK, Transport: uuid.New()}, + // &routing.Hop{From: clients[3].Addr.PK, To: clients[4].Addr.PK, Transport: uuid.New()}, + // }, + // Forward: routing.Route{ + // &routing.Hop{From: clients[4].Addr.PK, To: clients[3].Addr.PK, Transport: uuid.New()}, + // &routing.Hop{From: clients[3].Addr.PK, To: clients[2].Addr.PK, Transport: uuid.New()}, + // &routing.Hop{From: clients[2].Addr.PK, To: clients[1].Addr.PK, Transport: uuid.New()}, + // }, + // KeepAlive: 1 * time.Hour, + // } + // + // // client_1 initiates loop creation with setup node. + // iTp, err := clients[1].Dial(context.TODO(), setupPK, setupPort) + // require.NoError(t, err) + // iTpErrs := make(chan error, 2) + // go func() { + // iTpErrs <- CreateLoop(context.TODO(), NewSetupProtocol(iTp), ld) + // iTpErrs <- iTp.Close() + // close(iTpErrs) + // }() + // defer func() { + // i := 0 + // for err := range iTpErrs { + // require.NoError(t, err, i) + // i++ + // } + // }() + // + // var addRuleDone sync.WaitGroup + // var nextRouteID uint32 + // // CLOSURE: emulates how a visor node should react when expecting an AddRules packet. + // expectAddRules := func(client int, expRule routing.RuleType) { + // conn, err := clients[client].Listener.Accept() + // require.NoError(t, err) + // + // fmt.Printf("client %v:%v accepted\n", client, clients[client].Addr) + // + // proto := NewSetupProtocol(conn) + // + // pt, _, err := proto.ReadPacket() + // require.NoError(t, err) + // require.Equal(t, PacketRequestRouteID, pt) + // + // fmt.Printf("client %v:%v got PacketRequestRouteID\n", client, clients[client].Addr) + // + // routeID := atomic.AddUint32(&nextRouteID, 1) + // + // // TODO: This error is not checked due to a bug in dmsg. + // _ = proto.WritePacket(RespSuccess, []routing.RouteID{routing.RouteID(routeID)}) // nolint:errcheck + // require.NoError(t, err) + // + // fmt.Printf("client %v:%v responded to with registration ID: %v\n", client, clients[client].Addr, routeID) + // + // require.NoError(t, conn.Close()) + // + // conn, err = clients[client].Listener.Accept() + // require.NoError(t, err) + // + // fmt.Printf("client %v:%v accepted 2nd time\n", client, clients[client].Addr) + // + // proto = NewSetupProtocol(conn) + // + // pt, pp, err := proto.ReadPacket() + // require.NoError(t, err) + // require.Equal(t, PacketAddRules, pt) + // + // fmt.Printf("client %v:%v got PacketAddRules\n", client, clients[client].Addr) + // + // var rs []routing.Rule + // require.NoError(t, json.Unmarshal(pp, &rs)) + // + // for _, r := range rs { + // require.Equal(t, expRule, r.Type()) + // } + // + // // TODO: This error is not checked due to a bug in dmsg. + // err = proto.WritePacket(RespSuccess, nil) + // _ = err + // + // fmt.Printf("client %v:%v responded for PacketAddRules\n", client, clients[client].Addr) + // + // require.NoError(t, conn.Close()) + // + // addRuleDone.Done() + // } + // + // // CLOSURE: emulates how a visor node should react when expecting an OnConfirmLoop packet. + // expectConfirmLoop := func(client int) { + // tp, err := clients[client].Listener.AcceptTransport() + // require.NoError(t, err) + // + // proto := NewSetupProtocol(tp) + // + // pt, pp, err := proto.ReadPacket() + // require.NoError(t, err) + // require.Equal(t, PacketConfirmLoop, pt) + // + // var d routing.LoopData + // require.NoError(t, json.Unmarshal(pp, &d)) + // + // switch client { + // case 1: + // require.Equal(t, ld.Loop, d.Loop) + // case 4: + // require.Equal(t, ld.Loop.Local, d.Loop.Remote) + // require.Equal(t, ld.Loop.Remote, d.Loop.Local) + // default: + // t.Fatalf("We shouldn't be receiving a OnConfirmLoop packet from client %d", client) + // } + // + // // TODO: This error is not checked due to a bug in dmsg. + // err = proto.WritePacket(RespSuccess, nil) + // _ = err + // + // require.NoError(t, tp.Close()) + // } + // + // // since the route establishment is asynchronous, + // // we must expect all the messages in parallel + // addRuleDone.Add(4) + // go expectAddRules(4, routing.RuleApp) + // go expectAddRules(3, routing.RuleForward) + // go expectAddRules(2, routing.RuleForward) + // go expectAddRules(1, routing.RuleForward) + // addRuleDone.Wait() + // fmt.Println("FORWARD ROUTE DONE") + // addRuleDone.Add(4) + // go expectAddRules(1, routing.RuleApp) + // go expectAddRules(2, routing.RuleForward) + // go expectAddRules(3, routing.RuleForward) + // go expectAddRules(4, routing.RuleForward) + // addRuleDone.Wait() + // fmt.Println("REVERSE ROUTE DONE") + // expectConfirmLoop(1) + // expectConfirmLoop(4) + //}) + + // TEST: Emulates the communication between 2 visor nodes and a setup nodes, + // where a route is already established, + // and the first client attempts to tear it down. + t.Run("CloseLoop", func(t *testing.T) { + // client index 0 is for setup node. + // clients index 1 and 2 are for visor nodes. + clients, closeClients := prepClients(3) + defer closeClients() + + // prepare and serve setup node. + _, closeSetup := prepSetupNode(clients[0].Client, clients[0].Listener) + setupPK := clients[0].Addr.PK + setupPort := clients[0].Addr.Port + defer closeSetup() + + // prepare loop data describing the loop that is to be closed. + ld := routing.LoopData{ + Loop: routing.Loop{ + Local: routing.Addr{ + PubKey: clients[1].Addr.PK, + Port: 1, + }, + Remote: routing.Addr{ + PubKey: clients[2].Addr.PK, + Port: 2, + }, + }, + RouteID: 3, + } + + // client_1 initiates close loop with setup node. + iTp, err := clients[1].Dial(context.TODO(), setupPK, setupPort) + require.NoError(t, err) + iTpErrs := make(chan error, 2) + go func() { + iTpErrs <- CloseLoop(context.TODO(), NewSetupProtocol(iTp), ld) + iTpErrs <- iTp.Close() + close(iTpErrs) + }() + defer func() { + i := 0 + for err := range iTpErrs { + require.NoError(t, err, i) + i++ + } + }() + + // client_2 accepts close request. + tp, err := clients[2].Listener.AcceptTransport() + require.NoError(t, err) + defer func() { require.NoError(t, tp.Close()) }() + + proto := NewSetupProtocol(tp) + + pt, pp, err := proto.ReadPacket() + require.NoError(t, err) + require.Equal(t, PacketLoopClosed, pt) + + var d routing.LoopData + require.NoError(t, json.Unmarshal(pp, &d)) + require.Equal(t, ld.Loop.Remote, d.Loop.Local) + require.Equal(t, ld.Loop.Local, d.Loop.Remote) + + // TODO: This error is not checked due to a bug in dmsg. + err = proto.WritePacket(RespSuccess, nil) + _ = err + }) +} + +func createServer(t *testing.T, dc disc.APIClient) (srv *dmsg.Server, srvErr <-chan error) { + pk, sk, err := cipher.GenerateDeterministicKeyPair([]byte("s")) + require.NoError(t, err) + l, err := nettest.NewLocalListener("tcp") + require.NoError(t, err) + srv, err = dmsg.NewServer(pk, sk, "", l, dc) + require.NoError(t, err) + errCh := make(chan error, 1) + go func() { + errCh <- srv.Serve() + close(errCh) + }() + return srv, errCh +} + +func errWithTimeout(ch <-chan error) error { + select { + case err := <-ch: + return err + case <-time.After(5 * time.Second): + return errors.New("timeout") + } +} diff --git a/pkg/setup/protocol.go b/pkg/setup/protocol.go index 8167c27beb..8421406d97 100644 --- a/pkg/setup/protocol.go +++ b/pkg/setup/protocol.go @@ -34,7 +34,7 @@ func (sp PacketType) String() string { case RespFailure: return "Failure" case PacketRequestRouteID: - return "RequestRouteID" + return "RequestRouteIDs" } return fmt.Sprintf("Unknown(%d)", sp) } @@ -52,7 +52,7 @@ const ( PacketCloseLoop // PacketLoopClosed represents OnLoopClosed foundation packet. PacketLoopClosed - // PacketRequestRouteID represents RequestRouteID foundation packet. + // PacketRequestRouteID represents RequestRouteIDs foundation packet. PacketRequestRouteID // RespFailure represents failure response for a foundation packet. @@ -113,24 +113,24 @@ func (p *Protocol) Close() error { return nil } -// RequestRouteID sends RequestRouteID request. -func RequestRouteID(ctx context.Context, p *Protocol) (routing.RouteID, error) { - if err := p.WritePacket(PacketRequestRouteID, nil); err != nil { - return 0, err +// RequestRouteIDs sends RequestRouteIDs request. +func RequestRouteIDs(ctx context.Context, p *Protocol, n uint8) ([]routing.RouteID, error) { + if err := p.WritePacket(PacketRequestRouteID, n); err != nil { + return nil, err } var res []routing.RouteID if err := readAndDecodePacketWithTimeout(ctx, p, &res); err != nil { - return 0, err + return nil, err } - if len(res) == 0 { - return 0, errors.New("empty response") + if len(res) != int(n) { + return nil, errors.New("invalid response: wrong number of routeIDs") } - return res[0], nil + return res, nil } -// AddRule sends AddRule setup request. -func AddRule(ctx context.Context, p *Protocol, rule routing.Rule) error { - if err := p.WritePacket(PacketAddRules, []routing.Rule{rule}); err != nil { +// AddRules sends AddRule setup request. +func AddRules(ctx context.Context, p *Protocol, rules []routing.Rule) error { + if err := p.WritePacket(PacketAddRules, rules); err != nil { return err } return readAndDecodePacketWithTimeout(ctx, p, nil) @@ -197,6 +197,9 @@ func readAndDecodePacketWithTimeout(ctx context.Context, p *Protocol, v interfac case <-ctx.Done(): return ctx.Err() case <-done: + if err == io.ErrClosedPipe { + return nil + } return err } } diff --git a/pkg/transport/managed_transport.go b/pkg/transport/managed_transport.go index 6f9f7db573..9fe99ac0fe 100644 --- a/pkg/transport/managed_transport.go +++ b/pkg/transport/managed_transport.go @@ -107,6 +107,7 @@ func (mt *ManagedTransport) Serve(readCh chan<- routing.Packet, done <-chan stru mt.connMx.Unlock() }() + // Read loop. go func() { defer func() { mt.log.Infof("closed readPacket loop.") @@ -132,6 +133,7 @@ func (mt *ManagedTransport) Serve(readCh chan<- routing.Packet, done <-chan stru } }() + // Redial loop. for { select { case <-mt.done: @@ -224,7 +226,6 @@ func (mt *ManagedTransport) Dial(ctx context.Context) error { return mt.dial(ctx) } -// TODO: Figure out where this fella is called. func (mt *ManagedTransport) dial(ctx context.Context) error { tp, err := mt.n.Dial(mt.netName, mt.rPK, snet.TransportPort) if err != nil { diff --git a/pkg/transport/manager.go b/pkg/transport/manager.go index 0f8a694bfa..cfedc08dd1 100644 --- a/pkg/transport/manager.go +++ b/pkg/transport/manager.go @@ -99,6 +99,8 @@ func (tm *Manager) serve(ctx context.Context) { } }() } + + tm.initTransports(ctx) tm.Logger.Info("transport manager is serving.") // closing logic @@ -115,26 +117,25 @@ func (tm *Manager) serve(ctx context.Context) { } } -// TODO(nkryuchkov): either use or remove if unused -// func (tm *Manager) initTransports(ctx context.Context) { -// tm.mx.Lock() -// defer tm.mx.Unlock() -// -// entries, err := tm.conf.DiscoveryClient.GetTransportsByEdge(ctx, tm.conf.PubKey) -// if err != nil { -// log.Warnf("No transports found for local node: %v", err) -// } -// for _, entry := range entries { -// var ( -// tpType = entry.Entry.Type -// remote = entry.Entry.RemoteEdge(tm.conf.PubKey) -// tpID = entry.Entry.ID -// ) -// if _, err := tm.saveTransport(remote, tpType); err != nil { -// tm.Logger.Warnf("INIT: failed to init tp: type(%s) remote(%s) tpID(%s)", tpType, remote, tpID) -// } -// } -// } +func (tm *Manager) initTransports(ctx context.Context) { + tm.mx.Lock() + defer tm.mx.Unlock() + + entries, err := tm.conf.DiscoveryClient.GetTransportsByEdge(ctx, tm.conf.PubKey) + if err != nil { + log.Warnf("No transports found for local node: %v", err) + } + for _, entry := range entries { + var ( + tpType = entry.Entry.Type + remote = entry.Entry.RemoteEdge(tm.conf.PubKey) + tpID = entry.Entry.ID + ) + if _, err := tm.saveTransport(remote, tpType); err != nil { + tm.Logger.Warnf("INIT: failed to init tp: type(%s) remote(%s) tpID(%s)", tpType, remote, tpID) + } + } +} func (tm *Manager) acceptTransport(ctx context.Context, lis *snet.Listener) error { conn, err := lis.AcceptConn() // TODO: tcp panic.