diff --git a/retrier/retrier.go b/retrier/retrier.go index 1cd8d47..bb8dc37 100644 --- a/retrier/retrier.go +++ b/retrier/retrier.go @@ -11,11 +11,12 @@ import ( // Retrier implements the "retriable" resiliency pattern, abstracting out the process of retrying a failed action // a certain number of times with an optional back-off between each retry. type Retrier struct { - backoff []time.Duration - class Classifier - jitter float64 - rand *rand.Rand - randMu sync.Mutex + backoff []time.Duration + infiniteRetry bool + class Classifier + jitter float64 + rand *rand.Rand + randMu sync.Mutex } // New constructs a Retrier with the given backoff pattern and classifier. The length of the backoff pattern @@ -34,9 +35,15 @@ func New(backoff []time.Duration, class Classifier) *Retrier { } } +// WithInfiniteRetry set the retrier to loop infinitely on the last backoff duration +func (r *Retrier) WithInfiniteRetry() *Retrier { + r.infiniteRetry = true + return r +} + // Run executes the given work function by executing RunCtx without context.Context. func (r *Retrier) Run(work func() error) error { - return r.RunCtx(context.Background(), func(ctx context.Context) error { + return r.RunCtx(context.Background(), func(ctx context.Context, retries int) error { // never use ctx return work() }) @@ -47,16 +54,16 @@ func (r *Retrier) Run(work func() error) error { // returned to the caller. If the result is Retry, then Run sleeps according to the its backoff policy // before retrying. If the total number of retries is exceeded then the return value of the work function // is returned to the caller regardless. -func (r *Retrier) RunCtx(ctx context.Context, work func(ctx context.Context) error) error { +func (r *Retrier) RunCtx(ctx context.Context, work func(ctx context.Context, retries int) error) error { retries := 0 for { - ret := work(ctx) + ret := work(ctx, retries) switch r.class.Classify(ret) { case Succeed, Fail: return ret case Retry: - if retries >= len(r.backoff) { + if !r.infiniteRetry && retries >= len(r.backoff) { return ret } @@ -84,6 +91,9 @@ func (r *Retrier) sleep(ctx context.Context, timer *time.Timer) error { } func (r *Retrier) calcSleep(i int) time.Duration { + if i >= len(r.backoff) { + i = len(r.backoff) - 1 + } // lock unsafe rand prng r.randMu.Lock() defer r.randMu.Unlock() diff --git a/retrier/retrier_test.go b/retrier/retrier_test.go index aaa1d51..afe7e4f 100644 --- a/retrier/retrier_test.go +++ b/retrier/retrier_test.go @@ -20,9 +20,9 @@ func genWork(returns []error) func() error { } } -func genWorkWithCtx() func(ctx context.Context) error { +func genWorkWithCtx() func(ctx context.Context, retries int) error { i = 0 - return func(ctx context.Context) error { + return func(ctx context.Context, retries int) error { select { case <-ctx.Done(): return errFoo @@ -33,6 +33,15 @@ func genWorkWithCtx() func(ctx context.Context) error { } } +func genWorkWithCtxError(returns []error) func(ctx context.Context, retries int) error { + return func(ctx context.Context, retries int) error { + if retries > len(returns) { + return nil + } + return returns[retries-1] + } +} + func TestRetrier(t *testing.T) { r := New([]time.Duration{0, 10 * time.Millisecond}, WhitelistClassifier{errFoo}) @@ -85,6 +94,38 @@ func TestRetrierCtx(t *testing.T) { } } +func TestRetrierCtxError(t *testing.T) { + ctx := context.Background() + r := New([]time.Duration{0, 10 * time.Millisecond}, nil) + errExpected := []error{errFoo, errFoo, errBar, errBaz} + + err := r.RunCtx(ctx, func(ctx context.Context, retries int) error { + if retries >= len(errExpected) { + return nil + } + return errExpected[retries] + }) + if err != errBar { + t.Error(err) + } +} + +func TestRetrierCtxWithInfinite(t *testing.T) { + ctx := context.Background() + r := New([]time.Duration{0, 10 * time.Millisecond}, nil).WithInfiniteRetry() + errExpected := []error{errFoo, errFoo, errFoo, errBar, errBaz} + + err := r.RunCtx(ctx, func(ctx context.Context, retries int) error { + if retries >= len(errExpected) { + return nil + } + return errExpected[retries] + }) + if err != nil { + t.Error(err) + } +} + func TestRetrierNone(t *testing.T) { r := New(nil, nil)