Skip to content

Commit

Permalink
Merge pull request #584 from i-dot-ai/rest-endpoint-stuff-summarisati…
Browse files Browse the repository at this point in the history
…on-rebase-2

Update REST `rag_chat` endpoint to include the stuff summarisation already in streaming endpoint
  • Loading branch information
andy-symonds authored Jun 14, 2024
2 parents d6f7219 + c4279d0 commit 81c2678
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 24 deletions.
50 changes: 27 additions & 23 deletions core_api/src/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,31 +88,35 @@ async def rag_chat(
user_uuid: Annotated[UUID, Depends(get_user_uuid)],
llm: Annotated[ChatLiteLLM, Depends(get_llm)],
vector_store: Annotated[ElasticsearchStore, Depends(get_vector_store)],
storage_handler: Annotated[ElasticsearchStorageHandler, Depends(get_storage_handler)],
) -> ChatResponse:
"""REST endpoint. Get a LLM response to a question history and file."""
question = chat_request.message_history[-1].text
route = route_layer(question)
# TODO (@wpfl-dbt): will need updating - focused on streaming endpoint # noqa: TD003
if route_response := ROUTE_RESPONSES.get(route.name):
response = route_response.invoke({})
return ChatResponse(output_text=response.messages[0].content)

# build_vanilla_chain could go here

# RAG chat
chain, params = await build_retrieval_chain(chat_request, user_uuid, llm, vector_store)
"""REST endpoint.
Chose the correct route based on the question.
Get a response to a question history and file.
"""
chain, params = await semantic_router_to_chain(chat_request, user_uuid, llm, vector_store, storage_handler)

result = chain(params)

source_documents = [
SourceDocument(
page_content=langchain_document.page_content,
file_uuid=langchain_document.metadata.get("parent_doc_uuid"),
page_numbers=langchain_document.metadata.get("page_numbers"),
)
for langchain_document in result.get("input_documents", [])
]
return ChatResponse(output_text=result["output_text"], source_documents=source_documents)
result = chain.invoke(params)
if isinstance(result, dict):
source_documents = [
SourceDocument(
page_content=document.page_content,
file_uuid=document.metadata.get("parent_doc_uuid"),
page_numbers=document.metadata.get("page_numbers"),
)
for document in result.get("input_documents", [])
]
return ChatResponse(output_text=result["output_text"], source_documents=source_documents)
# stuff_summarisation route
elif isinstance(result, str):
return ChatResponse(output_text=result)
# hard-coded routes
else:
try:
msg = result.messages[0].content
return ChatResponse(output_text=msg)
except (KeyError, AttributeError):
logging.exception("unknown message format %s", str(result))


@chat_app.websocket("/rag")
Expand Down
19 changes: 18 additions & 1 deletion core_api/tests/routes/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def mock_get_chain():
# assert response.status_code == status_code


def test_rag_chat(app_client, headers):
def test_rag_chat_rest_gratitude(app_client, headers):
response = app_client.post(
"/chat/rag",
json={"message_history": [{"role": "user", "text": "Thank you"}]},
Expand All @@ -65,6 +65,23 @@ def test_rag_chat(app_client, headers):
assert response_dict["output_text"] == "You're welcome!"


def test_rag_chat_rest_stuff_summarise(app_client, headers):
response = app_client.post(
"/chat/rag",
json={
"message_history": [
{
"role": "user",
"text": "Please summarise the contents of the uploaded files.",
}
]
},
headers=headers,
)
response_dict = response.json()
assert isinstance(response_dict["output_text"], str)


def test_rag_chat_streamed(app_client, headers):
# Given
message_history = [
Expand Down

0 comments on commit 81c2678

Please sign in to comment.