From fd0a90790a6a346c174f2011e3d607f35965269c Mon Sep 17 00:00:00 2001 From: cpacker Date: Sun, 11 Aug 2024 18:06:32 -0700 Subject: [PATCH 1/5] fix: fix hanging send_message by wrapping server.step in try/catch + add a timeout to the generator --- memgpt/llm_api/openai.py | 1 + memgpt/schemas/enums.py | 6 ++ memgpt/schemas/message.py | 13 +++ memgpt/server/rest_api/agents/message.py | 22 ++-- memgpt/server/rest_api/interface.py | 132 ++++++++++++++--------- memgpt/server/server.py | 111 ++++++++++--------- 6 files changed, 174 insertions(+), 111 deletions(-) diff --git a/memgpt/llm_api/openai.py b/memgpt/llm_api/openai.py index 13c6df3b33..e0915bdabf 100644 --- a/memgpt/llm_api/openai.py +++ b/memgpt/llm_api/openai.py @@ -89,6 +89,7 @@ def openai_chat_completions_process_stream( on the chunks received from the OpenAI-compatible server POST SSE response. """ assert chat_completion_request.stream == True + assert stream_inferface is not None, "Required" # Count the prompt tokens # TODO move to post-request? diff --git a/memgpt/schemas/enums.py b/memgpt/schemas/enums.py index 9752aa3f1b..c8ebae734e 100644 --- a/memgpt/schemas/enums.py +++ b/memgpt/schemas/enums.py @@ -22,3 +22,9 @@ class JobStatus(str, Enum): completed = "completed" failed = "failed" pending = "pending" + + +class MessageStreamStatus(str, Enum): + done_generation = "[DONE_GEN]" + done_step = "[DONE_STEP]" + done = "[DONE]" diff --git a/memgpt/schemas/message.py b/memgpt/schemas/message.py index fb8c16cd33..3e403c381c 100644 --- a/memgpt/schemas/message.py +++ b/memgpt/schemas/message.py @@ -48,6 +48,19 @@ class MessageCreate(BaseMessage): name: Optional[str] = Field(None, description="The name of the participant.") +""" + +Assistant messages: +-> inner thoughts +-> tool calls + +User messages: + + + +""" + + class Message(BaseMessage): """ Representation of a message sent. diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py index 740df7d3d9..e3f93cdb71 100644 --- a/memgpt/server/rest_api/agents/message.py +++ b/memgpt/server/rest_api/agents/message.py @@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Query from starlette.responses import StreamingResponse +from memgpt.schemas.enums import MessageRole, MessageStreamStatus from memgpt.schemas.memgpt_request import MemGPTRequest from memgpt.schemas.memgpt_response import MemGPTResponse from memgpt.schemas.message import Message @@ -23,7 +24,7 @@ async def send_message_to_agent( server: SyncServer, agent_id: str, user_id: str, - role: str, + role: MessageRole, message: str, stream_steps: bool, stream_tokens: bool, @@ -36,9 +37,9 @@ async def send_message_to_agent( include_final_message = True # determine role - if role == "user" or role is None: + if role == MessageRole.user: message_func = server.user_message - elif role == "system": + elif role == MessageRole.system: message_func = server.system_message else: raise HTTPException(status_code=500, detail=f"Bad role {role}") @@ -83,9 +84,18 @@ async def send_message_to_agent( generated_stream = [] async for message in streaming_interface.get_generator(): generated_stream.append(message) - if "data" in message and message["data"] == "[DONE]": + if "data" in message and message["data"] == MessageStreamStatus.done: break - filtered_stream = [d for d in generated_stream if d not in ["[DONE_GEN]", "[DONE_STEP]", "[DONE]"]] + filtered_stream = [ + d + for d in generated_stream + if d + not in [ + MessageStreamStatus.done_generation, + MessageStreamStatus.done_step, + MessageStreamStatus.done, + ] + ] usage = await task return MemGPTResponse(messages=filtered_stream, usage=usage) @@ -147,8 +157,8 @@ async def send_message( # TODO: revise to `MemGPTRequest` # TODO: support sending multiple messages assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}" - message = request.messages[0] + # TODO: what to do with message.name? return await send_message_to_agent( server=server, diff --git a/memgpt/server/rest_api/interface.py b/memgpt/server/rest_api/interface.py index 70f5055b9b..8e29495217 100644 --- a/memgpt/server/rest_api/interface.py +++ b/memgpt/server/rest_api/interface.py @@ -5,6 +5,7 @@ from typing import AsyncGenerator, Literal, Optional, Union from memgpt.interface import AgentInterface +from memgpt.schemas.enums import MessageStreamStatus from memgpt.schemas.message import Message from memgpt.schemas.openai.chat_completion_response import ChatCompletionChunkResponse from memgpt.streaming_interface import AgentChunkStreamingInterface @@ -305,14 +306,21 @@ def __init__(self, multi_step=True): # if multi_step = True, the stream ends when the agent yields # if multi_step = False, the stream ends when the step ends self.multi_step = multi_step - self.multi_step_indicator = "[DONE_STEP]" - self.multi_step_gen_indicator = "[DONE_GEN]" + self.multi_step_indicator = MessageStreamStatus.done_step + self.multi_step_gen_indicator = MessageStreamStatus.done_generation + + # extra prints + self.debug = False + self.timeout = 30 async def _create_generator(self) -> AsyncGenerator: """An asynchronous generator that yields chunks as they become available.""" while self._active: - # Wait until there is an item in the deque or the stream is deactivated - await self._event.wait() + try: + # Wait until there is an item in the deque or the stream is deactivated + await asyncio.wait_for(self._event.wait(), timeout=self.timeout) # 30 second timeout + except asyncio.TimeoutError: + break # Exit the loop if we timeout while self._chunks: yield self._chunks.popleft() @@ -320,6 +328,31 @@ async def _create_generator(self) -> AsyncGenerator: # Reset the event until a new item is pushed self._event.clear() + # while self._active: + # # Wait until there is an item in the deque or the stream is deactivated + # await self._event.wait() + + # while self._chunks: + # yield self._chunks.popleft() + + # # Reset the event until a new item is pushed + # self._event.clear() + + def get_generator(self) -> AsyncGenerator: + """Get the generator that yields processed chunks.""" + if not self._active: + # If the stream is not active, don't return a generator that would produce values + raise StopIteration("The stream has not been started or has been ended.") + return self._create_generator() + + def _push_to_buffer(self, item: Union[dict, MessageStreamStatus]): + """Add an item to the deque""" + print("zzz", item) + assert self._active, "Generator is inactive" + assert isinstance(item, dict) or isinstance(item, MessageStreamStatus), f"Wrong type: {type(item)}" + self._chunks.append(item) + self._event.set() # Signal that new data is available + def stream_start(self): """Initialize streaming by activating the generator and clearing any old chunks.""" self.streaming_chat_completion_mode_function_name = None @@ -334,8 +367,10 @@ def stream_end(self): self.streaming_chat_completion_mode_function_name = None if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: - self._chunks.append(self.multi_step_gen_indicator) - self._event.set() # Signal that new data is available + self._push_to_buffer(self.multi_step_gen_indicator) + + # self._active = False + # self._event.set() # Unblock the generator if it's waiting to allow it to complete # if not self.multi_step: # # end the stream @@ -346,6 +381,31 @@ def stream_end(self): # self._chunks.append(self.multi_step_indicator) # self._event.set() # Signal that new data is available + def step_complete(self): + """Signal from the agent that one 'step' finished (step = LLM response + tool execution)""" + print("zzz step_complete") + + if not self.multi_step: + # end the stream + self._active = False + self._event.set() # Unblock the generator if it's waiting to allow it to complete + elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: + # signal that a new step has started in the stream + self._push_to_buffer(self.multi_step_indicator) + + def step_yield(self): + """If multi_step, this is the true 'stream_end' function.""" + print("zzz step_yield") + + # if self.multi_step: + # end the stream + self._active = False + self._event.set() # Unblock the generator if it's waiting to allow it to complete + + @staticmethod + def clear(): + return + def _process_chunk_to_memgpt_style(self, chunk: ChatCompletionChunkResponse) -> Optional[dict]: """ Example data from non-streaming response looks like: @@ -471,15 +531,7 @@ def process_chunk(self, chunk: ChatCompletionChunkResponse, msg_obj: Optional[Me if msg_obj: processed_chunk["id"] = str(msg_obj.id) - self._chunks.append(processed_chunk) - self._event.set() # Signal that new data is available - - def get_generator(self) -> AsyncGenerator: - """Get the generator that yields processed chunks.""" - if not self._active: - # If the stream is not active, don't return a generator that would produce values - raise StopIteration("The stream has not been started or has been ended.") - return self._create_generator() + self._push_to_buffer(processed_chunk) def user_message(self, msg: str, msg_obj: Optional[Message] = None): """MemGPT receives a user message""" @@ -496,8 +548,7 @@ def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): "id": str(msg_obj.id) if msg_obj is not None else None, } - self._chunks.append(processed_chunk) - self._event.set() # Signal that new data is available + self._push_to_buffer(processed_chunk) return @@ -539,26 +590,24 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None): # "date": "2024-06-22T23:04:32.141923+00:00" # } try: - func_args = json.loads(function_call.function["arguments"]) + func_args = json.loads(function_call.function.arguments) except: - func_args = function_call.function["arguments"] + func_args = function_call.function.arguments processed_chunk = { - "function_call": f"{function_call.function['name']}({func_args})", + "function_call": f"{function_call.function.name}({func_args})", "id": str(msg_obj.id), "date": msg_obj.created_at.isoformat(), } - self._chunks.append(processed_chunk) - self._event.set() # Signal that new data is available + self._push_to_buffer(processed_chunk) - if function_call.function["name"] == "send_message": + if function_call.function.name == "send_message": try: processed_chunk = { "assistant_message": func_args["message"], "id": str(msg_obj.id), "date": msg_obj.created_at.isoformat(), } - self._chunks.append(processed_chunk) - self._event.set() # Signal that new data is available + self._push_to_buffer(processed_chunk) except Exception as e: print(f"Failed to parse function message: {e}") @@ -567,14 +616,14 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None): processed_chunk = { "function_call": { # "id": function_call.id, - "name": function_call.function["name"], - "arguments": function_call.function["arguments"], + # "name": function_call.function["name"], + "name": function_call.function.name, + "arguments": function_call.function.arguments, }, "id": str(msg_obj.id), "date": msg_obj.created_at.isoformat(), } - self._chunks.append(processed_chunk) - self._event.set() # Signal that new data is available + self._push_to_buffer(processed_chunk) return else: @@ -605,27 +654,4 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None): assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at new_message["date"] = msg_obj.created_at.isoformat() - self._chunks.append(new_message) - self._event.set() # Signal that new data is available - - def step_complete(self): - """Signal from the agent that one 'step' finished (step = LLM response + tool execution)""" - if not self.multi_step: - # end the stream - self._active = False - self._event.set() # Unblock the generator if it's waiting to allow it to complete - elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: - # signal that a new step has started in the stream - self._chunks.append(self.multi_step_indicator) - self._event.set() # Signal that new data is available - - def step_yield(self): - """If multi_step, this is the true 'stream_end' function.""" - if self.multi_step: - # end the stream - self._active = False - self._event.set() # Unblock the generator if it's waiting to allow it to complete - - @staticmethod - def clear(): - return + self._push_to_buffer(new_message) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 664cced420..bbd3e03241 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -3,6 +3,7 @@ import inspect import json import os +import traceback import warnings from abc import abstractmethod from datetime import datetime @@ -387,62 +388,68 @@ def _step( self, user_id: str, agent_id: str, input_message: Union[str, Message], timestamp: Optional[datetime] ) -> MemGPTUsageStatistics: """Send the input message through the agent""" - logger.debug(f"Got input message: {input_message}") + try: - # Get the agent object (loaded in memory) - memgpt_agent = self._get_or_load_agent(agent_id=agent_id) - if memgpt_agent is None: - raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is not loaded") - - # Determine whether or not to token stream based on the capability of the interface - token_streaming = memgpt_agent.interface.streaming_mode if hasattr(memgpt_agent.interface, "streaming_mode") else False - - logger.debug(f"Starting agent step") - no_verify = True - next_input_message = input_message - counter = 0 - total_usage = UsageStatistics() - step_count = 0 - while True: - new_messages, heartbeat_request, function_failed, token_warning, usage = memgpt_agent.step( - next_input_message, - first_message=False, - skip_verify=no_verify, - return_dicts=False, - stream=token_streaming, - timestamp=timestamp, - ms=self.ms, - ) - step_count += 1 - total_usage += usage - counter += 1 - memgpt_agent.interface.step_complete() - - # Chain stops - if not self.chaining: - logger.debug("No chaining, stopping after one step") - break - elif self.max_chaining_steps is not None and counter > self.max_chaining_steps: - logger.debug(f"Hit max chaining steps, stopping after {counter} steps") - break - # Chain handlers - elif token_warning: - next_input_message = system.get_token_limit_warning() - continue # always chain - elif function_failed: - next_input_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE) - continue # always chain - elif heartbeat_request: - next_input_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE) - continue # always chain - # MemGPT no-op / yield - else: - break + # Get the agent object (loaded in memory) + memgpt_agent = self._get_or_load_agent(agent_id=agent_id) + if memgpt_agent is None: + raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is not loaded") + + # Determine whether or not to token stream based on the capability of the interface + token_streaming = memgpt_agent.interface.streaming_mode if hasattr(memgpt_agent.interface, "streaming_mode") else False + + logger.debug(f"Starting agent step") + no_verify = True + next_input_message = input_message + counter = 0 + total_usage = UsageStatistics() + step_count = 0 + while True: + new_messages, heartbeat_request, function_failed, token_warning, usage = memgpt_agent.step( + next_input_message, + first_message=False, + skip_verify=no_verify, + return_dicts=False, + stream=token_streaming, + timestamp=timestamp, + ms=self.ms, + ) + step_count += 1 + total_usage += usage + counter += 1 + memgpt_agent.interface.step_complete() + + # Chain stops + if not self.chaining: + logger.debug("No chaining, stopping after one step") + break + elif self.max_chaining_steps is not None and counter > self.max_chaining_steps: + logger.debug(f"Hit max chaining steps, stopping after {counter} steps") + break + # Chain handlers + elif token_warning: + next_input_message = system.get_token_limit_warning() + continue # always chain + elif function_failed: + next_input_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE) + continue # always chain + elif heartbeat_request: + next_input_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE) + continue # always chain + # MemGPT no-op / yield + else: + break - memgpt_agent.interface.step_yield() - logger.debug(f"Finished agent step") + except Exception as e: + logger.error(f"Error in server._step: {e}") + print(traceback.print_exc()) + raise + finally: + logger.debug("Calling step_yield()") + memgpt_agent.interface.step_yield() + logger.debug("Saving agent state") # save updated state save_agent(memgpt_agent, self.ms) From 1ac6c9ffbb0ebeaff32fa0ba3bf1b116912a460a Mon Sep 17 00:00:00 2001 From: cpacker Date: Sun, 11 Aug 2024 20:19:33 -0700 Subject: [PATCH 2/5] feat: updated MemGPTResponse to take MemGPTMessages in addition to Message --- memgpt/schemas/memgpt_message.py | 78 +++++++++++++++++++++++++++++++ memgpt/schemas/memgpt_response.py | 11 +++-- 2 files changed, 86 insertions(+), 3 deletions(-) create mode 100644 memgpt/schemas/memgpt_message.py diff --git a/memgpt/schemas/memgpt_message.py b/memgpt/schemas/memgpt_message.py new file mode 100644 index 0000000000..5b659b471a --- /dev/null +++ b/memgpt/schemas/memgpt_message.py @@ -0,0 +1,78 @@ +from datetime import datetime, timezone +from typing import Literal, Union + +from pydantic import BaseModel, field_serializer + +# MemGPT API style responses (intended to be easier to use vs getting true Message types) + + +class BaseMemGPTMessage(BaseModel): + id: str + date: datetime + + @field_serializer("date") + def serialize_datetime(self, dt: datetime, _info): + return dt.now(timezone.utc).isoformat() + + +class InternalMonologue(BaseMemGPTMessage): + """ + { + "internal_monologue": msg, + "date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(), + "id": str(msg_obj.id) if msg_obj is not None else None, + } + """ + + internal_monologue: str + + +class FunctionCall(BaseModel): + name: str + arguments: str + + +class FunctionCallMessage(BaseMemGPTMessage): + """ + { + "function_call": { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + }, + "id": str(msg_obj.id), + "date": msg_obj.created_at.isoformat(), + } + """ + + function_call: FunctionCall + + +class FunctionReturn(BaseMemGPTMessage): + """ + { + "function_return": msg, + "status": "success" or "error", + "id": str(msg_obj.id), + "date": msg_obj.created_at.isoformat(), + } + """ + + function_return: str + status: Literal["success", "error"] + + +MemGPTMessage = Union[InternalMonologue, FunctionCallMessage, FunctionReturn] + + +# Legacy MemGPT API had an additional type "assistant_message" and the "function_call" was a formatted string + + +class AssistantMessage(BaseMemGPTMessage): + assistant_message: str + + +class LegacyFunctionCallMessage(BaseMemGPTMessage): + function_call: str + + +LegacyMemGPTMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, FunctionReturn] diff --git a/memgpt/schemas/memgpt_response.py b/memgpt/schemas/memgpt_response.py index b8b56cfd0d..00132e42b8 100644 --- a/memgpt/schemas/memgpt_response.py +++ b/memgpt/schemas/memgpt_response.py @@ -1,12 +1,17 @@ -from typing import List +from typing import List, Union from pydantic import BaseModel, Field +from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage from memgpt.schemas.message import Message from memgpt.schemas.usage import MemGPTUsageStatistics - # TODO: consider moving into own file + + class MemGPTResponse(BaseModel): - messages: List[Message] = Field(..., description="The messages returned by the agent.") + # messages: List[Message] = Field(..., description="The messages returned by the agent.") + messages: Union[List[Message], List[MemGPTMessage], List[LegacyMemGPTMessage]] = Field( + ..., description="The messages returned by the agent." + ) usage: MemGPTUsageStatistics = Field(..., description="The usage statistics of the agent.") From 5347b05f92b0140eec45301e5c60b6c19b694721 Mon Sep 17 00:00:00 2001 From: cpacker Date: Sun, 11 Aug 2024 20:20:48 -0700 Subject: [PATCH 3/5] refactor: change streaming interface to use new pydantic types inside the buffer --- memgpt/schemas/message.py | 10 +- memgpt/server/rest_api/agents/message.py | 42 +++++--- memgpt/server/rest_api/interface.py | 117 ++++++++++++++++------- memgpt/server/rest_api/utils.py | 43 ++++----- memgpt/server/server.py | 7 ++ memgpt/utils.py | 11 +++ 6 files changed, 157 insertions(+), 73 deletions(-) diff --git a/memgpt/schemas/message.py b/memgpt/schemas/message.py index 3e403c381c..d0f9b1ad7b 100644 --- a/memgpt/schemas/message.py +++ b/memgpt/schemas/message.py @@ -2,7 +2,7 @@ import json import warnings from datetime import datetime, timezone -from typing import List, Optional +from typing import List, Optional, Union from pydantic import Field, field_validator @@ -10,6 +10,7 @@ from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG from memgpt.schemas.enums import MessageRole from memgpt.schemas.memgpt_base import MemGPTBase +from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage from memgpt.schemas.openai.chat_completions import ToolCall from memgpt.utils import get_utc_time, is_utc_datetime @@ -100,6 +101,13 @@ def to_json(self): json_message["created_at"] = self.created_at.isoformat() return json_message + def to_memgpt_message(self) -> Union[List[MemGPTMessage], List[LegacyMemGPTMessage]]: + """Convert message object (in DB format) to the style used by the original MemGPT API + + NOTE: this may split the message into two pieces (e.g. if the assistant has inner thoughts + function call) + """ + raise NotImplementedError + @staticmethod def dict_to_message( user_id: str, diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py index e3f93cdb71..df3def7b8d 100644 --- a/memgpt/server/rest_api/agents/message.py +++ b/memgpt/server/rest_api/agents/message.py @@ -7,6 +7,7 @@ from starlette.responses import StreamingResponse from memgpt.schemas.enums import MessageRole, MessageStreamStatus +from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage from memgpt.schemas.memgpt_request import MemGPTRequest from memgpt.schemas.memgpt_response import MemGPTResponse from memgpt.schemas.message import Message @@ -14,6 +15,7 @@ from memgpt.server.rest_api.interface import QueuingInterface, StreamingServerInterface from memgpt.server.rest_api.utils import sse_async_generator from memgpt.server.server import SyncServer +from memgpt.utils import deduplicate router = APIRouter() @@ -30,6 +32,8 @@ async def send_message_to_agent( stream_tokens: bool, chat_completion_mode: Optional[bool] = False, timestamp: Optional[datetime] = None, + # related to whether or not we return `MemGPTMessage`s or `Message`s + return_message_object: bool = True, # Should be True for Python Client, False for REST API ) -> Union[StreamingResponse, MemGPTResponse]: """Split off into a separate function so that it can be imported in the /chat/completion proxy.""" @@ -74,30 +78,44 @@ async def send_message_to_agent( ) if stream_steps: + if return_message_object: + # TODO implement returning `Message`s in a stream, not just `MemGPTMessage` format + raise NotImplementedError + # return a stream return StreamingResponse( sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message), media_type="text/event-stream", ) + else: # buffer the stream, then return the list generated_stream = [] async for message in streaming_interface.get_generator(): + assert ( + isinstance(message, MemGPTMessage) + or isinstance(message, LegacyMemGPTMessage) + or isinstance(message, MessageStreamStatus) + ), type(message) generated_stream.append(message) - if "data" in message and message["data"] == MessageStreamStatus.done: + if message == MessageStreamStatus.done: break - filtered_stream = [ - d - for d in generated_stream - if d - not in [ - MessageStreamStatus.done_generation, - MessageStreamStatus.done_step, - MessageStreamStatus.done, - ] - ] + + # Get rid of the stream status messages + filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)] usage = await task - return MemGPTResponse(messages=filtered_stream, usage=usage) + + # By default the stream will be messages of type MemGPTMessage or MemGPTLegacyMessage + # If we want to convert these to Message, we can use the attached IDs + # NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call) + # TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead + if return_message_object: + message_ids = [m.id for m in filtered_stream] + message_ids = deduplicate(message_ids) + message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids] + return MemGPTResponse(messages=message_objs, usage=usage) + else: + return MemGPTResponse(messages=filtered_stream, usage=usage) except HTTPException: raise diff --git a/memgpt/server/rest_api/interface.py b/memgpt/server/rest_api/interface.py index 8e29495217..26c253a8c2 100644 --- a/memgpt/server/rest_api/interface.py +++ b/memgpt/server/rest_api/interface.py @@ -6,10 +6,20 @@ from memgpt.interface import AgentInterface from memgpt.schemas.enums import MessageStreamStatus +from memgpt.schemas.memgpt_message import ( + AssistantMessage, + FunctionCall, + FunctionCallMessage, + FunctionReturn, + InternalMonologue, + LegacyFunctionCallMessage, + LegacyMemGPTMessage, + MemGPTMessage, +) from memgpt.schemas.message import Message from memgpt.schemas.openai.chat_completion_response import ChatCompletionChunkResponse from memgpt.streaming_interface import AgentChunkStreamingInterface -from memgpt.utils import get_utc_time, is_utc_datetime +from memgpt.utils import is_utc_datetime class QueuingInterface(AgentInterface): @@ -313,7 +323,7 @@ def __init__(self, multi_step=True): self.debug = False self.timeout = 30 - async def _create_generator(self) -> AsyncGenerator: + async def _create_generator(self) -> AsyncGenerator[Union[MemGPTMessage, LegacyMemGPTMessage, MessageStreamStatus], None]: """An asynchronous generator that yields chunks as they become available.""" while self._active: try: @@ -345,11 +355,14 @@ def get_generator(self) -> AsyncGenerator: raise StopIteration("The stream has not been started or has been ended.") return self._create_generator() - def _push_to_buffer(self, item: Union[dict, MessageStreamStatus]): + def _push_to_buffer(self, item: Union[MemGPTMessage, LegacyMemGPTMessage, MessageStreamStatus]): """Add an item to the deque""" print("zzz", item) assert self._active, "Generator is inactive" - assert isinstance(item, dict) or isinstance(item, MessageStreamStatus), f"Wrong type: {type(item)}" + # assert isinstance(item, dict) or isinstance(item, MessageStreamStatus), f"Wrong type: {type(item)}" + assert ( + isinstance(item, MemGPTMessage) or isinstance(item, LegacyMemGPTMessage) or isinstance(item, MessageStreamStatus) + ), f"Wrong type: {type(item)}" self._chunks.append(item) self._event.set() # Signal that new data is available @@ -542,11 +555,16 @@ def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): if not self.streaming_mode: # create a fake "chunk" of a stream - processed_chunk = { - "internal_monologue": msg, - "date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(), - "id": str(msg_obj.id) if msg_obj is not None else None, - } + # processed_chunk = { + # "internal_monologue": msg, + # "date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(), + # "id": str(msg_obj.id) if msg_obj is not None else None, + # } + processed_chunk = InternalMonologue( + id=msg_obj.id, + date=msg_obj.created_at, + internal_monologue=msg, + ) self._push_to_buffer(processed_chunk) @@ -593,36 +611,52 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None): func_args = json.loads(function_call.function.arguments) except: func_args = function_call.function.arguments - processed_chunk = { - "function_call": f"{function_call.function.name}({func_args})", - "id": str(msg_obj.id), - "date": msg_obj.created_at.isoformat(), - } + # processed_chunk = { + # "function_call": f"{function_call.function.name}({func_args})", + # "id": str(msg_obj.id), + # "date": msg_obj.created_at.isoformat(), + # } + processed_chunk = LegacyFunctionCallMessage( + id=msg_obj.id, + date=msg_obj.created_at, + function_call=f"{function_call.function.name}({func_args})", + ) self._push_to_buffer(processed_chunk) if function_call.function.name == "send_message": try: - processed_chunk = { - "assistant_message": func_args["message"], - "id": str(msg_obj.id), - "date": msg_obj.created_at.isoformat(), - } + # processed_chunk = { + # "assistant_message": func_args["message"], + # "id": str(msg_obj.id), + # "date": msg_obj.created_at.isoformat(), + # } + processed_chunk = AssistantMessage( + id=msg_obj.id, + date=msg_obj.created_at, + assistant_message=func_args["message"], + ) self._push_to_buffer(processed_chunk) except Exception as e: print(f"Failed to parse function message: {e}") else: - processed_chunk = { - "function_call": { - # "id": function_call.id, - # "name": function_call.function["name"], - "name": function_call.function.name, - "arguments": function_call.function.arguments, - }, - "id": str(msg_obj.id), - "date": msg_obj.created_at.isoformat(), - } + processed_chunk = FunctionCallMessage( + id=msg_obj.id, + date=msg_obj.created_at, + function_call=FunctionCall( + name=function_call.function.name, + arguments=function_call.function.arguments, + ), + ) + # processed_chunk = { + # "function_call": { + # "name": function_call.function.name, + # "arguments": function_call.function.arguments, + # }, + # "id": str(msg_obj.id), + # "date": msg_obj.created_at.isoformat(), + # } self._push_to_buffer(processed_chunk) return @@ -638,20 +672,33 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None): elif msg.startswith("Success: "): msg = msg.replace("Success: ", "") - new_message = {"function_return": msg, "status": "success"} + # new_message = {"function_return": msg, "status": "success"} + new_message = FunctionReturn( + id=msg_obj.id, + date=msg_obj.created_at, + function_return=msg, + status="success", + ) elif msg.startswith("Error: "): msg = msg.replace("Error: ", "") - new_message = {"function_return": msg, "status": "error"} + # new_message = {"function_return": msg, "status": "error"} + new_message = FunctionReturn( + id=msg_obj.id, + date=msg_obj.created_at, + function_return=msg, + status="error", + ) else: # NOTE: generic, should not happen + raise ValueError(msg) new_message = {"function_message": msg} # add extra metadata - if msg_obj is not None: - new_message["id"] = str(msg_obj.id) - assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at - new_message["date"] = msg_obj.created_at.isoformat() + # if msg_obj is not None: + # new_message["id"] = str(msg_obj.id) + # assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at + # new_message["date"] = msg_obj.created_at.isoformat() self._push_to_buffer(new_message) diff --git a/memgpt/server/rest_api/utils.py b/memgpt/server/rest_api/utils.py index 45a4ab3cec..0ce98f20d5 100644 --- a/memgpt/server/rest_api/utils.py +++ b/memgpt/server/rest_api/utils.py @@ -1,39 +1,24 @@ -import asyncio import json import traceback -from typing import AsyncGenerator, Generator, Union +from enum import Enum +from typing import AsyncGenerator, Union + +from pydantic import BaseModel from memgpt.constants import JSON_ENSURE_ASCII +SSE_PREFIX = "data: " +SSE_SUFFIX = "\n\n" SSE_FINISH_MSG = "[DONE]" # mimic openai SSE_ARTIFICIAL_DELAY = 0.1 def sse_formatter(data: Union[dict, str]) -> str: """Prefix with 'data: ', and always include double newlines""" + input(f"sse_formatter:\n{data}") assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}" data_str = json.dumps(data, ensure_ascii=JSON_ENSURE_ASCII) if isinstance(data, dict) else data - return f"data: {data_str}\n\n" - - -async def sse_generator(generator: Generator[dict, None, None]) -> Generator[str, None, None]: - """Generator that returns 'data: dict' formatted items, e.g.: - - data: {"id":"chatcmpl-9E0PdSZ2IBzAGlQ3SEWHJ5YwzucSP","object":"chat.completion.chunk","created":1713125205,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":null}]} - - data: {"id":"chatcmpl-9E0PdSZ2IBzAGlQ3SEWHJ5YwzucSP","object":"chat.completion.chunk","created":1713125205,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]} - - data: [DONE] - - """ - try: - for msg in generator: - yield sse_formatter(msg) - if SSE_ARTIFICIAL_DELAY: - await asyncio.sleep(SSE_ARTIFICIAL_DELAY) # Sleep to prevent a tight loop, adjust time as needed - except Exception as e: - yield sse_formatter({"error": f"{str(e)}"}) - yield sse_formatter(SSE_FINISH_MSG) # Signal that the stream is complete + return f"{SSE_PREFIX}{data_str}{SSE_SUFFIX}" async def sse_async_generator(generator: AsyncGenerator, finish_message=True): @@ -49,12 +34,20 @@ async def sse_async_generator(generator: AsyncGenerator, finish_message=True): try: async for chunk in generator: # yield f"data: {json.dumps(chunk)}\n\n" + if isinstance(chunk, BaseModel): + chunk = chunk.model_dump() + elif isinstance(chunk, Enum): + chunk = str(chunk.value) + elif not isinstance(chunk, dict): + chunk = str(chunk) yield sse_formatter(chunk) + except Exception as e: print("stream decoder hit error:", e) print(traceback.print_stack()) yield sse_formatter({"error": "stream decoder encountered an error"}) + finally: - # yield "data: [DONE]\n\n" if finish_message: - yield sse_formatter(SSE_FINISH_MSG) # Signal that the stream is complete + # Signal that the stream is complete + yield sse_formatter(SSE_FINISH_MSG) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index bbd3e03241..7b6be6491d 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1814,3 +1814,10 @@ def add_default_blocks(self, user_id: str): text = open(human_file, "r", encoding="utf-8").read() name = os.path.basename(human_file).replace(".txt", "") self.create_human(CreateHuman(name=name, value=text, template=True), user_id=user_id, update=True) + + def get_agent_message(self, agent_id: str, message_id: str) -> Message: + """Get a single message from the agent's memory""" + # Get the agent object (loaded in memory) + memgpt_agent = self._get_or_load_agent(agent_id=agent_id) + message = memgpt_agent.persistence_manager.recall_memory.storage.get(id=message_id) + return message diff --git a/memgpt/utils.py b/memgpt/utils.py index 8279d9b5d1..433b6b6204 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -468,6 +468,17 @@ ] +def deduplicate(target_list: list) -> list: + seen = set() + dedup_list = [] + for i in target_list: + if i not in seen: + seen.add(i) + dedup_list.append(i) + + return dedup_list + + def smart_urljoin(base_url: str, relative_url: str) -> str: """urljoin is stupid and wants a trailing / at the end of the endpoint address, or it will chop the suffix off""" if not base_url.endswith("/"): From 768d7ee8f73467f97d7daedb442cb98391b31b45 Mon Sep 17 00:00:00 2001 From: cpacker Date: Sun, 11 Aug 2024 20:22:49 -0700 Subject: [PATCH 4/5] chore: strip comments --- memgpt/schemas/message.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/memgpt/schemas/message.py b/memgpt/schemas/message.py index d0f9b1ad7b..cc42250bbb 100644 --- a/memgpt/schemas/message.py +++ b/memgpt/schemas/message.py @@ -49,19 +49,6 @@ class MessageCreate(BaseMessage): name: Optional[str] = Field(None, description="The name of the participant.") -""" - -Assistant messages: --> inner thoughts --> tool calls - -User messages: - - - -""" - - class Message(BaseMessage): """ Representation of a message sent. From 113cfb89d504f3d5b82e57608f0bc328faa53944 Mon Sep 17 00:00:00 2001 From: cpacker Date: Sun, 11 Aug 2024 20:27:10 -0700 Subject: [PATCH 5/5] chore: stray prints --- memgpt/server/rest_api/interface.py | 5 ----- memgpt/server/rest_api/utils.py | 1 - 2 files changed, 6 deletions(-) diff --git a/memgpt/server/rest_api/interface.py b/memgpt/server/rest_api/interface.py index 26c253a8c2..38ceb0553e 100644 --- a/memgpt/server/rest_api/interface.py +++ b/memgpt/server/rest_api/interface.py @@ -357,7 +357,6 @@ def get_generator(self) -> AsyncGenerator: def _push_to_buffer(self, item: Union[MemGPTMessage, LegacyMemGPTMessage, MessageStreamStatus]): """Add an item to the deque""" - print("zzz", item) assert self._active, "Generator is inactive" # assert isinstance(item, dict) or isinstance(item, MessageStreamStatus), f"Wrong type: {type(item)}" assert ( @@ -396,8 +395,6 @@ def stream_end(self): def step_complete(self): """Signal from the agent that one 'step' finished (step = LLM response + tool execution)""" - print("zzz step_complete") - if not self.multi_step: # end the stream self._active = False @@ -408,8 +405,6 @@ def step_complete(self): def step_yield(self): """If multi_step, this is the true 'stream_end' function.""" - print("zzz step_yield") - # if self.multi_step: # end the stream self._active = False diff --git a/memgpt/server/rest_api/utils.py b/memgpt/server/rest_api/utils.py index 0ce98f20d5..4f09f6fbd8 100644 --- a/memgpt/server/rest_api/utils.py +++ b/memgpt/server/rest_api/utils.py @@ -15,7 +15,6 @@ def sse_formatter(data: Union[dict, str]) -> str: """Prefix with 'data: ', and always include double newlines""" - input(f"sse_formatter:\n{data}") assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}" data_str = json.dumps(data, ensure_ascii=JSON_ENSURE_ASCII) if isinstance(data, dict) else data return f"{SSE_PREFIX}{data_str}{SSE_SUFFIX}"