Skip to content

Commit

Permalink
feat: move source_id to path variable (#1171)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Mar 21, 2024
1 parent 89cc4b9 commit d664e25
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 29 deletions.
19 changes: 9 additions & 10 deletions memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,8 @@ def delete_source(self, source_id: uuid.UUID):

def load_file_into_source(self, filename: str, source_id: uuid.UUID):
"""Load {filename} and insert into source"""
params = {"source_id": str(source_id)}
files = {"file": open(filename, "rb")}
response = requests.post(f"{self.base_url}/api/sources/upload", files=files, params=params, headers=self.headers)
response = requests.post(f"{self.base_url}/api/sources/{source_id}/upload", files=files, headers=self.headers)
return response.json()

def create_source(self, name: str) -> Source:
Expand All @@ -466,17 +465,17 @@ def create_source(self, name: str) -> Source:
embedding_model=response_json["embedding_config"]["embedding_model"],
)

def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
"""Attach a source to an agent"""
params = {"source_name": source_name, "agent_id": agent_id}
response = requests.post(f"{self.base_url}/api/sources/attach", params=params, headers=self.headers)
params = {"agent_id": agent_id}
response = requests.post(f"{self.base_url}/api/sources/{source_id}/attach", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
return response.json()

def detach_source(self, source_name: str, agent_id: uuid.UUID):
def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID):
"""Detach a source from an agent"""
params = {"source_name": source_name, "agent_id": str(agent_id)}
response = requests.post(f"{self.base_url}/api/sources/detach", params=params, headers=self.headers)
params = {"agent_id": str(agent_id)}
response = requests.post(f"{self.base_url}/api/sources/{source_id}/detach", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
return response.json()

Expand Down Expand Up @@ -615,8 +614,8 @@ def load_data(self, connector: DataConnector, source_name: str):
def create_source(self, name: str):
self.server.create_source(user_id=self.user_id, name=name)

def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):
self.server.attach_source_to_agent(user_id=self.user_id, source_name=source_name, agent_id=agent_id)
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
self.server.attach_source_to_agent(user_id=self.user_id, source_id=source_id, agent_id=agent_id)

def delete_agent(self, agent_id: uuid.UUID):
self.server.delete_agent(user_id=self.user_id, agent_id=agent_id)
Expand Down
24 changes: 12 additions & 12 deletions memgpt/server/rest_api/sources/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,17 @@ async def delete_source(
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")

@router.post("/sources/attach", tags=["sources"], response_model=SourceModel)
@router.post("/sources/{source_id}/attach", tags=["sources"], response_model=SourceModel)
async def attach_source_to_agent(
source_id: uuid.UUID,
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to attach the source to."),
source_name: str = Query(..., description="The name of the source to attach."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Attach a data source to an existing agent."""
interface.clear()
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
source = server.attach_source_to_agent(source_name=source_name, agent_id=agent_id, user_id=user_id)
source = server.attach_source_to_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
return SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
Expand All @@ -138,20 +138,20 @@ async def attach_source_to_agent(
created_at=source.created_at,
)

@router.post("/sources/detach", tags=["sources"], response_model=SourceModel)
@router.post("/sources/{source_id}/detach", tags=["sources"], response_model=SourceModel)
async def detach_source_from_agent(
source_id: uuid.UUID,
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to detach the source from."),
source_name: str = Query(..., description="The name of the source to detach."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Detach a data source from an existing agent."""
server.detach_source_from_agent(source_name=source_name, agent_id=agent_id, user_id=user_id)
server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)

@router.post("/sources/upload", tags=["sources"], response_model=UploadFileToSourceResponse)
@router.post("/sources/{source_id}/upload", tags=["sources"], response_model=UploadFileToSourceResponse)
async def upload_file_to_source(
# file: UploadFile = UploadFile(..., description="The file to upload."),
file: UploadFile,
source_id: uuid.UUID = Query(..., description="The unique identifier of the source to attach."),
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Upload a file to a data source."""
Expand All @@ -173,18 +173,18 @@ async def upload_file_to_source(
# TODO: actually return added passages/documents
return UploadFileToSourceResponse(source=source, added_passages=passage_count, added_documents=document_count)

@router.get("/sources/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
@router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
async def list_passages(
source_id: uuid.UUID = Body(...),
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all passages associated with a data source."""
passages = server.list_data_source_passages(user_id=user_id, source_id=source_id)
return GetSourcePassagesResponse(passages=passages)

@router.get("/sources/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
@router.get("/sources/{source_id}/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
async def list_documents(
source_id: uuid.UUID = Body(...),
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all documents associated with a data source."""
Expand Down
20 changes: 16 additions & 4 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,11 +1293,17 @@ def load_data(
passage_count, document_count = load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)
return passage_count, document_count

def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
def attach_source_to_agent(
self,
user_id: uuid.UUID,
agent_id: uuid.UUID,
source_id: Optional[uuid.UUID] = None,
source_name: Optional[str] = None,
):
# attach a data source to an agent
data_source = self.ms.get_source(source_name=source_name, user_id=user_id)
data_source = self.ms.get_source(source_id=source_id, user_id=user_id, source_name=source_name)
if data_source is None:
raise ValueError(f"Data source {source_name} does not exist for user_id {user_id}")
raise ValueError(f"Data source id={source_id} name={source_name} does not exist for user_id {user_id}")

# get connection to data source storage
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
Expand All @@ -1310,7 +1316,13 @@ def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source

return data_source

def detach_source_from_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
def detach_source_from_agent(
self,
user_id: uuid.UUID,
agent_id: uuid.UUID,
source_id: Optional[uuid.UUID] = None,
source_name: Optional[str] = None,
):
# TODO: remove all passages coresponding to source from agent's archival memory
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,12 @@ def test_sources(client, agent):
# load a file into a source
filename = "CONTRIBUTING.md"
num_passages = 20
response = client.load_file_into_source(filename, source.id)
response = client.load_file_into_source(filename=filename, source_id=source.id)
print(response)

# attach a source
# TODO: make sure things run in the right order
client.attach_source_to_agent(source_name="test_source", agent_id=agent.id)
client.attach_source_to_agent(source_id=source.id, agent_id=agent.id)

# list archival memory
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
Expand Down
2 changes: 1 addition & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_attach_source_to_agent(server, user_id, agent_id):
assert len(passages_before) == 0

# attach source
server.attach_source_to_agent(user_id, agent_id, "test_source")
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")

# check archival memory size
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
Expand Down

0 comments on commit d664e25

Please sign in to comment.