Skip to content

Commit

Permalink
feat: Chat completion supports multi-modal input (images and text).
Browse files Browse the repository at this point in the history
Text-to-speech supports streaming.
  • Loading branch information
icyqwq committed Jan 22, 2025
1 parent c880bff commit fbe2376
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 25 deletions.
231 changes: 210 additions & 21 deletions components/openai/OpenAI.c
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ typedef struct {
char *(*del)(const char *base_url, const char *api_key, const char *endpoint); /*!< Perform an HTTP DELETE request. */
char *(*post)(const char *base_url, const char *api_key, const char *endpoint, char *jsonBody); /*!< Perform an HTTP POST request. */
char *(*speechpost)(const char *base_url, const char *api_key, const char *endpoint, char *jsonBody, size_t *output_len); /*!< Perform an HTTP POST request for speech. */
char *(*speechpost_stream)(const char *base_url, const char *api_key, const char *endpoint, char *jsonBody, size_t *output_len, OpenAI_StreamCallback stream_callback); /*!< Perform an HTTP POST request for stream speech. */
char *(*upload)(const char *base_url, const char *api_key, const char *endpoint, const char *boundary, uint8_t *data, size_t len); /*!< Upload data using an HTTP request. */
} _OpenAI_t;

Expand Down Expand Up @@ -1057,35 +1058,106 @@ static void OpenAI_ChatCompletionClearConversation(OpenAI_ChatCompletion_t *chat
}
}

static cJSON *createChatMessage(cJSON *messages, const char *role, const char *content)
static cJSON *createContentObject(const char *type, const char *value)
{
cJSON *content_obj = cJSON_CreateObject();
if (!content_obj) {
ESP_LOGE(TAG, "Failed to create content_obj!");
return NULL;
}

if (cJSON_AddStringToObject(content_obj, "type", type) == NULL) {
cJSON_Delete(content_obj);
ESP_LOGE(TAG, "Failed to add 'type' field!");
return NULL;
}

if (strcmp(type, "text") == 0) {
if (cJSON_AddStringToObject(content_obj, "text", value) == NULL) {
cJSON_Delete(content_obj);
ESP_LOGE(TAG, "Failed to add 'text' field!");
return NULL;
}
} else if (strcmp(type, "image_url") == 0) {
cJSON *image_url_obj = cJSON_CreateObject();
if (!image_url_obj) {
cJSON_Delete(content_obj);
ESP_LOGE(TAG, "Failed to create image_url_obj!");
return NULL;
}
if (cJSON_AddStringToObject(image_url_obj, "url", value) == NULL) {
cJSON_Delete(content_obj);
cJSON_Delete(image_url_obj);
ESP_LOGE(TAG, "Failed to add 'url' field!");
return NULL;
}
if (!cJSON_AddItemToObject(content_obj, "image_url", image_url_obj)) {
cJSON_Delete(content_obj);
cJSON_Delete(image_url_obj);
ESP_LOGE(TAG, "Failed to add image_url_obj to content_obj!");
return NULL;
}
} else {
ESP_LOGW(TAG, "Unknown type: %s, skip building extra fields", type);
}

return content_obj;
}

