Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Embeddings API Support #131

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions aisuite/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, provider_configs: dict = {}):
self.providers = {}
self.provider_configs = provider_configs
self._chat = None
self._embeddings = None
self._initialize_providers()

def _initialize_providers(self):
Expand Down Expand Up @@ -65,6 +66,13 @@ def chat(self):
self._chat = Chat(self)
return self._chat

@property
def embeddings(self):
"""Return the embeddings API interface."""
if not self._embeddings:
self._embeddings = Embeddings(self)
return self._embeddings


class Chat:
def __init__(self, client: "Client"):
Expand Down Expand Up @@ -115,3 +123,43 @@ def create(self, model: str, messages: list, **kwargs):

# Delegate the chat completion to the correct provider's implementation
return provider.chat_completions_create(model_name, messages, **kwargs)


class Embeddings:
def __init__(self, client: "Client"):
self.client = client

def create(self, model: str, text: str, **kwargs):
"""
Create embeddings based on the model, text, and any extra arguments.
"""
# Check that correct format is used
if ":" not in model:
raise ValueError(
f"Invalid model format. Expected 'provider:model', got '{model}'"
)

# Extract the provider key from the model identifier, e.g., "google:gemini-xx"
provider_key, model_name = model.split(":", 1)

# Validate if the provider is supported
supported_providers = ProviderFactory.get_supported_providers()
if provider_key not in supported_providers:
raise ValueError(
f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. "
"Make sure the model string is formatted correctly as 'provider:model'."
)

# Initialize provider if not already initialized
if provider_key not in self.client.providers:
config = self.client.provider_configs.get(provider_key, {})
self.client.providers[provider_key] = ProviderFactory.create_provider(
provider_key, config
)

provider = self.client.providers.get(provider_key)
if not provider:
raise ValueError(f"Could not load provider for '{provider_key}'.")

# Delegate the embeddings creation to the correct provider's implementation
return provider.embeddings_create(model_name, text, **kwargs)
1 change: 1 addition & 0 deletions aisuite/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .provider_interface import ProviderInterface
from .chat_completion_response import ChatCompletionResponse
from .create_embedddings_response import CreateEmbeddingResponse
8 changes: 8 additions & 0 deletions aisuite/framework/create_embedddings_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from aisuite.framework.embedding import Embedding


class CreateEmbeddingResponse:
"""Used to conform to the response model of OpenAI"""

def __init__(self, number: int = 1):
self.data = [Embedding() for _ in range(number)]
3 changes: 3 additions & 0 deletions aisuite/framework/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class Embedding:
def __init__(self):
self.embedding = None
43 changes: 42 additions & 1 deletion aisuite/providers/ollama_provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import httpx
from aisuite.provider import Provider, LLMError
from aisuite.framework import ChatCompletionResponse
from aisuite.framework import ChatCompletionResponse, CreateEmbeddingResponse


class OllamaProvider(Provider):
Expand All @@ -13,6 +13,7 @@ class OllamaProvider(Provider):
"""

_CHAT_COMPLETION_ENDPOINT = "/api/chat"
_EMBEDDINGS_ENDPOINT = "/api/embed"
_CONNECT_ERROR_MESSAGE = "Ollama is likely not running. Start Ollama by running `ollama serve` on your host."

def __init__(self, **config):
Expand Down Expand Up @@ -54,6 +55,32 @@ def chat_completions_create(self, model, messages, **kwargs):
# Return the normalized response
return self._normalize_response(response.json())

def embeddings_create(self, model, input, **kwargs):
# Read more about the embeddings endpoint here:
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings

data = {
"model": model,
"input": input,
**kwargs, # Pass any additional arguments to the API
}

try:
response = httpx.post(
self.url.rstrip("/") + self._EMBEDDINGS_ENDPOINT,
json=data,
timeout=self.timeout,
)
response.raise_for_status()
except httpx.ConnectError: # Handle connection errors
raise LLMError(f"Connection failed: {self._CONNECT_ERROR_MESSAGE}")
except httpx.HTTPStatusError as http_err:
raise LLMError(f"Ollama request failed: {http_err}")
except Exception as e:
raise LLMError(f"An error occurred: {e}")

return self._normalize_embeddings_response(response.json())

def _normalize_response(self, response_data):
"""
Normalize the API response to a common format (ChatCompletionResponse).
Expand All @@ -63,3 +90,17 @@ def _normalize_response(self, response_data):
"content"
]
return normalized_response

