From 98a0a7b9c679539c98d212b12c0a9a950fd6371d Mon Sep 17 00:00:00 2001 From: Robert Craigie Date: Tue, 13 Aug 2024 15:38:32 -0400 Subject: [PATCH] feat(client): add streaming helpers for prompt caching --- src/anthropic/lib/streaming/__init__.py | 6 + .../_prompt_caching_beta_messages.py | 423 ++++++++++++++++++ .../streaming/_prompt_caching_beta_types.py | 32 ++ .../resources/beta/prompt_caching/messages.py | 119 +++++ 4 files changed, 580 insertions(+) create mode 100644 src/anthropic/lib/streaming/_prompt_caching_beta_messages.py create mode 100644 src/anthropic/lib/streaming/_prompt_caching_beta_types.py diff --git a/src/anthropic/lib/streaming/__init__.py b/src/anthropic/lib/streaming/__init__.py index 0ab41209..fbd25b02 100644 --- a/src/anthropic/lib/streaming/__init__.py +++ b/src/anthropic/lib/streaming/__init__.py @@ -11,3 +11,9 @@ MessageStreamManager as MessageStreamManager, AsyncMessageStreamManager as AsyncMessageStreamManager, ) +from ._prompt_caching_beta_messages import ( + PromptCachingBetaMessageStream as PromptCachingBetaMessageStream, + AsyncPromptCachingBetaMessageStream as AsyncPromptCachingBetaMessageStream, + PromptCachingBetaMessageStreamManager as PromptCachingBetaMessageStreamManager, + AsyncPromptCachingBetaMessageStreamManager as AsyncPromptCachingBetaMessageStreamManager, +) diff --git a/src/anthropic/lib/streaming/_prompt_caching_beta_messages.py b/src/anthropic/lib/streaming/_prompt_caching_beta_messages.py new file mode 100644 index 00000000..df727ea8 --- /dev/null +++ b/src/anthropic/lib/streaming/_prompt_caching_beta_messages.py @@ -0,0 +1,423 @@ +from __future__ import annotations + +from types import TracebackType +from typing import TYPE_CHECKING, Any, Callable, cast +from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never + +import httpx + +from ...types import ContentBlock +from ..._utils import consume_sync_iterator, consume_async_iterator +from ..._models import build, construct_type +from ..._streaming import Stream, AsyncStream +from ._prompt_caching_beta_types import ( + TextEvent, + InputJsonEvent, + MessageStopEvent, + ContentBlockStopEvent, + PromptCachingBetaMessageStreamEvent, +) +from ...types.beta.prompt_caching import PromptCachingBetaMessage, RawPromptCachingBetaMessageStreamEvent + +if TYPE_CHECKING: + from ..._client import Anthropic, AsyncAnthropic + + +class PromptCachingBetaMessageStream: + text_stream: Iterator[str] + """Iterator over just the text deltas in the stream. + + ```py + for text in stream.text_stream: + print(text, end="", flush=True) + print() + ``` + """ + + response: httpx.Response + + def __init__( + self, + *, + cast_to: type[RawPromptCachingBetaMessageStreamEvent], + response: httpx.Response, + client: Anthropic, + ) -> None: + self.response = response + self._cast_to = cast_to + self._client = client + + self.text_stream = self.__stream_text__() + self.__final_message_snapshot: PromptCachingBetaMessage | None = None + + self._iterator = self.__stream__() + self._raw_stream: Stream[RawPromptCachingBetaMessageStreamEvent] = Stream( + cast_to=cast_to, response=response, client=client + ) + + def __next__(self) -> PromptCachingBetaMessageStreamEvent: + return self._iterator.__next__() + + def __iter__(self) -> Iterator[PromptCachingBetaMessageStreamEvent]: + for item in self._iterator: + yield item + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + def close(self) -> None: + """ + Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + self.response.close() + + def get_final_message(self) -> PromptCachingBetaMessage: + """Waits until the stream has been read to completion and returns + the accumulated `PromptCachingBetaMessage` object. + """ + self.until_done() + assert self.__final_message_snapshot is not None + return self.__final_message_snapshot + + def get_final_text(self) -> str: + """Returns all `text` content blocks concatenated together. + + > [!NOTE] + > Currently the API will only respond with a single content block. + + Will raise an error if no `text` content blocks were returned. + """ + message = self.get_final_message() + text_blocks: list[str] = [] + for block in message.content: + if block.type == "text": + text_blocks.append(block.text) + + if not text_blocks: + raise RuntimeError("Expected to have received at least 1 text block") + + return "".join(text_blocks) + + def until_done(self) -> None: + """Blocks until the stream has been consumed""" + consume_sync_iterator(self) + + # properties + @property + def current_message_snapshot(self) -> PromptCachingBetaMessage: + assert self.__final_message_snapshot is not None + return self.__final_message_snapshot + + def __stream__(self) -> Iterator[PromptCachingBetaMessageStreamEvent]: + for sse_event in self._raw_stream: + self.__final_message_snapshot = accumulate_event( + event=sse_event, + current_snapshot=self.__final_message_snapshot, + ) + + events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot) + for event in events_to_fire: + yield event + + def __stream_text__(self) -> Iterator[str]: + for chunk in self: + if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": + yield chunk.delta.text + + +class PromptCachingBetaMessageStreamManager: + """Wrapper over PromptCachingBetaMessageStream that is returned by `.stream()`. + + ```py + with client.beta.prompt_caching.messages.stream(...) as stream: + for chunk in stream: + ... + ``` + """ + + def __init__( + self, + api_request: Callable[[], Stream[RawPromptCachingBetaMessageStreamEvent]], + ) -> None: + self.__stream: PromptCachingBetaMessageStream | None = None + self.__api_request = api_request + + def __enter__(self) -> PromptCachingBetaMessageStream: + raw_stream = self.__api_request() + + self.__stream = PromptCachingBetaMessageStream( + cast_to=raw_stream._cast_to, + response=raw_stream.response, + client=raw_stream._client, + ) + + return self.__stream + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self.__stream is not None: + self.__stream.close() + + +class AsyncPromptCachingBetaMessageStream: + text_stream: AsyncIterator[str] + """Async iterator over just the text deltas in the stream. + + ```py + async for text in stream.text_stream: + print(text, end="", flush=True) + print() + ``` + """ + + response: httpx.Response + + def __init__( + self, + *, + cast_to: type[RawPromptCachingBetaMessageStreamEvent], + response: httpx.Response, + client: AsyncAnthropic, + ) -> None: + self.response = response + self._cast_to = cast_to + self._client = client + + self.text_stream = self.__stream_text__() + self.__final_message_snapshot: PromptCachingBetaMessage | None = None + + self._iterator = self.__stream__() + self._raw_stream: AsyncStream[RawPromptCachingBetaMessageStreamEvent] = AsyncStream( + cast_to=cast_to, response=response, client=client + ) + + async def __anext__(self) -> PromptCachingBetaMessageStreamEvent: + return await self._iterator.__anext__() + + async def __aiter__(self) -> AsyncIterator[PromptCachingBetaMessageStreamEvent]: + async for item in self._iterator: + yield item + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.close() + + async def close(self) -> None: + """ + Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + await self.response.aclose() + + async def get_final_message(self) -> PromptCachingBetaMessage: + """Waits until the stream has been read to completion and returns + the accumulated `PromptCachingBetaMessage` object. + """ + await self.until_done() + assert self.__final_message_snapshot is not None + return self.__final_message_snapshot + + async def get_final_text(self) -> str: + """Returns all `text` content blocks concatenated together. + + > [!NOTE] + > Currently the API will only respond with a single content block. + + Will raise an error if no `text` content blocks were returned. + """ + message = await self.get_final_message() + text_blocks: list[str] = [] + for block in message.content: + if block.type == "text": + text_blocks.append(block.text) + + if not text_blocks: + raise RuntimeError("Expected to have received at least 1 text block") + + return "".join(text_blocks) + + async def until_done(self) -> None: + """Waits until the stream has been consumed""" + await consume_async_iterator(self) + + # properties + @property + def current_message_snapshot(self) -> PromptCachingBetaMessage: + assert self.__final_message_snapshot is not None + return self.__final_message_snapshot + + async def __stream__(self) -> AsyncIterator[PromptCachingBetaMessageStreamEvent]: + async for sse_event in self._raw_stream: + self.__final_message_snapshot = accumulate_event( + event=sse_event, + current_snapshot=self.__final_message_snapshot, + ) + + events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot) + for event in events_to_fire: + yield event + + async def __stream_text__(self) -> AsyncIterator[str]: + async for chunk in self: + if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": + yield chunk.delta.text + + +class AsyncPromptCachingBetaMessageStreamManager: + """Wrapper over AsyncMessageStream that is returned by `.stream()` + so that an async context manager can be used without `await`ing the + original client call. + + ```py + async with client.messages.stream(...) as stream: + async for chunk in stream: + ... + ``` + """ + + def __init__( + self, + api_request: Awaitable[AsyncStream[RawPromptCachingBetaMessageStreamEvent]], + ) -> None: + self.__stream: AsyncPromptCachingBetaMessageStream | None = None + self.__api_request = api_request + + async def __aenter__(self) -> AsyncPromptCachingBetaMessageStream: + raw_stream = await self.__api_request + + self.__stream = AsyncPromptCachingBetaMessageStream( + cast_to=raw_stream._cast_to, + response=raw_stream.response, + client=raw_stream._client, + ) + + return self.__stream + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self.__stream is not None: + await self.__stream.close() + + +def build_events( + *, + event: RawPromptCachingBetaMessageStreamEvent, + message_snapshot: PromptCachingBetaMessage, +) -> list[PromptCachingBetaMessageStreamEvent]: + events_to_fire: list[PromptCachingBetaMessageStreamEvent] = [] + + if event.type == "message_start": + events_to_fire.append(event) + elif event.type == "message_delta": + events_to_fire.append(event) + elif event.type == "message_stop": + events_to_fire.append(build(MessageStopEvent, type="message_stop", message=message_snapshot)) + elif event.type == "content_block_start": + events_to_fire.append(event) + elif event.type == "content_block_delta": + events_to_fire.append(event) + + content_block = message_snapshot.content[event.index] + if event.delta.type == "text_delta" and content_block.type == "text": + events_to_fire.append( + build( + TextEvent, + type="text", + text=event.delta.text, + snapshot=content_block.text, + ) + ) + elif event.delta.type == "input_json_delta" and content_block.type == "tool_use": + events_to_fire.append( + build( + InputJsonEvent, + type="input_json", + partial_json=event.delta.partial_json, + snapshot=content_block.input, + ) + ) + elif event.type == "content_block_stop": + content_block = message_snapshot.content[event.index] + + events_to_fire.append( + build(ContentBlockStopEvent, type="content_block_stop", index=event.index, content_block=content_block), + ) + else: + # we only want exhaustive checking for linters, not at runtime + if TYPE_CHECKING: # type: ignore[unreachable] + assert_never(event) + + return events_to_fire + + +JSON_BUF_PROPERTY = "__json_buf" + + +def accumulate_event( + *, + event: RawPromptCachingBetaMessageStreamEvent, + current_snapshot: PromptCachingBetaMessage | None, +) -> PromptCachingBetaMessage: + if current_snapshot is None: + if event.type == "message_start": + return PromptCachingBetaMessage.construct(**cast(Any, event.message.to_dict())) + + raise RuntimeError(f'Unexpected event order, got {event.type} before "message_start"') + + if event.type == "content_block_start": + # TODO: check index + current_snapshot.content.append( + cast( + ContentBlock, + construct_type(type_=ContentBlock, value=event.content_block.model_dump()), + ), + ) + elif event.type == "content_block_delta": + content = current_snapshot.content[event.index] + if content.type == "text" and event.delta.type == "text_delta": + content.text += event.delta.text + elif content.type == "tool_use" and event.delta.type == "input_json_delta": + from jiter import from_json + + # we need to keep track of the raw JSON string as well so that we can + # re-parse it for each delta, for now we just store it as an untyped + # property on the snapshot + json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b"")) + json_buf += bytes(event.delta.partial_json, "utf-8") + + if json_buf: + content.input = from_json(json_buf, partial_mode=True) + + setattr(content, JSON_BUF_PROPERTY, json_buf) + elif event.type == "message_delta": + current_snapshot.stop_reason = event.delta.stop_reason + current_snapshot.stop_sequence = event.delta.stop_sequence + current_snapshot.usage.output_tokens = event.usage.output_tokens + + return current_snapshot diff --git a/src/anthropic/lib/streaming/_prompt_caching_beta_types.py b/src/anthropic/lib/streaming/_prompt_caching_beta_types.py new file mode 100644 index 00000000..d8fdce52 --- /dev/null +++ b/src/anthropic/lib/streaming/_prompt_caching_beta_types.py @@ -0,0 +1,32 @@ +from typing import Union +from typing_extensions import Literal + +from ._types import ( + TextEvent, + InputJsonEvent, + RawMessageDeltaEvent, + ContentBlockStopEvent, + RawContentBlockDeltaEvent, + RawContentBlockStartEvent, +) +from ...types import RawMessageStopEvent +from ...types.beta.prompt_caching import PromptCachingBetaMessage, RawPromptCachingBetaMessageStartEvent + + +class MessageStopEvent(RawMessageStopEvent): + type: Literal["message_stop"] + + message: PromptCachingBetaMessage + + +PromptCachingBetaMessageStreamEvent = Union[ + RawPromptCachingBetaMessageStartEvent, + MessageStopEvent, + # same as non-beta + TextEvent, + InputJsonEvent, + RawMessageDeltaEvent, + RawContentBlockStartEvent, + RawContentBlockDeltaEvent, + ContentBlockStopEvent, +] diff --git a/src/anthropic/resources/beta/prompt_caching/messages.py b/src/anthropic/resources/beta/prompt_caching/messages.py index c2023c18..53e94ecd 100644 --- a/src/anthropic/resources/beta/prompt_caching/messages.py +++ b/src/anthropic/resources/beta/prompt_caching/messages.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import List, Union, Iterable, overload +from functools import partial from typing_extensions import Literal import httpx @@ -21,6 +22,7 @@ from ...._constants import DEFAULT_TIMEOUT from ...._streaming import Stream, AsyncStream from ...._base_client import make_request_options +from ....lib.streaming import PromptCachingBetaMessageStreamManager, AsyncPromptCachingBetaMessageStreamManager from ....types.model_param import ModelParam from ....types.beta.prompt_caching import message_create_params from ....types.beta.prompt_caching.prompt_caching_beta_message import PromptCachingBetaMessage @@ -885,6 +887,65 @@ def create( stream_cls=Stream[RawPromptCachingBetaMessageStreamEvent], ) + def stream( + self, + *, + max_tokens: int, + messages: Iterable[PromptCachingBetaMessageParam], + model: ModelParam, + metadata: message_create_params.Metadata | NotGiven = NOT_GIVEN, + stop_sequences: List[str] | NotGiven = NOT_GIVEN, + system: Union[str, Iterable[PromptCachingBetaTextBlockParam]] | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: message_create_params.ToolChoice | NotGiven = NOT_GIVEN, + tools: Iterable[PromptCachingBetaToolParam] | NotGiven = NOT_GIVEN, + top_k: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> PromptCachingBetaMessageStreamManager: + """Create a Message stream""" + if not is_given(timeout) and self._client.timeout == DEFAULT_TIMEOUT: + timeout = 600 + + extra_headers = { + "anthropic-beta": "prompt-caching-2024-07-31", + "X-Stainless-Stream-Helper": "beta.prompt_caching.messages", + **(extra_headers or {}), + } + request = partial( + self._post, + "/v1/messages?beta=prompt_caching", + body=maybe_transform( + { + "max_tokens": max_tokens, + "messages": messages, + "model": model, + "metadata": metadata, + "stop_sequences": stop_sequences, + "stream": True, + "system": system, + "temperature": temperature, + "tool_choice": tool_choice, + "tools": tools, + "top_k": top_k, + "top_p": top_p, + }, + message_create_params.MessageCreateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=PromptCachingBetaMessage, + stream=True, + stream_cls=Stream[RawPromptCachingBetaMessageStreamEvent], + ) + return PromptCachingBetaMessageStreamManager(request) + class AsyncMessages(AsyncAPIResource): @cached_property @@ -1737,6 +1798,64 @@ async def create( stream_cls=AsyncStream[RawPromptCachingBetaMessageStreamEvent], ) + def stream( + self, + *, + max_tokens: int, + messages: Iterable[PromptCachingBetaMessageParam], + model: ModelParam, + metadata: message_create_params.Metadata | NotGiven = NOT_GIVEN, + stop_sequences: List[str] | NotGiven = NOT_GIVEN, + system: Union[str, Iterable[PromptCachingBetaTextBlockParam]] | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: message_create_params.ToolChoice | NotGiven = NOT_GIVEN, + tools: Iterable[PromptCachingBetaToolParam] | NotGiven = NOT_GIVEN, + top_k: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AsyncPromptCachingBetaMessageStreamManager: + """Create a Message stream""" + if not is_given(timeout) and self._client.timeout == DEFAULT_TIMEOUT: + timeout = 600 + + extra_headers = { + "anthropic-beta": "prompt-caching-2024-07-31", + "X-Stainless-Stream-Helper": "beta.prompt_caching.messages", + **(extra_headers or {}), + } + request = self._post( + "/v1/messages?beta=prompt_caching", + body=maybe_transform( + { + "max_tokens": max_tokens, + "messages": messages, + "model": model, + "metadata": metadata, + "stop_sequences": stop_sequences, + "stream": True, + "system": system, + "temperature": temperature, + "tool_choice": tool_choice, + "tools": tools, + "top_k": top_k, + "top_p": top_p, + }, + message_create_params.MessageCreateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=PromptCachingBetaMessage, + stream=True, + stream_cls=AsyncStream[RawPromptCachingBetaMessageStreamEvent], + ) + return AsyncPromptCachingBetaMessageStreamManager(request) + class MessagesWithRawResponse: def __init__(self, messages: Messages) -> None: