Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add data loading and attaching to server #1051

Merged
merged 5 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion memgpt/cli/cli_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def load_directory(
embedding_config=config.default_embedding_config,
document_store=None,
passage_store=passage_storage,
chunk_size=1000,
)
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")

Expand Down
4 changes: 1 addition & 3 deletions memgpt/data_sources/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def load_data(
embedding_config: EmbeddingConfig,
passage_store: StorageConnector,
document_store: Optional[StorageConnector] = None,
chunk_size: int = 1000,
):
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""

Expand All @@ -49,7 +48,6 @@ def load_data(

# generate passages
for passage_text, passage_metadata in connector.generate_passages([document]):
print("passage", passage_text, passage_metadata)
embedding = embed_model.get_text_embedding(passage_text)
passage = Passage(
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
Expand All @@ -64,7 +62,7 @@ def load_data(
)

passages.append(passage)
if len(passages) >= chunk_size:
if len(passages) >= embedding_config.embedding_chunk_size:
# insert passages into passage store
passage_store.insert_many(passages)

Expand Down
57 changes: 46 additions & 11 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
from memgpt.cli.cli_config import get_model_options

# from memgpt.agent_store.storage import StorageConnector
from memgpt.data_sources.connectors import DataConnector, load_data
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.metadata import MetadataStore
import memgpt.presets.presets as presets
import memgpt.utils as utils
import memgpt.server.utils as server_utils
from memgpt.data_types import (
User,
Source,
Passage,
AgentState,
LLMConfig,
Expand Down Expand Up @@ -969,6 +970,10 @@ def rename_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_agent_name:

return memgpt_agent.agent_state

def delete_user(self, user_id: uuid.UUID):
# TODO: delete user
pass

def delete_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID):
"""Delete an agent in the database"""
if self.ms.get_user(user_id=user_id) is None:
Expand Down Expand Up @@ -1015,15 +1020,45 @@ def create_api_key_for_user(self, user_id: uuid.UUID) -> Token:
token = self.ms.create_api_key(user_id=user_id)
return token

def create_source(self, name: str): # TODO: add other fields
# craete a data source
pass
def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add other fields
"""Create a new data source"""
source = Source(name=name, user_id=user_id)
self.ms.create_source(source)
return source

def load_passages(self, source_id: uuid.UUID, passages: List[Passage]):
# load a list of passages into a data source
pass
def load_data(
self,
user_id: uuid.UUID,
connector: DataConnector,
source_name: Source,
):
"""Load data from a DataConnector into a source for a specified user_id"""
# TODO: this should be implemented as a batch job or at least async, since it may take a long time

def attach_source_to_agent(self, agent_id: uuid.UUID, source_id: uuid.UUID):
# load data from a data source into the document store
source = self.ms.get_source(source_name=source_name, user_id=user_id)
if source is None:
raise ValueError(f"Data source {source_name} does not exist for user {user_id}")

# get the data connectors
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
# TODO: add document store support
document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)

# load data into the document store
load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)

def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
# attach a data source to an agent
# TODO: insert passages into agent archival memory
pass
data_source = self.ms.get_source(source_name=source_name, user_id=user_id)
if data_source is None:
raise ValueError(f"Data source {source_name} does not exist")

# get connection to data source storage
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)

# load agent
agent = self._get_or_load_agent(user_id, agent_id)

# attach source to agent
agent.attach_source(data_source.name, source_connector, self.ms)
Loading
Loading