-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6174948
commit 607c393
Showing
27 changed files
with
567 additions
and
196 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
package chatgpt | ||
|
||
import "C" | ||
import ( | ||
"chat/globals" | ||
"chat/utils" | ||
"fmt" | ||
"github.com/spf13/viper" | ||
"strings" | ||
) | ||
|
||
type ChatProps struct { | ||
Model string | ||
Message []globals.Message | ||
Token int | ||
} | ||
|
||
func (c *ChatInstance) GetChatEndpoint(props *ChatProps) string { | ||
if props.Model == globals.GPT3TurboInstruct { | ||
return fmt.Sprintf("%s/v1/completions", c.GetEndpoint()) | ||
} | ||
return fmt.Sprintf("%s/v1/chat/completions", c.GetEndpoint()) | ||
} | ||
|
||
func (c *ChatInstance) GetCompletionPrompt(messages []globals.Message) string { | ||
result := "" | ||
for _, message := range messages { | ||
result += fmt.Sprintf("%s: %s\n", message.Role, message.Content) | ||
} | ||
return result | ||
} | ||
|
||
func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} { | ||
if props.Model == globals.GPT3TurboInstruct { | ||
// for completions | ||
return utils.Multi[interface{}](props.Token != -1, CompletionRequest{ | ||
Model: props.Model, | ||
Prompt: c.GetCompletionPrompt(props.Message), | ||
MaxToken: props.Token, | ||
Stream: stream, | ||
}, CompletionWithInfinity{ | ||
Model: props.Model, | ||
Prompt: c.GetCompletionPrompt(props.Message), | ||
Stream: stream, | ||
}) | ||
} | ||
|
||
if props.Token != -1 { | ||
return ChatRequest{ | ||
Model: props.Model, | ||
Messages: formatMessages(props), | ||
MaxToken: props.Token, | ||
Stream: stream, | ||
} | ||
} | ||
|
||
return ChatRequestWithInfinity{ | ||
Model: props.Model, | ||
Messages: formatMessages(props), | ||
Stream: stream, | ||
} | ||
} | ||
|
||
// CreateChatRequest is the native http request body for chatgpt | ||
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { | ||
res, err := utils.Post( | ||
c.GetChatEndpoint(props), | ||
c.GetHeader(), | ||
c.GetChatBody(props, false), | ||
) | ||
|
||
if err != nil || res == nil { | ||
return "", fmt.Errorf("chatgpt error: %s", err.Error()) | ||
} | ||
|
||
data := utils.MapToStruct[ChatResponse](res) | ||
if data == nil { | ||
return "", fmt.Errorf("chatgpt error: cannot parse response") | ||
} else if data.Error.Message != "" { | ||
return "", fmt.Errorf("chatgpt error: %s", data.Error.Message) | ||
} | ||
return data.Choices[0].Message.Content, nil | ||
} | ||
|
||
// CreateStreamChatRequest is the stream response body for chatgpt | ||
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { | ||
buf := "" | ||
instruct := props.Model == globals.GPT3TurboInstruct | ||
|
||
return utils.EventSource( | ||
"POST", | ||
c.GetChatEndpoint(props), | ||
c.GetHeader(), | ||
c.GetChatBody(props, true), | ||
func(data string) error { | ||
data, err := c.ProcessLine(instruct, buf, data) | ||
|
||
if err != nil { | ||
if strings.HasPrefix(err.Error(), "chatgpt error") { | ||
return err | ||
} | ||
|
||
// error when break line | ||
buf = buf + data | ||
return nil | ||
} | ||
|
||
buf = "" | ||
if data != "" { | ||
if err := callback(data); err != nil { | ||
return err | ||
} | ||
} | ||
return nil | ||
}, | ||
) | ||
} | ||
|
||
func (c *ChatInstance) Test() bool { | ||
result, err := c.CreateChatRequest(&ChatProps{ | ||
Model: globals.GPT3Turbo, | ||
Message: []globals.Message{{Role: "user", Content: "hi"}}, | ||
Token: 1, | ||
}) | ||
if err != nil { | ||
fmt.Println(fmt.Sprintf("%s: test failed (%s)", c.GetApiKey(), err.Error())) | ||
} | ||
|
||
return err == nil && len(result) > 0 | ||
} | ||
|
||
func FilterKeys(v string) []string { | ||
endpoint := viper.GetString(fmt.Sprintf("openai.%s.endpoint", v)) | ||
keys := strings.Split(viper.GetString(fmt.Sprintf("openai.%s.apikey", v)), "|") | ||
|
||
return FilterKeysNative(endpoint, keys) | ||
} | ||
|
||
func FilterKeysNative(endpoint string, keys []string) []string { | ||
stack := make(chan string, len(keys)) | ||
for _, key := range keys { | ||
go func(key string) { | ||
instance := NewChatInstance(endpoint, key) | ||
stack <- utils.Multi[string](instance.Test(), key, "") | ||
}(key) | ||
} | ||
|
||
var result []string | ||
for i := 0; i < len(keys); i++ { | ||
if res := <-stack; res != "" { | ||
result = append(result, res) | ||
} | ||
} | ||
return result | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
package chatgpt | ||
|
||
import ( | ||
"chat/globals" | ||
"chat/utils" | ||
"errors" | ||
"fmt" | ||
"strings" | ||
) | ||
|
||
func processFormat(data string) string { | ||
rep := strings.NewReplacer( | ||
"data: {", | ||
"\"data\": {", | ||
) | ||
item := rep.Replace(data) | ||
if !strings.HasPrefix(item, "{") { | ||
item = "{" + item | ||
} | ||
if !strings.HasSuffix(item, "}}") { | ||
item = item + "}" | ||
} | ||
|
||
return item | ||
} | ||
|
||
func formatMessages(props *ChatProps) []globals.Message { | ||
if props.Model == globals.GPT4Vision { | ||
base := props.Message[len(props.Message)-1].Content | ||
urls := utils.ExtractUrls(base) | ||
|
||
if len(urls) > 0 { | ||
base = fmt.Sprintf("%s %s", strings.Join(urls, " "), base) | ||
} | ||
props.Message[len(props.Message)-1].Content = base | ||
return props.Message | ||
} | ||
|
||
return props.Message | ||
} | ||
|
||
func processChatResponse(data string) *ChatStreamResponse { | ||
if strings.HasPrefix(data, "{") { | ||
var form *ChatStreamResponse | ||
if form = utils.UnmarshalForm[ChatStreamResponse](data); form != nil { | ||
return form | ||
} | ||
|
||
if form = utils.UnmarshalForm[ChatStreamResponse](data[:len(data)-1]); form != nil { | ||
return form | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func processCompletionResponse(data string) *CompletionResponse { | ||
if strings.HasPrefix(data, "{") { | ||
var form *CompletionResponse | ||
if form = utils.UnmarshalForm[CompletionResponse](data); form != nil { | ||
return form | ||
} | ||
|
||
if form = utils.UnmarshalForm[CompletionResponse](data[:len(data)-1]); form != nil { | ||
return form | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func processChatErrorResponse(data string) *ChatStreamErrorResponse { | ||
if strings.HasPrefix(data, "{") { | ||
var form *ChatStreamErrorResponse | ||
if form = utils.UnmarshalForm[ChatStreamErrorResponse](data); form != nil { | ||
return form | ||
} | ||
if form = utils.UnmarshalForm[ChatStreamErrorResponse](data + "}"); form != nil { | ||
return form | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func isDone(data string) bool { | ||
return utils.Contains[string](data, []string{ | ||
"{data: [DONE]}", "{data: [DONE]}}", | ||
"{[DONE]}", "{data:}", "{data:}}", | ||
}) | ||
} | ||
|
||
func getChoices(form *ChatStreamResponse) string { | ||
if len(form.Data.Choices) == 0 { | ||
return "" | ||
} | ||
|
||
return form.Data.Choices[0].Delta.Content | ||
} | ||
|
||
func getCompletionChoices(form *CompletionResponse) string { | ||
if len(form.Data.Choices) == 0 { | ||
return "" | ||
} | ||
|
||
return form.Data.Choices[0].Text | ||
} | ||
|
||
func (c *ChatInstance) ProcessLine(instruct bool, buf, data string) (string, error) { | ||
item := processFormat(buf + data) | ||
if isDone(item) { | ||
return "", nil | ||
} | ||
|
||
if form := processChatResponse(item); form == nil { | ||
if instruct { | ||
// legacy support | ||
if completion := processCompletionResponse(item); completion != nil { | ||
return getCompletionChoices(completion), nil | ||
} | ||
} | ||
|
||
// recursive call | ||
if len(buf) > 0 { | ||
return c.ProcessLine(instruct, "", buf+item) | ||
} | ||
|
||
if err := processChatErrorResponse(item); err == nil { | ||
globals.Warn(fmt.Sprintf("chatgpt error: cannot parse response: %s", item)) | ||
return data, errors.New("parser error: cannot parse response") | ||
} else { | ||
return "", fmt.Errorf("chatgpt error: %s (type: %s)", err.Data.Error.Message, err.Data.Error.Type) | ||
} | ||
|
||
} else { | ||
return getChoices(form), nil | ||
} | ||
} |
Oops, something went wrong.