Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Turning RAG into a runnable function #554

Merged
merged 8 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 94 additions & 1 deletion core_api/src/runnables.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from operator import itemgetter
from typing import Any, TypedDict
from uuid import UUID

from elasticsearch import Elasticsearch
from langchain.schema import StrOutputParser
from langchain_community.chat_models import ChatLiteLLM
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_elasticsearch import ElasticsearchRetriever

from core_api.src.format import format_chunks
from redbox.models import Chunk


def make_stuff_document_runnable(
Expand All @@ -15,6 +22,8 @@ def make_stuff_document_runnable(
"""Takes a system prompt and LLM returns a stuff document runnable.

Runnable takes input of a dict keyed to question, messages and documents.

Runnable returns a string.
"""
chat_history = [
("system", system_prompt),
Expand All @@ -32,3 +41,87 @@ def make_stuff_document_runnable(
| llm
| StrOutputParser()
)


class ESQuery(TypedDict):
question: str
file_uuids: list[UUID]
user_uuid: UUID


def make_es_retriever(
es: Elasticsearch, embedding_model: SentenceTransformerEmbeddings, chunk_index_name: str
) -> ElasticsearchRetriever:
"""Creates an Elasticsearch retriever runnable.

Runnable takes input of a dict keyed to question, file_uuids and user_uuid.

Runnable returns a list of Chunks.
"""

def es_query(query: ESQuery) -> dict[str, Any]:
vector = embedding_model.embed_query(query["question"])

knn_filter = [{"term": {"creator_user_uuid.keyword": str(query["user_uuid"])}}]

if len(query["file_uuids"]) != 0:
knn_filter.append({"terms": {"parent_file_uuid.keyword": [str(uuid) for uuid in query["file_uuids"]]}})

return {
"size": 5,
"query": {
"bool": {
"must": [
{
"knn": {
"field": "embedding",
"query_vector": vector,
"num_candidates": 10,
"filter": knn_filter,
}
}
]
}
},
}

def chunk_mapper(hit: dict[str, Any]) -> Chunk:
return Chunk(**hit["_source"])

return ElasticsearchRetriever(
es_client=es, index_name=chunk_index_name, body_func=es_query, document_mapper=chunk_mapper
)


def make_rag_runnable(
system_prompt: str,
llm: ChatLiteLLM,
retriever: VectorStoreRetriever,
) -> Runnable:
"""Takes a system prompt, LLM and retriever and returns a basic RAG runnable.

Runnable takes input of a dict keyed to question, messages and file_uuids and user_uuid.

Runnable returns a dict keyed to response and sources.
"""
chat_history = [
("system", system_prompt),
("placeholder", "{messages}"),
("user", "Question: {question}. \n\n Documents: \n\n {documents} \n\n Answer: "),
]

prompt = ChatPromptTemplate.from_messages(chat_history)

return (
RunnablePassthrough()
| {
"question": itemgetter("question"),
"messages": itemgetter("messages"),
"documents": retriever | format_chunks,
"sources": retriever,
}
| {
"response": prompt | llm | StrOutputParser(),
"sources": itemgetter("sources"),
}
)
22 changes: 21 additions & 1 deletion core_api/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from pathlib import Path
from uuid import UUID, uuid4

Expand All @@ -6,10 +7,12 @@
from elasticsearch import Elasticsearch
from fastapi.testclient import TestClient
from jose import jwt
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.llms.fake import FakeListLLM

from core_api.src.app import app as application
from core_api.src.app import env
from redbox.model_db import MODEL_PATH
from redbox.models import Chunk, File
from redbox.storage import ElasticsearchStorageHandler

Expand Down Expand Up @@ -79,13 +82,19 @@ def stored_file(elasticsearch_storage_handler, file) -> File:


@pytest.fixture()
def stored_file_chunks(stored_file) -> list[Chunk]:
def embedding_model_dim(embedding_model) -> int:
return len(embedding_model.embed_query("foo"))


@pytest.fixture()
def stored_file_chunks(stored_file, embedding_model_dim) -> list[Chunk]:
chunks: list[Chunk] = []
for i in range(5):
chunks.append(
Chunk(
text="hello",
index=i,
embedding=[1] * embedding_model_dim,
parent_file_uuid=stored_file.uuid,
creator_user_uuid=stored_file.creator_user_uuid,
)
Expand All @@ -98,6 +107,7 @@ def chunked_file(elasticsearch_storage_handler, stored_file_chunks, stored_file)
for chunk in stored_file_chunks:
elasticsearch_storage_handler.write_item(chunk)
elasticsearch_storage_handler.refresh()
time.sleep(1)
return stored_file


Expand All @@ -109,3 +119,13 @@ def file_pdf_path() -> Path:
@pytest.fixture()
def mock_llm():
return FakeListLLM(responses=["<<TESTING>>"] * 128)


@pytest.fixture()
def embedding_model() -> SentenceTransformerEmbeddings:
return SentenceTransformerEmbeddings(model_name=env.embedding_model, cache_folder=MODEL_PATH)


@pytest.fixture()
def chunk_index_name():
return f"{env.elastic_root_index}-chunk"
42 changes: 41 additions & 1 deletion core_api/tests/test_runnables.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re

from core_api.src.format import format_chunks, get_file_chunked_to_tokens
from core_api.src.runnables import make_stuff_document_runnable
from core_api.src.runnables import make_es_retriever, make_rag_runnable, make_stuff_document_runnable


def test_format_chunks(stored_file_chunks):
Expand Down Expand Up @@ -51,3 +51,43 @@ def test_make_stuff_document_runnable(mock_llm, stored_file_chunks):
)

assert response == "<<TESTING>>"


def test_make_es_retriever(es_client, embedding_model, chunked_file, chunk_index_name):
retriever = make_es_retriever(es=es_client, embedding_model=embedding_model, chunk_index_name=chunk_index_name)

one_doc_chunks = retriever.invoke(
input={"question": "hello", "file_uuids": [chunked_file.uuid], "user_uuid": chunked_file.creator_user_uuid}
)

assert {chunked_file.uuid} == {chunk.parent_file_uuid for chunk in one_doc_chunks}

no_doc_chunks = retriever.invoke(
input={"question": "tell me about energy", "file_uuids": [], "user_uuid": chunked_file.creator_user_uuid}
)

assert len(no_doc_chunks) >= 1


def test_make_rag_runnable(es_client, embedding_model, chunk_index_name, mock_llm, chunked_file):
retriever = make_es_retriever(es=es_client, embedding_model=embedding_model, chunk_index_name=chunk_index_name)

chain = make_rag_runnable(system_prompt="Your job is Q&A.", llm=mock_llm, retriever=retriever)

previous_history = [
{"text": "Lorem ipsum dolor sit amet.", "role": "user"},
{"text": "Consectetur adipiscing elit.", "role": "ai"},
{"text": "Donec cursus nunc tortor.", "role": "user"},
]

response = chain.invoke(
input={
"question": "Who are all these people?",
"messages": [(msg["role"], msg["text"]) for msg in previous_history],
"file_uuids": [chunked_file.uuid],
"user_uuid": chunked_file.creator_user_uuid,
}
)

assert response["response"] == "<<TESTING>>"
assert {chunked_file.uuid} == {chunk.parent_file_uuid for chunk in response["sources"]}
Loading
Loading