Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: patch messages route + unify all the api/agents API routes to use {agent_id} via path parameter #1129

Merged
merged 5 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions memgpt/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ def __init__(
assert tool_call_id is None
self.tool_call_id = tool_call_id

def to_json(self):
json_message = vars(self)
if json_message["tool_calls"] is not None:
json_message["tool_calls"] = [vars(tc) for tc in json_message["tool_calls"]]
return json_message

@staticmethod
def dict_to_message(
user_id: uuid.UUID,
Expand Down
11 changes: 7 additions & 4 deletions memgpt/server/rest_api/agents/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


class CommandRequest(BaseModel):
agent_id: str = Field(..., description="Identifier of the agent on which the command will be executed.")
command: str = Field(..., description="The command to be executed by the agent.")


Expand All @@ -23,8 +22,12 @@ class CommandResponse(BaseModel):
def setup_agents_command_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.post("/agents/command", tags=["agents"], response_model=CommandResponse)
def run_command(request: CommandRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server)):
@router.post("/agents/{agent_id}/command", tags=["agents"], response_model=CommandResponse)
def run_command(
agent_id: uuid.UUID,
request: CommandRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Execute a command on a specified agent.

Expand All @@ -34,7 +37,7 @@ def run_command(request: CommandRequest = Body(...), user_id: uuid.UUID = Depend
"""
interface.clear()
try:
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
response = server.run_command(user_id=user_id, agent_id=agent_id, command=request.command)
except HTTPException:
raise
Expand Down
22 changes: 8 additions & 14 deletions memgpt/server/rest_api/agents/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@
router = APIRouter()


class GetAgentRequest(BaseModel):
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")


class AgentRenameRequest(BaseModel):
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")
agent_name: str = Field(..., description="New name for the agent.")


Expand Down Expand Up @@ -51,19 +46,17 @@ def validate_agent_name(name: str) -> str:
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.get("/agents/config", tags=["agents"], response_model=GetAgentResponse)
@router.get("/agents/{agent_id}/config", tags=["agents"], response_model=GetAgentResponse)
def get_agent_config(
agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."),
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Retrieve the configuration for a specific agent.

This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
"""
request = GetAgentRequest(agent_id=agent_id)

agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
attached_sources = server.list_attached_sources(agent_id=agent_id)

interface.clear()
Expand All @@ -90,8 +83,9 @@ def get_agent_config(
sources=attached_sources,
)

@router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse)
@router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse)
def update_agent_name(
agent_id: uuid.UUID,
request: AgentRenameRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
Expand All @@ -100,7 +94,7 @@ def update_agent_name(

This changes the name of the agent in the database but does NOT edit the agent's persona.
"""
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None

valid_name = validate_agent_name(request.agent_name)

Expand All @@ -115,13 +109,13 @@ def update_agent_name(

@router.delete("/agents/{agent_id}", tags=["agents"])
def delete_agent(
agent_id,
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Delete an agent.
"""
agent_id = uuid.UUID(agent_id)
# agent_id = uuid.UUID(agent_id)

interface.clear()
try:
Expand Down
1 change: 0 additions & 1 deletion memgpt/server/rest_api/agents/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class GetAgentMemoryResponse(BaseModel):

# NOTE not subclassing CoreMemory since in the request both field are optional
class UpdateAgentMemoryRequest(BaseModel):
agent_id: str = Field(..., description="The unique identifier of the agent.")
human: str = Field(None, description="Human element of the core memory.")
persona: str = Field(None, description="Persona element of the core memory.")

Expand Down
13 changes: 6 additions & 7 deletions memgpt/server/rest_api/agents/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class MessageRoleType(str, Enum):


class UserMessageRequest(BaseModel):
agent_id: str = Field(..., description="The unique identifier of the agent.")
message: str = Field(..., description="The message content to be processed by the agent.")
stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.")
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
Expand All @@ -35,7 +34,6 @@ class UserMessageResponse(BaseModel):


class GetAgentMessagesRequest(BaseModel):
agent_id: str = Field(..., description="The unique identifier of the agent.")
start: int = Field(..., description="Message index to start on (reverse chronological).")
count: int = Field(..., description="How many messages to retrieve.")

Expand All @@ -47,9 +45,9 @@ class GetAgentMessagesResponse(BaseModel):
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.get("/agents/message", tags=["agents"], response_model=GetAgentMessagesResponse)
@router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=GetAgentMessagesResponse)
def get_agent_messages(
agent_id: str = Query(..., description="The unique identifier of the agent."),
agent_id: uuid.UUID,
start: int = Query(..., description="Message index to start on (reverse chronological)."),
count: int = Query(..., description="How many messages to retrieve."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
Expand All @@ -59,14 +57,15 @@ def get_agent_messages(
"""
# Validate with the Pydantic model (optional)
request = GetAgentMessagesRequest(agent_id=agent_id, start=start, count=count)
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None

interface.clear()
messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=request.start, count=request.count)
return GetAgentMessagesResponse(messages=messages)

@router.post("/agents/message", tags=["agents"], response_model=UserMessageResponse)
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse)
async def send_message(
agent_id: uuid.UUID,
request: UserMessageRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
Expand All @@ -76,7 +75,7 @@ async def send_message(
This endpoint accepts a message from a user and processes it through the agent.
It can optionally stream the response if 'stream' is set to True.
"""
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None

if request.role == "user" or request.role is None:
message_func = server.user_message
Expand Down
2 changes: 1 addition & 1 deletion memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int
messages = sorted(page, key=lambda x: x.created_at, reverse=True)

# convert to json
json_messages = [vars(record) for record in messages]
json_messages = [record.to_json() for record in messages]
return json_messages

def get_agent_archival(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
Expand Down
Loading