Skip to content

Commit

Permalink
fix: fix midjourney chunk stacking problem (#156) and stop signal can…
Browse files Browse the repository at this point in the history
…not trigger in some channel formats issue

Co-Authored-By: Minghan Zhang <[email protected]>
  • Loading branch information
Sh1n3zZ and zmh-program committed Mar 31, 2024
1 parent 6615cd9 commit f66eb11
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 33 deletions.
45 changes: 28 additions & 17 deletions adapter/midjourney/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ import (
"chat/utils"
"fmt"
"strings"
"time"
)

const maxTimeout = 30 * time.Minute // 30 min timeout

func getStatusCode(action string, response *CommonResponse) error {
code := response.Code
switch code {
Expand Down Expand Up @@ -94,6 +97,9 @@ func (c *ChatInstance) CreateStreamTask(props *adaptercommon.ChatProps, action s
task := res.Result
progress := -1

ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()

for {
utils.Sleep(50)
form := getStorage(task)
Expand All @@ -103,25 +109,30 @@ func (c *ChatInstance) CreateStreamTask(props *adaptercommon.ChatProps, action s
return nil, err
}

continue
}

switch form.Status {
case Success:
if err := hook(form, 100); err != nil {
return nil, err
}
return form, nil
case Failure:
return nil, fmt.Errorf("task failed: %s", form.FailReason)
case InProgress:
current := getProgress(form.Progress)
if progress != current {
if err := hook(form, current); err != nil {
switch form.Status {
case Success:
if err := hook(form, 100); err != nil {
return nil, err
}
return form, nil
case Failure:
return nil, fmt.Errorf("task failed: %s", form.FailReason)
case InProgress:
current := getProgress(form.Progress)
if progress != current {
if err := hook(form, current); err != nil {
return nil, err
}
progress = current
}
default:
// ping
if err := hook(form, -1); err != nil {
return nil, err
}
progress = current
}
case <-time.After(maxTimeout):

Check failure on line 134 in adapter/midjourney/handler.go

View workflow job for this annotation

GitHub Actions / release (18.x)

syntax error: unexpected case, expected }
return nil, fmt.Errorf("task timeout")
}
}
}
}

Check failure on line 138 in adapter/midjourney/handler.go

View workflow job for this annotation

GitHub Actions / release (18.x)

syntax error: non-declaration statement outside function body
4 changes: 2 additions & 2 deletions adapter/midjourney/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ func setStorage(task string, form StorageForm) error {
return utils.SetJson(connection.Cache, getTaskName(task), form, 60*60)
}

func getStorage(task string) *StorageForm {
return utils.GetJson[StorageForm](connection.Cache, getTaskName(task))
func getNotifyStorage(task string) *StorageForm {
return utils.GetCacheStore[StorageForm](connection.Cache, getTaskName(task))
}
9 changes: 7 additions & 2 deletions adapter/request.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package adapter

import (
"chat/adapter/common"
adaptercommon "chat/adapter/common"
"chat/globals"
"chat/utils"
"fmt"
Expand All @@ -10,7 +10,11 @@ import (
)

func IsAvailableError(err error) bool {
return err != nil && err.Error() != "signal"
return err != nil && (err.Error() != "signal" && !strings.Contains(err.Error(), "signal"))
}

func IsSkipError(err error) bool {
return err == nil || (err.Error() == "signal" || strings.Contains(err.Error(), "signal"))
}

func isQPSOverLimit(model string, err error) bool {
Expand All @@ -26,6 +30,7 @@ func NewChatRequest(conf globals.ChannelConfig, props *adaptercommon.ChatProps,
retries := conf.GetRetry()
props.Current++

fmt.Println(IsAvailableError(err))
if IsAvailableError(err) {
if isQPSOverLimit(props.OriginalModel, err) {
// sleep for 0.5s to avoid qps limit
Expand Down
6 changes: 4 additions & 2 deletions admin/statistic.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package admin

import (
"chat/adapter"
"chat/connection"
"chat/utils"
"github.com/go-redis/redis/v8"
"time"

"github.com/go-redis/redis/v8"
)

func IncrErrorRequest(cache *redis.Client) {
Expand All @@ -27,7 +29,7 @@ func IncrModelRequest(cache *redis.Client, model string, tokens int64) {
func AnalysisRequest(model string, buffer *utils.Buffer, err error) {
instance := connection.Cache

if err != nil && err.Error() != "signal" {
if adapter.IsAvailableError(err) {
IncrErrorRequest(instance)
return
}
Expand Down
13 changes: 9 additions & 4 deletions channel/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@ package channel

import (
"chat/adapter"
"chat/adapter/common"
adaptercommon "chat/adapter/common"
"chat/globals"
"chat/utils"
"fmt"
"github.com/go-redis/redis/v8"
"time"

"github.com/go-redis/redis/v8"
)

func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.Hook) error {
if err := AuditContent(props); err != nil {
return err
}

ticker := ConduitInstance.GetTicker(props.OriginalModel, group)
if ticker == nil || ticker.IsEmpty() {
return fmt.Errorf("cannot find channel for model %s", props.OriginalModel)
Expand All @@ -20,8 +25,8 @@ func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.H
for !ticker.IsDone() {
if channel := ticker.Next(); channel != nil {
props.MaxRetries = utils.ToPtr(channel.GetRetry())
if err = adapter.NewChatRequest(channel, props, hook); err == nil || err.Error() == "signal" {
return nil
if err = adapter.NewChatRequest(channel, props, hook); adapter.IsSkipError(err) {
return err
}

globals.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.OriginalModel, channel.GetName()))
Expand Down
2 changes: 1 addition & 1 deletion globals/variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func OriginIsAllowed(uri string) bool {
}

func OriginIsOpen(c *gin.Context) bool {
return strings.HasPrefix(c.Request.URL.Path, "/v1") || strings.HasPrefix(c.Request.URL.Path, "/dashboard")
return strings.HasPrefix(c.Request.URL.Path, "/v1") || strings.HasPrefix(c.Request.URL.Path, "/dashboard") || strings.HasPrefix(c.Request.URL.Path, "/mj")
}

const (
Expand Down
7 changes: 4 additions & 3 deletions manager/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package manager

import (
"chat/adapter"
"chat/adapter/common"
adaptercommon "chat/adapter/common"
"chat/addition/web"
"chat/admin"
"chat/auth"
Expand All @@ -11,8 +11,9 @@ import (
"chat/manager/conversation"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"runtime/debug"

"github.com/gin-gonic/gin"
)

const defaultMessage = "empty response"
Expand Down Expand Up @@ -96,7 +97,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
)

admin.AnalysisRequest(model, buffer, err)
if err != nil && err.Error() != "signal" {
if adapter.IsAvailableError(err) {
globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))

auth.RevertSubscriptionUsage(db, cache, user, model)
Expand Down
5 changes: 3 additions & 2 deletions utils/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import (
"context"
"errors"
"fmt"
"github.com/go-redis/redis/v8"
"time"

"github.com/go-redis/redis/v8"
)

func Incr(cache *redis.Client, key string, delta int64) (int64, error) {
Expand Down Expand Up @@ -37,7 +38,7 @@ func SetJson(cache *redis.Client, key string, value interface{}, expiration int6
return err
}

func GetJson[T any](cache *redis.Client, key string) *T {
func GetCacheStore[T any](cache *redis.Client, key string) *T {
val, err := cache.Get(context.Background(), key).Result()
if err != nil {
return nil
Expand Down

0 comments on commit f66eb11

Please sign in to comment.