From cefbf8cd2eca0ec15c596bb4a748844c21edce8a Mon Sep 17 00:00:00 2001 From: cpacker Date: Mon, 11 Mar 2024 13:38:01 -0700 Subject: [PATCH 1/5] unify all the api/agents API routes to use {agent_id} via path parameter --- memgpt/server/rest_api/agents/command.py | 10 +++++++--- memgpt/server/rest_api/agents/config.py | 15 ++++++++------- memgpt/server/rest_api/agents/message.py | 11 ++++++----- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/memgpt/server/rest_api/agents/command.py b/memgpt/server/rest_api/agents/command.py index 615c5ea329..d6520ccb63 100644 --- a/memgpt/server/rest_api/agents/command.py +++ b/memgpt/server/rest_api/agents/command.py @@ -23,8 +23,12 @@ class CommandResponse(BaseModel): def setup_agents_command_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) - @router.post("/agents/command", tags=["agents"], response_model=CommandResponse) - def run_command(request: CommandRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server)): + @router.post("/agents/{agent_id}/command", tags=["agents"], response_model=CommandResponse) + def run_command( + agent_id: uuid.UUID, + request: CommandRequest = Body(...), + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): """ Execute a command on a specified agent. @@ -34,7 +38,7 @@ def run_command(request: CommandRequest = Body(...), user_id: uuid.UUID = Depend """ interface.clear() try: - agent_id = uuid.UUID(request.agent_id) if request.agent_id else None + # agent_id = uuid.UUID(request.agent_id) if request.agent_id else None response = server.run_command(user_id=user_id, agent_id=agent_id, command=request.command) except HTTPException: raise diff --git a/memgpt/server/rest_api/agents/config.py b/memgpt/server/rest_api/agents/config.py index fa1bc3c01a..466169d5e7 100644 --- a/memgpt/server/rest_api/agents/config.py +++ b/memgpt/server/rest_api/agents/config.py @@ -51,9 +51,9 @@ def validate_agent_name(name: str) -> str: def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) - @router.get("/agents/config", tags=["agents"], response_model=GetAgentResponse) + @router.get("/agents/{agent_id}/config", tags=["agents"], response_model=GetAgentResponse) def get_agent_config( - agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."), + agent_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user_with_server), ): """ @@ -63,7 +63,7 @@ def get_agent_config( """ request = GetAgentRequest(agent_id=agent_id) - agent_id = uuid.UUID(request.agent_id) if request.agent_id else None + # agent_id = uuid.UUID(request.agent_id) if request.agent_id else None attached_sources = server.list_attached_sources(agent_id=agent_id) interface.clear() @@ -90,8 +90,9 @@ def get_agent_config( sources=attached_sources, ) - @router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse) + @router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse) def update_agent_name( + agent_id: uuid.UUID, request: AgentRenameRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server), ): @@ -100,7 +101,7 @@ def update_agent_name( This changes the name of the agent in the database but does NOT edit the agent's persona. """ - agent_id = uuid.UUID(request.agent_id) if request.agent_id else None + # agent_id = uuid.UUID(request.agent_id) if request.agent_id else None valid_name = validate_agent_name(request.agent_name) @@ -115,13 +116,13 @@ def update_agent_name( @router.delete("/agents/{agent_id}", tags=["agents"]) def delete_agent( - agent_id, + agent_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user_with_server), ): """ Delete an agent. """ - agent_id = uuid.UUID(agent_id) + # agent_id = uuid.UUID(agent_id) interface.clear() try: diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py index 75941976c8..d058247ba3 100644 --- a/memgpt/server/rest_api/agents/message.py +++ b/memgpt/server/rest_api/agents/message.py @@ -47,9 +47,9 @@ class GetAgentMessagesResponse(BaseModel): def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) - @router.get("/agents/message", tags=["agents"], response_model=GetAgentMessagesResponse) + @router.get("/agents/{agent_id}/message", tags=["agents"], response_model=GetAgentMessagesResponse) def get_agent_messages( - agent_id: str = Query(..., description="The unique identifier of the agent."), + agent_id: uuid.UUID, start: int = Query(..., description="Message index to start on (reverse chronological)."), count: int = Query(..., description="How many messages to retrieve."), user_id: uuid.UUID = Depends(get_current_user_with_server), @@ -59,14 +59,15 @@ def get_agent_messages( """ # Validate with the Pydantic model (optional) request = GetAgentMessagesRequest(agent_id=agent_id, start=start, count=count) - agent_id = uuid.UUID(request.agent_id) if request.agent_id else None + # agent_id = uuid.UUID(request.agent_id) if request.agent_id else None interface.clear() messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=request.start, count=request.count) return GetAgentMessagesResponse(messages=messages) - @router.post("/agents/message", tags=["agents"], response_model=UserMessageResponse) + @router.post("/agents/{agent_id}/message", tags=["agents"], response_model=UserMessageResponse) async def send_message( + agent_id: uuid.UUID, request: UserMessageRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server), ): @@ -76,7 +77,7 @@ async def send_message( This endpoint accepts a message from a user and processes it through the agent. It can optionally stream the response if 'stream' is set to True. """ - agent_id = uuid.UUID(request.agent_id) if request.agent_id else None + # agent_id = uuid.UUID(request.agent_id) if request.agent_id else None if request.role == "user" or request.role is None: message_func = server.user_message From 8d64053f06697651507757d648002cee0d93a10b Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Mon, 11 Mar 2024 14:13:33 -0700 Subject: [PATCH 2/5] Update memgpt/server/rest_api/agents/message.py Co-authored-by: Robin Goetz <35136007+goetzrobin@users.noreply.github.com> --- memgpt/server/rest_api/agents/message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py index d058247ba3..902e5088ca 100644 --- a/memgpt/server/rest_api/agents/message.py +++ b/memgpt/server/rest_api/agents/message.py @@ -65,7 +65,7 @@ def get_agent_messages( messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=request.start, count=request.count) return GetAgentMessagesResponse(messages=messages) - @router.post("/agents/{agent_id}/message", tags=["agents"], response_model=UserMessageResponse) + @router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse) async def send_message( agent_id: uuid.UUID, request: UserMessageRequest = Body(...), From a0ea57e8fc2f86a4f5ccf89a4afe23965e73b22b Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Mon, 11 Mar 2024 14:13:41 -0700 Subject: [PATCH 3/5] Update memgpt/server/rest_api/agents/message.py Co-authored-by: Robin Goetz <35136007+goetzrobin@users.noreply.github.com> --- memgpt/server/rest_api/agents/message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py index 902e5088ca..384c098f84 100644 --- a/memgpt/server/rest_api/agents/message.py +++ b/memgpt/server/rest_api/agents/message.py @@ -47,7 +47,7 @@ class GetAgentMessagesResponse(BaseModel): def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) - @router.get("/agents/{agent_id}/message", tags=["agents"], response_model=GetAgentMessagesResponse) + @router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=GetAgentMessagesResponse) def get_agent_messages( agent_id: uuid.UUID, start: int = Query(..., description="Message index to start on (reverse chronological)."), From fd209c7666013a407f2e80596410510d9a2dc6fd Mon Sep 17 00:00:00 2001 From: cpacker Date: Mon, 11 Mar 2024 14:26:35 -0700 Subject: [PATCH 4/5] patched serialization bug in Message --- memgpt/data_types.py | 6 ++++++ memgpt/server/server.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/memgpt/data_types.py b/memgpt/data_types.py index d87f462c36..7a2fd59d90 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -124,6 +124,12 @@ def __init__( assert tool_call_id is None self.tool_call_id = tool_call_id + def to_json(self): + json_message = vars(self) + if json_message["tool_calls"] is not None: + json_message["tool_calls"] = [vars(tc) for tc in json_message["tool_calls"]] + return json_message + @staticmethod def dict_to_message( user_id: uuid.UUID, diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 723be93ed0..c8b1f51401 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -827,7 +827,7 @@ def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int messages = sorted(page, key=lambda x: x.created_at, reverse=True) # convert to json - json_messages = [vars(record) for record in messages] + json_messages = [record.to_json() for record in messages] return json_messages def get_agent_archival(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list: From f4e016acb5093f1f5a5c0f507b9a41c649b211e4 Mon Sep 17 00:00:00 2001 From: cpacker Date: Mon, 11 Mar 2024 14:26:43 -0700 Subject: [PATCH 5/5] drop dead agent_id query/body params --- memgpt/server/rest_api/agents/command.py | 1 - memgpt/server/rest_api/agents/config.py | 7 ------- memgpt/server/rest_api/agents/memory.py | 1 - memgpt/server/rest_api/agents/message.py | 2 -- 4 files changed, 11 deletions(-) diff --git a/memgpt/server/rest_api/agents/command.py b/memgpt/server/rest_api/agents/command.py index d6520ccb63..e522653a07 100644 --- a/memgpt/server/rest_api/agents/command.py +++ b/memgpt/server/rest_api/agents/command.py @@ -12,7 +12,6 @@ class CommandRequest(BaseModel): - agent_id: str = Field(..., description="Identifier of the agent on which the command will be executed.") command: str = Field(..., description="The command to be executed by the agent.") diff --git a/memgpt/server/rest_api/agents/config.py b/memgpt/server/rest_api/agents/config.py index 466169d5e7..2f2e3ce051 100644 --- a/memgpt/server/rest_api/agents/config.py +++ b/memgpt/server/rest_api/agents/config.py @@ -15,12 +15,7 @@ router = APIRouter() -class GetAgentRequest(BaseModel): - agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.") - - class AgentRenameRequest(BaseModel): - agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.") agent_name: str = Field(..., description="New name for the agent.") @@ -61,8 +56,6 @@ def get_agent_config( This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs. """ - request = GetAgentRequest(agent_id=agent_id) - # agent_id = uuid.UUID(request.agent_id) if request.agent_id else None attached_sources = server.list_attached_sources(agent_id=agent_id) diff --git a/memgpt/server/rest_api/agents/memory.py b/memgpt/server/rest_api/agents/memory.py index 022718f279..adb08f1f19 100644 --- a/memgpt/server/rest_api/agents/memory.py +++ b/memgpt/server/rest_api/agents/memory.py @@ -26,7 +26,6 @@ class GetAgentMemoryResponse(BaseModel): # NOTE not subclassing CoreMemory since in the request both field are optional class UpdateAgentMemoryRequest(BaseModel): - agent_id: str = Field(..., description="The unique identifier of the agent.") human: str = Field(None, description="Human element of the core memory.") persona: str = Field(None, description="Persona element of the core memory.") diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py index 384c098f84..8b1f3a15f0 100644 --- a/memgpt/server/rest_api/agents/message.py +++ b/memgpt/server/rest_api/agents/message.py @@ -24,7 +24,6 @@ class MessageRoleType(str, Enum): class UserMessageRequest(BaseModel): - agent_id: str = Field(..., description="The unique identifier of the agent.") message: str = Field(..., description="The message content to be processed by the agent.") stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.") role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')") @@ -35,7 +34,6 @@ class UserMessageResponse(BaseModel): class GetAgentMessagesRequest(BaseModel): - agent_id: str = Field(..., description="The unique identifier of the agent.") start: int = Field(..., description="Message index to start on (reverse chronological).") count: int = Field(..., description="How many messages to retrieve.")