From 2170dc8d00ff0133929c5f1462d2f4fb01ef1927 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 5 Mar 2024 21:45:41 -0800 Subject: [PATCH 1/5] Fix /attach bug --- memgpt/cli/cli_load.py | 14 ++++++++++++-- memgpt/data_sources/connectors.py | 6 ++++-- memgpt/main.py | 11 +++++------ 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/memgpt/cli/cli_load.py b/memgpt/cli/cli_load.py index 55cbac5e1f..7f7d8875f2 100644 --- a/memgpt/cli/cli_load.py +++ b/memgpt/cli/cli_load.py @@ -96,7 +96,12 @@ def load_directory( user_id = uuid.UUID(config.anon_clientid) ms = MetadataStore(config) - source = Source(name=name, user_id=user_id) + source = Source( + name=name, + user_id=user_id, + embedding_model=config.default_embedding_config.embedding_model, + embedding_dim=config.default_embedding_config.embedding_dim, + ) ms.create_source(source) passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id) # TODO: also get document store @@ -209,7 +214,12 @@ def load_vector_database( user_id = uuid.UUID(config.anon_clientid) ms = MetadataStore(config) - source = Source(name=name, user_id=user_id) + source = Source( + name=name, + user_id=user_id, + embedding_model=config.default_embedding_config.embedding_model, + embedding_dim=config.default_embedding_config.embedding_dim, + ) ms.create_source(source) passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id) # TODO: also get document store diff --git a/memgpt/data_sources/connectors.py b/memgpt/data_sources/connectors.py index eaef6cc11a..198678a8d5 100644 --- a/memgpt/data_sources/connectors.py +++ b/memgpt/data_sources/connectors.py @@ -24,6 +24,8 @@ def load_data( document_store: Optional[StorageConnector] = None, ): """Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id.""" + assert source.embedding_model == embedding_config.embedding_model, "Source and embedding config models must match." + assert source.embedding_dim == embedding_config.embedding_dim, "Source and embedding config dimensions must match." # embedding model embed_model = embedding_model(embedding_config) @@ -55,8 +57,8 @@ def load_data( metadata_=passage_metadata, user_id=source.user_id, data_source=source.name, - embedding_dim=embedding_config.embedding_dim, - embedding_model=embedding_config.embedding_model, + embedding_dim=source.embedding_dim, + embedding_model=source.embedding_model, embedding=embedding, ) diff --git a/memgpt/main.py b/memgpt/main.py index 72c0625c6b..010af6af67 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -131,14 +131,13 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, ): valid_options.append(source.name) else: + # print warning about invalid sources + typer.secho( + f"Source {source.name} exists but has embedding dimentions {source.embedding_dim} from model {source.embedding_model}, while the agent uses embedding dimentions {memgpt_agent.agent_state.embedding_config.embedding_dim} and model {memgpt_agent.agent_state.embedding_config.embedding_model}", + fg=typer.colors.YELLOW, + ) invalid_options.append(source.name) - # print warning about invalid sources - typer.secho( - f"Warning: the following sources are not compatible with this agent's embedding model and dimension: {invalid_options}", - fg=typer.colors.YELLOW, - ) - # prompt user for data source selection data_source = questionary.select("Select data source", choices=valid_options).ask() From a8360f3c4962e0c517956b366b9171a5604a4533 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 5 Mar 2024 21:53:51 -0800 Subject: [PATCH 2/5] update assert print --- memgpt/data_sources/connectors.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/memgpt/data_sources/connectors.py b/memgpt/data_sources/connectors.py index 198678a8d5..7e0b255c10 100644 --- a/memgpt/data_sources/connectors.py +++ b/memgpt/data_sources/connectors.py @@ -24,8 +24,12 @@ def load_data( document_store: Optional[StorageConnector] = None, ): """Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id.""" - assert source.embedding_model == embedding_config.embedding_model, "Source and embedding config models must match." - assert source.embedding_dim == embedding_config.embedding_dim, "Source and embedding config dimensions must match." + assert ( + source.embedding_model == embedding_config.embedding_model + ), "Source and embedding config models must match, got: {source.embedding_model} and {embedding_config.embedding_model}" + assert ( + source.embedding_dim == embedding_config.embedding_dim + ), "Source and embedding config dimensions must match, got: {source.embedding_dim} and {embedding_config.embedding_dim}." # embedding model embed_model = embedding_model(embedding_config) From bf7f9ccf93d1f941f6561e154849a88148c6f80f Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 5 Mar 2024 21:58:05 -0800 Subject: [PATCH 3/5] fix print --- memgpt/data_sources/connectors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/memgpt/data_sources/connectors.py b/memgpt/data_sources/connectors.py index 7e0b255c10..7d2b544679 100644 --- a/memgpt/data_sources/connectors.py +++ b/memgpt/data_sources/connectors.py @@ -26,10 +26,10 @@ def load_data( """Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id.""" assert ( source.embedding_model == embedding_config.embedding_model - ), "Source and embedding config models must match, got: {source.embedding_model} and {embedding_config.embedding_model}" + ), f"Source and embedding config models must match, got: {source.embedding_model} and {embedding_config.embedding_model}" assert ( source.embedding_dim == embedding_config.embedding_dim - ), "Source and embedding config dimensions must match, got: {source.embedding_dim} and {embedding_config.embedding_dim}." + ), f"Source and embedding config dimensions must match, got: {source.embedding_dim} and {embedding_config.embedding_dim}." # embedding model embed_model = embedding_model(embedding_config) From 6c031259b22cc99d3e54e76ad318fcd0d1e59f5f Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 5 Mar 2024 22:01:39 -0800 Subject: [PATCH 4/5] fix test --- memgpt/server/server.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 28f964ce10..aaee704852 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1,4 +1,5 @@ import json +from datetime import datetime import logging import uuid from abc import abstractmethod @@ -1023,7 +1024,13 @@ def create_api_key_for_user(self, user_id: uuid.UUID) -> Token: def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add other fields """Create a new data source""" - source = Source(name=name, user_id=user_id) + source = Source( + name=name, + user_id=user_id, + created_at=datetime.datetime.now(), + embedding_model=self.config.default_embedding_config.embedding_model, + embedding_dim=self.config.default_embedding_config.embedding_dim, + ) self.ms.create_source(source) return source From f98983a9138376bb2f69038f3413dd0be5b8b981 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 5 Mar 2024 22:05:15 -0800 Subject: [PATCH 5/5] remove datetime --- memgpt/server/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index aaee704852..99f4c98da4 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1027,7 +1027,6 @@ def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add o source = Source( name=name, user_id=user_id, - created_at=datetime.datetime.now(), embedding_model=self.config.default_embedding_config.embedding_model, embedding_dim=self.config.default_embedding_config.embedding_dim, )