diff --git a/examples/messages_stream.py b/examples/messages_stream.py index 523c485e..be69a2c1 100644 --- a/examples/messages_stream.py +++ b/examples/messages_stream.py @@ -16,8 +16,12 @@ async def main() -> None: ], model="claude-3-opus-20240229", ) as stream: - async for text in stream.text_stream: - print(text, end="", flush=True) + async for event in stream: + if event.type == "text": + print(event.text, end="", flush=True) + elif event.type == "content_block_stop": + print() + print("\ncontent block finished accumulating:", event.content_block) print() # you can still get the accumulated final message outside of diff --git a/examples/messages_stream_handler.py b/examples/messages_stream_handler.py deleted file mode 100644 index 6bf98dab..00000000 --- a/examples/messages_stream_handler.py +++ /dev/null @@ -1,32 +0,0 @@ -import asyncio -from typing_extensions import override - -from anthropic import AsyncAnthropic, AsyncMessageStream -from anthropic.types import MessageStreamEvent - -client = AsyncAnthropic() - - -class MyStream(AsyncMessageStream): - @override - async def on_stream_event(self, event: MessageStreamEvent) -> None: - print("on_event fired with:", event) - - -async def main() -> None: - async with client.messages.stream( - max_tokens=1024, - messages=[ - { - "role": "user", - "content": "Say hello there!", - } - ], - model="claude-3-opus-20240229", - event_handler=MyStream, - ) as stream: - accumulated = await stream.get_final_message() - print("accumulated message: ", accumulated.to_json()) - - -asyncio.run(main()) diff --git a/src/anthropic/lib/streaming/__init__.py b/src/anthropic/lib/streaming/__init__.py index 71c5efd6..a9329d5d 100644 --- a/src/anthropic/lib/streaming/__init__.py +++ b/src/anthropic/lib/streaming/__init__.py @@ -1,3 +1,9 @@ +from ._types import ( + TextEvent as TextEvent, + MessageStopEvent as MessageStopEvent, + MessageStreamEvent as MessageStreamEvent, + ContentBlockStopEvent as ContentBlockStopEvent, +) from ._messages import ( MessageStream as MessageStream, MessageStreamT as MessageStreamT, diff --git a/src/anthropic/lib/streaming/_messages.py b/src/anthropic/lib/streaming/_messages.py index b41e6df7..0a49c819 100644 --- a/src/anthropic/lib/streaming/_messages.py +++ b/src/anthropic/lib/streaming/_messages.py @@ -7,7 +7,8 @@ import httpx -from ...types import Message, ContentBlock, MessageStreamEvent +from ._types import TextEvent, MessageStopEvent, MessageStreamEvent, ContentBlockStopEvent +from ...types import Message, ContentBlock, RawMessageStreamEvent from ..._utils import consume_sync_iterator, consume_async_iterator from ..._streaming import Stream, AsyncStream @@ -31,7 +32,7 @@ class MessageStream: def __init__( self, *, - cast_to: type[MessageStreamEvent], + cast_to: type[RawMessageStreamEvent], response: httpx.Response, client: Anthropic, ) -> None: @@ -43,7 +44,7 @@ def __init__( self.__final_message_snapshot: Message | None = None self._iterator = self.__stream__() - self._raw_stream: Stream[MessageStreamEvent] = Stream(cast_to=cast_to, response=response, client=client) + self._raw_stream: Stream[RawMessageStreamEvent] = Stream(cast_to=cast_to, response=response, client=client) def __next__(self) -> MessageStreamEvent: return self._iterator.__next__() @@ -110,7 +111,7 @@ def current_message_snapshot(self) -> Message: return self.__final_message_snapshot # event handlers - def on_stream_event(self, event: MessageStreamEvent) -> None: + def on_stream_event(self, event: RawMessageStreamEvent) -> None: """Callback that is fired for every Server-Sent-Event""" def on_message(self, message: Message) -> None: @@ -154,9 +155,10 @@ def __stream__(self) -> Iterator[MessageStreamEvent]: event=sse_event, current_snapshot=self.__final_message_snapshot, ) - self._emit_sse_event(sse_event) - yield sse_event + events_to_fire = self._emit_sse_event(sse_event) + for event in events_to_fire: + yield event except (httpx.TimeoutException, asyncio.TimeoutError) as exc: self.on_timeout() self.on_exception(exc) @@ -172,32 +174,47 @@ def __stream_text__(self) -> Iterator[str]: if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": yield chunk.delta.text - def _emit_sse_event(self, event: MessageStreamEvent) -> None: + def _emit_sse_event(self, event: RawMessageStreamEvent) -> list[MessageStreamEvent]: self.on_stream_event(event) + events_to_fire: list[MessageStreamEvent] = [] + if event.type == "message_start": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "message_delta": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "message_stop": self.on_message(self.current_message_snapshot) + events_to_fire.append(MessageStopEvent(type="message_stop", message=self.current_message_snapshot)) elif event.type == "content_block_start": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "content_block_delta": - content = self.current_message_snapshot.content[event.index] - if event.delta.type == "text_delta" and content.type == "text": - self.on_text(event.delta.text, content.text) + events_to_fire.append(event) + + content_block = self.current_message_snapshot.content[event.index] + if event.delta.type == "text_delta" and content_block.type == "text": + self.on_text(event.delta.text, content_block.text) + events_to_fire.append( + TextEvent( + type="text", + text=event.delta.text, + snapshot=content_block.text, + ) + ) elif event.type == "content_block_stop": - content = self.current_message_snapshot.content[event.index] - self.on_content_block(content) + content_block = self.current_message_snapshot.content[event.index] + self.on_content_block(content_block) + + events_to_fire.append( + ContentBlockStopEvent(type="content_block_stop", index=event.index, content_block=content_block), + ) else: # we only want exhaustive checking for linters, not at runtime if TYPE_CHECKING: # type: ignore[unreachable] assert_never(event) + return events_to_fire + MessageStreamT = TypeVar("MessageStreamT", bound=MessageStream) @@ -213,7 +230,7 @@ class MessageStreamManager(Generic[MessageStreamT]): """ def __init__( - self, api_request: Callable[[], Stream[MessageStreamEvent]], event_handler_cls: type[MessageStreamT] + self, api_request: Callable[[], Stream[RawMessageStreamEvent]], event_handler_cls: type[MessageStreamT] ) -> None: self.__event_handler: MessageStreamT | None = None self.__event_handler_cls: type[MessageStreamT] = event_handler_cls @@ -256,7 +273,7 @@ class AsyncMessageStream: def __init__( self, *, - cast_to: type[MessageStreamEvent], + cast_to: type[RawMessageStreamEvent], response: httpx.Response, client: AsyncAnthropic, ) -> None: @@ -268,7 +285,7 @@ def __init__( self.__final_message_snapshot: Message | None = None self._iterator = self.__stream__() - self._raw_stream: AsyncStream[MessageStreamEvent] = AsyncStream( + self._raw_stream: AsyncStream[RawMessageStreamEvent] = AsyncStream( cast_to=cast_to, response=response, client=client ) @@ -337,7 +354,7 @@ def current_message_snapshot(self) -> Message: return self.__final_message_snapshot # event handlers - async def on_stream_event(self, event: MessageStreamEvent) -> None: + async def on_stream_event(self, event: RawMessageStreamEvent) -> None: """Callback that is fired for every Server-Sent-Event""" async def on_message(self, message: Message) -> None: @@ -387,9 +404,10 @@ async def __stream__(self) -> AsyncIterator[MessageStreamEvent]: event=sse_event, current_snapshot=self.__final_message_snapshot, ) - await self._emit_sse_event(sse_event) - yield sse_event + events_to_fire = await self._emit_sse_event(sse_event) + for event in events_to_fire: + yield event except (httpx.TimeoutException, asyncio.TimeoutError) as exc: await self.on_timeout() await self.on_exception(exc) @@ -405,35 +423,47 @@ async def __stream_text__(self) -> AsyncIterator[str]: if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": yield chunk.delta.text - async def _emit_sse_event(self, event: MessageStreamEvent) -> None: + async def _emit_sse_event(self, event: RawMessageStreamEvent) -> list[MessageStreamEvent]: await self.on_stream_event(event) + events_to_fire: list[MessageStreamEvent] = [] + if event.type == "message_start": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "message_delta": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "message_stop": await self.on_message(self.current_message_snapshot) + events_to_fire.append(MessageStopEvent(type="message_stop", message=self.current_message_snapshot)) elif event.type == "content_block_start": - # nothing special we want to fire here - pass + events_to_fire.append(event) elif event.type == "content_block_delta": - content = self.current_message_snapshot.content[event.index] - if event.delta.type == "text_delta" and content.type == "text": - await self.on_text(event.delta.text, content.text) + events_to_fire.append(event) + + content_block = self.current_message_snapshot.content[event.index] + if event.delta.type == "text_delta" and content_block.type == "text": + await self.on_text(event.delta.text, content_block.text) + events_to_fire.append( + TextEvent( + type="text", + text=event.delta.text, + snapshot=content_block.text, + ) + ) elif event.type == "content_block_stop": - content = self.current_message_snapshot.content[event.index] - await self.on_content_block(content) + content_block = self.current_message_snapshot.content[event.index] + await self.on_content_block(content_block) - if content.type == "text": - await self.on_final_text(content.text) + events_to_fire.append( + ContentBlockStopEvent(type="content_block_stop", index=event.index, content_block=content_block), + ) else: # we only want exhaustive checking for linters, not at runtime if TYPE_CHECKING: # type: ignore[unreachable] assert_never(event) + return events_to_fire + AsyncMessageStreamT = TypeVar("AsyncMessageStreamT", bound=AsyncMessageStream) @@ -451,7 +481,7 @@ class AsyncMessageStreamManager(Generic[AsyncMessageStreamT]): """ def __init__( - self, api_request: Awaitable[AsyncStream[MessageStreamEvent]], event_handler_cls: type[AsyncMessageStreamT] + self, api_request: Awaitable[AsyncStream[RawMessageStreamEvent]], event_handler_cls: type[AsyncMessageStreamT] ) -> None: self.__event_handler: AsyncMessageStreamT | None = None self.__event_handler_cls: type[AsyncMessageStreamT] = event_handler_cls @@ -478,7 +508,7 @@ async def __aexit__( await self.__event_handler.close() -def accumulate_event(*, event: MessageStreamEvent, current_snapshot: Message | None) -> Message: +def accumulate_event(*, event: RawMessageStreamEvent, current_snapshot: Message | None) -> Message: if current_snapshot is None: if event.type == "message_start": return event.message diff --git a/src/anthropic/lib/streaming/_types.py b/src/anthropic/lib/streaming/_types.py new file mode 100644 index 00000000..ad19ae93 --- /dev/null +++ b/src/anthropic/lib/streaming/_types.py @@ -0,0 +1,47 @@ +from typing import Union +from typing_extensions import Literal + +from ...types import ( + Message, + ContentBlock, + MessageDeltaEvent as RawMessageDeltaEvent, + MessageStartEvent as RawMessageStartEvent, + RawMessageStopEvent, + ContentBlockDeltaEvent as RawContentBlockDeltaEvent, + ContentBlockStartEvent as RawContentBlockStartEvent, + RawContentBlockStopEvent, +) +from ..._models import BaseModel + + +class TextEvent(BaseModel): + type: Literal["text"] + + text: str + """The text delta""" + + snapshot: str + """The entire accumulated text""" + + +class MessageStopEvent(RawMessageStopEvent): + type: Literal["message_stop"] + + message: Message + + +class ContentBlockStopEvent(RawContentBlockStopEvent): + type: Literal["content_block_stop"] + + content_block: ContentBlock + + +MessageStreamEvent = Union[ + TextEvent, + RawMessageStartEvent, + RawMessageDeltaEvent, + MessageStopEvent, + RawContentBlockStartEvent, + RawContentBlockDeltaEvent, + ContentBlockStopEvent, +]