From 401de5ace72808c69b4cf0de8f0e8de276b79a0a Mon Sep 17 00:00:00 2001 From: Deng Junhai Date: Thu, 27 Jun 2024 22:57:17 +0800 Subject: [PATCH] feat: better stop signal (#181) feat: better stop signal (#181) Co-Authored-By: Minghan Zhang <112773885+zmh-program@users.noreply.github.com> --- addition/web/call.go | 60 ++----- addition/web/parser.go | 68 -------- addition/web/search.go | 28 ---- addition/web/{duckduckgo.go => searxng.go} | 0 addition/web/utils.go | 24 --- addition/web/webpilot.go | 36 ---- manager/chat.go | 182 +++++++++++++++++---- 7 files changed, 166 insertions(+), 232 deletions(-) delete mode 100644 addition/web/parser.go rename addition/web/{duckduckgo.go => searxng.go} (100%) delete mode 100644 addition/web/utils.go delete mode 100644 addition/web/webpilot.go diff --git a/addition/web/call.go b/addition/web/call.go index e0a6af12..d94373cb 100644 --- a/addition/web/call.go +++ b/addition/web/call.go @@ -2,9 +2,9 @@ package web import ( "chat/globals" + "chat/manager/conversation" "chat/utils" "fmt" - "strings" "time" ) @@ -12,7 +12,7 @@ type Hook func(message []globals.Message, token int) (string, error) func ChatWithWeb(message []globals.Message) []globals.Message { data := utils.GetSegmentString( - SearchWebResult(GetPointByLatestMessage(message)), 2048, + SearchWebResult(message[len(message)-1].Content), 2048, ) return utils.Insert(message, 0, globals.Message{ @@ -24,52 +24,20 @@ func ChatWithWeb(message []globals.Message) []globals.Message { }) } -func StringCleaner(content string) string { - for _, replacer := range []string{",", "、", ",", "。", ":", ":", ";", ";", "!", "!", "=", "?", "?", "(", ")", "(", ")", "关键字", "空", "1+1"} { - content = strings.ReplaceAll(content, replacer, " ") - } - return strings.TrimSpace(content) -} +func UsingWebSegment(instance *conversation.Conversation, restart bool) []globals.Message { + segment := conversation.CopyMessage(instance.GetChatMessage(restart)) -func GetKeywordPoint(hook Hook, message []globals.Message) string { - resp, _ := hook([]globals.Message{{ - Role: globals.System, - Content: "If the user input content require ONLINE SEARCH to get the results, please output these keywords to refine the data Interval with space, remember not to answer other content, json format return, format {\"keyword\": \"...\" }", - }, { - Role: globals.User, - Content: "你是谁", - }, { - Role: globals.Assistant, - Content: "{\"keyword\":\"\"}", - }, { - Role: globals.User, - Content: "那fystart起始页是什么 和深能科创有什么关系", - }, { - Role: globals.Assistant, - Content: "{\"keyword\":\"fystart起始页 深能科创 关系\"}", - }, { - Role: globals.User, - Content: "1+1=?", - }, { - Role: globals.Assistant, - Content: "{\"keyword\":\"\"}", - }, { - Role: globals.User, - Content: "?", - }, { - Role: globals.Assistant, - Content: "{\"keyword\":\"\"}", - }, { - Role: globals.User, - Content: message[len(message)-1].Content, - }}, 40) - keyword := utils.UnmarshalJson[map[string]interface{}](resp) - if keyword == nil { - return "" + if instance.IsEnableWeb() { + segment = ChatWithWeb(segment) } - return StringCleaner(keyword["keyword"].(string)) + + return segment } -func GetPointByLatestMessage(message []globals.Message) string { - return StringCleaner(message[len(message)-1].Content) +func UsingWebNativeSegment(enable bool, message []globals.Message) []globals.Message { + if enable { + return ChatWithWeb(message) + } else { + return message + } } diff --git a/addition/web/parser.go b/addition/web/parser.go deleted file mode 100644 index 19daeab3..00000000 --- a/addition/web/parser.go +++ /dev/null @@ -1,68 +0,0 @@ -package web - -import ( - "chat/utils" - "golang.org/x/net/html" - "regexp" - "strings" -) - -var unexpected = []string{ - "", - "", - "
", - ""), 1) - return strings.Split(suf, "")[0] -} - -func SplitPagination(html string) string { - pre := strings.Split(html, "
  • ")[0] - return utils.TryGet(strings.Split(pre, "
    在新选项卡中打开链接
    "), 1) -} - -func GetContent(html string) []string { - re := regexp.MustCompile(`>([^<]+)<`) - matches := re.FindAllString(html, -1) - - return FilterContent(matches) -} - -func IsExpected(data string) bool { - if IsLink(data) { - return false - } - for _, str := range unexpected { - if strings.HasPrefix(data, str) { - return false - } - } - return true -} - -func IsLink(input string) bool { - re := regexp.MustCompile(`^(https?|ftp):\/\/[^\s/$.?#].\S*$`) - return re.MatchString(input) -} - -func FilterContent(matches []string) []string { - res := make([]string, 0) - - for _, match := range matches { - source := strings.TrimSpace(match[1 : len(match)-1]) - if len(source) > 0 && IsExpected(source) { - res = append(res, source) - } - } - - return res -} diff --git a/addition/web/search.go b/addition/web/search.go index 556ec0f9..95c1847f 100644 --- a/addition/web/search.go +++ b/addition/web/search.go @@ -4,36 +4,8 @@ import ( "chat/globals" "chat/utils" "fmt" - "net/url" ) -func GetBingUrl(q string) string { - return "https://bing.com/search?q=" + url.QueryEscape(q) -} - -func RequestWithUA(url string) string { - data, err := utils.GetRaw(url, map[string]string{ - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/116.0", - "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8", - }) - - if err != nil { - return "" - } - - return data -} - -func SearchReverse(q string) string { - // deprecated - uri := GetBingUrl(q) - if res := CallPilotAPI(uri); res != nil { - return utils.Marshal(res.Results) - } - data := RequestWithUA(uri) - return ParseBing(data) -} - func SearchWebResult(q string) string { res, err := CallDuckDuckGoAPI(q) if err != nil { diff --git a/addition/web/duckduckgo.go b/addition/web/searxng.go similarity index 100% rename from addition/web/duckduckgo.go rename to addition/web/searxng.go diff --git a/addition/web/utils.go b/addition/web/utils.go deleted file mode 100644 index b56d58c3..00000000 --- a/addition/web/utils.go +++ /dev/null @@ -1,24 +0,0 @@ -package web - -import ( - "chat/globals" - "chat/manager/conversation" -) - -func UsingWebSegment(instance *conversation.Conversation, restart bool) []globals.Message { - segment := conversation.CopyMessage(instance.GetChatMessage(restart)) - - if instance.IsEnableWeb() { - segment = ChatWithWeb(segment) - } - - return segment -} - -func UsingWebNativeSegment(enable bool, message []globals.Message) []globals.Message { - if enable { - return ChatWithWeb(message) - } else { - return message - } -} diff --git a/addition/web/webpilot.go b/addition/web/webpilot.go deleted file mode 100644 index 19af71fc..00000000 --- a/addition/web/webpilot.go +++ /dev/null @@ -1,36 +0,0 @@ -package web - -import ( - "chat/utils" - "github.com/google/uuid" -) - -type PilotResponseResult struct { - Title string `json:"title"` - Link string `json:"link"` - Snippet string `json:"snippet"` -} - -type PilotResponse struct { - Results []PilotResponseResult `json:"extra_search_results" required:"true"` -} - -func GenerateFriendUID() string { - return uuid.New().String() -} - -func CallPilotAPI(url string) *PilotResponse { - data, err := utils.Post("https://webreader.webpilotai.com/api/visit-web", map[string]string{ - "Content-Type": "application/json", - "WebPilot-Friend-UID": GenerateFriendUID(), - }, map[string]interface{}{ - "link": url, - "user_has_request": false, - }) - - if err != nil { - return nil - } - - return utils.MapToStruct[PilotResponse](data) -} diff --git a/manager/chat.go b/manager/chat.go index 708f4131..e43af54d 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -2,7 +2,7 @@ package manager import ( "chat/adapter" - "chat/adapter/common" + adaptercommon "chat/adapter/common" "chat/addition/web" "chat/admin" "chat/auth" @@ -10,12 +10,19 @@ import ( "chat/globals" "chat/manager/conversation" "chat/utils" + + "database/sql" + "errors" "fmt" - "github.com/gin-gonic/gin" "runtime/debug" + "strings" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" ) const defaultMessage = "empty response" +const interruptMessage = "interrupted" func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) { db := utils.GetDBFromContext(c) @@ -34,6 +41,148 @@ func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncount } } +type partialChunk struct { + Chunk *globals.Chunk + End bool + Hit bool + Error error +} + +func createStopSignal(conn *Connection) chan bool { + stopSignal := make(chan bool, 1) + go func(conn *Connection, stopSignal chan bool) { + defer func() { + if r := recover(); r != nil && !strings.Contains(fmt.Sprintf("%s", r), "closed channel") { + stack := debug.Stack() + globals.Warn(fmt.Sprintf("caught panic from stop signal: %s\n%s", r, stack)) + } + }() + + for { + if conn.PeekStop() != nil { + stopSignal <- true + break + } + } + }(conn, stopSignal) + + return stopSignal +} + +func createChatTask( + conn *Connection, user *auth.User, buffer *utils.Buffer, db *sql.DB, cache *redis.Client, + model string, instance *conversation.Conversation, segment []globals.Message, plan bool, +) (hit bool, err error) { + chunkChan := make(chan partialChunk, 24) // the channel to send the chunk data + interruptSignal := make(chan error, 1) // the signal to interrupt the chat task routine + stopSignal := createStopSignal(conn) // the signal to stop from the client + + defer func() { + // close all channels + close(interruptSignal) + close(stopSignal) + close(chunkChan) + }() + + // create a new chat request routine + go func() { + defer func() { + if r := recover(); r != nil && !strings.Contains(fmt.Sprintf("%s", r), "closed channel") { + stack := debug.Stack() + globals.Warn(fmt.Sprintf("caught panic from chat request: %s\n%s", r, stack)) + } + }() + + hit, err := channel.NewChatRequestWithCache( + cache, buffer, + auth.GetGroup(db, user), + &adaptercommon.ChatProps{ + Model: model, + Message: segment, + MaxTokens: instance.GetMaxTokens(), + Temperature: instance.GetTemperature(), + TopP: instance.GetTopP(), + TopK: instance.GetTopK(), + PresencePenalty: instance.GetPresencePenalty(), + FrequencyPenalty: instance.GetFrequencyPenalty(), + RepetitionPenalty: instance.GetRepetitionPenalty(), + }, + + // the function to handle the chunk data + func(data *globals.Chunk) error { + // if interrupt signal is received + if len(interruptSignal) > 0 { + return errors.New(interruptMessage) + } + + // send the chunk data to the channel + chunkChan <- partialChunk{ + Chunk: data, + End: false, + Hit: true, + Error: nil, + } + return nil + }, + ) + + // chat request routine is done + chunkChan <- partialChunk{ + Chunk: nil, + End: true, + Hit: hit, + Error: err, + } + }() + + for { + select { + case data := <-chunkChan: + if data.Error != nil && data.Error.Error() == interruptMessage { + // skip the interrupt message + continue + } + + hit = data.Hit + err = data.Error + + if data.End { + return + } + + sendPackError := conn.SendClient(globals.ChatSegmentResponse{ + Message: buffer.WriteChunk(data.Chunk), + Quota: buffer.GetQuota(), + End: false, + Plan: plan, + }) + if sendPackError != nil { + globals.Warn(fmt.Sprintf("failed to send message to client: %s", sendPackError.Error())) + _ = conn.SendClient(globals.ChatSegmentResponse{ + Message: sendPackError.Error(), + Quota: buffer.GetQuota(), + End: true, + Plan: plan, + }) + + interruptSignal <- sendPackError + + return hit, sendPackError + } + case <-stopSignal: + globals.Info(fmt.Sprintf("client stopped the chat request (model: %s, client: %s)", model, conn.GetCtx().ClientIP())) + _ = conn.SendClient(globals.ChatSegmentResponse{ + Quota: buffer.GetQuota(), + End: true, + Plan: plan, + }) + interruptSignal <- errors.New("signal") + + return + } + } +} + func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation, restart bool) string { defer func() { if err := recover(); err != nil { @@ -66,34 +215,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve } buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model)) - hit, err := channel.NewChatRequestWithCache( - cache, buffer, - auth.GetGroup(db, user), - &adaptercommon.ChatProps{ - Model: model, - Message: segment, - Buffer: *buffer, - MaxTokens: instance.GetMaxTokens(), - Temperature: instance.GetTemperature(), - TopP: instance.GetTopP(), - TopK: instance.GetTopK(), - PresencePenalty: instance.GetPresencePenalty(), - FrequencyPenalty: instance.GetFrequencyPenalty(), - RepetitionPenalty: instance.GetRepetitionPenalty(), - }, - func(data *globals.Chunk) error { - if signal := conn.PeekStop(); signal != nil { - // stop signal from client - return fmt.Errorf("signal") - } - return conn.SendClient(globals.ChatSegmentResponse{ - Message: buffer.WriteChunk(data), - Quota: buffer.GetQuota(), - End: false, - Plan: plan, - }) - }, - ) + hit, err := createChatTask(conn, user, buffer, db, cache, model, instance, segment, plan) admin.AnalysisRequest(model, buffer, err) if adapter.IsAvailableError(err) {