Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cache embeddings (resolves #200) #208

Merged
merged 16 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions docs/user_guides/advanced/embedding-search-providers.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Embedding Search Providers

NeMo Guardrails uses embedding search (a.k.a. vector databases) for implementing the [guardrails process](../../architecture/README.md#the-guardrails-process) and for the [knowledge base](../configuration-guide.md#knowledge-base-documents) functionality.
NeMo Guardrails utilizes embedding search, also known as vector databases, for implementing the [guardrails process](../../architecture/README.md#the-guardrails-process) and for the [knowledge base](../configuration-guide.md#knowledge-base-documents) functionality.

The default embedding search uses FastEmbed for computing the embeddings (the `all-MiniLM-L6-v2` model) and Annoy for performing the search.
To enhance the efficiency of the embedding search process, NeMo Guardrails employs a caching mechanism for embeddings. This mechanism stores computed embeddings, thereby reducing the need for repeated computations and accelerating the search process.

The default configuration is the following:
The default embedding search uses FastEmbed for computing the embeddings (the `all-MiniLM-L6-v2` model) and Annoy for performing the search. The default configuration is as follows:

```yaml
core:
Expand All @@ -13,13 +13,23 @@ core:
parameters:
embedding_engine: FastEmbed
embedding_model: all-MiniLM-L6-v2
cache:
enabled: True
key_generator: md5
store: filesystem
store_config: {}

knowledge_base:
embedding_search_provider:
name: default
parameters:
embedding_engine: FastEmbed
embedding_model: all-MiniLM-L6-v2
cache:
enabled: True
key_generator: md5
store: filesystem
store_config: {}
```

The default embedding search provider can also work with OpenAI embeddings:
Expand All @@ -31,15 +41,28 @@ core:
parameters:
embedding_engine: openai
embedding_model: text-embedding-ada-002
cache:
enabled: True
key_generator: md5
store: filesystem
store_config: {}

knowledge_base:
embedding_search_provider:
name: default
parameters:
embedding_engine: openai
embedding_model: text-embedding-ada-002
cache:
enabled: True
key_generator: md5
store: filesystem
store_config: {}
```

The `cache` configuration is optional. If enabled, it uses the specified `key_generator` and `store` to cache the embeddings. The `store_config` can be used to provide additional configuration options required for the store.
The default `cache` configuration uses the `md5` key generator and the `filesystem` store. The cache is enabled by default.

## Custom Embedding Search Providers

You can implement your own custom embedding search provider by subclassing `EmbeddingsIndex`. For quick reference, the complete interface is included below:
Expand All @@ -52,6 +75,10 @@ class EmbeddingsIndex:
def embedding_size(self):
raise NotImplementedError

@property
def cache_config(self):
raise NotImplementedError

async def add_item(self, item: IndexItem):
"""Adds a new item to the index."""
raise NotImplementedError()
Expand Down
41 changes: 34 additions & 7 deletions nemoguardrails/embeddings/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import Any, Dict, List

from annoy import AnnoyIndex

from nemoguardrails.embeddings.cache import cache_embeddings
from nemoguardrails.embeddings.index import EmbeddingModel, EmbeddingsIndex, IndexItem
from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig


class BasicEmbeddingsIndex(EmbeddingsIndex):
Expand All @@ -31,32 +33,47 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
embedding_engine (str): The engine for computing embeddings.
embeddings_index (AnnoyIndex): The current embedding index.
embedding_size (int): The size of the embeddings.
cache_config (EmbeddingsCacheConfig): The cache configuration.
embeddings (List[List[float]]): The computed embeddings.
"""

def __init__(self, embedding_model=None, embedding_engine=None, index=None):
def __init__(
self,
embedding_model=None,
embedding_engine=None,
index=None,
cache_config: EmbeddingsCacheConfig | Dict[str, Any] = None,
):
"""Initialize the BasicEmbeddingsIndex.

Args:
embedding_model (str, optional): The model for computing embeddings. Defaults to None.
embedding_engine (str, optional): The engine for computing embeddings. Defaults to None.
index (AnnoyIndex, optional): The pre-existing index. Defaults to None.
cache_config (EmbeddingsCacheConfig | Dict[str, Any], optional): The cache configuration. Defaults to None.
"""
self._model = None
self._items = []
self._embeddings = []
self.embedding_model = embedding_model
self.embedding_engine = embedding_engine
self._embedding_size = 0

# When the index is provided, it means it's from the cache.
if isinstance(cache_config, Dict):
self._cache_config = EmbeddingsCacheConfig(**cache_config)
else:
self._cache_config = cache_config or EmbeddingsCacheConfig()
self._index = index

@property
def embeddings_index(self):
"""Get the current embedding index"""
return self._index

@property
def cache_config(self):
"""Get the cache configuration."""
return self._cache_config

@property
def embedding_size(self):
"""Get the size of the embeddings."""
Expand All @@ -78,7 +95,8 @@ def _init_model(self):
embedding_model=self.embedding_model, embedding_engine=self.embedding_engine
)

def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
@cache_embeddings
async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Compute embeddings for a list of texts.

Args:
Expand Down Expand Up @@ -181,7 +199,8 @@ def encode(self, documents: List[str]) -> List[List[float]]:
Returns:
List[List[float]]: The list of sentence embeddings, where each embedding is a list of floats.
"""
return self.model.encode(documents)

return self.model.encode(documents).tolist()


class FastEmbedEmbeddingModel(EmbeddingModel):
Expand Down Expand Up @@ -236,8 +255,12 @@ class OpenAIEmbeddingModel(EmbeddingModel):

"""

def __init__(self, embedding_model: str):
def __init__(
self,
embedding_model: str,
):
self.model = embedding_model

self.embedding_size = len(self.encode(["test"])[0])

def encode(self, documents: List[str]) -> List[List[float]]:
Expand All @@ -252,6 +275,10 @@ def encode(self, documents: List[str]) -> List[List[float]]:
"""
import openai

print("performing embeddings on :", len(documents))
if len(documents) == 1:
print(documents)

# Make embedding request to OpenAI API
res = openai.Embedding.create(input=documents, engine=self.model)
embeddings = [record["embedding"] for record in res["data"]]
Expand Down
Loading