Skip to content

Commit

Permalink
Add Circuit Breaker Logic
Browse files Browse the repository at this point in the history
  • Loading branch information
segevda committed Dec 24, 2024
1 parent 8422694 commit fb9decd
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 0 deletions.
159 changes: 159 additions & 0 deletions breaker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package resty

import (
"errors"
"net/http"
"sync/atomic"
"time"
)

// CircuitBreaker can be in one of three states: Closed, Open, or Half-Open.
// When the circuit breaker is Closed, requests are allowed to pass through.
// If a failure count threshold is reached within a specified time-frame, the circuit breaker transitions to the Open state.
// When the circuit breaker is Open, requests are blocked.
// After a specified timeout, the circuit breaker transitions to the Half-Open state.
// When the circuit breaker is Half-Open, a single request is allowed to pass through.
// If that request fails, the circuit breaker returns to the Open state.
// If that request succeeds, the circuit breaker returns to the Closed state.
type CircuitBreaker struct {
policies []BreakerPolicy
timeout time.Duration
failThreshold, successThreshold int

state circuitBreakerState
failCount, successCount int
lastFail time.Time
}

// NewCircuitBreaker creates a new CircuitBreaker with default settings.
// The default settings are:
// - Timeout: 10 seconds
// - FailThreshold: 3
// - SuccessThreshold: 1
// - Policies: Count5xxPolicy
func NewCircuitBreaker() *CircuitBreaker {
return &CircuitBreaker{
policies: []BreakerPolicy{Count5xxPolicy},
timeout: 10 * time.Second,
failThreshold: 3,
successThreshold: 1,
}
}

// Policies sets the BreakerPolicy's that the CircuitBreaker will use to determine whether a response is a failure.
func (cb *CircuitBreaker) Policies(policies []BreakerPolicy) *CircuitBreaker {
cb.policies = policies
return cb
}

// Timeout sets the timeout duration for the CircuitBreaker
func (cb *CircuitBreaker) Timeout(timeout time.Duration) *CircuitBreaker {
cb.timeout = timeout
return cb
}

// FailThreshold sets the number of failures that must occur within the timeout duration for the CircuitBreaker to
// transition to the Open state.
func (cb *CircuitBreaker) FailThreshold(threshold int) *CircuitBreaker {
cb.failThreshold = threshold
return cb
}

// SuccessThreshold sets the number of successes that must occur to transition the CircuitBreaker from the Half-Open state
// to the Closed state.
func (cb *CircuitBreaker) SuccessThreshold(threshold int) *CircuitBreaker {
cb.successThreshold = threshold
return cb
}

// BreakerPolicy is a function that determines whether a response should trip the circuit breaker
type BreakerPolicy func(resp *http.Response) bool

// Count5xxPolicy is a BreakerPolicy that trips the circuit breaker if the response status code is 500 or greater
func Count5xxPolicy(resp *http.Response) bool {
return resp.StatusCode > 499
}

var ErrBreakerOpen = errors.New("circuit breaker open")

type circuitBreakerState uint32

const (
closed circuitBreakerState = iota
open
halfOpen
)

func (cb *CircuitBreaker) getState() circuitBreakerState {
return circuitBreakerState(atomic.LoadUint32((*uint32)(&cb.state)))
}

func (cb *CircuitBreaker) allow() error {
if cb == nil {
return nil
}

if cb.getState() == open {
return ErrBreakerOpen
}

return nil
}

func (cb *CircuitBreaker) processResponse(resp *http.Response) {
if cb == nil {
return
}

failed := false
for _, policy := range cb.policies {
if policy(resp) {
failed = true
break
}
}

if failed {
if cb.failCount > 0 && time.Since(cb.lastFail) > cb.timeout {
cb.failCount = 0
}

switch cb.getState() {
case closed:
cb.failCount++
if cb.failCount == cb.failThreshold {
cb.open()
} else {
cb.lastFail = time.Now()
}
case halfOpen:
cb.open()
}
} else {
switch cb.getState() {
case closed:
return
case halfOpen:
cb.successCount++
if cb.successCount == cb.successThreshold {
cb.changeState(closed)
}
}
}

return
}

func (cb *CircuitBreaker) open() {
cb.changeState(open)
go func() {
time.Sleep(cb.timeout)
cb.changeState(halfOpen)
}()
}

