Skip to content

Commit

Permalink
Remove embeddings as argument in archival_memory.insert (#284)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Nov 5, 2023
1 parent 665ba54 commit f46cc3b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 24 deletions.
8 changes: 4 additions & 4 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,8 +800,8 @@ def recall_memory_search_date(self, start_date, end_date, count=5, page=0):
results_str = f"{results_pref} {json.dumps(results_formatted)}"
return results_str

def archival_memory_insert(self, content, embedding=None):
self.persistence_manager.archival_memory.insert(content, embedding=None)
def archival_memory_insert(self, content):
self.persistence_manager.archival_memory.insert(content)
return None

def archival_memory_search(self, query, count=5, page=0):
Expand Down Expand Up @@ -1245,8 +1245,8 @@ async def recall_memory_search_date(self, start_date, end_date, count=5, page=0)
results_str = f"{results_pref} {json.dumps(results_formatted)}"
return results_str

async def archival_memory_insert(self, content, embedding=None):
await self.persistence_manager.archival_memory.a_insert(content, embedding=None)
async def archival_memory_insert(self, content):
await self.persistence_manager.archival_memory.a_insert(content)
return None

async def archival_memory_search(self, query, count=5, page=0):
Expand Down
34 changes: 14 additions & 20 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,7 @@ def __repr__(self) -> str:
memory_str = "\n".join([d["content"] for d in self._archive])
return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}"

def insert(self, memory_string, embedding=None):
if embedding is not None:
raise ValueError("Basic text-based archival memory does not support embeddings")
def insert(self, memory_string):
self._archive.append(
{
# can eventually upgrade to adding semantic tags, etc
Expand All @@ -242,8 +240,8 @@ def insert(self, memory_string, embedding=None):
}
)

async def a_insert(self, memory_string, embedding=None):
return self.insert(memory_string, embedding)
async def a_insert(self, memory_string):
return self.insert(memory_string)

def search(self, query_string, count=None, start=None):
"""Simple text-based search"""
Expand Down Expand Up @@ -293,14 +291,12 @@ def _insert(self, memory_string, embedding):
}
)

def insert(self, memory_string, embedding=None):
if embedding is None:
embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model)
def insert(self, memory_string):
embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model)
return self._insert(memory_string, embedding)

async def a_insert(self, memory_string, embedding=None):
if embedding is None:
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
async def a_insert(self, memory_string):
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
return self._insert(memory_string, embedding)

def _search(self, query_embedding, query_string, count, start):
Expand Down Expand Up @@ -382,16 +378,14 @@ def _insert(self, memory_string, embedding):
embedding = np.array([embedding]).astype("float32")
self.index.add(embedding)

def insert(self, memory_string, embedding=None):
if embedding is None:
# Get the embedding
embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model)
def insert(self, memory_string):
# Get the embedding
embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model)
return self._insert(memory_string, embedding)

async def a_insert(self, memory_string, embedding=None):
if embedding is None:
# Get the embedding
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
async def a_insert(self, memory_string):
# Get the embedding
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
return self._insert(memory_string, embedding)

def _search(self, query_embedding, query_string, count=None, start=None):
Expand Down Expand Up @@ -814,7 +808,7 @@ def search(self, query_string, count=None, start=None):
async def a_search(self, query_string, count=None, start=None):
return self.search(query_string, count, start)

async def a_insert(self, memory_string, embedding=None):
async def a_insert(self, memory_string):
return self.insert(memory_string)

def __repr__(self) -> str:
Expand Down

0 comments on commit f46cc3b

Please sign in to comment.