Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: move source_id to path variable #1171

Merged
merged 4 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,8 @@ def delete_source(self, source_id: uuid.UUID):

def load_file_into_source(self, filename: str, source_id: uuid.UUID):
"""Load {filename} and insert into source"""
params = {"source_id": str(source_id)}
files = {"file": open(filename, "rb")}
response = requests.post(f"{self.base_url}/api/sources/upload", files=files, params=params, headers=self.headers)
response = requests.post(f"{self.base_url}/api/sources/{source_id}/upload", files=files, headers=self.headers)
return response.json()

def create_source(self, name: str) -> Source:
Expand All @@ -466,17 +465,17 @@ def create_source(self, name: str) -> Source:
embedding_model=response_json["embedding_config"]["embedding_model"],
)

def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
"""Attach a source to an agent"""
params = {"source_name": source_name, "agent_id": agent_id}
response = requests.post(f"{self.base_url}/api/sources/attach", params=params, headers=self.headers)
params = {"agent_id": agent_id}
response = requests.post(f"{self.base_url}/api/sources/{source_id}/attach", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
return response.json()

def detach_source(self, source_name: str, agent_id: uuid.UUID):
def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID):
"""Detach a source from an agent"""
params = {"source_name": source_name, "agent_id": str(agent_id)}
response = requests.post(f"{self.base_url}/api/sources/detach", params=params, headers=self.headers)
params = {"agent_id": str(agent_id)}
response = requests.post(f"{self.base_url}/api/sources/{source_id}/detach", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
return response.json()

Expand Down Expand Up @@ -615,8 +614,8 @@ def load_data(self, connector: DataConnector, source_name: str):
def create_source(self, name: str):
self.server.create_source(user_id=self.user_id, name=name)

def attach_source_to_agent(self, source_name: str, agent_id: uuid.UUID):
self.server.attach_source_to_agent(user_id=self.user_id, source_name=source_name, agent_id=agent_id)
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
self.server.attach_source_to_agent(user_id=self.user_id, source_id=source_id, agent_id=agent_id)

def delete_agent(self, agent_id: uuid.UUID):
self.server.delete_agent(user_id=self.user_id, agent_id=agent_id)
Expand Down
24 changes: 12 additions & 12 deletions memgpt/server/rest_api/sources/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,17 @@ async def delete_source(
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")

@router.post("/sources/attach", tags=["sources"], response_model=SourceModel)
@router.post("/sources/{source_id}/attach", tags=["sources"], response_model=SourceModel)
async def attach_source_to_agent(
source_id: uuid.UUID,
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to attach the source to."),
source_name: str = Query(..., description="The name of the source to attach."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Attach a data source to an existing agent."""
interface.clear()
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
source = server.attach_source_to_agent(source_name=source_name, agent_id=agent_id, user_id=user_id)
source = server.attach_source_to_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
return SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
Expand All @@ -138,20 +138,20 @@ async def attach_source_to_agent(
created_at=source.created_at,
)

@router.post("/sources/detach", tags=["sources"], response_model=SourceModel)
@router.post("/sources/{source_id}/detach", tags=["sources"], response_model=SourceModel)
async def detach_source_from_agent(
source_id: uuid.UUID,
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to detach the source from."),
source_name: str = Query(..., description="The name of the source to detach."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Detach a data source from an existing agent."""
server.detach_source_from_agent(source_name=source_name, agent_id=agent_id, user_id=user_id)
server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)

@router.post("/sources/upload", tags=["sources"], response_model=UploadFileToSourceResponse)
@router.post("/sources/{source_id}/upload", tags=["sources"], response_model=UploadFileToSourceResponse)
async def upload_file_to_source(
# file: UploadFile = UploadFile(..., description="The file to upload."),
file: UploadFile,
source_id: uuid.UUID = Query(..., description="The unique identifier of the source to attach."),
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Upload a file to a data source."""
Expand All @@ -173,18 +173,18 @@ async def upload_file_to_source(
# TODO: actually return added passages/documents
return UploadFileToSourceResponse(source=source, added_passages=passage_count, added_documents=document_count)

@router.get("/sources/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
@router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
async def list_passages(
source_id: uuid.UUID = Body(...),
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all passages associated with a data source."""
passages = server.list_data_source_passages(user_id=user_id, source_id=source_id)
return GetSourcePassagesResponse(passages=passages)

@router.get("/sources/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
@router.get("/sources/{source_id}/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
async def list_documents(
source_id: uuid.UUID = Body(...),
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all documents associated with a data source."""
Expand Down
20 changes: 16 additions & 4 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,11 +1293,17 @@ def load_data(
passage_count, document_count = load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)
return passage_count, document_count

def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
def attach_source_to_agent(
self,
user_id: uuid.UUID,
agent_id: uuid.UUID,
source_id: Optional[uuid.UUID] = None,
source_name: Optional[str] = None,
):
# attach a data source to an agent
data_source = self.ms.get_source(source_name=source_name, user_id=user_id)
data_source = self.ms.get_source(source_id=source_id, user_id=user_id, source_name=source_name)
if data_source is None:
raise ValueError(f"Data source {source_name} does not exist for user_id {user_id}")
raise ValueError(f"Data source id={source_id} name={source_name} does not exist for user_id {user_id}")

# get connection to data source storage
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
Expand All @@ -1310,7 +1316,13 @@ def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source

return data_source

def detach_source_from_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
def detach_source_from_agent(
self,
user_id: uuid.UUID,
agent_id: uuid.UUID,
source_id: Optional[uuid.UUID] = None,
source_name: Optional[str] = None,
):
# TODO: remove all passages coresponding to source from agent's archival memory
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,12 @@ def test_sources(client, agent):
# load a file into a source
filename = "CONTRIBUTING.md"
num_passages = 20
response = client.load_file_into_source(filename, source.id)
response = client.load_file_into_source(filename=filename, source_id=source.id)
print(response)

# attach a source
# TODO: make sure things run in the right order
client.attach_source_to_agent(source_name="test_source", agent_id=agent.id)
client.attach_source_to_agent(source_id=source.id, agent_id=agent.id)

# list archival memory
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
Expand Down
2 changes: 1 addition & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_attach_source_to_agent(server, user_id, agent_id):
assert len(passages_before) == 0

# attach source
server.attach_source_to_agent(user_id, agent_id, "test_source")
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")

# check archival memory size
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
Expand Down
Loading