Skip to content

Commit

Permalink
fix: refactor send_message API (#1633)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored and sarahwooders committed Aug 17, 2024
1 parent 191e86a commit 41358a5
Show file tree
Hide file tree
Showing 10 changed files with 383 additions and 172 deletions.
1 change: 1 addition & 0 deletions memgpt/llm_api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def openai_chat_completions_process_stream(
on the chunks received from the OpenAI-compatible server POST SSE response.
"""
assert chat_completion_request.stream == True
assert stream_inferface is not None, "Required"

# Count the prompt tokens
# TODO move to post-request?
Expand Down
6 changes: 6 additions & 0 deletions memgpt/schemas/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ class JobStatus(str, Enum):
completed = "completed"
failed = "failed"
pending = "pending"


class MessageStreamStatus(str, Enum):
done_generation = "[DONE_GEN]"
done_step = "[DONE_STEP]"
done = "[DONE]"
78 changes: 78 additions & 0 deletions memgpt/schemas/memgpt_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from datetime import datetime, timezone
from typing import Literal, Union

from pydantic import BaseModel, field_serializer

# MemGPT API style responses (intended to be easier to use vs getting true Message types)


class BaseMemGPTMessage(BaseModel):
id: str
date: datetime

@field_serializer("date")
def serialize_datetime(self, dt: datetime, _info):
return dt.now(timezone.utc).isoformat()


class InternalMonologue(BaseMemGPTMessage):
"""
{
"internal_monologue": msg,
"date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(),
"id": str(msg_obj.id) if msg_obj is not None else None,
}
"""

internal_monologue: str


class FunctionCall(BaseModel):
name: str
arguments: str


class FunctionCallMessage(BaseMemGPTMessage):
"""
{
"function_call": {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
},
"id": str(msg_obj.id),
"date": msg_obj.created_at.isoformat(),
}
"""

function_call: FunctionCall


class FunctionReturn(BaseMemGPTMessage):
"""
{
"function_return": msg,
"status": "success" or "error",
"id": str(msg_obj.id),
"date": msg_obj.created_at.isoformat(),
}
"""

function_return: str
status: Literal["success", "error"]


MemGPTMessage = Union[InternalMonologue, FunctionCallMessage, FunctionReturn]


# Legacy MemGPT API had an additional type "assistant_message" and the "function_call" was a formatted string


class AssistantMessage(BaseMemGPTMessage):
assistant_message: str


class LegacyFunctionCallMessage(BaseMemGPTMessage):
function_call: str


LegacyMemGPTMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, FunctionReturn]
11 changes: 8 additions & 3 deletions memgpt/schemas/memgpt_response.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing import List
from typing import List, Union

from pydantic import BaseModel, Field

from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.message import Message
from memgpt.schemas.usage import MemGPTUsageStatistics


# TODO: consider moving into own file


class MemGPTResponse(BaseModel):
messages: List[Message] = Field(..., description="The messages returned by the agent.")
# messages: List[Message] = Field(..., description="The messages returned by the agent.")
messages: Union[List[Message], List[MemGPTMessage], List[LegacyMemGPTMessage]] = Field(
..., description="The messages returned by the agent."
)
usage: MemGPTUsageStatistics = Field(..., description="The usage statistics of the agent.")
10 changes: 9 additions & 1 deletion memgpt/schemas/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import json
import warnings
from datetime import datetime, timezone
from typing import List, Optional
from typing import List, Optional, Union

from pydantic import Field, field_validator

from memgpt.constants import JSON_ENSURE_ASCII, TOOL_CALL_ID_MAX_LEN
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG
from memgpt.schemas.enums import MessageRole
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.openai.chat_completions import ToolCall
from memgpt.utils import get_utc_time, is_utc_datetime

Expand Down Expand Up @@ -87,6 +88,13 @@ def to_json(self):
json_message["created_at"] = self.created_at.isoformat()
return json_message

def to_memgpt_message(self) -> Union[List[MemGPTMessage], List[LegacyMemGPTMessage]]:
"""Convert message object (in DB format) to the style used by the original MemGPT API
NOTE: this may split the message into two pieces (e.g. if the assistant has inner thoughts + function call)
"""
raise NotImplementedError

@staticmethod
def dict_to_message(
user_id: str,
Expand Down
42 changes: 35 additions & 7 deletions memgpt/server/rest_api/agents/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from starlette.responses import StreamingResponse

from memgpt.schemas.enums import MessageRole, MessageStreamStatus
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.memgpt_request import MemGPTRequest
from memgpt.schemas.memgpt_response import MemGPTResponse
from memgpt.schemas.message import Message
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface, StreamingServerInterface
from memgpt.server.rest_api.utils import sse_async_generator
from memgpt.server.server import SyncServer
from memgpt.utils import deduplicate

router = APIRouter()

Expand All @@ -23,22 +26,24 @@ async def send_message_to_agent(
server: SyncServer,
agent_id: str,
user_id: str,
role: str,
role: MessageRole,
message: str,
stream_steps: bool,
stream_tokens: bool,
chat_completion_mode: Optional[bool] = False,
timestamp: Optional[datetime] = None,
# related to whether or not we return `MemGPTMessage`s or `Message`s
return_message_object: bool = True, # Should be True for Python Client, False for REST API
) -> Union[StreamingResponse, MemGPTResponse]:
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""

# TODO: @charles is this the correct way to handle?
include_final_message = True

# determine role
if role == "user" or role is None:
if role == MessageRole.user:
message_func = server.user_message
elif role == "system":
elif role == MessageRole.system:
message_func = server.system_message
else:
raise HTTPException(status_code=500, detail=f"Bad role {role}")
Expand Down Expand Up @@ -73,21 +78,44 @@ async def send_message_to_agent(
)

if stream_steps:
if return_message_object:
# TODO implement returning `Message`s in a stream, not just `MemGPTMessage` format
raise NotImplementedError

# return a stream
return StreamingResponse(
sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message),
media_type="text/event-stream",
)

else:
# buffer the stream, then return the list
generated_stream = []
async for message in streaming_interface.get_generator():
assert (
isinstance(message, MemGPTMessage)
or isinstance(message, LegacyMemGPTMessage)
or isinstance(message, MessageStreamStatus)
), type(message)
generated_stream.append(message)
if "data" in message and message["data"] == "[DONE]":
if message == MessageStreamStatus.done:
break
filtered_stream = [d for d in generated_stream if d not in ["[DONE_GEN]", "[DONE_STEP]", "[DONE]"]]

# Get rid of the stream status messages
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
usage = await task
return MemGPTResponse(messages=filtered_stream, usage=usage)

# By default the stream will be messages of type MemGPTMessage or MemGPTLegacyMessage
# If we want to convert these to Message, we can use the attached IDs
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
if return_message_object:
message_ids = [m.id for m in filtered_stream]
message_ids = deduplicate(message_ids)
message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids]
return MemGPTResponse(messages=message_objs, usage=usage)
else:
return MemGPTResponse(messages=filtered_stream, usage=usage)

except HTTPException:
raise
Expand Down Expand Up @@ -147,8 +175,8 @@ async def send_message(
# TODO: revise to `MemGPTRequest`
# TODO: support sending multiple messages
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"

message = request.messages[0]

# TODO: what to do with message.name?
return await send_message_to_agent(
server=server,
Expand Down
Loading

0 comments on commit 41358a5

Please sign in to comment.