From 8ec4e2a73d65668ce058d3910733144a083c5b4e Mon Sep 17 00:00:00 2001 From: Nikita Kryuchkov Date: Mon, 18 Nov 2019 19:47:09 +0300 Subject: [PATCH] Add config to RouteGroup --- pkg/router/route_group.go | 52 +++++++++++++++++++++------------- pkg/router/route_group_test.go | 4 ++- pkg/router/router.go | 18 ++++++------ 3 files changed, 44 insertions(+), 30 deletions(-) diff --git a/pkg/router/route_group.go b/pkg/router/route_group.go index e3b93db602..3ab1a77952 100644 --- a/pkg/router/route_group.go +++ b/pkg/router/route_group.go @@ -19,7 +19,8 @@ import ( ) const ( - readChBufSize = 1024 + defaultRouteGroupKeepAliveInterval = 1 * time.Minute + defaultReadChBufSize = 1024 ) var ( @@ -31,15 +32,26 @@ var ( ErrBadTransport = errors.New("bad transport") ) +type RouteGroupConfig struct { + ReadChBufSize int + KeepAliveInterval time.Duration +} + +func DefaultRouteGroupConfig() *RouteGroupConfig { + return &RouteGroupConfig{ + KeepAliveInterval: defaultRouteGroupKeepAliveInterval, + ReadChBufSize: defaultReadChBufSize, + } +} + // RouteGroup should implement 'io.ReadWriteCloser'. // It implements 'net.Conn'. type RouteGroup struct { mu sync.RWMutex logger *logging.Logger - - desc routing.RouteDescriptor // describes the route group - rt routing.Table + desc routing.RouteDescriptor // describes the route group + rt routing.Table // 'tps' is transports used for writing/forward rules. // It should have the same number of elements as 'fwd' @@ -64,7 +76,11 @@ type RouteGroup struct { once sync.Once } -func NewRouteGroup(rt routing.Table, desc routing.RouteDescriptor) *RouteGroup { +func NewRouteGroup(cfg *RouteGroupConfig, rt routing.Table, desc routing.RouteDescriptor) *RouteGroup { + if cfg == nil { + cfg = DefaultRouteGroupConfig() + } + rg := &RouteGroup{ logger: logging.MustGetLogger(fmt.Sprintf("RouteGroup %v", desc)), desc: desc, @@ -72,12 +88,12 @@ func NewRouteGroup(rt routing.Table, desc routing.RouteDescriptor) *RouteGroup { tps: make([]*transport.ManagedTransport, 0), fwd: make([]routing.Rule, 0), rvs: make([]routing.Rule, 0), - readCh: make(chan []byte, readChBufSize), + readCh: make(chan []byte, cfg.ReadChBufSize), readBuf: bytes.Buffer{}, done: make(chan struct{}), } - go rg.keepAliveLoop() + go rg.keepAliveLoop(cfg.KeepAliveInterval) return rg } @@ -104,7 +120,7 @@ func (r *RouteGroup) Read(p []byte) (n int, err error) { // Write writes payload to a RouteGroup // For the first version, only the first ForwardRule (fwd[0]) is used for writing. func (r *RouteGroup) Write(p []byte) (n int, err error) { - if r.isClosing() { + if r.isClosed() { return 0, io.ErrClosedPipe } @@ -192,16 +208,14 @@ func (r *RouteGroup) SetWriteDeadline(t time.Time) error { return nil } -func (r *RouteGroup) keepAliveLoop() { - keepAlive := 1 * time.Minute // TODO: proper value - - ticker := time.NewTicker(keepAlive / 2) +func (r *RouteGroup) keepAliveLoop(interval time.Duration) { + ticker := time.NewTicker(interval) defer ticker.Stop() for range ticker.C { lastSent := time.Unix(0, atomic.LoadInt64(&r.lastSent)) - if time.Since(lastSent) < keepAlive/2 { + if time.Since(lastSent) < interval { continue } @@ -215,18 +229,16 @@ func (r *RouteGroup) sendKeepAlive() error { r.mu.Lock() defer r.mu.Unlock() - if len(r.tps) == 0 { - return ErrNoTransports - } - if len(r.fwd) == 0 { - return ErrNoRules + if len(r.tps) == 0 || len(r.fwd) == 0 { + // if no transports, no rules, then no keepalive + return nil } tp := r.tps[0] rule := r.fwd[0] if tp == nil { - return errors.New("unknown transport") + return ErrBadTransport } packet := routing.MakeKeepAlivePacket(rule.KeyRouteID()) @@ -236,7 +248,7 @@ func (r *RouteGroup) sendKeepAlive() error { return nil } -func (r *RouteGroup) isClosing() bool { +func (r *RouteGroup) isClosed() bool { select { case <-r.done: return true diff --git a/pkg/router/route_group_test.go b/pkg/router/route_group_test.go index da1c7bc3a6..f7ec71fa24 100644 --- a/pkg/router/route_group_test.go +++ b/pkg/router/route_group_test.go @@ -18,8 +18,10 @@ func TestNewRouteGroup(t *testing.T) { port2 := routing.Port(2) desc := routing.NewRouteDescriptor(pk1, pk2, port1, port2) - rg := NewRouteGroup(rt, desc) + rg := NewRouteGroup(DefaultRouteGroupConfig(), rt, desc) require.NotNil(t, rg) + require.False(t, rg.isClosed()) + require.NoError(t, rg.Close()) } diff --git a/pkg/router/router.go b/pkg/router/router.go index 3d75070f5b..e31fa78ec7 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -25,13 +25,18 @@ import ( const ( // DefaultRouteKeepAlive is the default expiration interval for routes - DefaultRouteKeepAlive = 2 * time.Hour // TODO(nkryuchkov): change + DefaultRouteKeepAlive = 2 * time.Minute acceptSize = 1024 minHops = 0 maxHops = 50 ) +var ( + // ErrUnknownPacketType is returned when packet type is unknown. + ErrUnknownPacketType = errors.New("unknown packet type") +) + var log = logging.MustGetLogger("router") // Config configures Router. @@ -88,15 +93,10 @@ type Router interface { // - Save to routing.Table and internal RouteGroup map. // - Return the RoutingGroup. AcceptRoutes(context.Context) (*RouteGroup, error) - SaveRoutingRules(rules ...routing.Rule) error - ReserveKeys(n int) ([]routing.RouteID, error) - IntroduceRules(rules routing.EdgeRules) error - Serve(context.Context) error - SetupIsTrusted(cipher.PubKey) bool } @@ -276,7 +276,7 @@ func (r *router) saveRouteGroupRules(rules routing.EdgeRules) *RouteGroup { rg, ok := r.rgs[rules.Desc] if !ok || rg == nil { - rg = NewRouteGroup(r.rt, rules.Desc) + rg = NewRouteGroup(DefaultRouteGroupConfig(), r.rt, rules.Desc) r.rgs[rules.Desc] = rg } @@ -298,7 +298,7 @@ func (r *router) handleTransportPacket(ctx context.Context, packet routing.Packe case routing.KeepAlivePacket: return r.handleKeepAlivePacket(ctx, packet) default: - return errors.New("unknown packet type") + return ErrUnknownPacketType } } @@ -323,7 +323,7 @@ func (r *router) handleDataPacket(ctx context.Context, packet routing.Packet) er return r.forwardPacket(ctx, packet.Payload(), rule) } - if rg.isClosing() { + if rg.isClosed() { return io.ErrClosedPipe }