Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor send_message API #1633

Merged
merged 6 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions memgpt/llm_api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def openai_chat_completions_process_stream(
on the chunks received from the OpenAI-compatible server POST SSE response.
"""
assert chat_completion_request.stream == True
assert stream_inferface is not None, "Required"

# Count the prompt tokens
# TODO move to post-request?
Expand Down
6 changes: 6 additions & 0 deletions memgpt/schemas/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ class JobStatus(str, Enum):
completed = "completed"
failed = "failed"
pending = "pending"


class MessageStreamStatus(str, Enum):
done_generation = "[DONE_GEN]"
done_step = "[DONE_STEP]"
done = "[DONE]"
78 changes: 78 additions & 0 deletions memgpt/schemas/memgpt_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from datetime import datetime, timezone
from typing import Literal, Union

from pydantic import BaseModel, field_serializer

# MemGPT API style responses (intended to be easier to use vs getting true Message types)


class BaseMemGPTMessage(BaseModel):
id: str
date: datetime

@field_serializer("date")
def serialize_datetime(self, dt: datetime, _info):
return dt.now(timezone.utc).isoformat()


class InternalMonologue(BaseMemGPTMessage):
"""
{
"internal_monologue": msg,
"date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(),
"id": str(msg_obj.id) if msg_obj is not None else None,
}
"""

internal_monologue: str


class FunctionCall(BaseModel):
name: str
arguments: str


class FunctionCallMessage(BaseMemGPTMessage):
"""
{
"function_call": {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
},
"id": str(msg_obj.id),
"date": msg_obj.created_at.isoformat(),
}
"""

function_call: FunctionCall


class FunctionReturn(BaseMemGPTMessage):
"""
{
"function_return": msg,
"status": "success" or "error",
"id": str(msg_obj.id),
"date": msg_obj.created_at.isoformat(),
}
"""

function_return: str
status: Literal["success", "error"]


MemGPTMessage = Union[InternalMonologue, FunctionCallMessage, FunctionReturn]


# Legacy MemGPT API had an additional type "assistant_message" and the "function_call" was a formatted string


class AssistantMessage(BaseMemGPTMessage):
assistant_message: str


class LegacyFunctionCallMessage(BaseMemGPTMessage):
function_call: str


LegacyMemGPTMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, FunctionReturn]
11 changes: 8 additions & 3 deletions memgpt/schemas/memgpt_response.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing import List
from typing import List, Union

from pydantic import BaseModel, Field

from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.message import Message
from memgpt.schemas.usage import MemGPTUsageStatistics


# TODO: consider moving into own file


class MemGPTResponse(BaseModel):
messages: List[Message] = Field(..., description="The messages returned by the agent.")
# messages: List[Message] = Field(..., description="The messages returned by the agent.")
messages: Union[List[Message], List[MemGPTMessage], List[LegacyMemGPTMessage]] = Field(
..., description="The messages returned by the agent."
)
usage: MemGPTUsageStatistics = Field(..., description="The usage statistics of the agent.")
10 changes: 9 additions & 1 deletion memgpt/schemas/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import json
import warnings
from datetime import datetime, timezone
from typing import List, Optional
from typing import List, Optional, Union

from pydantic import Field, field_validator

from memgpt.constants import JSON_ENSURE_ASCII, TOOL_CALL_ID_MAX_LEN
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG
from memgpt.schemas.enums import MessageRole
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.openai.chat_completions import ToolCall
from memgpt.utils import get_utc_time, is_utc_datetime

Expand Down Expand Up @@ -87,6 +88,13 @@ def to_json(self):
json_message["created_at"] = self.created_at.isoformat()
return json_message

def to_memgpt_message(self) -> Union[List[MemGPTMessage], List[LegacyMemGPTMessage]]:
"""Convert message object (in DB format) to the style used by the original MemGPT API

NOTE: this may split the message into two pieces (e.g. if the assistant has inner thoughts + function call)
"""
raise NotImplementedError

@staticmethod
def dict_to_message(
user_id: str,
Expand Down
42 changes: 35 additions & 7 deletions memgpt/server/rest_api/agents/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from starlette.responses import StreamingResponse

from memgpt.schemas.enums import MessageRole, MessageStreamStatus
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.memgpt_request import MemGPTRequest
from memgpt.schemas.memgpt_response import MemGPTResponse
from memgpt.schemas.message import Message
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface, StreamingServerInterface
from memgpt.server.rest_api.utils import sse_async_generator
from memgpt.server.server import SyncServer
from memgpt.utils import deduplicate

router = APIRouter()

Expand All @@ -23,21 +26,23 @@ async def send_message_to_agent(
server: SyncServer,
agent_id: str,
user_id: str,
role: str,
role: MessageRole,
message: str,
stream_steps: bool,
stream_tokens: bool,
chat_completion_mode: Optional[bool] = False,
timestamp: Optional[datetime] = None,
# related to whether or not we return `MemGPTMessage`s or `Message`s
return_message_object: bool = True, # Should be True for Python Client, False for REST API
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sarahwooders this isn't properly connected anywhere yet - probably this should be set to False by default, and the Python + REST clients should set it to True manually

) -> Union[StreamingResponse, MemGPTResponse]:
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""
# TODO: @charles is this the correct way to handle?
include_final_message = True

# determine role
if role == "user" or role is None:
if role == MessageRole.user:
message_func = server.user_message
elif role == "system":
elif role == MessageRole.system:
message_func = server.system_message
else:
raise HTTPException(status_code=500, detail=f"Bad role {role}")
Expand Down Expand Up @@ -72,21 +77,44 @@ async def send_message_to_agent(
)

if stream_steps:
if return_message_object:
# TODO implement returning `Message`s in a stream, not just `MemGPTMessage` format
raise NotImplementedError

# return a stream
return StreamingResponse(
sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message),
media_type="text/event-stream",
)

else:
# buffer the stream, then return the list
generated_stream = []
async for message in streaming_interface.get_generator():
assert (
isinstance(message, MemGPTMessage)
or isinstance(message, LegacyMemGPTMessage)
or isinstance(message, MessageStreamStatus)
), type(message)
generated_stream.append(message)
if "data" in message and message["data"] == "[DONE]":
if message == MessageStreamStatus.done:
break
filtered_stream = [d for d in generated_stream if d not in ["[DONE_GEN]", "[DONE_STEP]", "[DONE]"]]

# Get rid of the stream status messages
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
usage = await task
return MemGPTResponse(messages=filtered_stream, usage=usage)

# By default the stream will be messages of type MemGPTMessage or MemGPTLegacyMessage
# If we want to convert these to Message, we can use the attached IDs
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
if return_message_object:
message_ids = [m.id for m in filtered_stream]
message_ids = deduplicate(message_ids)
message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids]
return MemGPTResponse(messages=message_objs, usage=usage)
else:
return MemGPTResponse(messages=filtered_stream, usage=usage)

except HTTPException:
raise
Expand Down Expand Up @@ -146,8 +174,8 @@ async def send_message(
# TODO: revise to `MemGPTRequest`
# TODO: support sending multiple messages
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"

message = request.messages[0]

# TODO: what to do with message.name?
return await send_message_to_agent(
server=server,
Expand Down
Loading
Loading