From 85e9911ae2a257044e9d8efeff1433df93cc1466 Mon Sep 17 00:00:00 2001 From: Akbar Nurlybayev Date: Mon, 2 Dec 2024 15:26:08 -0700 Subject: [PATCH 1/4] Add CentML provider --- aisuite/providers/centml_provider.py | 65 ++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 aisuite/providers/centml_provider.py diff --git a/aisuite/providers/centml_provider.py b/aisuite/providers/centml_provider.py new file mode 100644 index 00000000..af4697ee --- /dev/null +++ b/aisuite/providers/centml_provider.py @@ -0,0 +1,65 @@ +import os +import httpx +from aisuite.provider import Provider, LLMError +from aisuite.framework import ChatCompletionResponse + + +class CentmlProvider(Provider): + """ + CentML Provider using httpx for direct API calls. + """ + + BASE_URL = "https://api.centml.com/openai/v1/chat/completions" + + def __init__(self, **config): + """ + Initialize the CentML provider with the given configuration. + The API key is fetched from the config or environment variables. + """ + self.api_key = config.get("api_key", os.getenv("CENTML_API_KEY")) + if not self.api_key: + raise ValueError( + "CentML API key is missing. Please provide it in the config or set the CENTML_API_KEY environment variable." + ) + + # Optionally set a custom timeout (default to 30s) + self.timeout = config.get("timeout", 30) + + def chat_completions_create(self, model, messages, **kwargs): + """ + Makes a request to the CentML chat completions endpoint using httpx. + """ + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + data = { + "model": model, + "messages": messages, + **kwargs, # Pass any additional arguments to the API + } + + try: + # Make the request to CentML endpoint. + response = httpx.post( + self.BASE_URL, json=data, headers=headers, timeout=self.timeout + ) + response.raise_for_status() + except httpx.HTTPStatusError as http_err: + raise LLMError(f"CentML request failed: {http_err}") + except Exception as e: + raise LLMError(f"An error occurred: {e}") + + # Return the normalized response + return self._normalize_response(response.json()) + + def _normalize_response(self, response_data): + """ + Normalize the response to a common format (ChatCompletionResponse). + """ + normalized_response = ChatCompletionResponse() + normalized_response.choices[0].message.content = response_data["choices"][0][ + "message" + ]["content"] + return normalized_response From c5504de9536a82142499e4c2a97c03b315d75f3f Mon Sep 17 00:00:00 2001 From: Akbar Nurlybayev Date: Mon, 2 Dec 2024 16:25:18 -0700 Subject: [PATCH 2/4] Add documentation --- .env.sample | 2 ++ README.md | 4 ++-- guides/centml.md | 30 ++++++++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) create mode 100644 guides/centml.md diff --git a/.env.sample b/.env.sample index cc933d9b..e88d6647 100644 --- a/.env.sample +++ b/.env.sample @@ -36,3 +36,5 @@ XAI_API_KEY= # Sambanova SAMBANOVA_API_KEY= +# CentML +CENTML_API_KEY= diff --git a/README.md b/README.md index add8b851..79c228ea 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Simple, unified interface to multiple Generative AI providers. `aisuite` makes it easy for developers to use multiple LLM through a standardized interface. Using an interface similar to OpenAI's, `aisuite` makes it easy to interact with the most popular LLMs and compare the results. It is a thin wrapper around python client libraries, and allows creators to seamlessly swap out and test responses from different LLM providers without changing their code. Today, the library is primarily focussed on chat completions. We will expand it cover more use cases in near future. Currently supported providers are - -OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace Ollama, Sambanova and Watsonx. +OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace Ollama, Sambanova, Watsonx, CentML. To maximize stability, `aisuite` uses either the HTTP endpoint or the SDK for making calls to the provider. ## Installation @@ -108,7 +108,7 @@ We follow a convention-based approach for loading providers, which relies on str ``` in providers/huggingface_provider.py. - + - **OpenAI**: The provider class should be defined as: diff --git a/guides/centml.md b/guides/centml.md new file mode 100644 index 00000000..10a47e12 --- /dev/null +++ b/guides/centml.md @@ -0,0 +1,30 @@ +# CentML Platform + +To use CentML Platform with the `aisuite` library, you'll need a [CentML Platform](https://app.centml.com) account. After logging in, go to the [User credentials](https://app.centml.com/user/credentials) section in your account settings and generate a new API key. Once you have your key, add it to your environment as follows: + +```shell +export CENTML_API_KEY="your-centml-api-key" +``` + +In your code: +```python +import aisuite as ai +client = ai.Client() + +provider = "centml" +model_id = "meta-llama/Llama-3.1-405B-Instruct-FP8" + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What’s the weather like in San Francisco?"}, +] + +response = client.chat.completions.create( + model=f"{provider}:{model_id}", + messages=messages, +) + +print(response.choices[0].message.content) +``` + +Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). From d248a749593fbf9d2184157ebb930aa1b1d88fe9 Mon Sep 17 00:00:00 2001 From: Akbar Nurlybayev Date: Mon, 2 Dec 2024 17:02:28 -0700 Subject: [PATCH 3/4] Add provider tests --- tests/providers/test_centml_provider.py | 55 +++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 tests/providers/test_centml_provider.py diff --git a/tests/providers/test_centml_provider.py b/tests/providers/test_centml_provider.py new file mode 100644 index 00000000..e0b38da7 --- /dev/null +++ b/tests/providers/test_centml_provider.py @@ -0,0 +1,55 @@ +import os +import pytest + +from unittest.mock import patch, MagicMock + +from aisuite.providers.centml_provider import CentmlProvider + + +@pytest.fixture(autouse=True) +def set_api_key_env_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("CENTML_API_KEY", "test-api-key") + + +def test_centml_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + headers = { + "Authorization": f"Bearer {os.getenv('CENTML_API_KEY')}", + "Content-Type": "application/json", + } + + provider = CentmlProvider() + + # Create a dictionary that matches the expected JSON response structure + mock_json_response = {"choices": [{"message": {"content": response_text_content}}]} + + with patch( + "httpx.post", + return_value=MagicMock(status_code=200, json=lambda: mock_json_response), + ) as mock_post: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_post.assert_called_once_with( + provider.BASE_URL, + json={ + "model": selected_model, + "messages": message_history, + "temperature": chosen_temperature, + }, + timeout=30, + headers=headers, + ) + + assert response.choices[0].message.content == response_text_content From 8ebd4d6ce43d5ec90547d759b59faf50fcce506c Mon Sep 17 00:00:00 2001 From: Akbar Nurlybayev Date: Sat, 4 Jan 2025 16:36:54 -0700 Subject: [PATCH 4/4] Replace deprecated Llama 405B with 3.3 70B --- guides/centml.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guides/centml.md b/guides/centml.md index 10a47e12..8914c2cc 100644 --- a/guides/centml.md +++ b/guides/centml.md @@ -12,7 +12,7 @@ import aisuite as ai client = ai.Client() provider = "centml" -model_id = "meta-llama/Llama-3.1-405B-Instruct-FP8" +model_id = "meta-llama/Llama-3.3-70B-Instruct" messages = [ {"role": "system", "content": "You are a helpful assistant."},