diff --git a/examples/vertex.py b/examples/vertex.py new file mode 100644 index 00000000..825cf95c --- /dev/null +++ b/examples/vertex.py @@ -0,0 +1,43 @@ +import asyncio + +from anthropic import AnthropicVertex, AsyncAnthropicVertex + + +def sync_client() -> None: + print("------ Sync Vertex ------") + + client = AnthropicVertex() + + message = client.beta.messages.create( + model="claude-instant-1p2", + max_tokens=100, + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + ) + print(message.model_dump_json(indent=2)) + + +async def async_client() -> None: + print("------ Async Vertex ------") + + client = AsyncAnthropicVertex() + + message = await client.beta.messages.create( + model="claude-instant-1p2", + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "Say hello there!", + } + ], + ) + print(message.model_dump_json(indent=2)) + + +sync_client() +asyncio.run(async_client()) diff --git a/pyproject.toml b/pyproject.toml index bdef10d0..806feea4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,8 @@ classifiers = [ "License :: OSI Approved :: MIT License" ] - +[project.optional-dependencies] +vertex = ["google-auth >=2, <3"] [project.urls] Homepage = "https://github.com/anthropics/anthropic-sdk-python" diff --git a/requirements-dev.lock b/requirements-dev.lock index 8ff81764..07c3326f 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -11,6 +11,7 @@ annotated-types==0.6.0 anyio==4.1.0 argcomplete==3.1.2 attrs==23.1.0 +cachetools==5.3.2 certifi==2023.7.22 charset-normalizer==3.3.2 colorlog==6.7.0 @@ -20,6 +21,7 @@ distro==1.8.0 exceptiongroup==1.1.3 filelock==3.12.4 fsspec==2023.12.2 +google-auth==2.26.2 h11==0.14.0 httpcore==1.0.2 httpx==0.25.2 @@ -35,6 +37,8 @@ packaging==23.2 platformdirs==3.11.0 pluggy==1.3.0 py==1.11.0 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 pydantic==2.4.2 pydantic-core==2.10.1 pyright==1.1.332 @@ -45,6 +49,7 @@ pytz==2023.3.post1 pyyaml==6.0.1 requests==2.31.0 respx==0.20.2 +rsa==4.9 ruff==0.1.9 six==1.16.0 sniffio==1.3.0 diff --git a/requirements.lock b/requirements.lock index 54d268da..e4692c10 100644 --- a/requirements.lock +++ b/requirements.lock @@ -9,22 +9,27 @@ -e file:. annotated-types==0.6.0 anyio==4.1.0 +cachetools==5.3.2 certifi==2023.7.22 charset-normalizer==3.3.2 distro==1.8.0 exceptiongroup==1.1.3 filelock==3.13.1 fsspec==2023.12.0 +google-auth==2.26.2 h11==0.14.0 httpcore==1.0.2 httpx==0.25.2 huggingface-hub==0.16.4 idna==3.4 packaging==23.2 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 pydantic==2.4.2 pydantic-core==2.10.1 pyyaml==6.0.1 requests==2.31.0 +rsa==4.9 sniffio==1.3.0 tokenizers==0.14.0 tqdm==4.66.1 diff --git a/src/anthropic/__init__.py b/src/anthropic/__init__.py index 09afc2a1..3dfa40fc 100644 --- a/src/anthropic/__init__.py +++ b/src/anthropic/__init__.py @@ -69,6 +69,7 @@ "AI_PROMPT", ] +from .lib.vertex import * from .lib.streaming import * _setup_logging() diff --git a/src/anthropic/lib/_extras/__init__.py b/src/anthropic/lib/_extras/__init__.py new file mode 100644 index 00000000..4e3037ee --- /dev/null +++ b/src/anthropic/lib/_extras/__init__.py @@ -0,0 +1 @@ +from ._google_auth import google_auth as google_auth diff --git a/src/anthropic/lib/_extras/_common.py b/src/anthropic/lib/_extras/_common.py new file mode 100644 index 00000000..5d2b7f6a --- /dev/null +++ b/src/anthropic/lib/_extras/_common.py @@ -0,0 +1,13 @@ +from ..._exceptions import AnthropicError + +INSTRUCTIONS = """ + +Anthropic error: missing required dependency `{library}`. + + $ pip install anthropic[{extra}] +""" + + +class MissingDependencyError(AnthropicError): + def __init__(self, *, library: str, extra: str) -> None: + super().__init__(INSTRUCTIONS.format(library=library, extra=extra)) diff --git a/src/anthropic/lib/_extras/_google_auth.py b/src/anthropic/lib/_extras/_google_auth.py new file mode 100644 index 00000000..16cc7909 --- /dev/null +++ b/src/anthropic/lib/_extras/_google_auth.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from typing_extensions import ClassVar, override + +from ._common import MissingDependencyError +from ..._utils import LazyProxy + +if TYPE_CHECKING: + import google.auth # type: ignore + + google_auth = google.auth + + +class GoogleAuthProxy(LazyProxy[Any]): + should_cache: ClassVar[bool] = True + + @override + def __load__(self) -> Any: + try: + import google.auth # type: ignore + except ImportError as err: + raise MissingDependencyError(extra="vertex", library="google-auth") from err + + return google.auth + + +if not TYPE_CHECKING: + google_auth = GoogleAuthProxy() diff --git a/src/anthropic/lib/vertex/__init__.py b/src/anthropic/lib/vertex/__init__.py new file mode 100644 index 00000000..45b6301e --- /dev/null +++ b/src/anthropic/lib/vertex/__init__.py @@ -0,0 +1 @@ +from ._client import AnthropicVertex as AnthropicVertex, AsyncAnthropicVertex as AsyncAnthropicVertex diff --git a/src/anthropic/lib/vertex/_auth.py b/src/anthropic/lib/vertex/_auth.py new file mode 100644 index 00000000..befc3dc6 --- /dev/null +++ b/src/anthropic/lib/vertex/_auth.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .._extras import google_auth + +if TYPE_CHECKING: + from google.auth.credentials import Credentials # type: ignore[import-untyped] + +# pyright: reportMissingTypeStubs=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false +# google libraries don't provide types :/ + +# Note: these functions are blocking as they make HTTP requests, the async +# client runs these functions in a separate thread to ensure they do not +# cause synchronous blocking issues. + + +def load_auth() -> tuple[Credentials, str]: + from google.auth.transport.requests import Request # type: ignore[import-untyped] + + credentials, project_id = google_auth.default() + credentials.refresh(Request()) + + if not project_id: + raise ValueError("Could not resolve project_id") + + if not isinstance(project_id, str): + raise TypeError(f"Expected project_id to be a str but got {type(project_id)}") + + return credentials, project_id + + +def refresh_auth(credentials: Credentials) -> None: + from google.auth.transport.requests import Request # type: ignore[import-untyped] + + credentials.refresh(Request()) diff --git a/src/anthropic/lib/vertex/_client.py b/src/anthropic/lib/vertex/_client.py new file mode 100644 index 00000000..5d7029fe --- /dev/null +++ b/src/anthropic/lib/vertex/_client.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any, Union, Mapping, TypeVar +from typing_extensions import override + +import httpx + +from ... import _exceptions +from ._auth import load_auth, refresh_auth +from ..._types import NOT_GIVEN, NotGiven, Transport, ProxiesTypes, AsyncTransport +from ..._utils import is_dict, asyncify, is_given +from ..._compat import typed_cached_property +from ..._models import FinalRequestOptions +from ..._version import __version__ +from ..._streaming import Stream, AsyncStream +from ..._exceptions import APIStatusError +from ..._base_client import DEFAULT_MAX_RETRIES, BaseClient, SyncAPIClient, AsyncAPIClient +from ...resources.beta.beta import Beta, AsyncBeta + +if TYPE_CHECKING: + from google.auth.credentials import Credentials as GoogleCredentials # type: ignore + + +DEFAULT_VERSION = "vertex-2023-10-16" +DEFAULT_BETA_TYPES = ["private-messages-testing"] + +_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient]) +_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]]) + + +class BaseVertexClient(BaseClient[_HttpxClientT, _DefaultStreamT]): + @override + def _build_request( + self, + options: FinalRequestOptions, + ) -> httpx.Request: + if is_dict(options.json_data): + options.json_data.setdefault("anthropic_version", DEFAULT_VERSION) + options.json_data.setdefault("anthropic_beta", DEFAULT_BETA_TYPES) + + if options.url == "/v1/messages" and options.method == "post": + project_id = self.project_id + if project_id is None: + raise RuntimeError( + "No project_id was given and it could not be resolved from credentials. The client should be instantiated with the `project_id` argument or the `ANTHROPIC_VERTEX_PROJECT_ID` environment variable should be set." + ) + + if not is_dict(options.json_data): + raise RuntimeError("Expected json data to be a dictionary for post /v1/messages") + + model = options.json_data.pop("model") + stream = options.json_data.get("stream", False) + specifier = "streamRawPredict" if stream else "rawPredict" + + options.url = ( + f"/projects/{self.project_id}/locations/{self.region}/publishers/anthropic/models/{model}:{specifier}" + ) + + if is_dict(options.json_data): + options.json_data.pop("model", None) + + return super()._build_request(options) + + @typed_cached_property + def region(self) -> str: + raise RuntimeError("region not set") + + @typed_cached_property + def project_id(self) -> str | None: + project_id = os.environ.get("ANTHROPIC_VERTEX_PROJECT_ID") + if project_id: + return project_id + + return None + + @override + def _make_status_error( + self, + err_msg: str, + *, + body: object, + response: httpx.Response, + ) -> APIStatusError: + if response.status_code == 400: + return _exceptions.BadRequestError(err_msg, response=response, body=body) + + if response.status_code == 401: + return _exceptions.AuthenticationError(err_msg, response=response, body=body) + + if response.status_code == 403: + return _exceptions.PermissionDeniedError(err_msg, response=response, body=body) + + if response.status_code == 404: + return _exceptions.NotFoundError(err_msg, response=response, body=body) + + if response.status_code == 409: + return _exceptions.ConflictError(err_msg, response=response, body=body) + + if response.status_code == 422: + return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body) + + if response.status_code == 429: + return _exceptions.RateLimitError(err_msg, response=response, body=body) + + if response.status_code >= 500: + return _exceptions.InternalServerError(err_msg, response=response, body=body) + return APIStatusError(err_msg, response=response, body=body) + + +class AnthropicVertex(BaseVertexClient[httpx.Client, Stream[Any]], SyncAPIClient): + beta: Beta + + def __init__( + self, + *, + region: str | NotGiven = NOT_GIVEN, + project_id: str | NotGiven = NOT_GIVEN, + access_token: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + max_retries: int = DEFAULT_MAX_RETRIES, + default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + # Configure a custom httpx client. See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details. + http_client: httpx.Client | None = None, + # See httpx documentation for [custom transports](https://www.python-httpx.org/advanced/#custom-transports) + transport: Transport | None = None, + # See httpx documentation for [proxies](https://www.python-httpx.org/advanced/#http-proxying) + proxies: ProxiesTypes | None = None, + # See httpx documentation for [limits](https://www.python-httpx.org/advanced/#pool-limit-configuration) + connection_pool_limits: httpx.Limits | None = None, + _strict_response_validation: bool = False, + ) -> None: + if not is_given(region): + region = os.environ.get("CLOUD_ML_REGION", NOT_GIVEN) + if not is_given(region): + raise ValueError( + "No region was given. The client should be instantiated with the `region` argument or the `CLOUD_ML_REGION` environment variable should be set." + ) + + if base_url is None: + base_url = os.environ.get("ANTHROPIC_VERTEX_BASE_URL") + if base_url is None: + base_url = f"https://{region}-aiplatform.googleapis.com/v1" + + super().__init__( + version=__version__, + base_url=base_url, + timeout=timeout, + max_retries=max_retries, + custom_headers=default_headers, + custom_query=default_query, + http_client=http_client, + transport=transport, + proxies=proxies, + limits=connection_pool_limits, + _strict_response_validation=_strict_response_validation, + ) + + if is_given(project_id): + self.project_id = project_id + + self.region = region + self.access_token = access_token + self._credentials: GoogleCredentials | None = None + + self.beta = Beta( + # TODO: fix types here + self # type: ignore + ) + + @override + def _prepare_request(self, request: httpx.Request) -> None: + access_token = self._ensure_access_token() + + if request.headers.get("Authorization"): + # already authenticated, nothing for us to do + return + + request.headers["Authorization"] = f"Bearer {access_token}" + + def _ensure_access_token(self) -> str: + if self.access_token is not None: + return self.access_token + + if not self._credentials: + self._credentials, project_id = load_auth() + if not self.project_id: + self.project_id = project_id + else: + refresh_auth(self._credentials) + + if not self._credentials.token: + raise RuntimeError("Could not resolve API token from the environment") + + assert isinstance(self._credentials.token, str) + return self._credentials.token + + +class AsyncAnthropicVertex(BaseVertexClient[httpx.AsyncClient, AsyncStream[Any]], AsyncAPIClient): + beta: AsyncBeta + + def __init__( + self, + *, + region: str | NotGiven = NOT_GIVEN, + project_id: str | NotGiven = NOT_GIVEN, + access_token: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + max_retries: int = DEFAULT_MAX_RETRIES, + default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + # Configure a custom httpx client. See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details. + http_client: httpx.AsyncClient | None = None, + # See httpx documentation for [custom transports](https://www.python-httpx.org/advanced/#custom-transports) + transport: AsyncTransport | None = None, + # See httpx documentation for [proxies](https://www.python-httpx.org/advanced/#http-proxying) + proxies: ProxiesTypes | None = None, + # See httpx documentation for [limits](https://www.python-httpx.org/advanced/#pool-limit-configuration) + connection_pool_limits: httpx.Limits | None = None, + _strict_response_validation: bool = False, + ) -> None: + if not is_given(region): + region = os.environ.get("CLOUD_ML_REGION", NOT_GIVEN) + if not is_given(region): + raise ValueError( + "No region was given. The client should be instantiated with the `region` argument or the `CLOUD_ML_REGION` environment variable should be set." + ) + + if base_url is None: + base_url = os.environ.get("ANTHROPIC_VERTEX_BASE_URL") + if base_url is None: + base_url = f"https://{region}-aiplatform.googleapis.com/v1" + + super().__init__( + version=__version__, + base_url=base_url, + timeout=timeout, + max_retries=max_retries, + custom_headers=default_headers, + custom_query=default_query, + http_client=http_client, + transport=transport, + proxies=proxies, + limits=connection_pool_limits, + _strict_response_validation=_strict_response_validation, + ) + + if is_given(project_id): + self.project_id = project_id + + self.region = region + self.access_token = access_token + self._credentials: GoogleCredentials | None = None + + self.beta = AsyncBeta( + # TODO: fix types here + self # type: ignore + ) + + @override + async def _prepare_request(self, request: httpx.Request) -> None: + access_token = await self._ensure_access_token() + + if request.headers.get("Authorization"): + # already authenticated, nothing for us to do + return + + request.headers["Authorization"] = f"Bearer {access_token}" + + async def _ensure_access_token(self) -> str: + if self.access_token is not None: + return self.access_token + + if not self._credentials: + self._credentials, project_id = await asyncify(load_auth)() + if not self.project_id: + self.project_id = project_id + else: + await asyncify(refresh_auth)(self._credentials) + + if not self._credentials.token: + raise RuntimeError("Could not resolve API token from the environment") + + assert isinstance(self._credentials.token, str) + return self._credentials.token