Skip to content

Commit

Permalink
fix: chromadb max batch size (#1087)
Browse files Browse the repository at this point in the history
  • Loading branch information
imartinez authored Oct 20, 2023
1 parent b46c108 commit f5a9bf4
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 68 deletions.
82 changes: 20 additions & 62 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

87 changes: 87 additions & 0 deletions private_gpt/components/vector_store/batched_chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Any

from llama_index.schema import BaseNode, MetadataMode
from llama_index.vector_stores import ChromaVectorStore
from llama_index.vector_stores.chroma import chunk_list
from llama_index.vector_stores.utils import node_to_metadata_dict


class BatchedChromaVectorStore(ChromaVectorStore):
"""Chroma vector store, batching additions to avoid reaching the max batch limit.
In this vector store, embeddings are stored within a ChromaDB collection.
During query time, the index uses ChromaDB to query for the top
k most similar nodes.
Args:
chroma_client (from chromadb.api.API):
API instance
chroma_collection (chromadb.api.models.Collection.Collection):
ChromaDB collection instance
"""

chroma_client: Any | None

def __init__(
self,
chroma_client: Any,
chroma_collection: Any,
host: str | None = None,
port: str | None = None,
ssl: bool = False,
headers: dict[str, str] | None = None,
collection_kwargs: dict[Any, Any] | None = None,
) -> None:
super().__init__(
chroma_collection=chroma_collection,
host=host,
port=port,
ssl=ssl,
headers=headers,
collection_kwargs=collection_kwargs or {},
)
self.chroma_client = chroma_client

def add(self, nodes: list[BaseNode]) -> list[str]:
"""Add nodes to index, batching the insertion to avoid issues.
Args:
nodes: List[BaseNode]: list of nodes with embeddings
"""
if not self.chroma_client:
raise ValueError("Client not initialized")

if not self._collection:
raise ValueError("Collection not initialized")

max_chunk_size = self.chroma_client.max_batch_size
node_chunks = chunk_list(nodes, max_chunk_size)

all_ids = []
for node_chunk in node_chunks:
embeddings = []
metadatas = []
ids = []
documents = []
for node in node_chunk:
embeddings.append(node.get_embedding())
metadatas.append(
node_to_metadata_dict(
node, remove_text=True, flat_metadata=self.flat_metadata
)
)
ids.append(node.node_id)
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))

self._collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
all_ids.extend(ids)

return all_ids
10 changes: 6 additions & 4 deletions private_gpt/components/vector_store/vector_store_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from injector import inject, singleton
from llama_index import VectorStoreIndex
from llama_index.indices.vector_store import VectorIndexRetriever
from llama_index.vector_stores import ChromaVectorStore
from llama_index.vector_stores.types import VectorStore

from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.paths import local_data_path

Expand Down Expand Up @@ -36,14 +36,16 @@ class VectorStoreComponent:

@inject
def __init__(self) -> None:
db = chromadb.PersistentClient(
chroma_client = chromadb.PersistentClient(
path=str((local_data_path / "chroma_db").absolute())
)
chroma_collection = db.get_or_create_collection(
chroma_collection = chroma_client.get_or_create_collection(
"make_this_parameterizable_per_api_call"
) # TODO

self.vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
self.vector_store = BatchedChromaVectorStore(
chroma_client=chroma_client, chroma_collection=chroma_collection
)

@staticmethod
def get_retriever(
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ description = "Private GPT"
authors = ["Zylon <[email protected]>"]

[tool.poetry.dependencies]
python = ">=3.11,<3.13"
python = ">=3.11,<3.12"
fastapi = { extras = ["all"], version = "^0.103.1" }
loguru = "^0.7.2"
boto3 = "^1.28.56"
injector = "^0.21.0"
pyyaml = "^6.0.1"
python-multipart = "^0.0.6"
pypdf = "^3.16.2"
llama-index = "v0.8.35"
llama-index = "0.8.47"
chromadb = "^0.4.13"
watchdog = "^3.0.0"
transformers = "^4.34.0"
Expand Down
27 changes: 27 additions & 0 deletions tests/server/ingest/test_ingest_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from unittest.mock import PropertyMock, patch

from llama_index import Document

from private_gpt.server.ingest.ingest_service import IngestService
from tests.fixtures.mock_injector import MockInjector


def test_save_many_nodes(injector: MockInjector) -> None:
"""This is a specific test for a local Chromadb Vector Database setup.
Extend it when we add support for other vector databases in VectorStoreComponent.
"""
with patch(
"chromadb.api.segment.SegmentAPI.max_batch_size", new_callable=PropertyMock
) as max_batch_size:
# Make max batch size of Chromadb very small
max_batch_size.return_value = 10

ingest_service = injector.get(IngestService)

documents = []
for _i in range(100):
documents.append(Document(text="This is a sentence."))

ingested_docs = ingest_service._save_docs(documents)
assert len(ingested_docs) == len(documents)

0 comments on commit f5a9bf4

Please sign in to comment.