Skip to content

Commit

Permalink
REDBOX 337 chat file selection (#556)
Browse files Browse the repository at this point in the history
* Add ChatMessage many-to-many relation to File, and make them available through the chat view.

* Show selected files on chats page.

* Save selected file, and send to core, for streamed chat version.

* Use checkboxes for selecting files

* Move "Files to use" to sidebar

* Tests for saving & sending selected files.

* Setup document selection for streaming client-side

* Save selected file, and send to core, for streamed chat version.

* Enable chat streaming in all tests.

* add eval results visualisation and calculate uncertainity

* remove inline outputs

* Address Will's PR comments

* Remove streaming demo

* Save non-rag responses to DB - tactical fix.

* Revert tactical fix - we are doing it properly here.

* Recieve selected file list in core API for streaming.

* Add selected files to e2e tests.

* Bug - ensure latest question is always the one answered.

* Unit tests not working but the core plumbing is there

* Post merge formatting.

* wip

* test now passing

* Add some logging for debug purposes.

* must != should

* remove change made in error

* migration test added

* no longer changing source_files

---------

Co-authored-by: Kevin Etchells <[email protected]>
Co-authored-by: esoutter <[email protected]>
Co-authored-by: Will Langdale <[email protected]>
Co-authored-by: George Burton <[email protected]>
  • Loading branch information
5 people authored Jun 12, 2024
1 parent cfe9459 commit a1d930c
Show file tree
Hide file tree
Showing 24 changed files with 617 additions and 189 deletions.
2 changes: 1 addition & 1 deletion core_api/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# === Logging ===

logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()

env = Settings()
Expand Down
34 changes: 19 additions & 15 deletions core_api/src/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

# === Logging ===

logging.basicConfig(level=logging.INFO)
log = logging.getLogger()
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()


chat_app = FastAPI(
Expand Down Expand Up @@ -86,10 +86,7 @@ async def build_vanilla_chain(


async def build_retrieval_chain(
chat_request: ChatRequest,
user_uuid: UUID,
llm: ChatLiteLLM,
vector_store: ElasticsearchStore,
chat_request: ChatRequest, user_uuid: UUID, llm: ChatLiteLLM, vector_store: ElasticsearchStore
):
question = chat_request.message_history[-1].text
previous_history = list(chat_request.message_history[:-1])
Expand All @@ -109,7 +106,14 @@ async def build_retrieval_chain(

standalone_question = condense_question_chain({"question": question, "chat_history": previous_history})["text"]

search_kwargs = {"filter": {"term": {"creator_user_uuid.keyword": str(user_uuid)}}}
search_kwargs = {"filter": {"bool": {"must": [{"term": {"creator_user_uuid.keyword": str(user_uuid)}}]}}}

if chat_request.selected_files:
logging.info("chat_request.selected_files: %s", str(chat_request.selected_files))
search_kwargs["filter"]["bool"]["must"] = [
{"term": {"parent_file_uuid.keyword": str(file.uuid)}} for file in chat_request.selected_files
]

docs = vector_store.as_retriever(search_kwargs=search_kwargs).get_relevant_documents(standalone_question)

params = {
Expand All @@ -120,12 +124,7 @@ async def build_retrieval_chain(
return docs_with_sources_chain, params


async def build_chain(
chat_request: ChatRequest,
user_uuid: UUID,
llm: ChatLiteLLM,
vector_store: ElasticsearchStore,
):
async def build_chain(chat_request: ChatRequest, user_uuid: UUID, llm: ChatLiteLLM, vector_store: ElasticsearchStore):
question = chat_request.message_history[-1].text
route = route_layer(question)

Expand Down Expand Up @@ -154,6 +153,8 @@ async def rag_chat(
StreamingResponse: a stream of the chain response
"""

logging.info("chat_request: %s", chat_request)

question = chat_request.message_history[-1].text
route = route_layer(question)

Expand Down Expand Up @@ -189,7 +190,10 @@ async def rag_chat_streamed(

user_uuid = await get_ws_user_uuid(websocket)

chat_request = ChatRequest.parse_raw(await websocket.receive_text())
request = await websocket.receive_text()
logger.debug("raw request from django-app: %s", request)
chat_request = ChatRequest.model_validate_json(request)
logger.debug("chat request from django-app: %s", chat_request)

chain, params = await build_chain(chat_request, user_uuid, llm, vector_store)

Expand All @@ -216,6 +220,6 @@ async def rag_chat_streamed(
msg = event["data"]["chunk"].messages[0].content
await websocket.send_json({"resource_type": "text", "data": msg})
except (KeyError, AttributeError):
logging.exception("unknown message format %s", str(event["data"]["chunk"]))
logger.exception("unknown message format %s", str(event["data"]["chunk"]))

await websocket.close()
59 changes: 53 additions & 6 deletions core_api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
from elasticsearch import Elasticsearch
from fastapi.testclient import TestClient
from jose import jwt
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.llms.fake import FakeListLLM
from langchain_elasticsearch import ApproxRetrievalStrategy, ElasticsearchStore

from core_api.src.app import app as application
from core_api.src.app import env
from redbox.model_db import MODEL_PATH
from redbox.models import Chunk, File
from redbox.storage import ElasticsearchStorageHandler

Expand Down Expand Up @@ -72,33 +75,53 @@ def file(s3_client, file_pdf_path: Path, alice) -> File:


@pytest.fixture()
def stored_file(elasticsearch_storage_handler, file) -> File:
def stored_file_1(elasticsearch_storage_handler, file) -> File:
elasticsearch_storage_handler.write_item(file)
elasticsearch_storage_handler.refresh()
return file


@pytest.fixture()
def stored_file_chunks(stored_file) -> list[Chunk]:
def stored_file_chunks(stored_file_1) -> list[Chunk]:
chunks: list[Chunk] = []
for i in range(5):
chunks.append(
Chunk(
text="hello",
index=i,
parent_file_uuid=stored_file.uuid,
creator_user_uuid=stored_file.creator_user_uuid,
parent_file_uuid=stored_file_1.uuid,
creator_user_uuid=stored_file_1.creator_user_uuid,
embedding=[1] * 768,
metadata={"parent_doc_uuid": str(stored_file_1.uuid)},
)
)
return chunks


@pytest.fixture()
def chunked_file(elasticsearch_storage_handler, stored_file_chunks, stored_file) -> File:
def other_stored_file_chunks(stored_file_1) -> list[Chunk]:
new_uuid = uuid4()
chunks: list[Chunk] = []
for i in range(5):
chunks.append(
Chunk(
text="hello",
index=i,
parent_file_uuid=new_uuid,
creator_user_uuid=stored_file_1.creator_user_uuid,
embedding=[1] * 768,
metadata={"parent_doc_uuid": str(new_uuid)},
)
)
return chunks


@pytest.fixture()
def chunked_file(elasticsearch_storage_handler, stored_file_chunks, stored_file_1) -> File:
for chunk in stored_file_chunks:
elasticsearch_storage_handler.write_item(chunk)
elasticsearch_storage_handler.refresh()
return stored_file
return stored_file_1


@pytest.fixture()
Expand All @@ -109,3 +132,27 @@ def file_pdf_path() -> Path:
@pytest.fixture()
def mock_llm():
return FakeListLLM(responses=["<<TESTING>>"] * 128)


@pytest.fixture()
def embedding_model() -> SentenceTransformerEmbeddings:
return SentenceTransformerEmbeddings(model_name=env.embedding_model, cache_folder=MODEL_PATH)


@pytest.fixture()
def vector_store(es_client, embedding_model):
if env.elastic.subscription_level == "basic":
strategy = ApproxRetrievalStrategy(hybrid=False)
elif env.elastic.subscription_level in ["platinum", "enterprise"]:
strategy = ApproxRetrievalStrategy(hybrid=True)
else:
message = f"Unknown Elastic subscription level {env.elastic.subscription_level}"
raise ValueError(message)

return ElasticsearchStore(
es_connection=es_client,
index_name=f"{env.elastic_root_index}-chunk",
embedding=embedding_model,
strategy=strategy,
vector_query_field="embedding",
)
6 changes: 5 additions & 1 deletion core_api/tests/routes/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def test_rag_chat_streamed(app_client, headers):
{"text": "What can I do for you?", "role": "system"},
{"text": "Who put the ram in the rama lama ding dong?", "role": "user"},
]
selected_files = [
{"uuid": "9aa1aa15-dde0-471f-ab27-fd410612025b"},
{"uuid": "219c2e94-9877-4f83-ad6a-a59426f90171"},
]
events: Iterable[StreamEvent] = [
StreamEvent(
event="on_chat_model_stream",
Expand All @@ -97,7 +101,7 @@ def test_rag_chat_streamed(app_client, headers):
app_client.websocket_connect("/chat/rag", headers=headers) as websocket,
):
# When
websocket.send_text(json.dumps({"message_history": message_history}))
websocket.send_text(json.dumps({"message_history": message_history, "selected_files": selected_files}))

all_text, docs = [], []
while True:
Expand Down
8 changes: 4 additions & 4 deletions core_api/tests/routes/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def test_post_file_upload(s3_client, app_client, file_pdf_path: Path, head
assert response.status_code == HTTPStatus.CREATED


def test_list_files(app_client, stored_file, headers):
def test_list_files(app_client, stored_file_1, headers):
"""
Given a previously saved file
When I GET all files from /file
Expand All @@ -51,16 +51,16 @@ def test_list_files(app_client, stored_file, headers):
file_list = json.loads(response.content.decode("utf-8"))
assert len(file_list) > 0

assert str(stored_file.uuid) in [file["uuid"] for file in file_list]
assert str(stored_file_1.uuid) in [file["uuid"] for file in file_list]


def test_get_file(app_client, stored_file, headers):
def test_get_file(app_client, stored_file_1, headers):
"""
Given a previously saved file
When I GET it from /file/uuid
I Expect to receive it
"""
response = app_client.get(f"/file/{stored_file.uuid}", headers=headers)
response = app_client.get(f"/file/{stored_file_1.uuid}", headers=headers)
assert response.status_code == HTTPStatus.OK


Expand Down
4 changes: 2 additions & 2 deletions core_api/tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
({"Authorization": "Bearer " + jwt.encode({"user_uuid": str(uuid4())}, key="super-secure-private-key")}, 404),
],
)
def test_get_file_fails_auth(app_client, stored_file, malformed_headers, status_code):
def test_get_file_fails_auth(app_client, stored_file_1, malformed_headers, status_code):
"""
Given a previously saved file
When I GET it from /file/uuid with a missing/broken/correct header
I Expect get an appropriate status_code
"""
response = app_client.get(f"/file/{stored_file.uuid}", headers=malformed_headers)
response = app_client.get(f"/file/{stored_file_1.uuid}", headers=malformed_headers)
assert response.status_code == status_code
23 changes: 23 additions & 0 deletions core_api/tests/test_runnables.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import re

import pytest

from core_api.src.format import format_chunks, get_file_chunked_to_tokens
from core_api.src.routes.chat import build_retrieval_chain
from core_api.src.runnables import make_stuff_document_runnable
from redbox.models.chat import ChatRequest


def test_format_chunks(stored_file_chunks):
Expand Down Expand Up @@ -51,3 +55,22 @@ def test_make_stuff_document_runnable(mock_llm, stored_file_chunks):
)

assert response == "<<TESTING>>"


@pytest.mark.asyncio()
async def test_build_retrieval_chain(mock_llm, chunked_file, other_stored_file_chunks, vector_store): # noqa: ARG001
request = {
"message_history": [
{"text": "hello", "role": "user"},
],
"selected_files": [{"uuid": chunked_file.uuid}],
}

docs_with_sources_chain, params = await build_retrieval_chain(
chat_request=ChatRequest(**request),
user_uuid=chunked_file.creator_user_uuid,
llm=mock_llm,
vector_store=vector_store,
)

assert all(doc.metadata["parent_doc_uuid"] == str(chunked_file.uuid) for doc in params["input_documents"])
47 changes: 43 additions & 4 deletions django_app/frontend/js/chats/streaming.js
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ class ChatMessage extends HTMLElement {
/**
* Streams an LLM response
* @param {string} message
* @param {string[]} selectedDocuments
* @param {string | undefined} sessionId
* @param {string} endPoint
* @param {HTMLElement} chatControllerRef
*/
stream = (message, sessionId, endPoint, chatControllerRef) => {
stream = (message, selectedDocuments, sessionId, endPoint, chatControllerRef) => {

let responseContainer = /** @type MarkdownConverter */(this.querySelector('markdown-converter'));
let sourcesContainer = /** @type SourcesList */(this.querySelector('sources-list'));
Expand All @@ -77,7 +78,7 @@ class ChatMessage extends HTMLElement {
let sources = [];

webSocket.onopen = (event) => {
webSocket.send(JSON.stringify({message: message, sessionId: sessionId}));
webSocket.send(JSON.stringify({message: message, sessionId: sessionId, selectedFiles: selectedDocuments}));
this.dataset.status = "streaming";
};

Expand Down Expand Up @@ -123,11 +124,12 @@ class ChatController extends HTMLElement {

connectedCallback() {

const messageForm = this.querySelector('.js-message-input');
const messageForm = this.closest('form');
const textArea = /** @type {HTMLInputElement | null} */ (this.querySelector('.js-user-text'));
const messageContainer = this.querySelector('.js-message-container');
const insertPosition = this.querySelector('.js-response-feedback');
const feedbackButtons = /** @type {HTMLElement | null} */ (this.querySelector('feedback-buttons'));
let selectedDocuments = [];

messageForm?.addEventListener('submit', (evt) => {

Expand All @@ -146,7 +148,7 @@ class ChatController extends HTMLElement {
aiMessage.setAttribute('data-role', 'ai');
aiMessage.setAttribute('tabindex', '-1');
messageContainer?.insertBefore(aiMessage, insertPosition);
aiMessage.stream(userText, this.dataset.sessionId, this.dataset.streamUrl || '', this);
aiMessage.stream(userText, selectedDocuments, this.dataset.sessionId, this.dataset.streamUrl || '', this);
aiMessage.focus();

// reset UI
Expand All @@ -157,7 +159,44 @@ class ChatController extends HTMLElement {

});

document.body.addEventListener('selected-docs-change', (evt) => {
selectedDocuments = /** @type{CustomEvent} */(evt).detail;
});

}

}
customElements.define('chat-controller', ChatController);




class DocumentSelector extends HTMLElement {

connectedCallback() {

const documents = /** @type {NodeListOf<HTMLInputElement>} */ (this.querySelectorAll('input[type="checkbox"]'));

const getSelectedDocuments = () => {
let selectedDocuments = [];
documents.forEach((document) => {
if (document.checked) {
selectedDocuments.push(document.value);
}
});
const evt = new CustomEvent('selected-docs-change', { detail: selectedDocuments });
document.body.dispatchEvent(evt);
}

// update on page load
getSelectedDocuments();

// update on any selection change
documents.forEach((document) => {
document.addEventListener('change', getSelectedDocuments);
});

}

}
customElements.define('document-selector', DocumentSelector);
Loading

0 comments on commit a1d930c

Please sign in to comment.