Skip to content

Commit

Permalink
add use sharing conversation feature
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Oct 18, 2023
1 parent 5936da8 commit fd8c2d6
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 23 deletions.
67 changes: 67 additions & 0 deletions manager/conversation/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"strconv"
"strings"
)

type ShareForm struct {
ConversationId int64 `json:"conversation_id"`
Refs []int `json:"refs"`
}

func ListAPI(c *gin.Context) {
user := auth.GetUser(c)
if user == nil {
Expand Down Expand Up @@ -94,3 +100,64 @@ func DeleteAPI(c *gin.Context) {
"message": "",
})
}

func ShareAPI(c *gin.Context) {
user := auth.GetUser(c)
if user == nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"message": "user not found",
})
return
}

db := utils.GetDBFromContext(c)
var form ShareForm
if err := c.ShouldBindJSON(&form); err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"message": "invalid form",
})
return
}

if err := ShareConversation(db, user, form.ConversationId, form.Refs); err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"message": err.Error(),
})
return
}

c.JSON(http.StatusOK, gin.H{
"status": true,
"message": "",
})
}

func ViewAPI(c *gin.Context) {
db := utils.GetDBFromContext(c)
hash := strings.TrimSpace(c.Query("hash"))
if hash == "" {
c.JSON(http.StatusOK, gin.H{
"status": false,
"message": "invalid hash",
})
return
}

shared, err := GetSharedConversation(db, hash)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,
"message": err.Error(),
})
return
}

c.JSON(http.StatusOK, gin.H{
"status": true,
"message": "",
"data": shared,
})
}
34 changes: 19 additions & 15 deletions manager/conversation/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,33 @@ type FormMessage struct {

func NewAnonymousConversation() *Conversation {
return &Conversation{
Auth: false,
UserID: -1,
Id: -1,
Name: "anonymous",
Message: []globals.Message{},
Model: globals.GPT3Turbo,
EnableWeb: false,
Auth: false,
UserID: -1,
Id: -1,
Name: "anonymous",
Message: []globals.Message{},
Model: globals.GPT3Turbo,
}
}

func NewConversation(db *sql.DB, id int64) *Conversation {
return &Conversation{
Auth: true,
UserID: id,
Id: GetConversationLengthByUserID(db, id) + 1,
Name: "new chat",
Message: []globals.Message{},
Model: globals.GPT3Turbo,
EnableWeb: false,
Auth: true,
UserID: id,
Id: GetConversationLengthByUserID(db, id) + 1,
Name: "new chat",
Message: []globals.Message{},
Model: globals.GPT3Turbo,
}
}

func ExtractConversation(db *sql.DB, user *auth.User, id int64) *Conversation {
func ExtractConversation(db *sql.DB, user *auth.User, id int64, ref string) *Conversation {
if ref != "" {
if instance := UseSharedConversation(db, user, ref); instance != nil {
return instance
}
}

if user == nil {
return NewAnonymousConversation()
}
Expand Down
1 change: 1 addition & 0 deletions manager/conversation/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ func Register(app *gin.Engine) {
router.GET("/list", ListAPI)
router.GET("/load", LoadAPI)
router.GET("/delete", DeleteAPI)
router.POST("/share", ShareAPI)
}
}
44 changes: 37 additions & 7 deletions manager/conversation/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,24 @@ func GetRef(refs []int) (result string) {
return strings.TrimSuffix(result, ",")
}

func (c *Conversation) ShareConversation(db *sql.DB, user *auth.User, refs []int) bool {
if c.GetId() < 0 || user == nil {
return false
func ShareConversation(db *sql.DB, user *auth.User, id int64, refs []int) error {
if id < 0 || user == nil {
return nil
}

ref := GetRef(refs)
hash := utils.Md5EncryptForm(SharedHashForm{
Id: c.GetId(),
ConversationId: c.GetId(),
Id: user.GetID(db),
ConversationId: id,
Refs: refs,
})

_, err := db.Exec(`
INSERT INTO sharing (hash, user_id, conversation_id, refs) VALUES (?, ?, ?, ?)
ON DUPLICATE KEY UPDATE refs = ?
`, hash, user.GetID(db), c.GetId(), ref, ref)
`, hash, user.GetID(db), id, ref, ref)

return err == nil
return err
}

func GetSharedMessages(db *sql.DB, userId int64, conversationId int64, refs []string) []globals.Message {
Expand Down Expand Up @@ -96,3 +96,33 @@ func GetSharedConversation(db *sql.DB, hash string) (*SharedForm, error) {

return &shared, nil
}

func UseSharedConversation(db *sql.DB, user *auth.User, hash string) *Conversation {
shared, err := GetSharedConversation(db, hash)
if err != nil {
return nil
}

if user == nil {
// anonymous
return &Conversation{
Auth: false,
UserID: -1,
Id: -1,
Name: shared.Name,
Message: shared.Messages,
Model: globals.GPT3Turbo,
}
}

// create new conversation
id := user.GetID(db)
return &Conversation{
Auth: true,
Id: GetConversationLengthByUserID(db, id) + 1,
UserID: id,
Name: shared.Name,
Model: globals.GPT3Turbo,
Message: shared.Messages,
}
}
3 changes: 2 additions & 1 deletion manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
type WebsocketAuthForm struct {
Token string `json:"token" binding:"required"`
Id int64 `json:"id" binding:"required"`
Ref string `json:"ref"`
}

func EventHandler(conn *utils.WebSocket, instance *conversation.Conversation, user *auth.User) string {
Expand Down Expand Up @@ -66,7 +67,7 @@ func ChatAPI(c *gin.Context) {

id := auth.GetId(db, user)

instance := conversation.ExtractConversation(db, user, form.Id)
instance := conversation.ExtractConversation(db, user, form.Id, form.Ref)
hash := fmt.Sprintf(":chatthread:%s", utils.Md5Encrypt(utils.Multi(
authenticated,
strconv.FormatInt(id, 10),
Expand Down

0 comments on commit fd8c2d6

Please sign in to comment.