Skip to content

Commit

Permalink
feat: client circuit breaker #448
Browse files Browse the repository at this point in the history
Reviewed-by: ccoVeille <[email protected]>
  • Loading branch information
segevda committed Dec 31, 2024
1 parent f2a67d3 commit ab9dda9
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 0 deletions.
163 changes: 163 additions & 0 deletions circuit_breaker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package resty

import (
"errors"
"net/http"
"sync/atomic"
"time"
)

// CircuitBreaker can be in one of three states: Closed, Open, or Half-Open.
// - When the CircuitBreaker is Closed, requests are allowed to pass through.
// - If a failure count threshold is reached within a specified time-frame,
// the CircuitBreaker transitions to the Open state.
// - When the CircuitBreaker is Open, requests are blocked.
// - After a specified timeout, the CircuitBreaker transitions to the Half-Open state.
// - When the CircuitBreaker is Half-Open, a single request is allowed to pass through.
// - If that request fails, the CircuitBreaker returns to the Open state.
// - If the number of successes reaches a specified threshold,
// the CircuitBreaker transitions to the Closed state.
type CircuitBreaker struct {
policies []CircuitBreakerPolicy
timeout time.Duration
failThreshold, successThreshold uint32

state atomic.Value // circuitBreakerState
failCount, successCount atomic.Uint32
lastFail time.Time
}

// NewCircuitBreaker creates a new [CircuitBreaker] with default settings.
// The default settings are:
// - Timeout: 10 seconds
// - FailThreshold: 3
// - SuccessThreshold: 1
// - Policies: CircuitBreaker5xxPolicy
func NewCircuitBreaker() *CircuitBreaker {
cb := &CircuitBreaker{
policies: []CircuitBreakerPolicy{CircuitBreaker5xxPolicy},
timeout: 10 * time.Second,
failThreshold: 3,
successThreshold: 1,
}
cb.state.Store(circuitBreakerStateClosed)
return cb
}

// SetPolicies sets the CircuitBreakerPolicy's that the [CircuitBreaker] will use to determine whether a response is a failure.
func (cb *CircuitBreaker) SetPolicies(policies []CircuitBreakerPolicy) *CircuitBreaker {
cb.policies = policies
return cb
}

// SetTimeout sets the timeout duration for the [CircuitBreaker].
func (cb *CircuitBreaker) SetTimeout(timeout time.Duration) *CircuitBreaker {
cb.timeout = timeout
return cb
}

// SetFailThreshold sets the number of failures that must occur within the timeout duration for the [CircuitBreaker] to
// transition to the Open state.
func (cb *CircuitBreaker) SetFailThreshold(threshold uint32) *CircuitBreaker {
cb.failThreshold = threshold
return cb
}

// SetSuccessThreshold sets the number of successes that must occur to transition the [CircuitBreaker] from the Half-Open state
// to the Closed state.
func (cb *CircuitBreaker) SetSuccessThreshold(threshold uint32) *CircuitBreaker {
cb.successThreshold = threshold
return cb
}

// CircuitBreakerPolicy is a function that determines whether a response should trip the [CircuitBreaker].
type CircuitBreakerPolicy func(resp *http.Response) bool

// CircuitBreaker5xxPolicy is a [CircuitBreakerPolicy] that trips the [CircuitBreaker] if the response status code is 500 or greater.
func CircuitBreaker5xxPolicy(resp *http.Response) bool {
return resp.StatusCode > 499
}

var ErrCircuitBreakerOpen = errors.New("resty: circuit breaker open")

type circuitBreakerState uint32

const (
circuitBreakerStateClosed circuitBreakerState = iota
circuitBreakerStateOpen
circuitBreakerStateHalfOpen
)

func (cb *CircuitBreaker) getState() circuitBreakerState {
return cb.state.Load().(circuitBreakerState)
}

func (cb *CircuitBreaker) allow() error {
if cb == nil {
return nil
}

if cb.getState() == circuitBreakerStateOpen {
return ErrCircuitBreakerOpen
}

return nil
}

func (cb *CircuitBreaker) applyPolicies(resp *http.Response) {
if cb == nil {
return
}

failed := false
for _, policy := range cb.policies {
if policy(resp) {
failed = true
break
}
}

if failed {
if cb.failCount.Load() > 0 && time.Since(cb.lastFail) > cb.timeout {
cb.failCount.Store(0)
}

switch cb.getState() {
case circuitBreakerStateClosed:
failCount := cb.failCount.Add(1)
if failCount >= cb.failThreshold {
cb.open()
} else {
cb.lastFail = time.Now()
}
case circuitBreakerStateHalfOpen:
cb.open()
}
} else {
switch cb.getState() {
case circuitBreakerStateClosed:
return
case circuitBreakerStateHalfOpen:
successCount := cb.successCount.Add(1)
if successCount >= cb.successThreshold {
cb.changeState(circuitBreakerStateClosed)
}
}
}

return
}

