From d664e25cef4aa0a34c709ab5b9899bad8362ca0c Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Wed, 20 Mar 2024 21:37:52 -0700 Subject: [PATCH] feat: move `source_id` to path variable (#1171) --- memgpt/client/client.py | 19 +++++++++---------- memgpt/server/rest_api/sources/index.py | 24 ++++++++++++------------ memgpt/server/server.py | 20 ++++++++++++++++---- tests/test_client.py | 4 ++-- tests/test_server.py | 2 +- 5 files changed, 40 insertions(+), 29 deletions(-) diff --git a/memgpt/client/client.py b/memgpt/client/client.py index 61b9156eb2..c238a021ea 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -446,9 +446,8 @@ def delete_source(self, source_id: uuid.UUID): def load_file_into_source(self, filename: str, source_id: uuid.UUID): """Load {filename} and insert into source""" - params = {"source_id": str(source_id)} files = {"file": open(filename, "rb")} - response = requests.post(f"{self.base_url}/api/sources/upload", files=files, params=params, headers=self.headers) + response = requests.post(f"{self.base_url}/api/sources/{source_id}/upload", files=files, headers=self.headers) return response.json() def create_source(self, name: str) -> Source: @@ -466,17 +465,17 @@ def create_source(self, name: str) -> Source: embedding_model=response_json["embedding_config"]["embedding_model"], ) - def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID): + def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID): """Attach a source to an agent""" - params = {"source_name": source_name, "agent_id": agent_id} - response = requests.post(f"{self.base_url}/api/sources/attach", params=params, headers=self.headers) + params = {"agent_id": agent_id} + response = requests.post(f"{self.base_url}/api/sources/{source_id}/attach", params=params, headers=self.headers) assert response.status_code == 200, f"Failed to attach source to agent: {response.text}" return response.json() - def detach_source(self, source_name: str, agent_id: uuid.UUID): + def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID): """Detach a source from an agent""" - params = {"source_name": source_name, "agent_id": str(agent_id)} - response = requests.post(f"{self.base_url}/api/sources/detach", params=params, headers=self.headers) + params = {"agent_id": str(agent_id)} + response = requests.post(f"{self.base_url}/api/sources/{source_id}/detach", params=params, headers=self.headers) assert response.status_code == 200, f"Failed to detach source from agent: {response.text}" return response.json() @@ -615,8 +614,8 @@ def load_data(self, connector: DataConnector, source_name: str): def create_source(self, name: str): self.server.create_source(user_id=self.user_id, name=name) - def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID): - self.server.attach_source_to_agent(user_id=self.user_id, source_name=source_name, agent_id=agent_id) + def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID): + self.server.attach_source_to_agent(user_id=self.user_id, source_id=source_id, agent_id=agent_id) def delete_agent(self, agent_id: uuid.UUID): self.server.delete_agent(user_id=self.user_id, agent_id=agent_id) diff --git a/memgpt/server/rest_api/sources/index.py b/memgpt/server/rest_api/sources/index.py index d7e8f69ebe..5c7ae9deeb 100644 --- a/memgpt/server/rest_api/sources/index.py +++ b/memgpt/server/rest_api/sources/index.py @@ -118,17 +118,17 @@ async def delete_source( except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") - @router.post("/sources/attach", tags=["sources"], response_model=SourceModel) + @router.post("/sources/{source_id}/attach", tags=["sources"], response_model=SourceModel) async def attach_source_to_agent( + source_id: uuid.UUID, agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to attach the source to."), - source_name: str = Query(..., description="The name of the source to attach."), user_id: uuid.UUID = Depends(get_current_user_with_server), ): """Attach a data source to an existing agent.""" interface.clear() assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}" assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}" - source = server.attach_source_to_agent(source_name=source_name, agent_id=agent_id, user_id=user_id) + source = server.attach_source_to_agent(source_id=source_id, agent_id=agent_id, user_id=user_id) return SourceModel( name=source.name, description=None, # TODO: actually store descriptions @@ -138,20 +138,20 @@ async def attach_source_to_agent( created_at=source.created_at, ) - @router.post("/sources/detach", tags=["sources"], response_model=SourceModel) + @router.post("/sources/{source_id}/detach", tags=["sources"], response_model=SourceModel) async def detach_source_from_agent( + source_id: uuid.UUID, agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to detach the source from."), - source_name: str = Query(..., description="The name of the source to detach."), user_id: uuid.UUID = Depends(get_current_user_with_server), ): """Detach a data source from an existing agent.""" - server.detach_source_from_agent(source_name=source_name, agent_id=agent_id, user_id=user_id) + server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=user_id) - @router.post("/sources/upload", tags=["sources"], response_model=UploadFileToSourceResponse) + @router.post("/sources/{source_id}/upload", tags=["sources"], response_model=UploadFileToSourceResponse) async def upload_file_to_source( # file: UploadFile = UploadFile(..., description="The file to upload."), file: UploadFile, - source_id: uuid.UUID = Query(..., description="The unique identifier of the source to attach."), + source_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user_with_server), ): """Upload a file to a data source.""" @@ -173,18 +173,18 @@ async def upload_file_to_source( # TODO: actually return added passages/documents return UploadFileToSourceResponse(source=source, added_passages=passage_count, added_documents=document_count) - @router.get("/sources/passages ", tags=["sources"], response_model=GetSourcePassagesResponse) + @router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=GetSourcePassagesResponse) async def list_passages( - source_id: uuid.UUID = Body(...), + source_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user_with_server), ): """List all passages associated with a data source.""" passages = server.list_data_source_passages(user_id=user_id, source_id=source_id) return GetSourcePassagesResponse(passages=passages) - @router.get("/sources/documents", tags=["sources"], response_model=GetSourceDocumentsResponse) + @router.get("/sources/{source_id}/documents", tags=["sources"], response_model=GetSourceDocumentsResponse) async def list_documents( - source_id: uuid.UUID = Body(...), + source_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user_with_server), ): """List all documents associated with a data source.""" diff --git a/memgpt/server/server.py b/memgpt/server/server.py index e424c9a1ff..dcb4430997 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1293,11 +1293,17 @@ def load_data( passage_count, document_count = load_data(connector, source, self.config.default_embedding_config, passage_store, document_store) return passage_count, document_count - def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str): + def attach_source_to_agent( + self, + user_id: uuid.UUID, + agent_id: uuid.UUID, + source_id: Optional[uuid.UUID] = None, + source_name: Optional[str] = None, + ): # attach a data source to an agent - data_source = self.ms.get_source(source_name=source_name, user_id=user_id) + data_source = self.ms.get_source(source_id=source_id, user_id=user_id, source_name=source_name) if data_source is None: - raise ValueError(f"Data source {source_name} does not exist for user_id {user_id}") + raise ValueError(f"Data source id={source_id} name={source_name} does not exist for user_id {user_id}") # get connection to data source storage source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) @@ -1310,7 +1316,13 @@ def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source return data_source - def detach_source_from_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str): + def detach_source_from_agent( + self, + user_id: uuid.UUID, + agent_id: uuid.UUID, + source_id: Optional[uuid.UUID] = None, + source_name: Optional[str] = None, + ): # TODO: remove all passages coresponding to source from agent's archival memory raise NotImplementedError diff --git a/tests/test_client.py b/tests/test_client.py index 7e67759d86..25f7d1098f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -270,12 +270,12 @@ def test_sources(client, agent): # load a file into a source filename = "CONTRIBUTING.md" num_passages = 20 - response = client.load_file_into_source(filename, source.id) + response = client.load_file_into_source(filename=filename, source_id=source.id) print(response) # attach a source # TODO: make sure things run in the right order - client.attach_source_to_agent(source_name="test_source", agent_id=agent.id) + client.attach_source_to_agent(source_id=source.id, agent_id=agent.id) # list archival memory archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory diff --git a/tests/test_server.py b/tests/test_server.py index 8349055763..1f441e331b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -157,7 +157,7 @@ def test_attach_source_to_agent(server, user_id, agent_id): assert len(passages_before) == 0 # attach source - server.attach_source_to_agent(user_id, agent_id, "test_source") + server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source") # check archival memory size passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)