From fd8c2d6eb75054471317f55f8c404a870893d10f Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Wed, 18 Oct 2023 16:26:22 +0800 Subject: [PATCH] add use sharing conversation feature --- manager/conversation/api.go | 67 ++++++++++++++++++++++++++++ manager/conversation/conversation.go | 34 +++++++------- manager/conversation/router.go | 1 + manager/conversation/shared.go | 44 +++++++++++++++--- manager/manager.go | 3 +- 5 files changed, 126 insertions(+), 23 deletions(-) diff --git a/manager/conversation/api.go b/manager/conversation/api.go index 7c0d8548..43a96ce0 100644 --- a/manager/conversation/api.go +++ b/manager/conversation/api.go @@ -6,8 +6,14 @@ import ( "github.com/gin-gonic/gin" "net/http" "strconv" + "strings" ) +type ShareForm struct { + ConversationId int64 `json:"conversation_id"` + Refs []int `json:"refs"` +} + func ListAPI(c *gin.Context) { user := auth.GetUser(c) if user == nil { @@ -94,3 +100,64 @@ func DeleteAPI(c *gin.Context) { "message": "", }) } + +func ShareAPI(c *gin.Context) { + user := auth.GetUser(c) + if user == nil { + c.JSON(http.StatusOK, gin.H{ + "status": false, + "message": "user not found", + }) + return + } + + db := utils.GetDBFromContext(c) + var form ShareForm + if err := c.ShouldBindJSON(&form); err != nil { + c.JSON(http.StatusOK, gin.H{ + "status": false, + "message": "invalid form", + }) + return + } + + if err := ShareConversation(db, user, form.ConversationId, form.Refs); err != nil { + c.JSON(http.StatusOK, gin.H{ + "status": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "status": true, + "message": "", + }) +} + +func ViewAPI(c *gin.Context) { + db := utils.GetDBFromContext(c) + hash := strings.TrimSpace(c.Query("hash")) + if hash == "" { + c.JSON(http.StatusOK, gin.H{ + "status": false, + "message": "invalid hash", + }) + return + } + + shared, err := GetSharedConversation(db, hash) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "status": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "status": true, + "message": "", + "data": shared, + }) +} diff --git a/manager/conversation/conversation.go b/manager/conversation/conversation.go index 092232f9..efe8e92a 100644 --- a/manager/conversation/conversation.go +++ b/manager/conversation/conversation.go @@ -28,29 +28,33 @@ type FormMessage struct { func NewAnonymousConversation() *Conversation { return &Conversation{ - Auth: false, - UserID: -1, - Id: -1, - Name: "anonymous", - Message: []globals.Message{}, - Model: globals.GPT3Turbo, - EnableWeb: false, + Auth: false, + UserID: -1, + Id: -1, + Name: "anonymous", + Message: []globals.Message{}, + Model: globals.GPT3Turbo, } } func NewConversation(db *sql.DB, id int64) *Conversation { return &Conversation{ - Auth: true, - UserID: id, - Id: GetConversationLengthByUserID(db, id) + 1, - Name: "new chat", - Message: []globals.Message{}, - Model: globals.GPT3Turbo, - EnableWeb: false, + Auth: true, + UserID: id, + Id: GetConversationLengthByUserID(db, id) + 1, + Name: "new chat", + Message: []globals.Message{}, + Model: globals.GPT3Turbo, } } -func ExtractConversation(db *sql.DB, user *auth.User, id int64) *Conversation { +func ExtractConversation(db *sql.DB, user *auth.User, id int64, ref string) *Conversation { + if ref != "" { + if instance := UseSharedConversation(db, user, ref); instance != nil { + return instance + } + } + if user == nil { return NewAnonymousConversation() } diff --git a/manager/conversation/router.go b/manager/conversation/router.go index 3b1e4d18..b5e9bc8d 100644 --- a/manager/conversation/router.go +++ b/manager/conversation/router.go @@ -8,5 +8,6 @@ func Register(app *gin.Engine) { router.GET("/list", ListAPI) router.GET("/load", LoadAPI) router.GET("/delete", DeleteAPI) + router.POST("/share", ShareAPI) } } diff --git a/manager/conversation/shared.go b/manager/conversation/shared.go index ce52bdea..b2c760a4 100644 --- a/manager/conversation/shared.go +++ b/manager/conversation/shared.go @@ -30,24 +30,24 @@ func GetRef(refs []int) (result string) { return strings.TrimSuffix(result, ",") } -func (c *Conversation) ShareConversation(db *sql.DB, user *auth.User, refs []int) bool { - if c.GetId() < 0 || user == nil { - return false +func ShareConversation(db *sql.DB, user *auth.User, id int64, refs []int) error { + if id < 0 || user == nil { + return nil } ref := GetRef(refs) hash := utils.Md5EncryptForm(SharedHashForm{ - Id: c.GetId(), - ConversationId: c.GetId(), + Id: user.GetID(db), + ConversationId: id, Refs: refs, }) _, err := db.Exec(` INSERT INTO sharing (hash, user_id, conversation_id, refs) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE refs = ? - `, hash, user.GetID(db), c.GetId(), ref, ref) + `, hash, user.GetID(db), id, ref, ref) - return err == nil + return err } func GetSharedMessages(db *sql.DB, userId int64, conversationId int64, refs []string) []globals.Message { @@ -96,3 +96,33 @@ func GetSharedConversation(db *sql.DB, hash string) (*SharedForm, error) { return &shared, nil } + +func UseSharedConversation(db *sql.DB, user *auth.User, hash string) *Conversation { + shared, err := GetSharedConversation(db, hash) + if err != nil { + return nil + } + + if user == nil { + // anonymous + return &Conversation{ + Auth: false, + UserID: -1, + Id: -1, + Name: shared.Name, + Message: shared.Messages, + Model: globals.GPT3Turbo, + } + } + + // create new conversation + id := user.GetID(db) + return &Conversation{ + Auth: true, + Id: GetConversationLengthByUserID(db, id) + 1, + UserID: id, + Name: shared.Name, + Model: globals.GPT3Turbo, + Message: shared.Messages, + } +} diff --git a/manager/manager.go b/manager/manager.go index 13b34281..15fed160 100644 --- a/manager/manager.go +++ b/manager/manager.go @@ -14,6 +14,7 @@ import ( type WebsocketAuthForm struct { Token string `json:"token" binding:"required"` Id int64 `json:"id" binding:"required"` + Ref string `json:"ref"` } func EventHandler(conn *utils.WebSocket, instance *conversation.Conversation, user *auth.User) string { @@ -66,7 +67,7 @@ func ChatAPI(c *gin.Context) { id := auth.GetId(db, user) - instance := conversation.ExtractConversation(db, user, form.Id) + instance := conversation.ExtractConversation(db, user, form.Id, form.Ref) hash := fmt.Sprintf(":chatthread:%s", utils.Md5Encrypt(utils.Multi( authenticated, strconv.FormatInt(id, 10),