Skip to content

Commit

Permalink
Add proper handling of Close packets by the router
Browse files Browse the repository at this point in the history
  • Loading branch information
Darkren committed Dec 25, 2019
1 parent 61b5506 commit f45ffec
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 44 deletions.
75 changes: 40 additions & 35 deletions pkg/router/route_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,42 +242,9 @@ func (rg *RouteGroup) tp() (*transport.ManagedTransport, error) {
return tp, nil
}

// Close closes a RouteGroup:
// - Send Close packet for all ForwardRules.
// - Delete all rules (ForwardRules and ConsumeRules) from routing table.
// - Close all go channels.
// Close closes a RouteGroup.
func (rg *RouteGroup) Close() error {
rg.mu.Lock()
defer rg.mu.Unlock()

if len(rg.fwd) != len(rg.tps) {
return ErrRuleTransportMismatch
}

for i := 0; i < len(rg.tps); i++ {
packet := routing.MakeClosePacket(rg.fwd[i].KeyRouteID(), routing.CloseRequested)
if err := rg.tps[i].WritePacket(context.Background(), packet); err != nil {
return err
}
}

rules := rg.rt.RulesWithDesc(rg.desc)
routeIDs := make([]routing.RouteID, 0, len(rules))

for _, rule := range rules {
routeIDs = append(routeIDs, rule.KeyRouteID())
}

rg.rt.DelRules(routeIDs)

rg.once.Do(func() {
close(rg.done)
rg.readChMu.Lock()
close(rg.readCh)
rg.readChMu.Unlock()
})

return nil
return rg.close(routing.CloseRequested)
}

// LocalAddr returns destination address of underlying RouteDescriptor.
Expand Down Expand Up @@ -352,6 +319,44 @@ func (rg *RouteGroup) sendKeepAlive() error {
return nil
}

// Close closes a RouteGroup with the specified close `code`:
// - Send Close packet for all ForwardRules with the code `code`.
// - Delete all rules (ForwardRules and ConsumeRules) from routing table.
// - Close all go channels.
func (rg *RouteGroup) close(code routing.CloseCode) error {
rg.mu.Lock()
defer rg.mu.Unlock()

if len(rg.fwd) != len(rg.tps) {
return ErrRuleTransportMismatch
}

for i := 0; i < len(rg.tps); i++ {
packet := routing.MakeClosePacket(rg.fwd[i].KeyRouteID(), code)
if err := rg.tps[i].WritePacket(context.Background(), packet); err != nil {
return err
}
}

rules := rg.rt.RulesWithDesc(rg.desc)
routeIDs := make([]routing.RouteID, 0, len(rules))

for _, rule := range rules {
routeIDs = append(routeIDs, rule.KeyRouteID())
}

rg.rt.DelRules(routeIDs)

rg.once.Do(func() {
close(rg.done)
rg.readChMu.Lock()
close(rg.readCh)
rg.readChMu.Unlock()
})

return nil
}

func (rg *RouteGroup) isClosed() bool {
select {
case <-rg.done:
Expand Down
64 changes: 55 additions & 9 deletions pkg/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,6 @@ func (r *router) handleTransportPacket(ctx context.Context, packet routing.Packe

func (r *router) handleDataPacket(ctx context.Context, packet routing.Packet) error {
rule, err := r.GetRule(packet.RouteID())

if err != nil {
return err
}
Expand Down Expand Up @@ -370,7 +369,6 @@ func (r *router) handleDataPacket(ctx context.Context, packet routing.Packet) er
r.logger.Infof("Packet contents (len = %d): %v", len(packet.Payload()), packet.Payload())

if rg.isClosed() {
r.logger.Infoln("RG IS CLOSED")
return io.ErrClosedPipe
}

Expand All @@ -379,23 +377,69 @@ func (r *router) handleDataPacket(ctx context.Context, packet routing.Packet) er

select {
case <-rg.done:
r.logger.Infof("RG IS DONE")
return io.ErrClosedPipe
case rg.readCh <- packet.Payload():
r.logger.Infof("PUT PAYLOAD INTO RG READ CHAN")
return nil
}
}

func (r *router) handleClosePacket(_ context.Context, packet routing.Packet) error {
func (r *router) handleClosePacket(ctx context.Context, packet routing.Packet) error {
routeID := packet.RouteID()

r.logger.Infof("Received keepalive packet for route ID %v", routeID)
r.logger.Infof("Received close packet for route ID %v", routeID)

rules := []routing.RouteID{routeID}
r.rt.DelRules(rules)
rule, err := r.GetRule(routeID)
if err != nil {
return err
}

return nil
if t := rule.Type(); t == routing.RuleIntermediaryForward {
r.logger.Infoln("Handling intermediary close packet")

// defer this only on intermediary nodes. destination node will remove
// the needed rules in the route group `Close` routine
defer func() {
routeIDs := []routing.RouteID{routeID}
r.rt.DelRules(routeIDs)
}()

return r.forwardPacket(ctx, packet, rule)
}

desc := rule.RouteDescriptor()
rg, ok := r.routeGroup(desc)

r.logger.Infof("Handling close packet with descriptor %s", &desc)

if !ok {
r.logger.Infof("Descriptor not found for rule with type %s, descriptor: %s", rule.Type(), &desc)
return errors.New("route descriptor does not exist")
}

if rg == nil {
return errors.New("RouteGroup is nil")
}

r.logger.Infof("Got new remote close packet with route ID %d. Using rule: %s", packet.RouteID(), rule)
r.logger.Infof("Packet contents (len = %d): %v", len(packet.Payload()), packet.Payload())

if rg.isClosed() {
return io.ErrClosedPipe
}

rg.mu.Lock()
defer rg.mu.Unlock()

select {
case <-rg.done:
return io.ErrClosedPipe
default:
if err := rg.Close(); err != nil {
return fmt.Errorf("error closing route group with descriptor %s: %w", &desc, err)
}

return nil
}
}

func (r *router) handleKeepAlivePacket(ctx context.Context, packet routing.Packet) error {
Expand Down Expand Up @@ -477,6 +521,8 @@ func (r *router) forwardPacket(ctx context.Context, packet routing.Packet, rule
p = routing.MakeDataPacket(rule.NextRouteID(), packet.Payload())
case routing.KeepAlivePacket:
p = routing.MakeKeepAlivePacket(rule.NextRouteID())
case routing.ClosePacket:
p = routing.MakeClosePacket(rule.NextRouteID(), routing.CloseCode(packet.Payload()[0]))
default:
return fmt.Errorf("packet of type %s can't be forwarded", packet.Type())
}
Expand Down

0 comments on commit f45ffec

Please sign in to comment.