Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Azure embedders support #6676

Merged
merged 13 commits into from
Jan 5, 2024
4 changes: 4 additions & 0 deletions haystack/components/embedders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
from haystack.components.embedders.azure_text_embedder import AzureOpenAITextEmbedder
from haystack.components.embedders.azure_document_embedder import AzureOpenAIDocumentEmbedder

__all__ = [
"HuggingFaceTEITextEmbedder",
Expand All @@ -12,4 +14,6 @@
"SentenceTransformersDocumentEmbedder",
"OpenAITextEmbedder",
"OpenAIDocumentEmbedder",
"AzureOpenAITextEmbedder",
"AzureOpenAIDocumentEmbedder",
]
178 changes: 178 additions & 0 deletions haystack/components/embedders/azure_document_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
from typing import List, Optional, Dict, Any, Tuple

from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
from tqdm import tqdm

from haystack import component, Document, default_to_dict


@component
class AzureOpenAIDocumentEmbedder:
"""
A component for computing Document embeddings using OpenAI models.
The embedding of each Document is stored in the `embedding` field of the Document.

Usage example:
```python
from haystack import Document
from haystack.components.embedders import AzureOpenAIDocumentEmbedder

doc = Document(content="I love pizza!")

document_embedder = AzureOpenAIDocumentEmbedder()
anakin87 marked this conversation as resolved.
Show resolved Hide resolved

result = document_embedder.run([doc])
print(result['documents'][0].embedding)

# [0.017020374536514282, -0.023255806416273117, ...]
```
"""

def __init__(
self,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
azure_deployment: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
organization: Optional[str] = None,
prefix: str = "",
suffix: str = "",
batch_size: int = 32,
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
"""
Create an AzureOpenAITextEmbedder component.

:param azure_endpoint: The endpoint of the deployed model, e.g. `https://example-resource.azure.openai.com/`
:param api_version: The version of the API to use. Defaults to 2023-05-15
:param azure_deployment: The deployment of the model, usually the model name.
:param api_key: The API key to use for authentication.
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
on every request.
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of Documents to encode at once.
:param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments
to keep the logs clean.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""
# if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
if not azure_endpoint:
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")

anakin87 marked this conversation as resolved.
Show resolved Hide resolved
self.api_version = api_version
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
self.organization = organization
vblagoje marked this conversation as resolved.
Show resolved Hide resolved
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator

self._client = AzureOpenAI(
api_version=api_version,
azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.azure_deployment}

def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
to the constructor.
"""
return default_to_dict(
self,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
organization=self.organization,
api_version=self.api_version,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
"""
texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
]

text_to_embed = (
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
).replace("\n", " ")

texts_to_embed.append(text_to_embed)
return texts_to_embed

def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
"""
Embed a list of texts in batches.
"""

all_embeddings: List[List[float]] = []
meta: Dict[str, Any] = {"model": "", "usage": {"prompt_tokens": 0, "total_tokens": 0}}
for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Texts"):
batch = texts_to_embed[i : i + batch_size]
response = self._client.embeddings.create(model=self.azure_deployment, input=batch)

# Append embeddings to the list
all_embeddings.extend(el.embedding for el in response.data)

# Update the meta information only once if it's empty
if not meta["model"]:
meta["model"] = response.model
meta["usage"] = dict(response.usage)
else:
# Update the usage tokens
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
meta["usage"]["total_tokens"] += response.usage.total_tokens

return all_embeddings, meta

@component.output_types(documents=List[Document], meta=Dict[str, Any])
def run(self, documents: List[Document]):
"""
Embed a list of Documents. The embedding of each Document is stored in the `embedding` field of the Document.

:param documents: A list of Documents to embed.
"""
if not (isinstance(documents, list) and all(isinstance(doc, Document) for doc in documents)):
raise TypeError("Input must be a list of Document instances. For strings, use AzureOpenAITextEmbedder.")

texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)

# Assign the corresponding embeddings to each document
for doc, emb in zip(documents, embeddings):
doc.embedding = emb

return {"documents": documents, "meta": meta}
123 changes: 123 additions & 0 deletions haystack/components/embedders/azure_text_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import os
from typing import List, Optional, Dict, Any

from openai.lib.azure import AzureADTokenProvider, AzureOpenAI

from haystack import component, default_to_dict, Document


@component
class AzureOpenAITextEmbedder:
"""
A component for embedding strings using OpenAI models.

Usage example:
```python
from haystack.components.embedders import AzureOpenAITextEmbedder

text_to_embed = "I love pizza!"

text_embedder = AzureOpenAITextEmbedder()

print(text_embedder.run(text_to_embed))

# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
# 'meta': {'model': 'text-embedding-ada-002-v2',
# 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}}
```
"""

def __init__(
self,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
azure_deployment: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
organization: Optional[str] = None,
prefix: str = "",
suffix: str = "",
):
"""
Create an AzureOpenAITextEmbedder component.

:param azure_endpoint: The endpoint of the deployed model, e.g. `https://example-resource.azure.openai.com/`
:param api_version: The version of the API to use. Defaults to 2023-05-15
:param azure_deployment: The deployment of the model, usually the model name.
:param api_key: The API key to use for authentication.
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
on every request.
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
"""
# Why is this here?
# AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
# None init parameters. This way we accommodate the use case where env var AZURE_OPENAI_ENDPOINT is set instead
# of passing it as a parameter.
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
if not azure_endpoint:
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")

self.api_version = api_version
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
self.organization = organization
self.prefix = prefix
self.suffix = suffix

self._client = AzureOpenAI(
api_version=api_version,
azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.azure_deployment}

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: The serialized component as a dictionary.
"""
return default_to_dict(
self,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
organization=self.organization,
api_version=self.api_version,
prefix=self.prefix,
suffix=self.suffix,
)

@component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str):
"""Embed a string using AzureOpenAITextEmbedder."""
if not isinstance(text, str):
# Check if input is a list and all elements are instances of Document
if isinstance(text, list) and all(isinstance(elem, Document) for elem in text):
error_message = "Input must be a string. Use AzureOpenAIDocumentEmbedder for a list of Documents."
else:
error_message = "Input must be a string."
raise TypeError(error_message)

# Preprocess the text by adding prefixes/suffixes
# finally, replace newlines as recommended by OpenAI docs
processed_text = f"{self.prefix}{text}{self.suffix}".replace("\n", " ")

response = self._client.embeddings.create(model=self.azure_deployment, input=processed_text)

return {
"embedding": response.data[0].embedding,
"meta": {"model": response.model, "usage": dict(response.usage)},
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Adds AzureOpenAIDocumentEmbedder and AzureOpenAITextEmbedder as new embedders. These embedders are very similar to
their OpenAI counterparts, but they use the Azure API instead of the OpenAI API.
Loading