diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 15107b51ff..14a8a1adb7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,7 +42,7 @@ jobs: PGVECTOR_TEST_DB_URL: ${{ secrets.PGVECTOR_TEST_DB_URL }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | - poetry install -E dev -E postgres -E local + poetry install -E dev -E postgres -E local -E chroma -E lancedb - name: Set Poetry config env: diff --git a/docs/storage.md b/docs/storage.md index 5a4572c529..2a8df51304 100644 --- a/docs/storage.md +++ b/docs/storage.md @@ -18,23 +18,34 @@ pip install 'pymemgpt[postgres]' ### Running Postgres You will need to have a URI to a Postgres database which support [pgvector](https://github.com/pgvector/pgvector). You can either use a [hosted provider](https://github.com/pgvector/pgvector/issues/54) or [install pgvector](https://github.com/pgvector/pgvector#installation). +## Chroma +To enable the Chroma storage backend, install the dependencies with: +``` +pip install `pymemgpt[chroma]` +``` +You can configure Chroma with both the HTTP and persistent storage client via `memgpt configure`. You will need to specify either a persistent storage path or host/port dependending on your client choice. The example below shows how to configure Chroma with local persistent storage: +``` +? Select LLM inference provider: openai +? Override default endpoint: https://api.openai.com/v1 +? Select default model (recommended: gpt-4): gpt-4 +? Select embedding provider: openai +? Select default preset: memgpt_chat +? Select default persona: sam_pov +? Select default human: cs_phd +? Select storage backend for archival data: chroma +? Select chroma backend: persistent +? Enter persistent storage location: /Users/sarahwooders/.memgpt/config/chroma +``` ## LanceDB -In order to use the LanceDB backend. - - You have to enable the LanceDB backend by running - - ``` - memgpt configure - ``` - and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`. - To enable the LanceDB backend, make sure to install the required dependencies with: ``` pip install 'pymemgpt[lancedb]' ``` -for more checkout [lancedb docs](https://lancedb.github.io/lancedb/) +You have to enable the LanceDB backend by running + ``` + memgpt configure + ``` +and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`. For more checkout [lancedb docs](https://lancedb.github.io/lancedb/) -## Chroma -(Coming soon) diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index dfd9ba5e4a..1d07e6e1ce 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -241,24 +241,40 @@ def configure_cli(config: MemGPTConfig): def configure_archival_storage(config: MemGPTConfig): # Configure archival storage backend - archival_storage_options = ["local", "lancedb", "postgres"] + archival_storage_options = ["local", "lancedb", "postgres", "chroma"] archival_storage_type = questionary.select( "Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type ).ask() - archival_storage_uri = None + archival_storage_uri, archival_storage_path = None, None + + # configure postgres if archival_storage_type == "postgres": archival_storage_uri = questionary.text( "Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):", default=config.archival_storage_uri if config.archival_storage_uri else "", ).ask() + # configure lancedb if archival_storage_type == "lancedb": archival_storage_uri = questionary.text( "Enter lanncedb connection string (e.g. ./.lancedb", default=config.archival_storage_uri if config.archival_storage_uri else "./.lancedb", ).ask() - return archival_storage_type, archival_storage_uri + # configure chroma + if archival_storage_type == "chroma": + chroma_type = questionary.select("Select chroma backend:", ["http", "persistent"], default="http").ask() + if chroma_type == "http": + archival_storage_uri = questionary.text("Enter chroma ip (e.g. localhost:8000):", default="localhost:8000").ask() + if chroma_type == "persistent": + print(config.config_path, config.archival_storage_path) + default_archival_storage_path = ( + config.archival_storage_path if config.archival_storage_path else os.path.join(config.config_path, "chroma") + ) + print(default_archival_storage_path) + archival_storage_path = questionary.text("Enter persistent storage location:", default=default_archival_storage_path).ask() + + return archival_storage_type, archival_storage_uri, archival_storage_path # TODO: allow configuring embedding model @@ -275,7 +291,7 @@ def configure(): model, model_wrapper, context_window = configure_model(config, model_endpoint_type) embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config) default_preset, default_persona, default_human, default_agent = configure_cli(config) - archival_storage_type, archival_storage_uri = configure_archival_storage(config) + archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config) # check credentials azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = get_azure_credentials() @@ -322,6 +338,7 @@ def configure(): # storage archival_storage_type=archival_storage_type, archival_storage_uri=archival_storage_uri, + archival_storage_path=archival_storage_path, ) print(f"Saving config to {config.config_path}") config.save() diff --git a/memgpt/cli/cli_load.py b/memgpt/cli/cli_load.py index 2bdbc7e458..780a28a0fe 100644 --- a/memgpt/cli/cli_load.py +++ b/memgpt/cli/cli_load.py @@ -102,9 +102,7 @@ def load_directory( reader = SimpleDirectoryReader(input_files=input_files) # load docs - print("loading data") docs = reader.load_data() - print("done loading data") store_docs(name, docs) diff --git a/memgpt/connectors/chroma.py b/memgpt/connectors/chroma.py new file mode 100644 index 0000000000..8db7aa2ee2 --- /dev/null +++ b/memgpt/connectors/chroma.py @@ -0,0 +1,125 @@ +import chromadb +import json +import re +from typing import Optional, List, Iterator +from memgpt.connectors.storage import StorageConnector, Passage +from memgpt.utils import printd +from memgpt.config import AgentConfig, MemGPTConfig + + +def create_chroma_client(): + config = MemGPTConfig.load() + # create chroma client + if config.archival_storage_path: + client = chromadb.PersistentClient(config.archival_storage_path) + else: + # assume uri={ip}:{port} + ip = config.archival_storage_uri.split(":")[0] + port = config.archival_storage_uri.split(":")[1] + client = chromadb.HttpClient(host=ip, port=port) + return client + + +class ChromaStorageConnector(StorageConnector): + """Storage via Chroma""" + + # WARNING: This is not thread safe. Do NOT do concurrent access to the same collection. + + def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None): + # determine table name + if agent_config: + assert name is None, f"Cannot specify both agent config and name {name}" + self.table_name = self.generate_table_name_agent(agent_config) + elif name: + assert agent_config is None, f"Cannot specify both agent config and name {name}" + self.table_name = self.generate_table_name(name) + else: + raise ValueError("Must specify either agent config or name") + + printd(f"Using table name {self.table_name}") + + # create client + self.client = create_chroma_client() + + # get a collection or create if it doesn't exist already + self.collection = self.client.get_or_create_collection(self.table_name) + + def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]: + offset = 0 + while True: + # Retrieve a chunk of records with the given page_size + db_passages_chunk = self.collection.get(offset=offset, limit=page_size, include=["embeddings", "documents"]) + + # If the chunk is empty, we've retrieved all records + if not db_passages_chunk: + break + + # Yield a list of Passage objects converted from the chunk + yield [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages_chunk] + + # Increment the offset to get the next chunk in the next iteration + offset += page_size + + def get_all(self) -> List[Passage]: + results = self.collection.get(include=["embeddings", "documents"]) + return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"], results["embeddings"])] + + def get(self, id: str) -> Optional[Passage]: + results = self.collection.get(ids=[id]) + return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"], results["embeddings"])] + + def insert(self, passage: Passage): + self.collection.add(documents=[passage.text], embeddings=[passage.embedding], ids=[str(self.collection.count())]) + + def insert_many(self, passages: List[Passage], show_progress=True): + count = self.collection.count() + ids = [str(count + i) for i in range(len(passages))] + self.collection.add( + documents=[passage.text for passage in passages], embeddings=[passage.embedding for passage in passages], ids=ids + ) + + def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]: + results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=["embeddings", "documents"]) + # get index [0] since query is passed as list + return [Passage(text=text, embedding=embedding) for (text, embedding) in zip(results["documents"][0], results["embeddings"][0])] + + def delete(self): + self.client.delete_collection(name=self.table_name) + + def save(self): + # save to persistence file (nothing needs to be done) + printd("Saving chroma") + pass + + @staticmethod + def list_loaded_data(): + client = create_chroma_client() + collections = client.list_collections() + collections = [c.name for c in collections if c.name.startswith("memgpt_") and not c.name.startswith("memgpt_agent_")] + return collections + + def sanitize_table_name(self, name: str) -> str: + # Remove leading and trailing whitespace + name = name.strip() + + # Replace spaces and invalid characters with underscores + name = re.sub(r"\s+|\W+", "_", name) + + # Truncate to the maximum identifier length (e.g., 63 for PostgreSQL) + max_length = 63 + if len(name) > max_length: + name = name[:max_length].rstrip("_") + + # Convert to lowercase + name = name.lower() + + return name + + def generate_table_name_agent(self, agent_config: AgentConfig): + return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}" + + def generate_table_name(self, name: str): + return f"memgpt_{self.sanitize_table_name(name)}" + + def size(self) -> int: + return self.collection.count() diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index 65f9aacb5b..ac09e4ddcd 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -157,7 +157,8 @@ def list_loaded_data(): inspector = inspect(engine) tables = inspector.get_table_names() tables = [table for table in tables if table.startswith("memgpt_") and not table.startswith("memgpt_agent_")] - tables = [table.replace("memgpt_", "") for table in tables] + start_chars = len("memgpt_") + tables = [table[start_chars:] for table in tables] return tables def sanitize_table_name(self, name: str) -> str: @@ -300,7 +301,8 @@ def list_loaded_data(): tables = db.table_names() tables = [table for table in tables if table.startswith("memgpt_")] - tables = [table.replace("memgpt_", "") for table in tables] + start_chars = len("memgpt_") + tables = [table[start_chars:] for table in tables] return tables def sanitize_table_name(self, name: str) -> str: diff --git a/memgpt/connectors/storage.py b/memgpt/connectors/storage.py index fcd415fbd7..dc9089fce8 100644 --- a/memgpt/connectors/storage.py +++ b/memgpt/connectors/storage.py @@ -17,6 +17,9 @@ from memgpt.config import AgentConfig, MemGPTConfig +from memgpt.config import AgentConfig, MemGPTConfig + + class Passage: """A passage is a single unit of memory, and a standard format accross all storage backends. @@ -47,12 +50,14 @@ def get_storage_connector(name: Optional[str] = None, agent_config: Optional[Age from memgpt.connectors.db import PostgresStorageConnector return PostgresStorageConnector(name=name, agent_config=agent_config) + elif storage_type == "chroma": + from memgpt.connectors.chroma import ChromaStorageConnector + return ChromaStorageConnector(name=name, agent_config=agent_config) elif storage_type == "lancedb": from memgpt.connectors.db import LanceDBConnector return LanceDBConnector(name=name, agent_config=agent_config) - else: raise NotImplementedError(f"Storage type {storage_type} not implemented") @@ -67,7 +72,10 @@ def list_loaded_data(): from memgpt.connectors.db import PostgresStorageConnector return PostgresStorageConnector.list_loaded_data() + elif storage_type == "chroma": + from memgpt.connectors.chroma import ChromaStorageConnector + return ChromaStorageConnector.list_loaded_data() elif storage_type == "lancedb": from memgpt.connectors.db import LanceDBConnector diff --git a/memgpt/memory.py b/memgpt/memory.py index ff68150853..ada5aa961f 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -11,9 +11,6 @@ from llama_index.node_parser import SimpleNodeParser from llama_index.node_parser import SimpleNodeParser -from memgpt.embeddings import embedding_model -from memgpt.config import MemGPTConfig - class CoreMemory(object): """Held in-context inside the system message diff --git a/poetry.lock b/poetry.lock index cd64f84871..bea8603ce9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -388,6 +388,16 @@ files = [ {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, ] +[[package]] +name = "chroma" +version = "0.2.0" +description = "Color handling made simple." +optional = true +python-versions = "*" +files = [ + {file = "Chroma-0.2.0.tar.gz", hash = "sha256:e265bcd503e2b35c4448b83257467166c252ecf3ab610492432780691cdfb286"}, +] + [[package]] name = "click" version = "8.1.7" @@ -3663,6 +3673,7 @@ idna = ">=2.0" multidict = ">=4.0" [extras] +chroma = ["chroma"] dev = ["black", "datasets", "pre-commit", "pytest"] lancedb = ["lancedb"] local = ["huggingface-hub", "torch", "transformers"] @@ -3671,4 +3682,4 @@ postgres = ["pg8000", "pgvector", "psycopg", "psycopg-binary", "psycopg2-binary" [metadata] lock-version = "2.0" python-versions = "<3.12,>=3.9" -content-hash = "581c188b69fb076a077c85cfedd57c91742d2d63bf04af2d5235fbb87c604ecd" +content-hash = "39dc4a9aee8fb019c1d3eeab1e3091ae6f39d3ae670f34bcd63e50c14abc351e" diff --git a/pyproject.toml b/pyproject.toml index 30a8755bfc..cbee3e49b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true} websockets = "^12.0" docstring-parser = "^0.15" lancedb = {version = "^0.3.3", optional = true} +chroma = {version = "^0.2.0", optional = true} httpx = "^0.25.2" numpy = "^1.26.2" demjson3 = "^3.0.6" @@ -54,6 +55,7 @@ pyyaml = "^6.0.1" local = ["torch", "huggingface-hub", "transformers"] lancedb = ["lancedb"] postgres = ["pgvector", "psycopg", "psycopg-binary", "psycopg2-binary", "pg8000"] +chroma = ["chroma"] dev = ["pytest", "black", "pre-commit", "datasets"] [build-system] diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index a303279f0c..f8265ec918 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -1,6 +1,6 @@ # import tempfile # import asyncio -# import os +import os # import asyncio # from datasets import load_dataset @@ -111,6 +111,62 @@ def test_chroma(): recursive=True, ) + +def test_postgres(): + # override config path with enviornment variable + # TODO: make into temporary file + os.environ["MEMGPT_CONFIG_PATH"] = "/Users/sarahwooders/repos/MemGPT/test_config.cfg" + print("env", os.getenv("MEMGPT_CONFIG_PATH")) + config = memgpt.config.MemGPTConfig(archival_storage_type="postgres", config_path=os.getenv("MEMGPT_CONFIG_PATH")) + print(config) + config.save() + # exit() + + name = "tmp_hf_dataset2" + + dataset = load_dataset("MemGPT/example_short_stories") + + cache_dir = os.getenv("HF_DATASETS_CACHE") + if cache_dir is None: + # Construct the default path if the environment variable is not set. + cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets") + + load_directory( + name=name, + input_dir=cache_dir, + recursive=True, + ) + + +def test_chroma(): + import chromadb + + # override config path with enviornment variable + # TODO: make into temporary file + os.environ["MEMGPT_CONFIG_PATH"] = "/Users/sarahwooders/repos/MemGPT/test_config.cfg" + print("env", os.getenv("MEMGPT_CONFIG_PATH")) + config = memgpt.config.MemGPTConfig(archival_storage_type="chroma", config_path=os.getenv("MEMGPT_CONFIG_PATH")) + print(config) + config.save() + # exit() + + name = "tmp_hf_dataset" + + dataset = load_dataset("MemGPT/example_short_stories") + + cache_dir = os.getenv("HF_DATASETS_CACHE") + if cache_dir is None: + # Construct the default path if the environment variable is not set. + cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets") + + config = memgpt.config.MemGPTConfig(archival_storage_type="chroma") + + load_directory( + name=name, + input_dir=cache_dir, + recursive=True, + ) + # index = memgpt.embeddings.Index(name) ## query chroma diff --git a/tests/test_storage.py b/tests/test_storage.py index 3ace665027..fc941fa13b 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -11,6 +11,7 @@ import pgvector # Try to import again after installing from memgpt.connectors.storage import StorageConnector, Passage +from memgpt.connectors.chroma import ChromaStorageConnector from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector from memgpt.embeddings import embedding_model from memgpt.config import MemGPTConfig, AgentConfig @@ -59,6 +60,43 @@ def test_postgres_openai(): # print("...finished") +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key") +def test_chroma_openai(): + if not os.getenv("OPENAI_API_KEY"): + return # soft pass + + config = MemGPTConfig( + archival_storage_type="chroma", + archival_storage_path="./test_chroma", + embedding_endpoint_type="openai", + embedding_dim=1536, + model="gpt4", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + ) + config.save() + embed_model = embedding_model() + + passage = ["This is a test passage", "This is another test passage", "Cinderella wept"] + + db = ChromaStorageConnector(name="test-openai") + + for passage in passage: + db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage))) + + query = "why was she crying" + query_vec = embed_model.get_text_embedding(query) + res = db.query(query, query_vec, top_k=2) + + assert len(res) == 2, f"Expected 2 results, got {len(res)}" + assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" + + print(res[0].text) + + print("deleting") + db.delete() + + @pytest.mark.skipif( not os.getenv("LANCEDB_TEST_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing LANCEDB URI and/or OpenAI API key" )