From 33607808ff9a375a9b84b02a57abacb6db59cb6d Mon Sep 17 00:00:00 2001 From: Julian Spring <51660117+springjd@users.noreply.github.com> Date: Fri, 13 Dec 2024 15:13:09 -0500 Subject: [PATCH] Add retries for 502, 504 HTTP statuses Add 502 Bad Gateway and 504 Gateway Timeout HTTP status to client's retry logic Add tests for retries on 502,504 and test that no retry is done on another status (404) --- trino/trino.go | 2 +- trino/trino_test.go | 74 +++++++++++++++++++++++++++++++-------------- 2 files changed, 53 insertions(+), 23 deletions(-) diff --git a/trino/trino.go b/trino/trino.go index 1453f03..6f32316 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -617,7 +617,7 @@ func (c *Conn) roundTrip(ctx context.Context, req *http.Request) (*http.Response } } return resp, nil - case http.StatusServiceUnavailable: + case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: resp.Body.Close() timer.Reset(delay) delay = time.Duration(math.Min( diff --git a/trino/trino_test.go b/trino/trino_test.go index cc4c181..f1623a3 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -266,32 +266,62 @@ func TestRegisterCustomClientReserved(t *testing.T) { } func TestRoundTripRetryQueryError(t *testing.T) { - count := 0 - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if count == 0 { - count++ - w.WriteHeader(http.StatusServiceUnavailable) - return - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(&stmtResponse{ - Error: ErrTrino{ - ErrorName: "TEST", - }, - }) - })) + testcases := []struct { + Name string + HttpStatus int + ExpectedErrorStatus string + }{ + { + Name: "Test retry 502 Bad Gateway", + HttpStatus: http.StatusBadGateway, + ExpectedErrorStatus: "200 OK", + }, + { + Name: "Test retry 503 Service Unavailable", + HttpStatus: http.StatusServiceUnavailable, + ExpectedErrorStatus: "200 OK", + }, + { + Name: "Test retry 504 Gateway Timeout", + HttpStatus: http.StatusGatewayTimeout, + ExpectedErrorStatus: "200 OK", + }, + { + Name: "Test no retry 404 Not Found", + HttpStatus: http.StatusNotFound, + ExpectedErrorStatus: "404 Not Found", + }, + } + for _, tc := range testcases { + t.Run(tc.Name, func(t *testing.T) { + count := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if count == 0 { + count++ + w.WriteHeader(tc.HttpStatus) + return + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(&stmtResponse{ + Error: ErrTrino{ + ErrorName: "TEST", + }, + }) + })) - t.Cleanup(ts.Close) + t.Cleanup(ts.Close) - db, err := sql.Open("trino", ts.URL) - require.NoError(t, err) + db, err := sql.Open("trino", ts.URL) + require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, db.Close()) - }) + t.Cleanup(func() { + assert.NoError(t, db.Close()) + }) - _, err = db.Query("SELECT 1") - assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err) + _, err = db.Query("SELECT 1") + assert.ErrorContains(t, err, tc.ExpectedErrorStatus, "unexpected error: %w", err) + }) + } } func TestRoundTripBogusData(t *testing.T) {