Skip to content

Commit

Permalink
feat: add remaining Python client support for REST API routes + tests (
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Mar 18, 2024
1 parent 192c08e commit b2c2d56
Show file tree
Hide file tree
Showing 9 changed files with 397 additions and 94 deletions.
302 changes: 256 additions & 46 deletions memgpt/client/client.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def insert(self, memory_string, return_ids=False) -> Union[bool, List[uuid.UUID]
"""Embed and save memory string"""

if not isinstance(memory_string, str):
return TypeError("memory must be a string")
raise TypeError("memory must be a string")

try:
passages = []
Expand Down
5 changes: 5 additions & 0 deletions memgpt/models/pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ class AgentStateModel(BaseModel):
state: Optional[Dict] = Field(None, description="The state of the agent.")


class CoreMemory(BaseModel):
human: str = Field(..., description="Human element of the core memory.")
persona: str = Field(..., description="Persona element of the core memory.")


class HumanModel(SQLModel, table=True):
text: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human text.")
name: str = Field(..., description="The name of the human.")
Expand Down
38 changes: 32 additions & 6 deletions memgpt/server/rest_api/agents/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,17 @@ def get_agent_config(
This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
"""
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
attached_sources = server.list_attached_sources(agent_id=agent_id)

interface.clear()
if not server.ms.get_agent(user_id=user_id, agent_id=agent_id):
# agent does not exist
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")

agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id)
# return GetAgentResponse(agent_state=agent_state)
# get sources
attached_sources = server.list_attached_sources(agent_id=agent_id)

# configs
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))

Expand All @@ -73,8 +78,8 @@ def get_agent_config(
preset=agent_state.preset,
persona=agent_state.persona,
human=agent_state.human,
llm_config=agent_state.llm_config,
embedding_config=agent_state.embedding_config,
llm_config=llm_config,
embedding_config=embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
Expand All @@ -101,11 +106,32 @@ def update_agent_name(
interface.clear()
try:
agent_state = server.rename_agent(user_id=user_id, agent_id=agent_id, new_agent_name=valid_name)
# get sources
attached_sources = server.list_attached_sources(agent_id=agent_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return GetAgentResponse(agent_state=agent_state)
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))

return GetAgentResponse(
agent_state=AgentStateModel(
id=agent_state.id,
name=agent_state.name,
user_id=agent_state.user_id,
preset=agent_state.preset,
persona=agent_state.persona,
human=agent_state.human,
llm_config=llm_config,
embedding_config=embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
),
last_run_at=None, # TODO
sources=attached_sources,
)

@router.delete("/agents/{agent_id}", tags=["agents"])
def delete_agent(
Expand Down
4 changes: 1 addition & 3 deletions memgpt/server/rest_api/agents/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class GetAgentArchivalMemoryResponse(BaseModel):


class InsertAgentArchivalMemoryRequest(BaseModel):
content: str = Field(None, description="The memory contents to insert into archival memory.")
content: str = Field(..., description="The memory contents to insert into archival memory.")


class InsertAgentArchivalMemoryResponse(BaseModel):
Expand Down Expand Up @@ -87,8 +87,6 @@ def update_agent_memory(
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
"""
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None

interface.clear()

new_memory_contents = {"persona": request.persona, "human": request.human}
Expand Down
5 changes: 3 additions & 2 deletions memgpt/server/rest_api/humans/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ async def list_humans(
return ListHumansResponse(humans=humans)

@router.post("/humans", tags=["humans"], response_model=HumanModel)
async def create_persona(
async def create_human(
request: CreateHumanRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
new_human = HumanModel(text=request.text, name=request.name, user_id=user_id)
human_id = new_human.id
server.ms.add_human(new_human)
return new_human
return HumanModel(id=human_id, text=request.text, name=request.name, user_id=user_id)

return router
3 changes: 2 additions & 1 deletion memgpt/server/rest_api/personas/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ async def create_persona(
):
interface.clear()
new_persona = PersonaModel(text=request.text, name=request.name, user_id=user_id)
persona_id = new_persona.id
server.ms.add_persona(new_persona)
return new_persona
return PersonaModel(id=persona_id, text=request.text, name=request.name, user_id=user_id)

return router
2 changes: 1 addition & 1 deletion memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ def update_agent_core_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_
"modified": modified,
}

def rename_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_agent_name: str) -> dict:
def rename_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_agent_name: str) -> AgentState:
"""Update the name of the agent in the database"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
Expand Down
130 changes: 96 additions & 34 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def run_server():
config.save()
credentials.save()

start_server(debug=False)
start_server(debug=True)


@pytest.fixture(scope="session", autouse=True)
Expand Down Expand Up @@ -125,8 +125,8 @@ def user_token():


# Fixture to create clients with different configurations
@pytest.fixture(params=[{"base_url": test_base_url}, {"base_url": None}], scope="module")
# @pytest.fixture(params=[{"base_url": test_base_url}], scope="module")
# @pytest.fixture(params=[{"base_url": test_base_url}, {"base_url": None}], scope="module")
@pytest.fixture(params=[{"base_url": test_base_url}], scope="module")
def client(request, user_token):
# use token or not
if request.param["base_url"]:
Expand All @@ -149,29 +149,99 @@ def agent(client):
client.delete_agent(agent_state.id)


# TODO: add back once REST API supports
# def test_create_preset(client):
#
# available_functions = load_all_function_sets(merge=True)
# functions_schema = [f_dict["json_schema"] for f_name, f_dict in available_functions.items()]
# preset = Preset(
# name=test_preset_name,
# user_id=test_user_id,
# description="A preset for testing the MemGPT client",
# system=gpt_system.get_system_text(DEFAULT_PRESET),
# functions_schema=functions_schema,
# )
# client.create_preset(preset)
def test_agent(client, agent):
# test client.rename_agent
new_name = "RenamedTestAgent"
client.rename_agent(agent_id=agent.id, new_name=new_name)
renamed_agent = client.get_agent(agent_id=str(agent.id))
assert renamed_agent.name == new_name, "Agent renaming failed"

# test client.delete_agent and client.agent_exists
delete_agent = client.create_agent(name="DeleteTestAgent", preset=test_preset_name)
assert client.agent_exists(agent_id=delete_agent.id), "Agent creation failed"
client.delete_agent(agent_id=delete_agent.id)
assert client.agent_exists(agent_id=delete_agent.id) == False, "Agent deletion failed"

# def test_create_agent(client):
# global test_agent_state
# test_agent_state = client.create_agent(
# name=test_agent_name,
# preset=test_preset_name,
# )
# print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}")
# assert test_agent_state is not None

def test_memory(client, agent):
memory_response = client.get_agent_memory(agent_id=agent.id)
print("MEMORY", memory_response)

updated_memory = {"human": "Updated human memory", "persona": "Updated persona memory"}
client.update_agent_core_memory(agent_id=str(agent.id), new_memory_contents=updated_memory)
updated_memory_response = client.get_agent_memory(agent_id=agent.id)
assert (
updated_memory_response.core_memory.human == updated_memory["human"]
and updated_memory_response.core_memory.persona == updated_memory["persona"]
), "Memory update failed"


def test_agent_interactions(client, agent):
message = "Hello, agent!"
message_response = client.user_message(agent_id=str(agent.id), message=message)

command = "/memory"
command_response = client.run_command(agent_id=str(agent.id), command=command)
print("command", command_response)


def test_archival_memory(client, agent):
memory_content = "Archival memory content"
insert_response = client.insert_archival_memory(agent_id=agent.id, memory=memory_content)
assert insert_response, "Inserting archival memory failed"

archival_memory_response = client.get_agent_archival_memory(agent_id=agent.id, limit=1)
archival_memories = [memory.contents for memory in archival_memory_response.archival_memory]
assert memory_content in archival_memories, f"Retrieving archival memory failed: {archival_memories}"

memory_id_to_delete = archival_memory_response.archival_memory[0].id
client.delete_archival_memory(agent_id=agent.id, memory_id=memory_id_to_delete)

# TODO: check deletion


def test_messages(client, agent):
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
assert send_message_response, "Sending message failed"

messages_response = client.get_messages(agent_id=agent.id, limit=1)
assert len(messages_response.messages) > 0, "Retrieving messages failed"


def test_humans_personas(client, agent):
humans_response = client.list_humans()
print("HUMANS", humans_response)

personas_response = client.list_personas()
print("PERSONAS", personas_response)

persona_name = "TestPersona"
persona = client.create_persona(name=persona_name, persona="Persona text")
assert persona.name == persona_name
assert persona.text == "Persona text", "Creating persona failed"

human_name = "TestHuman"
human = client.create_human(name=human_name, human="Human text")
assert human.name == human_name
assert human.text == "Human text", "Creating human failed"


def test_tools(client, agent):
tools_response = client.list_tools()
print("TOOLS", tools_response)

tool_name = "TestTool"
tool_response = client.create_tool(name=tool_name, source_code="print('Hello World')", source_type="python")
assert tool_response, "Creating tool failed"


def test_config(client, agent):
models_response = client.list_models()
print("MODELS", models_response)

config_response = client.get_config()
# TODO: ensure config is the same as the one in the server
print("CONFIG", config_response)


def test_sources(client, agent):
Expand All @@ -192,7 +262,7 @@ def test_sources(client, agent):
assert len(sources) == 1

# check agent archival memory size
archival_memories = client.get_agent_archival_memory(agent_id=agent.id)
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
print(archival_memories)
assert len(archival_memories) == 0

Expand All @@ -207,7 +277,7 @@ def test_sources(client, agent):
client.attach_source_to_agent(source_name="test_source", agent_id=agent.id)

# list archival memory
archival_memories = client.get_agent_archival_memory(agent_id=agent.id)
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
print(archival_memories)
assert len(archival_memories) == num_passages

Expand All @@ -217,11 +287,3 @@ def test_sources(client, agent):

# delete the source
client.delete_source(source.id)


# def test_user_message(client, agent):
# """Test that we can send a message through the client"""
# assert client is not None, "Run create_agent test first"
# print(f"\n\n[2] SENDING MESSAGE TO AGENT {agent.id}!!!\n\tmessages={agent.state['messages']}")
# response = client.user_message(agent_id=agent.id, message="Hello my name is Test, Client Test")
# assert response is not None and len(response) > 0

0 comments on commit b2c2d56

Please sign in to comment.