Skip to content

Commit

Permalink
feat: move agent_id from query param to path variable and remove unus… (
Browse files Browse the repository at this point in the history
#1094)

Co-authored-by: cpacker <[email protected]>
  • Loading branch information
goetzrobin and cpacker authored Mar 5, 2024
1 parent 3fd568e commit 575d825
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 27 deletions.
5 changes: 3 additions & 2 deletions memgpt/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,14 @@ def get_chat_completion(
# if hasattr(llm_wrapper, "supports_first_message"):
if hasattr(llm_wrapper, "supports_first_message") and llm_wrapper.supports_first_message:
prompt = llm_wrapper.chat_completion_to_prompt(
messages, functions if functions else [], first_message=first_message, function_documentation=documentation
messages=messages, functions=functions, first_message=first_message, function_documentation=documentation
)
else:
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions if functions else [], function_documentation=documentation)
prompt = llm_wrapper.chat_completion_to_prompt(messages=messages, functions=functions, function_documentation=documentation)

printd(prompt)
except Exception as e:
print(e)
raise LocalLLMError(
f"Failed to convert ChatCompletion messages into prompt string with wrapper {str(llm_wrapper)} - error: {str(e)}"
)
Expand Down
18 changes: 8 additions & 10 deletions memgpt/server/rest_api/agents/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import re
import uuid
from functools import partial
from typing import List, Optional

from fastapi import APIRouter, Body, Depends, Query, HTTPException, status
from fastapi import APIRouter, Body, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Optional

from memgpt.models.pydantic_models import AgentStateModel, LLMConfigModel, EmbeddingConfigModel
from memgpt.server.rest_api.auth_token import get_current_user
Expand All @@ -20,7 +20,6 @@ class GetAgentRequest(BaseModel):


class AgentRenameRequest(BaseModel):
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")
agent_name: str = Field(..., description="New name for the agent.")


Expand Down Expand Up @@ -51,9 +50,9 @@ def validate_agent_name(name: str) -> str:
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.get("/agents", tags=["agents"], response_model=GetAgentResponse)
@router.get("/agents/{agent_id}", tags=["agents"], response_model=GetAgentResponse)
def get_agent_config(
agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."),
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Expand Down Expand Up @@ -90,8 +89,9 @@ def get_agent_config(
sources=attached_sources,
)

@router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse)
@router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse)
def update_agent_name(
agent_id: uuid.UUID,
request: AgentRenameRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
Expand All @@ -100,8 +100,6 @@ def update_agent_name(
This changes the name of the agent in the database but does NOT edit the agent's persona.
"""
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None

valid_name = validate_agent_name(request.agent_name)

interface.clear()
Expand All @@ -113,9 +111,9 @@ def update_agent_name(
raise HTTPException(status_code=500, detail=f"{e}")
return GetAgentResponse(agent_state=agent_state)

@router.delete("/agents", tags=["agents"])
@router.delete("/agents/{agent_id}", tags=["agents"])
def delete_agent(
agent_id: str = Query(..., description="Unique identifier of the agent to be deleted."),
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Expand Down
18 changes: 5 additions & 13 deletions memgpt/server/rest_api/agents/memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from functools import partial

from fastapi import APIRouter, Depends, Body, Query
from fastapi import APIRouter, Depends, Body
from pydantic import BaseModel, Field

from memgpt.server.rest_api.auth_token import get_current_user
Expand All @@ -16,10 +16,6 @@ class CoreMemory(BaseModel):
persona: str | None = Field(None, description="Persona element of the core memory.")


class GetAgentMemoryRequest(BaseModel):
agent_id: str = Field(..., description="The unique identifier of the agent.")


class GetAgentMemoryResponse(BaseModel):
core_memory: CoreMemory = Field(..., description="The state of the agent's core memory.")
recall_memory: int = Field(..., description="Size of the agent's recall memory.")
Expand All @@ -41,27 +37,23 @@ class UpdateAgentMemoryResponse(BaseModel):
def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.get("/agents/memory", tags=["agents"], response_model=GetAgentMemoryResponse)
@router.get("/agents/{agent_id}/memory", tags=["agents"], response_model=GetAgentMemoryResponse)
def get_agent_memory(
agent_id: str = Query(..., description="The unique identifier of the agent."),
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Retrieve the memory state of a specific agent.
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
"""
# Validate with the Pydantic model (optional)
request = GetAgentMemoryRequest(agent_id=agent_id)

agent_id = uuid.UUID(request.agent_id) if request.agent_id else None

interface.clear()
memory = server.get_agent_memory(user_id=user_id, agent_id=agent_id)
return GetAgentMemoryResponse(**memory)

@router.post("/agents/memory", tags=["agents"], response_model=UpdateAgentMemoryResponse)
@router.post("/agents/{agent_id}/memory", tags=["agents"], response_model=UpdateAgentMemoryResponse)
def update_agent_memory(
agent_id: uuid.UUID,
request: UpdateAgentMemoryRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
Expand Down
6 changes: 5 additions & 1 deletion tests/test_migrate.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import os
from memgpt.migrate import migrate_all_agents, migrate_all_sources
from memgpt.config import MemGPTConfig
from .utils import wipe_config
from .utils import wipe_config, create_config
from memgpt.server.server import SyncServer
import shutil
import uuid


def test_migrate_0211():
wipe_config()
if os.getenv("OPENAI_API_KEY"):
create_config("openai")
else:
create_config("memgpt_hosted")

data_dir = "tests/data/memgpt-0.2.11"
tmp_dir = f"tmp_{str(uuid.uuid4())}"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
def create_test_agent():
"""Create a test agent that we can call functions on"""
wipe_config()
global client
if os.getenv("OPENAI_API_KEY"):
create_config("openai")
else:
create_config("memgpt_hosted")

global client
client = create_client()
agent_state = client.create_agent(
name=test_agent_name,
Expand Down

0 comments on commit 575d825

Please sign in to comment.