Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chroma storage integration #285

Merged
merged 70 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 68 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
89cf976
mark depricated API section
sarahwooders Oct 30, 2023
be6212c
add readme
sarahwooders Oct 31, 2023
b011380
add readme
sarahwooders Oct 31, 2023
59f7b71
add readme
sarahwooders Oct 31, 2023
176538b
add readme
sarahwooders Oct 31, 2023
9905266
add readme
sarahwooders Oct 31, 2023
3606959
add readme
sarahwooders Oct 31, 2023
c48803c
add readme
sarahwooders Oct 31, 2023
40cdb23
add readme
sarahwooders Oct 31, 2023
ff43c98
add readme
sarahwooders Oct 31, 2023
01db319
CLI bug fixes for azure
sarahwooders Oct 31, 2023
a11cef9
check azure before running
sarahwooders Oct 31, 2023
a47d49e
Merge branch 'cpacker:main' into main
sarahwooders Oct 31, 2023
fbe2482
Update README.md
sarahwooders Oct 31, 2023
446a1a1
Update README.md
sarahwooders Oct 31, 2023
1541482
bug fix with persona loading
sarahwooders Oct 31, 2023
5776e30
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Oct 31, 2023
d48cf23
Merge branch 'cpacker:main' into main
sarahwooders Oct 31, 2023
7a8eb80
remove print
sarahwooders Oct 31, 2023
9a5ece0
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Oct 31, 2023
d3370b3
merge
sarahwooders Nov 3, 2023
c19c2ce
Merge branch 'cpacker:main' into main
sarahwooders Nov 3, 2023
aa6ee71
Merge branch 'cpacker:main' into main
sarahwooders Nov 3, 2023
36bb04d
make errors for cli flags more clear
sarahwooders Nov 3, 2023
6f50db1
format
sarahwooders Nov 3, 2023
4c91a41
Merge branch 'cpacker:main' into main
sarahwooders Nov 3, 2023
31282f5
add initial postgres implementation
sarahwooders Oct 31, 2023
d9e137c
working chroma loading
sarahwooders Nov 1, 2023
25b45f1
add postgres tests
sarahwooders Nov 1, 2023
54bd66d
working initial load into postgres and chroma
sarahwooders Nov 1, 2023
3632c3f
add load index command
sarahwooders Nov 1, 2023
fe25682
semi working load index
sarahwooders Nov 1, 2023
732e732
disgusting import code thanks to llama index's nasty APIs
sarahwooders Nov 1, 2023
ac6638c
add postgres connector
sarahwooders Nov 2, 2023
8787936
working postgres integration
sarahwooders Nov 2, 2023
6e6b3e1
working local storage (changed saving)
sarahwooders Nov 3, 2023
b409eac
implement /attach
sarahwooders Nov 3, 2023
b7842ec
remove old code
sarahwooders Nov 3, 2023
c5ab594
split up storage conenctors into multiple files
sarahwooders Nov 3, 2023
b6eeb2f
remove unused code
sarahwooders Nov 3, 2023
70e5a5f
cleanup
sarahwooders Nov 3, 2023
da2a208
implement vector db loading
sarahwooders Nov 3, 2023
88b5e18
cleanup state savign
sarahwooders Nov 3, 2023
82c0cbf
add chroma
sarahwooders Nov 3, 2023
29446cd
merge
sarahwooders Nov 21, 2023
25053f7
minor fix
sarahwooders Nov 21, 2023
5f9c7ef
fix up chroma integration
sarahwooders Nov 21, 2023
1edb6b6
fix list error
sarahwooders Nov 21, 2023
53792e3
update dependencies
sarahwooders Nov 21, 2023
9ff32a8
update docs
sarahwooders Nov 21, 2023
3368d5c
format
sarahwooders Nov 21, 2023
4453fd4
cleanup
sarahwooders Nov 22, 2023
96004d9
Merge branch 'cpacker:main' into main
sarahwooders Nov 23, 2023
cbeefd1
Merge branch 'cpacker:main' into main
sarahwooders Nov 27, 2023
5ebc42b
Merge branch 'cpacker:main' into main
sarahwooders Nov 29, 2023
c134f4a
Merge branch 'cpacker:main' into main
sarahwooders Nov 30, 2023
2ca083d
Merge branch 'cpacker:main' into main
sarahwooders Nov 30, 2023
c7354b3
Merge branch 'cpacker:main' into main
sarahwooders Nov 30, 2023
771624d
Merge branch 'cpacker:main' into main
sarahwooders Nov 30, 2023
e3f6ae1
Merge branch 'cpacker:main' into main
sarahwooders Dec 1, 2023
937cf05
Merge branch 'cpacker:main' into main
sarahwooders Dec 1, 2023
36d41bf
forgot to add embedding file
sarahwooders Dec 2, 2023
c982b55
upgrade llama index
sarahwooders Dec 2, 2023
8c3aab7
Merge branch 'cpacker:main' into main
sarahwooders Dec 2, 2023
362697b
fix data source naming bug
sarahwooders Dec 2, 2023
9e97c03
Merge branch 'cpacker:main' into main
sarahwooders Dec 2, 2023
2642930
Merge branch 'cpacker:main' into main
sarahwooders Dec 6, 2023
1f7cbe8
merge
sarahwooders Dec 6, 2023
d4dbc33
remove legacy
sarahwooders Dec 6, 2023
3982e3b
os import
sarahwooders Dec 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
poetry install -E dev -E postgres -E local
poetry install -E dev -E postgres -E local -E legacy -E chroma -E lancedb