func (cb *CircuitBreaker) changeState(state circuitBreakerState) {
cb.failCount = 0
cb.successCount = 0
atomic.StoreUint32((*uint32)(&cb.state), uint32(state))
}
19 changes: 19 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ type Client struct {
contentDecompresserKeys []string
contentDecompressers map[string]ContentDecompresser
certWatcherStopChan chan bool
circuitBreaker *CircuitBreaker
}

// CertWatcherOptions allows configuring a watcher that reloads dynamically TLS certs.
Expand Down Expand Up @@ -939,6 +940,18 @@ func (c *Client) SetContentDecompresserKeys(keys []string) *Client {
return c
}

// SetCircuitBreaker method sets the Circuit Breaker instance into the client.
// It is used to prevent the client from sending requests that are likely to fail.
// For Example: To use the default Circuit Breaker:
//
// client.SetCircuitBreaker(NewCircuitBreaker())
func (c *Client) SetCircuitBreaker(b *CircuitBreaker) *Client {
c.lock.Lock()
defer c.lock.Unlock()
c.circuitBreaker = b
return c
}

// IsDebug method returns `true` if the client is in debug mode; otherwise, it is `false`.
func (c *Client) IsDebug() bool {
c.lock.RLock()
Expand Down Expand Up @@ -2066,6 +2079,10 @@ func (c *Client) executeRequestMiddlewares(req *Request) (err error) {
// Executes method executes the given `Request` object and returns
// response or error.
func (c *Client) execute(req *Request) (*Response, error) {
if err := c.circuitBreaker.allow(); err != nil {
return nil, err
}

if err := c.executeRequestMiddlewares(req); err != nil {
return nil, err
}
Expand All @@ -2090,6 +2107,8 @@ func (c *Client) execute(req *Request) (*Response, error) {
}
}
if resp != nil {
c.circuitBreaker.processResponse(resp)

response.Body = resp.Body
if err = response.wrapContentDecompresser(); err != nil {
return response, err
Expand Down
67 changes: 67 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1414,3 +1414,70 @@ func TestClientDebugf(t *testing.T) {
assertEqual(t, "", b.String())
})
}

func TestClientCircuitBreaker(t *testing.T) {
ts := createTestServer(func(w http.ResponseWriter, r *http.Request) {
t.Logf("Method: %v", r.Method)
t.Logf("Path: %v", r.URL.Path)

switch r.URL.Path {
case "/200":
w.WriteHeader(http.StatusOK)
return
case "/500":
w.WriteHeader(http.StatusInternalServerError)
return
}
})
defer ts.Close()

failThreshold := 2
successThreshold := 1
timeout := 1 * time.Second

c := dcnl().SetCircuitBreaker(
NewCircuitBreaker().
Timeout(timeout).
FailThreshold(failThreshold).
SuccessThreshold(successThreshold).
Policies([]BreakerPolicy{Count5xxPolicy}))

for i := 0; i < failThreshold; i++ {
_, err := c.R().Get(ts.URL + "/500")
assertNil(t, err)
}
resp, err := c.R().Get(ts.URL + "/500")
assertErrorIs(t, ErrBreakerOpen, err)
assertNil(t, resp)
assertEqual(t, c.circuitBreaker.getState(), open)

time.Sleep(timeout + 1*time.Millisecond)
assertEqual(t, c.circuitBreaker.getState(), halfOpen)

resp, err = c.R().Get(ts.URL + "/500")
assertError(t, err)
assertEqual(t, c.circuitBreaker.getState(), open)

time.Sleep(timeout + 1*time.Millisecond)
assertEqual(t, c.circuitBreaker.getState(), halfOpen)

for i := 0; i < successThreshold; i++ {
_, err := c.R().Get(ts.URL + "/200")
assertNil(t, err)
}
assertEqual(t, c.circuitBreaker.getState(), closed)

resp, err = c.R().Get(ts.URL + "/200")
assertNil(t, err)
assertEqual(t, http.StatusOK, resp.StatusCode())

resp, err = c.R().Get(ts.URL + "/500")
assertError(t, err)
assertEqual(t, c.circuitBreaker.failCount, 1)

time.Sleep(timeout)

resp, err = c.R().Get(ts.URL + "/500")
assertError(t, err)
assertEqual(t, c.circuitBreaker.failCount, 1)
}

0 comments on commit fb9decd

Please sign in to comment.