diff --git a/memgpt/client/client.py b/memgpt/client/client.py index d72e532a7d..71c430b643 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -5,6 +5,7 @@ from typing import Dict, List, Union, Optional, Tuple from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig, Source +from memgpt.models.pydantic_models import HumanModel, PersonaModel from memgpt.cli.cli import QuickstartChoice from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice from memgpt.config import MemGPTConfig @@ -13,6 +14,23 @@ from memgpt.metadata import MetadataStore from memgpt.data_sources.connectors import DataConnector +# import pydantic response objects from memgpt.server.rest_api +from memgpt.server.rest_api.agents.command import CommandResponse +from memgpt.server.rest_api.agents.config import GetAgentResponse +from memgpt.server.rest_api.agents.memory import ( + GetAgentMemoryResponse, + GetAgentArchivalMemoryResponse, + UpdateAgentMemoryResponse, + InsertAgentArchivalMemoryResponse, +) +from memgpt.server.rest_api.agents.index import ListAgentsResponse, CreateAgentResponse +from memgpt.server.rest_api.agents.message import UserMessageResponse, GetAgentMessagesResponse +from memgpt.server.rest_api.config.index import ConfigResponse +from memgpt.server.rest_api.humans.index import ListHumansResponse +from memgpt.server.rest_api.personas.index import ListPersonasResponse +from memgpt.server.rest_api.tools.index import ListToolsResponse, CreateToolResponse +from memgpt.server.rest_api.models.index import ListModelsResponse + def create_client(base_url: Optional[str] = None, token: Optional[str] = None): if base_url is None: @@ -30,6 +48,8 @@ def __init__( self.auto_save = auto_save self.debug = debug + # agents + def list_agents(self): """List all agents associated with a given user.""" raise NotImplementedError @@ -50,18 +70,31 @@ def create_agent( """Create a new agent with the specified configuration.""" raise NotImplementedError - def create_preset(self, preset: Preset): + def rename_agent(self, agent_id: uuid.UUID, new_name: str): + """Rename the agent.""" + raise NotImplementedError + + def delete_agent(self, agent_id: uuid.UUID): + """Delete the agent.""" raise NotImplementedError def get_agent(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> AgentState: raise NotImplementedError + # presets + def create_preset(self, preset: Preset): + raise NotImplementedError + + # memory + def get_agent_memory(self, agent_id: str) -> Dict: raise NotImplementedError def update_agent_core_memory(self, agent_id: str, human: Optional[str] = None, persona: Optional[str] = None) -> Dict: raise NotImplementedError + # agent interactions + def user_message(self, agent_id: str, message: str) -> Union[List[Dict], Tuple[List[Dict], int]]: raise NotImplementedError @@ -71,6 +104,64 @@ def run_command(self, agent_id: str, command: str) -> Union[str, None]: def save(self): raise NotImplementedError + # archival memory + + def get_agent_archival_memory( + self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000 + ): + """Paginated get for the archival memory for an agent""" + raise NotImplementedError + + def insert_archival_memory(self, agent_id: uuid.UUID, memory: str): + """Insert archival memory into the agent.""" + raise NotImplementedError + + def delete_archival_memory(self, agent_id: uuid.UUID, memory_id: uuid.UUID): + """Delete archival memory from the agent.""" + raise NotImplementedError + + # messages (recall memory) + + def get_messages( + self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000 + ): + """Get messages for the agent.""" + raise NotImplementedError + + def send_message(self, agent_id: uuid.UUID, message: str, role: str, stream: Optional[bool] = False): + """Send a message to the agent.""" + raise NotImplementedError + + # humans / personas + + def list_humans(self): + """List all humans.""" + raise NotImplementedError + + def create_human(self, name: str, human: str): + """Create a human.""" + raise NotImplementedError + + def list_personas(self): + """List all personas.""" + raise NotImplementedError + + def create_persona(self, name: str, persona: str): + """Create a persona.""" + raise NotImplementedError + + # tools + + def list_tools(self): + """List all tools.""" + raise NotImplementedError + + def create_tool(self, name: str, source_code: str, source_type: str, tags: Optional[List[str]] = None): + """Create a tool.""" + raise NotImplementedError + + # data sources + def list_sources(self): """List loaded sources""" raise NotImplementedError @@ -95,10 +186,14 @@ def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID): """Detach a source from an agent""" raise NotImplementedError - def get_agent_archival_memory( - self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000 - ): - """Paginated get for the archival memory for an agent""" + # server configuration commands + + def list_models(self): + """List all models.""" + raise NotImplementedError + + def get_config(self): + """Get server config""" raise NotImplementedError @@ -113,13 +208,23 @@ def __init__( self.base_url = base_url self.headers = {"accept": "application/json", "authorization": f"Bearer {token}"} + # agents + def list_agents(self): - response = requests.get(f"{self.base_url}/agents", headers=self.headers) - print(response.text) + response = requests.get(f"{self.base_url}/api/agents", headers=self.headers) + return ListAgentsResponse(**response.json()) def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool: - response = requests.get(f"{self.base_url}/agents/config?agent_id={str(agent_id)}", headers=self.headers) - print(response.text) + response = requests.get(f"{self.base_url}/api/agents/{str(agent_id)}/config", headers=self.headers) + print(response.text, response.status_code) + print(response) + if response.status_code == 404: + # not found error + return False + elif response.status_code == 200: + return True + else: + raise ValueError(f"Failed to check if agent exists: {response.text}") def create_agent( self, @@ -143,55 +248,163 @@ def create_agent( response = requests.post(f"{self.base_url}/api/agents", json=payload, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to create agent: {response.text}") - response_json = response.json() - print(response_json) - llm_config = LLMConfig(**response_json["agent_state"]["llm_config"]) - embedding_config = EmbeddingConfig(**response_json["agent_state"]["embedding_config"]) + response_obj = CreateAgentResponse(**response.json()) + return self.get_agent_response_to_state(response_obj) + + def get_agent_response_to_state(self, response: Union[GetAgentResponse, CreateAgentResponse]) -> AgentState: + # TODO: eventually remove this conversion + llm_config = LLMConfig( + model=response.agent_state.llm_config.model, + model_endpoint_type=response.agent_state.llm_config.model_endpoint_type, + model_endpoint=response.agent_state.llm_config.model_endpoint, + model_wrapper=response.agent_state.llm_config.model_wrapper, + context_window=response.agent_state.llm_config.context_window, + ) + embedding_config = EmbeddingConfig( + embedding_endpoint_type=response.agent_state.embedding_config.embedding_endpoint_type, + embedding_endpoint=response.agent_state.embedding_config.embedding_endpoint, + embedding_model=response.agent_state.embedding_config.embedding_model, + embedding_dim=response.agent_state.embedding_config.embedding_dim, + embedding_chunk_size=response.agent_state.embedding_config.embedding_chunk_size, + ) agent_state = AgentState( - id=uuid.UUID(response_json["agent_state"]["id"]), - name=response_json["agent_state"]["name"], - user_id=uuid.UUID(response_json["agent_state"]["user_id"]), - preset=response_json["agent_state"]["preset"], - persona=response_json["agent_state"]["persona"], - human=response_json["agent_state"]["human"], + id=response.agent_state.id, + name=response.agent_state.name, + user_id=response.agent_state.user_id, + preset=response.agent_state.preset, + persona=response.agent_state.persona, + human=response.agent_state.human, llm_config=llm_config, embedding_config=embedding_config, - state=response_json["agent_state"]["state"], + state=response.agent_state.state, # load datetime from timestampe - created_at=datetime.datetime.fromtimestamp(response_json["agent_state"]["created_at"]), + created_at=datetime.datetime.fromtimestamp(response.agent_state.created_at), ) return agent_state + def rename_agent(self, agent_id: uuid.UUID, new_name: str): + response = requests.patch(f"{self.base_url}/api/agents/{str(agent_id)}/rename", json={"agent_name": new_name}, headers=self.headers) + assert response.status_code == 200, f"Failed to rename agent: {response.text}" + response_obj = GetAgentResponse(**response.json()) + return self.get_agent_response_to_state(response_obj) + def delete_agent(self, agent_id: uuid.UUID): + """Delete the agent.""" response = requests.delete(f"{self.base_url}/api/agents/{str(agent_id)}", headers=self.headers) assert response.status_code == 200, f"Failed to delete agent: {response.text}" + def get_agent(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> AgentState: + response = requests.get(f"{self.base_url}/api/agents/{str(agent_id)}/config", headers=self.headers) + assert response.status_code == 200, f"Failed to get agent: {response.text}" + response_obj = GetAgentResponse(**response.json()) + return self.get_agent_response_to_state(response_obj) + + # presets def create_preset(self, preset: Preset): raise NotImplementedError - def get_agent_config(self, agent_id: uuid.UUID) -> AgentState: - raise NotImplementedError + # memory - def get_agent_memory(self, agent_id: uuid.UUID) -> Dict: - raise NotImplementedError + def get_agent_memory(self, agent_id: uuid.UUID) -> GetAgentMemoryResponse: + response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory", headers=self.headers) + return GetAgentMemoryResponse(**response.json()) - def update_agent_core_memory(self, agent_id: str, new_memory_contents: Dict) -> Dict: - raise NotImplementedError + def update_agent_core_memory(self, agent_id: str, new_memory_contents: Dict) -> UpdateAgentMemoryResponse: + response = requests.post(f"{self.base_url}/api/agents/{agent_id}/memory", json=new_memory_contents, headers=self.headers) + return UpdateAgentMemoryResponse(**response.json()) + + # agent interactions def user_message(self, agent_id: str, message: str) -> Union[List[Dict], Tuple[List[Dict], int]]: - # TODO: support role? what is return_token_count? - payload = {"agent_id": str(agent_id), "message": message} - response = requests.post(f"{self.base_url}/api/agents/message", json=payload, headers=self.headers) - response_json = response.json() - print(response_json) - return response_json + return self.send_message(agent_id, message, role="user") def run_command(self, agent_id: str, command: str) -> Union[str, None]: - raise NotImplementedError + response = requests.post(f"{self.base_url}/api/agents/{str(agent_id)}/command", json={"command": command}, headers=self.headers) + return CommandResponse(**response.json()) def save(self): raise NotImplementedError + # archival memory + + def get_agent_archival_memory( + self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000 + ): + """Paginated get for the archival memory for an agent""" + params = {"limit": limit} + if before: + params["before"] = str(before) + if after: + params["after"] = str(after) + response = requests.get(f"{self.base_url}/api/agents/{str(agent_id)}/archival", params=params, headers=self.headers) + assert response.status_code == 200, f"Failed to get archival memory: {response.text}" + return GetAgentArchivalMemoryResponse(**response.json()) + + def insert_archival_memory(self, agent_id: uuid.UUID, memory: str) -> GetAgentArchivalMemoryResponse: + response = requests.post(f"{self.base_url}/api/agents/{agent_id}/archival", json={"content": memory}, headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to insert archival memory: {response.text}") + print(response.json()) + return InsertAgentArchivalMemoryResponse(**response.json()) + + def delete_archival_memory(self, agent_id: uuid.UUID, memory_id: uuid.UUID): + response = requests.delete(f"{self.base_url}/api/agents/{agent_id}/archival?id={memory_id}", headers=self.headers) + assert response.status_code == 200, f"Failed to delete archival memory: {response.text}" + + # messages (recall memory) + + def get_messages( + self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000 + ) -> GetAgentMessagesResponse: + params = {"before": before, "after": after, "limit": limit} + response = requests.get(f"{self.base_url}/api/agents/{agent_id}/messages-cursor", params=params, headers=self.headers) + return GetAgentMessagesResponse(**response.json()) + + def send_message(self, agent_id: uuid.UUID, message: str, role: str, stream: Optional[bool] = False) -> UserMessageResponse: + data = {"message": message, "role": role, "stream": stream} + response = requests.post(f"{self.base_url}/api/agents/{agent_id}/messages", json=data, headers=self.headers) + return UserMessageResponse(**response.json()) + + # humans / personas + + def list_humans(self) -> ListHumansResponse: + response = requests.get(f"{self.base_url}/api/humans", headers=self.headers) + return ListHumansResponse(**response.json()) + + def create_human(self, name: str, human: str) -> HumanModel: + data = {"name": name, "text": human} + response = requests.post(f"{self.base_url}/api/humans", json=data, headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to create human: {response.text}") + + print(response.json()) + return HumanModel(**response.json()) + + def list_personas(self) -> ListPersonasResponse: + response = requests.get(f"{self.base_url}/api/personas", headers=self.headers) + return ListPersonasResponse(**response.json()) + + def create_persona(self, name: str, persona: str) -> PersonaModel: + data = {"name": name, "text": persona} + response = requests.post(f"{self.base_url}/api/personas", json=data, headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to create persona: {response.text}") + print(response.json()) + return PersonaModel(**response.json()) + + # tools + + def list_tools(self) -> ListToolsResponse: + response = requests.get(f"{self.base_url}/api/tools", headers=self.headers) + return ListToolsResponse(**response.json()) + + def create_tool(self, name: str, source_code: str, source_type: str, tags: Optional[List[str]] = None) -> CreateToolResponse: + data = {"name": name, "source_code": source_code, "source_type": source_type, "tags": tags} + response = requests.post(f"{self.base_url}/api/tools", json=data, headers=self.headers) + return CreateToolResponse(**response.json()) + + # sources + def list_sources(self): """List loaded sources""" response = requests.get(f"{self.base_url}/api/sources", headers=self.headers) @@ -239,18 +452,15 @@ def detach_source(self, source_name: str, agent_id: uuid.UUID): assert response.status_code == 200, f"Failed to detach source from agent: {response.text}" return response.json() - def get_agent_archival_memory( - self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000 - ): - """Paginated get for the archival memory for an agent""" - params = {"limit": limit} - if before: - params["before"] = str(before) - if after: - params["after"] = str(after) - response = requests.get(f"{self.base_url}/api/agents/{str(agent_id)}/archival", params=params, headers=self.headers) - assert response.status_code == 200, f"Failed to get archival memory: {response.text}" - return response.json()["archival_memory"] + # server configuration commands + + def list_models(self) -> ListModelsResponse: + response = requests.get(f"{self.base_url}/api/models", headers=self.headers) + return ListModelsResponse(**response.json()) + + def get_config(self) -> ConfigResponse: + response = requests.get(f"{self.base_url}/api/config", headers=self.headers) + return ConfigResponse(**response.json()) class LocalClient(AbstractClient): diff --git a/memgpt/memory.py b/memgpt/memory.py index b95e66ba2a..e3b967635f 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -393,7 +393,7 @@ def insert(self, memory_string, return_ids=False) -> Union[bool, List[uuid.UUID] """Embed and save memory string""" if not isinstance(memory_string, str): - return TypeError("memory must be a string") + raise TypeError("memory must be a string") try: passages = [] diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index 2b1444874e..7398b6222e 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -75,6 +75,11 @@ class AgentStateModel(BaseModel): state: Optional[Dict] = Field(None, description="The state of the agent.") +class CoreMemory(BaseModel): + human: str = Field(..., description="Human element of the core memory.") + persona: str = Field(..., description="Persona element of the core memory.") + + class HumanModel(SQLModel, table=True): text: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human text.") name: str = Field(..., description="The name of the human.") diff --git a/memgpt/server/rest_api/agents/config.py b/memgpt/server/rest_api/agents/config.py index 2f2e3ce051..d0f24fbff2 100644 --- a/memgpt/server/rest_api/agents/config.py +++ b/memgpt/server/rest_api/agents/config.py @@ -56,12 +56,17 @@ def get_agent_config( This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs. """ - # agent_id = uuid.UUID(request.agent_id) if request.agent_id else None - attached_sources = server.list_attached_sources(agent_id=agent_id) interface.clear() + if not server.ms.get_agent(user_id=user_id, agent_id=agent_id): + # agent does not exist + raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.") + agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id) - # return GetAgentResponse(agent_state=agent_state) + # get sources + attached_sources = server.list_attached_sources(agent_id=agent_id) + + # configs llm_config = LLMConfigModel(**vars(agent_state.llm_config)) embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config)) @@ -73,8 +78,8 @@ def get_agent_config( preset=agent_state.preset, persona=agent_state.persona, human=agent_state.human, - llm_config=agent_state.llm_config, - embedding_config=agent_state.embedding_config, + llm_config=llm_config, + embedding_config=embedding_config, state=agent_state.state, created_at=int(agent_state.created_at.timestamp()), functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead @@ -101,11 +106,32 @@ def update_agent_name( interface.clear() try: agent_state = server.rename_agent(user_id=user_id, agent_id=agent_id, new_agent_name=valid_name) + # get sources + attached_sources = server.list_attached_sources(agent_id=agent_id) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") - return GetAgentResponse(agent_state=agent_state) + llm_config = LLMConfigModel(**vars(agent_state.llm_config)) + embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config)) + + return GetAgentResponse( + agent_state=AgentStateModel( + id=agent_state.id, + name=agent_state.name, + user_id=agent_state.user_id, + preset=agent_state.preset, + persona=agent_state.persona, + human=agent_state.human, + llm_config=llm_config, + embedding_config=embedding_config, + state=agent_state.state, + created_at=int(agent_state.created_at.timestamp()), + functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead + ), + last_run_at=None, # TODO + sources=attached_sources, + ) @router.delete("/agents/{agent_id}", tags=["agents"]) def delete_agent( diff --git a/memgpt/server/rest_api/agents/memory.py b/memgpt/server/rest_api/agents/memory.py index 1adfcbaadb..2447ab04e7 100644 --- a/memgpt/server/rest_api/agents/memory.py +++ b/memgpt/server/rest_api/agents/memory.py @@ -46,7 +46,7 @@ class GetAgentArchivalMemoryResponse(BaseModel): class InsertAgentArchivalMemoryRequest(BaseModel): - content: str = Field(None, description="The memory contents to insert into archival memory.") + content: str = Field(..., description="The memory contents to insert into archival memory.") class InsertAgentArchivalMemoryResponse(BaseModel): @@ -87,8 +87,6 @@ def update_agent_memory( This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID. """ - agent_id = uuid.UUID(request.agent_id) if request.agent_id else None - interface.clear() new_memory_contents = {"persona": request.persona, "human": request.human} diff --git a/memgpt/server/rest_api/humans/index.py b/memgpt/server/rest_api/humans/index.py index a2e7242648..80af6b4c44 100644 --- a/memgpt/server/rest_api/humans/index.py +++ b/memgpt/server/rest_api/humans/index.py @@ -35,13 +35,14 @@ async def list_humans( return ListHumansResponse(humans=humans) @router.post("/humans", tags=["humans"], response_model=HumanModel) - async def create_persona( + async def create_human( request: CreateHumanRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server), ): interface.clear() new_human = HumanModel(text=request.text, name=request.name, user_id=user_id) + human_id = new_human.id server.ms.add_human(new_human) - return new_human + return HumanModel(id=human_id, text=request.text, name=request.name, user_id=user_id) return router diff --git a/memgpt/server/rest_api/personas/index.py b/memgpt/server/rest_api/personas/index.py index dd30370a29..76cb118208 100644 --- a/memgpt/server/rest_api/personas/index.py +++ b/memgpt/server/rest_api/personas/index.py @@ -42,7 +42,8 @@ async def create_persona( ): interface.clear() new_persona = PersonaModel(text=request.text, name=request.name, user_id=user_id) + persona_id = new_persona.id server.ms.add_persona(new_persona) - return new_persona + return PersonaModel(id=persona_id, text=request.text, name=request.name, user_id=user_id) return router diff --git a/memgpt/server/server.py b/memgpt/server/server.py index af92d4c94b..dda7be3c46 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1105,7 +1105,7 @@ def update_agent_core_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_ "modified": modified, } - def rename_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_agent_name: str) -> dict: + def rename_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_agent_name: str) -> AgentState: """Update the name of the agent in the database""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") diff --git a/tests/test_client.py b/tests/test_client.py index 0b9e226915..6c25091e0f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -95,7 +95,7 @@ def run_server(): config.save() credentials.save() - start_server(debug=False) + start_server(debug=True) @pytest.fixture(scope="session", autouse=True) @@ -125,8 +125,8 @@ def user_token(): # Fixture to create clients with different configurations -@pytest.fixture(params=[{"base_url": test_base_url}, {"base_url": None}], scope="module") -# @pytest.fixture(params=[{"base_url": test_base_url}], scope="module") +# @pytest.fixture(params=[{"base_url": test_base_url}, {"base_url": None}], scope="module") +@pytest.fixture(params=[{"base_url": test_base_url}], scope="module") def client(request, user_token): # use token or not if request.param["base_url"]: @@ -149,29 +149,99 @@ def agent(client): client.delete_agent(agent_state.id) -# TODO: add back once REST API supports -# def test_create_preset(client): -# -# available_functions = load_all_function_sets(merge=True) -# functions_schema = [f_dict["json_schema"] for f_name, f_dict in available_functions.items()] -# preset = Preset( -# name=test_preset_name, -# user_id=test_user_id, -# description="A preset for testing the MemGPT client", -# system=gpt_system.get_system_text(DEFAULT_PRESET), -# functions_schema=functions_schema, -# ) -# client.create_preset(preset) +def test_agent(client, agent): + # test client.rename_agent + new_name = "RenamedTestAgent" + client.rename_agent(agent_id=agent.id, new_name=new_name) + renamed_agent = client.get_agent(agent_id=str(agent.id)) + assert renamed_agent.name == new_name, "Agent renaming failed" + # test client.delete_agent and client.agent_exists + delete_agent = client.create_agent(name="DeleteTestAgent", preset=test_preset_name) + assert client.agent_exists(agent_id=delete_agent.id), "Agent creation failed" + client.delete_agent(agent_id=delete_agent.id) + assert client.agent_exists(agent_id=delete_agent.id) == False, "Agent deletion failed" -# def test_create_agent(client): -# global test_agent_state -# test_agent_state = client.create_agent( -# name=test_agent_name, -# preset=test_preset_name, -# ) -# print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}") -# assert test_agent_state is not None + +def test_memory(client, agent): + memory_response = client.get_agent_memory(agent_id=agent.id) + print("MEMORY", memory_response) + + updated_memory = {"human": "Updated human memory", "persona": "Updated persona memory"} + client.update_agent_core_memory(agent_id=str(agent.id), new_memory_contents=updated_memory) + updated_memory_response = client.get_agent_memory(agent_id=agent.id) + assert ( + updated_memory_response.core_memory.human == updated_memory["human"] + and updated_memory_response.core_memory.persona == updated_memory["persona"] + ), "Memory update failed" + + +def test_agent_interactions(client, agent): + message = "Hello, agent!" + message_response = client.user_message(agent_id=str(agent.id), message=message) + + command = "/memory" + command_response = client.run_command(agent_id=str(agent.id), command=command) + print("command", command_response) + + +def test_archival_memory(client, agent): + memory_content = "Archival memory content" + insert_response = client.insert_archival_memory(agent_id=agent.id, memory=memory_content) + assert insert_response, "Inserting archival memory failed" + + archival_memory_response = client.get_agent_archival_memory(agent_id=agent.id, limit=1) + archival_memories = [memory.contents for memory in archival_memory_response.archival_memory] + assert memory_content in archival_memories, f"Retrieving archival memory failed: {archival_memories}" + + memory_id_to_delete = archival_memory_response.archival_memory[0].id + client.delete_archival_memory(agent_id=agent.id, memory_id=memory_id_to_delete) + + # TODO: check deletion + + +def test_messages(client, agent): + send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user") + assert send_message_response, "Sending message failed" + + messages_response = client.get_messages(agent_id=agent.id, limit=1) + assert len(messages_response.messages) > 0, "Retrieving messages failed" + + +def test_humans_personas(client, agent): + humans_response = client.list_humans() + print("HUMANS", humans_response) + + personas_response = client.list_personas() + print("PERSONAS", personas_response) + + persona_name = "TestPersona" + persona = client.create_persona(name=persona_name, persona="Persona text") + assert persona.name == persona_name + assert persona.text == "Persona text", "Creating persona failed" + + human_name = "TestHuman" + human = client.create_human(name=human_name, human="Human text") + assert human.name == human_name + assert human.text == "Human text", "Creating human failed" + + +def test_tools(client, agent): + tools_response = client.list_tools() + print("TOOLS", tools_response) + + tool_name = "TestTool" + tool_response = client.create_tool(name=tool_name, source_code="print('Hello World')", source_type="python") + assert tool_response, "Creating tool failed" + + +def test_config(client, agent): + models_response = client.list_models() + print("MODELS", models_response) + + config_response = client.get_config() + # TODO: ensure config is the same as the one in the server + print("CONFIG", config_response) def test_sources(client, agent): @@ -192,7 +262,7 @@ def test_sources(client, agent): assert len(sources) == 1 # check agent archival memory size - archival_memories = client.get_agent_archival_memory(agent_id=agent.id) + archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory print(archival_memories) assert len(archival_memories) == 0 @@ -207,7 +277,7 @@ def test_sources(client, agent): client.attach_source_to_agent(source_name="test_source", agent_id=agent.id) # list archival memory - archival_memories = client.get_agent_archival_memory(agent_id=agent.id) + archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory print(archival_memories) assert len(archival_memories) == num_passages @@ -217,11 +287,3 @@ def test_sources(client, agent): # delete the source client.delete_source(source.id) - - -# def test_user_message(client, agent): -# """Test that we can send a message through the client""" -# assert client is not None, "Run create_agent test first" -# print(f"\n\n[2] SENDING MESSAGE TO AGENT {agent.id}!!!\n\tmessages={agent.state['messages']}") -# response = client.user_message(agent_id=agent.id, message="Hello my name is Test, Client Test") -# assert response is not None and len(response) > 0