Skip to content

Commit

Permalink
update dalle2 model
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Nov 3, 2023
1 parent 6174948 commit 607c393
Show file tree
Hide file tree
Showing 27 changed files with 567 additions and 196 deletions.
20 changes: 20 additions & 0 deletions adapter/chatgpt/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ func (c *ChatInstance) GetCompletionPrompt(messages []globals.Message) string {
return result
}

func (c *ChatInstance) GetLatestPrompt(props *ChatProps) string {
if len(props.Message) == 0 {
return ""
}

return props.Message[len(props.Message)-1].Content
}

func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
if props.Model == globals.GPT3TurboInstruct {
// for completions
Expand Down Expand Up @@ -63,6 +71,10 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {

// CreateChatRequest is the native http request body for chatgpt
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
if props.Model == globals.Dalle2 {
return c.CreateImage(props)
}

res, err := utils.Post(
c.GetChatEndpoint(props),
c.GetHeader(),
Expand All @@ -84,6 +96,14 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {

// CreateStreamChatRequest is the stream response body for chatgpt
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
if props.Model == globals.Dalle2 {
if url, err := c.CreateImage(props); err != nil {
return err
} else {
return callback(url)
}
}

buf := ""
instruct := props.Model == globals.GPT3TurboInstruct

Expand Down
20 changes: 18 additions & 2 deletions adapter/chatgpt/dalle.go → adapter/chatgpt/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package chatgpt
import (
"chat/utils"
"fmt"
"strings"
)

type ImageProps struct {
Expand All @@ -14,8 +15,8 @@ func (c *ChatInstance) GetImageEndpoint() string {
return fmt.Sprintf("%s/v1/images/generations", c.GetEndpoint())
}

// CreateImage will create a dalle image from prompt, return url of image and error
func (c *ChatInstance) CreateImage(props ImageProps) (string, error) {
// CreateImageRequest will create a dalle image from prompt, return url of image and error
func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) {
res, err := utils.Post(
c.GetImageEndpoint(),
c.GetHeader(), ImageRequest{
Expand All @@ -36,3 +37,18 @@ func (c *ChatInstance) CreateImage(props ImageProps) (string, error) {

return data.Data[0].Url, nil
}

// CreateImage will create a dalle image from prompt, return markdown of image
func (c *ChatInstance) CreateImage(props *ChatProps) (string, error) {
url, err := c.CreateImageRequest(ImageProps{
Prompt: c.GetLatestPrompt(props),
})
if err != nil {
if strings.Contains(err.Error(), "safety") {
return err.Error(), nil
}
return "", err
}

return utils.GetImageMarkdown(url), nil
}
2 changes: 2 additions & 0 deletions adapter/chatgpt/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ func NewChatInstanceFromModel(props *InstanceProps) *ChatInstance {
return NewChatInstanceFromConfig("subscribe")
}
return NewChatInstanceFromConfig("gpt3")
case globals.Dalle2:
return NewChatInstanceFromConfig("gpt3")
default:
return NewChatInstanceFromConfig("gpt3")
}
Expand Down
155 changes: 155 additions & 0 deletions adapter/oneapi/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package chatgpt

import "C"
import (
"chat/globals"
"chat/utils"
"fmt"
"github.com/spf13/viper"
"strings"
)

type ChatProps struct {
Model string
Message []globals.Message
Token int
}

func (c *ChatInstance) GetChatEndpoint(props *ChatProps) string {
if props.Model == globals.GPT3TurboInstruct {
return fmt.Sprintf("%s/v1/completions", c.GetEndpoint())
}
return fmt.Sprintf("%s/v1/chat/completions", c.GetEndpoint())
}

func (c *ChatInstance) GetCompletionPrompt(messages []globals.Message) string {
result := ""
for _, message := range messages {
result += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
}
return result
}

func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
if props.Model == globals.GPT3TurboInstruct {
// for completions
return utils.Multi[interface{}](props.Token != -1, CompletionRequest{
Model: props.Model,
Prompt: c.GetCompletionPrompt(props.Message),
MaxToken: props.Token,
Stream: stream,
}, CompletionWithInfinity{
Model: props.Model,
Prompt: c.GetCompletionPrompt(props.Message),
Stream: stream,
})
}

if props.Token != -1 {
return ChatRequest{
Model: props.Model,
Messages: formatMessages(props),
MaxToken: props.Token,
Stream: stream,
}
}

return ChatRequestWithInfinity{
Model: props.Model,
Messages: formatMessages(props),
Stream: stream,
}
}

// CreateChatRequest is the native http request body for chatgpt
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
res, err := utils.Post(
c.GetChatEndpoint(props),
c.GetHeader(),
c.GetChatBody(props, false),
)

if err != nil || res == nil {
return "", fmt.Errorf("chatgpt error: %s", err.Error())
}

data := utils.MapToStruct[ChatResponse](res)
if data == nil {
return "", fmt.Errorf("chatgpt error: cannot parse response")
} else if data.Error.Message != "" {
return "", fmt.Errorf("chatgpt error: %s", data.Error.Message)
}
return data.Choices[0].Message.Content, nil
}

// CreateStreamChatRequest is the stream response body for chatgpt
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
buf := ""
instruct := props.Model == globals.GPT3TurboInstruct

return utils.EventSource(
"POST",
c.GetChatEndpoint(props),
c.GetHeader(),
c.GetChatBody(props, true),
func(data string) error {
data, err := c.ProcessLine(instruct, buf, data)

if err != nil {
if strings.HasPrefix(err.Error(), "chatgpt error") {
return err
}

// error when break line
buf = buf + data
return nil
}

buf = ""
if data != "" {
if err := callback(data); err != nil {
return err
}
}
return nil
},
)
}

func (c *ChatInstance) Test() bool {
result, err := c.CreateChatRequest(&ChatProps{
Model: globals.GPT3Turbo,
Message: []globals.Message{{Role: "user", Content: "hi"}},
Token: 1,
})
if err != nil {
fmt.Println(fmt.Sprintf("%s: test failed (%s)", c.GetApiKey(), err.Error()))
}

return err == nil && len(result) > 0
}

func FilterKeys(v string) []string {
endpoint := viper.GetString(fmt.Sprintf("openai.%s.endpoint", v))
keys := strings.Split(viper.GetString(fmt.Sprintf("openai.%s.apikey", v)), "|")

return FilterKeysNative(endpoint, keys)
}

func FilterKeysNative(endpoint string, keys []string) []string {
stack := make(chan string, len(keys))
for _, key := range keys {
go func(key string) {
instance := NewChatInstance(endpoint, key)
stack <- utils.Multi[string](instance.Test(), key, "")
}(key)
}

var result []string
for i := 0; i < len(keys); i++ {
if res := <-stack; res != "" {
result = append(result, res)
}
}
return result
}
138 changes: 138 additions & 0 deletions adapter/oneapi/processor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package chatgpt

import (
"chat/globals"
"chat/utils"
"errors"
"fmt"
"strings"
)

func processFormat(data string) string {
rep := strings.NewReplacer(
"data: {",
"\"data\": {",
)
item := rep.Replace(data)
if !strings.HasPrefix(item, "{") {
item = "{" + item
}
if !strings.HasSuffix(item, "}}") {
item = item + "}"
}

return item
}

func formatMessages(props *ChatProps) []globals.Message {
if props.Model == globals.GPT4Vision {
base := props.Message[len(props.Message)-1].Content
urls := utils.ExtractUrls(base)

if len(urls) > 0 {
base = fmt.Sprintf("%s %s", strings.Join(urls, " "), base)
}
props.Message[len(props.Message)-1].Content = base
return props.Message
}

return props.Message
}

func processChatResponse(data string) *ChatStreamResponse {
if strings.HasPrefix(data, "{") {
var form *ChatStreamResponse
if form = utils.UnmarshalForm[ChatStreamResponse](data); form != nil {
return form
}

if form = utils.UnmarshalForm[ChatStreamResponse](data[:len(data)-1]); form != nil {
return form
}
}

return nil
}

func processCompletionResponse(data string) *CompletionResponse {
if strings.HasPrefix(data, "{") {
var form *CompletionResponse
if form = utils.UnmarshalForm[CompletionResponse](data); form != nil {
return form
}

if form = utils.UnmarshalForm[CompletionResponse](data[:len(data)-1]); form != nil {
return form
}
}

return nil
}

func processChatErrorResponse(data string) *ChatStreamErrorResponse {
if strings.HasPrefix(data, "{") {
var form *ChatStreamErrorResponse
if form = utils.UnmarshalForm[ChatStreamErrorResponse](data); form != nil {
return form
}
if form = utils.UnmarshalForm[ChatStreamErrorResponse](data + "}"); form != nil {
return form
}
}

return nil
}

func isDone(data string) bool {
return utils.Contains[string](data, []string{
"{data: [DONE]}", "{data: [DONE]}}",
"{[DONE]}", "{data:}", "{data:}}",
})
}

func getChoices(form *ChatStreamResponse) string {
if len(form.Data.Choices) == 0 {
return ""
}

return form.Data.Choices[0].Delta.Content
}

func getCompletionChoices(form *CompletionResponse) string {
if len(form.Data.Choices) == 0 {
return ""
}

return form.Data.Choices[0].Text
}

func (c *ChatInstance) ProcessLine(instruct bool, buf, data string) (string, error) {
item := processFormat(buf + data)
if isDone(item) {
return "", nil
}

if form := processChatResponse(item); form == nil {
if instruct {
// legacy support
if completion := processCompletionResponse(item); completion != nil {
return getCompletionChoices(completion), nil
}
}

// recursive call
if len(buf) > 0 {
return c.ProcessLine(instruct, "", buf+item)
}

if err := processChatErrorResponse(item); err == nil {
globals.Warn(fmt.Sprintf("chatgpt error: cannot parse response: %s", item))
return data, errors.New("parser error: cannot parse response")
} else {
return "", fmt.Errorf("chatgpt error: %s (type: %s)", err.Data.Error.Message, err.Data.Error.Type)
}

} else {
return getChoices(form), nil
}
}
Loading

0 comments on commit 607c393

Please sign in to comment.