Skip to content

Commit

Permalink
Chroma storage integration (#285)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Dec 6, 2023
1 parent aa75fa1 commit 9c2e6b7
Show file tree
Hide file tree
Showing 12 changed files with 293 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry install -E dev -E postgres -E local
poetry install -E dev -E postgres -E local -E chroma -E lancedb
- name: Set Poetry config
env:
Expand Down
35 changes: 23 additions & 12 deletions docs/storage.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,34 @@ pip install 'pymemgpt[postgres]'
### Running Postgres
You will need to have a URI to a Postgres database which support [pgvector](https://github.com/pgvector/pgvector). You can either use a [hosted provider](https://github.com/pgvector/pgvector/issues/54) or [install pgvector](https://github.com/pgvector/pgvector#installation).

## Chroma
To enable the Chroma storage backend, install the dependencies with:
```
pip install `pymemgpt[chroma]`
```
You can configure Chroma with both the HTTP and persistent storage client via `memgpt configure`. You will need to specify either a persistent storage path or host/port dependending on your client choice. The example below shows how to configure Chroma with local persistent storage:
```
? Select LLM inference provider: openai
? Override default endpoint: https://api.openai.com/v1
? Select default model (recommended: gpt-4): gpt-4
? Select embedding provider: openai
? Select default preset: memgpt_chat
? Select default persona: sam_pov
? Select default human: cs_phd
? Select storage backend for archival data: chroma
? Select chroma backend: persistent
? Enter persistent storage location: /Users/sarahwooders/.memgpt/config/chroma
```

## LanceDB
In order to use the LanceDB backend.

You have to enable the LanceDB backend by running

```
memgpt configure
```
and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`.

To enable the LanceDB backend, make sure to install the required dependencies with:
```
pip install 'pymemgpt[lancedb]'
```
for more checkout [lancedb docs](https://lancedb.github.io/lancedb/)
You have to enable the LanceDB backend by running
```
memgpt configure
```
and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`. For more checkout [lancedb docs](https://lancedb.github.io/lancedb/)


## Chroma
(Coming soon)
25 changes: 21 additions & 4 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,24 +241,40 @@ def configure_cli(config: MemGPTConfig):

def configure_archival_storage(config: MemGPTConfig):
# Configure archival storage backend
archival_storage_options = ["local", "lancedb", "postgres"]
archival_storage_options = ["local", "lancedb", "postgres", "chroma"]
archival_storage_type = questionary.select(
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
).ask()
archival_storage_uri = None
archival_storage_uri, archival_storage_path = None, None

# configure postgres
if archival_storage_type == "postgres":
archival_storage_uri = questionary.text(
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
default=config.archival_storage_uri if config.archival_storage_uri else "",
).ask()

# configure lancedb
if archival_storage_type == "lancedb":
archival_storage_uri = questionary.text(
"Enter lanncedb connection string (e.g. ./.lancedb",
default=config.archival_storage_uri if config.archival_storage_uri else "./.lancedb",
).ask()

return archival_storage_type, archival_storage_uri
# configure chroma
if archival_storage_type == "chroma":
chroma_type = questionary.select("Select chroma backend:", ["http", "persistent"], default="http").ask()
if chroma_type == "http":
archival_storage_uri = questionary.text("Enter chroma ip (e.g. localhost:8000):", default="localhost:8000").ask()
if chroma_type == "persistent":
print(config.config_path, config.archival_storage_path)
default_archival_storage_path = (
config.archival_storage_path if config.archival_storage_path else os.path.join(config.config_path, "chroma")
)
print(default_archival_storage_path)
archival_storage_path = questionary.text("Enter persistent storage location:", default=default_archival_storage_path).ask()

return archival_storage_type, archival_storage_uri, archival_storage_path

# TODO: allow configuring embedding model

Expand All @@ -275,7 +291,7 @@ def configure():
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
default_preset, default_persona, default_human, default_agent = configure_cli(config)
archival_storage_type, archival_storage_uri = configure_archival_storage(config)
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)

# check credentials
azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = get_azure_credentials()
Expand Down Expand Up @@ -322,6 +338,7 @@ def configure():
# storage
archival_storage_type=archival_storage_type,
archival_storage_uri=archival_storage_uri,
archival_storage_path=archival_storage_path,
)
print(f"Saving config to {config.config_path}")
config.save()
Expand Down
2 changes: 0 additions & 2 deletions memgpt/cli/cli_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def load_directory(
reader = SimpleDirectoryReader(input_files=input_files)

# load docs
print("loading data")
docs = reader.load_data()
print("done loading data")
store_docs(name, docs)


Expand Down
125 changes: 125 additions & 0 deletions memgpt/connectors/chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import chromadb
import json
import re
from typing import Optional, List, Iterator
from memgpt.connectors.storage import StorageConnector, Passage
from memgpt.utils import printd
from memgpt.config import AgentConfig, MemGPTConfig


def create_chroma_client():
config = MemGPTConfig.load()
# create chroma client
if config.archival_storage_path:
client = chromadb.PersistentClient(config.archival_storage_path)
else:
# assume uri={ip}:{port}
ip = config.archival_storage_uri.split(":")[0]
port = config.archival_storage_uri.split(":")[1]
client = chromadb.HttpClient(host=ip, port=port)
return client


class ChromaStorageConnector(StorageConnector):
"""Storage via Chroma"""

# WARNING: This is not thread safe. Do NOT do concurrent access to the same collection.

def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
# determine table name
if agent_config:
assert name is None, f"Cannot specify both agent config and name {name}"
self.table_name = self.generate_table_name_agent(agent_config)
elif name:
assert agent_config is None, f"Cannot specify both agent config and name {name}"
self.table_name = self.generate_table_name(name)
else:
raise ValueError("Must specify either agent config or name")

printd(f"Using table name {self.table_name}")

# create client
self.client = create_chroma_client()

# get a collection or create if it doesn't exist already
self.collection = self.client.get_or_create_collection(self.table_name)

def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
offset = 0
while True:
# Retrieve a chunk of records with the given page_size
db_passages_chunk = self.collection.get(offset=offset, limit=page_size, include=["embeddings", "documents"])

# If the chunk is empty, we've retrieved all records
if not db_passages_chunk:
break

# Yield a list of Passage objects converted from the chunk
yield [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages_chunk]

# Increment the offset to get the next chunk in the next iteration
offset += page_size

def get_all(self) -> List[Passage]:
results = self.collection.get(include=["embeddings", "documents"])
return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"], results["embeddings"])]

def get(self, id: str) -> Optional[Passage]:
results = self.collection.get(ids=[id])
return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"], results["embeddings"])]

def insert(self, passage: Passage):
self.collection.add(documents=[passage.text], embeddings=[passage.embedding], ids=[str(self.collection.count())])

def insert_many(self, passages: List[Passage], show_progress=True):
count = self.collection.count()
ids = [str(count + i) for i in range(len(passages))]
self.collection.add(
documents=[passage.text for passage in passages], embeddings=[passage.embedding for passage in passages], ids=ids
)

def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]:
results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=["embeddings", "documents"])
# get index [0] since query is passed as list
return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"][0], results["embeddings"][0])]

