Skip to content

Commit

Permalink
fix: nil pointer dereference error when carrying an image to a conv…
Browse files Browse the repository at this point in the history
…ersation (#221)
  • Loading branch information
Sh1n3zZ committed Jul 2, 2024
1 parent 980bb2c commit 140bed5
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 41 deletions.
45 changes: 27 additions & 18 deletions adapter/common/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,37 @@ import (
)

type RequestProps struct {
MaxRetries *int
Current int
Group string

Proxy globals.ProxyConfig
MaxRetries *int `json:"-"`
Current int `json:"-"`
Group string `json:"-"`
Proxy globals.ProxyConfig `json:"-"`
}

type ChatProps struct {
RequestProps

Model string
OriginalModel string
Model string `json:"model,omitempty"`
OriginalModel string `json:"-"`

Message []globals.Message `json:"messages,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
PresencePenalty *float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty *float32 `json:"frequency_penalty,omitempty"`
RepetitionPenalty *float32 `json:"repetition_penalty,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"top_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
Tools *globals.FunctionTools `json:"tools,omitempty"`
ToolChoice *interface{} `json:"tool_choice,omitempty"`
Buffer *utils.Buffer `json:"-"`
}

func (c *ChatProps) SetupBuffer(buf *utils.Buffer) {
buf.SetPrompts(c)
c.Buffer = buf
}

Message []globals.Message
MaxTokens *int
PresencePenalty *float32
FrequencyPenalty *float32
RepetitionPenalty *float32
Temperature *float32
TopP *float32
TopK *int
Tools *globals.FunctionTools
ToolChoice *interface{}
Buffer utils.Buffer
func CreateChatProps(props *ChatProps, buffer *utils.Buffer) *ChatProps {
props.SetupBuffer(buffer)
return props
}
9 changes: 4 additions & 5 deletions addition/generation/prompt.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package generation

import (
"chat/adapter/common"
adaptercommon "chat/adapter/common"
"chat/admin"
"chat/channel"
"chat/globals"
Expand All @@ -17,17 +17,16 @@ func CreateGeneration(group, model, prompt, path string, hook func(buffer *utils
message := GenerateMessage(prompt)
buffer := utils.NewBuffer(model, message, channel.ChargeInstance.GetCharge(model))

err := channel.NewChatRequest(group, &adaptercommon.ChatProps{
err := channel.NewChatRequest(group, adaptercommon.CreateChatProps(&adaptercommon.ChatProps{
OriginalModel: model,
Message: message,
Buffer: *buffer,
}, func(data *globals.Chunk) error {
}, buffer), func(data *globals.Chunk) error {
buffer.WriteChunk(data)
hook(buffer, data.Content)
return nil
})

admin.AnalysisRequest(model, buffer, err)
admin.AnalyseRequest(model, buffer, err)
if err != nil {
return err
}
Expand Down
5 changes: 3 additions & 2 deletions admin/statistic.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import (
"chat/adapter"
"chat/connection"
"chat/utils"
"github.com/go-redis/redis/v8"
"time"

"github.com/go-redis/redis/v8"
)

func IncrErrorRequest(cache *redis.Client) {
Expand All @@ -25,7 +26,7 @@ func IncrModelRequest(cache *redis.Client, model string, tokens int64) {
utils.IncrWithExpire(cache, getModelFormat(getDay(), model), tokens, time.Hour*24*7*2)
}

func AnalysisRequest(model string, buffer *utils.Buffer, err error) {
func AnalyseRequest(model string, buffer *utils.Buffer, err error) {
instance := connection.Cache

if adapter.IsAvailableError(err) {
Expand Down
7 changes: 4 additions & 3 deletions manager/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func createChatTask(
hit, err := channel.NewChatRequestWithCache(
cache, buffer,
auth.GetGroup(db, user),
&adaptercommon.ChatProps{
adaptercommon.CreateChatProps(&adaptercommon.ChatProps{
Model: model,
Message: segment,
MaxTokens: instance.GetMaxTokens(),
Expand All @@ -114,7 +114,7 @@ func createChatTask(
PresencePenalty: instance.GetPresencePenalty(),
FrequencyPenalty: instance.GetFrequencyPenalty(),
RepetitionPenalty: instance.GetRepetitionPenalty(),
},
}, buffer),

// the function to handle the chunk data
func(data *globals.Chunk) error {
Expand Down Expand Up @@ -168,6 +168,7 @@ func createChatTask(
interruptSignal <- err
return hit, nil
}

case signal := <-stopSignal:
// if stop signal is received
if signal {
Expand Down Expand Up @@ -219,7 +220,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
hit, err := createChatTask(conn, user, buffer, db, cache, model, instance, segment, plan)

admin.AnalysisRequest(model, buffer, err)
admin.AnalyseRequest(model, buffer, err)
if adapter.IsAvailableError(err) {
globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))

Expand Down
9 changes: 4 additions & 5 deletions manager/chat_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func ChatRelayAPI(c *gin.Context) {
}

func getChatProps(form RelayForm, messages []globals.Message, buffer *utils.Buffer) *adaptercommon.ChatProps {
return &adaptercommon.ChatProps{
return adaptercommon.CreateChatProps(&adaptercommon.ChatProps{
Model: form.Model,
Message: messages,
MaxTokens: form.MaxTokens,
Expand All @@ -106,8 +106,7 @@ func getChatProps(form RelayForm, messages []globals.Message, buffer *utils.Buff
TopK: form.TopK,
Tools: form.Tools,
ToolChoice: form.ToolChoice,
Buffer: *buffer,
}
}, buffer)
}

func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) {
Expand All @@ -120,7 +119,7 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals
return nil
})

admin.AnalysisRequest(form.Model, buffer, err)
admin.AnalyseRequest(form.Model, buffer, err)
if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, form.Model)
globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err, form.Model, c.ClientIP()))
Expand Down Expand Up @@ -235,7 +234,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form RelayForm, messages []g
},
)

admin.AnalysisRequest(form.Model, buffer, err)
admin.AnalyseRequest(form.Model, buffer, err)
if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, form.Model)
globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err.Error(), form.Model, c.ClientIP()))
Expand Down
7 changes: 3 additions & 4 deletions manager/completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,17 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
hit, err := channel.NewChatRequestWithCache(
cache, buffer,
auth.GetGroup(db, user),
&adaptercommon.ChatProps{
adaptercommon.CreateChatProps(&adaptercommon.ChatProps{
Model: model,
Message: segment,
Buffer: *buffer,
},
}, buffer),
func(resp *globals.Chunk) error {
buffer.WriteChunk(resp)
return nil
},
)

admin.AnalysisRequest(model, buffer, err)
admin.AnalyseRequest(model, buffer, err)
if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, model)
return err.Error(), 0
Expand Down
7 changes: 3 additions & 4 deletions manager/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,11 @@ func ImagesRelayAPI(c *gin.Context) {
}

func getImageProps(form RelayImageForm, messages []globals.Message, buffer *utils.Buffer) *adaptercommon.ChatProps {
return &adaptercommon.ChatProps{
return adaptercommon.CreateChatProps(&adaptercommon.ChatProps{
Model: form.Model,
Message: messages,
MaxTokens: utils.ToPtr(-1),
Buffer: *buffer,
}
}, buffer)
}

func getUrlFromBuffer(buffer *utils.Buffer) string {
Expand Down Expand Up @@ -100,7 +99,7 @@ func createRelayImageObject(c *gin.Context, form RelayImageForm, prompt string,
return nil
})

admin.AnalysisRequest(form.Model, buffer, err)
admin.AnalyseRequest(form.Model, buffer, err)
if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, form.Model)
globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err, form.Model, c.ClientIP()))
Expand Down

0 comments on commit 140bed5

Please sign in to comment.