Skip to content

Commit

Permalink
Merge pull request #32 from klipach/fix-chat-search
Browse files Browse the repository at this point in the history
Fix chat history search
  • Loading branch information
klipach authored Jan 23, 2025
2 parents d85f30e + bdc05ce commit a3c7eb6
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 9 deletions.
33 changes: 24 additions & 9 deletions chat/history.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@ const (
)

type firestoreUser struct {
DisplayName string `firestore:"display_name"`
Chats []struct {
Messages []struct {
From string `firestore:"from"`
Message string `firestore:"message"`
} `firestore:"messages"`
} `firestore:"chats"`
DisplayName string `firestore:"display_name"`
Chats []*firestoreChat `firestore:"chats"`
}

type firestoreMessage struct {
From string `firestore:"from"`
Message string `firestore:"message"`
}

type firestoreChat struct {
ChatID int `firestore:"chat_id"`
Messages []*firestoreMessage `firestore:"messages"`
}

func LoadHistory(ctx context.Context, userID string, chatID int) ([]llms.MessageContent, error) {
Expand Down Expand Up @@ -54,12 +59,13 @@ func LoadHistory(ctx context.Context, userID string, chatID int) ([]llms.Message
user := firestoreUser{}
userDoc.DataTo(&user)

if chatID >= len(user.Chats) {
messages := findChatMessages(user.Chats, chatID)
if messages == nil {
logger.Printf("chat not found: %d", chatID)
return chatHistory, nil
}

for _, m := range user.Chats[chatID].Messages {
for _, m := range messages {
switch m.From {
case fromUser:
chatHistory = append(chatHistory, llms.TextParts(llms.ChatMessageTypeHuman, m.Message))
Expand All @@ -71,3 +77,12 @@ func LoadHistory(ctx context.Context, userID string, chatID int) ([]llms.Message
}
return chatHistory, nil
}

func findChatMessages(chats []*firestoreChat, chatID int) []*firestoreMessage {
for _, chat := range chats {
if chat.ChatID == chatID {
return chat.Messages
}
}
return nil
}
56 changes: 56 additions & 0 deletions chat/history_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package chat

import (
"reflect"
"testing"
)

func TestFindChatMessages(t *testing.T) {
tests := []struct {
name string
chats []*firestoreChat
chatID int
expected []*firestoreMessage
}{
{
name: "Chat found",
chats: []*firestoreChat{
{ChatID: 1, Messages: []*firestoreMessage{{}, {}}},
{ChatID: 2, Messages: []*firestoreMessage{{}}},
},
chatID: 1,
expected: []*firestoreMessage{{}, {}},
},
{
name: "Chat not found",
chats: []*firestoreChat{
{ChatID: 1, Messages: []*firestoreMessage{{}, {}}},
{ChatID: 2, Messages: []*firestoreMessage{{}}},
},
chatID: 3,
expected: nil,
},
{
name: "Empty chats",
chats: []*firestoreChat{},
chatID: 1,
expected: nil,
},
{
name: "nil chats",
chats: nil,
chatID: 1,
expected: nil,
},

}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := findChatMessages(test.chats, test.chatID)
if !reflect.DeepEqual(result, test.expected) {
t.Errorf("findChatMessages(%v, %d) = %v; want %v", test.chats, test.chatID, result, test.expected)
}
})
}
}

0 comments on commit a3c7eb6

Please sign in to comment.