func (cb *CircuitBreaker) open() {
cb.changeState(circuitBreakerStateOpen)
go func() {
time.Sleep(cb.timeout)
cb.changeState(circuitBreakerStateHalfOpen)
}()
}

func (cb *CircuitBreaker) changeState(state circuitBreakerState) {
cb.failCount.Store(0)
cb.successCount.Store(0)
cb.state.Store(state)
}
19 changes: 19 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ type Client struct {
contentDecompresserKeys []string
contentDecompressers map[string]ContentDecompresser
certWatcherStopChan chan bool
circuitBreaker *CircuitBreaker
}

// CertWatcherOptions allows configuring a watcher that reloads dynamically TLS certs.
Expand Down Expand Up @@ -942,6 +943,18 @@ func (c *Client) SetContentDecompresserKeys(keys []string) *Client {
return c
}

// SetCircuitBreaker method sets the Circuit Breaker instance into the client.
// It is used to prevent the client from sending requests that are likely to fail.
// For Example: To use the default Circuit Breaker:
//
// client.SetCircuitBreaker(NewCircuitBreaker())
func (c *Client) SetCircuitBreaker(b *CircuitBreaker) *Client {
c.lock.Lock()
defer c.lock.Unlock()
c.circuitBreaker = b
return c
}

// IsDebug method returns `true` if the client is in debug mode; otherwise, it is `false`.
func (c *Client) IsDebug() bool {
c.lock.RLock()
Expand Down Expand Up @@ -2094,6 +2107,10 @@ func (c *Client) executeRequestMiddlewares(req *Request) (err error) {
// Executes method executes the given `Request` object and returns
// response or error.
func (c *Client) execute(req *Request) (*Response, error) {
if err := c.circuitBreaker.allow(); err != nil {
return nil, err
}

if err := c.executeRequestMiddlewares(req); err != nil {
return nil, err
}
Expand All @@ -2118,6 +2135,8 @@ func (c *Client) execute(req *Request) (*Response, error) {
}
}
if resp != nil {
c.circuitBreaker.applyPolicies(resp)

response.Body = resp.Body
if err = response.wrapContentDecompresser(); err != nil {
return response, err
Expand Down
69 changes: 69 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1421,3 +1421,72 @@ func TestClientDebugf(t *testing.T) {
assertEqual(t, "", b.String())
})
}

var _ CircuitBreakerPolicy = CircuitBreaker5xxPolicy

func TestClientCircuitBreaker(t *testing.T) {
ts := createTestServer(func(w http.ResponseWriter, r *http.Request) {
t.Logf("Method: %v", r.Method)
t.Logf("Path: %v", r.URL.Path)

switch r.URL.Path {
case "/200":
w.WriteHeader(http.StatusOK)
return
case "/500":
w.WriteHeader(http.StatusInternalServerError)
return
}
})
defer ts.Close()

failThreshold := uint32(2)
successThreshold := uint32(1)
timeout := 1 * time.Second

c := dcnl().SetCircuitBreaker(
NewCircuitBreaker().
SetTimeout(timeout).
SetFailThreshold(failThreshold).
SetSuccessThreshold(successThreshold).
SetPolicies([]CircuitBreakerPolicy{CircuitBreaker5xxPolicy}))

for i := uint32(0); i < failThreshold; i++ {
_, err := c.R().Get(ts.URL + "/500")
assertNil(t, err)
}
resp, err := c.R().Get(ts.URL + "/500")
assertErrorIs(t, ErrCircuitBreakerOpen, err)
assertNil(t, resp)
assertEqual(t, circuitBreakerStateOpen, c.circuitBreaker.getState())

time.Sleep(timeout + 1*time.Millisecond)
assertEqual(t, circuitBreakerStateHalfOpen, c.circuitBreaker.getState())

resp, err = c.R().Get(ts.URL + "/500")
assertError(t, err)
assertEqual(t, circuitBreakerStateOpen, c.circuitBreaker.getState())

time.Sleep(timeout + 1*time.Millisecond)
assertEqual(t, circuitBreakerStateHalfOpen, c.circuitBreaker.getState())

for i := uint32(0); i < successThreshold; i++ {
_, err := c.R().Get(ts.URL + "/200")
assertNil(t, err)
}
assertEqual(t, circuitBreakerStateClosed, c.circuitBreaker.getState())

resp, err = c.R().Get(ts.URL + "/200")
assertNil(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())

resp, err = c.R().Get(ts.URL + "/500")
assertError(t, err)
assertEqual(t, uint32(1), c.circuitBreaker.failCount.Load())

time.Sleep(timeout)

resp, err = c.R().Get(ts.URL + "/500")
assertError(t, err)
assertEqual(t, uint32(1), c.circuitBreaker.failCount.Load())
}

0 comments on commit ab9dda9

Please sign in to comment.