From fbe2376255fd8e14987718e93fcfc3ad260e332d Mon Sep 17 00:00:00 2001 From: icyqwq Date: Wed, 22 Jan 2025 10:29:56 +0800 Subject: [PATCH] feat: Chat completion supports multi-modal input (images and text). Text-to-speech supports streaming. --- components/openai/OpenAI.c | 231 ++++++++++++++++-- components/openai/include/OpenAI.h | 20 +- .../openai/test_apps/main/test_openai.c | 42 +++- 3 files changed, 268 insertions(+), 25 deletions(-) diff --git a/components/openai/OpenAI.c b/components/openai/OpenAI.c index 04cd9d735..261254670 100644 --- a/components/openai/OpenAI.c +++ b/components/openai/OpenAI.c @@ -130,6 +130,7 @@ typedef struct { char *(*del)(const char *base_url, const char *api_key, const char *endpoint); /*!< Perform an HTTP DELETE request. */ char *(*post)(const char *base_url, const char *api_key, const char *endpoint, char *jsonBody); /*!< Perform an HTTP POST request. */ char *(*speechpost)(const char *base_url, const char *api_key, const char *endpoint, char *jsonBody, size_t *output_len); /*!< Perform an HTTP POST request for speech. */ + char *(*speechpost_stream)(const char *base_url, const char *api_key, const char *endpoint, char *jsonBody, size_t *output_len, OpenAI_StreamCallback stream_callback); /*!< Perform an HTTP POST request for stream speech. */ char *(*upload)(const char *base_url, const char *api_key, const char *endpoint, const char *boundary, uint8_t *data, size_t len); /*!< Upload data using an HTTP request. */ } _OpenAI_t; @@ -1057,35 +1058,106 @@ static void OpenAI_ChatCompletionClearConversation(OpenAI_ChatCompletion_t *chat } } -static cJSON *createChatMessage(cJSON *messages, const char *role, const char *content) +static cJSON *createContentObject(const char *type, const char *value) +{ + cJSON *content_obj = cJSON_CreateObject(); + if (!content_obj) { + ESP_LOGE(TAG, "Failed to create content_obj!"); + return NULL; + } + + if (cJSON_AddStringToObject(content_obj, "type", type) == NULL) { + cJSON_Delete(content_obj); + ESP_LOGE(TAG, "Failed to add 'type' field!"); + return NULL; + } + + if (strcmp(type, "text") == 0) { + if (cJSON_AddStringToObject(content_obj, "text", value) == NULL) { + cJSON_Delete(content_obj); + ESP_LOGE(TAG, "Failed to add 'text' field!"); + return NULL; + } + } else if (strcmp(type, "image_url") == 0) { + cJSON *image_url_obj = cJSON_CreateObject(); + if (!image_url_obj) { + cJSON_Delete(content_obj); + ESP_LOGE(TAG, "Failed to create image_url_obj!"); + return NULL; + } + if (cJSON_AddStringToObject(image_url_obj, "url", value) == NULL) { + cJSON_Delete(content_obj); + cJSON_Delete(image_url_obj); + ESP_LOGE(TAG, "Failed to add 'url' field!"); + return NULL; + } + if (!cJSON_AddItemToObject(content_obj, "image_url", image_url_obj)) { + cJSON_Delete(content_obj); + cJSON_Delete(image_url_obj); + ESP_LOGE(TAG, "Failed to add image_url_obj to content_obj!"); + return NULL; + } + } else { + ESP_LOGW(TAG, "Unknown type: %s, skip building extra fields", type); + } + + return content_obj; +} + +static cJSON *createChatMessage(const char *role, const char *type, const char *value) { cJSON *message = cJSON_CreateObject(); - OPENAI_ERROR_CHECK(message != NULL, "cJSON_CreateObject failed!", NULL); + if (!message) { + ESP_LOGE(TAG, "Failed to create message object!"); + return NULL; + } if (cJSON_AddStringToObject(message, "role", role) == NULL) { cJSON_Delete(message); - ESP_LOGE(TAG, "cJSON_AddStringToObject failed!"); + ESP_LOGE(TAG, "Failed to add role field!"); + return NULL; + } + + cJSON *content_arr = cJSON_CreateArray(); + if (!content_arr) { + cJSON_Delete(message); + ESP_LOGE(TAG, "Failed to create content array!"); + return NULL; + } + + cJSON *content_obj = createContentObject(type, value); + if (!content_obj) { + cJSON_Delete(message); + cJSON_Delete(content_arr); return NULL; } - if (cJSON_AddStringToObject(message, "content", content) == NULL) { + + if (!cJSON_AddItemToArray(content_arr, content_obj)) { cJSON_Delete(message); - ESP_LOGE(TAG, "cJSON_AddStringToObject failed!"); + cJSON_Delete(content_arr); + cJSON_Delete(content_obj); + ESP_LOGE(TAG, "Failed to add content_obj to content array!"); return NULL; } - if (!cJSON_AddItemToArray(messages, message)) { + + if (!cJSON_AddItemToObject(message, "content", content_arr)) { cJSON_Delete(message); - ESP_LOGE(TAG, "cJSON_AddItemToArray failed!"); + cJSON_Delete(content_arr); + ESP_LOGE(TAG, "Failed to add content array to message!"); return NULL; } + return message; } -OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *chatCompletion, const char *p, bool save) +OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *chatCompletion, const char *type, const char *contentValue, bool save) { + const char *role = "user"; const char *endpoint = "chat/completions"; OpenAI_StringResponse_t *result = NULL; cJSON *req = cJSON_CreateObject(); OPENAI_ERROR_CHECK(req != NULL, "cJSON_CreateObject failed!", result); + _OpenAI_ChatCompletion_t *_chatCompletion = __containerof(chatCompletion, _OpenAI_ChatCompletion_t, parent); reqAddString("model", (_chatCompletion->model == NULL) ? "gpt-3.5-turbo" : _chatCompletion->model); @@ -1096,11 +1168,19 @@ OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *c ESP_LOGE(TAG, "cJSON_CreateArray failed!"); return result; } - if (_chatCompletion->description != NULL) { - if (createChatMessage(_messages, "system", _chatCompletion->description) == NULL) { + if (_chatCompletion->description) { + cJSON *system_msg = createChatMessage("system", "text", _chatCompletion->description); + if (!system_msg) { + cJSON_Delete(req); + cJSON_Delete(_messages); + ESP_LOGE(TAG, "Failed to create system_msg!"); + return result; + } + if (!cJSON_AddItemToArray(_messages, system_msg)) { cJSON_Delete(req); cJSON_Delete(_messages); - ESP_LOGE(TAG, "createChatMessage failed!"); + cJSON_Delete(system_msg); + ESP_LOGE(TAG, "Failed to add system_msg!"); return result; } } @@ -1118,10 +1198,18 @@ OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *c } } } - if (createChatMessage(_messages, "user", p) == NULL) { + cJSON *new_msg = createChatMessage(role, type, contentValue); + if (!new_msg) { + cJSON_Delete(req); + cJSON_Delete(_messages); + ESP_LOGE(TAG, "Failed to create new_msg!"); + return result; + } + if (!cJSON_AddItemToArray(_messages, new_msg)) { cJSON_Delete(req); cJSON_Delete(_messages); - ESP_LOGE(TAG, "createChatMessage failed!"); + cJSON_Delete(new_msg); + ESP_LOGE(TAG, "Failed to add new_msg!"); return result; } @@ -1156,12 +1244,13 @@ OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *c //add the responses to the messages here //double parsing is here as workaround OpenAI_StringResponse_t *r = OpenAI_StringResponseCreate(res); - if (r->getLen(r)) { - if (createChatMessage(_chatCompletion->messages, "user", p) == NULL) { - ESP_LOGE(TAG, "createChatMessage failed!"); - } - if (createChatMessage(_chatCompletion->messages, "assistant", r->getData(r, 0)) == NULL) { - ESP_LOGE(TAG, "createChatMessage failed!"); + if (r && r->getLen(r)) { + const char *assistant_text = r->getData(r, 0); + cJSON *assistant_msg = createChatMessage("assistant", "text", assistant_text); + if (assistant_msg) { + cJSON_AddItemToArray(_chatCompletion->messages, assistant_msg); + } else { + ESP_LOGE(TAG, "Failed to create assistant_msg!"); } } return r; @@ -1782,7 +1871,7 @@ static const char *audio_input_mime[] = { "audio/webm" }; -static const char *audio_speech_formats[] = {"mp3", "opus", "aac", "flac"}; +static const char *audio_speech_formats[] = {"mp3", "opus", "aac", "flac", "wav", "pcm"}; /** * @brief Gives audio from the input text. @@ -1845,7 +1934,7 @@ static void OpenAI_AudioSpeechSetSpeed(OpenAI_AudioSpeech_t *speech, float t) static void OpenAI_AudioSpeechSetResponseFormat(OpenAI_AudioSpeech_t *audioCreateSpeech, OpenAI_Audio_Output_Format rf) { _OpenAI_AudioSpeech_t *_audioCreateSpeech = __containerof(audioCreateSpeech, _OpenAI_AudioSpeech_t, parent); - if (rf >= OPENAI_AUDIO_OUTPUT_FORMAT_MP3 && rf <= OPENAI_AUDIO_OUTPUT_FORMAT_FLAC) { + if (rf >= OPENAI_AUDIO_OUTPUT_FORMAT_MP3 && rf < OPENAI_AUDIO_OUTPUT_FORMAT_MAX) { _audioCreateSpeech->response_format = rf; } } @@ -1938,6 +2027,31 @@ OpenAI_SpeechResponse_t *OpenAI_AudioSpeechMessage(OpenAI_AudioSpeech_t *audioSp return OpenAI_SpeechResponseCreate(res, dataLength); } +OpenAI_SpeechResponse_t *OpenAI_AudioSpeechMessageStream(OpenAI_AudioSpeech_t *audioSpeech, char *p, OpenAI_StreamCallback stream_callback) +{ + size_t dataLength = 0; + const char *endpoint = "audio/speech"; + OpenAI_SpeechResponse_t *result = NULL; + cJSON *req = cJSON_CreateObject(); + OPENAI_ERROR_CHECK(req != NULL, "cJSON_CreateObject failed!", NULL); + _OpenAI_AudioSpeech_t *_audioSpeech = __containerof(audioSpeech, _OpenAI_AudioSpeech_t, parent); + reqAddString("model", (_audioSpeech->model == NULL) ? "tts-1" : _audioSpeech->model); + reqAddString("input", p); + reqAddString("voice", (_audioSpeech->voice == NULL) ? "alloy" : _audioSpeech->voice); + if (_audioSpeech->response_format != OPENAI_AUDIO_OUTPUT_FORMAT_MP3) { + reqAddString("response_format", audio_speech_formats[_audioSpeech->response_format]); + } + if (_audioSpeech->speed != 1.0) { + reqAddNumber("speed", _audioSpeech->speed); + } + char *jsonBody = cJSON_Print(req); + ESP_LOGD(TAG, "json body for Speech Message %s", jsonBody); + cJSON_Delete(req); + char *res = _audioSpeech->oai->speechpost_stream(_audioSpeech->oai->base_url, _audioSpeech->oai->api_key, endpoint, jsonBody, &dataLength, stream_callback); + free(jsonBody); + return NULL; +} + static OpenAI_AudioSpeech_t *OpenAI_AudioSpeechCreate(OpenAI_t *openai) { _OpenAI_AudioSpeech_t *_audioCreateSpeech = (_OpenAI_AudioSpeech_t *)calloc(1, sizeof(_OpenAI_AudioSpeech_t)); @@ -2486,6 +2600,80 @@ static char *OpenAI_Speech_Post(const char *base_url, const char *api_key, const return OpenAI_Speech_Request(base_url, api_key, endpoint, "application/json", HTTP_METHOD_POST, NULL, (uint8_t *)jsonBody, strlen(jsonBody), output_len); } +static char *OpenAI_Speech_Request_Stream(const char *base_url, const char *api_key, const char *endpoint, const char *content_type, esp_http_client_method_t method, const char *boundary, uint8_t *data, size_t len, size_t *output_len, OpenAI_StreamCallback stream_callback) +{ + ESP_LOGD(TAG, "\"%s\", len=%u", endpoint, len); + char *url = NULL; + asprintf(&url, "%s%s", base_url, endpoint); + OPENAI_ERROR_CHECK(url != NULL, "Failed to allocate url!", NULL); + esp_http_client_config_t config = { + .url = url, + .method = method, + .timeout_ms = 60000, + .crt_bundle_attach = esp_crt_bundle_attach, + }; + esp_http_client_handle_t client = esp_http_client_init(&config); + char *headers = NULL; + if (boundary) { + asprintf(&headers, "%s; boundary=%s", content_type, boundary); + } else { + asprintf(&headers, "%s", content_type); + } + OPENAI_ERROR_CHECK_GOTO(headers != NULL, "Failed to allocate headers!", end); + esp_http_client_set_header(client, "Content-Type", headers); + ESP_LOGD(TAG, "headers:\r\n%s", headers); + free(headers); + + asprintf(&headers, "Bearer %s", api_key); + OPENAI_ERROR_CHECK_GOTO(headers != NULL, "Failed to allocate headers!", end); + esp_http_client_set_header(client, "Authorization", headers); + free(headers); + + esp_err_t err = esp_http_client_open(client, len); + ESP_LOGD(TAG, "data:\r\n%s", data); + + OPENAI_ERROR_CHECK_GOTO(err == ESP_OK, "Failed to open client!", end); + if (len > 0) { + int wlen = esp_http_client_write(client, (const char *)data, len); + OPENAI_ERROR_CHECK_GOTO(wlen >= 0, "Failed to write client!", end); + } + int content_length = esp_http_client_fetch_headers(client); + if (esp_http_client_is_chunked_response(client)) { + esp_http_client_get_chunk_length(client, &content_length); + } + ESP_LOGD(TAG, "chunk_length=%d", content_length); //4096 + OPENAI_ERROR_CHECK_GOTO(content_length > 0, "HTTP client fetch headers failed!", end); + + int read_len = 0; + *output_len = 0; + const uint32_t chunk_size = 1024 * 33; + uint8_t * chunk_data = (uint8_t *)malloc(chunk_size); + if (!chunk_data) { + ESP_LOGE(TAG, "Failed to allocate chunk_data"); + goto end; + } + do { + read_len = esp_http_client_read_response(client, (char*)chunk_data, chunk_size); + if (stream_callback) { + stream_callback(chunk_data, read_len); + } + *output_len += read_len; + ESP_LOGD(TAG, "HTTP_READ:=%d", read_len); + } while (read_len > 0); + ESP_LOGD(TAG, "output_len: %d\n", (int)*output_len); + free(chunk_data); +end: + free(url); + esp_http_client_close(client); + esp_http_client_cleanup(client); + return NULL; +} + +static char *OpenAI_Speech_Post_Stream(const char *base_url, const char *api_key, const char *endpoint, char *jsonBody, size_t *output_len, OpenAI_StreamCallback cb) +{ + return OpenAI_Speech_Request_Stream(base_url, api_key, endpoint, "application/json", HTTP_METHOD_POST, NULL, (uint8_t *)jsonBody, strlen(jsonBody), output_len, cb); +} + static char *OpenAI_Upload(const char *base_url, const char *api_key, const char *endpoint, const char *boundary, uint8_t *data, size_t len) { return OpenAI_Request(base_url, api_key, endpoint, "multipart/form-data", HTTP_METHOD_POST, boundary, data, len); @@ -2571,6 +2759,7 @@ OpenAI_t *OpenAICreate(const char *api_key) _oai->del = &OpenAI_Del; _oai->post = &OpenAI_Post; _oai->speechpost = &OpenAI_Speech_Post; + _oai->speechpost_stream = &OpenAI_Speech_Post_Stream; _oai->upload = &OpenAI_Upload; return &_oai->parent; } diff --git a/components/openai/include/OpenAI.h b/components/openai/include/OpenAI.h index 3c32fe900..3585c9139 100644 --- a/components/openai/include/OpenAI.h +++ b/components/openai/include/OpenAI.h @@ -50,9 +50,14 @@ typedef enum { OPENAI_AUDIO_OUTPUT_FORMAT_MP3, OPENAI_AUDIO_OUTPUT_FORMAT_OPUS, OPENAI_AUDIO_OUTPUT_FORMAT_AAC, - OPENAI_AUDIO_OUTPUT_FORMAT_FLAC + OPENAI_AUDIO_OUTPUT_FORMAT_FLAC, + OPENAI_AUDIO_OUTPUT_FORMAT_WAV, + OPENAI_AUDIO_OUTPUT_FORMAT_PCM, + OPENAI_AUDIO_OUTPUT_FORMAT_MAX, } OpenAI_Audio_Output_Format; +typedef void (*OpenAI_StreamCallback)(const uint8_t *data, size_t length); + /** * @brief Struct for Embedding data * @@ -456,11 +461,13 @@ typedef struct OpenAI_ChatCompletion { * @brief Send the message for completion. Save it with the first response if selected. * * @param chatCompletion[in] the point of OpenAI_ChatCompletion + * @param type[in] the type of the message for completion * @param p[in] the message for completion * @param save[in] save it with the first response if selected * @return OpenAI_StringResponse_t* */ - OpenAI_StringResponse_t *(*message)(struct OpenAI_ChatCompletion *chatCompletion, const char *p, bool save); + OpenAI_StringResponse_t *(*message)(struct OpenAI_ChatCompletion *chatCompletion, const char *type, const char *p, bool save); + } OpenAI_ChatCompletion_t; /** @@ -753,6 +760,15 @@ typedef struct OpenAI_AudioSpeech { */ OpenAI_SpeechResponse_t *(*speech)(struct OpenAI_AudioSpeech *createSpeech, char *p); + /** + * @brief Send the message for completion. Save it with the first response if selected. + * + * @param createSpeech[in] the point of OpenAI_SpeechResponse_t + * @param p[in] the message for audio generation + * @param stream_callback[in] the callback function for audio stream + */ + void (*speechStream)(struct OpenAI_AudioSpeech *createSpeech, char *p, OpenAI_StreamCallback stream_callback); + } OpenAI_AudioSpeech_t; /** diff --git a/components/openai/test_apps/main/test_openai.c b/components/openai/test_apps/main/test_openai.c index 680fc4631..6e4cf3fee 100644 --- a/components/openai/test_apps/main/test_openai.c +++ b/components/openai/test_apps/main/test_openai.c @@ -53,7 +53,7 @@ TEST_CASE("test ChatCompletion", "[ChatCompletion]") chatCompletion->setFrequencyPenalty(chatCompletion, 0); //float between -2.0 and 2.0. Positive values decrease the model's likelihood to repeat the same line verbatim. chatCompletion->setUser(chatCompletion, "OpenAI-ESP32"); //A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. // chinese - OpenAI_StringResponse_t *result = chatCompletion->message(chatCompletion, "给我讲一个笑话", false); + OpenAI_StringResponse_t *result = chatCompletion->message(chatCompletion, "text", "给我讲一个笑话", false); TEST_ASSERT_NOT_NULL(result); if (result->getLen(result) == 1) { ESP_LOGI(TAG, "Received message. Tokens: %"PRIu32"", result->getUsage(result)); @@ -72,7 +72,25 @@ TEST_CASE("test ChatCompletion", "[ChatCompletion]") } result->deleteResponse(result); // english - result = chatCompletion->message(chatCompletion, "tell me a joke", false); + result = chatCompletion->message(chatCompletion, "text", "tell me a joke", false); + TEST_ASSERT_NOT_NULL(result); + if (result->getLen(result) == 1) { + ESP_LOGI(TAG, "Received message. Tokens: %"PRIu32"", result->getUsage(result)); + char *response = result->getData(result, 0); + ESP_LOGI(TAG, "%s", response); + } else if (result->getLen(result) > 1) { + ESP_LOGI(TAG, "Received %"PRIu32" messages. Tokens: %"PRIu32"", result->getLen(result), result->getUsage(result)); + for (int i = 0; i < result->getLen(result); ++i) { + char *response = result->getData(result, i); + ESP_LOGI(TAG, "Message[%d]: %s", i, response); + } + } else if (result->getError(result)) { + ESP_LOGE(TAG, "Error! %s", result->getError(result)); + } else { + ESP_LOGE(TAG, "Unknown error!"); + } + // image + result = chatCompletion->message(chatCompletion, "image_url", "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", false); TEST_ASSERT_NOT_NULL(result); if (result->getLen(result) == 1) { ESP_LOGI(TAG, "Received message. Tokens: %"PRIu32"", result->getUsage(result)); @@ -148,6 +166,15 @@ TEST_CASE("test AudioTranscription cn", "[AudioTranscription]") vTaskDelay(1000 / portTICK_PERIOD_MS); } +uint8_t *speech_stream_data = NULL; +size_t speech_stream_len = 0; +static void on_stream(const uint8_t *data, size_t length) +{ + speech_stream_data = (uint8_t*)realloc(speech_stream_data, speech_stream_len + length); + memcpy(speech_stream_data + speech_stream_len, data, length); + speech_stream_len += length; +} + TEST_CASE("test AudioSpeech", "[AudioSpeech]") { /* This helper function configures Wi-Fi or Ethernet, as selected in menuconfig. @@ -184,6 +211,16 @@ TEST_CASE("test AudioSpeech", "[AudioSpeech]") TEST_ASSERT_NOT_NULL(finaltext); ESP_LOGI(TAG, "Final Text: %s", finaltext); TEST_ASSERT_TRUE(strcmp(giventext, finaltext) == 0); + + /*stream mode*/ + audioSpeech->speechStream(audioSpeech, giventext, on_stream); + TEST_ASSERT_NOT_NULL(speech_stream_data); + char *finaltext2 = audioTranscription->file(audioTranscription, (uint8_t *)speech_stream_data, speech_stream_len, OPENAI_AUDIO_INPUT_FORMAT_MP3); + TEST_ASSERT_NOT_NULL(finaltext2); + ESP_LOGI(TAG, "Final Text: %s", finaltext2); + TEST_ASSERT_TRUE(strcmp(giventext, finaltext2) == 0); + + free(speech_stream_data); free(giventext); free(finaltext); openai->audioTranscriptionDelete(audioTranscription); @@ -251,3 +288,4 @@ void app_main(void) unity_run_menu(); } +