diff --git a/src/anthropic/_base_client.py b/src/anthropic/_base_client.py index 852c3b6f..e3a27338 100644 --- a/src/anthropic/_base_client.py +++ b/src/anthropic/_base_client.py @@ -3,8 +3,10 @@ import json import time import uuid +import email import inspect import platform +import email.utils from types import TracebackType from random import random from typing import ( @@ -616,10 +618,22 @@ def _calculate_retry_timeout( try: # About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After # - # TODO: we may want to handle the case where the header is using the http-date syntax: "Retry-After: # ". See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax for # details. - retry_after = -1 if response_headers is None else int(response_headers.get("retry-after")) + if response_headers is not None: + retry_header = response_headers.get("retry-after") + try: + retry_after = int(retry_header) + except Exception: + retry_date_tuple = email.utils.parsedate_tz(retry_header) + if retry_date_tuple is None: + retry_after = -1 + else: + retry_date = email.utils.mktime_tz(retry_date_tuple) + retry_after = int(retry_date - time.time()) + else: + retry_after = -1 + except Exception: retry_after = -1 diff --git a/tests/test_client.py b/tests/test_client.py index fdd1e323..4b38d3ba 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -7,6 +7,7 @@ import asyncio import inspect from typing import Any, Dict, Union, cast +from unittest import mock import httpx import pytest @@ -423,6 +424,33 @@ class Model(BaseModel): response = client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] + @pytest.mark.parametrize( + "remaining_retries,retry_after,timeout", + [ + [3, "20", 20], + [3, "0", 2], + [3, "-10", 2], + [3, "60", 60], + [3, "61", 2], + [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20], + [3, "Fri, 29 Sep 2023 16:26:37 GMT", 2], + [3, "Fri, 29 Sep 2023 16:26:27 GMT", 2], + [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60], + [3, "Fri, 29 Sep 2023 16:27:38 GMT", 2], + [3, "99999999999999999999999999999999999", 2], + [3, "Zun, 29 Sep 2023 16:26:27 GMT", 2], + [3, "", 2], + ], + ) + @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) + def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: + client = Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True) + + headers = httpx.Headers({"retry-after": retry_after}) + options = FinalRequestOptions(method="get", url="/foo", max_retries=2) + calculated = client._calculate_retry_timeout(remaining_retries, options, headers) + assert calculated == pytest.approx(timeout, 0.6) # pyright: ignore[reportUnknownMemberType] + class TestAsyncAnthropic: client = AsyncAnthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -821,3 +849,31 @@ class Model(BaseModel): response = await client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] + + @pytest.mark.parametrize( + "remaining_retries,retry_after,timeout", + [ + [3, "20", 20], + [3, "0", 2], + [3, "-10", 2], + [3, "60", 60], + [3, "61", 2], + [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20], + [3, "Fri, 29 Sep 2023 16:26:37 GMT", 2], + [3, "Fri, 29 Sep 2023 16:26:27 GMT", 2], + [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60], + [3, "Fri, 29 Sep 2023 16:27:38 GMT", 2], + [3, "99999999999999999999999999999999999", 2], + [3, "Zun, 29 Sep 2023 16:26:27 GMT", 2], + [3, "", 2], + ], + ) + @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) + @pytest.mark.asyncio + async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: + client = AsyncAnthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True) + + headers = httpx.Headers({"retry-after": retry_after}) + options = FinalRequestOptions(method="get", url="/foo", max_retries=2) + calculated = client._calculate_retry_timeout(remaining_retries, options, headers) + assert calculated == pytest.approx(timeout, 0.6) # pyright: ignore[reportUnknownMemberType]