Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/semantic routing #540

Merged
merged 25 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6cabf3a
[Semantic routing] AS | Update dependencies to include semantic routing
andy-symonds Jun 5, 2024
39d3f1a
[Semantic-router] | AS | Adding semantic-router to pyproject.toml and…
andy-symonds Jun 6, 2024
3c996dc
Merge branch 'main' into feature/semantic-routing
andy-symonds Jun 6, 2024
5b0f1bc
patch
andy-symonds Jun 6, 2024
6cca775
[semantic-router] | AS | Notebook with routes for info, gratitude, su…
andy-symonds Jun 6, 2024
ee1b83d
[semantic-router] | AS | Added an elasticsearch health check conditio…
andy-symonds Jun 6, 2024
3dd320a
[semantic-router] | AS | Updated make run command to include --wait a…
andy-symonds Jun 6, 2024
a1113df
now using huggingface encoder
Jun 7, 2024
2a396da
revert makefile changes
Jun 7, 2024
ea7a0c9
e2e test extended
Jun 7, 2024
5a9e2d8
[semantic-routing] | AS | Addedd code for non-streaming endpoint rout…
andy-symonds Jun 7, 2024
f547a5a
[semantic-routing] | AS | Updated tests to reflect vannila endpoint n…
andy-symonds Jun 7, 2024
02675bf
[semantic-routing] | AS | Update test_chat.py to reflect vanilla chat…
andy-symonds Jun 7, 2024
fce683c
[semantic-routing] | AS | fixed formating
andy-symonds Jun 7, 2024
7d4ae58
[semantic-routing] | AS | Refactored build_chain logic and added coac…
andy-symonds Jun 7, 2024
6775819
[semantic-routing] | AS | Rename variables for route responses and re…
andy-symonds Jun 7, 2024
b6acc72
[semantic-routing] | AS | Fixed formating
andy-symonds Jun 7, 2024
9ce2510
[semantic-routing] | AS | Refactor semantic routing setup into semant…
andy-symonds Jun 7, 2024
017590a
[semantic-routing] | AS | Add test for non-streaming chat/rag endpoint
andy-symonds Jun 7, 2024
da4a406
[semantic-routing] | AS | Fixed formating
andy-symonds Jun 7, 2024
30766ad
formatting changes
Jun 8, 2024
92a79a9
Merge branch 'main' into feature/semantic-routing
Jun 8, 2024
41110d2
tests passing again
Jun 8, 2024
1e51560
linting
Jun 8, 2024
1c3a805
Merge branch 'main' into feature/semantic-routing
Jun 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core_api/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi import Depends, FastAPI, Response
from fastapi.responses import RedirectResponse

from core_api.src import services
from core_api.src.dependencies import get_elasticsearch_client
from core_api.src.routes.chat import chat_app
from core_api.src.routes.file import file_app
from redbox.models import Settings, StatusResponse
Expand Down Expand Up @@ -50,7 +50,7 @@ def root():


@app.get("/health", status_code=HTTPStatus.OK, tags=["health"])
def health(response: Response, es: Annotated[Elasticsearch, Depends(services.elasticsearch_client)]) -> StatusResponse:
def health(response: Response, es: Annotated[Elasticsearch, Depends(get_elasticsearch_client)]) -> StatusResponse:
"""Returns the health of the API

Returns:
Expand Down
16 changes: 8 additions & 8 deletions core_api/src/services.py → core_api/src/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,24 @@
log = logging.getLogger()


async def env() -> Settings:
async def get_env() -> Settings:
return Settings()


async def elasticsearch_client(env: Annotated[Settings, Depends(env)]) -> Elasticsearch:
async def get_elasticsearch_client(env: Annotated[Settings, Depends(get_env)]) -> Elasticsearch:
return env.elasticsearch_client()


async def embedding_model(env: Annotated[Settings, Depends(env)]) -> Embeddings:
async def get_embedding_model(env: Annotated[Settings, Depends(get_env)]) -> Embeddings:
embedding_model = SentenceTransformerEmbeddings(model_name=env.embedding_model, cache_folder=MODEL_PATH)
log.info("Loaded embedding model from environment: %s", env.embedding_model)
return embedding_model


async def vector_store(
env: Annotated[Settings, Depends(env)],
es: Annotated[Elasticsearch, Depends(elasticsearch_client)],
embedding_model: Annotated[Embeddings, Depends(embedding_model)],
async def get_vector_store(
env: Annotated[Settings, Depends(get_env)],
es: Annotated[Elasticsearch, Depends(get_elasticsearch_client)],
embedding_model: Annotated[Embeddings, Depends(get_embedding_model)],
) -> ElasticsearchStore:
if env.elastic.subscription_level == "basic":
strategy = ApproxRetrievalStrategy(hybrid=False)
Expand All @@ -52,7 +52,7 @@ async def vector_store(
)


async def llm(env: Annotated[Settings, Depends(env)]) -> ChatLiteLLM:
async def get_llm(env: Annotated[Settings, Depends(get_env)]) -> ChatLiteLLM:
# Create the appropriate LLM, either openai, Azure, anthropic or bedrock
if env.openai_api_key is not None:
log.info("Creating OpenAI LLM Client")
Expand Down
88 changes: 68 additions & 20 deletions core_api/src/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@
from langchain_core.prompts import ChatPromptTemplate
from langchain_elasticsearch import ElasticsearchStore

from core_api.src import services
from core_api.src.auth import get_user_uuid, get_ws_user_uuid
from core_api.src.dependencies import get_llm, get_vector_store
from core_api.src.semantic_routes import (
ABILITY_RESPONSE,
COACH_RESPONSE,
INFO_RESPONSE,
route_layer,
)
from redbox.llm.prompts.chat import (
CONDENSE_QUESTION_PROMPT,
STUFF_DOCUMENT_PROMPT,
Expand All @@ -32,21 +38,30 @@
version="0.1.0",
openapi_tags=[
{"name": "chat", "description": "Chat interactions with LLM and RAG backend"},
{"name": "embedding", "description": "Embedding interactions with SentenceTransformer"},
{
"name": "embedding",
"description": "Embedding interactions with SentenceTransformer",
},
{"name": "llm", "description": "LLM information and parameters"},
],
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
)

ROUTE_RESPONSES = {
"info": ChatPromptTemplate.from_template(INFO_RESPONSE),
"ability": ChatPromptTemplate.from_template(ABILITY_RESPONSE),
"coach": ChatPromptTemplate.from_template(COACH_RESPONSE),
"gratitude": ChatPromptTemplate.from_template("You're welcome!"),
"summarisation": ChatPromptTemplate.from_template("You are asking for summarisation - route not yet implemented"),
"extract": ChatPromptTemplate.from_template("You asking to extract some information - route not yet implemented"),
}


@chat_app.post("/vanilla", tags=["chat"], response_model=ChatResponse)
def simple_chat(
async def build_vanilla_chain(
chat_request: ChatRequest,
_user_uuid: Annotated[UUID, Depends(get_user_uuid)],
llm: Annotated[ChatLiteLLM, Depends(services.llm)],
) -> ChatResponse:
) -> ChatPromptTemplate:
"""Get a LLM response to a question history"""

if len(chat_request.message_history) < 2: # noqa: PLR2004
Expand All @@ -67,17 +82,14 @@ def simple_chat(
detail="The final entry in the chat history should be a user question",
)

chat_prompt = ChatPromptTemplate.from_messages((msg.role, msg.text) for msg in chat_request.message_history)
# Convert to LangChain style messages
messages = chat_prompt.format_messages()

response = llm(messages)
chat_response: ChatResponse = ChatResponse(output_text=response.content)
return chat_response
return ChatPromptTemplate.from_messages((msg.role, msg.text) for msg in chat_request.message_history)


async def build_retrieval_chain(
chat_request: ChatRequest, user_uuid: UUID, llm: ChatLiteLLM, vector_store: ElasticsearchStore
chat_request: ChatRequest,
user_uuid: UUID,
llm: ChatLiteLLM,
vector_store: ElasticsearchStore,
):
question = chat_request.message_history[-1].text
previous_history = list(chat_request.message_history[:-1])
Expand Down Expand Up @@ -108,12 +120,30 @@ async def build_retrieval_chain(
return docs_with_sources_chain, params


async def build_chain(
chat_request: ChatRequest,
user_uuid: UUID,
llm: ChatLiteLLM,
vector_store: ElasticsearchStore,
):
question = chat_request.message_history[-1].text
route = route_layer(question)

if route_response := ROUTE_RESPONSES.get(route.name):
return route_response, {}
# build_vanilla_chain could go here

# RAG chat
chain, params = await build_retrieval_chain(chat_request, user_uuid, llm, vector_store)
return chain, params


@chat_app.post("/rag", tags=["chat"])
async def rag_chat(
chat_request: ChatRequest,
user_uuid: Annotated[UUID, Depends(get_user_uuid)],
llm: Annotated[ChatLiteLLM, Depends(services.llm)],
vector_store: Annotated[ElasticsearchStore, Depends(services.vector_store)],
llm: Annotated[ChatLiteLLM, Depends(get_llm)],
vector_store: Annotated[ElasticsearchStore, Depends(get_vector_store)],
) -> ChatResponse:
"""Get a LLM response to a question history and file

Expand All @@ -123,7 +153,19 @@ async def rag_chat(
Returns:
StreamingResponse: a stream of the chain response
"""

question = chat_request.message_history[-1].text
route = route_layer(question)

if route_response := ROUTE_RESPONSES.get(route.name):
response = route_response.invoke({})
return ChatResponse(output_text=response.messages[0].content)

# build_vanilla_chain could go here

# RAG chat
chain, params = await build_retrieval_chain(chat_request, user_uuid, llm, vector_store)

result = chain(params)

source_documents = [
Expand All @@ -140,16 +182,16 @@ async def rag_chat(
@chat_app.websocket("/rag")
async def rag_chat_streamed(
websocket: WebSocket,
llm: Annotated[ChatLiteLLM, Depends(services.llm)],
vector_store: Annotated[ElasticsearchStore, Depends(services.vector_store)],
llm: Annotated[ChatLiteLLM, Depends(get_llm)],
vector_store: Annotated[ElasticsearchStore, Depends(get_vector_store)],
):
await websocket.accept()

user_uuid = await get_ws_user_uuid(websocket)

chat_request = ChatRequest.parse_raw(await websocket.receive_text())

chain, params = await build_retrieval_chain(chat_request, user_uuid, llm, vector_store)
chain, params = await build_chain(chat_request, user_uuid, llm, vector_store)

async for event in chain.astream_events(params, version="v1"):
kind = event["event"]
Expand All @@ -169,5 +211,11 @@ async def rag_chat_streamed(
for document in event["data"]["chunk"].get("input_documents", [])
]
await websocket.send_json({"resource_type": "documents", "data": source_documents})
elif kind == "on_prompt_stream":
try:
msg = event["data"]["chunk"].messages[0].content
await websocket.send_json({"resource_type": "text", "data": msg})
except (KeyError, AttributeError):
logging.exception("unknown message format %s", str(event["data"]["chunk"]))

await websocket.close()
107 changes: 107 additions & 0 deletions core_api/src/semantic_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from semantic_router import Route
from semantic_router.encoders import HuggingFaceEncoder
from semantic_router.layer import RouteLayer

from redbox.model_db import MODEL_PATH

# === Pre-canned responses for non-LLM routes ===
INFO_RESPONSE = """
I am RedBox, an AI focused on helping UK Civil Servants, Political Advisors and
Ministers triage and summarise information from a wide variety of sources.
"""

ABILITY_RESPONSE = """
* I can help you search over selected documents and do Q&A on them.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our markdown formatter in the front end is interpreting:

long lines of text that go on and on and on and have new line continuation mark like this \
in them as code and assigns them the <code> tag which isnt what we want

image

so i have stripped them out

* I can help you summarise selected documents.
* I can help you extract information from selected documents.
* I can return information in a variety of formats, such as bullet points.
"""

COACH_RESPONSE = """
I am sorry that didn't work.
You could try rephrasing your task, i.e if you want to summarise a document please use the term,
"Summarise the selected document" or "extract all action items from the selected document."
If you want the results to be returned in a specific format, please specify the format in as much detail as possible.
"""

# === Set up the semantic router ===
info = Route(
name="info",
utterances=[
"What is your name?",
"Who are you?",
"What is Redbox?",
],
)

ability = Route(
name="ability",
utterances=[
"What can you do?",
"What can you do?",
"How can you help me?",
"What does Redbox do?",
"What can Redbox do",
"What don't you do",
"Please help me",
"Please help",
"Help me!",
"help",
],
)

coach = Route(
name="coach",
utterances=[
"That is not the answer I wanted",
"Rubbish",
"No good",
"That's not what I wanted",
"How can I improve the results?",
],
)

gratitude = Route(
name="gratitude",
utterances=[
"Thank you ever so much for your help!",
"I'm really grateful for your assistance.",
"Cheers for the detailed response!",
"Thanks a lot, that was very informative.",
"Nice one",
"Thanks!",
],
)

summarisation = Route(
name="summarisation",
utterances=[
"I'd like to summarise the documents I've uploaded.",
"Can you help me with summarising these documents?",
"Please summarise the documents with a focus on the impact on northern England",
"Please summarise the contents of the uploaded files.",
"I'd appreciate a summary of the documents I've just uploaded.",
"Could you provide a summary of these uploaded documents?",
"Summarise the documents with a focus on macro economic trends.",
],
)

extract = Route(
name="extract",
utterances=[
"I'd like to find some information in the documents I've uploaded",
"Can you help me identify details from these documents?",
"Please give me all action items from this document",
"Give me all the action items from these meeting notes",
"Could you locate some key information in these uploaded documents?",
"I need to obtain certain details from the documents I have uploaded, please",
"Please extract all action items from this document",
"Extract all the sentences with the word 'shall'",
],
)


routes = [info, ability, coach, gratitude, summarisation, extract]

encoder = HuggingFaceEncoder(name="sentence-transformers/paraphrase-albert-small-v2", cache_dir=MODEL_PATH)
route_layer = RouteLayer(encoder=encoder, routes=routes)
Loading
Loading