From c0d7204c53ba15186405057b739bf07bd48b881e Mon Sep 17 00:00:00 2001 From: bookfere Date: Mon, 4 Nov 2024 18:26:03 +0800 Subject: [PATCH] feat: Validate a custom model for batch translation in ChatGPT. --- engines/openai.py | 30 +++++------------- tests/test_engine.py | 73 +++++++++++++++++++++++++++----------------- 2 files changed, 52 insertions(+), 51 deletions(-) diff --git a/engines/openai.py b/engines/openai.py index 3eaa7ce..ebd700a 100644 --- a/engines/openai.py +++ b/engines/openai.py @@ -116,28 +116,6 @@ def _parse_stream(self, response): class ChatgptBatchTranslate: """https://cookbook.openai.com/examples/batch_processing""" - - supported_models = [ - 'gpt-4o', - 'gpt-4-turbo', - 'gpt-4', - 'gpt-4-32k', - 'gpt-3.5-turbo', - 'gpt-3.5-turbo-16k', - 'gpt-4-turbo-preview', - 'gpt-4-vision-preview', - 'gpt-4-turbo-2024-04-09', - 'gpt-4-0314', - 'gpt-4-32k-0314', - 'gpt-4-32k-0613', - 'gpt-3.5-turbo-0301', - 'gpt-3.5-turbo-16k-0613', - 'gpt-3.5-turbo-1106', - 'gpt-3.5-turbo-0613', - 'text-embedding-3-large', - 'text-embedding-3-small', - 'text-embedding-ada-002', - ] boundary = uuid.uuid4().hex def __init__(self, translator): @@ -146,6 +124,7 @@ def __init__(self, translator): domain_name = '://'.join( urlsplit(self.translator.endpoint, 'https')[:2]) + self.model_endpint = '%s/v1/models' % domain_name self.file_endpoint = '%s/v1/files' % domain_name self.batch_endpoint = '%s/v1/batches' % domain_name @@ -166,6 +145,11 @@ def _create_multipart_form_data(self, body): data.append('--%s--' % self.boundary) return '\r\n'.join(data).encode('utf-8') + def supported_models(self): + response = request( + self.model_endpint, headers=self.translator.get_headers()) + return [item['id'] for item in json.loads(response).get('data')] + def headers(self, extra_headers={}): headers = self.translator.get_headers() headers.update(extra_headers) @@ -178,7 +162,7 @@ def upload(self, paragraphs): """Upload the original content and retrieve the file id. https://platform.openai.com/docs/api-reference/files/create """ - if self.translator.model not in self.supported_models: + if self.translator.model not in self.supported_models(): raise UnsupportedModel( 'The model "{}" does not support batch functionality.' .format(self.translator.model)) diff --git a/tests/test_engine.py b/tests/test_engine.py index 2136718..b1784dc 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -384,7 +384,7 @@ def setUp(self): def test_get_body(self): self.assertEqual(self.translator.get_body('test content'), json.dumps({ - 'model': 'gpt-3.5-turbo', + 'model': 'gpt-4o', 'messages': [ { 'role': 'system', @@ -406,7 +406,7 @@ def test_get_body_without_stream(self): self.assertEqual( self.translator.get_body('test content'), json.dumps({ - 'model': 'gpt-3.5-turbo', + 'model': 'gpt-4o', 'messages': [ { 'role': 'system', @@ -433,7 +433,7 @@ def test_translate_stream(self, mock_request, mock_et): 'only. Do not explain any term or answer any question-like ' 'content.') data = json.dumps({ - 'model': 'gpt-3.5-turbo', + 'model': 'gpt-4o', 'messages': [ {'role': 'system', 'content': prompt}, {'role': 'user', 'content': 'Hello World!'}], @@ -479,26 +479,6 @@ def setUp(self): self.batch_translator = ChatgptBatchTranslate(self.mock_translator) def test_class_object(self): - self.assertEqual(ChatgptBatchTranslate.supported_models, [ - 'gpt-4o', - 'gpt-4-turbo', - 'gpt-4', - 'gpt-4-32k', - 'gpt-3.5-turbo', - 'gpt-3.5-turbo-16k', - 'gpt-4-turbo-preview', - 'gpt-4-vision-preview', - 'gpt-4-turbo-2024-04-09', - 'gpt-4-0314', - 'gpt-4-32k-0314', - 'gpt-4-32k-0613', - 'gpt-3.5-turbo-0301', - 'gpt-3.5-turbo-16k-0613', - 'gpt-3.5-turbo-1106', - 'gpt-3.5-turbo-0613', - 'text-embedding-3-large', - 'text-embedding-3-small', - 'text-embedding-ada-002']) self.assertRegex(ChatgptBatchTranslate.boundary, r'(?a)^\w+$') def test_created_translator(self): @@ -512,7 +492,41 @@ def test_created_translator(self): self.batch_translator.batch_endpoint, 'https://api.openai.com/v1/batches') - def test_upload_with_unsupported_model(self): + @patch(module_name + '.openai.request') + def test_supportd_models(self, mock_request): + mock_request.return_value = """ +{ + "object": "list", + "data": [ + { + "id": "model-id-0", + "object": "model", + "created": 1686935002, + "owned_by": "organization-owner" + }, + { + "id": "model-id-1", + "object": "model", + "created": 1686935002, + "owned_by": "organization-owner" + }, + { + "id": "model-id-2", + "object": "model", + "created": 1686935002, + "owned_by": "openai" + } + ], + "object": "list" +} +""" + self.assertEqual( + self.batch_translator.supported_models(), + ['model-id-0', 'model-id-1', 'model-id-2']) + + @patch(module_name + '.openai.ChatgptBatchTranslate.supported_models') + def test_upload_with_unsupported_model(self, mock_suppored_models): + mock_suppored_models.return_value = ['gpt-4o'] self.mock_translator.model = 'fake-model' self.mock_translator.stream = True with self.assertRaises(UnsupportedModel) as cm: @@ -522,8 +536,9 @@ def test_upload_with_unsupported_model(self): 'The model "fake-model" does not support batch functionality.') @patch.object(ChatgptBatchTranslate, 'boundary', new='xxxxxxxxxx') + @patch(module_name + '.openai.ChatgptBatchTranslate.supported_models') @patch(module_name + '.openai.request') - def test_upload(self, mock_request): + def test_upload(self, mock_request, mock_suppored_models): mock_request.return_value = """ { "id": "test-file-id", @@ -534,6 +549,8 @@ def test_upload(self, mock_request): "purpose": "fine-tune" } """ + mock_suppored_models.return_value = ['gpt-4o'] + mock_paragraph_1 = Mock(Paragraph) mock_paragraph_1.md5 = 'abc' mock_paragraph_1.original = 'test content 1' @@ -545,7 +562,7 @@ def test_upload(self, mock_request): def mock_get_body(text): return json.dumps({ - 'model': 'gpt-3.5-turbo', + 'model': 'gpt-4o', 'messages': [ {'role': 'system', 'content': 'some prompt...'}, {'role': 'user', 'content': text}], @@ -566,13 +583,13 @@ def mock_get_body(text): 'Content-Type: application/json\r\n' '\r\n{"custom_id": "abc", "method": "POST", ' '"url": "/v1/chat/completions", ' - '"body": {"model": "gpt-3.5-turbo", ' + '"body": {"model": "gpt-4o", ' '"messages": [{"role": "system", ' '"content": "some prompt..."}, {"role": "user", ' '"content": "test content 1"}], "temperature": 1.0}}\n' '{"custom_id": "def", "method": "POST", ' '"url": "/v1/chat/completions", ' - '"body": {"model": "gpt-3.5-turbo", ' + '"body": {"model": "gpt-4o", ' '"messages": [{"role": "system", ' '"content": "some prompt..."}, {"role": "user", ' '"content": "test content 2"}], "temperature": 1.0}}\r\n'