Skip to content

Commit

Permalink
feat: Add data loading and attaching to server (#1051)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Feb 25, 2024
1 parent 243d3ed commit b091333
Show file tree
Hide file tree
Showing 7 changed files with 636 additions and 422 deletions.
1 change: 0 additions & 1 deletion memgpt/cli/cli_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def load_directory(
embedding_config=config.default_embedding_config,
document_store=None,
passage_store=passage_storage,
chunk_size=1000,
)
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")

Expand Down
4 changes: 1 addition & 3 deletions memgpt/data_sources/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def load_data(
embedding_config: EmbeddingConfig,
passage_store: StorageConnector,
document_store: Optional[StorageConnector] = None,
chunk_size: int = 1000,
):
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""

Expand All @@ -49,7 +48,6 @@ def load_data(

# generate passages
for passage_text, passage_metadata in connector.generate_passages([document]):
print("passage", passage_text, passage_metadata)
embedding = embed_model.get_text_embedding(passage_text)
passage = Passage(
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
Expand All @@ -64,7 +62,7 @@ def load_data(
)

passages.append(passage)
if len(passages) >= chunk_size:
if len(passages) >= embedding_config.embedding_chunk_size:
# insert passages into passage store
passage_store.insert_many(passages)

Expand Down
57 changes: 46 additions & 11 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
from memgpt.cli.cli_config import get_model_options

# from memgpt.agent_store.storage import StorageConnector
from memgpt.data_sources.connectors import DataConnector, load_data
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.metadata import MetadataStore
import memgpt.presets.presets as presets
import memgpt.utils as utils
import memgpt.server.utils as server_utils
from memgpt.data_types import (
User,
Source,
Passage,
AgentState,
LLMConfig,
Expand Down Expand Up @@ -969,6 +970,10 @@ def rename_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_agent_name:

return memgpt_agent.agent_state

def delete_user(self, user_id: uuid.UUID):
# TODO: delete user
pass

def delete_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID):
"""Delete an agent in the database"""
if self.ms.get_user(user_id=user_id) is None:
Expand Down Expand Up @@ -1015,15 +1020,45 @@ def create_api_key_for_user(self, user_id: uuid.UUID) -> Token:
token = self.ms.create_api_key(user_id=user_id)
return token

def create_source(self, name: str): # TODO: add other fields
# craete a data source
pass
def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add other fields
"""Create a new data source"""
source = Source(name=name, user_id=user_id)
self.ms.create_source(source)
return source

def load_passages(self, source_id: uuid.UUID, passages: List[Passage]):
# load a list of passages into a data source
pass
def load_data(
self,
user_id: uuid.UUID,
connector: DataConnector,
source_name: Source,
):
"""Load data from a DataConnector into a source for a specified user_id"""
# TODO: this should be implemented as a batch job or at least async, since it may take a long time

def attach_source_to_agent(self, agent_id: uuid.UUID, source_id: uuid.UUID):
# load data from a data source into the document store
source = self.ms.get_source(source_name=source_name, user_id=user_id)
if source is None:
raise ValueError(f"Data source {source_name} does not exist for user {user_id}")

# get the data connectors
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
# TODO: add document store support
document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)

# load data into the document store
load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)

def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
# attach a data source to an agent
# TODO: insert passages into agent archival memory
pass
data_source = self.ms.get_source(source_name=source_name, user_id=user_id)
if data_source is None:
raise ValueError(f"Data source {source_name} does not exist")

# get connection to data source storage
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)

# load agent
agent = self._get_or_load_agent(user_id, agent_id)

# attach source to agent
agent.attach_source(data_source.name, source_connector, self.ms)
Loading

0 comments on commit b091333

Please sign in to comment.