Skip to content

Commit

Permalink
Merge pull request #539 from i-dot-ai/feature/summarisation-func
Browse files Browse the repository at this point in the history
Created basic stuff document runnable creation function
  • Loading branch information
wpfl-dbt authored Jun 10, 2024
2 parents 0b3ab51 + 1f87481 commit a366439
Show file tree
Hide file tree
Showing 5 changed files with 1,230 additions and 6 deletions.
49 changes: 49 additions & 0 deletions core_api/src/format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from functools import partial, reduce
from uuid import UUID

from redbox.models.file import Chunk, Metadata
from redbox.storage import ElasticsearchStorageHandler


def format_chunks(chunks: list[Chunk]) -> str:
formatted: list[str] = []

for chunk in chunks:
doc_xml = f"<Doc{chunk.parent_file_uuid}>\n {chunk.text} \n</Doc{chunk.parent_file_uuid}>"
formatted.append(doc_xml)

return "\n\n".join(formatted)


def reduce_chunks_by_tokens(chunks: list[Chunk] | None, chunk: Chunk, max_tokens: int) -> list[Chunk]:
if not chunks:
return [chunk]

last_chunk = chunks[-1]

if chunk.token_count + last_chunk.token_count <= max_tokens:
chunks[-1] = Chunk(
parent_file_uuid=last_chunk.parent_file_uuid,
index=last_chunk.index,
text=last_chunk.text + chunk.text,
metadata=Metadata.merge(last_chunk.metadata, chunk.metadata),
creator_user_uuid=last_chunk.creator_user_uuid,
)
else:
chunk.index = last_chunk.index + 1
chunks.append(chunk)

return chunks


def get_file_chunked_to_tokens(
file_uuid: UUID, user_uuid: UUID, storage_handler: ElasticsearchStorageHandler, max_tokens: int | None = None
) -> list[Chunk]:
"""Gets a file as larger document-sized Chunks, splitting it by max_tokens."""
n = max_tokens or float("inf")
chunks_unsorted = storage_handler.get_file_chunks(parent_file_uuid=file_uuid, user_uuid=user_uuid)
chunks_sorted = sorted(chunks_unsorted, key=lambda x: x.index)

reduce_chunk_n = partial(reduce_chunks_by_tokens, max_tokens=n)

return reduce(lambda cs, c: reduce_chunk_n(cs, c), chunks_sorted, [])
34 changes: 34 additions & 0 deletions core_api/src/runnables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from operator import itemgetter

from langchain.schema import StrOutputParser
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableLambda

from core_api.src.format import format_chunks


def make_stuff_document_runnable(
system_prompt: str,
llm: ChatLiteLLM,
) -> Runnable:
"""Takes a system prompt and LLM returns a stuff document runnable.
Runnable takes input of a dict keyed to question, messages and documents.
"""
chat_history = [
("system", system_prompt),
("placeholder", "{messages}"),
("user", "Question: {question}. \n\n Documents: \n\n {documents} \n\n Answer: "),
]

return (
{
"question": itemgetter("question"),
"messages": itemgetter("messages"),
"documents": itemgetter("documents") | RunnableLambda(format_chunks),
}
| ChatPromptTemplate.from_messages(chat_history)
| llm
| StrOutputParser()
)
27 changes: 21 additions & 6 deletions core_api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from elasticsearch import Elasticsearch
from fastapi.testclient import TestClient
from jose import jwt
from langchain_community.llms.fake import FakeListLLM

from core_api.src.app import app as application
from core_api.src.app import env
Expand Down Expand Up @@ -78,14 +79,23 @@ def stored_file(elasticsearch_storage_handler, file) -> File:


@pytest.fixture()
def chunked_file(elasticsearch_storage_handler, stored_file) -> File:
def stored_file_chunks(stored_file) -> list[Chunk]:
chunks: list[Chunk] = []
for i in range(5):
chunk = Chunk(
text="hello",
index=i,
parent_file_uuid=stored_file.uuid,
creator_user_uuid=stored_file.creator_user_uuid,
chunks.append(
Chunk(
text="hello",
index=i,
parent_file_uuid=stored_file.uuid,
creator_user_uuid=stored_file.creator_user_uuid,
)
)
return chunks


@pytest.fixture()
def chunked_file(elasticsearch_storage_handler, stored_file_chunks, stored_file) -> File:
for chunk in stored_file_chunks:
elasticsearch_storage_handler.write_item(chunk)
elasticsearch_storage_handler.refresh()
return stored_file
Expand All @@ -94,3 +104,8 @@ def chunked_file(elasticsearch_storage_handler, stored_file) -> File:
@pytest.fixture()
def file_pdf_path() -> Path:
return Path(__file__).parents[2] / "tests" / "data" / "pdf" / "Cabinet Office - Wikipedia.pdf"


@pytest.fixture()
def mock_llm():
return FakeListLLM(responses=["<<TESTING>>"] * 128)
53 changes: 53 additions & 0 deletions core_api/tests/test_runnables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import re

from core_api.src.format import format_chunks, get_file_chunked_to_tokens
from core_api.src.runnables import make_stuff_document_runnable


def test_format_chunks(stored_file_chunks):
formatted_documents = format_chunks(chunks=stored_file_chunks)

assert isinstance(formatted_documents, str)
assert len(list(re.finditer("hello", formatted_documents))) == len(stored_file_chunks)


def test_get_file_chunked_to_tokens(chunked_file, elasticsearch_storage_handler):
one_document = get_file_chunked_to_tokens(
file_uuid=chunked_file.uuid,
user_uuid=chunked_file.creator_user_uuid,
storage_handler=elasticsearch_storage_handler,
)

assert len(one_document) == 1

many_documents = get_file_chunked_to_tokens(
file_uuid=chunked_file.uuid,
user_uuid=chunked_file.creator_user_uuid,
storage_handler=elasticsearch_storage_handler,
max_tokens=2,
)

assert len(many_documents) > 1


def test_make_stuff_document_runnable(mock_llm, stored_file_chunks):
chain = make_stuff_document_runnable(
system_prompt="Your job is summarisation.",
llm=mock_llm,
)

previous_history = [
{"text": "Lorem ipsum dolor sit amet.", "role": "user"},
{"text": "Consectetur adipiscing elit.", "role": "ai"},
{"text": "Donec cursus nunc tortor.", "role": "user"},
]

response = chain.invoke(
input={
"question": "Who are all these people?",
"documents": stored_file_chunks,
"messages": [(msg["role"], msg["text"]) for msg in previous_history],
}
)

assert response == "<<TESTING>>"
Loading

0 comments on commit a366439

Please sign in to comment.