Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CentML Provider #128

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,5 @@ XAI_API_KEY=

# Sambanova
SAMBANOVA_API_KEY=
# CentML
CENTML_API_KEY=
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Simple, unified interface to multiple Generative AI providers.
`aisuite` makes it easy for developers to use multiple LLM through a standardized interface. Using an interface similar to OpenAI's, `aisuite` makes it easy to interact with the most popular LLMs and compare the results. It is a thin wrapper around python client libraries, and allows creators to seamlessly swap out and test responses from different LLM providers without changing their code. Today, the library is primarily focussed on chat completions. We will expand it cover more use cases in near future.

Currently supported providers are -
OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace Ollama, Sambanova and Watsonx.
OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace Ollama, Sambanova, Watsonx, CentML.
To maximize stability, `aisuite` uses either the HTTP endpoint or the SDK for making calls to the provider.

## Installation
Expand Down Expand Up @@ -108,7 +108,7 @@ We follow a convention-based approach for loading providers, which relies on str
```

in providers/huggingface_provider.py.

- **OpenAI**:
The provider class should be defined as:

Expand Down
65 changes: 65 additions & 0 deletions aisuite/providers/centml_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
import httpx
from aisuite.provider import Provider, LLMError
from aisuite.framework import ChatCompletionResponse


class CentmlProvider(Provider):
"""
CentML Provider using httpx for direct API calls.
"""

BASE_URL = "https://api.centml.com/openai/v1/chat/completions"

def __init__(self, **config):
"""
Initialize the CentML provider with the given configuration.
The API key is fetched from the config or environment variables.
"""
self.api_key = config.get("api_key", os.getenv("CENTML_API_KEY"))
if not self.api_key:
raise ValueError(
"CentML API key is missing. Please provide it in the config or set the CENTML_API_KEY environment variable."
)

# Optionally set a custom timeout (default to 30s)
self.timeout = config.get("timeout", 30)

def chat_completions_create(self, model, messages, **kwargs):
"""
Makes a request to the CentML chat completions endpoint using httpx.
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}

data = {
"model": model,
"messages": messages,
**kwargs, # Pass any additional arguments to the API
}

try:
# Make the request to CentML endpoint.
response = httpx.post(
self.BASE_URL, json=data, headers=headers, timeout=self.timeout
)
response.raise_for_status()
except httpx.HTTPStatusError as http_err:
raise LLMError(f"CentML request failed: {http_err}")
except Exception as e:
raise LLMError(f"An error occurred: {e}")

# Return the normalized response
return self._normalize_response(response.json())

def _normalize_response(self, response_data):
"""
Normalize the response to a common format (ChatCompletionResponse).
"""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response_data["choices"][0][
"message"
]["content"]
return normalized_response
30 changes: 30 additions & 0 deletions guides/centml.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# CentML Platform

To use CentML Platform with the `aisuite` library, you'll need a [CentML Platform](https://app.centml.com) account. After logging in, go to the [User credentials](https://app.centml.com/user/credentials) section in your account settings and generate a new API key. Once you have your key, add it to your environment as follows:

```shell
export CENTML_API_KEY="your-centml-api-key"
```

In your code:
```python
import aisuite as ai
client = ai.Client()

provider = "centml"
model_id = "meta-llama/Llama-3.3-70B-Instruct"

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What’s the weather like in San Francisco?"},
]

response = client.chat.completions.create(
model=f"{provider}:{model_id}",
messages=messages,
)

print(response.choices[0].message.content)
```

Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md).
55 changes: 55 additions & 0 deletions tests/providers/test_centml_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import pytest

from unittest.mock import patch, MagicMock

from aisuite.providers.centml_provider import CentmlProvider


@pytest.fixture(autouse=True)
def set_api_key_env_var(monkeypatch):
"""Fixture to set environment variables for tests."""
monkeypatch.setenv("CENTML_API_KEY", "test-api-key")


def test_centml_provider():
"""High-level test that the provider is initialized and chat completions are requested successfully."""

user_greeting = "Hello!"
message_history = [{"role": "user", "content": user_greeting}]
selected_model = "our-favorite-model"
chosen_temperature = 0.75
response_text_content = "mocked-text-response-from-model"

headers = {
"Authorization": f"Bearer {os.getenv('CENTML_API_KEY')}",
"Content-Type": "application/json",
}

provider = CentmlProvider()

# Create a dictionary that matches the expected JSON response structure
mock_json_response = {"choices": [{"message": {"content": response_text_content}}]}

with patch(
"httpx.post",
return_value=MagicMock(status_code=200, json=lambda: mock_json_response),
) as mock_post:
response = provider.chat_completions_create(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

mock_post.assert_called_once_with(
provider.BASE_URL,
json={
"model": selected_model,
"messages": message_history,
"temperature": chosen_temperature,
},
timeout=30,
headers=headers,
)

assert response.choices[0].message.content == response_text_content