Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REDBOX 337 chat file selection #556

Merged
merged 32 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
faed993
Add ChatMessage many-to-many relation to File, and make them availabl…
brunns Jun 7, 2024
3add66a
Show selected files on chats page.
brunns Jun 7, 2024
1b8e317
Save selected file, and send to core, for streamed chat version.
brunns Jun 7, 2024
f7ff925
Use checkboxes for selecting files
KevinEtchells Jun 7, 2024
3eb716f
Move "Files to use" to sidebar
KevinEtchells Jun 7, 2024
bffcb44
Tests for saving & sending selected files.
brunns Jun 10, 2024
cf63656
Setup document selection for streaming client-side
KevinEtchells Jun 10, 2024
2c913e0
Save selected file, and send to core, for streamed chat version.
brunns Jun 10, 2024
c4aa4f3
Enable chat streaming in all tests.
brunns Jun 7, 2024
8c254cc
add eval results visualisation and calculate uncertainity
esoutter Jun 7, 2024
6b14a74
remove inline outputs
esoutter Jun 7, 2024
e8562eb
Address Will's PR comments
esoutter Jun 7, 2024
f842c3e
Remove streaming demo
KevinEtchells Jun 10, 2024
26b199d
Save non-rag responses to DB - tactical fix.
brunns Jun 10, 2024
99252a2
Revert tactical fix - we are doing it properly here.
brunns Jun 10, 2024
0f1bfe6
Recieve selected file list in core API for streaming.
brunns Jun 10, 2024
4fe6d67
Merge branch 'main' into feature/REDBOX-337-chat-file-selection
brunns Jun 10, 2024
6fab42f
Merge branch 'main' into feature/REDBOX-337-chat-file-selection
brunns Jun 10, 2024
0a5c6aa
Add selected files to e2e tests.
brunns Jun 10, 2024
f7d555a
Bug - ensure latest question is always the one answered.
brunns Jun 10, 2024
3f9c028
Merge branch 'main' into feature/REDBOX-337-chat-file-selection
brunns Jun 10, 2024
117b152
Unit tests not working but the core plumbing is there
Jun 10, 2024
00f8d44
Merge branch 'main' into feature/REDBOX-337-chat-file-selection
brunns Jun 11, 2024
aa66c80
Post merge formatting.
brunns Jun 11, 2024
bfb193e
wip
Jun 11, 2024
9a4a3bb
test now passing
Jun 11, 2024
28d72a2
Add some logging for debug purposes.
brunns Jun 11, 2024
6a9fc15
Merge branch 'main' into feature/REDBOX-337-chat-file-selection
brunns Jun 11, 2024
b70ff2d
must != should
Jun 11, 2024
50ba97e
remove change made in error
Jun 11, 2024
a7a658a
migration test added
Jun 11, 2024
bd368c0
no longer changing source_files
Jun 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions core_api/src/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,7 @@ async def build_vanilla_chain(


async def build_retrieval_chain(
chat_request: ChatRequest,
user_uuid: UUID,
llm: ChatLiteLLM,
vector_store: ElasticsearchStore,
chat_request: ChatRequest, user_uuid: UUID, llm: ChatLiteLLM, vector_store: ElasticsearchStore
):
question = chat_request.message_history[-1].text
previous_history = list(chat_request.message_history[:-1])
Expand All @@ -109,7 +106,13 @@ async def build_retrieval_chain(

standalone_question = condense_question_chain({"question": question, "chat_history": previous_history})["text"]

search_kwargs = {"filter": {"term": {"creator_user_uuid.keyword": str(user_uuid)}}}
search_kwargs = {"filter": {"bool": {"must": [{"term": {"creator_user_uuid.keyword": str(user_uuid)}}]}}}

if chat_request.selected_files is not None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if chat_request.selected_files is not None:
if chat_request.selected_files:

selected_files should never be None - the model uses a default factory.

search_kwargs["filter"]["bool"]["should"] = [
{"term": {"parent_file_uuid.keyword": str(file.uuid)}} for file in chat_request.selected_files
]

docs = vector_store.as_retriever(search_kwargs=search_kwargs).get_relevant_documents(standalone_question)

params = {
Expand All @@ -120,12 +123,7 @@ async def build_retrieval_chain(
return docs_with_sources_chain, params


async def build_chain(
chat_request: ChatRequest,
user_uuid: UUID,
llm: ChatLiteLLM,
vector_store: ElasticsearchStore,
):
async def build_chain(chat_request: ChatRequest, user_uuid: UUID, llm: ChatLiteLLM, vector_store: ElasticsearchStore):
question = chat_request.message_history[-1].text
route = route_layer(question)

Expand Down Expand Up @@ -189,7 +187,8 @@ async def rag_chat_streamed(

user_uuid = await get_ws_user_uuid(websocket)

chat_request = ChatRequest.parse_raw(await websocket.receive_text())
request = await websocket.receive_text()
chat_request = ChatRequest.model_validate_json(request)

chain, params = await build_chain(chat_request, user_uuid, llm, vector_store)

Expand Down
59 changes: 53 additions & 6 deletions core_api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
from elasticsearch import Elasticsearch
from fastapi.testclient import TestClient
from jose import jwt
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.llms.fake import FakeListLLM
from langchain_elasticsearch import ApproxRetrievalStrategy, ElasticsearchStore

from core_api.src.app import app as application
from core_api.src.app import env
from redbox.model_db import MODEL_PATH
from redbox.models import Chunk, File
from redbox.storage import ElasticsearchStorageHandler

Expand Down Expand Up @@ -72,33 +75,53 @@ def file(s3_client, file_pdf_path: Path, alice) -> File:


@pytest.fixture()
def stored_file(elasticsearch_storage_handler, file) -> File:
def stored_file_1(elasticsearch_storage_handler, file) -> File:
elasticsearch_storage_handler.write_item(file)
elasticsearch_storage_handler.refresh()
return file


@pytest.fixture()
def stored_file_chunks(stored_file) -> list[Chunk]:
def stored_file_chunks(stored_file_1) -> list[Chunk]:
chunks: list[Chunk] = []
for i in range(5):
chunks.append(
Chunk(
text="hello",
index=i,
parent_file_uuid=stored_file.uuid,
creator_user_uuid=stored_file.creator_user_uuid,
parent_file_uuid=stored_file_1.uuid,
creator_user_uuid=stored_file_1.creator_user_uuid,
embedding=[1] * 768,
metadata={"parent_doc_uuid": str(stored_file_1.uuid)},
)
)
return chunks


@pytest.fixture()
def chunked_file(elasticsearch_storage_handler, stored_file_chunks, stored_file) -> File:
def other_stored_file_chunks(stored_file_1) -> list[Chunk]:
new_uuid = uuid4()
chunks: list[Chunk] = []
for i in range(5):
chunks.append(
Chunk(
text="hello",
index=i,
parent_file_uuid=new_uuid,
creator_user_uuid=stored_file_1.creator_user_uuid,
embedding=[1] * 768,
metadata={"parent_doc_uuid": str(new_uuid)},
)
)
return chunks


@pytest.fixture()
def chunked_file(elasticsearch_storage_handler, stored_file_chunks, stored_file_1) -> File:
for chunk in stored_file_chunks:
elasticsearch_storage_handler.write_item(chunk)
elasticsearch_storage_handler.refresh()
return stored_file
return stored_file_1


@pytest.fixture()
Expand All @@ -109,3 +132,27 @@ def file_pdf_path() -> Path:
@pytest.fixture()
def mock_llm():
return FakeListLLM(responses=["<<TESTING>>"] * 128)


@pytest.fixture()
def embedding_model() -> SentenceTransformerEmbeddings:
return SentenceTransformerEmbeddings(model_name=env.embedding_model, cache_folder=MODEL_PATH)


@pytest.fixture()
def vector_store(es_client, embedding_model):
if env.elastic.subscription_level == "basic":
strategy = ApproxRetrievalStrategy(hybrid=False)
elif env.elastic.subscription_level in ["platinum", "enterprise"]:
strategy = ApproxRetrievalStrategy(hybrid=True)
else:
message = f"Unknown Elastic subscription level {env.elastic.subscription_level}"
raise ValueError(message)

return ElasticsearchStore(
es_connection=es_client,
index_name=f"{env.elastic_root_index}-chunk",
embedding=embedding_model,
strategy=strategy,
vector_query_field="embedding",
)
6 changes: 5 additions & 1 deletion core_api/tests/routes/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def test_rag_chat_streamed(app_client, headers):
{"text": "What can I do for you?", "role": "system"},
{"text": "Who put the ram in the rama lama ding dong?", "role": "user"},
]
selected_files = [
{"uuid": "9aa1aa15-dde0-471f-ab27-fd410612025b"},
{"uuid": "219c2e94-9877-4f83-ad6a-a59426f90171"},
]
events: Iterable[StreamEvent] = [
StreamEvent(
event="on_chat_model_stream",
Expand All @@ -97,7 +101,7 @@ def test_rag_chat_streamed(app_client, headers):
app_client.websocket_connect("/chat/rag", headers=headers) as websocket,
):
# When
websocket.send_text(json.dumps({"message_history": message_history}))
websocket.send_text(json.dumps({"message_history": message_history, "selected_files": selected_files}))

all_text, docs = [], []
while True:
Expand Down
8 changes: 4 additions & 4 deletions core_api/tests/routes/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def test_post_file_upload(s3_client, app_client, file_pdf_path: Path, head
assert response.status_code == HTTPStatus.CREATED


def test_list_files(app_client, stored_file, headers):
def test_list_files(app_client, stored_file_1, headers):
"""
Given a previously saved file
When I GET all files from /file
Expand All @@ -51,16 +51,16 @@ def test_list_files(app_client, stored_file, headers):
file_list = json.loads(response.content.decode("utf-8"))
assert len(file_list) > 0

assert str(stored_file.uuid) in [file["uuid"] for file in file_list]
assert str(stored_file_1.uuid) in [file["uuid"] for file in file_list]


def test_get_file(app_client, stored_file, headers):
def test_get_file(app_client, stored_file_1, headers):
"""
Given a previously saved file
When I GET it from /file/uuid
I Expect to receive it
"""
response = app_client.get(f"/file/{stored_file.uuid}", headers=headers)
response = app_client.get(f"/file/{stored_file_1.uuid}", headers=headers)
assert response.status_code == HTTPStatus.OK


Expand Down
4 changes: 2 additions & 2 deletions core_api/tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
({"Authorization": "Bearer " + jwt.encode({"user_uuid": str(uuid4())}, key="super-secure-private-key")}, 404),
],
)
def test_get_file_fails_auth(app_client, stored_file, malformed_headers, status_code):
def test_get_file_fails_auth(app_client, stored_file_1, malformed_headers, status_code):
"""
Given a previously saved file
When I GET it from /file/uuid with a missing/broken/correct header
I Expect get an appropriate status_code
"""
response = app_client.get(f"/file/{stored_file.uuid}", headers=malformed_headers)
response = app_client.get(f"/file/{stored_file_1.uuid}", headers=malformed_headers)
assert response.status_code == status_code
23 changes: 23 additions & 0 deletions core_api/tests/test_runnables.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import re

import pytest

from core_api.src.format import format_chunks, get_file_chunked_to_tokens
from core_api.src.routes.chat import build_retrieval_chain
from core_api.src.runnables import make_stuff_document_runnable
from redbox.models.chat import ChatRequest


def test_format_chunks(stored_file_chunks):
Expand Down Expand Up @@ -51,3 +55,22 @@ def test_make_stuff_document_runnable(mock_llm, stored_file_chunks):
)

assert response == "<<TESTING>>"


@pytest.mark.asyncio()
async def test_build_retrieval_chain(mock_llm, chunked_file, other_stored_file_chunks, vector_store): # noqa: ARG001
request = {
"message_history": [
{"text": "hello", "role": "user"},
],
"selected_files": [{"uuid": chunked_file.uuid}],
}

docs_with_sources_chain, params = await build_retrieval_chain(
chat_request=ChatRequest(**request),
user_uuid=chunked_file.creator_user_uuid,
llm=mock_llm,
vector_store=vector_store,
)

assert all(doc.metadata["parent_doc_uuid"] == str(chunked_file.uuid) for doc in params["input_documents"])
47 changes: 43 additions & 4 deletions django_app/frontend/js/chats/streaming.js
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ class ChatMessage extends HTMLElement {
/**
* Streams an LLM response
* @param {string} message
* @param {string[]} selectedDocuments
* @param {string | undefined} sessionId
* @param {string} endPoint
* @param {HTMLElement} chatControllerRef
*/
stream = (message, sessionId, endPoint, chatControllerRef) => {
stream = (message, selectedDocuments, sessionId, endPoint, chatControllerRef) => {

let responseContainer = /** @type MarkdownConverter */(this.querySelector('markdown-converter'));
let sourcesContainer = /** @type SourcesList */(this.querySelector('sources-list'));
Expand All @@ -77,7 +78,7 @@ class ChatMessage extends HTMLElement {
let sources = [];

webSocket.onopen = (event) => {
webSocket.send(JSON.stringify({message: message, sessionId: sessionId}));
webSocket.send(JSON.stringify({message: message, sessionId: sessionId, selectedFiles: selectedDocuments}));
this.dataset.status = "streaming";
};

Expand Down Expand Up @@ -123,11 +124,12 @@ class ChatController extends HTMLElement {

connectedCallback() {

const messageForm = this.querySelector('.js-message-input');
const messageForm = this.closest('form');
const textArea = /** @type {HTMLInputElement | null} */ (this.querySelector('.js-user-text'));
const messageContainer = this.querySelector('.js-message-container');
const insertPosition = this.querySelector('.js-response-feedback');
const feedbackButtons = /** @type {HTMLElement | null} */ (this.querySelector('feedback-buttons'));
let selectedDocuments = [];

messageForm?.addEventListener('submit', (evt) => {

Expand All @@ -146,7 +148,7 @@ class ChatController extends HTMLElement {
aiMessage.setAttribute('data-role', 'ai');
aiMessage.setAttribute('tabindex', '-1');
messageContainer?.insertBefore(aiMessage, insertPosition);
aiMessage.stream(userText, this.dataset.sessionId, this.dataset.streamUrl || '', this);
aiMessage.stream(userText, selectedDocuments, this.dataset.sessionId, this.dataset.streamUrl || '', this);
aiMessage.focus();

// reset UI
Expand All @@ -157,7 +159,44 @@ class ChatController extends HTMLElement {

});

document.body.addEventListener('selected-docs-change', (evt) => {
selectedDocuments = /** @type{CustomEvent} */(evt).detail;
});

}

}
customElements.define('chat-controller', ChatController);




class DocumentSelector extends HTMLElement {

connectedCallback() {

const documents = /** @type {NodeListOf<HTMLInputElement>} */ (this.querySelectorAll('input[type="checkbox"]'));

const getSelectedDocuments = () => {
let selectedDocuments = [];
documents.forEach((document) => {
if (document.checked) {
selectedDocuments.push(document.value);
}
});
const evt = new CustomEvent('selected-docs-change', { detail: selectedDocuments });
document.body.dispatchEvent(evt);
}

// update on page load
getSelectedDocuments();

// update on any selection change
documents.forEach((document) => {
document.addEventListener('change', getSelectedDocuments);
});

}

}
customElements.define('document-selector', DocumentSelector);
5 changes: 5 additions & 0 deletions django_app/redbox_app/redbox_core/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ class FileResource(admin.ModelAdmin):
list_display = ["original_file_name", "user", "status"]


class ChatMessageResource(admin.ModelAdmin):
list_display = ["chat_history", "text", "role"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
list_display = ["chat_history", "text", "role"]
list_display = ["chat_history", "text", "role", "created_at"]



class ChatMessageInline(admin.StackedInline):
model = models.ChatMessage
extra = 1
Expand Down Expand Up @@ -49,3 +53,4 @@ def export_as_csv(self, request, queryset): # noqa:ARG002
admin.site.register(models.User, UserResource)
admin.site.register(models.File, FileResource)
admin.site.register(models.ChatHistory, ChatHistoryAdmin)
admin.site.register(models.ChatMessage, ChatMessageResource)
Loading
Loading