From f46cc3b15bd73503483db008fed9e82ad46a136b Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Sun, 5 Nov 2023 12:48:22 -0800 Subject: [PATCH] Remove embeddings as argument in archival_memory.insert (#284) --- memgpt/agent.py | 8 ++++---- memgpt/memory.py | 34 ++++++++++++++-------------------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index d56876976f..7c1383ef68 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -800,8 +800,8 @@ def recall_memory_search_date(self, start_date, end_date, count=5, page=0): results_str = f"{results_pref} {json.dumps(results_formatted)}" return results_str - def archival_memory_insert(self, content, embedding=None): - self.persistence_manager.archival_memory.insert(content, embedding=None) + def archival_memory_insert(self, content): + self.persistence_manager.archival_memory.insert(content) return None def archival_memory_search(self, query, count=5, page=0): @@ -1245,8 +1245,8 @@ async def recall_memory_search_date(self, start_date, end_date, count=5, page=0) results_str = f"{results_pref} {json.dumps(results_formatted)}" return results_str - async def archival_memory_insert(self, content, embedding=None): - await self.persistence_manager.archival_memory.a_insert(content, embedding=None) + async def archival_memory_insert(self, content): + await self.persistence_manager.archival_memory.a_insert(content) return None async def archival_memory_search(self, query, count=5, page=0): diff --git a/memgpt/memory.py b/memgpt/memory.py index e9fc47aaa4..13678e2ea3 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -231,9 +231,7 @@ def __repr__(self) -> str: memory_str = "\n".join([d["content"] for d in self._archive]) return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}" - def insert(self, memory_string, embedding=None): - if embedding is not None: - raise ValueError("Basic text-based archival memory does not support embeddings") + def insert(self, memory_string): self._archive.append( { # can eventually upgrade to adding semantic tags, etc @@ -242,8 +240,8 @@ def insert(self, memory_string, embedding=None): } ) - async def a_insert(self, memory_string, embedding=None): - return self.insert(memory_string, embedding) + async def a_insert(self, memory_string): + return self.insert(memory_string) def search(self, query_string, count=None, start=None): """Simple text-based search""" @@ -293,14 +291,12 @@ def _insert(self, memory_string, embedding): } ) - def insert(self, memory_string, embedding=None): - if embedding is None: - embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model) + def insert(self, memory_string): + embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model) return self._insert(memory_string, embedding) - async def a_insert(self, memory_string, embedding=None): - if embedding is None: - embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model) + async def a_insert(self, memory_string): + embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model) return self._insert(memory_string, embedding) def _search(self, query_embedding, query_string, count, start): @@ -382,16 +378,14 @@ def _insert(self, memory_string, embedding): embedding = np.array([embedding]).astype("float32") self.index.add(embedding) - def insert(self, memory_string, embedding=None): - if embedding is None: - # Get the embedding - embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model) + def insert(self, memory_string): + # Get the embedding + embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model) return self._insert(memory_string, embedding) - async def a_insert(self, memory_string, embedding=None): - if embedding is None: - # Get the embedding - embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model) + async def a_insert(self, memory_string): + # Get the embedding + embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model) return self._insert(memory_string, embedding) def _search(self, query_embedding, query_string, count=None, start=None): @@ -814,7 +808,7 @@ def search(self, query_string, count=None, start=None): async def a_search(self, query_string, count=None, start=None): return self.search(query_string, count, start) - async def a_insert(self, memory_string, embedding=None): + async def a_insert(self, memory_string): return self.insert(memory_string) def __repr__(self) -> str: