Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: bug with storing embedding info in metadata store #1101

Merged
merged 5 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions memgpt/cli/cli_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ def load_directory(
user_id = uuid.UUID(config.anon_clientid)

ms = MetadataStore(config)
source = Source(name=name, user_id=user_id)
source = Source(
name=name,
user_id=user_id,
embedding_model=config.default_embedding_config.embedding_model,
embedding_dim=config.default_embedding_config.embedding_dim,
)
ms.create_source(source)
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
# TODO: also get document store
Expand Down Expand Up @@ -209,7 +214,12 @@ def load_vector_database(
user_id = uuid.UUID(config.anon_clientid)

ms = MetadataStore(config)
source = Source(name=name, user_id=user_id)
source = Source(
name=name,
user_id=user_id,
embedding_model=config.default_embedding_config.embedding_model,
embedding_dim=config.default_embedding_config.embedding_dim,
)
ms.create_source(source)
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
# TODO: also get document store
Expand Down
10 changes: 8 additions & 2 deletions memgpt/data_sources/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def load_data(
document_store: Optional[StorageConnector] = None,
):
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
assert (
source.embedding_model == embedding_config.embedding_model
), f"Source and embedding config models must match, got: {source.embedding_model} and {embedding_config.embedding_model}"
assert (
source.embedding_dim == embedding_config.embedding_dim
), f"Source and embedding config dimensions must match, got: {source.embedding_dim} and {embedding_config.embedding_dim}."

# embedding model
embed_model = embedding_model(embedding_config)
Expand Down Expand Up @@ -55,8 +61,8 @@ def load_data(
metadata_=passage_metadata,
user_id=source.user_id,
data_source=source.name,
embedding_dim=embedding_config.embedding_dim,
embedding_model=embedding_config.embedding_model,
embedding_dim=source.embedding_dim,
embedding_model=source.embedding_model,
embedding=embedding,
)

Expand Down
11 changes: 5 additions & 6 deletions memgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,13 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
):
valid_options.append(source.name)
else:
# print warning about invalid sources
typer.secho(
f"Source {source.name} exists but has embedding dimentions {source.embedding_dim} from model {source.embedding_model}, while the agent uses embedding dimentions {memgpt_agent.agent_state.embedding_config.embedding_dim} and model {memgpt_agent.agent_state.embedding_config.embedding_model}",
fg=typer.colors.YELLOW,
)
invalid_options.append(source.name)

# print warning about invalid sources
typer.secho(
f"Warning: the following sources are not compatible with this agent's embedding model and dimension: {invalid_options}",
fg=typer.colors.YELLOW,
)

# prompt user for data source selection
data_source = questionary.select("Select data source", choices=valid_options).ask()

Expand Down
8 changes: 7 additions & 1 deletion memgpt/server/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from datetime import datetime
import logging
import uuid
from abc import abstractmethod
Expand Down Expand Up @@ -1023,7 +1024,12 @@ def create_api_key_for_user(self, user_id: uuid.UUID) -> Token:

def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add other fields
"""Create a new data source"""
source = Source(name=name, user_id=user_id)
source = Source(
name=name,
user_id=user_id,
embedding_model=self.config.default_embedding_config.embedding_model,
embedding_dim=self.config.default_embedding_config.embedding_dim,
)
self.ms.create_source(source)
return source

Expand Down
Loading