From b732c7736674e7d5495f5cd73bee2c6255072235 Mon Sep 17 00:00:00 2001 From: Julian Spring <51660117+springjd@users.noreply.github.com> Date: Fri, 13 Dec 2024 15:13:09 -0500 Subject: [PATCH] Add retries for 502, 504 HTTP statuses Add 502 Bad Gateway and 504 Gateway Timeout HTTP status to client's retry logic Add tests for retries on 502,504 and test that no retry is done on another status (404) --- trino/trino.go | 2 +- trino/trino_test.go | 74 +++++++++++++++++++++++++++++++-------------- 2 files changed, 53 insertions(+), 23 deletions(-) diff --git a/trino/trino.go b/trino/trino.go index 1453f03..6f32316 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -617,7 +617,7 @@ func (c *Conn) roundTrip(ctx context.Context, req *http.Request) (*http.Response } } return resp, nil - case http.StatusServiceUnavailable: + case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: resp.Body.Close() timer.Reset(delay) delay = time.Duration(math.Min( diff --git a/trino/trino_test.go b/trino/trino_test.go index cc4c181..f1623a3 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -266,32 +266,62 @@ func TestRegisterCustomClientReserved(t *testing.T) { } func TestRoundTripRetryQueryError(t *testing.T) { - count := 0 - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if count == 0 { - count++ - w.WriteHeader(http.StatusServiceUnavailable) - return - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(&stmtResponse{ - Error: ErrTrino{ - ErrorName: "TEST", - }, - }) - })) + testcases := []struct { + Name string + HttpStatus int + ExpectedErrorStatus string + }{ + { + Name: "Test retry 502 Bad Gateway", + HttpStatus: http.StatusBadGateway, + ExpectedErrorStatus: "200 OK", + }, + { + Name: "Test retry 503 Service Unavailable", + HttpStatus: http.StatusServiceUnavailable, + ExpectedErrorStatus: "200 OK", + }, + { + Name: "Test retry 504 Gateway Timeout", + HttpStatus: http.StatusGatewayTimeout, + ExpectedErrorStatus: "200 OK", + }, + { + Name: "Test no retry 404 Not Found", + HttpStatus: http.StatusNotFound, + ExpectedErrorStatus: "404 Not Found", + }, + } + for _, tc := range testcases { + t.Run(tc.Name, func(t *testing.T) { + count := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if count == 0 { + count++ + w.WriteHeader(tc.HttpStatus) + return + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(&stmtResponse{ + Error: ErrTrino{ + ErrorName: "TEST", + }, + }) + })) - t.Cleanup(ts.Close) + t.Cleanup(ts.Close) - db, err := sql.Open("trino", ts.URL) - require.NoError(t, err) + db, err := sql.Open("trino", ts.URL) + require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, db.Close()) - }) + t.Cleanup(func() { + assert.NoError(t, db.Close()) + }) - _, err = db.Query("SELECT 1") - assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err) + _, err = db.Query("SELECT 1") + assert.ErrorContains(t, err, tc.ExpectedErrorStatus, "unexpected error: %w", err) + }) + } } func TestRoundTripBogusData(t *testing.T) {