static cJSON *createChatMessage(const char *role, const char *type, const char *value)
{
cJSON *message = cJSON_CreateObject();
OPENAI_ERROR_CHECK(message != NULL, "cJSON_CreateObject failed!", NULL);
if (!message) {
ESP_LOGE(TAG, "Failed to create message object!");
return NULL;
}
if (cJSON_AddStringToObject(message, "role", role) == NULL) {
cJSON_Delete(message);
ESP_LOGE(TAG, "cJSON_AddStringToObject failed!");
ESP_LOGE(TAG, "Failed to add role field!");
return NULL;
}

cJSON *content_arr = cJSON_CreateArray();
if (!content_arr) {
cJSON_Delete(message);
ESP_LOGE(TAG, "Failed to create content array!");
return NULL;
}

cJSON *content_obj = createContentObject(type, value);
if (!content_obj) {
cJSON_Delete(message);
cJSON_Delete(content_arr);
return NULL;
}
if (cJSON_AddStringToObject(message, "content", content) == NULL) {

if (!cJSON_AddItemToArray(content_arr, content_obj)) {
cJSON_Delete(message);
ESP_LOGE(TAG, "cJSON_AddStringToObject failed!");
cJSON_Delete(content_arr);
cJSON_Delete(content_obj);
ESP_LOGE(TAG, "Failed to add content_obj to content array!");
return NULL;
}
if (!cJSON_AddItemToArray(messages, message)) {

if (!cJSON_AddItemToObject(message, "content", content_arr)) {
cJSON_Delete(message);
ESP_LOGE(TAG, "cJSON_AddItemToArray failed!");
cJSON_Delete(content_arr);
ESP_LOGE(TAG, "Failed to add content array to message!");
return NULL;
}

return message;
}

OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *chatCompletion, const char *p, bool save)
OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *chatCompletion, const char *type, const char *contentValue, bool save)
{
const char *role = "user";
const char *endpoint = "chat/completions";
OpenAI_StringResponse_t *result = NULL;

cJSON *req = cJSON_CreateObject();
OPENAI_ERROR_CHECK(req != NULL, "cJSON_CreateObject failed!", result);

_OpenAI_ChatCompletion_t *_chatCompletion = __containerof(chatCompletion, _OpenAI_ChatCompletion_t, parent);
reqAddString("model", (_chatCompletion->model == NULL) ? "gpt-3.5-turbo" : _chatCompletion->model);

Expand All @@ -1096,11 +1168,19 @@ OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *c
ESP_LOGE(TAG, "cJSON_CreateArray failed!");
return result;
}
if (_chatCompletion->description != NULL) {
if (createChatMessage(_messages, "system", _chatCompletion->description) == NULL) {
if (_chatCompletion->description) {
cJSON *system_msg = createChatMessage("system", "text", _chatCompletion->description);
if (!system_msg) {
cJSON_Delete(req);
cJSON_Delete(_messages);
ESP_LOGE(TAG, "Failed to create system_msg!");
return result;
}
if (!cJSON_AddItemToArray(_messages, system_msg)) {
cJSON_Delete(req);
cJSON_Delete(_messages);
ESP_LOGE(TAG, "createChatMessage failed!");
cJSON_Delete(system_msg);
ESP_LOGE(TAG, "Failed to add system_msg!");
return result;
}
}
Expand All @@ -1118,10 +1198,18 @@ OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *c
}
}
}
if (createChatMessage(_messages, "user", p) == NULL) {
cJSON *new_msg = createChatMessage(role, type, contentValue);
if (!new_msg) {
cJSON_Delete(req);
cJSON_Delete(_messages);
ESP_LOGE(TAG, "Failed to create new_msg!");
return result;
}
if (!cJSON_AddItemToArray(_messages, new_msg)) {
cJSON_Delete(req);
cJSON_Delete(_messages);
ESP_LOGE(TAG, "createChatMessage failed!");
cJSON_Delete(new_msg);
ESP_LOGE(TAG, "Failed to add new_msg!");
return result;
}

Expand Down Expand Up @@ -1156,12 +1244,13 @@ OpenAI_StringResponse_t *OpenAI_ChatCompletionMessage(OpenAI_ChatCompletion_t *c
//add the responses to the messages here
//double parsing is here as workaround
OpenAI_StringResponse_t *r = OpenAI_StringResponseCreate(res);
if (r->getLen(r)) {
if (createChatMessage(_chatCompletion->messages, "user", p) == NULL) {
ESP_LOGE(TAG, "createChatMessage failed!");
}
if (createChatMessage(_chatCompletion->messages, "assistant", r->getData(r, 0)) == NULL) {
ESP_LOGE(TAG, "createChatMessage failed!");
if (r && r->getLen(r)) {
const char *assistant_text = r->getData(r, 0);
cJSON *assistant_msg = createChatMessage("assistant", "text", assistant_text);
if (assistant_msg) {
cJSON_AddItemToArray(_chatCompletion->messages, assistant_msg);
} else {
ESP_LOGE(TAG, "Failed to create assistant_msg!");
}
}
return r;
Expand Down Expand Up @@ -1782,7 +1871,7 @@ static const char *audio_input_mime[] = {
"audio/webm"
};

static const char *audio_speech_formats[] = {"mp3", "opus", "aac", "flac"};
static const char *audio_speech_formats[] = {"mp3", "opus", "aac", "flac", "wav", "pcm"};

/**
* @brief Gives audio from the input text.
Expand Down Expand Up @@ -1845,7 +1934,7 @@ static void OpenAI_AudioSpeechSetSpeed(OpenAI_AudioSpeech_t *speech, float t)
static void OpenAI_AudioSpeechSetResponseFormat(OpenAI_AudioSpeech_t *audioCreateSpeech, OpenAI_Audio_Output_Format rf)
{
_OpenAI_AudioSpeech_t *_audioCreateSpeech = __containerof(audioCreateSpeech, _OpenAI_AudioSpeech_t, parent);
if (rf >= OPENAI_AUDIO_OUTPUT_FORMAT_MP3 && rf <= OPENAI_AUDIO_OUTPUT_FORMAT_FLAC) {
if (rf >= OPENAI_AUDIO_OUTPUT_FORMAT_MP3 && rf < OPENAI_AUDIO_OUTPUT_FORMAT_MAX) {
_audioCreateSpeech->response_format = rf;
}
}
Expand Down Expand Up @@ -1938,6 +2027,31 @@ OpenAI_SpeechResponse_t *OpenAI_AudioSpeechMessage(OpenAI_AudioSpeech_t *audioSp
return OpenAI_SpeechResponseCreate(res, dataLength);
}

OpenAI_SpeechResponse_t *OpenAI_AudioSpeechMessageStream(OpenAI_AudioSpeech_t *audioSpeech, char *p, OpenAI_StreamCallback stream_callback)
{
size_t dataLength = 0;
const char *endpoint = "audio/speech";
OpenAI_SpeechResponse_t *result = NULL;
cJSON *req = cJSON_CreateObject();
OPENAI_ERROR_CHECK(req != NULL, "cJSON_CreateObject failed!", NULL);
_OpenAI_AudioSpeech_t *_audioSpeech = __containerof(audioSpeech, _OpenAI_AudioSpeech_t, parent);
reqAddString("model", (_audioSpeech->model == NULL) ? "tts-1" : _audioSpeech->model);
reqAddString("input", p);
reqAddString("voice", (_audioSpeech->voice == NULL) ? "alloy" : _audioSpeech->voice);
if (_audioSpeech->response_format != OPENAI_AUDIO_OUTPUT_FORMAT_MP3) {
reqAddString("response_format", audio_speech_formats[_audioSpeech->response_format]);
}
if (_audioSpeech->speed != 1.0) {
reqAddNumber("speed", _audioSpeech->speed);
}
char *jsonBody = cJSON_Print(req);
ESP_LOGD(TAG, "json body for Speech Message %s", jsonBody);
cJSON_Delete(req);
char *res = _audioSpeech->oai->speechpost_stream(_audioSpeech->oai->base_url, _audioSpeech->oai->api_key, endpoint, jsonBody, &dataLength, stream_callback);
free(jsonBody);
return NULL;
}

static OpenAI_AudioSpeech_t *OpenAI_AudioSpeechCreate(OpenAI_t *openai)
{
_OpenAI_AudioSpeech_t *_audioCreateSpeech = (_OpenAI_AudioSpeech_t *)calloc(1, sizeof(_OpenAI_AudioSpeech_t));
Expand Down Expand Up @@ -2486,6 +2600,80 @@ static char *OpenAI_Speech_Post(const char *base_url, const char *api_key, const
return OpenAI_Speech_Request(base_url, api_key, endpoint, "application/json", HTTP_METHOD_POST, NULL, (uint8_t *)jsonBody, strlen(jsonBody), output_len);
}

static char *OpenAI_Speech_Request_Stream(const char *base_url, const char *api_key, const char *endpoint, const char *content_type, esp_http_client_method_t method, const char *boundary, uint8_t *data, size_t len, size_t *output_len, OpenAI_StreamCallback stream_callback)
{
ESP_LOGD(TAG, "\"%s\", len=%u", endpoint, len);
char *url = NULL;
asprintf(&url, "%s%s", base_url, endpoint);
OPENAI_ERROR_CHECK(url != NULL, "Failed to allocate url!", NULL);
esp_http_client_config_t config = {
.url = url,
.method = method,
.timeout_ms = 60000,
.crt_bundle_attach = esp_crt_bundle_attach,
};
esp_http_client_handle_t client = esp_http_client_init(&config);
char *headers = NULL;
if (boundary) {
asprintf(&headers, "%s; boundary=%s", content_type, boundary);
} else {
asprintf(&headers, "%s", content_type);
}
OPENAI_ERROR_CHECK_GOTO(headers != NULL, "Failed to allocate headers!", end);
esp_http_client_set_header(client, "Content-Type", headers);
ESP_LOGD(TAG, "headers:\r\n%s", headers);
free(headers);

asprintf(&headers, "Bearer %s", api_key);
OPENAI_ERROR_CHECK_GOTO(headers != NULL, "Failed to allocate headers!", end);
esp_http_client_set_header(client, "Authorization", headers);
free(headers);

esp_err_t err = esp_http_client_open(client, len);
ESP_LOGD(TAG, "data:\r\n%s", data);

OPENAI_ERROR_CHECK_GOTO(err == ESP_OK, "Failed to open client!", end);
if (len > 0) {
int wlen = esp_http_client_write(client, (const char *)data, len);
OPENAI_ERROR_CHECK_GOTO(wlen >= 0, "Failed to write client!", end);
}
int content_length = esp_http_client_fetch_headers(client);
if (esp_http_client_is_chunked_response(client)) {
esp_http_client_get_chunk_length(client, &content_length);
}
ESP_LOGD(TAG, "chunk_length=%d", content_length); //4096
OPENAI_ERROR_CHECK_GOTO(content_length > 0, "HTTP client fetch headers failed!", end);

int read_len = 0;
*output_len = 0;
const uint32_t chunk_size = 1024 * 33;
uint8_t * chunk_data = (uint8_t *)malloc(chunk_size);
if (!chunk_data) {
ESP_LOGE(TAG, "Failed to allocate chunk_data");
goto end;
}
do {
read_len = esp_http_client_read_response(client, (char*)chunk_data, chunk_size);
if (stream_callback) {
stream_callback(chunk_data, read_len);
}
*output_len += read_len;
ESP_LOGD(TAG, "HTTP_READ:=%d", read_len);
} while (read_len > 0);
ESP_LOGD(TAG, "output_len: %d\n", (int)*output_len);
free(chunk_data);
end:
free(url);
esp_http_client_close(client);
esp_http_client_cleanup(client);
return NULL;
}

static char *OpenAI_Speech_Post_Stream(const char *base_url, const char *api_key, const char *endpoint, char *jsonBody, size_t *output_len, OpenAI_StreamCallback cb)
{
return OpenAI_Speech_Request_Stream(base_url, api_key, endpoint, "application/json", HTTP_METHOD_POST, NULL, (uint8_t *)jsonBody, strlen(jsonBody), output_len, cb);
}

static char *OpenAI_Upload(const char *base_url, const char *api_key, const char *endpoint, const char *boundary, uint8_t *data, size_t len)
{
return OpenAI_Request(base_url, api_key, endpoint, "multipart/form-data", HTTP_METHOD_POST, boundary, data, len);
Expand Down Expand Up @@ -2571,6 +2759,7 @@ OpenAI_t *OpenAICreate(const char *api_key)
_oai->del = &OpenAI_Del;
_oai->post = &OpenAI_Post;
_oai->speechpost = &OpenAI_Speech_Post;
_oai->speechpost_stream = &OpenAI_Speech_Post_Stream;
_oai->upload = &OpenAI_Upload;
return &_oai->parent;
}
20 changes: 18 additions & 2 deletions components/openai/include/OpenAI.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ typedef enum {
OPENAI_AUDIO_OUTPUT_FORMAT_MP3,
OPENAI_AUDIO_OUTPUT_FORMAT_OPUS,
OPENAI_AUDIO_OUTPUT_FORMAT_AAC,
OPENAI_AUDIO_OUTPUT_FORMAT_FLAC
OPENAI_AUDIO_OUTPUT_FORMAT_FLAC,
OPENAI_AUDIO_OUTPUT_FORMAT_WAV,
OPENAI_AUDIO_OUTPUT_FORMAT_PCM,
OPENAI_AUDIO_OUTPUT_FORMAT_MAX,
} OpenAI_Audio_Output_Format;

typedef void (*OpenAI_StreamCallback)(const uint8_t *data, size_t length);

/**
* @brief Struct for Embedding data
*
Expand Down Expand Up @@ -456,11 +461,13 @@ typedef struct OpenAI_ChatCompletion {
* @brief Send the message for completion. Save it with the first response if selected.
*
* @param chatCompletion[in] the point of OpenAI_ChatCompletion
* @param type[in] the type of the message for completion
* @param p[in] the message for completion
* @param save[in] save it with the first response if selected
* @return OpenAI_StringResponse_t*
*/
OpenAI_StringResponse_t *(*message)(struct OpenAI_ChatCompletion *chatCompletion, const char *p, bool save);
OpenAI_StringResponse_t *(*message)(struct OpenAI_ChatCompletion *chatCompletion, const char *type, const char *p, bool save);

} OpenAI_ChatCompletion_t;

/**
Expand Down Expand Up @@ -753,6 +760,15 @@ typedef struct OpenAI_AudioSpeech {
*/
OpenAI_SpeechResponse_t *(*speech)(struct OpenAI_AudioSpeech *createSpeech, char *p);

/**
* @brief Send the message for completion. Save it with the first response if selected.
*
* @param createSpeech[in] the point of OpenAI_SpeechResponse_t
* @param p[in] the message for audio generation
* @param stream_callback[in] the callback function for audio stream
*/
void (*speechStream)(struct OpenAI_AudioSpeech *createSpeech, char *p, OpenAI_StreamCallback stream_callback);

} OpenAI_AudioSpeech_t;

/**
Expand Down
Loading

0 comments on commit fbe2376

Please sign in to comment.