Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add aisettings to django #927

Merged
merged 7 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions django_app/redbox_app/redbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from botocore.exceptions import ClientError
from dataclasses_json import Undefined, dataclass_json
from django.conf import settings
from django.forms.models import model_to_dict
from yarl import URL

from redbox_app.redbox_core.models import User
from redbox_app.redbox_core.models import AISettings, User

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,12 +86,22 @@ def upload_file(self, name: str, user: User) -> FileOperation:
response.raise_for_status()
return FileOperation.schema().loads(response.content)

def get_ai_settings(self, user: User) -> AISettings:
return model_to_dict(
user.ai_settings,
fields=[field.name for field in user.ai_settings._meta.fields if field.name != "label"], # noqa: SLF001
)

def rag_chat(
self, message_history: list[dict[str, str]], selected_files: list[dict[str, str]], user: User
) -> CoreChatResponse:
response = requests.post(
self.url / "chat/rag",
json={"message_history": message_history, "selected_files": selected_files},
json={
"message_history": message_history,
"selected_files": selected_files,
"ai_settings": self.get_ai_settings(user),
},
headers={"Authorization": user.get_bearer_token()},
timeout=60,
)
Expand Down
12 changes: 11 additions & 1 deletion django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from channels.generic.websocket import AsyncWebsocketConsumer
from dataclasses_json import Undefined, dataclass_json
from django.conf import settings
from django.forms.models import model_to_dict
from django.utils import timezone
from websockets import ConnectionClosedError, WebSocketClientProtocol
from websockets.client import connect
from yarl import URL

from redbox_app.redbox_core import error_messages
from redbox_app.redbox_core.models import ChatHistory, ChatMessage, ChatRoleEnum, Citation, File, User
from redbox_app.redbox_core.models import AISettings, ChatHistory, ChatMessage, ChatRoleEnum, Citation, File, User

