Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created basic stuff document runnable creation function #539

Merged
merged 6 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions core_api/src/format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from functools import partial, reduce
from uuid import UUID

from redbox.models.file import Chunk, Metadata
from redbox.storage import ElasticsearchStorageHandler


def format_chunks(chunks: list[Chunk]) -> str:
formatted: list[str] = []

for chunk in chunks:
doc_xml = f"<Doc{chunk.parent_file_uuid}>\n {chunk.text} \n</Doc{chunk.parent_file_uuid}>"
formatted.append(doc_xml)

return "\n\n".join(formatted)


def reduce_chunks_by_tokens(chunks: list[Chunk] | None, chunk: Chunk, max_tokens: int) -> list[Chunk]:
if not chunks:
return [chunk]

last_chunk = chunks[-1]

if chunk.token_count + last_chunk.token_count <= max_tokens:
chunks[-1] = Chunk(
parent_file_uuid=last_chunk.parent_file_uuid,
index=last_chunk.index,
text=last_chunk.text + chunk.text,
metadata=Metadata.merge(last_chunk.metadata, chunk.metadata),
creator_user_uuid=last_chunk.creator_user_uuid,
)
else:
chunk.index = last_chunk.index + 1
chunks.append(chunk)

return chunks


def get_file_chunked_to_tokens(
file_uuid: UUID, user_uuid: UUID, storage_handler: ElasticsearchStorageHandler, max_tokens: int | None = None
) -> list[Chunk]:
"""Gets a file as larger document-sized Chunks, splitting it by max_tokens."""
n = max_tokens or float("inf")
chunks_unsorted = storage_handler.get_file_chunks(parent_file_uuid=file_uuid, user_uuid=user_uuid)
chunks_sorted = sorted(chunks_unsorted, key=lambda x: x.index)

reduce_chunk_n = partial(reduce_chunks_by_tokens, max_tokens=n)

return reduce(lambda cs, c: reduce_chunk_n(cs, c), chunks_sorted, [])
34 changes: 34 additions & 0 deletions core_api/src/runnables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from operator import itemgetter

from langchain.schema import StrOutputParser
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableLambda

from core_api.src.format import format_chunks


def make_stuff_document_runnable(
system_prompt: str,
llm: ChatLiteLLM,
) -> Runnable:
"""Takes a system prompt and LLM returns a stuff document runnable.

Runnable takes input of a dict keyed to question, messages and documents.
"""
chat_history = [
("system", system_prompt),
("placeholder", "{messages}"),
("user", "Question: {question}. \n\n Documents: \n\n {documents} \n\n Answer: "),
]

return (
{
"question": itemgetter("question"),
"messages": itemgetter("messages"),
"documents": itemgetter("documents") | RunnableLambda(format_chunks),
}
| ChatPromptTemplate.from_messages(chat_history)
| llm
| StrOutputParser()
)
27 changes: 21 additions & 6 deletions core_api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from elasticsearch import Elasticsearch
from fastapi.testclient import TestClient
from jose import jwt
from langchain_community.llms.fake import FakeListLLM

from core_api.src.app import app as application
from core_api.src.app import env
Expand Down Expand Up @@ -78,14 +79,23 @@ def stored_file(elasticsearch_storage_handler, file) -> File:


@pytest.fixture()
def chunked_file(elasticsearch_storage_handler, stored_file) -> File:
def stored_file_chunks(stored_file) -> list[Chunk]:
chunks: list[Chunk] = []
for i in range(5):
chunk = Chunk(
text="hello",
index=i,
parent_file_uuid=stored_file.uuid,
creator_user_uuid=stored_file.creator_user_uuid,
chunks.append(
Chunk(
text="hello",
index=i,
parent_file_uuid=stored_file.uuid,
creator_user_uuid=stored_file.creator_user_uuid,
)
)
return chunks


@pytest.fixture()
def chunked_file(elasticsearch_storage_handler, stored_file_chunks, stored_file) -> File:
for chunk in stored_file_chunks:
elasticsearch_storage_handler.write_item(chunk)
elasticsearch_storage_handler.refresh()
return stored_file
Expand All @@ -94,3 +104,8 @@ def chunked_file(elasticsearch_storage_handler, stored_file) -> File:
@pytest.fixture()
def file_pdf_path() -> Path:
return Path(__file__).parents[2] / "tests" / "data" / "pdf" / "Cabinet Office - Wikipedia.pdf"


@pytest.fixture()
def mock_llm():
return FakeListLLM(responses=["<<TESTING>>"] * 128)
53 changes: 53 additions & 0 deletions core_api/tests/test_runnables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import re

from core_api.src.format import format_chunks, get_file_chunked_to_tokens
from core_api.src.runnables import make_stuff_document_runnable


def test_format_chunks(stored_file_chunks):
formatted_documents = format_chunks(chunks=stored_file_chunks)

assert isinstance(formatted_documents, str)
assert len(list(re.finditer("hello", formatted_documents))) == len(stored_file_chunks)


def test_get_file_chunked_to_tokens(chunked_file, elasticsearch_storage_handler):
one_document = get_file_chunked_to_tokens(
file_uuid=chunked_file.uuid,
user_uuid=chunked_file.creator_user_uuid,
storage_handler=elasticsearch_storage_handler,
)

assert len(one_document) == 1

many_documents = get_file_chunked_to_tokens(
file_uuid=chunked_file.uuid,
user_uuid=chunked_file.creator_user_uuid,
storage_handler=elasticsearch_storage_handler,
max_tokens=2,
)

assert len(many_documents) > 1


def test_make_stuff_document_runnable(mock_llm, stored_file_chunks):
chain = make_stuff_document_runnable(
system_prompt="Your job is summarisation.",
llm=mock_llm,
)

previous_history = [
{"text": "Lorem ipsum dolor sit amet.", "role": "user"},
{"text": "Consectetur adipiscing elit.", "role": "ai"},
{"text": "Donec cursus nunc tortor.", "role": "user"},
]

response = chain.invoke(
input={
"question": "Who are all these people?",
"documents": stored_file_chunks,
"messages": [(msg["role"], msg["text"]) for msg in previous_history],
}
)

assert response == "<<TESTING>>"
Loading
Loading