From f0d616968f132bf69e16702014ec15989188d0de Mon Sep 17 00:00:00 2001 From: Segev Dagan Date: Mon, 23 Dec 2024 13:28:46 +0200 Subject: [PATCH] Add Circuit Breaker Logic --- breaker.go | 159 +++++++++++++++++++++++++++++++++++++++++++++++++ client.go | 19 ++++++ client_test.go | 67 +++++++++++++++++++++ 3 files changed, 245 insertions(+) create mode 100644 breaker.go diff --git a/breaker.go b/breaker.go new file mode 100644 index 00000000..c8526e63 --- /dev/null +++ b/breaker.go @@ -0,0 +1,159 @@ +package resty + +import ( + "errors" + "net/http" + "sync/atomic" + "time" +) + +// CircuitBreaker can be in one of three states: Closed, Open, or Half-Open. +// When the circuit breaker is Closed, requests are allowed to pass through. +// If a failure count threshold is reached within a specified time-frame, the circuit breaker transitions to the Open state. +// When the circuit breaker is Open, requests are blocked. +// After a specified timeout, the circuit breaker transitions to the Half-Open state. +// When the circuit breaker is Half-Open, a single request is allowed to pass through. +// If that request fails, the circuit breaker returns to the Open state. +// If SuccessThreshold requests succeed, the circuit breaker transitions back to the Closed state. +type CircuitBreaker struct { + policies []BreakerPolicy + timeout time.Duration + failThreshold, successThreshold int + + state circuitBreakerState + failCount, successCount int + lastFail time.Time +} + +// NewCircuitBreaker creates a new CircuitBreaker with default settings. +// The default settings are: +// - Timeout: 10 seconds +// - FailThreshold: 3 +// - SuccessThreshold: 1 +// - Policies: Count5xxPolicy +func NewCircuitBreaker() *CircuitBreaker { + return &CircuitBreaker{ + policies: []BreakerPolicy{Count5xxPolicy}, + timeout: 10 * time.Second, + failThreshold: 3, + successThreshold: 1, + } +} + +// Policies sets the BreakerPolicy's that the CircuitBreaker will use to determine whether a response is a failure. +func (cb *CircuitBreaker) Policies(policies []BreakerPolicy) *CircuitBreaker { + cb.policies = policies + return cb +} + +// Timeout sets the timeout duration for the CircuitBreaker +func (cb *CircuitBreaker) Timeout(timeout time.Duration) *CircuitBreaker { + cb.timeout = timeout + return cb +} + +// FailThreshold sets the number of failures that must occur within the timeout duration for the CircuitBreaker to +// transition to the Open state. +func (cb *CircuitBreaker) FailThreshold(threshold int) *CircuitBreaker { + cb.failThreshold = threshold + return cb +} + +// SuccessThreshold sets the number of successes that must occur to transition the CircuitBreaker from the Half-Open state +// to the Closed state. +func (cb *CircuitBreaker) SuccessThreshold(threshold int) *CircuitBreaker { + cb.successThreshold = threshold + return cb +} + +// BreakerPolicy is a function that determines whether a response should trip the circuit breaker +type BreakerPolicy func(resp *http.Response) bool + +// Count5xxPolicy is a BreakerPolicy that trips the circuit breaker if the response status code is 500 or greater +func Count5xxPolicy(resp *http.Response) bool { + return resp.StatusCode > 499 +} + +var ErrBreakerOpen = errors.New("circuit breaker open") + +type circuitBreakerState uint32 + +const ( + closed circuitBreakerState = iota + open + halfOpen +) + +func (cb *CircuitBreaker) getState() circuitBreakerState { + return circuitBreakerState(atomic.LoadUint32((*uint32)(&cb.state))) +} + +func (cb *CircuitBreaker) allow() error { + if cb == nil { + return nil + } + + if cb.getState() == open { + return ErrBreakerOpen + } + + return nil +} + +func (cb *CircuitBreaker) processResponse(resp *http.Response) { + if cb == nil { + return + } + + failed := false + for _, policy := range cb.policies { + if policy(resp) { + failed = true + break + } + } + + if failed { + if cb.failCount > 0 && time.Since(cb.lastFail) > cb.timeout { + cb.failCount = 0 + } + + switch cb.getState() { + case closed: + cb.failCount++ + if cb.failCount == cb.failThreshold { + cb.open() + } else { + cb.lastFail = time.Now() + } + case halfOpen: + cb.open() + } + } else { + switch cb.getState() { + case closed: + return + case halfOpen: + cb.successCount++ + if cb.successCount == cb.successThreshold { + cb.changeState(closed) + } + } + } + + return +} + +func (cb *CircuitBreaker) open() { + cb.changeState(open) + go func() { + time.Sleep(cb.timeout) + cb.changeState(halfOpen) + }() +} + +func (cb *CircuitBreaker) changeState(state circuitBreakerState) { + cb.failCount = 0 + cb.successCount = 0 + atomic.StoreUint32((*uint32)(&cb.state), uint32(state)) +} diff --git a/client.go b/client.go index 07a7f126..85930d88 100644 --- a/client.go +++ b/client.go @@ -221,6 +221,7 @@ type Client struct { contentDecompresserKeys []string contentDecompressers map[string]ContentDecompresser certWatcherStopChan chan bool + circuitBreaker *CircuitBreaker } // CertWatcherOptions allows configuring a watcher that reloads dynamically TLS certs. @@ -939,6 +940,18 @@ func (c *Client) SetContentDecompresserKeys(keys []string) *Client { return c } +// SetCircuitBreaker method sets the Circuit Breaker instance into the client. +// It is used to prevent the client from sending requests that are likely to fail. +// For Example: To use the default Circuit Breaker: +// +// client.SetCircuitBreaker(NewCircuitBreaker()) +func (c *Client) SetCircuitBreaker(b *CircuitBreaker) *Client { + c.lock.Lock() + defer c.lock.Unlock() + c.circuitBreaker = b + return c +} + // IsDebug method returns `true` if the client is in debug mode; otherwise, it is `false`. func (c *Client) IsDebug() bool { c.lock.RLock() @@ -2066,6 +2079,10 @@ func (c *Client) executeRequestMiddlewares(req *Request) (err error) { // Executes method executes the given `Request` object and returns // response or error. func (c *Client) execute(req *Request) (*Response, error) { + if err := c.circuitBreaker.allow(); err != nil { + return nil, err + } + if err := c.executeRequestMiddlewares(req); err != nil { return nil, err } @@ -2090,6 +2107,8 @@ func (c *Client) execute(req *Request) (*Response, error) { } } if resp != nil { + c.circuitBreaker.processResponse(resp) + response.Body = resp.Body if err = response.wrapContentDecompresser(); err != nil { return response, err diff --git a/client_test.go b/client_test.go index 94e4c9d9..a8553e2f 100644 --- a/client_test.go +++ b/client_test.go @@ -1414,3 +1414,70 @@ func TestClientDebugf(t *testing.T) { assertEqual(t, "", b.String()) }) } + +func TestClientCircuitBreaker(t *testing.T) { + ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { + t.Logf("Method: %v", r.Method) + t.Logf("Path: %v", r.URL.Path) + + switch r.URL.Path { + case "/200": + w.WriteHeader(http.StatusOK) + return + case "/500": + w.WriteHeader(http.StatusInternalServerError) + return + } + }) + defer ts.Close() + + failThreshold := 2 + successThreshold := 1 + timeout := 1 * time.Second + + c := dcnl().SetCircuitBreaker( + NewCircuitBreaker(). + Timeout(timeout). + FailThreshold(failThreshold). + SuccessThreshold(successThreshold). + Policies([]BreakerPolicy{Count5xxPolicy})) + + for i := 0; i < failThreshold; i++ { + _, err := c.R().Get(ts.URL + "/500") + assertNil(t, err) + } + resp, err := c.R().Get(ts.URL + "/500") + assertErrorIs(t, ErrBreakerOpen, err) + assertNil(t, resp) + assertEqual(t, c.circuitBreaker.getState(), open) + + time.Sleep(timeout + 1*time.Millisecond) + assertEqual(t, c.circuitBreaker.getState(), halfOpen) + + resp, err = c.R().Get(ts.URL + "/500") + assertError(t, err) + assertEqual(t, c.circuitBreaker.getState(), open) + + time.Sleep(timeout + 1*time.Millisecond) + assertEqual(t, c.circuitBreaker.getState(), halfOpen) + + for i := 0; i < successThreshold; i++ { + _, err := c.R().Get(ts.URL + "/200") + assertNil(t, err) + } + assertEqual(t, c.circuitBreaker.getState(), closed) + + resp, err = c.R().Get(ts.URL + "/200") + assertNil(t, err) + assertEqual(t, http.StatusOK, resp.StatusCode()) + + resp, err = c.R().Get(ts.URL + "/500") + assertError(t, err) + assertEqual(t, c.circuitBreaker.failCount, 1) + + time.Sleep(timeout) + + resp, err = c.R().Get(ts.URL + "/500") + assertError(t, err) + assertEqual(t, c.circuitBreaker.failCount, 1) +}