def _normalize_embeddings_response(self, response_data):
"""
Normalize the API response to a common format (EmbeddingsResponse).
"""
normalized_response = CreateEmbeddingResponse(
number=len(response_data["embeddings"])
)

# Set the embeddings in the response
for i, embedding in enumerate(response_data["embeddings"]):
normalized_response.data[i].embedding = embedding

return normalized_response
9 changes: 9 additions & 0 deletions aisuite/providers/openai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,12 @@ def chat_completions_create(self, model, messages, **kwargs):
messages=messages,
**kwargs # Pass any additional arguments to the OpenAI API
)

def embeddings_create(self, model, input, **kwargs):
# Refer to the OpenAI API documentation for details on the parameters and response:
# https://platform.openai.com/docs/api-reference/embeddings/create?lang=python
return self.client.embeddings.create(
model=model,
input=input,
**kwargs # Pass any additional arguments to the OpenAI API
)
57 changes: 57 additions & 0 deletions guides/ollama.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Ollama

To use Ollama with `aisuite`, you'll need to ensure that your Ollama environment is properly configured. Ollama does not require an external API key but must be installed and running on your system or other devices.

## Setup Ollama
1. **Install Ollama**: Follow the installation instructions from [Ollama's official website](https://ollama.com/) to install the Ollama CLI on your system.
2. **Start the Ollama Service**: Ensure the Ollama service is running on your machine. You can start it with the following command:
```shell
ollama serve
```
3. **Download the Models**: Ensure that the models you plan to use are downloaded and available in your Ollama instance. You can explore the list of supported models in the [Ollama Library](https://ollama.com/library)
4. **Default API URL**: If the `OLLAMA_API_URL` is not set or explicitly passed in the configuration, `aisuite` will default to using `"http://localhost:11434"`.

## Create a Chat Completion
The following example demonstrates how to create a chat completion using the Ollama provider:

```python
import aisuite as ai

client = ai.Client()

provider = "ollama"
model_id = "qwq" # Replace with the model name running on your device

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
]

response = client.chat.completions.create(
model=f"{provider}:{model_id}",
messages=messages,
)

print(response.choices[0].message.content)
```

## Get Text Embeddings
You can generate text embeddings for your input using the embedding models available on your Ollama instance:

```python
import aisuite as ai

client = ai.Client()

provider = "ollama"
embedding_model_id = "bge-large"

input = ["This is a test sentence."]

response = client.embeddings.create(
model=f"{provider}:{embedding_model_id}",
input=input,
)

print(response.data[0].embedding)
```
21 changes: 21 additions & 0 deletions guides/openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,25 @@ response = client.chat.completions.create(
print(response.choices[0].message.content)
```

## Get Text Embeddings
You can generate text embeddings using OpenAI embedding models with the `aisuite` client:
```python
import aisuite as ai

client = ai.Client()

provider = "openai"
embedding_model_id = "text-embedding-ada-002"

input_text = ["This is an example sentence for embeddings."]

response = client.embeddings.create(
model=f"{provider}:{embedding_model_id}",
input=input_text,
)

embeddings = response.data[0].embedding
print("Embeddings:", embeddings)
```

Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md).
25 changes: 25 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,31 @@ def test_client_chat_completions(
next_compl_instance = client.chat.completions
assert compl_instance is next_compl_instance

@patch("aisuite.providers.openai_provider.OpenaiProvider.embeddings_create")
def test_client_embeddings_create(self, mock_openai):
# Mock responses from providers
mock_openai.return_value = [0.0023064255] * 32

# Provider configurations
provider_configs = {
"openai": {"api_key": "test_openai_api_key"},
}

# Initialize the client
client = Client()
client.configure(provider_configs)

# Test OpenAI model
open_ai_model = "openai" + ":" + "text-embedding-ada-002"
openai_response = client.embeddings.create(open_ai_model, "Hello, world!")
self.assertEqual(openai_response, [0.0023064255] * 32)
mock_openai.assert_called_once()

# Test that the `embeddings` property returns the same instance
embeddings_instance = client.embeddings
next_embeddings_instance = client.embeddings
assert embeddings_instance is next_embeddings_instance

def test_invalid_provider_in_client_config(self):
# Testing an invalid provider name in the configuration
invalid_provider_configs = {
Expand Down
Loading