-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
100 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,109 @@ | ||
import unittest | ||
from unittest.mock import patch | ||
import os | ||
import logging | ||
|
||
import pytest | ||
import asyncio | ||
from dotenv import load_dotenv | ||
from datetime import datetime | ||
import os | ||
from unittest.mock import patch, Mock | ||
from typing import Dict, Union, List | ||
import openai | ||
import pinecone | ||
import backoff | ||
from pinembed import EnvConfig, OpenAIHandler, PineconeHandler, DataStreamHandler | ||
|
||
# Load environment variables from .env file | ||
load_dotenv() | ||
|
||
# Initialize logging | ||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
class EnvConfig: | ||
"""Class for handling environment variables and API keys.""" | ||
|
||
def __init__(self) -> None: | ||
"""Initialize environment variables.""" | ||
self.openai_key: str = os.getenv("OPENAI_API_KEY") | ||
self.pinecone_key: str = os.getenv("PINECONE_API_KEY") | ||
self.pinecone_environment: str = os.getenv("PINECONE_ENVIRONMENT") | ||
self.pinecone_environment: str = os.getenv("PINEDEX") | ||
@pytest.fixture | ||
def mock_env_config(): | ||
"""Fixture for setting up a mock environment configuration. | ||
Mocks os.getenv to return a test value and initializes EnvConfig. | ||
Returns: | ||
EnvConfig: Mocked environment configuration | ||
""" | ||
with patch('os.getenv', return_value="test_value"): | ||
config = EnvConfig() | ||
return config | ||
|
||
class OpenAIHandler: | ||
"""Class for handling OpenAI operations.""" | ||
@pytest.mark.parametrize("env_value, expected_value", [("test_value", "test_value"), (None, None)]) | ||
def test_EnvConfig_init(env_value, expected_value, mock_env_config): | ||
"""Test initialization of EnvConfig. | ||
Tests if the EnvConfig is correctly initialized with environment variables. | ||
Args: | ||
env_value (str or None): Mock environment variable value | ||
expected_value (str or None): Expected value for EnvConfig attributes | ||
mock_env_config (EnvConfig): Mocked environment configuration | ||
""" | ||
assert mock_env_config.openai_key == expected_value | ||
assert mock_env_config.pinecone_key == expected_value | ||
|
||
def __init__(self, config: EnvConfig) -> None: | ||
"""Initialize OpenAI API key.""" | ||
openai.api_key = config.openai_key | ||
|
||
@backoff.on_exception(backoff.expo, Exception, max_tries=3) | ||
async def create_embedding(self, input_text: str) -> Dict[str, Union[int, List[float]]]: | ||
""" | ||
Create an embedding using OpenAI. | ||
@pytest.mark.asyncio | ||
@pytest.mark.parallel | ||
async def test_OpenAIHandler_create_embedding(mock_env_config): | ||
"""Asynchronous test for creating embeddings via OpenAIHandler. | ||
Tests if OpenAIHandler.create_embedding method correctly returns mock response. | ||
Args: | ||
mock_env_config (EnvConfig): Mocked environment configuration | ||
""" | ||
handler = OpenAIHandler(mock_env_config) | ||
mock_response = {"id": 1, "values": [0.1, 0.2, 0.3]} | ||
|
||
with patch.object(handler.openai.Embedding, 'create', return_value=mock_response): | ||
response = await handler.create_embedding("test_text") | ||
|
||
Parameters: | ||
input_text (str): The text to be embedded. | ||
Returns: | ||
Dict[str, Union[int, List[float]]]: The embedding response. | ||
""" | ||
response = openai.Embedding.create( | ||
model="text-embedding-ada-002",engine="ada", | ||
text=input_text, | ||
) | ||
return response | ||
assert response == mock_response | ||
|
||
# Create test class | ||
class TestOpenAIHandler(unittest.TestCase): | ||
# Set up test environment | ||
def setUp(self): | ||
self.config = EnvConfig() | ||
self.openai_handler = OpenAIHandler(self.config) | ||
@pytest.mark.parallel | ||
def test_PineconeHandler_init(mock_env_config): | ||
"""Test initialization of PineconeHandler. | ||
Tests if PineconeHandler is correctly initialized with environment variables. | ||
Args: | ||
mock_env_config (EnvConfig): Mocked environment configuration | ||
""" | ||
handler = PineconeHandler(mock_env_config) | ||
handler.pinecone.init.assert_called_with(api_key="test_value", environment="test_value") | ||
assert handler.index_name == "test_value" | ||
|
||
# Test create_embedding method | ||
@patch('openai.Embedding.create') | ||
def test_create_embedding(self, mock_create): | ||
input_text = 'This is a test' | ||
expected_response = {'id': 12345, 'embedding': [1.0, 2.0, 3.0]} | ||
mock_create.return_value = expected_response | ||
response = self.openai_handler.create_embedding(input_text) | ||
self.assertEqual(response, expected_response) | ||
@pytest.mark.asyncio | ||
@pytest.mark.parallel | ||
async def test_PineconeHandler_upload_embedding(mock_env_config): | ||
"""Asynchronous test for uploading embeddings via PineconeHandler. | ||
Tests if PineconeHandler.upload_embedding method correctly calls pinecone.Index.upsert. | ||
Args: | ||
mock_env_config (EnvConfig): Mocked environment configuration | ||
""" | ||
handler = PineconeHandler(mock_env_config) | ||
mock_embedding = { | ||
"id": "1", | ||
"values": [0.1, 0.2, 0.3], | ||
"metadata": {}, | ||
"sparse_values": {} | ||
} | ||
|
||
with patch.object(handler.pinecone.Index, 'upsert', return_value=None): | ||
await handler.upload_embedding(mock_embedding) | ||
|
||
handler.pinecone.Index.assert_called_with("test_value") | ||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
@pytest.mark.asyncio | ||
@pytest.mark.parallel | ||
async def test_DataStreamHandler_process_data(mock_env_config): | ||
"""Asynchronous test for processing data via DataStreamHandler. | ||
Tests if DataStreamHandler.process_data method correctly calls methods of OpenAIHandler and PineconeHandler. | ||
Args: | ||
mock_env_config (EnvConfig): Mocked environment configuration | ||
""" | ||
openai_handler = OpenAIHandler(mock_env_config) | ||
pinecone_handler = PineconeHandler(mock_env_config) | ||
handler = DataStreamHandler(openai_handler, pinecone_handler) | ||
|
||
mock_data = "test_data" | ||
mock_embedding = {"id": 1, "values": [0.1, 0.2, 0.3]} | ||
|
||
with patch.object(OpenAIHandler, 'create_embedding', return_value=mock_embedding): | ||
with patch.object(PineconeHandler, 'upload_embedding', return_value=None): | ||
await handler.process_data(mock_data) |