Skip to content

Commit

Permalink
feat(streaming/messages): refactor to event iterator structure
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed May 30, 2024
1 parent bb62980 commit 997af69
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 74 deletions.
8 changes: 6 additions & 2 deletions examples/messages_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ async def main() -> None:
],
model="claude-3-opus-20240229",
) as stream:
async for text in stream.text_stream:
print(text, end="", flush=True)
async for event in stream:
if event.type == "text":
print(event.text, end="", flush=True)
elif event.type == "content_block_stop":
print()
print("\ncontent block finished accumulating:", event.content_block)
print()

# you can still get the accumulated final message outside of
Expand Down
32 changes: 0 additions & 32 deletions examples/messages_stream_handler.py

This file was deleted.

6 changes: 6 additions & 0 deletions src/anthropic/lib/streaming/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from ._types import (
TextEvent as TextEvent,
MessageStopEvent as MessageStopEvent,
MessageStreamEvent as MessageStreamEvent,
ContentBlockStopEvent as ContentBlockStopEvent,
)
from ._messages import (
MessageStream as MessageStream,
MessageStreamT as MessageStreamT,
Expand Down
110 changes: 70 additions & 40 deletions src/anthropic/lib/streaming/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import httpx

from ...types import Message, ContentBlock, MessageStreamEvent
from ._types import TextEvent, MessageStopEvent, MessageStreamEvent, ContentBlockStopEvent
from ...types import Message, ContentBlock, RawMessageStreamEvent
from ..._utils import consume_sync_iterator, consume_async_iterator
from ..._streaming import Stream, AsyncStream

Expand All @@ -31,7 +32,7 @@ class MessageStream:
def __init__(
self,
*,
cast_to: type[MessageStreamEvent],
cast_to: type[RawMessageStreamEvent],
response: httpx.Response,
client: Anthropic,
) -> None:
Expand All @@ -43,7 +44,7 @@ def __init__(
self.__final_message_snapshot: Message | None = None

self._iterator = self.__stream__()
self._raw_stream: Stream[MessageStreamEvent] = Stream(cast_to=cast_to, response=response, client=client)
self._raw_stream: Stream[RawMessageStreamEvent] = Stream(cast_to=cast_to, response=response, client=client)

def __next__(self) -> MessageStreamEvent:
return self._iterator.__next__()
Expand Down Expand Up @@ -110,7 +111,7 @@ def current_message_snapshot(self) -> Message:
return self.__final_message_snapshot

# event handlers
def on_stream_event(self, event: MessageStreamEvent) -> None:
def on_stream_event(self, event: RawMessageStreamEvent) -> None:
"""Callback that is fired for every Server-Sent-Event"""

def on_message(self, message: Message) -> None:
Expand Down Expand Up @@ -154,9 +155,10 @@ def __stream__(self) -> Iterator[MessageStreamEvent]:
event=sse_event,
current_snapshot=self.__final_message_snapshot,
)
self._emit_sse_event(sse_event)

yield sse_event
events_to_fire = self._emit_sse_event(sse_event)
for event in events_to_fire:
yield event
except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
self.on_timeout()
self.on_exception(exc)
Expand All @@ -172,32 +174,47 @@ def __stream_text__(self) -> Iterator[str]:
if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
yield chunk.delta.text

def _emit_sse_event(self, event: MessageStreamEvent) -> None:
def _emit_sse_event(self, event: RawMessageStreamEvent) -> list[MessageStreamEvent]:
self.on_stream_event(event)

events_to_fire: list[MessageStreamEvent] = []

if event.type == "message_start":
# nothing special we want to fire here
pass
events_to_fire.append(event)
elif event.type == "message_delta":
# nothing special we want to fire here
pass
events_to_fire.append(event)
elif event.type == "message_stop":
self.on_message(self.current_message_snapshot)
events_to_fire.append(MessageStopEvent(type="message_stop", message=self.current_message_snapshot))
elif event.type == "content_block_start":
# nothing special we want to fire here
pass
events_to_fire.append(event)
elif event.type == "content_block_delta":
content = self.current_message_snapshot.content[event.index]
if event.delta.type == "text_delta" and content.type == "text":
self.on_text(event.delta.text, content.text)
events_to_fire.append(event)

content_block = self.current_message_snapshot.content[event.index]
if event.delta.type == "text_delta" and content_block.type == "text":
self.on_text(event.delta.text, content_block.text)
events_to_fire.append(
TextEvent(
type="text",
text=event.delta.text,
snapshot=content_block.text,
)
)
elif event.type == "content_block_stop":
content = self.current_message_snapshot.content[event.index]
self.on_content_block(content)
content_block = self.current_message_snapshot.content[event.index]
self.on_content_block(content_block)

events_to_fire.append(
ContentBlockStopEvent(type="content_block_stop", index=event.index, content_block=content_block),
)
else:
# we only want exhaustive checking for linters, not at runtime
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event)

return events_to_fire


MessageStreamT = TypeVar("MessageStreamT", bound=MessageStream)