OptFileSeq = Sequence[File] | None
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -79,6 +80,7 @@ async def llm_conversation(
message = {
"message_history": message_history,
"selected_files": [{"uuid": f.core_file_uuid} for f in selected_files],
"ai_settings": await self.get_ai_settings(user),
}
await self.send_to_server(core_websocket, message)
await self.send_to_client("session-id", session.id)
Expand Down Expand Up @@ -219,6 +221,14 @@ def get_sources_with_files(
def file_save(file):
return file.save()

@staticmethod
@database_sync_to_async
def get_ai_settings(user: User) -> AISettings:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this appears again in django_app/redbox_app/redbox_core/client.py ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't one for streamed and one for non-streamed requests?

return model_to_dict(
user.ai_settings,
fields=[field.name for field in user.ai_settings._meta.fields if field.name != "label"], # noqa: SLF001
)


class CoreError(Exception):
message: str
131 changes: 131 additions & 0 deletions django_app/redbox_app/redbox_core/migrations/0028_aisettings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Generated by Django 5.0.7 on 2024-08-07 12:33

import uuid
from django.db import migrations, models
import django.db.models.deletion


def create_default_ai_settings(apps, schema_editor):
AISettings = apps.get_model("redbox_core", "AISettings")
AISettings.objects.create(label="default")


class Migration(migrations.Migration):
dependencies = [
("redbox_core", "0027_alter_file_status"),
]

operations = [
migrations.CreateModel(
name="AISettings",
fields=[
(
"id",
models.UUIDField(
default=uuid.uuid4,
editable=False,
primary_key=True,
serialize=False,
),
),
("created_at", models.DateTimeField(auto_now_add=True)),
("modified_at", models.DateTimeField(auto_now=True)),
("label", models.CharField(max_length=50, unique=True)),
("context_window_size", models.PositiveIntegerField(default=8000)),
("rag_k", models.PositiveIntegerField(default=30)),
("rag_num_candidates", models.PositiveIntegerField(default=10)),
("rag_desired_chunk_size", models.PositiveIntegerField(default=300)),
("elbow_filter_enabled", models.BooleanField(default=False)),
(
"chat_system_prompt",
models.TextField(
default="You are an AI assistant called Redbox tasked with answering questions and providing information objectively."
),
),
(
"chat_question_prompt",
models.TextField(default="{question}\n=========\n Response: "),
),
("stuff_chunk_context_ratio", models.FloatField(default=0.75)),
(
"chat_with_docs_system_prompt",
models.TextField(
default="You are an AI assistant called Redbox tasked with answering questions on user provided documents and providing information objectively."
),
),
(
"chat_with_docs_question_prompt",
models.TextField(
default="Question: {question}. \n\n Documents: \n\n {formatted_documents} \n\n Answer: "
),
),
(
"chat_with_docs_reduce_system_prompt",
models.TextField(
default="You are an AI assistant tasked with answering questions on user provided documents. Your goal is to answer the user question based on list of summaries in a coherent manner.Please follow these guidelines while answering the question: \n1) Identify and highlight key points,\n2) Avoid repetition,\n3) Ensure the answer is easy to understand,\n4) Maintain the original context and meaning.\n"
),
),
(
"retrieval_system_prompt",
models.TextField(
default="Given the following conversation and extracted parts of a long document and a question, create a final answer. \nIf you don't know the answer, just say that you don't know. Don't try to make up an answer. If a user asks for a particular format to be returned, such as bullet points, then please use that format. If a user asks for bullet points you MUST give bullet points. If the user asks for a specific number or range of bullet points you MUST give that number of bullet points. \nUse **bold** to highlight the most question relevant parts in your response. If dealing dealing with lots of data return it in markdown table format. "
),
),
(
"retrieval_question_prompt",
models.TextField(
default="{question} \n=========\n{formatted_documents}\n=========\nFINAL ANSWER: "
),
),
(
"condense_system_prompt",
models.TextField(
default="Given the following conversation and a follow up question, generate a follow up question to be a standalone question. You are only allowed to generate one question in response. Include sources from the chat history in the standalone question created, when they are available. If you don't know the answer, just say that you don't know, don't try to make up an answer. \n"
),
),
(
"condense_question_prompt",
models.TextField(
default="{question}\n=========\n Standalone question: "
),
),
("map_max_concurrency", models.PositiveIntegerField(default=128)),
(
"chat_map_system_prompt",
models.TextField(
default="You are an AI assistant tasked with summarizing documents. Your goal is to extract the most important information and present it in a concise and coherent manner. Please follow these guidelines while summarizing: \n1) Identify and highlight key points,\n2) Avoid repetition,\n3) Ensure the summary is easy to understand,\n4) Maintain the original context and meaning.\n"
),
),
(
"chat_map_question_prompt",
models.TextField(
default="Question: {question}. \n Documents: \n {formatted_documents} \n\n Answer: "
),
),
(
"reduce_system_prompt",
models.TextField(
default="You are an AI assistant tasked with summarizing documents. Your goal is to write a concise summary of list of summaries from a list of summaries in a concise and coherent manner. Please follow these guidelines while summarizing: \n1) Identify and highlight key points,\n2) Avoid repetition,\n3) Ensure the summary is easy to understand,\n4) Maintain the original context and meaning.\n"
),
),
("llm_max_tokens", models.PositiveIntegerField(default=1024)),
("match_boost", models.PositiveIntegerField(default=1)),
("knn_boost", models.PositiveIntegerField(default=1)),
("similarity_threshold", models.PositiveIntegerField(default=0)),
],
options={
"abstract": False,
},
),
migrations.RunPython(create_default_ai_settings, migrations.RunPython.noop),
migrations.AddField(
model_name="user",
name="ai_settings",
field=models.ForeignKey(
default="default",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL. nice

on_delete=django.db.models.deletion.SET_DEFAULT,
to="redbox_core.aisettings",
to_field="label",
),
),
]
32 changes: 32 additions & 0 deletions django_app/redbox_app/redbox_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from jose import jwt
from yarl import URL

from redbox_app.redbox_core import prompts
from redbox_app.redbox_core.utils import get_date_group

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -44,6 +45,36 @@ def sanitise_string(string: str | None) -> str | None:
return string.replace("\x00", "\ufffd") if string else string


class AISettings(UUIDPrimaryKeyBase, TimeStampedModel):
label = models.CharField(max_length=50, unique=True)
context_window_size = models.PositiveIntegerField(default=8_000)
rag_k = models.PositiveIntegerField(default=30)
rag_num_candidates = models.PositiveIntegerField(default=10)
rag_desired_chunk_size = models.PositiveIntegerField(default=300)
elbow_filter_enabled = models.BooleanField(default=False)
chat_system_prompt = models.TextField(default=prompts.CHAT_SYSTEM_PROMPT)
chat_question_prompt = models.TextField(default=prompts.CHAT_QUESTION_PROMPT)
stuff_chunk_context_ratio = models.FloatField(default=0.75)
chat_with_docs_system_prompt = models.TextField(default=prompts.CHAT_WITH_DOCS_SYSTEM_PROMPT)
chat_with_docs_question_prompt = models.TextField(default=prompts.CHAT_WITH_DOCS_QUESTION_PROMPT)
chat_with_docs_reduce_system_prompt = models.TextField(default=prompts.CHAT_WITH_DOCS_REDUCE_SYSTEM_PROMPT)
retrieval_system_prompt = models.TextField(default=prompts.RETRIEVAL_SYSTEM_PROMPT)
retrieval_question_prompt = models.TextField(default=prompts.RETRIEVAL_QUESTION_PROMPT)
condense_system_prompt = models.TextField(default=prompts.CONDENSE_SYSTEM_PROMPT)
condense_question_prompt = models.TextField(default=prompts.CONDENSE_QUESTION_PROMPT)
map_max_concurrency = models.PositiveIntegerField(default=128)
chat_map_system_prompt = models.TextField(default=prompts.CHAT_MAP_SYSTEM_PROMPT)
chat_map_question_prompt = models.TextField(default=prompts.CHAT_MAP_QUESTION_PROMPT)
reduce_system_prompt = models.TextField(default=prompts.REDUCE_SYSTEM_PROMPT)
llm_max_tokens = models.PositiveIntegerField(default=1024)
match_boost = models.PositiveIntegerField(default=1)
knn_boost = models.PositiveIntegerField(default=1)
similarity_threshold = models.PositiveIntegerField(default=0)

def __str__(self) -> str:
return str(self.label)


class BusinessUnit(UUIDPrimaryKeyBase):
name = models.TextField(max_length=64, null=False, blank=False, unique=True)

Expand Down Expand Up @@ -129,6 +160,7 @@ class AIExperienceLevel(models.TextChoices):
name = models.CharField(null=True, blank=True)
ai_experience = models.CharField(null=True, blank=True, max_length=25, choices=AIExperienceLevel)
profession = models.CharField(null=True, blank=True, max_length=4, choices=Profession)
ai_settings = models.ForeignKey(AISettings, on_delete=models.SET_DEFAULT, default="default", to_field="label")
objects = BaseUserManager()

def __str__(self) -> str: # pragma: no cover
Expand Down
68 changes: 68 additions & 0 deletions django_app/redbox_app/redbox_core/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
CHAT_SYSTEM_PROMPT = (
"You are an AI assistant called Redbox tasked with answering questions and providing information objectively."
)

CHAT_WITH_DOCS_SYSTEM_PROMPT = (
"You are an AI assistant called Redbox tasked with answering questions on user provided documents and "
"providing information objectively."
)

CHAT_WITH_DOCS_REDUCE_SYSTEM_PROMPT = (
"You are an AI assistant tasked with answering questions on user provided documents. "
"Your goal is to answer the user question based on list of summaries in a coherent manner."
"Please follow these guidelines while answering the question: \n"
"1) Identify and highlight key points,\n"
"2) Avoid repetition,\n"
"3) Ensure the answer is easy to understand,\n"
"4) Maintain the original context and meaning.\n"
)

RETRIEVAL_SYSTEM_PROMPT = (
"Given the following conversation and extracted parts of a long document and a question, create a final answer. \n"
"If you don't know the answer, just say that you don't know. Don't try to make up an answer. "
"If a user asks for a particular format to be returned, such as bullet points, then please use that format. "
"If a user asks for bullet points you MUST give bullet points. "
"If the user asks for a specific number or range of bullet points you MUST give that number of bullet points. \n"
"Use **bold** to highlight the most question relevant parts in your response. "
"If dealing dealing with lots of data return it in markdown table format. "
)

CHAT_MAP_SYSTEM_PROMPT = (
"You are an AI assistant tasked with summarizing documents. "
"Your goal is to extract the most important information and present it in "
"a concise and coherent manner. Please follow these guidelines while summarizing: \n"
"1) Identify and highlight key points,\n"
"2) Avoid repetition,\n"
"3) Ensure the summary is easy to understand,\n"
"4) Maintain the original context and meaning.\n"
)

REDUCE_SYSTEM_PROMPT = (
"You are an AI assistant tasked with summarizing documents. "
"Your goal is to write a concise summary of list of summaries from a list of summaries in "
"a concise and coherent manner. Please follow these guidelines while summarizing: \n"
"1) Identify and highlight key points,\n"
"2) Avoid repetition,\n"
"3) Ensure the summary is easy to understand,\n"
"4) Maintain the original context and meaning.\n"
)

CONDENSE_SYSTEM_PROMPT = (
"Given the following conversation and a follow up question, generate a follow "
"up question to be a standalone question. "
"You are only allowed to generate one question in response. "
"Include sources from the chat history in the standalone question created, "
"when they are available. "
"If you don't know the answer, just say that you don't know, "
"don't try to make up an answer. \n"
)

CHAT_QUESTION_PROMPT = "{question}\n=========\n Response: "

CHAT_WITH_DOCS_QUESTION_PROMPT = "Question: {question}. \n\n Documents: \n\n {formatted_documents} \n\n Answer: "

RETRIEVAL_QUESTION_PROMPT = "{question} \n=========\n{formatted_documents}\n=========\nFINAL ANSWER: "

CHAT_MAP_QUESTION_PROMPT = "Question: {question}. \n Documents: \n {formatted_documents} \n\n Answer: "

CONDENSE_QUESTION_PROMPT = "{question}\n=========\n Standalone question: "
3 changes: 3 additions & 0 deletions django_app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from redbox_app.redbox_core import client
from redbox_app.redbox_core.models import (
AISettings,
BusinessUnit,
ChatHistory,
ChatMessage,
Expand All @@ -34,6 +35,8 @@ def _collect_static():

@pytest.fixture()
def create_user():
AISettings.objects.get_or_create(label="default")

def _create_user(email, date_joined_iso, is_staff=False):
date_joined = datetime.fromisoformat(date_joined_iso).astimezone(UTC)
return User.objects.create_user(email=email, date_joined=date_joined, is_staff=is_staff)
Expand Down
25 changes: 25 additions & 0 deletions django_app/tests/test_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from redbox_app.redbox_core import error_messages
from redbox_app.redbox_core.consumers import ChatConsumer
from redbox_app.redbox_core.models import ChatHistory, ChatMessage, ChatRoleEnum, File, User
from redbox_app.redbox_core.prompts import CHAT_MAP_QUESTION_PROMPT

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -257,6 +258,8 @@ async def test_chat_consumer_with_selected_files(

# TODO (@brunns): Assert selected files sent to core.
# Requires fix for https://github.com/django/channels/issues/1091
# fixed now merged in https://github.com/django/channels/pull/2101, but not released
# Retry this when a version of Channels after 4.1.0 is released
mocked_websocket = mocked_connect_with_several_files.return_value.__aenter__.return_value
expected = json.dumps(
{
Expand All @@ -268,6 +271,7 @@ async def test_chat_consumer_with_selected_files(
{"role": "user", "text": "Third question, with selected files?"},
],
"selected_files": selected_file_core_uuids,
"ai_settings": await ChatConsumer.get_ai_settings(alice),
}
)
mocked_websocket.send.assert_called_with(expected)
Expand Down Expand Up @@ -354,6 +358,27 @@ async def test_chat_consumer_with_explicit_no_document_selected_error(
await communicator.disconnect()


@pytest.mark.django_db()
@pytest.mark.asyncio()
async def test_chat_consumer_get_ai_settings(
alice: User, mocked_connect_with_explicit_no_document_selected_error: Connect
):
with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_explicit_no_document_selected_error):
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
connected, _ = await communicator.connect()
assert connected

ai_settings = await ChatConsumer.get_ai_settings(alice)

assert ai_settings["chat_map_question_prompt"] == CHAT_MAP_QUESTION_PROMPT
with pytest.raises(KeyError):
ai_settings["label"]

# Close
await communicator.disconnect()


@database_sync_to_async
def get_chat_messages(user: User) -> Sequence[ChatMessage]:
return list(
Expand Down
Loading
Loading