def delete(self):
self.client.delete_collection(name=self.table_name)

def save(self):
# save to persistence file (nothing needs to be done)
printd("Saving chroma")
pass

@staticmethod
def list_loaded_data():
client = create_chroma_client()
collections = client.list_collections()
collections = [c.name for c in collections if c.name.startswith("memgpt_") and not c.name.startswith("memgpt_agent_")]
return collections

def sanitize_table_name(self, name: str) -> str:
# Remove leading and trailing whitespace
name = name.strip()

# Replace spaces and invalid characters with underscores
name = re.sub(r"\s+|\W+", "_", name)

# Truncate to the maximum identifier length (e.g., 63 for PostgreSQL)
max_length = 63
if len(name) > max_length:
name = name[:max_length].rstrip("_")

# Convert to lowercase
name = name.lower()

return name

def generate_table_name_agent(self, agent_config: AgentConfig):
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"

def generate_table_name(self, name: str):
return f"memgpt_{self.sanitize_table_name(name)}"

def size(self) -> int:
return self.collection.count()
6 changes: 4 additions & 2 deletions memgpt/connectors/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def list_loaded_data():
inspector = inspect(engine)
tables = inspector.get_table_names()
tables = [table for table in tables if table.startswith("memgpt_") and not table.startswith("memgpt_agent_")]
tables = [table.replace("memgpt_", "") for table in tables]
start_chars = len("memgpt_")
tables = [table[start_chars:] for table in tables]
return tables

def sanitize_table_name(self, name: str) -> str:
Expand Down Expand Up @@ -300,7 +301,8 @@ def list_loaded_data():

tables = db.table_names()
tables = [table for table in tables if table.startswith("memgpt_")]
tables = [table.replace("memgpt_", "") for table in tables]
start_chars = len("memgpt_")
tables = [table[start_chars:] for table in tables]
return tables

def sanitize_table_name(self, name: str) -> str:
Expand Down
10 changes: 9 additions & 1 deletion memgpt/connectors/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from memgpt.config import AgentConfig, MemGPTConfig


from memgpt.config import AgentConfig, MemGPTConfig


class Passage:
"""A passage is a single unit of memory, and a standard format accross all storage backends.
Expand Down Expand Up @@ -47,12 +50,14 @@ def get_storage_connector(name: Optional[str] = None, agent_config: Optional[Age
from memgpt.connectors.db import PostgresStorageConnector

return PostgresStorageConnector(name=name, agent_config=agent_config)
elif storage_type == "chroma":
from memgpt.connectors.chroma import ChromaStorageConnector

return ChromaStorageConnector(name=name, agent_config=agent_config)
elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector

return LanceDBConnector(name=name, agent_config=agent_config)

else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")

Expand All @@ -67,7 +72,10 @@ def list_loaded_data():
from memgpt.connectors.db import PostgresStorageConnector

return PostgresStorageConnector.list_loaded_data()
elif storage_type == "chroma":
from memgpt.connectors.chroma import ChromaStorageConnector

return ChromaStorageConnector.list_loaded_data()
elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector

Expand Down
3 changes: 0 additions & 3 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
from llama_index.node_parser import SimpleNodeParser
from llama_index.node_parser import SimpleNodeParser

from memgpt.embeddings import embedding_model
from memgpt.config import MemGPTConfig


class CoreMemory(object):
"""Held in-context inside the system message
Expand Down
15 changes: 13 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true}
websockets = "^12.0"
docstring-parser = "^0.15"
lancedb = {version = "^0.3.3", optional = true}
chroma = {version = "^0.2.0", optional = true}
httpx = "^0.25.2"
numpy = "^1.26.2"
demjson3 = "^3.0.6"
Expand All @@ -54,6 +55,7 @@ pyyaml = "^6.0.1"
local = ["torch", "huggingface-hub", "transformers"]
lancedb = ["lancedb"]
postgres = ["pgvector", "psycopg", "psycopg-binary", "psycopg2-binary", "pg8000"]
chroma = ["chroma"]
dev = ["pytest", "black", "pre-commit", "datasets"]

[build-system]
Expand Down
Loading

0 comments on commit 9c2e6b7

Please sign in to comment.