Expand All @@ -213,7 +230,7 @@ class MessageStreamManager(Generic[MessageStreamT]):
"""

def __init__(
self, api_request: Callable[[], Stream[MessageStreamEvent]], event_handler_cls: type[MessageStreamT]
self, api_request: Callable[[], Stream[RawMessageStreamEvent]], event_handler_cls: type[MessageStreamT]
) -> None:
self.__event_handler: MessageStreamT | None = None
self.__event_handler_cls: type[MessageStreamT] = event_handler_cls
Expand Down Expand Up @@ -256,7 +273,7 @@ class AsyncMessageStream:
def __init__(
self,
*,
cast_to: type[MessageStreamEvent],
cast_to: type[RawMessageStreamEvent],
response: httpx.Response,
client: AsyncAnthropic,
) -> None:
Expand All @@ -268,7 +285,7 @@ def __init__(
self.__final_message_snapshot: Message | None = None

self._iterator = self.__stream__()
self._raw_stream: AsyncStream[MessageStreamEvent] = AsyncStream(
self._raw_stream: AsyncStream[RawMessageStreamEvent] = AsyncStream(
cast_to=cast_to, response=response, client=client
)

Expand Down Expand Up @@ -337,7 +354,7 @@ def current_message_snapshot(self) -> Message:
return self.__final_message_snapshot

# event handlers
async def on_stream_event(self, event: MessageStreamEvent) -> None:
async def on_stream_event(self, event: RawMessageStreamEvent) -> None:
"""Callback that is fired for every Server-Sent-Event"""

async def on_message(self, message: Message) -> None:
Expand Down Expand Up @@ -387,9 +404,10 @@ async def __stream__(self) -> AsyncIterator[MessageStreamEvent]:
event=sse_event,
current_snapshot=self.__final_message_snapshot,
)
await self._emit_sse_event(sse_event)

yield sse_event
events_to_fire = await self._emit_sse_event(sse_event)
for event in events_to_fire:
yield event
except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
await self.on_timeout()
await self.on_exception(exc)
Expand All @@ -405,35 +423,47 @@ async def __stream_text__(self) -> AsyncIterator[str]:
if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
yield chunk.delta.text

async def _emit_sse_event(self, event: MessageStreamEvent) -> None:
async def _emit_sse_event(self, event: RawMessageStreamEvent) -> list[MessageStreamEvent]:
await self.on_stream_event(event)

events_to_fire: list[MessageStreamEvent] = []

if event.type == "message_start":
# nothing special we want to fire here
pass
events_to_fire.append(event)
elif event.type == "message_delta":
# nothing special we want to fire here
pass
events_to_fire.append(event)
elif event.type == "message_stop":
await self.on_message(self.current_message_snapshot)
events_to_fire.append(MessageStopEvent(type="message_stop", message=self.current_message_snapshot))
elif event.type == "content_block_start":
# nothing special we want to fire here
pass
events_to_fire.append(event)
elif event.type == "content_block_delta":
content = self.current_message_snapshot.content[event.index]
if event.delta.type == "text_delta" and content.type == "text":
await self.on_text(event.delta.text, content.text)
events_to_fire.append(event)

content_block = self.current_message_snapshot.content[event.index]
if event.delta.type == "text_delta" and content_block.type == "text":
await self.on_text(event.delta.text, content_block.text)
events_to_fire.append(
TextEvent(
type="text",
text=event.delta.text,
snapshot=content_block.text,
)
)
elif event.type == "content_block_stop":
content = self.current_message_snapshot.content[event.index]
await self.on_content_block(content)
content_block = self.current_message_snapshot.content[event.index]
await self.on_content_block(content_block)

if content.type == "text":
await self.on_final_text(content.text)
events_to_fire.append(
ContentBlockStopEvent(type="content_block_stop", index=event.index, content_block=content_block),
)
else:
# we only want exhaustive checking for linters, not at runtime
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event)

return events_to_fire


AsyncMessageStreamT = TypeVar("AsyncMessageStreamT", bound=AsyncMessageStream)

Expand All @@ -451,7 +481,7 @@ class AsyncMessageStreamManager(Generic[AsyncMessageStreamT]):
"""

def __init__(
self, api_request: Awaitable[AsyncStream[MessageStreamEvent]], event_handler_cls: type[AsyncMessageStreamT]
self, api_request: Awaitable[AsyncStream[RawMessageStreamEvent]], event_handler_cls: type[AsyncMessageStreamT]
) -> None:
self.__event_handler: AsyncMessageStreamT | None = None
self.__event_handler_cls: type[AsyncMessageStreamT] = event_handler_cls
Expand All @@ -478,7 +508,7 @@ async def __aexit__(
await self.__event_handler.close()


def accumulate_event(*, event: MessageStreamEvent, current_snapshot: Message | None) -> Message:
def accumulate_event(*, event: RawMessageStreamEvent, current_snapshot: Message | None) -> Message:
if current_snapshot is None:
if event.type == "message_start":
return event.message
Expand Down
47 changes: 47 additions & 0 deletions src/anthropic/lib/streaming/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Union
from typing_extensions import Literal

from ...types import (
Message,
ContentBlock,
MessageDeltaEvent as RawMessageDeltaEvent,
MessageStartEvent as RawMessageStartEvent,
RawMessageStopEvent,
ContentBlockDeltaEvent as RawContentBlockDeltaEvent,
ContentBlockStartEvent as RawContentBlockStartEvent,
RawContentBlockStopEvent,
)
from ..._models import BaseModel


class TextEvent(BaseModel):
type: Literal["text"]

text: str
"""The text delta"""

snapshot: str
"""The entire accumulated text"""


class MessageStopEvent(RawMessageStopEvent):
type: Literal["message_stop"]

message: Message


class ContentBlockStopEvent(RawContentBlockStopEvent):
type: Literal["content_block_stop"]

content_block: ContentBlock


MessageStreamEvent = Union[
TextEvent,
RawMessageStartEvent,
RawMessageDeltaEvent,
MessageStopEvent,
RawContentBlockStartEvent,
RawContentBlockDeltaEvent,
ContentBlockStopEvent,
]

0 comments on commit 997af69

Please sign in to comment.