Skip to content

Commit

Permalink
feat(bedrock): add messages API (#362)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot authored Mar 4, 2024
1 parent 4895381 commit 5409be9
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 24 deletions.
53 changes: 30 additions & 23 deletions examples/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,44 @@
# Note: you must have installed `anthropic` with the `bedrock` extra
# e.g. `pip install -U anthropic[bedrock]`

from anthropic import AI_PROMPT, HUMAN_PROMPT, AnthropicBedrock
from anthropic import AnthropicBedrock

# Note: this assumes you have AWS credentials configured.
#
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
client = AnthropicBedrock()

print("------ standard response ------")
completion = client.completions.create(
model="anthropic.claude-instant-v1",
prompt=f"{HUMAN_PROMPT} hey!{AI_PROMPT}",
stop_sequences=[HUMAN_PROMPT],
max_tokens_to_sample=500,
temperature=0.5,
top_k=250,
top_p=0.5,
message = client.messages.create(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Hello!",
}
],
model="anthropic.claude-3-sonnet-20240229-v1:0",
)
print(completion.completion)
print(message.model_dump_json(indent=2))

print("------ streamed response ------")

question = """
Hey Claude! How can I recursively list all files in a directory in Python?
"""
with client.messages.stream(
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Say hello there!",
}
],
model="anthropic.claude-3-sonnet-20240229-v1:0",
) as stream:
for text in stream.text_stream:
print(text, end="", flush=True)
print()

print("------ streamed response ------")
stream = client.completions.create(
model="anthropic.claude-instant-v1",
prompt=f"{HUMAN_PROMPT} {question}{AI_PROMPT}",
max_tokens_to_sample=500,
stream=True,
)
for item in stream:
print(item.completion, end="")
print()
# you can still get the accumulated final message outside of
# the context manager, as long as the entire stream was consumed
# inside of the context manager
accumulated = stream.get_final_message()
print("accumulated message: ", accumulated.model_dump_json(indent=2))
7 changes: 6 additions & 1 deletion src/anthropic/lib/bedrock/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..._exceptions import APIStatusError
from ..._base_client import DEFAULT_MAX_RETRIES, BaseClient, SyncAPIClient, AsyncAPIClient, FinalRequestOptions
from ._stream_decoder import AWSEventStreamDecoder
from ...resources.messages import Messages, AsyncMessages
from ...resources.completions import Completions, AsyncCompletions

DEFAULT_VERSION = "bedrock-2023-05-31"
Expand All @@ -31,7 +32,7 @@ def _build_request(
if is_dict(options.json_data):
options.json_data.setdefault("anthropic_version", DEFAULT_VERSION)

if options.url == "/v1/complete" and options.method == "post":
if options.url in {"/v1/complete", "/v1/messages"} and options.method == "post":
if not is_dict(options.json_data):
raise RuntimeError("Expected dictionary json_data for post /completions endpoint")

Expand Down Expand Up @@ -79,6 +80,7 @@ def _make_status_error(


class AnthropicBedrock(BaseBedrockClient[httpx.Client, Stream[Any]], SyncAPIClient):
messages: Messages
completions: Completions

def __init__(
Expand Down Expand Up @@ -130,6 +132,7 @@ def __init__(
_strict_response_validation=_strict_response_validation,
)

self.messages = Messages(self)
self.completions = Completions(self)

@override
Expand All @@ -156,6 +159,7 @@ def _prepare_request(self, request: httpx.Request) -> None:


class AsyncAnthropicBedrock(BaseBedrockClient[httpx.AsyncClient, AsyncStream[Any]], AsyncAPIClient):
messages: AsyncMessages
completions: AsyncCompletions

def __init__(
Expand Down Expand Up @@ -207,6 +211,7 @@ def __init__(
_strict_response_validation=_strict_response_validation,
)

self.messages = AsyncMessages(self)
self.completions = AsyncCompletions(self)

@override
Expand Down

0 comments on commit 5409be9

Please sign in to comment.