-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
261e500
commit a1f4845
Showing
11 changed files
with
673 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
package azure | ||
|
||
import ( | ||
"chat/globals" | ||
"chat/utils" | ||
"fmt" | ||
"strings" | ||
) | ||
|
||
type ChatProps struct { | ||
Model string | ||
Message []globals.Message | ||
Token *int | ||
PresencePenalty *float32 | ||
FrequencyPenalty *float32 | ||
Temperature *float32 | ||
TopP *float32 | ||
Tools *globals.FunctionTools | ||
ToolChoice *interface{} | ||
Buffer utils.Buffer | ||
} | ||
|
||
func (c *ChatInstance) GetChatEndpoint(props *ChatProps) string { | ||
model := strings.ReplaceAll(props.Model, ".", "") | ||
if props.Model == globals.GPT3TurboInstruct { | ||
return fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", c.GetResource(), model, c.GetEndpoint()) | ||
} | ||
return fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", c.GetResource(), model, c.GetEndpoint()) | ||
} | ||
|
||
func (c *ChatInstance) GetCompletionPrompt(messages []globals.Message) string { | ||
result := "" | ||
for _, message := range messages { | ||
result += fmt.Sprintf("%s: %s\n", message.Role, message.Content) | ||
} | ||
return result | ||
} | ||
|
||
func (c *ChatInstance) GetLatestPrompt(props *ChatProps) string { | ||
if len(props.Message) == 0 { | ||
return "" | ||
} | ||
|
||
return props.Message[len(props.Message)-1].Content | ||
} | ||
|
||
func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} { | ||
if props.Model == globals.GPT3TurboInstruct { | ||
// for completions | ||
return CompletionRequest{ | ||
Prompt: c.GetCompletionPrompt(props.Message), | ||
MaxToken: props.Token, | ||
Stream: stream, | ||
} | ||
} | ||
|
||
return ChatRequest{ | ||
Messages: formatMessages(props), | ||
MaxToken: props.Token, | ||
Stream: stream, | ||
PresencePenalty: props.PresencePenalty, | ||
FrequencyPenalty: props.FrequencyPenalty, | ||
Temperature: props.Temperature, | ||
TopP: props.TopP, | ||
Tools: props.Tools, | ||
ToolChoice: props.ToolChoice, | ||
} | ||
} | ||
|
||
// CreateChatRequest is the native http request body for chatgpt | ||
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { | ||
if globals.IsDalleModel(props.Model) { | ||
return c.CreateImage(props) | ||
} | ||
|
||
res, err := utils.Post( | ||
c.GetChatEndpoint(props), | ||
c.GetHeader(), | ||
c.GetChatBody(props, false), | ||
) | ||
|
||
if err != nil || res == nil { | ||
return "", fmt.Errorf("chatgpt error: %s", err.Error()) | ||
} | ||
|
||
data := utils.MapToStruct[ChatResponse](res) | ||
if data == nil { | ||
return "", fmt.Errorf("chatgpt error: cannot parse response") | ||
} else if data.Error.Message != "" { | ||
return "", fmt.Errorf("chatgpt error: %s", data.Error.Message) | ||
} | ||
return data.Choices[0].Message.Content, nil | ||
} | ||
|
||
// CreateStreamChatRequest is the stream response body for chatgpt | ||
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { | ||
if globals.IsDalleModel(props.Model) { | ||
if url, err := c.CreateImage(props); err != nil { | ||
return err | ||
} else { | ||
return callback(url) | ||
} | ||
} | ||
|
||
buf := "" | ||
cursor := 0 | ||
chunk := "" | ||
instruct := props.Model == globals.GPT3TurboInstruct | ||
|
||
err := utils.EventSource( | ||
"POST", | ||
c.GetChatEndpoint(props), | ||
c.GetHeader(), | ||
c.GetChatBody(props, true), | ||
func(data string) error { | ||
data, err := c.ProcessLine(props.Buffer, instruct, buf, data) | ||
chunk += data | ||
|
||
if err != nil { | ||
if strings.HasPrefix(err.Error(), "chatgpt error") { | ||
return err | ||
} | ||
|
||
// error when break line | ||
buf = buf + data | ||
return nil | ||
} | ||
|
||
buf = "" | ||
if data != "" { | ||
cursor += 1 | ||
if err := callback(data); err != nil { | ||
return err | ||
} | ||
} | ||
return nil | ||
}, | ||
) | ||
|
||
if err != nil { | ||
return err | ||
} else if len(chunk) == 0 { | ||
return fmt.Errorf("empty response") | ||
} | ||
|
||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
package azure | ||
|
||
import ( | ||
"chat/globals" | ||
"chat/utils" | ||
"fmt" | ||
"strings" | ||
) | ||
|
||
type ImageProps struct { | ||
Model string | ||
Prompt string | ||
Size ImageSize | ||
} | ||
|
||
func (c *ChatInstance) GetImageEndpoint(model string) string { | ||
model = strings.ReplaceAll(model, ".", "") | ||
return fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", c.GetResource(), model, c.GetEndpoint()) | ||
} | ||
|
||
// CreateImageRequest will create a dalle image from prompt, return url of image and error | ||
func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) { | ||
res, err := utils.Post( | ||
c.GetImageEndpoint(props.Model), | ||
c.GetHeader(), ImageRequest{ | ||
Prompt: props.Prompt, | ||
Size: utils.Multi[ImageSize]( | ||
props.Model == globals.Dalle3, | ||
ImageSize1024, | ||
ImageSize512, | ||
), | ||
N: 1, | ||
}) | ||
if err != nil || res == nil { | ||
return "", fmt.Errorf("chatgpt error: %s", err.Error()) | ||
} | ||
|
||
data := utils.MapToStruct[ImageResponse](res) | ||
if data == nil { | ||
return "", fmt.Errorf("chatgpt error: cannot parse response") | ||
} else if data.Error.Message != "" { | ||
return "", fmt.Errorf("chatgpt error: %s", data.Error.Message) | ||
} | ||
|
||
return data.Data[0].Url, nil | ||
} | ||
|
||
// CreateImage will create a dalle image from prompt, return markdown of image | ||
func (c *ChatInstance) CreateImage(props *ChatProps) (string, error) { | ||
url, err := c.CreateImageRequest(ImageProps{ | ||
Model: props.Model, | ||
Prompt: c.GetLatestPrompt(props), | ||
}) | ||
if err != nil { | ||
if strings.Contains(err.Error(), "safety") { | ||
return err.Error(), nil | ||
} | ||
return "", err | ||
} | ||
|
||
return utils.GetImageMarkdown(url), nil | ||
} |
Oops, something went wrong.