diff --git a/trino/trino.go b/trino/trino.go index 1453f03..6f32316 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -617,7 +617,7 @@ func (c *Conn) roundTrip(ctx context.Context, req *http.Request) (*http.Response } } return resp, nil - case http.StatusServiceUnavailable: + case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: resp.Body.Close() timer.Reset(delay) delay = time.Duration(math.Min( diff --git a/trino/trino_test.go b/trino/trino_test.go index cc4c181..f1623a3 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -266,32 +266,62 @@ func TestRegisterCustomClientReserved(t *testing.T) { } func TestRoundTripRetryQueryError(t *testing.T) { - count := 0 - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if count == 0 { - count++ - w.WriteHeader(http.StatusServiceUnavailable) - return - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(&stmtResponse{ - Error: ErrTrino{ - ErrorName: "TEST", - }, - }) - })) + testcases := []struct { + Name string + HttpStatus int + ExpectedErrorStatus string + }{ + { + Name: "Test retry 502 Bad Gateway", + HttpStatus: http.StatusBadGateway, + ExpectedErrorStatus: "200 OK", + }, + { + Name: "Test retry 503 Service Unavailable", + HttpStatus: http.StatusServiceUnavailable, + ExpectedErrorStatus: "200 OK", + }, + { + Name: "Test retry 504 Gateway Timeout", + HttpStatus: http.StatusGatewayTimeout, + ExpectedErrorStatus: "200 OK", + }, + { + Name: "Test no retry 404 Not Found", + HttpStatus: http.StatusNotFound, + ExpectedErrorStatus: "404 Not Found", + }, + } + for _, tc := range testcases { + t.Run(tc.Name, func(t *testing.T) { + count := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if count == 0 { + count++ + w.WriteHeader(tc.HttpStatus) + return + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(&stmtResponse{ + Error: ErrTrino{ + ErrorName: "TEST", + }, + }) + })) - t.Cleanup(ts.Close) + t.Cleanup(ts.Close) - db, err := sql.Open("trino", ts.URL) - require.NoError(t, err) + db, err := sql.Open("trino", ts.URL) + require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, db.Close()) - }) + t.Cleanup(func() { + assert.NoError(t, db.Close()) + }) - _, err = db.Query("SELECT 1") - assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err) + _, err = db.Query("SELECT 1") + assert.ErrorContains(t, err, tc.ExpectedErrorStatus, "unexpected error: %w", err) + }) + } } func TestRoundTripBogusData(t *testing.T) {