Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add remaining Python client support for REST API routes + tests #1160

Merged
merged 23 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 256 additions & 46 deletions memgpt/client/client.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def insert(self, memory_string, return_ids=False) -> Union[bool, List[uuid.UUID]
"""Embed and save memory string"""

if not isinstance(memory_string, str):
return TypeError("memory must be a string")
raise TypeError("memory must be a string")

try:
passages = []
Expand Down
1 change: 1 addition & 0 deletions memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ def add_persona(self, persona: PersonaModel):
with self.session_maker() as session:
session.add(persona)
session.commit()
return persona
cpacker marked this conversation as resolved.
Show resolved Hide resolved

@enforce_types
def add_preset(self, preset: PresetModel):
Expand Down
38 changes: 32 additions & 6 deletions memgpt/server/rest_api/agents/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,17 @@ def get_agent_config(

This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
"""
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
attached_sources = server.list_attached_sources(agent_id=agent_id)

interface.clear()
if not server.ms.get_agent(user_id=user_id, agent_id=agent_id):
# agent does not exist
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")

agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id)
# return GetAgentResponse(agent_state=agent_state)
# get sources
attached_sources = server.list_attached_sources(agent_id=agent_id)

# configs
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))

Expand All @@ -73,8 +78,8 @@ def get_agent_config(
preset=agent_state.preset,
persona=agent_state.persona,
human=agent_state.human,
llm_config=agent_state.llm_config,
embedding_config=agent_state.embedding_config,
llm_config=llm_config,
embedding_config=embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
Expand All @@ -101,11 +106,32 @@ def update_agent_name(
interface.clear()
try:
agent_state = server.rename_agent(user_id=user_id, agent_id=agent_id, new_agent_name=valid_name)
# get sources
attached_sources = server.list_attached_sources(agent_id=agent_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return GetAgentResponse(agent_state=agent_state)
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))

return GetAgentResponse(
agent_state=AgentStateModel(
id=agent_state.id,
name=agent_state.name,
user_id=agent_state.user_id,
preset=agent_state.preset,
persona=agent_state.persona,
human=agent_state.human,
llm_config=llm_config,
embedding_config=embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
),
last_run_at=None, # TODO
sources=attached_sources,
)

@router.delete("/agents/{agent_id}", tags=["agents"])
def delete_agent(
Expand Down
4 changes: 1 addition & 3 deletions memgpt/server/rest_api/agents/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class GetAgentArchivalMemoryResponse(BaseModel):


class InsertAgentArchivalMemoryRequest(BaseModel):
content: str = Field(None, description="The memory contents to insert into archival memory.")
content: str = Field(..., description="The memory contents to insert into archival memory.")


class InsertAgentArchivalMemoryResponse(BaseModel):
Expand Down Expand Up @@ -87,8 +87,6 @@ def update_agent_memory(

This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
"""
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None

interface.clear()

new_memory_contents = {"persona": request.persona, "human": request.human}
Expand Down
5 changes: 3 additions & 2 deletions memgpt/server/rest_api/humans/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ async def list_humans(
return ListHumansResponse(humans=humans)

@router.post("/humans", tags=["humans"], response_model=HumanModel)
async def create_persona(
async def create_human(
request: CreateHumanRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
new_human = HumanModel(text=request.text, name=request.name, user_id=user_id)
human_id = new_human.id
server.ms.add_human(new_human)
return new_human
return HumanModel(id=human_id, text=request.text, name=request.name, user_id=user_id)

return router
3 changes: 2 additions & 1 deletion memgpt/server/rest_api/personas/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ async def create_persona(
):
interface.clear()
new_persona = PersonaModel(text=request.text, name=request.name, user_id=user_id)
persona_id = new_persona.id
server.ms.add_persona(new_persona)
return new_persona
return PersonaModel(id=persona_id, text=request.text, name=request.name, user_id=user_id)

return router
2 changes: 1 addition & 1 deletion memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ def update_agent_core_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_
"modified": modified,
}

def rename_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_agent_name: str) -> dict:
def rename_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_agent_name: str) -> AgentState:
"""Update the name of the agent in the database"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
Expand Down
130 changes: 96 additions & 34 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def run_server():
config.save()
credentials.save()

start_server(debug=False)
start_server(debug=True)


@pytest.fixture(scope="session", autouse=True)
Expand Down Expand Up @@ -125,8 +125,8 @@ def user_token():


# Fixture to create clients with different configurations
@pytest.fixture(params=[{"base_url": test_base_url}, {"base_url": None}], scope="module")
# @pytest.fixture(params=[{"base_url": test_base_url}], scope="module")
# @pytest.fixture(params=[{"base_url": test_base_url}, {"base_url": None}], scope="module")
@pytest.fixture(params=[{"base_url": test_base_url}], scope="module")
def client(request, user_token):
# use token or not
if request.param["base_url"]:
Expand All @@ -149,29 +149,99 @@ def agent(client):
client.delete_agent(agent_state.id)


# TODO: add back once REST API supports
# def test_create_preset(client):
#
# available_functions = load_all_function_sets(merge=True)
# functions_schema = [f_dict["json_schema"] for f_name, f_dict in available_functions.items()]
# preset = Preset(
# name=test_preset_name,
# user_id=test_user_id,
# description="A preset for testing the MemGPT client",
# system=gpt_system.get_system_text(DEFAULT_PRESET),
# functions_schema=functions_schema,
# )
# client.create_preset(preset)
def test_agent(client, agent):
# test client.rename_agent
new_name = "RenamedTestAgent"
client.rename_agent(agent_id=agent.id, new_name=new_name)
renamed_agent = client.get_agent(agent_id=str(agent.id))
assert renamed_agent.name == new_name, "Agent renaming failed"

# test client.delete_agent and client.agent_exists
delete_agent = client.create_agent(name="DeleteTestAgent", preset=test_preset_name)
assert client.agent_exists(agent_id=delete_agent.id), "Agent creation failed"
client.delete_agent(agent_id=delete_agent.id)
assert client.agent_exists(agent_id=delete_agent.id) == False, "Agent deletion failed"

# def test_create_agent(client):
# global test_agent_state
# test_agent_state = client.create_agent(
# name=test_agent_name,
# preset=test_preset_name,
# )
# print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}")
# assert test_agent_state is not None

def test_memory(client, agent):
memory_response = client.get_agent_memory(agent_id=agent.id)
print("MEMORY", memory_response)

updated_memory = {"human": "Updated human memory", "persona": "Updated persona memory"}
client.update_agent_core_memory(agent_id=str(agent.id), new_memory_contents=updated_memory)
updated_memory_response = client.get_agent_memory(agent_id=agent.id)
assert (
updated_memory_response.core_memory.human == updated_memory["human"]
and updated_memory_response.core_memory.persona == updated_memory["persona"]
), "Memory update failed"


def test_agent_interactions(client, agent):
message = "Hello, agent!"
message_response = client.user_message(agent_id=str(agent.id), message=message)

command = "/memory"
command_response = client.run_command(agent_id=str(agent.id), command=command)
print("command", command_response)


def test_archival_memory(client, agent):
memory_content = "Archival memory content"
insert_response = client.insert_archival_memory(agent_id=agent.id, memory=memory_content)
assert insert_response, "Inserting archival memory failed"

archival_memory_response = client.get_agent_archival_memory(agent_id=agent.id, limit=1)
archival_memories = [memory.contents for memory in archival_memory_response.archival_memory]
assert memory_content in archival_memories, f"Retrieving archival memory failed: {archival_memories}"

memory_id_to_delete = archival_memory_response.archival_memory[0].id
client.delete_archival_memory(agent_id=agent.id, memory_id=memory_id_to_delete)

# TODO: check deletion


def test_messages(client, agent):
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
assert send_message_response, "Sending message failed"

messages_response = client.get_messages(agent_id=agent.id, limit=1)
assert len(messages_response.messages) > 0, "Retrieving messages failed"


def test_humans_personas(client, agent):
humans_response = client.list_humans()
print("HUMANS", humans_response)

personas_response = client.list_personas()
print("PERSONAS", personas_response)

persona_name = "TestPersona"
persona = client.create_persona(name=persona_name, persona="Persona text")
assert persona.name == persona_name
assert persona.text == "Persona text", "Creating persona failed"

human_name = "TestHuman"
human = client.create_human(name=human_name, human="Human text")
assert human.name == human_name
assert human.text == "Human text", "Creating human failed"


def test_tools(client, agent):
tools_response = client.list_tools()
print("TOOLS", tools_response)

tool_name = "TestTool"
tool_response = client.create_tool(name=tool_name, source_code="print('Hello World')", source_type="python")
assert tool_response, "Creating tool failed"


def test_config(client, agent):
models_response = client.list_models()
print("MODELS", models_response)

config_response = client.get_config()
# TODO: ensure config is the same as the one in the server
print("CONFIG", config_response)


def test_sources(client, agent):
Expand All @@ -192,7 +262,7 @@ def test_sources(client, agent):
assert len(sources) == 1

# check agent archival memory size
archival_memories = client.get_agent_archival_memory(agent_id=agent.id)
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
print(archival_memories)
assert len(archival_memories) == 0

Expand All @@ -207,7 +277,7 @@ def test_sources(client, agent):
client.attach_source_to_agent(source_name="test_source", agent_id=agent.id)

# list archival memory
archival_memories = client.get_agent_archival_memory(agent_id=agent.id)
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
print(archival_memories)
assert len(archival_memories) == num_passages

Expand All @@ -217,11 +287,3 @@ def test_sources(client, agent):

# delete the source
client.delete_source(source.id)


# def test_user_message(client, agent):
# """Test that we can send a message through the client"""
# assert client is not None, "Run create_agent test first"
# print(f"\n\n[2] SENDING MESSAGE TO AGENT {agent.id}!!!\n\tmessages={agent.state['messages']}")
# response = client.user_message(agent_id=agent.id, message="Hello my name is Test, Client Test")
# assert response is not None and len(response) > 0
Loading