Skip to content

Commit

Permalink
Merge branch 'main' into feature/run-e2e-in-cd-pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
brunns committed May 31, 2024
2 parents c84a0d4 + 0c17d64 commit 66ef2be
Show file tree
Hide file tree
Showing 15 changed files with 130 additions and 109 deletions.
11 changes: 10 additions & 1 deletion core_api/src/auth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends, HTTPException
from fastapi import Depends, HTTPException, WebSocket, WebSocketException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError, jwt
from starlette import status
Expand All @@ -22,3 +22,12 @@ async def get_user_uuid(token: Annotated[HTTPAuthorizationCredentials, Depends(h
headers={"WWW-Authenticate": "Bearer"},
)
raise credentials_exception from e


async def get_ws_user_uuid(websocket: WebSocket) -> UUID:
try:
token = dict(websocket.headers)["authorization"]
payload = jwt.get_unverified_claims(token.split(" ", 1)[-1])
return UUID(payload["user_uuid"])
except (KeyError, JWTError) as e:
raise WebSocketException(code=403, reason="authorized") from e
105 changes: 34 additions & 71 deletions core_api/src/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,21 @@

from fastapi import Depends, FastAPI, HTTPException, WebSocket
from fastapi.encoders import jsonable_encoder
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.llm import LLMChain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.chains.retrieval import create_retrieval_chain
from langchain_community.chat_models import ChatLiteLLM
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable
from langchain_core.prompts import ChatPromptTemplate
from langchain_elasticsearch import ApproxRetrievalStrategy, ElasticsearchStore

from core_api.src.auth import get_user_uuid
from core_api.src.auth import get_user_uuid, get_ws_user_uuid
from redbox.llm.prompts.chat import (
CONDENSE_QUESTION_PROMPT,
STUFF_DOCUMENT_PROMPT,
WITH_SOURCES_PROMPT,
)
from redbox.model_db import MODEL_PATH
from redbox.models import EmbeddingModelInfo, Settings
from redbox.models import Settings
from redbox.models.chat import ChatRequest, ChatResponse, SourceDocument

# === Logging ===
Expand All @@ -50,23 +45,10 @@
openapi_url="/openapi.json",
)

log.info("Loading embedding model from environment: %s", env.embedding_model)
embedding_model = SentenceTransformerEmbeddings(model_name=env.embedding_model, cache_folder=MODEL_PATH)
log.info("Loaded embedding model from environment: %s", env.embedding_model)


def populate_embedding_model_info() -> EmbeddingModelInfo:
test_text = "This is a test sentence."
embedding = embedding_model.embed_documents([test_text])[0]
return EmbeddingModelInfo(
embedding_model=env.embedding_model,
vector_size=len(embedding),
)


embedding_model_info = populate_embedding_model_info()


# === LLM setup ===

# Create the appropriate LLM, either openai, Azure, anthropic or bedrock
Expand Down Expand Up @@ -153,16 +135,10 @@ def simple_chat(chat_request: ChatRequest, _user_uuid: Annotated[UUID, Depends(g
return chat_response


@chat_app.post("/rag", tags=["chat"])
def rag_chat(chat_request: ChatRequest, user_uuid: Annotated[UUID, Depends(get_user_uuid)]) -> ChatResponse:
"""Get a LLM response to a question history and file
Args:
Returns:
StreamingResponse: a stream of the chain response
"""
async def build_retrieval_chain(
chat_request: ChatRequest,
user_uuid: UUID,
):
question = chat_request.message_history[-1].text
previous_history = list(chat_request.message_history[:-1])
previous_history = ChatPromptTemplate.from_messages(
Expand All @@ -181,16 +157,29 @@ def rag_chat(chat_request: ChatRequest, user_uuid: Annotated[UUID, Depends(get_u

standalone_question = condense_question_chain({"question": question, "chat_history": previous_history})["text"]

docs = vector_store.as_retriever(
search_kwargs={"filter": {"term": {"creator_user_uuid.keyword": str(user_uuid)}}}
).get_relevant_documents(standalone_question)
search_kwargs = {"filter": {"term": {"creator_user_uuid.keyword": str(user_uuid)}}}
docs = vector_store.as_retriever(search_kwargs=search_kwargs).get_relevant_documents(standalone_question)

params = {
"question": standalone_question,
"input_documents": docs,
}

return docs_with_sources_chain, params

result = docs_with_sources_chain(
{
"question": standalone_question,
"input_documents": docs,
},
)

@chat_app.post("/rag", tags=["chat"])
async def rag_chat(chat_request: ChatRequest, user_uuid: Annotated[UUID, Depends(get_user_uuid)]) -> ChatResponse:
"""Get a LLM response to a question history and file
Args:
Returns:
StreamingResponse: a stream of the chain response
"""
chain, params = await build_retrieval_chain(chat_request, user_uuid)
result = chain(params)

source_documents = [
SourceDocument(
Expand All @@ -207,15 +196,13 @@ def rag_chat(chat_request: ChatRequest, user_uuid: Annotated[UUID, Depends(get_u
async def rag_chat_streamed(websocket: WebSocket):
await websocket.accept()

retrieval_chain = await build_retrieval_chain()
user_uuid = await get_ws_user_uuid(websocket)

chat_request = ChatRequest.parse_raw(await websocket.receive_text())
chat_history = [
HumanMessage(content=x.text) if x.role == "user" else AIMessage(content=x.text)
for x in chat_request.message_history[:-1]
]
chat = {"chat_history": chat_history, "input": chat_request.message_history[-1].text}
async for event in retrieval_chain.astream_events(chat, version="v1"):

chain, params = await build_retrieval_chain(chat_request, user_uuid)

async for event in chain.astream_events(params, version="v1"):
kind = event["event"]
if kind == "on_chat_model_stream":
await websocket.send_json({"resource_type": "text", "data": event["data"]["chunk"].content})
Expand All @@ -235,27 +222,3 @@ async def rag_chat_streamed(websocket: WebSocket):
await websocket.send_json({"resource_type": "documents", "data": source_documents})

await websocket.close()


async def build_retrieval_chain() -> Runnable:
prompt_search_query = ChatPromptTemplate.from_messages(
[
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
(
"user",
"Given the above conversation, generate a search query to look up to get information relevant to the "
"conversation",
),
]
)
retriever_chain = create_history_aware_retriever(llm, vector_store.as_retriever(), prompt_search_query)
prompt_get_answer = ChatPromptTemplate.from_messages(
[
("system", "Answer the user's questions based on the below context:\\n\\n{context}"),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
]
)
document_chain = create_stuff_documents_chain(llm, prompt_get_answer)
return create_retrieval_chain(retriever_chain, document_chain)
2 changes: 1 addition & 1 deletion core_api/tests/routes/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def mock_build_retrieval_chain(events):
retrieval_chain = AsyncMock(spec=Runnable, name="retrieval_chain")
retrieval_chain.astream_events = astream_events

return AsyncMock(name="build_retrieval_chain", return_value=retrieval_chain)
return AsyncMock(name="build_retrieval_chain", return_value=(retrieval_chain, None))


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion django_app/redbox_app/jinja2.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def to_user_timezone(value):

def environment(**options):
extra_options = {}
env = jinja2.Environment( # nosec B701 # noqa S701
env = jinja2.Environment( # nosec: B701 # noqa: S701
**{
"autoescape": True,
"extensions": [CompressorExtension],
Expand Down
5 changes: 4 additions & 1 deletion django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ async def receive(self, text_data):
message_history: list[dict[str, str]] = [
{"role": message.role, "text": message.text} for message in session_messages
]
url = URL.build(scheme="ws", host=settings.CORE_API_HOST, port=settings.CORE_API_PORT) / "chat/rag"
url = (
URL.build(scheme=settings.WEBSOCKET_SCHEME, host=settings.CORE_API_HOST, port=settings.CORE_API_PORT)
/ "chat/rag"
)
async with connect(str(url), extra_headers={"Authorization": user.get_bearer_token()}) as websocket:
await websocket.send(json.dumps({"message_history": message_history}))
await self.send_json({"type": "session-id", "data": str(session.id)})
Expand Down
2 changes: 1 addition & 1 deletion django_app/redbox_app/redbox_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def remove_doc_view(request, doc_id: uuid):


@login_required
def chats_view(request: HttpRequest, chat_id: uuid = None):
def chats_view(request: HttpRequest, chat_id: uuid.UUID | None = None):
chat_history = ChatHistory.objects.filter(users=request.user).order_by("-created_at")

messages = []
Expand Down
13 changes: 7 additions & 6 deletions django_app/redbox_app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

SECRET_KEY = env.str("DJANGO_SECRET_KEY")
ENVIRONMENT = env.str("ENVIRONMENT")
WEBSOCKET_SCHEME = env.str("WEBSOCKET_SCHEME", default="ws")

# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = env.bool("DEBUG")
Expand Down Expand Up @@ -220,7 +221,7 @@
ALLOWED_HOSTS = [
"localhost",
"127.0.0.1",
"0.0.0.0", # noqa S104
"0.0.0.0", # noqa: S104
] # nosec B104 - don't do this on server!
else:
STORAGES = {
Expand Down Expand Up @@ -326,15 +327,15 @@
FILE_EXPIRY_IN_SECONDS = env.int("FILE_EXPIRY_IN_DAYS") * 24 * 60 * 60
SUPERUSER_EMAIL = env.str("SUPERUSER_EMAIL", None)

# Security classifications
# https://www.gov.uk/government/publications/government-security-classifications/


class Classification(StrEnum):
"""Security classifications
https://www.gov.uk/government/publications/government-security-classifications/"""

OFFICIAL = "Official"
OFFICIAL_SENSITIVE = "Official Sensitive"
SECRET = "Secret" # noqa S105
TOP_SECRET = "Top Secret" # noqa S105
SECRET = "Secret" # noqa: S105
TOP_SECRET = "Top Secret" # noqa: S105


MAX_SECURITY_CLASSIFICATION = Classification[env.str("MAX_SECURITY_CLASSIFICATION")]
7 changes: 4 additions & 3 deletions django_app/tests_playwright/pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABCMeta, abstractmethod
from itertools import islice
from pathlib import Path
from typing import NamedTuple
from typing import Any, ClassVar, NamedTuple

from _settings import BASE_URL
from axe_playwright_python.sync_playwright import Axe
Expand All @@ -16,7 +16,7 @@
class BasePage(metaclass=ABCMeta):
# All available rules/categories can be found at https://github.com/dequelabs/axe-core/blob/develop/doc/rule-descriptions.md
# Can't include all as gov.uk design system violates the "region" rule
AXE_OPTIONS = {
AXE_OPTIONS: ClassVar[dict[str, Any]] = {
"runOnly": {
"type": "tag",
"values": [
Expand Down Expand Up @@ -219,7 +219,8 @@ def get_expected_page_title(self) -> str:


def batched(iterable, n):
# TODO: Use library version in Python 3.12: https://docs.python.org/3/library/itertools.html#itertools.batched
# TODO (@brunns): Use library version when we upgrade to Python 3.12.
# https://docs.python.org/3/library/itertools.html#itertools.batched
if n < 1:
message = "n must be at least one"
raise ValueError(message)
Expand Down
3 changes: 2 additions & 1 deletion infrastructure/aws/ecs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ locals {
"COMPRESSION_ENABLED" : true,
"CONTACT_EMAIL": var.contact_email,
"FILE_EXPIRY_IN_DAYS": 30,
"MAX_SECURITY_CLASSIFICATION": "OFFICIAL_SENSITIVE"
"MAX_SECURITY_CLASSIFICATION": "OFFICIAL_SENSITIVE",
"WEBSOCKET_SCHEME": "wss"
}
}

Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pytest-asyncio = "^0.23.6"
boto3-stubs = "^1.34.106"
moto = {extras = ["s3"], version = "^5.0.5"}
httpx = "^0.27.0"
websockets = "^12.0"
playwright = "^1.43"
pytest-playwright = "^0.5"
axe-playwright-python = "^0.1"
Expand Down Expand Up @@ -108,23 +109,22 @@ select = [
"COM",
"DJ",
"DTZ",
# "EM",
"EM",
"EXE",
#* "FUTB",
"ICN",
# "ISC",
"ISC",
"LOG",
"NPY",
"PD",
# "PGH",
"PGH",
"PIE",
#* "PL",
"PT",
"PTH",
"PYI",
"RET",
"RSE",
#* "RUF",
"RUF",
"SIM",
"SLF",
"TCH",
Expand All @@ -144,3 +144,5 @@ ignore = ["COM812", "DJ001", "RET505", "RET508"]
"worker/*" = ["B008"]
"redbox/*" = ["TCH003"]
"redbox/tests/*" = ["ARG001"]
"*/admin.py" = ["RUF012"]
"*/models.py" = ["RUF012"]
Loading

0 comments on commit 66ef2be

Please sign in to comment.