- name: Set Poetry config
env:
Expand Down
35 changes: 23 additions & 12 deletions docs/storage.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,34 @@ pip install 'pymemgpt[postgres]'
### Running Postgres
You will need to have a URI to a Postgres database which support [pgvector](https://github.com/pgvector/pgvector). You can either use a [hosted provider](https://github.com/pgvector/pgvector/issues/54) or [install pgvector](https://github.com/pgvector/pgvector#installation).

## Chroma
To enable the Chroma storage backend, install the dependencies with:
```
pip install `pymemgpt[chroma]`
```
You can configure Chroma with both the HTTP and persistent storage client via `memgpt configure`. You will need to specify either a persistent storage path or host/port dependending on your client choice. The example below shows how to configure Chroma with local persistent storage:
```
? Select LLM inference provider: openai
? Override default endpoint: https://api.openai.com/v1
? Select default model (recommended: gpt-4): gpt-4
? Select embedding provider: openai
? Select default preset: memgpt_chat
? Select default persona: sam_pov
? Select default human: cs_phd
? Select storage backend for archival data: chroma
? Select chroma backend: persistent
? Enter persistent storage location: /Users/sarahwooders/.memgpt/config/chroma
```

## LanceDB
In order to use the LanceDB backend.

You have to enable the LanceDB backend by running

```
memgpt configure
```
and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`.

To enable the LanceDB backend, make sure to install the required dependencies with:
```
pip install 'pymemgpt[lancedb]'
```
for more checkout [lancedb docs](https://lancedb.github.io/lancedb/)
You have to enable the LanceDB backend by running
```
memgpt configure
```
and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`. For more checkout [lancedb docs](https://lancedb.github.io/lancedb/)


## Chroma
(Coming soon)
25 changes: 21 additions & 4 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,24 +241,40 @@ def configure_cli(config: MemGPTConfig):

def configure_archival_storage(config: MemGPTConfig):
# Configure archival storage backend
archival_storage_options = ["local", "lancedb", "postgres"]
archival_storage_options = ["local", "lancedb", "postgres", "chroma"]
archival_storage_type = questionary.select(
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
).ask()
archival_storage_uri = None
archival_storage_uri, archival_storage_path = None, None

# configure postgres
if archival_storage_type == "postgres":
archival_storage_uri = questionary.text(
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
default=config.archival_storage_uri if config.archival_storage_uri else "",
).ask()

# configure lancedb
if archival_storage_type == "lancedb":
archival_storage_uri = questionary.text(
"Enter lanncedb connection string (e.g. ./.lancedb",
default=config.archival_storage_uri if config.archival_storage_uri else "./.lancedb",
).ask()

return archival_storage_type, archival_storage_uri
# configure chroma
if archival_storage_type == "chroma":
chroma_type = questionary.select("Select chroma backend:", ["http", "persistent"], default="http").ask()
if chroma_type == "http":
archival_storage_uri = questionary.text("Enter chroma ip (e.g. localhost:8000):", default="localhost:8000").ask()
if chroma_type == "persistent":
print(config.config_path, config.archival_storage_path)
default_archival_storage_path = (
config.archival_storage_path if config.archival_storage_path else os.path.join(config.config_path, "chroma")
)
print(default_archival_storage_path)
archival_storage_path = questionary.text("Enter persistent storage location:", default=default_archival_storage_path).ask()

return archival_storage_type, archival_storage_uri, archival_storage_path

# TODO: allow configuring embedding model

Expand All @@ -275,7 +291,7 @@ def configure():
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
default_preset, default_persona, default_human, default_agent = configure_cli(config)
archival_storage_type, archival_storage_uri = configure_archival_storage(config)
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)

# check credentials
azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = get_azure_credentials()
Expand Down Expand Up @@ -322,6 +338,7 @@ def configure():
# storage
archival_storage_type=archival_storage_type,
archival_storage_uri=archival_storage_uri,
archival_storage_path=archival_storage_path,
)
print(f"Saving config to {config.config_path}")
config.save()
Expand Down
2 changes: 0 additions & 2 deletions memgpt/cli/cli_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def load_directory(
reader = SimpleDirectoryReader(input_files=input_files)

# load docs
print("loading data")
docs = reader.load_data()
print("done loading data")
store_docs(name, docs)


Expand Down
125 changes: 125 additions & 0 deletions memgpt/connectors/chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import chromadb
import json
import re
from typing import Optional, List, Iterator
from memgpt.connectors.storage import StorageConnector, Passage
from memgpt.utils import printd
from memgpt.config import AgentConfig, MemGPTConfig


def create_chroma_client():
config = MemGPTConfig.load()
# create chroma client
if config.archival_storage_path:
client = chromadb.PersistentClient(config.archival_storage_path)
else:
# assume uri={ip}:{port}
ip = config.archival_storage_uri.split(":")[0]
port = config.archival_storage_uri.split(":")[1]
client = chromadb.HttpClient(host=ip, port=port)
return client


class ChromaStorageConnector(StorageConnector):
"""Storage via Chroma"""

# WARNING: This is not thread safe. Do NOT do concurrent access to the same collection.

def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
# determine table name
if agent_config:
assert name is None, f"Cannot specify both agent config and name {name}"
self.table_name = self.generate_table_name_agent(agent_config)
elif name:
assert agent_config is None, f"Cannot specify both agent config and name {name}"
self.table_name = self.generate_table_name(name)
else:
raise ValueError("Must specify either agent config or name")

printd(f"Using table name {self.table_name}")

# create client
self.client = create_chroma_client()

# get a collection or create if it doesn't exist already
self.collection = self.client.get_or_create_collection(self.table_name)

def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
offset = 0
while True:
# Retrieve a chunk of records with the given page_size
db_passages_chunk = self.collection.get(offset=offset, limit=page_size, include=["embeddings", "documents"])

# If the chunk is empty, we've retrieved all records
if not db_passages_chunk:
break

# Yield a list of Passage objects converted from the chunk
yield [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages_chunk]

# Increment the offset to get the next chunk in the next iteration
offset += page_size

def get_all(self) -> List[Passage]:
results = self.collection.get(include=["embeddings", "documents"])
return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"], results["embeddings"])]

def get(self, id: str) -> Optional[Passage]:
results = self.collection.get(ids=[id])
return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"], results["embeddings"])]

def insert(self, passage: Passage):
self.collection.add(documents=[passage.text], embeddings=[passage.embedding], ids=[str(self.collection.count())])

def insert_many(self, passages: List[Passage], show_progress=True):
count = self.collection.count()
ids = [str(count + i) for i in range(len(passages))]
self.collection.add(
documents=[passage.text for passage in passages], embeddings=[passage.embedding for passage in passages], ids=ids
)

def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]:
results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=["embeddings", "documents"])
# get index [0] since query is passed as list
return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"][0], results["embeddings"][0])]

