diff --git a/memgpt/data_types.py b/memgpt/data_types.py index d87f462c36..7a2fd59d90 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -124,6 +124,12 @@ def __init__( assert tool_call_id is None self.tool_call_id = tool_call_id + def to_json(self): + json_message = vars(self) + if json_message["tool_calls"] is not None: + json_message["tool_calls"] = [vars(tc) for tc in json_message["tool_calls"]] + return json_message + @staticmethod def dict_to_message( user_id: uuid.UUID, diff --git a/memgpt/server/rest_api/agents/command.py b/memgpt/server/rest_api/agents/command.py index 615c5ea329..e522653a07 100644 --- a/memgpt/server/rest_api/agents/command.py +++ b/memgpt/server/rest_api/agents/command.py @@ -12,7 +12,6 @@ class CommandRequest(BaseModel): - agent_id: str = Field(..., description="Identifier of the agent on which the command will be executed.") command: str = Field(..., description="The command to be executed by the agent.") @@ -23,8 +22,12 @@ class CommandResponse(BaseModel): def setup_agents_command_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) - @router.post("/agents/command", tags=["agents"], response_model=CommandResponse) - def run_command(request: CommandRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server)): + @router.post("/agents/{agent_id}/command", tags=["agents"], response_model=CommandResponse) + def run_command( + agent_id: uuid.UUID, + request: CommandRequest = Body(...), + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): """ Execute a command on a specified agent. @@ -34,7 +37,7 @@ def run_command(request: CommandRequest = Body(...), user_id: uuid.UUID = Depend """ interface.clear() try: - agent_id = uuid.UUID(request.agent_id) if request.agent_id else None + # agent_id = uuid.UUID(request.agent_id) if request.agent_id else None response = server.run_command(user_id=user_id, agent_id=agent_id, command=request.command) except HTTPException: raise diff --git a/memgpt/server/rest_api/agents/config.py b/memgpt/server/rest_api/agents/config.py index fa1bc3c01a..2f2e3ce051 100644 --- a/memgpt/server/rest_api/agents/config.py +++ b/memgpt/server/rest_api/agents/config.py @@ -15,12 +15,7 @@ router = APIRouter() -class GetAgentRequest(BaseModel): - agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.") - - class AgentRenameRequest(BaseModel): - agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.") agent_name: str = Field(..., description="New name for the agent.") @@ -51,9 +46,9 @@ def validate_agent_name(name: str) -> str: def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) - @router.get("/agents/config", tags=["agents"], response_model=GetAgentResponse) + @router.get("/agents/{agent_id}/config", tags=["agents"], response_model=GetAgentResponse) def get_agent_config( - agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."), + agent_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user_with_server), ): """ @@ -61,9 +56,7 @@ def get_agent_config( This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs. """ - request = GetAgentRequest(agent_id=agent_id) - - agent_id = uuid.UUID(request.agent_id) if request.agent_id else None + # agent_id = uuid.UUID(request.agent_id) if request.agent_id else None attached_sources = server.list_attached_sources(agent_id=agent_id) interface.clear() @@ -90,8 +83,9 @@ def get_agent_config( sources=attached_sources, ) - @router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse) + @router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse) def update_agent_name( + agent_id: uuid.UUID, request: AgentRenameRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server), ): @@ -100,7 +94,7 @@ def update_agent_name( This changes the name of the agent in the database but does NOT edit the agent's persona. """ - agent_id = uuid.UUID(request.agent_id) if request.agent_id else None + # agent_id = uuid.UUID(request.agent_id) if request.agent_id else None valid_name = validate_agent_name(request.agent_name) @@ -115,13 +109,13 @@ def update_agent_name( @router.delete("/agents/{agent_id}", tags=["agents"]) def delete_agent( - agent_id, + agent_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user_with_server), ): """ Delete an agent. """ - agent_id = uuid.UUID(agent_id) + # agent_id = uuid.UUID(agent_id) interface.clear() try: diff --git a/memgpt/server/rest_api/agents/memory.py b/memgpt/server/rest_api/agents/memory.py index 022718f279..adb08f1f19 100644 --- a/memgpt/server/rest_api/agents/memory.py +++ b/memgpt/server/rest_api/agents/memory.py @@ -26,7 +26,6 @@ class GetAgentMemoryResponse(BaseModel): # NOTE not subclassing CoreMemory since in the request both field are optional class UpdateAgentMemoryRequest(BaseModel): - agent_id: str = Field(..., description="The unique identifier of the agent.") human: str = Field(None, description="Human element of the core memory.") persona: str = Field(None, description="Persona element of the core memory.") diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py index 75941976c8..8b1f3a15f0 100644 --- a/memgpt/server/rest_api/agents/message.py +++ b/memgpt/server/rest_api/agents/message.py @@ -24,7 +24,6 @@ class MessageRoleType(str, Enum): class UserMessageRequest(BaseModel): - agent_id: str = Field(..., description="The unique identifier of the agent.") message: str = Field(..., description="The message content to be processed by the agent.") stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.") role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')") @@ -35,7 +34,6 @@ class UserMessageResponse(BaseModel): class GetAgentMessagesRequest(BaseModel): - agent_id: str = Field(..., description="The unique identifier of the agent.") start: int = Field(..., description="Message index to start on (reverse chronological).") count: int = Field(..., description="How many messages to retrieve.") @@ -47,9 +45,9 @@ class GetAgentMessagesResponse(BaseModel): def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) - @router.get("/agents/message", tags=["agents"], response_model=GetAgentMessagesResponse) + @router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=GetAgentMessagesResponse) def get_agent_messages( - agent_id: str = Query(..., description="The unique identifier of the agent."), + agent_id: uuid.UUID, start: int = Query(..., description="Message index to start on (reverse chronological)."), count: int = Query(..., description="How many messages to retrieve."), user_id: uuid.UUID = Depends(get_current_user_with_server), @@ -59,14 +57,15 @@ def get_agent_messages( """ # Validate with the Pydantic model (optional) request = GetAgentMessagesRequest(agent_id=agent_id, start=start, count=count) - agent_id = uuid.UUID(request.agent_id) if request.agent_id else None + # agent_id = uuid.UUID(request.agent_id) if request.agent_id else None interface.clear() messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=request.start, count=request.count) return GetAgentMessagesResponse(messages=messages) - @router.post("/agents/message", tags=["agents"], response_model=UserMessageResponse) + @router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse) async def send_message( + agent_id: uuid.UUID, request: UserMessageRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server), ): @@ -76,7 +75,7 @@ async def send_message( This endpoint accepts a message from a user and processes it through the agent. It can optionally stream the response if 'stream' is set to True. """ - agent_id = uuid.UUID(request.agent_id) if request.agent_id else None + # agent_id = uuid.UUID(request.agent_id) if request.agent_id else None if request.role == "user" or request.role is None: message_func = server.user_message diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 723be93ed0..c8b1f51401 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -827,7 +827,7 @@ def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int messages = sorted(page, key=lambda x: x.created_at, reverse=True) # convert to json - json_messages = [vars(record) for record in messages] + json_messages = [record.to_json() for record in messages] return json_messages def get_agent_archival(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list: