Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Turning RAG into a runnable function #554

Merged
merged 8 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 167 additions & 1 deletion core_api/src/runnables.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from operator import itemgetter
from typing import Any, TypedDict
from uuid import UUID

from elasticsearch import Elasticsearch
from langchain.schema import StrOutputParser
from langchain_community.chat_models import ChatLiteLLM
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_elasticsearch import ElasticsearchRetriever

from core_api.src.format import format_chunks
from redbox.models import Chunk


def make_stuff_document_runnable(
Expand All @@ -15,6 +22,8 @@ def make_stuff_document_runnable(
"""Takes a system prompt and LLM returns a stuff document runnable.

Runnable takes input of a dict keyed to question, messages and documents.

Runnable returns a string.
"""
chat_history = [
("system", system_prompt),
Expand All @@ -32,3 +41,160 @@ def make_stuff_document_runnable(
| llm
| StrOutputParser()
)


class ESQuery(TypedDict):
question: str
file_uuids: list[UUID]
user_uuid: UUID


def make_es_retriever(
es: Elasticsearch, embedding_model: SentenceTransformerEmbeddings, chunk_index_name: str
) -> ElasticsearchRetriever:
"""Creates an Elasticsearch retriever runnable.

Runnable takes input of a dict keyed to question, file_uuids and user_uuid.

Runnable returns a list of Chunks.
"""

def es_query(query: ESQuery) -> dict[str, Any]:
vector = embedding_model.embed_query(query["question"])

knn_filter = [{"term": {"creator_user_uuid.keyword": str(query["user_uuid"])}}]

if len(query["file_uuids"]) != 0:
knn_filter.append({"terms": {"parent_file_uuid.keyword": [str(uuid) for uuid in query["file_uuids"]]}})

return {
"size": 5,
"query": {
"bool": {
"must": [
{
"knn": {
"field": "embedding",
"query_vector": vector,
"num_candidates": 10,
"filter": knn_filter,
}
}
]
}
},
}

def chunk_mapper(hit: dict[str, Any]) -> Chunk:
return Chunk(**hit["_source"])

return ElasticsearchRetriever(
es_client=es, index_name=chunk_index_name, body_func=es_query, document_mapper=chunk_mapper
)


def make_rag_runnable(
system_prompt: str,
llm: ChatLiteLLM,
retriever: VectorStoreRetriever,
) -> Runnable:
"""Takes a system prompt, LLM and retriever and returns a basic RAG runnable.

Runnable takes input of a dict keyed to question, messages and file_uuids and user_uuid.

Runnable returns a dict keyed to response and sources.
"""
chat_history = [
("system", system_prompt),
("placeholder", "{messages}"),
("user", "Question: {question}. \n\n Documents: \n\n {documents} \n\n Answer: "),
]

prompt = ChatPromptTemplate.from_messages(chat_history)

return (
RunnablePassthrough()
| {
"question": itemgetter("question"),
"messages": itemgetter("messages"),
"documents": retriever | format_chunks,
"sources": retriever,
}
| {
"response": prompt | llm | StrOutputParser(),
"sources": itemgetter("sources"),
}
)


def make_condense_question_runnable(llm: ChatLiteLLM) -> Runnable:
"""Takes a system prompt and LLM returns a condense question runnable.

Runnable takes input of a dict keyed to question and messages.

Runnable returns a string.
"""
condense_prompt = (
"Given the following conversation and a follow up question, "
"rephrase the follow up question to be a standalone question. \n"
"Chat history:"
)

chat_history = [
("system", condense_prompt),
("placeholder", "{messages}"),
("user", "Follow up question: {question}. \nStandalone question: "),
]

return (
{
"question": itemgetter("question"),
"messages": itemgetter("messages"),
}
| ChatPromptTemplate.from_messages(chat_history)
| llm
| StrOutputParser()
)


def make_condense_rag_runnable(
system_prompt: str,
llm: ChatLiteLLM,
retriever: VectorStoreRetriever,
) -> Runnable:
"""Takes a system prompt, LLM and retriever and returns a condense RAG runnable.

This attempts to condense the chat history into a more salient question for the
LLM to answer, and doesn't pass the entire history on to RAG -- just the condensed
question.

Runnable takes input of a dict keyed to question, messages and file_uuids and user_uuid.

Runnable returns a dict keyed to response and sources.
"""
chat_history = [
("system", system_prompt),
("user", "Question: {question}. \n\n Documents: \n\n {documents} \n\n Answer: "),
]

prompt = ChatPromptTemplate.from_messages(chat_history)

condense_question_runnable = make_condense_question_runnable(llm=llm)

condense_question_chain = {
"question": itemgetter("question"),
"messages": itemgetter("messages"),
} | condense_question_runnable

return (
RunnablePassthrough()
| {
"question": condense_question_chain,
"documents": retriever | format_chunks,
"sources": retriever,
}
| {
"response": prompt | llm | StrOutputParser(),
"sources": itemgetter("sources"),
}
)
14 changes: 12 additions & 2 deletions core_api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,21 @@ def stored_file_1(elasticsearch_storage_handler, file) -> File:


@pytest.fixture()
def stored_file_chunks(stored_file_1) -> list[Chunk]:
def embedding_model_dim(embedding_model) -> int:
return len(embedding_model.embed_query("foo"))


@pytest.fixture()
def stored_file_chunks(stored_file_1, embedding_model_dim) -> list[Chunk]:
chunks: list[Chunk] = []
for i in range(5):
chunks.append(
Chunk(
text="hello",
index=i,
embedding=[1] * embedding_model_dim,
parent_file_uuid=stored_file_1.uuid,
creator_user_uuid=stored_file_1.creator_user_uuid,
embedding=[1] * 768,
metadata={"parent_doc_uuid": str(stored_file_1.uuid)},
)
)
Expand Down Expand Up @@ -139,6 +144,11 @@ def embedding_model() -> SentenceTransformerEmbeddings:
return SentenceTransformerEmbeddings(model_name=env.embedding_model, cache_folder=MODEL_PATH)


@pytest.fixture()
def chunk_index_name():
return f"{env.elastic_root_index}-chunk"


@pytest.fixture()
def vector_store(es_client, embedding_model):
if env.elastic.subscription_level == "basic":
Expand Down
93 changes: 91 additions & 2 deletions core_api/tests/test_runnables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@

from core_api.src.format import format_chunks, get_file_chunked_to_tokens
from core_api.src.routes.chat import build_retrieval_chain
from core_api.src.runnables import make_stuff_document_runnable
from redbox.models.chat import ChatRequest
from core_api.src.runnables import (
make_condense_question_runnable,
make_condense_rag_runnable,
make_es_retriever,
make_rag_runnable,
make_stuff_document_runnable,
)
from redbox.models import ChatRequest


def test_format_chunks(stored_file_chunks):
Expand Down Expand Up @@ -74,3 +80,86 @@ async def test_build_retrieval_chain(mock_llm, chunked_file, other_stored_file_c
)

assert all(doc.metadata["parent_doc_uuid"] == str(chunked_file.uuid) for doc in params["input_documents"])


def test_make_es_retriever(es_client, embedding_model, chunked_file, chunk_index_name):
retriever = make_es_retriever(es=es_client, embedding_model=embedding_model, chunk_index_name=chunk_index_name)

one_doc_chunks = retriever.invoke(
input={"question": "hello", "file_uuids": [chunked_file.uuid], "user_uuid": chunked_file.creator_user_uuid}
)

assert {chunked_file.uuid} == {chunk.parent_file_uuid for chunk in one_doc_chunks}

no_doc_chunks = retriever.invoke(
input={"question": "tell me about energy", "file_uuids": [], "user_uuid": chunked_file.creator_user_uuid}
)

assert len(no_doc_chunks) >= 1


def test_make_rag_runnable(es_client, embedding_model, chunk_index_name, mock_llm, chunked_file):
retriever = make_es_retriever(es=es_client, embedding_model=embedding_model, chunk_index_name=chunk_index_name)

chain = make_rag_runnable(system_prompt="Your job is Q&A.", llm=mock_llm, retriever=retriever)

previous_history = [
{"text": "Lorem ipsum dolor sit amet.", "role": "user"},
{"text": "Consectetur adipiscing elit.", "role": "ai"},
{"text": "Donec cursus nunc tortor.", "role": "user"},
]

response = chain.invoke(
input={
"question": "Who are all these people?",
"messages": [(msg["role"], msg["text"]) for msg in previous_history],
"file_uuids": [chunked_file.uuid],
"user_uuid": chunked_file.creator_user_uuid,
}
)

assert response["response"] == "<<TESTING>>"
assert {chunked_file.uuid} == {chunk.parent_file_uuid for chunk in response["sources"]}


def test_make_condense_question_runnable(mock_llm):
chain = make_condense_question_runnable(llm=mock_llm)

previous_history = [
{"text": "Lorem ipsum dolor sit amet.", "role": "user"},
{"text": "Consectetur adipiscing elit.", "role": "ai"},
{"text": "Donec cursus nunc tortor.", "role": "user"},
]

response = chain.invoke(
input={
"question": "How are you today?",
"messages": [(msg["role"], msg["text"]) for msg in previous_history],
}
)

assert response == "<<TESTING>>"


def test_make_condense_rag_runnable(es_client, embedding_model, chunk_index_name, mock_llm, chunked_file):
retriever = make_es_retriever(es=es_client, embedding_model=embedding_model, chunk_index_name=chunk_index_name)

chain = make_condense_rag_runnable(system_prompt="Your job is Q&A.", llm=mock_llm, retriever=retriever)

previous_history = [
{"text": "Lorem ipsum dolor sit amet.", "role": "user"},
{"text": "Consectetur adipiscing elit.", "role": "ai"},
{"text": "Donec cursus nunc tortor.", "role": "user"},
]

response = chain.invoke(
input={
"question": "Who are all these people?",
"messages": [(msg["role"], msg["text"]) for msg in previous_history],
"file_uuids": [chunked_file.uuid],
"user_uuid": chunked_file.creator_user_uuid,
}
)

assert response["response"] == "<<TESTING>>"
assert {chunked_file.uuid} == {chunk.parent_file_uuid for chunk in response["sources"]}
Loading
Loading