def delete(self):
self.client.delete_collection(name=self.table_name)

def save(self):
# save to persistence file (nothing needs to be done)
printd("Saving chroma")
pass

@staticmethod
def list_loaded_data():
client = create_chroma_client()
collections = client.list_collections()
collections = [c.name for c in collections if c.name.startswith("memgpt_") and not c.name.startswith("memgpt_agent_")]
return collections

def sanitize_table_name(self, name: str) -> str:
# Remove leading and trailing whitespace
name = name.strip()

# Replace spaces and invalid characters with underscores
name = re.sub(r"\s+|\W+", "_", name)

# Truncate to the maximum identifier length (e.g., 63 for PostgreSQL)
max_length = 63
if len(name) > max_length:
name = name[:max_length].rstrip("_")

# Convert to lowercase
name = name.lower()

return name

def generate_table_name_agent(self, agent_config: AgentConfig):
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"

def generate_table_name(self, name: str):
return f"memgpt_{self.sanitize_table_name(name)}"

def size(self) -> int:
return self.collection.count()
6 changes: 4 additions & 2 deletions memgpt/connectors/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def list_loaded_data():
inspector = inspect(engine)
tables = inspector.get_table_names()
tables = [table for table in tables if table.startswith("memgpt_") and not table.startswith("memgpt_agent_")]
tables = [table.replace("memgpt_", "") for table in tables]
start_chars = len("memgpt_")
tables = [table[start_chars:] for table in tables]
return tables

def sanitize_table_name(self, name: str) -> str:
Expand Down Expand Up @@ -300,7 +301,8 @@ def list_loaded_data():

tables = db.table_names()
tables = [table for table in tables if table.startswith("memgpt_")]
tables = [table.replace("memgpt_", "") for table in tables]
start_chars = len("memgpt_")
tables = [table[start_chars:] for table in tables]
return tables

def sanitize_table_name(self, name: str) -> str:
Expand Down
10 changes: 9 additions & 1 deletion memgpt/connectors/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from memgpt.config import AgentConfig, MemGPTConfig


from memgpt.config import AgentConfig, MemGPTConfig


class Passage:
"""A passage is a single unit of memory, and a standard format accross all storage backends.

Expand Down Expand Up @@ -47,12 +50,14 @@ def get_storage_connector(name: Optional[str] = None, agent_config: Optional[Age
from memgpt.connectors.db import PostgresStorageConnector

return PostgresStorageConnector(name=name, agent_config=agent_config)
elif storage_type == "chroma":
from memgpt.connectors.chroma import ChromaStorageConnector

return ChromaStorageConnector(name=name, agent_config=agent_config)
elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector

return LanceDBConnector(name=name, agent_config=agent_config)

else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")

Expand All @@ -67,7 +72,10 @@ def list_loaded_data():
from memgpt.connectors.db import PostgresStorageConnector

return PostgresStorageConnector.list_loaded_data()
elif storage_type == "chroma":
from memgpt.connectors.chroma import ChromaStorageConnector

return ChromaStorageConnector.list_loaded_data()
elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector

Expand Down
3 changes: 0 additions & 3 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
from llama_index.node_parser import SimpleNodeParser
from llama_index.node_parser import SimpleNodeParser

from memgpt.embeddings import embedding_model
from memgpt.config import MemGPTConfig


class CoreMemory(object):
"""Held in-context inside the system message
Expand Down
15 changes: 13 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true}
websockets = "^12.0"
docstring-parser = "^0.15"
lancedb = {version = "^0.3.3", optional = true}
chroma = {version = "^0.2.0", optional = true}
httpx = "^0.25.2"
numpy = "^1.26.2"
demjson3 = "^3.0.6"
Expand All @@ -54,6 +55,7 @@ pyyaml = "^6.0.1"
local = ["torch", "huggingface-hub", "transformers"]
lancedb = ["lancedb"]
postgres = ["pgvector", "psycopg", "psycopg-binary", "psycopg2-binary", "pg8000"]
chroma = ["chroma"]
dev = ["pytest", "black", "pre-commit", "datasets"]

[build-system]
Expand Down
Loading
Loading