diff --git a/README.md b/README.md index 3d607542..836dc982 100644 --- a/README.md +++ b/README.md @@ -360,6 +360,7 @@ import httpx from anthropic import Anthropic client = Anthropic( + # Or use the `ANTHROPIC_BASE_URL` env var base_url="http://my.test.server.example.com:8083", http_client=httpx.Client( proxies="http://my.test.proxy.example.com", diff --git a/src/anthropic/_client.py b/src/anthropic/_client.py index 62f80393..ab7a5cc6 100644 --- a/src/anthropic/_client.py +++ b/src/anthropic/_client.py @@ -103,6 +103,8 @@ def __init__( auth_token = os.environ.get("ANTHROPIC_AUTH_TOKEN") self.auth_token = auth_token + if base_url is None: + base_url = os.environ.get("ANTHROPIC_BASE_URL") if base_url is None: base_url = f"https://api.anthropic.com" @@ -362,6 +364,8 @@ def __init__( auth_token = os.environ.get("ANTHROPIC_AUTH_TOKEN") self.auth_token = auth_token + if base_url is None: + base_url = os.environ.get("ANTHROPIC_BASE_URL") if base_url is None: base_url = f"https://api.anthropic.com" diff --git a/tests/test_client.py b/tests/test_client.py index 8cfa59ab..7d03435d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -27,6 +27,8 @@ make_request_options, ) +from .utils import update_env + base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") api_key = "my-anthropic-api-key" @@ -408,6 +410,11 @@ class Model2(BaseModel): assert isinstance(response, Model1) assert response.foo == 1 + def test_base_url_env(self) -> None: + with update_env(ANTHROPIC_BASE_URL="http://localhost:5000/from/env"): + client = Anthropic(api_key=api_key, _strict_response_validation=True) + assert client.base_url == "http://localhost:5000/from/env/" + @pytest.mark.parametrize( "client", [ @@ -1036,6 +1043,11 @@ class Model2(BaseModel): assert isinstance(response, Model1) assert response.foo == 1 + def test_base_url_env(self) -> None: + with update_env(ANTHROPIC_BASE_URL="http://localhost:5000/from/env"): + client = AsyncAnthropic(api_key=api_key, _strict_response_validation=True) + assert client.base_url == "http://localhost:5000/from/env/" + @pytest.mark.parametrize( "client", [ diff --git a/tests/utils.py b/tests/utils.py index 0a1733f6..348363a5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,9 @@ from __future__ import annotations +import os import traceback -from typing import Any, TypeVar, cast +import contextlib +from typing import Any, TypeVar, Iterator, cast from datetime import date, datetime from typing_extensions import Literal, get_args, get_origin, assert_type @@ -103,3 +105,16 @@ def _assert_list_type(type_: type[object], value: object) -> None: inner_type = get_args(type_)[0] for entry in value: assert_type(inner_type, entry) # type: ignore + + +@contextlib.contextmanager +def update_env(**new_env: str) -> Iterator[None]: + old = os.environ.copy() + + try: + os.environ.update(new_env) + + yield None + finally: + os.environ.clear() + os.environ.update(old)